diff --git a/.cspell.json b/.cspell.json index 21199c0c5f5..af2f1401d95 100644 --- a/.cspell.json +++ b/.cspell.json @@ -66,8 +66,11 @@ "deoptimize", "emscripten", "excs", + "flufl", "fnfe", + "fsdefault", "ifexp", + "implicits", "interps", "jitted", "jitting", diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39481edaa9f..dfe7ac869dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: priority: 0 - id: ruff-check - args: [--select, I, --fix, --exit-non-zero-on-fix] + args: [--select, I, --fix, --exit-non-zero-on-fix, --config, "lint.isort.known-first-party = ['cpython', 'opcodes', 'utils']"] types_or: [python] require_serial: true priority: 1 diff --git a/Cargo.lock b/Cargo.lock index 60080587885..481cad630cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1391,9 +1391,9 @@ dependencies = [ [[package]] name = "get-size-derive2" -version = "0.7.4" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b6d1e2f75c16bfbcd0f95d84f99858a6e2f885c2287d1f5c3a96e8444a34b4" +checksum = "1da24fbda09ec01bca7cfa1797c0e520e75123bccb01dcdf9041f8aa65183bc2" dependencies = [ "attribute-derive", "quote", @@ -1402,15 +1402,16 @@ dependencies = [ [[package]] name = "get-size2" -version = "0.7.4" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49cf31a6d70300cf81461098f7797571362387ef4bf85d32ac47eaa59b3a5a1a" +checksum = "823645bc6404ae2915707777061a47d3a031a9ee0bff51b34ec973df3d8d2990" dependencies = [ "compact_str", "get-size-derive2", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "ordermap", "smallvec", + "thin-vec", ] [[package]] @@ -3315,6 +3316,74 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ruff_python_ast" +version = "0.0.0" +source = "git+https://github.com/astral-sh/ruff.git?rev=6686f63404207bfdffe8ab0acb25da72c3432190#6686f63404207bfdffe8ab0acb25da72c3432190" +dependencies = [ + "aho-corasick", + "bitflags 2.13.0", + "compact_str", + "get-size2", + "is-macro", + "memchr", + "ruff_python_trivia", + "ruff_source_file", + "ruff_text_size", + "rustc-hash", + "thin-vec", + "thiserror", +] + +[[package]] +name = "ruff_python_parser" +version = "0.0.0" +source = "git+https://github.com/astral-sh/ruff.git?rev=6686f63404207bfdffe8ab0acb25da72c3432190#6686f63404207bfdffe8ab0acb25da72c3432190" +dependencies = [ + "bitflags 2.13.0", + "bstr", + "compact_str", + "get-size2", + "memchr", + "ruff_python_ast", + "ruff_python_trivia", + "ruff_text_size", + "rustc-hash", + "static_assertions", + "thin-vec", + "unicode-ident", + "unicode-normalization", + "unicode_names2 1.3.0", +] + +[[package]] +name = "ruff_python_trivia" +version = "0.0.0" +source = "git+https://github.com/astral-sh/ruff.git?rev=6686f63404207bfdffe8ab0acb25da72c3432190#6686f63404207bfdffe8ab0acb25da72c3432190" +dependencies = [ + "itertools 0.14.0", + "ruff_source_file", + "ruff_text_size", + "unicode-ident", +] + +[[package]] +name = "ruff_source_file" +version = "0.0.0" +source = "git+https://github.com/astral-sh/ruff.git?rev=6686f63404207bfdffe8ab0acb25da72c3432190#6686f63404207bfdffe8ab0acb25da72c3432190" +dependencies = [ + "memchr", + "ruff_text_size", +] + +[[package]] +name = "ruff_text_size" +version = "0.0.0" +source = "git+https://github.com/astral-sh/ruff.git?rev=6686f63404207bfdffe8ab0acb25da72c3432190#6686f63404207bfdffe8ab0acb25da72c3432190" +dependencies = [ + "get-size2", +] + [[package]] name = "rustc-hash" version = "2.1.2" @@ -3459,12 +3528,12 @@ dependencies = [ "libc", "log", "pyo3", + "ruff_python_parser", "rustls", "rustls-graviola", "rustpython-capi", "rustpython-compiler", "rustpython-pylib", - "rustpython-ruff_python_parser", "rustpython-stdlib", "rustpython-vm", "rustyline", @@ -3497,11 +3566,11 @@ dependencies = [ "num-complex", "num-traits", "rapidhash", + "ruff_python_ast", + "ruff_python_parser", + "ruff_text_size", "rustpython-compiler-core", "rustpython-literal", - "rustpython-ruff_python_ast", - "rustpython-ruff_python_parser", - "rustpython-ruff_text_size", "rustpython-wtf8", "thiserror", "unicode_names2 2.0.0", @@ -3534,12 +3603,12 @@ dependencies = [ name = "rustpython-compiler" version = "0.5.0" dependencies = [ + "ruff_python_ast", + "ruff_python_parser", + "ruff_source_file", + "ruff_text_size", "rustpython-codegen", "rustpython-compiler-core", - "rustpython-ruff_python_ast", - "rustpython-ruff_python_parser", - "rustpython-ruff_source_file", - "rustpython-ruff_text_size", "thiserror", ] @@ -3554,7 +3623,7 @@ dependencies = [ "malachite-bigint", "num-complex", "num-traits", - "rustpython-ruff_source_file", + "ruff_source_file", "rustpython-wtf8", ] @@ -3562,8 +3631,8 @@ dependencies = [ name = "rustpython-compiler-source" version = "0.4.1+deprecated" dependencies = [ - "rustpython-ruff_source_file", - "rustpython-ruff_text_size", + "ruff_source_file", + "ruff_text_size", ] [[package]] @@ -3658,77 +3727,6 @@ dependencies = [ "rustpython-derive", ] -[[package]] -name = "rustpython-ruff_python_ast" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f021ff72cabf5e2cd6d8ec8813d376a8445a228dc610ab56c27bd9054cda70d4" -dependencies = [ - "aho-corasick", - "bitflags 2.13.0", - "compact_str", - "get-size2", - "is-macro", - "memchr", - "rustc-hash", - "rustpython-ruff_python_trivia", - "rustpython-ruff_source_file", - "rustpython-ruff_text_size", - "thiserror", -] - -[[package]] -name = "rustpython-ruff_python_parser" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01e6ee78bd9671fb5766664b2695fe1f2a92a961f4d9101646c570d8acdb1e0b" -dependencies = [ - "bitflags 2.13.0", - "bstr", - "compact_str", - "get-size2", - "memchr", - "rustc-hash", - "rustpython-ruff_python_ast", - "rustpython-ruff_python_trivia", - "rustpython-ruff_text_size", - "static_assertions", - "unicode-ident", - "unicode-normalization", - "unicode_names2 1.3.0", -] - -[[package]] -name = "rustpython-ruff_python_trivia" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e7cfd1056f3a02ff0d2d0e4474286ca963260782f878b7b81c1dd87432e682" -dependencies = [ - "itertools 0.14.0", - "rustpython-ruff_source_file", - "rustpython-ruff_text_size", - "unicode-ident", -] - -[[package]] -name = "rustpython-ruff_source_file" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "948107aad62ddb12a11fc7bf68a49e52a0b0a3737d415a2505e54f5a9edac737" -dependencies = [ - "memchr", - "rustpython-ruff_text_size", -] - -[[package]] -name = "rustpython-ruff_text_size" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8291ee0f5a779e54ccd4e0151a0c426f8b49a123f99b5b6545db17ccdd4277aa" -dependencies = [ - "get-size2", -] - [[package]] name = "rustpython-sre_engine" version = "0.5.0" @@ -3794,6 +3792,10 @@ dependencies = [ "pymath", "rand 0.10.1", "rapidhash", + "ruff_python_ast", + "ruff_python_parser", + "ruff_source_file", + "ruff_text_size", "rustls", "rustls-native-certs", "rustls-pemfile", @@ -3802,10 +3804,6 @@ dependencies = [ "rustpython-derive", "rustpython-host_env", "rustpython-pylib", - "rustpython-ruff_python_ast", - "rustpython-ruff_python_parser", - "rustpython-ruff_source_file", - "rustpython-ruff_text_size", "rustpython-vm", "sha1 0.11.0", "sha2", @@ -3866,6 +3864,9 @@ dependencies = [ "psm", "rapidhash", "result-like", + "ruff_python_ast", + "ruff_python_parser", + "ruff_text_size", "rustpython-codegen", "rustpython-common", "rustpython-compiler", @@ -3874,9 +3875,6 @@ dependencies = [ "rustpython-host_env", "rustpython-jit", "rustpython-literal", - "rustpython-ruff_python_ast", - "rustpython-ruff_python_parser", - "rustpython-ruff_text_size", "rustpython-sre_engine", "rustyline", "scopeguard", @@ -3884,6 +3882,7 @@ dependencies = [ "static_assertions", "strum", "strum_macros", + "thin-vec", "thiserror", "timsort", "wasm-bindgen", @@ -3908,6 +3907,7 @@ version = "0.5.0" dependencies = [ "console_error_panic_hook", "js-sys", + "ruff_text_size", "rustpython-common", "rustpython-pylib", "rustpython-stdlib", @@ -4392,6 +4392,12 @@ version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057" +[[package]] +name = "thin-vec" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f7e269b48f0a7dd0146680fa24b50cc67fc0373f086a5b2f99bd084639b482" + [[package]] name = "thiserror" version = "2.0.18" diff --git a/Cargo.toml b/Cargo.toml index 677de28ef1c..f13dc023fab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -183,19 +183,16 @@ rustpython-sre_engine = { path = "crates/sre_engine", version = "0.5.0" } rustpython-wtf8 = { path = "crates/wtf8", version = "0.5.0" } rustpython-doc = { path = "crates/doc", version = "0.5.0" } -# Use RustPython-packaged Ruff crates from the published fork while keeping -# existing crate names in the codebase. -ruff_python_parser = { package = "rustpython-ruff_python_parser", version = "0.15.8" } -ruff_python_ast = { package = "rustpython-ruff_python_ast", version = "0.15.8" } -ruff_text_size = { package = "rustpython-ruff_text_size", version = "0.15.8" } -ruff_source_file = { package = "rustpython-ruff_source_file", version = "0.15.8" } -# To update ruff crates, comment out the above lines and uncomment the following lines to pull directly from the Ruff repository at the specified commit hash. -# Ruff tag 0.15.8 is based on commit c2a8815842f9dc5d24ec19385eae0f1a7188b0d9 -# at the time of this capture. We use the commit hash to ensure reproducible builds. -# ruff_python_parser = { git = "https://github.com/astral-sh/ruff.git", rev = "c2a8815842f9dc5d24ec19385eae0f1a7188b0d9" } -# ruff_python_ast = { git = "https://github.com/astral-sh/ruff.git", rev = "c2a8815842f9dc5d24ec19385eae0f1a7188b0d9" } -# ruff_text_size = { git = "https://github.com/astral-sh/ruff.git", rev = "c2a8815842f9dc5d24ec19385eae0f1a7188b0d9" } -# ruff_source_file = { git = "https://github.com/astral-sh/ruff.git", rev = "c2a8815842f9dc5d24ec19385eae0f1a7188b0d9" } +# Use upstream Ruff directly while the RustPython-packaged Ruff crates lag +# behind the parser API used by this branch. +ruff_python_parser = { git = "https://github.com/astral-sh/ruff.git", rev = "6686f63404207bfdffe8ab0acb25da72c3432190" } +ruff_python_ast = { git = "https://github.com/astral-sh/ruff.git", rev = "6686f63404207bfdffe8ab0acb25da72c3432190" } +ruff_text_size = { git = "https://github.com/astral-sh/ruff.git", rev = "6686f63404207bfdffe8ab0acb25da72c3432190" } +ruff_source_file = { git = "https://github.com/astral-sh/ruff.git", rev = "6686f63404207bfdffe8ab0acb25da72c3432190" } +# ruff_python_parser = { package = "rustpython-ruff_python_parser", version = "0.15.8" } +# ruff_python_ast = { package = "rustpython-ruff_python_ast", version = "0.15.8" } +# ruff_text_size = { package = "rustpython-ruff_text_size", version = "0.15.8" } +# ruff_source_file = { package = "rustpython-ruff_source_file", version = "0.15.8" } der = { version = "0.8", features = ["alloc", "oid", "pem", "zeroize"] } phf = { version = "0.13.1", default-features = false, features = ["macros"]} @@ -311,6 +308,7 @@ tcl-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.2.0" textwrap = { version = "0.16.2", default-features = false } termios = "0.3.3" thiserror = "2.0" +thin-vec = "0.2.14" timsort = "0.1.2" tk-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.2.0" } icu_casemap = "2" diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index 00283ca05a0..31fd6296451 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -150,7 +150,6 @@ def test_parse_invalid_ast(self): self.assertRaises(TypeError, ast.parse, ast.Constant(42), optimize=optval) - @unittest.expectedFailure # TODO: RUSTPYTHON; ValueError: compile() unrecognized flags def test_optimization_levels__debug__(self): cases = [(-1, '__debug__'), (0, '__debug__'), (1, False), (2, False)] for (optval, expected) in cases: @@ -586,7 +585,6 @@ def test_invalid_sum(self): compile(m, "", "exec") self.assertIn("but got expr()", str(cm.exception)) - @unittest.expectedFailure # TODO: RUSTPYTHON; ValueError: expected str for name def test_invalid_identifier(self): m = ast.Module([ast.Expr(ast.Name(42, ast.Load()))], []) ast.fix_missing_locations(m) @@ -1365,7 +1363,6 @@ def test_replace_ignore_known_custom_instance_fields(self): self.assertIs(repl.ctx, context) self.assertRaises(AttributeError, getattr, repl, 'extra') - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "Name\.__replace__\ missing\ 1\ keyword\ argument:\ 'id'\." does not match "replace() does not support Name objects" def test_replace_reject_missing_field(self): # case: warn if deleted field is not replaced node = ast.parse('x').body[0].value @@ -1700,7 +1697,6 @@ def check_text(code, empty, full, **kwargs): full="Module(body=[Import(names=[alias(name='_ast', asname='ast')]), ImportFrom(module='module', names=[alias(name='sub')], level=0)], type_ignores=[])", ) - @unittest.expectedFailure # TODO: RUSTPYTHON; ? ^^^^^^^^^ ^^^^^^^^^ def test_copy_location(self): src = ast.parse('1 + 1', mode='eval') src.body.right = ast.copy_location(ast.Constant(2), src.body.right) @@ -1737,7 +1733,6 @@ def test_fix_missing_locations(self): "end_col_offset=0), lineno=1, col_offset=0, end_lineno=1, end_col_offset=0)])" ) - @unittest.expectedFailure # TODO: RUSTPYTHON; ? ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ def test_increment_lineno(self): src = ast.parse('1 + 1', mode='eval') self.assertEqual(ast.increment_lineno(src, n=3), src) @@ -1959,7 +1954,6 @@ def test_literal_eval_syntax_errors(self): (\ \ ''') - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: required field "lineno" missing from alias def test_bad_integer(self): # issue13436: Bad error message with invalid numeric values body = [ast.ImportFrom(module='time', @@ -3259,7 +3253,6 @@ class MyAttrs(ast.AST): r"MyAttrs.__init__ got an unexpected keyword argument 'c'."): obj = MyAttrs(c=3) - @unittest.expectedFailure # TODO: RUSTPYTHON; DeprecationWarning not triggered def test_fields_and_types_no_default(self): class FieldsAndTypesNoDefault(ast.AST): _fields = ('a',) @@ -3273,7 +3266,6 @@ class FieldsAndTypesNoDefault(ast.AST): obj = FieldsAndTypesNoDefault(a=1) self.assertEqual(obj.a, 1) - @unittest.expectedFailure # TODO: RUSTPYTHON; DeprecationWarning not triggered def test_incomplete_field_types(self): class MoreFieldsThanTypes(ast.AST): _fields = ('a', 'b') @@ -3293,7 +3285,6 @@ class MoreFieldsThanTypes(ast.AST): self.assertEqual(obj.a, 1) self.assertEqual(obj.b, 2) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'str' but 'bytes' found. def test_malformed_fields_with_bytes(self): class BadFields(ast.AST): _fields = (b'\xff'*64,) @@ -3713,7 +3704,6 @@ def assert_ast(self, code, non_optimized_target, optimized_target): f"{ast.dump(optimized_tree)}", ) - @unittest.expectedFailure # TODO: RUSTPYTHON; ValueError: compile() unrecognized flags def test_folding_format(self): code = "'%s' % (a,)" diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py index d01d36ad3db..690a6e7434e 100644 --- a/Lib/test/test_audit.py +++ b/Lib/test/test_audit.py @@ -77,7 +77,6 @@ def test_monkeypatch(self): def test_open(self): self.do_test("test_open", os_helper.TESTFN) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_cantrace(self): self.do_test("test_cantrace") diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 163ebcfb5bd..10783cf33e2 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -486,7 +486,6 @@ def test_compile_top_level_await_no_coro(self): msg=f"source={source} mode={mode}") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_compile_top_level_await(self): """Test whether code with top level await can be compiled. @@ -627,7 +626,6 @@ def test_compile_async_generator(self): exec(co, glob) self.assertEqual(type(glob['ticker']()), AsyncGeneratorType) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: <_ast.Name object at 0xb40000731e3d1360> is not an instance of def test_compile_ast(self): args = ("a*__debug__", "f.py", "exec") raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0] @@ -1020,7 +1018,6 @@ def test_exec_redirected(self): finally: sys.stdout = savestdout - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument closure def test_exec_closure(self): def function_without_closures(): return 3 * 5 diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index 8b8c452f676..16df318ae8e 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -645,7 +645,6 @@ def test_syntaxerror_indented_caret_position(self): self.assertNotIn("\f", text) self.assertIn("\n 1 + 1 = 2\n ^^^^^\n", text) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_syntaxerror_multi_line_fstring(self): script = 'foo = f"""{}\nfoo"""\n' with os_helper.temp_dir() as script_dir: diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py index 12976122241..2e1568d5ea2 100644 --- a/Lib/test/test_codeop.py +++ b/Lib/test/test_codeop.py @@ -279,7 +279,6 @@ def test_filename(self): self.assertNotEqual(compile_command("a = 1\n", "abc").co_filename, compile("a = 1\n", "def", 'single').co_filename) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 2 def test_warning(self): # Test that the warning is only returned once. with warnings_helper.check_warnings( diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py index 4d117be1b88..a6542b396cc 100644 --- a/Lib/test/test_compile.py +++ b/Lib/test/test_compile.py @@ -209,7 +209,6 @@ def test_literals_with_leading_zeroes(self): self.assertEqual(eval("0o777"), 511) self.assertEqual(eval("-0o0000010"), -8) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised def test_int_literals_too_long(self): n = 3000 source = f"a = 1\nb = 2\nc = {'3'*n}\nd = 4" @@ -283,7 +282,6 @@ def test_none_assignment(self): self.assertRaises(SyntaxError, compile, stmt, 'tmp', 'single') self.assertRaises(SyntaxError, compile, stmt, 'tmp', 'exec') - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised by compile def test_import(self): succeed = [ 'import sys', @@ -348,7 +346,6 @@ def test_lambda_consts(self): l = lambda: "this is the only const" self.assertEqual(l.__code__.co_consts, ("this is the only const",)) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised by compile def test_encoding(self): code = b'# -*- coding: badencoding -*-\npass\n' self.assertRaises(SyntaxError, compile, code, 'tmp', 'exec') @@ -465,7 +462,6 @@ def test_condition_expression_with_dead_blocks_compiles(self): # See gh-113054 compile('if (5 if 5 else T): 0', '', 'exec') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_condition_expression_with_redundant_comparisons_compiles(self): # See gh-113054, gh-114083 exprs = [ @@ -580,7 +576,6 @@ def test_compile_redundant_jump_after_convert_pseudo_ops(self): compile(ast.fix_missing_locations(tree), "", "exec") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: at 0xb77555080 file "1", line 1> != at 0xb77554f00 file "3", line 1> def test_compile_ast(self): fname = __file__ if fname.lower().endswith('pyc'): @@ -696,7 +691,6 @@ def test_single_statement(self): self.compile_single("class T:\n pass") self.compile_single("c = '''\na=1\nb=2\nc=3\n'''") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised by compile_single def test_bad_single_statement(self): self.assertInvalidSingle('1\n2') self.assertInvalidSingle('def f(): pass') @@ -708,7 +702,6 @@ def test_bad_single_statement(self): self.assertInvalidSingle('x = 5 # comment\nx = 6\n') self.assertInvalidSingle("c = '''\nd=1\n'''\na = 1\n\nb = 2\n") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: b'source code cannot contain null bytes' not found in b'OSError: stream did not contain valid UTF-8\n' def test_particularly_evil_undecodable(self): # Issue 24022 src = b'0000\x00\n00000000000\n\x00\n\x9e\n' @@ -719,7 +712,6 @@ def test_particularly_evil_undecodable(self): res = script_helper.run_python_until_end(fn)[0] self.assertIn(b"source code cannot contain null bytes", res.err) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: b'source code cannot contain null bytes' not found in b'OSError: stream did not contain valid UTF-8\n' def test_yet_more_evil_still_undecodable(self): # Issue #25388 src = b"#\x00\n#\xfd\n" @@ -756,7 +748,6 @@ def check_limit(prefix, repeated, mode="single"): # check_limit("a", " if a else a") # check_limit("if a: pass", "\nelif a: pass", mode="exec") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "cannot contain null" does not match "invalid syntax (, line 1)" def test_null_terminated(self): # The source code is null-terminated internally, but bytes-like # objects are accepted, which could be not terminated. @@ -1673,7 +1664,6 @@ class WeirdDict(dict): self.assertRaises(NameError, ns['foo']) - @unittest.expectedFailure # TODO: RUSTPYTHON; + [3, 5, 3, 5] def test_compile_warnings(self): # Each invocation of compile() emits compiler warnings, even if they # have the same message and line number. @@ -1691,7 +1681,6 @@ def test_compile_warnings(self): self.assertEqual([wm.lineno for wm in caught], [3, 5] * 2) - @unittest.expectedFailure # TODO: RUSTPYTHON; + [5, 9] def test_compile_warning_in_finally(self): # Ensure that warnings inside finally blocks are # only emitted once despite the block being @@ -1742,7 +1731,6 @@ def test_compile_warning_in_finally(self): self.assertEqual(wm.category, SyntaxWarning) self.assertIn("\"is\" with 'int' literal", str(wm.message)) - @unittest.expectedFailure # TODO: RUSTPYTHON @support.subTests('src', [ textwrap.dedent(""" def f(): diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py index 7e79732a3b9..ae1f91e65b5 100644 --- a/Lib/test/test_exceptions.py +++ b/Lib/test/test_exceptions.py @@ -231,7 +231,6 @@ def check(self, src, lineno, offset, end_lineno=None, end_offset=None, encoding= line = line.removeprefix('\ufeff') self.assertIn(line, cm.exception.text) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_offset_continuation_characters(self): check = self.check check('"\\\n"(1 for c in I,\\\n\\', 2, 2) diff --git a/Lib/test/test_fstring.py b/Lib/test/test_fstring.py index f4fca1caec7..e35d5118f18 100644 --- a/Lib/test/test_fstring.py +++ b/Lib/test/test_fstring.py @@ -701,7 +701,6 @@ def test_double_braces(self): ["f'{ {{}} }'", # dict in a set ]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_compile_time_concat(self): x = 'def' self.assertEqual('abc' f'## {x}ghi', 'abc## defghi') @@ -816,7 +815,6 @@ def build_fstr(n, extra=''): s = "f'{1}' 'x' 'y'" * 1024 self.assertEqual(eval(s), '1xy' * 1024) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_format_specifier_expressions(self): width = 10 precision = 4 @@ -947,7 +945,6 @@ def test_parens_in_expressions(self): ["f'{3)+(4}'", ]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_newlines_before_syntax_error(self): self.assertAllRaise(SyntaxError, "f-string: expecting a valid expression after '{'", @@ -1031,7 +1028,6 @@ def test_misformed_unicode_character_name(self): r"'\N{GREEK CAPITAL LETTER DELTA'", ]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_backslashes_in_expression_part(self): self.assertEqual(f"{( 1 + @@ -1732,7 +1728,6 @@ def test_with_an_underscore_and_a_comma_in_format_specifier(self): with self.assertRaisesRegex(ValueError, error_msg): f'{1:_,}' - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "f-string: expecting a valid expression after '{'" does not match "invalid syntax (?, line 1)" def test_syntax_error_for_starred_expressions(self): with self.assertRaisesRegex(SyntaxError, "can't use starred expression here"): compile("f'{*a}'", "?", "exec") diff --git a/Lib/test/test_future_stmt/test_future.py b/Lib/test/test_future_stmt/test_future.py index faa5f4cc683..8d2050a3936 100644 --- a/Lib/test/test_future_stmt/test_future.py +++ b/Lib/test/test_future_stmt/test_future.py @@ -81,7 +81,6 @@ def test_future_multiple_features(self): ): from test.test_future_stmt import test_future_multiple_features # noqa: F401 - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 1 != 24 def test_unknown_future_flag(self): code = """ from __future__ import nested_scopes @@ -135,14 +134,12 @@ def test_multiple_import_statements_on_same_line(self): """ self.assertSyntaxError(code, offset=54) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 1 != 24 def test_future_import_star(self): code = """ from __future__ import * """ self.assertSyntaxError(code, message='future feature * is not defined', offset=24) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_future_import_braces(self): code = """ from __future__ import braces @@ -188,7 +185,6 @@ def test_syntactical_future_repl(self): out = kill_python(p) self.assertNotIn(b'SyntaxError: invalid syntax', out) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_future_dotted_import(self): with self.assertRaises(ImportError): exec("from .__future__ import spam") @@ -480,7 +476,6 @@ def bar(): self.assertEqual(foo.__code__.co_cellvars, ()) self.assertEqual(foo().__code__.co_freevars, ()) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised def test_annotations_forbidden(self): with self.assertRaises(SyntaxError): self._exec_future("test: (yield)") diff --git a/Lib/test/test_genexps.py b/Lib/test/test_genexps.py index fde12f13cdc..17d2d137074 100644 --- a/Lib/test/test_genexps.py +++ b/Lib/test/test_genexps.py @@ -159,7 +159,7 @@ ... SyntaxError: cannot assign to generator expression - >>> (y for y in (1,2)) += 10 # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + >>> (y for y in (1,2)) += 10 Traceback (most recent call last): ... SyntaxError: 'generator expression' is an illegal expression for augmented assignment diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py index 1f55dfbe1ac..11d0bd54e8b 100644 --- a/Lib/test/test_global.py +++ b/Lib/test/test_global.py @@ -28,7 +28,6 @@ def setUp(self): ### Syntax error cases as covered in Python/symtable.c ###################################################### - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 12 != 5 def test_name_param(self): prog_text = """\ def fn(name_param): @@ -36,7 +35,6 @@ def fn(name_param): """ check_syntax_error(self, prog_text, lineno=2, offset=5) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 12 != 5 def test_name_after_assign(self): prog_text = """\ def fn(): @@ -45,7 +43,6 @@ def fn(): """ check_syntax_error(self, prog_text, lineno=3, offset=5) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 12 != 5 def test_name_after_use(self): prog_text = """\ def fn(): @@ -54,7 +51,6 @@ def fn(): """ check_syntax_error(self, prog_text, lineno=3, offset=5) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 12 != 5 def test_name_annot(self): prog_text_3 = """\ def fn(): diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py index cf90de7b115..19440b10115 100644 --- a/Lib/test/test_grammar.py +++ b/Lib/test/test_grammar.py @@ -114,7 +114,6 @@ def test_underscore_literals(self): # Sanity check: no literal begins with an underscore self.assertRaises(NameError, eval, "_0") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_numerical_literals(self): check = self.check_syntax_error check("0b12", "invalid digit '2' in binary literal") @@ -137,7 +136,6 @@ def test_bad_numerical_literals(self): check("1e2_", "invalid decimal literal") check("1e+", "invalid decimal literal") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_end_of_numerical_literals(self): def check(test, error=False): with self.subTest(expr=test): @@ -251,7 +249,6 @@ def test_eof_error(self): compile(s, "", "exec") self.assertIn("was never closed", str(cm.exception)) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised @skip_wasi_stack_overflow() def test_max_level(self): # Macro defined in Parser/lexer/state.h @@ -298,7 +295,6 @@ def one(): my_lst[one()-1]: int = 5 self.assertEqual(my_lst, [5]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_var_annot_syntax_errors(self): # parser pass check_syntax_error(self, "def f: int") @@ -751,7 +747,6 @@ def test_expr_stmt(self): # Check the heuristic for print & exec covers significant cases # As well as placing some limits on false positives - @unittest.expectedFailure # TODO: RUSTPYTHON def test_former_statements_refer_to_builtins(self): keywords = "print", "exec" # Cases where we want the custom error @@ -1165,7 +1160,6 @@ def continue_in_finally_after_return2(x): """, True) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_yield(self): # Allowed as standalone statement def g(): yield 1 @@ -1205,7 +1199,6 @@ def g(): rest = 4, 5, 6; yield 1, 2, 3, *rest # Check annotation refleak on SyntaxError check_syntax_error(self, "def g(a:(yield)): pass") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_yield_in_comprehensions(self): # Check yield in comprehensions def g(): [x for x in [(yield 1)]] @@ -1302,7 +1295,6 @@ def test_assert_failures(self): else: self.fail("AssertionError not raised by 'assert False'") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_assert_syntax_warnings(self): # Ensure that we warn users if they provide a non-zero length tuple as # the assertion test. @@ -1317,7 +1309,6 @@ def test_assert_syntax_warnings(self): compile('assert x, "msg"', '', 'exec') compile('assert False, "msg"', '', 'exec') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_assert_warning_promotes_to_syntax_error(self): # If SyntaxWarning is configured to be an error, it actually raises a # SyntaxError. @@ -1496,7 +1487,6 @@ def test_comparison(self): if 1 not in (): pass if 1 < 1 > 1 == 1 >= 1 <= 1 != 1 in 1 not in x is x is not x: pass - @unittest.expectedFailure # TODO: RUSTPYTHON def test_comparison_is_literal(self): def check(test, msg): self.check_syntax_warning(test, msg) @@ -1526,7 +1516,6 @@ def check(test, msg): compile('True is x', '', 'exec') compile('... is x', '', 'exec') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_warn_missed_comma(self): def check(test): self.check_syntax_warning(test, msg) diff --git a/Lib/test/test_listcomps.py b/Lib/test/test_listcomps.py index 5dbc130b4c5..5e09fad72d8 100644 --- a/Lib/test/test_listcomps.py +++ b/Lib/test/test_listcomps.py @@ -219,7 +219,6 @@ class i: [__conditional_annotations__ for x in y] """ self._check_in_scopes(code, raises=NameError) - @unittest.expectedFailure # TODO: RUSTPYTHON; SyntaxError: compiler_make_closure: cannot find '__conditional_annotations__' in parent vars def test_references___conditional_annotations___nested(self): code = """ class i: [lambda: __conditional_annotations__ for x in y] diff --git a/Lib/test/test_named_expressions.py b/Lib/test/test_named_expressions.py index 4f92176b301..2e0643484fc 100644 --- a/Lib/test/test_named_expressions.py +++ b/Lib/test/test_named_expressions.py @@ -4,35 +4,30 @@ class NamedExpressionInvalidTest(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_01(self): code = """x := 0""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_02(self): code = """x = y := 0""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_03(self): code = """y := f(x)""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_04(self): code = """y0 = y1 := f(x)""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_06(self): code = """((a, b) := (1, 2))""" diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 8d359a646d9..5a06972fdde 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -82,7 +82,6 @@ class S4(collections.UserList, dict, C): self.assertEqual(self.check_mapping_then_sequence(S3()), "seq") self.assertEqual(self.check_mapping_then_sequence(S4()), "seq") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_late_registration_mapping(self): class Parent: pass @@ -106,7 +105,6 @@ class GrandchildPost(ChildPost): self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "map") self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "map") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_late_registration_sequence(self): class Parent: pass @@ -2246,7 +2244,6 @@ def f(w): self.assertEqual(f(None), {}) self.assertEqual(f((1, 2)), {}) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_patma_210(self): def f(w): match w: @@ -2955,7 +2952,6 @@ def test_invalid_syntax_2(self): pass """) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_syntax_3(self): self.assert_syntax_error(""" match ...: @@ -3075,7 +3071,6 @@ def test_name_capture_makes_remaining_patterns_unreachable_4(self): pass """) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_patterns_may_only_match_literals_and_attribute_lookups_0(self): self.assert_syntax_error(""" match ...: @@ -3083,7 +3078,6 @@ def test_patterns_may_only_match_literals_and_attribute_lookups_0(self): pass """) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_patterns_may_only_match_literals_and_attribute_lookups_1(self): self.assert_syntax_error(""" match ...: @@ -3126,7 +3120,6 @@ def test_real_number_multiple_ops(self): pass """) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_real_number_wrong_ops(self): for op in ["*", "/", "@", "**", "%", "//"]: with self.subTest(op=op): @@ -3202,7 +3195,6 @@ def test_mapping_pattern_duplicate_key(self): pass """) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_mapping_pattern_duplicate_key_edge_case0(self): self.assert_syntax_error(""" match ...: @@ -3210,7 +3202,6 @@ def test_mapping_pattern_duplicate_key_edge_case0(self): pass """) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_mapping_pattern_duplicate_key_edge_case1(self): self.assert_syntax_error(""" match ...: @@ -3225,8 +3216,6 @@ def test_mapping_pattern_duplicate_key_edge_case2(self): pass """) - - @unittest.expectedFailure # TODO: RUSTPYTHON def test_mapping_pattern_duplicate_key_edge_case3(self): self.assert_syntax_error(""" match ...: @@ -3258,7 +3247,6 @@ def test_accepts_positional_subpatterns_1(self): self.assertEqual(x, range(10)) self.assertIs(y, None) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_got_multiple_subpatterns_for_attribute_0(self): class Class: __match_args__ = ("a", "a") @@ -3273,7 +3261,6 @@ class Class: self.assertIs(y, None) self.assertIs(z, None) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_got_multiple_subpatterns_for_attribute_1(self): class Class: __match_args__ = ("a",) @@ -3379,7 +3366,6 @@ class A: class TestValueErrors(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON def test_mapping_pattern_checks_duplicate_key_1(self): class Keys: KEY = "a" diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py index 8b2806781af..97f084088c5 100644 --- a/Lib/test/test_pdb.py +++ b/Lib/test/test_pdb.py @@ -2165,7 +2165,7 @@ def test_pdb_await_support(): >>> def test_function(): ... asyncio.run(main(), loop_factory=asyncio.EventLoop) - >>> with PdbTestInput([ # TODO: RUSTPYTHON # doctest: +ELLIPSIS +EXPECTED_FAILURE + >>> with PdbTestInput([ # doctest: +ELLIPSIS ... 'x = await task', ... 'p x', ... 'x = await test()', @@ -2280,7 +2280,7 @@ def test_pdb_await_contextvar(): >>> def test_function(): ... asyncio.run(main(), loop_factory=asyncio.EventLoop) - >>> with PdbTestInput([ # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + >>> with PdbTestInput([ ... 'p var.get()', ... 'print(await get_var())', ... 'print(await asyncio.create_task(set_var(100)))', @@ -2768,7 +2768,7 @@ def test_pdb_multiline_statement(): >>> def test_function(): ... import pdb; pdb.Pdb(nosigint=True, readrc=False).set_trace() - >>> with PdbTestInput([ # TODO: RUSTPYTHON # doctest: +NORMALIZE_WHITESPACE +EXPECTED_FAILURE + >>> with PdbTestInput([ # doctest: +NORMALIZE_WHITESPACE ... 'def f(x):', ... ' return x * 2', ... '', diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py index eb251568767..14657dd2e77 100644 --- a/Lib/test/test_peepholer.py +++ b/Lib/test/test_peepholer.py @@ -157,7 +157,6 @@ def test_pack_unpack(self): self.assertNotInBytecode(code, 'UNPACK_SEQUENCE') self.check_lnotab(code) - @unittest.expectedFailure # TODO: RUSTPYTHON; LOAD_CONST count mismatch in long-tuple branch def test_constant_folding_tuples_of_constants(self): for line, elem in ( ('a = 1,2,3', (1, 2, 3)), diff --git a/Lib/test/test_pep646_syntax.py b/Lib/test/test_pep646_syntax.py index 8034bb9e935..d79196219fe 100644 --- a/Lib/test/test_pep646_syntax.py +++ b/Lib/test/test_pep646_syntax.py @@ -312,7 +312,7 @@ >>> f4.__annotations__ {'args': StarredB, 'arg1': } - >>> def f5(*args: *b = (1,)): pass # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + >>> def f5(*args: *b = (1,)): pass Traceback (most recent call last): ... SyntaxError: invalid syntax diff --git a/Lib/test/test_pydoc/test_pydoc.py b/Lib/test/test_pydoc/test_pydoc.py index 46f8ba60f8b..2a96ef4dd71 100644 --- a/Lib/test/test_pydoc/test_pydoc.py +++ b/Lib/test/test_pydoc/test_pydoc.py @@ -932,7 +932,6 @@ def test_synopsis(self): synopsis = pydoc.synopsis(TESTFN, {}) self.assertEqual(synopsis, 'line 1: h\xe9') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_source_synopsis(self): def check(source, expected, encoding=None): if isinstance(source, str): diff --git a/Lib/test/test_pyrepl/test_interact.py b/Lib/test/test_pyrepl/test_interact.py index e4f90db3304..65b1eed5bdd 100644 --- a/Lib/test/test_pyrepl/test_interact.py +++ b/Lib/test/test_pyrepl/test_interact.py @@ -117,7 +117,6 @@ def f(x, x): ... SyntaxError: duplicate argument 'x' in function definition""" self.assertIn(r, f.getvalue()) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_runsource_shows_syntax_error_for_failed_compilation(self): console = InteractiveColoredConsole() source = "print('Hello, world!'" @@ -133,7 +132,6 @@ def test_runsource_shows_syntax_error_for_failed_compilation(self): console.runsource(source) mock_showsyntaxerror.assert_called_once() - @unittest.expectedFailure # TODO: RUSTPYTHON def test_runsource_survives_null_bytes(self): console = InteractiveColoredConsole() source = "\x00\n" @@ -155,7 +153,6 @@ def test_no_active_future(self): self.assertFalse(result) self.assertEqual(f.getvalue(), "{'x': }\n") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_future_annotations(self): console = InteractiveColoredConsole() source = dedent("""\ @@ -210,7 +207,6 @@ def test_multiline_single_assignment(self): console = InteractiveColoredConsole(namespace, filename="") self.assertFalse(_more_lines(console, code)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_multiline_single_block(self): namespace = {} code = dedent("""\ @@ -227,7 +223,6 @@ def test_multiple_statements_single_line(self): console = InteractiveColoredConsole(namespace, filename="") self.assertFalse(_more_lines(console, code)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_multiple_statements(self): namespace = {} code = dedent("""\ @@ -237,7 +232,6 @@ def test_multiple_statements(self): console = InteractiveColoredConsole(namespace, filename="") self.assertTrue(_more_lines(console, code)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_multiple_blocks(self): namespace = {} code = dedent("""\ @@ -285,7 +279,6 @@ def test_incomplete_statement(self): class TestWarnings(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON def test_pep_765_warning(self): """ Test that a SyntaxWarning emitted from the diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index 74735ef3c84..1bf3f9715b4 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -466,7 +466,6 @@ def prepare_reader(self, events): reader = ReadlineAlikeReader(console=console, config=config) return reader - @unittest.expectedFailure # TODO: RUSTPYTHON def test_auto_indent_default(self): # fmt: off input_code = ( @@ -486,7 +485,6 @@ def test_auto_indent_default(self): output = multiline_input(reader) self.assertEqual(output, output_code) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_auto_indent_continuation(self): # auto indenting according to previous user indentation # fmt: off @@ -514,7 +512,6 @@ def test_auto_indent_continuation(self): output = multiline_input(reader) self.assertEqual(output, output_code) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_auto_indent_prev_block(self): # auto indenting according to indentation in different block # fmt: off @@ -546,7 +543,6 @@ def test_auto_indent_prev_block(self): output2 = multiline_input(reader) self.assertEqual(output2, output_code) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_auto_indent_multiline(self): # fmt: off events = itertools.chain( @@ -586,7 +582,6 @@ def test_auto_indent_multiline(self): output = multiline_input(reader) self.assertEqual(output, output_code) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_auto_indent_with_comment(self): # fmt: off events = code_to_events( @@ -605,7 +600,6 @@ def test_auto_indent_with_comment(self): output = multiline_input(reader) self.assertEqual(output, output_code) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_auto_indent_with_multicomment(self): # fmt: off events = code_to_events( @@ -680,7 +674,6 @@ def test_get_line_buffer_returns_str(self): wrapper = _ReadlineWrapper(f_in=None, f_out=None, reader=reader) self.assertIs(type(wrapper.get_line_buffer()), str) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_multiline_edit(self): events = itertools.chain( code_to_events("def f():\n...\n\n"), @@ -744,7 +737,6 @@ def test_history_navigation_with_up_arrow(self): self.assertEqual(output, "1+1") self.assert_screen_equal(reader, "1+1", clean=True) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_history_with_multiline_entries(self): code = "def foo():\nx = 1\ny = 2\nz = 3\n\ndef bar():\nreturn 42\n\n" events = list(itertools.chain( @@ -1426,7 +1418,6 @@ def test_paste_mid_newlines(self): output = multiline_input(reader) self.assertEqual(output, code) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_paste_mid_newlines_not_in_paste_mode(self): # fmt: off code = ( @@ -1448,7 +1439,6 @@ def test_paste_mid_newlines_not_in_paste_mode(self): output = multiline_input(reader) self.assertEqual(output, expected) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_paste_not_in_paste_mode(self): # fmt: off input_code = ( diff --git a/Lib/test/test_pyrepl/test_reader.py b/Lib/test/test_pyrepl/test_reader.py index 33ef95accbd..51644ec7ce4 100644 --- a/Lib/test/test_pyrepl/test_reader.py +++ b/Lib/test/test_pyrepl/test_reader.py @@ -180,7 +180,6 @@ def test_up_arrow_after_ctrl_r(self): reader, _ = handle_all_events(events) self.assert_screen_equal(reader, "") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: Lists differ def test_newline_within_block_trailing_whitespace(self): # fmt: off code = ( diff --git a/Lib/test/test_symtable.py b/Lib/test/test_symtable.py index f36bbcaea1f..b55adab6baf 100644 --- a/Lib/test/test_symtable.py +++ b/Lib/test/test_symtable.py @@ -198,7 +198,6 @@ class SymtableTest(unittest.TestCase): T = find_block(GenericMine, "T") U = find_block(GenericMine, "U") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: != 'type alias' def test_type(self): self.assertEqual(self.top.get_type(), "module") self.assertEqual(self.Mine.get_type(), "class") diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py index 0934f22d470..5013eb096f5 100644 --- a/Lib/test/test_syntax.py +++ b/Lib/test/test_syntax.py @@ -59,15 +59,15 @@ Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> def __debug__(): pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def __debug__(): pass Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> async def __debug__(): pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> async def __debug__(): pass Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> class __debug__: pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> class __debug__: pass Traceback (most recent call last): SyntaxError: cannot assign to __debug__ @@ -75,7 +75,7 @@ Traceback (most recent call last): SyntaxError: cannot delete __debug__ ->>> f() = 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f() = 1 Traceback (most recent call last): SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? @@ -83,11 +83,11 @@ Traceback (most recent call last): SyntaxError: assignment to yield expression not possible ->>> del f() # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> del f() Traceback (most recent call last): SyntaxError: cannot delete function call ->>> a + 1 = 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> a + 1 = 2 Traceback (most recent call last): SyntaxError: cannot assign to expression here. Maybe you meant '==' instead of '='? @@ -120,7 +120,7 @@ This test just checks a couple of cases rather than enumerating all of them. ->>> (a, "b", c) = (1, 2, 3) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> (a, "b", c) = (1, 2, 3) Traceback (most recent call last): SyntaxError: cannot assign to literal @@ -168,15 +168,15 @@ Traceback (most recent call last): SyntaxError: expected 'else' after 'if' expression ->>> x = 1 if 1 else pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> x = 1 if 1 else pass Traceback (most recent call last): SyntaxError: expected expression after 'else', but statement is given ->>> x = pass if 1 else 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> x = pass if 1 else 1 Traceback (most recent call last): SyntaxError: expected expression before 'if', but statement is given ->>> x = pass if 1 else pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> x = pass if 1 else pass Traceback (most recent call last): SyntaxError: expected expression before 'if', but statement is given @@ -200,15 +200,15 @@ Traceback (most recent call last): SyntaxError: assignment to yield expression not possible ->>> a, b += 1, 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> a, b += 1, 2 Traceback (most recent call last): SyntaxError: 'tuple' is an illegal expression for augmented assignment ->>> (a, b) += 1, 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> (a, b) += 1, 2 Traceback (most recent call last): SyntaxError: 'tuple' is an illegal expression for augmented assignment ->>> [a, b] += 1, 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> [a, b] += 1, 2 Traceback (most recent call last): SyntaxError: 'list' is an illegal expression for augmented assignment @@ -243,7 +243,7 @@ Traceback (most recent call last): SyntaxError: cannot assign to expression ->>> for i < (): pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> for i < (): pass Traceback (most recent call last): SyntaxError: invalid syntax @@ -285,11 +285,11 @@ Comprehensions without 'in' keyword: ->>> [x for x if range(1)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> [x for x if range(1)] Traceback (most recent call last): SyntaxError: 'in' expected after for-loop variables ->>> tuple(x for x if range(1)) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> tuple(x for x if range(1)) Traceback (most recent call last): SyntaxError: 'in' expected after for-loop variables @@ -301,7 +301,7 @@ Traceback (most recent call last): SyntaxError: cannot assign to expression ->>> [x for a, b, (c + 1, d()) if y] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> [x for a, b, (c + 1, d()) if y] Traceback (most recent call last): SyntaxError: 'in' expected after for-loop variables @@ -316,11 +316,11 @@ Comprehensions creating tuples without parentheses should produce a specialized error message: ->>> [x,y for x,y in range(100)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> [x,y for x,y in range(100)] Traceback (most recent call last): SyntaxError: did you forget parentheses around the comprehension target? ->>> {x,y for x,y in range(100)} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> {x,y for x,y in range(100)} Traceback (most recent call last): SyntaxError: did you forget parentheses around the comprehension target? @@ -385,7 +385,7 @@ # But prefixes of soft keywords should # still raise specialized errors ->>> (mat x) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> (mat x) Traceback (most recent call last): SyntaxError: invalid syntax. Perhaps you forgot a comma? @@ -413,7 +413,7 @@ Traceback (most recent call last): SyntaxError: invalid syntax ->>> def f(*None): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def f(*None): ... pass Traceback (most recent call last): SyntaxError: invalid syntax @@ -423,7 +423,7 @@ Traceback (most recent call last): SyntaxError: invalid syntax ->>> def foo(/,a,b=,c): +>>> def foo(/,a,b=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: at least one argument must precede / @@ -468,12 +468,12 @@ Traceback (most recent call last): SyntaxError: var-positional argument cannot have default value ->>> def foo(a,**b=3): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def foo(a,**b=3): ... pass Traceback (most recent call last): SyntaxError: var-keyword argument cannot have default value ->>> def foo(a,**b: int=3): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def foo(a,**b: int=3): ... pass Traceback (most recent call last): SyntaxError: var-keyword argument cannot have default value @@ -523,22 +523,22 @@ Traceback (most recent call last): SyntaxError: * argument may appear only once ->>> def foo(a=1,/*,b,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def foo(a=1,/*,b,c): ... pass Traceback (most recent call last): SyntaxError: expected comma between / and * ->>> def foo(a=1,d=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def foo(a=1,d=,c): ... pass Traceback (most recent call last): SyntaxError: expected default value expression ->>> def foo(a,d=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def foo(a,d=,c): ... pass Traceback (most recent call last): SyntaxError: expected default value expression ->>> def foo(a,d: int=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> def foo(a,d: int=,c): ... pass Traceback (most recent call last): SyntaxError: expected default value expression @@ -571,7 +571,7 @@ Traceback (most recent call last): SyntaxError: / must be ahead of * ->>> lambda a=1,/*,b,c: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> lambda a=1,/*,b,c: None Traceback (most recent call last): SyntaxError: expected comma between / and * @@ -579,7 +579,7 @@ Traceback (most recent call last): SyntaxError: var-positional argument cannot have default value ->>> lambda a,**b=3: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> lambda a,**b=3: None Traceback (most recent call last): SyntaxError: var-keyword argument cannot have default value @@ -619,11 +619,11 @@ Traceback (most recent call last): SyntaxError: * argument may appear only once ->>> lambda a=1,d=,c: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> lambda a=1,d=,c: None Traceback (most recent call last): SyntaxError: expected default value expression ->>> lambda a,d=,c: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> lambda a,d=,c: None Traceback (most recent call last): SyntaxError: expected default value expression @@ -641,7 +641,7 @@ ... a, # type: int ... ): ... pass -... ''', type_comments=True) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +... ''', type_comments=True) Traceback (most recent call last): SyntaxError: bare * has associated type comment @@ -784,7 +784,7 @@ ... 290, 291, 292, 293, 294, 295, 296, 297, 298, 299) # doctest: +ELLIPSIS (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ..., 297, 298, 299) ->>> f(lambda x: x[0] = 3) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(lambda x: x[0] = 3) Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? @@ -796,25 +796,25 @@ The grammar accepts any test (basically, any expression) in the keyword slot of a call site. Test a few different options. ->>> f(x()=2) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(x()=2) Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f(a or b=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(a or b=1) Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f(x.y=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(x.y=1) Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f((x)=2) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f((x)=2) Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f(True=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(True=1) Traceback (most recent call last): SyntaxError: cannot assign to True ->>> f(False=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(False=1) Traceback (most recent call last): SyntaxError: cannot assign to False ->>> f(None=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(None=1) Traceback (most recent call last): SyntaxError: cannot assign to None >>> f(__debug__=1) @@ -826,42 +826,42 @@ >>> x.__debug__: int Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> f(a=) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(a=) Traceback (most recent call last): SyntaxError: expected argument value expression ->>> f(a, b, c=) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(a, b, c=) Traceback (most recent call last): SyntaxError: expected argument value expression ->>> f(a, b, c=, d) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(a, b, c=, d) Traceback (most recent call last): SyntaxError: expected argument value expression ->>> f(*args=[0]) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(*args=[0]) Traceback (most recent call last): SyntaxError: cannot assign to iterable argument unpacking ->>> f(a, b, *args=[0]) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(a, b, *args=[0]) Traceback (most recent call last): SyntaxError: cannot assign to iterable argument unpacking ->>> f(**kwargs={'a': 1}) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(**kwargs={'a': 1}) Traceback (most recent call last): SyntaxError: cannot assign to keyword argument unpacking ->>> f(a, b, *args, **kwargs={'a': 1}) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> f(a, b, *args, **kwargs={'a': 1}) Traceback (most recent call last): SyntaxError: cannot assign to keyword argument unpacking More set_context(): ->>> (x for x in x) += 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> (x for x in x) += 1 Traceback (most recent call last): SyntaxError: 'generator expression' is an illegal expression for augmented assignment ->>> None += 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> None += 1 Traceback (most recent call last): SyntaxError: 'None' is an illegal expression for augmented assignment >>> __debug__ += 1 Traceback (most recent call last): SyntaxError: cannot assign to __debug__ >>> f() += 1 # TODO: RUSTPYTHON; Raises an exception # doctest: +SKIP -Traceback (most recent call last): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): SyntaxError: 'function call' is an illegal expression for augmented assignment @@ -957,7 +957,7 @@ elif can't come after an else. - >>> if a % 2 == 0: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> if a % 2 == 0: ... pass ... else: ... pass @@ -1185,7 +1185,7 @@ Traceback (most recent call last): SyntaxError: expected ':' - >>> with (blech as something) + >>> with (blech as something) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected ':' @@ -1195,12 +1195,12 @@ Traceback (most recent call last): SyntaxError: expected ':' - >>> with (blech, block as something) + >>> with (blech, block as something) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected ':' - >>> with (blech, block as something, bluch) + >>> with (blech, block as something, bluch) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected ':' @@ -1313,39 +1313,39 @@ Parenthesized arguments in function definitions - >>> def f(x, (y, z), w): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> def f(x, (y, z), w): ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> def f((x, y, z, w)): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> def f((x, y, z, w)): ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> def f(x, (y, z, w)): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> def f(x, (y, z, w)): ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> def f((x, y, z), w): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> def f((x, y, z), w): ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> lambda x, (y, z), w: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> lambda x, (y, z), w: None Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized - >>> lambda (x, y, z, w): None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> lambda (x, y, z, w): None Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized - >>> lambda x, (y, z, w): None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> lambda x, (y, z, w): None Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized - >>> lambda (x, y, z), w: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> lambda (x, y, z), w: None Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized @@ -1361,7 +1361,7 @@ >>> try: ... pass - ... except TypeError as __debug__: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... except TypeError as __debug__: ... pass Traceback (most recent call last): SyntaxError: cannot assign to __debug__ @@ -1410,28 +1410,28 @@ Better error message for using `except as` with not a name: - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... pass ... except TypeError as obj.attr: ... pass Traceback (most recent call last): SyntaxError: cannot use except statement with attribute - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... pass ... except TypeError as obj[1]: ... pass Traceback (most recent call last): SyntaxError: cannot use except statement with subscript - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... pass ... except* TypeError as (obj, name): ... pass Traceback (most recent call last): SyntaxError: cannot use except* statement with tuple - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... pass ... except* TypeError as 1: ... pass @@ -1440,18 +1440,18 @@ Regression tests for gh-133999: - >>> try: pass - ... except TypeError as name: raise from None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... except TypeError as name: raise from None Traceback (most recent call last): SyntaxError: invalid syntax - >>> try: pass - ... except* TypeError as name: raise from None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... except* TypeError as name: raise from None Traceback (most recent call last): SyntaxError: invalid syntax - >>> match 1: - ... case 1 | 2 as abc: raise from None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match 1: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... case 1 | 2 as abc: raise from None Traceback (most recent call last): SyntaxError: invalid syntax @@ -1464,7 +1464,7 @@ Traceback (most recent call last): SyntaxError: invalid syntax - >>> dict(x=34, (x for x in range 10), 1); x $ y + >>> dict(x=34, (x for x in range 10), 1); x $ y # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax @@ -1474,27 +1474,27 @@ Incomplete dictionary literals - >>> {1:2, 3:4, 5} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1:2, 3:4, 5} Traceback (most recent call last): SyntaxError: ':' expected after dictionary key - >>> {1:2, 3:4, 5:} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1:2, 3:4, 5:} Traceback (most recent call last): SyntaxError: expression expected after dictionary key and ':' - >>> {1: *12+1, 23: 1} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1: *12+1, 23: 1} Traceback (most recent call last): SyntaxError: cannot use a starred expression in a dictionary value - >>> {1: *12+1} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1: *12+1} Traceback (most recent call last): SyntaxError: cannot use a starred expression in a dictionary value - >>> {1: 23, 1: *12+1} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1: 23, 1: *12+1} Traceback (most recent call last): SyntaxError: cannot use a starred expression in a dictionary value - >>> {1:} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1:} Traceback (most recent call last): SyntaxError: expression expected after dictionary key and ':' @@ -1506,7 +1506,7 @@ # Ensure that the error is not raised for invalid expressions - >>> {1: 2, 3: foo(,), 4: 5} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> {1: 2, 3: foo(,), 4: 5} Traceback (most recent call last): SyntaxError: invalid syntax @@ -1516,48 +1516,48 @@ Specialized indentation errors: - >>> while condition: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> while condition: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'while' statement on line 1 - >>> for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> for x in range(10): ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'for' statement on line 1 - >>> for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> for x in range(10): ... pass ... else: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'else' statement on line 3 - >>> async for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> async for x in range(10): ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'for' statement on line 1 - >>> async for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> async for x in range(10): ... pass ... else: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'else' statement on line 3 - >>> if something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> if something: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'if' statement on line 1 - >>> if something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> if something: ... pass ... elif something_else: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'elif' statement on line 3 - >>> if something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> if something: ... pass ... elif something_else: ... pass @@ -1566,33 +1566,33 @@ Traceback (most recent call last): IndentationError: expected an indented block after 'else' statement on line 5 - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'try' statement on line 1 - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... something() ... except: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'except' statement on line 3 - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... something() ... except A: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'except' statement on line 3 - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... something() ... except* A: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'except*' statement on line 3 - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... something() ... except A: ... pass @@ -1601,7 +1601,7 @@ Traceback (most recent call last): IndentationError: expected an indented block after 'finally' statement on line 5 - >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> try: ... something() ... except* A: ... pass @@ -1610,57 +1610,57 @@ Traceback (most recent call last): IndentationError: expected an indented block after 'finally' statement on line 5 - >>> with A: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> with A: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> with A as a, B as b: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> with A as a, B as b: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> with (A as a, B as b): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> with (A as a, B as b): ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> async with A: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> async with A: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> async with A as a, B as b: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> async with A as a, B as b: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> async with (A as a, B as b): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> async with (A as a, B as b): ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> def foo(x, /, y, *, z=2): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> def foo(x, /, y, *, z=2): ... pass Traceback (most recent call last): IndentationError: expected an indented block after function definition on line 1 - >>> def foo[T](x, /, y, *, z=2): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> def foo[T](x, /, y, *, z=2): ... pass Traceback (most recent call last): IndentationError: expected an indented block after function definition on line 1 - >>> class Blech(A): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class Blech(A): ... pass Traceback (most recent call last): IndentationError: expected an indented block after class definition on line 1 - >>> class Blech[T](A): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class Blech[T](A): ... pass Traceback (most recent call last): IndentationError: expected an indented block after class definition on line 1 - >>> class C(__debug__=42): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class C(__debug__=42): ... Traceback (most recent call last): SyntaxError: cannot assign to __debug__ @@ -1668,23 +1668,23 @@ ... def __new__(*args, **kwargs): ... pass - >>> class C(metaclass=Meta, __debug__=42): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class C(metaclass=Meta, __debug__=42): ... pass Traceback (most recent call last): SyntaxError: cannot assign to __debug__ - >>> match something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match something: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'match' statement on line 1 - >>> match something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match something: ... case []: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'case' statement on line 2 - >>> match something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match something: ... case []: ... ... ... case {}: @@ -1981,23 +1981,23 @@ Traceback (most recent call last): SyntaxError: cannot assign to t-string expression here. Maybe you meant '==' instead of '='? ->>> (x, y, z=3, d, e) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> (x, y, z=3, d, e) Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> [x, y, z=3, d, e] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> [x, y, z=3, d, e] Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> [z=3] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> [z=3] Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> {x, y, z=3, d, e} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> {x, y, z=3, d, e} Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> {z=3} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> {z=3} Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? @@ -2009,35 +2009,35 @@ Traceback (most recent call last): SyntaxError: trailing comma not allowed without surrounding parentheses ->>> import a from b # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a from b Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z from b.y.z # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a.y.z from b.y.z Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a from b as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a from b as bar Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z from b.y.z as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a.y.z from b.y.z as bar Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a, b,c from b # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a, b,c from b Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z, b.y.z, c.y.z from b.y.z # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a.y.z, b.y.z, c.y.z from b.y.z Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a,b,c from b as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a,b,c from b as bar Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z, b.y.z, c.y.z from b.y.z as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a.y.z, b.y.z, c.y.z from b.y.z as bar Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? @@ -2061,19 +2061,19 @@ Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> import a as b.c # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a as b.c Traceback (most recent call last): SyntaxError: cannot use attribute as import target ->>> import a.b as (a, b) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a.b as (a, b) Traceback (most recent call last): SyntaxError: cannot use tuple as import target ->>> import a, a.b as 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a, a.b as 1 Traceback (most recent call last): SyntaxError: cannot use literal as import target ->>> import a.b as 'a', a # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> import a.b as 'a', a Traceback (most recent call last): SyntaxError: cannot use literal as import target @@ -2081,7 +2081,7 @@ Traceback (most recent call last): SyntaxError: cannot use attribute as import target ->>> from a import b as 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> from a import b as 1 Traceback (most recent call last): SyntaxError: cannot use literal as import target @@ -2103,11 +2103,11 @@ Traceback (most recent call last): SyntaxError: cannot use tuple as import target ->>> from a import b, с as d[e] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> from a import b, с as d[e] Traceback (most recent call last): SyntaxError: cannot use subscript as import target ->>> from a import с as d[e], b # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +>>> from a import с as d[e], b Traceback (most recent call last): SyntaxError: cannot use subscript as import target @@ -2239,7 +2239,7 @@ Invalid pattern matching constructs: - >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match ...: ... case 42 as _: ... ... Traceback (most recent call last): @@ -2251,13 +2251,13 @@ Traceback (most recent call last): SyntaxError: cannot use expression as pattern target - >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match ...: ... case 42 as a.b: ... ... Traceback (most recent call last): SyntaxError: cannot use attribute as pattern target - >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> match ...: ... case 42 as (a, b): ... ... Traceback (most recent call last): @@ -2307,7 +2307,7 @@ Traceback (most recent call last): ... SyntaxError: invalid syntax - >>> A[:(*b)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> A[:(*b)] Traceback (most recent call last): ... SyntaxError: cannot use starred expression here @@ -2326,7 +2326,7 @@ Traceback (most recent call last): ... SyntaxError: invalid syntax - >>> A[(*b):] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> A[(*b):] Traceback (most recent call last): ... SyntaxError: cannot use starred expression here @@ -2636,26 +2636,26 @@ def f(x: *b) Traceback (most recent call last): SyntaxError: cannot assign to __debug__ - >>> class A[__debug__]: pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class A[__debug__]: pass Traceback (most recent call last): SyntaxError: cannot assign to __debug__ - >>> class A[T]((x := 3)): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class A[T]((x := 3)): ... Traceback (most recent call last): ... SyntaxError: named expression cannot be used within the definition of a generic - >>> class A[T]((yield 3)): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class A[T]((yield 3)): ... Traceback (most recent call last): ... SyntaxError: yield expression cannot be used within the definition of a generic - >>> class A[T]((await 3)): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class A[T]((await 3)): ... Traceback (most recent call last): ... SyntaxError: await expression cannot be used within the definition of a generic - >>> class A[T]((yield from [])): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> class A[T]((yield from [])): ... Traceback (most recent call last): ... SyntaxError: yield expression cannot be used within the definition of a generic @@ -2664,23 +2664,23 @@ def f(x: *b) Traceback (most recent call last): SyntaxError: iterable argument unpacking follows keyword argument unpacking - >>> f(**x, *) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> f(**x, *) Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x, *:) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> f(x, *:) Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x, *) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> f(x, *) Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x = 5, *) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> f(x = 5, *) Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x = 5, *:) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + >>> f(x = 5, *:) Traceback (most recent call last): SyntaxError: Invalid star expression """ @@ -2702,7 +2702,6 @@ def check_warning(self, code, errtext, filename="", mode="exec"): with self.assertWarnsRegex(SyntaxWarning, errtext): compile(code, filename, mode) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxWarning not triggered def test_return_in_finally(self): source = textwrap.dedent(""" def f(): @@ -2737,7 +2736,6 @@ def f(): """) self.check_warning(source, "'return' in a 'finally' block") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxWarning not triggered def test_break_and_continue_in_finally(self): for kw in ('break', 'continue'): @@ -2807,7 +2805,6 @@ def _check_error(self, code, errtext, else: self.fail("compile() did not raise SyntaxError") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_expression_with_assignment(self): self._check_error( "print(end1 + end2 = ' ')", @@ -2821,7 +2818,6 @@ def test_curly_brace_after_primary_raises_immediately(self): def test_assign_call(self): self._check_error("f() = 1", "assign") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_assign_del(self): self._check_error("del (,)", "invalid syntax") self._check_error("del 1", "cannot delete literal") @@ -2955,13 +2951,11 @@ def test_generator_in_function_call(self): "Generator expression must be parenthesized", lineno=1, end_lineno=1, offset=11, end_offset=53) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_except_then_except_star(self): self._check_error("try: pass\nexcept ValueError: pass\nexcept* TypeError: pass", r"cannot have both 'except' and 'except\*' on the same 'try'", lineno=3, end_lineno=3, offset=1, end_offset=8) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_except_star_then_except(self): self._check_error("try: pass\nexcept* ValueError: pass\nexcept TypeError: pass", r"cannot have both 'except' and 'except\*' on the same 'try'", @@ -3109,7 +3103,6 @@ def func2(): """ self._check_error(code, "expected ':'") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_line_continuation_error_position(self): self._check_error(r"a = 3 \ 4", "unexpected character after line continuation character", @@ -3129,7 +3122,6 @@ def test_invalid_line_continuation_left_recursive(self): self._check_error("A.\u03bc\\\n", "unexpected EOF while parsing") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_parenthesis(self): for paren in "([{": self._check_error(paren + "1 + 2", f"\\{paren}' was never closed") @@ -3155,7 +3147,6 @@ def test_error_parenthesis(self): s = b'# coding=latin\n(aaaaaaaaaaaaaaaaa\naaaaaaaaaaa\xb5' self._check_error(s, r"'\(' was never closed") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_string_literal(self): self._check_error("'blech", r"unterminated string literal \(.*\)$") @@ -3169,7 +3160,6 @@ def test_error_string_literal(self): self._check_error("'''blech", "unterminated triple-quoted string literal") self._check_error('"""blech', "unterminated triple-quoted string literal") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_invisible_characters(self): self._check_error('print\x17("Hello")', "invalid non-printable character") self._check_error(b"with(0,,):\n\x01", "invalid non-printable character") @@ -3252,7 +3242,6 @@ def test_deep_invalid_rule(self): with self.assertRaises(SyntaxError): compile(source, "", "exec") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_except_stmt_invalid_as_expr(self): self._check_error( textwrap.dedent( @@ -3270,7 +3259,6 @@ def test_except_stmt_invalid_as_expr(self): end_offset=22 + len("obj.attr"), ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_match_stmt_invalid_as_expr(self): self._check_error( textwrap.dedent( @@ -3287,7 +3275,6 @@ def test_match_stmt_invalid_as_expr(self): end_offset=15 + len("obj.attr"), ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_ifexp_else_stmt(self): msg = "expected expression after 'else', but statement is given" @@ -3308,7 +3295,6 @@ def test_ifexp_else_stmt(self): ]: self._check_error(f"x = 1 if 1 else {stmt}", msg) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_ifexp_body_stmt_else_expression(self): msg = "expected expression before 'if', but statement is given" @@ -3319,7 +3305,6 @@ def test_ifexp_body_stmt_else_expression(self): ]: self._check_error(f"x = {stmt} if 1 else 1", msg) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_ifexp_body_stmt_else_stmt(self): msg = "expected expression before 'if', but statement is given" for lhs_stmt, rhs_stmt in [ diff --git a/Lib/test/test_sys_setprofile.py b/Lib/test/test_sys_setprofile.py index 813adff2a32..d0d2b0c3e01 100644 --- a/Lib/test/test_sys_setprofile.py +++ b/Lib/test/test_sys_setprofile.py @@ -169,7 +169,6 @@ def g(p): (1, 'return', g_ident), ]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_exception_propagation(self): def f(p): 1/0 diff --git a/Lib/test/test_type_comments.py b/Lib/test/test_type_comments.py index 0deb25f16d3..d827ac27108 100644 --- a/Lib/test/test_type_comments.py +++ b/Lib/test/test_type_comments.py @@ -252,7 +252,6 @@ def parse_all(self, source, minver=lowest, maxver=highest, expected_regex=""): def classic_parse(self, source): return ast.parse(source) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'FunctionDef' object has no attribute 'type_comment' def test_funcdef(self): for tree in self.parse_all(funcdef): self.assertEqual(tree.body[0].type_comment, "() -> int") @@ -261,7 +260,6 @@ def test_funcdef(self): self.assertEqual(tree.body[0].type_comment, None) self.assertEqual(tree.body[1].type_comment, None) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised : feature_version=(3, 4) def test_asyncdef(self): for tree in self.parse_all(asyncdef, minver=5): self.assertEqual(tree.body[0].type_comment, "() -> int") @@ -274,12 +272,10 @@ def test_asyncvar(self): with self.assertRaises(SyntaxError): self.classic_parse(asyncvar) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised : feature_version=(3, 4) def test_asynccomp(self): for tree in self.parse_all(asynccomp, minver=6): pass - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised : feature_version=(3, 4) def test_matmul(self): for tree in self.parse_all(matmul, minver=5): pass @@ -288,37 +284,31 @@ def test_fstring(self): for tree in self.parse_all(fstring): pass - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised : feature_version=(3, 4) def test_underscorednumber(self): for tree in self.parse_all(underscorednumber, minver=6): pass - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised : feature_version=(3, 4) def test_redundantdef(self): for tree in self.parse_all(redundantdef, maxver=0, expected_regex="^Cannot have two type comments on def"): pass - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'FunctionDef' object has no attribute 'type_comment' def test_nonasciidef(self): for tree in self.parse_all(nonasciidef): self.assertEqual(tree.body[0].type_comment, "() -> àçčéñt") - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'For' object has no attribute 'type_comment' def test_forstmt(self): for tree in self.parse_all(forstmt): self.assertEqual(tree.body[0].type_comment, "int") tree = self.classic_parse(forstmt) self.assertEqual(tree.body[0].type_comment, None) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'With' object has no attribute 'type_comment' def test_withstmt(self): for tree in self.parse_all(withstmt): self.assertEqual(tree.body[0].type_comment, "int") tree = self.classic_parse(withstmt) self.assertEqual(tree.body[0].type_comment, None) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'With' object has no attribute 'type_comment' def test_parenthesized_withstmt(self): for tree in self.parse_all(parenthesized_withstmt): self.assertEqual(tree.body[0].type_comment, "int") @@ -327,14 +317,12 @@ def test_parenthesized_withstmt(self): self.assertEqual(tree.body[0].type_comment, None) self.assertEqual(tree.body[1].type_comment, None) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: None != 'int' def test_vardecl(self): for tree in self.parse_all(vardecl): self.assertEqual(tree.body[0].type_comment, "int") tree = self.classic_parse(vardecl) self.assertEqual(tree.body[0].type_comment, None) - @unittest.expectedFailure # TODO: RUSTPYTHON; + (11, ' whatever')] def test_ignores(self): for tree in self.parse_all(ignores): self.assertEqual( @@ -350,7 +338,6 @@ def test_ignores(self): tree = self.classic_parse(ignores) self.assertEqual(tree.type_ignores, []) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised : feature_version=(3, 4) def test_longargs(self): for tree in self.parse_all(longargs, minver=8): for t in tree.body: @@ -381,7 +368,6 @@ def test_longargs(self): self.assertIsNone(arg.type_comment, "%s(%s:%r)" % (t.name, arg.arg, arg.type_comment)) - @unittest.expectedFailure # TODO: RUSTPYTHON; Tests for inappropriately-placed type comments. def test_inappropriate_type_comments(self): """Tests for inappropriately-placed type comments. @@ -416,7 +402,6 @@ def test_non_utf8_type_comment_with_ignore_cookie(self): _testcapi.Py_CompileStringExFlags( b"def a(f=8, #type: \x80\n\x80", "", 256, flags) - @unittest.expectedFailure # TODO: RUSTPYTHON; ValueError: mode must be "exec", "eval", "ipython", or "single" def test_func_type_input(self): def parse_func_type_input(source): diff --git a/Lib/test/test_type_params.py b/Lib/test/test_type_params.py index c63ea2d291c..65261fcb6ed 100644 --- a/Lib/test/test_type_params.py +++ b/Lib/test/test_type_params.py @@ -683,7 +683,6 @@ def foo[U: T](self): ... self.assertIs(X.foo.__type_params__[0].__bound__, float) self.assertIs(X.Alias.__value__, float) - @unittest.expectedFailure # TODO: RUSTPYTHON; + global def test_binding_uses_global(self): ns = run_code(""" x = "global" @@ -1076,7 +1075,6 @@ async def coroutine[B](): class TypeParamsTypeVarTupleTest(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "cannot use bound with TypeVarTuple" does not match "invalid syntax (, line 1)" def test_typevartuple_01(self): code = """def func1[*A: str](): pass""" check_syntax_error(self, code, "cannot use bound with TypeVarTuple") @@ -1100,7 +1098,6 @@ def func1[*A](): class TypeParamsTypeVarParamSpecTest(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "cannot use bound with ParamSpec" does not match "invalid syntax (, line 1)" def test_paramspec_01(self): code = """def func1[**A: str](): pass""" check_syntax_error(self, code, "cannot use bound with ParamSpec") diff --git a/Lib/test/test_unpack_ex.py b/Lib/test/test_unpack_ex.py index 13b789f52dc..e5d9ea97a5f 100644 --- a/Lib/test/test_unpack_ex.py +++ b/Lib/test/test_unpack_ex.py @@ -356,7 +356,7 @@ ... SyntaxError: can't use starred expression here - >>> (*x),y = 1, 2 # TODO: RUSTPYTHON # doctest:+ELLIPSIS +EXPECTED_FAILURE + >>> (*x),y = 1, 2 Traceback (most recent call last): ... SyntaxError: cannot use starred expression here @@ -366,12 +366,12 @@ ... SyntaxError: cannot use starred expression here - >>> z,(*x),y = 1, 2, 4 # TODO: RUSTPYTHON # doctest:+ELLIPSIS +EXPECTED_FAILURE + >>> z,(*x),y = 1, 2, 4 Traceback (most recent call last): ... SyntaxError: cannot use starred expression here - >>> z,(*x) = 1, 2 # TODO: RUSTPYTHON # doctest:+ELLIPSIS +EXPECTED_FAILURE + >>> z,(*x) = 1, 2 Traceback (most recent call last): ... SyntaxError: cannot use starred expression here diff --git a/crates/capi/src/ceval.rs b/crates/capi/src/ceval.rs index d28dad4d6df..867c39e0388 100644 --- a/crates/capi/src/ceval.rs +++ b/crates/capi/src/ceval.rs @@ -1,17 +1,13 @@ use crate::pystate::with_vm; +use crate::unicodeobject::decode_fsdefault_and_size; use core::ffi::{CStr, c_char, c_int}; use core::ptr::NonNull; use rustpython_vm::builtins::{PyCode, PyDict}; -use rustpython_vm::compiler::Mode; use rustpython_vm::function::ArgMapping; use rustpython_vm::scope::Scope; +use rustpython_vm::version; use rustpython_vm::{AsObject, PyObject, TryFromObject}; -const PY_SINGLE_INPUT: c_int = 256; -const PY_FILE_INPUT: c_int = 257; -const PY_EVAL_INPUT: c_int = 258; -const PY_FUNC_TYPE_INPUT: c_int = 345; - #[unsafe(no_mangle)] pub unsafe extern "C" fn Py_CompileString( code: *const c_char, @@ -19,27 +15,11 @@ pub unsafe extern "C" fn Py_CompileString( start: c_int, ) -> *mut PyObject { with_vm(|vm| { - let code = unsafe { CStr::from_ptr(code) }.to_str().map_err(|_| { - vm.new_system_error("Py_CompileString called with non UTF-8 code string") - })?; - let filename = unsafe { CStr::from_ptr(filename) } - .to_str() - .map_err(|_| vm.new_system_error("Py_CompileString called with non UTF-8 filename"))?; - - let mode = match start { - PY_SINGLE_INPUT => Mode::Single, - PY_FILE_INPUT => Mode::Exec, - PY_EVAL_INPUT => Mode::Eval, - PY_FUNC_TYPE_INPUT => Mode::BlockExpr, - _ => { - return Err( - vm.new_system_error("Invalid start argument passed to Py_CompileString") - ); - } - }; - - vm.compile(code, mode, filename) - .map_err(|err| vm.new_syntax_error(&err, Some(code))) + let code = unsafe { CStr::from_ptr(code) }.to_bytes(); + let filename_size = unsafe { CStr::from_ptr(filename) }.to_bytes().len(); + let filename = decode_fsdefault_and_size(vm, filename, filename_size)?; + let filename = filename.to_string_lossy(); + vm.compile_string_object_with_flags(code, &filename, start, 0, version::MINOR as c_int, -1) }) } diff --git a/crates/capi/src/unicodeobject.rs b/crates/capi/src/unicodeobject.rs index acc6e392c53..33d46692602 100644 --- a/crates/capi/src/unicodeobject.rs +++ b/crates/capi/src/unicodeobject.rs @@ -4,8 +4,8 @@ use core::ffi::{CStr, c_char, c_int}; use core::ptr::NonNull; use core::slice; use core::str; -use rustpython_vm::PyObjectRef; -use rustpython_vm::builtins::PyStr; +use rustpython_vm::builtins::{PyStr, PyStrRef}; +use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; define_py_check!(fn PyUnicode_Check, types.str_type); define_py_check!(exact fn PyUnicode_CheckExact, types.str_type); @@ -113,26 +113,42 @@ pub unsafe extern "C" fn PyUnicode_DecodeFSDefaultAndSize( .try_into() .map_err(|_| vm.new_system_error("size must be non-negative"))?; - let bytes = if s.is_null() { - if size != 0 { - return Err(vm.new_system_error( - "PyUnicode_DecodeFSDefaultAndSize called with null data and non-zero size", - )); - } - &[][..] - } else { - unsafe { slice::from_raw_parts(s.cast::(), size) } - }; + decode_fsdefault_and_size(vm, s, size) + }) +} - vm.state.codec_registry.decode_text( - vm.ctx.new_bytes(bytes.to_vec()).into(), - vm.fs_encoding().as_str(), - Some(vm.fs_encode_errors().to_owned()), - vm, - ) +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyUnicode_DecodeFSDefault(s: *const c_char) -> *mut PyObject { + with_vm(|vm| { + let size = unsafe { CStr::from_ptr(s) }.to_bytes().len(); + decode_fsdefault_and_size(vm, s, size) }) } +pub(crate) fn decode_fsdefault_and_size( + vm: &VirtualMachine, + s: *const c_char, + size: usize, +) -> PyResult { + let bytes = if s.is_null() { + if size != 0 { + return Err(vm.new_system_error( + "PyUnicode_DecodeFSDefaultAndSize called with null data and non-zero size", + )); + } + &[][..] + } else { + unsafe { slice::from_raw_parts(s.cast::(), size) } + }; + + vm.state.codec_registry.decode_text( + vm.ctx.new_bytes(bytes.to_vec()).into(), + vm.fs_encoding().as_str(), + Some(vm.fs_encode_errors().to_owned()), + vm, + ) +} + #[unsafe(no_mangle)] pub unsafe extern "C" fn PyUnicode_EncodeFSDefault(unicode: *mut PyObject) -> *mut PyObject { with_vm(|vm| { diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index a178a23bd2f..da030251e31 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -10,43 +10,41 @@ #![deny(clippy::cast_possible_truncation)] use crate::{ - IndexMap, IndexSet, ToPythonName, - error::{CodegenError, CodegenErrorType, InternalError, PatternUnreachableReason}, + IndexMap, IndexSet, PublicAstExprList, PublicAstFormattedValue, PublicAstInterpolation, + PublicAstNodeMap, ToPythonName, + error::{CodegenError, CodegenErrorType, InternalError}, ir::{self, Block, BlockIdx, Blocks}, preprocess, symboltable::{self, CompilerScope, Symbol, SymbolFlags, SymbolScope, SymbolTable}, unparse::UnparseExpr, }; -use alloc::borrow::Cow; -use core::mem; +use alloc::{borrow::Cow, sync::Arc}; +use core::{mem, slice}; use malachite_bigint::BigInt; use num_complex::Complex; use num_traits::{Num, ToPrimitive, Zero}; use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange, TextSize}; - use rustpython_compiler_core::{ Mode, OneIndexed, PositionEncoding, SourceFile, SourceLocation, bytecode::{ self, AnyInstruction, AnyOpcode, Arg as OpArgMarker, BinaryOperator, BuildSliceArgCount, - CodeFlags, CodeObject, ComparisonOperator, ConstantData, ConvertValueOparg, Instruction, - IntrinsicFunction1, Invert, LoadAttr, LoadSuperAttr, MakeFunctionFlag, MakeFunctionFlags, - OpArg, OpArgType, Opcode, PseudoInstruction, PseudoOpcode, SpecialMethod, UnpackExArgs, - oparg, + CodeObject, ComparisonOperator, ConstantData, ConvertValueOparg, Instruction, + IntrinsicFunction1, Invert, LoadAttr, LoadSuperAttr, OpArg, OpArgType, PseudoInstruction, + SpecialMethod, UnpackExArgs, oparg, }, }; +use rustpython_literal::{ + complex as literal_complex, + escape::{AsciiEscape, UnicodeEscape}, + float as literal_float, +}; use rustpython_wtf8::Wtf8Buf; /// Extension trait for `ast::Expr` to add constant checking methods trait ExprExt { /// Returns true if the expression is a constant literal with no side effects. fn is_constant(&self) -> bool; - - /// Check if a slice expression has all constant elements - fn is_constant_slice(&self) -> bool; - - /// Check if we should use BINARY_SLICE/STORE_SLICE optimization - fn should_use_slice_optimization(&self) -> bool; } impl ExprExt for ast::Expr { @@ -61,25 +59,9 @@ impl ExprExt for ast::Expr { | Self::EllipsisLiteral(_) ) } - - fn is_constant_slice(&self) -> bool { - match self { - Self::Slice(s) => { - let lower_const = s.lower.as_deref().is_none_or(|e| e.is_constant()); - let upper_const = s.upper.as_deref().is_none_or(|e| e.is_constant()); - let step_const = s.step.as_deref().is_none_or(|e| e.is_constant()); - lower_const && upper_const && step_const - } - _ => false, - } - } - - fn should_use_slice_optimization(&self) -> bool { - !self.is_constant_slice() && matches!(self, Self::Slice(s) if s.step.is_none()) - } } -const MAXBLOCKS: usize = 20; +const CO_MAXBLOCKS: usize = 21; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FBlockType { @@ -142,6 +124,44 @@ pub struct FBlockInfo { pub(crate) type InternalResult = Result; type CompileResult = Result; +pub type SyntaxWarningHandler<'a> = + dyn FnMut(SourceLocation, String) -> Result<(), CodegenError> + 'a; + +fn warn_ast_preprocess_syntax( + source_file: &SourceFile, + handler: &mut SyntaxWarningHandler<'_>, + range: TextRange, + message: String, +) -> CompileResult<()> { + let location = source_file + .to_source_code() + .source_location(range.start(), PositionEncoding::Utf8); + handler(location, message) +} + +fn checked_future_features( + ast: &ruff_python_ast::Mod, + source_file: &SourceFile, +) -> CompileResult { + preprocess::checked_future_features(ast).map_err(|err| { + let location = source_file + .to_source_code() + .source_location(err.range.start(), PositionEncoding::Utf8); + let error = match err.kind { + preprocess::FutureFeatureErrorKind::InvalidFeature(feature) => { + CodegenErrorType::InvalidFutureFeature(feature) + } + preprocess::FutureFeatureErrorKind::InvalidBraces => { + CodegenErrorType::InvalidFutureBraces + } + }; + CodegenError { + location: Some(location), + error, + source_path: source_file.name().to_owned(), + } + }) +} #[derive(PartialEq, Eq, Clone, Copy)] enum NameUsage { @@ -150,13 +170,15 @@ enum NameUsage { Delete, } /// Main structure holding the state of compilation. -struct Compiler { +struct Compiler<'a> { code_stack: Vec, symbol_table_stack: Vec, + annotation_symbol_sources: Vec, source_file: SourceFile, // current_source_location: SourceLocation, current_source_range: TextRange, done_with_future_stmts: DoneWithFuture, + future_features: bytecode::CodeFlags, future_annotations: bool, ctx: CompileContext, opts: CompileOpts, @@ -168,9 +190,12 @@ struct Compiler { /// When > 0, the compiler walks AST (consuming sub_tables) but emits no bytecode. /// Mirrors CPython's `c_do_not_emit_bytecode`. do_not_emit_bytecode: u32, + /// Mirrors CPython's `c_disable_warning` while compiling FINALLY_END copies. + disable_warning: u32, /// Disable constant tuple/list/set collection folding in contexts where /// CPython keeps the builder form for later assignment lowering. disable_const_collection_folding: bool, + syntax_warning_handler: Option<&'a mut SyntaxWarningHandler<'a>>, } #[derive(Clone, Copy)] @@ -180,13 +205,58 @@ enum DoneWithFuture { Yes, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy)] +enum AnnotationSymbolSource { + Sibling, + Hidden, +} + +#[derive(Clone, Copy)] +enum ComprehensionSymbolSource { + Child, + Inlined, +} + +#[derive(Clone, Copy)] +struct SymbolTableCursors { + sub_table: usize, + hidden_annotation_block: usize, + inlined_comprehension_block: usize, +} + +#[derive(Clone, Debug)] pub struct CompileOpts { /// How optimized the bytecode output should be; any optimize > 0 does /// not emit assert statements pub optimize: u8, /// Include column info in bytecode (-X no_debug_ranges disables) pub debug_ranges: bool, + /// Maximum decimal integer literal digits, matching sys.int_info/default. + pub int_max_str_digits: usize, + /// Allow module-level await/async-for/async-with, matching PyCF_ALLOW_TOP_LEVEL_AWAIT. + pub allow_top_level_await: bool, + /// Future compiler flags passed explicitly to compile(), matching cf_flags merge. + pub future_features: bytecode::CodeFlags, + /// Keep single-input blocks incomplete until a terminating newline is seen. + pub dont_imply_dedent: bool, + /// Recursion limit used by compiler tree walks, matching Py_EnterRecursiveCall. + pub recursion_limit: usize, + /// CPython Constant_kind stores value/kind directly; Ruff has no Constant + /// expr variant. Dense public node indexes make Vec lookup cheaper than + /// hashing, and compile order never observes map insertion order. + pub ast_constant_overrides: Option>>, + /// CPython Interpolation has raw str and expr? format_spec; Ruff t-string + /// elements do not. Dense node lookup keeps the CPython codegen path. + pub ast_interpolation_overrides: Option>>, + /// CPython FormattedValue has expr? format_spec; Ruff f-string specs are + /// nested string elements. Dense node lookup avoids hash overhead. + pub ast_formatted_value_overrides: Option>>, + /// CPython JoinedStr.values is expr*; Ruff stores f-string element trees. + /// Dense node lookup restores only public JoinedStr overrides. + pub ast_joined_str_overrides: Option>>, + /// CPython TemplateStr.values is expr*; Ruff stores t-string element trees. + /// Dense node lookup restores only public TemplateStr overrides. + pub ast_template_str_overrides: Option>>, } impl Default for CompileOpts { @@ -194,6 +264,16 @@ impl Default for CompileOpts { Self { optimize: 0, debug_ranges: true, + int_max_str_digits: 4300, + allow_top_level_await: false, + future_features: bytecode::CodeFlags::empty(), + dont_imply_dedent: false, + recursion_limit: 1000, + ast_constant_overrides: None, + ast_interpolation_overrides: None, + ast_formatted_value_overrides: None, + ast_joined_str_overrides: None, + ast_template_str_overrides: None, } } } @@ -258,19 +338,69 @@ fn validate_duplicate_params(params: &ast::Parameters) -> Result<(), CodegenErro /// Compile an Mod produced from ruff parser pub fn compile_top( - mut ast: ruff_python_ast::Mod, + ast: ruff_python_ast::Mod, source_file: SourceFile, mode: Mode, opts: CompileOpts, ) -> CompileResult { - preprocess::preprocess_mod(&mut ast); + compile_top_with_syntax_warning_handler(ast, source_file, mode, opts, None) +} + +pub fn compile_top_with_syntax_warning_handler<'a>( + mut ast: ruff_python_ast::Mod, + source_file: SourceFile, + mode: Mode, + mut opts: CompileOpts, + mut syntax_warning_handler: Option<&'a mut SyntaxWarningHandler<'a>>, +) -> CompileResult { + opts.future_features |= checked_future_features(&ast, &source_file)?; + let future_annotations = opts + .future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + if let Some(handler) = syntax_warning_handler.as_deref_mut() { + preprocess::warn_control_flow_in_finally(&ast, |range, message| { + warn_ast_preprocess_syntax(&source_file, handler, range, message) + })?; + } + if matches!(mode, Mode::Single) + && let ruff_python_ast::Mod::Module(module) = &mut ast + { + preprocess::preprocess_statements( + &mut module.body, + opts.optimize, + future_annotations, + false, + ); + } else { + preprocess::preprocess_mod(&mut ast, opts.optimize, future_annotations, false); + } match ast { ruff_python_ast::Mod::Module(module) => match mode { - Mode::Exec | Mode::Eval => compile_program(&module, source_file, opts), - Mode::Single => compile_program_single(&module, source_file, opts), - Mode::BlockExpr => compile_block_expression(&module, source_file, opts), + Mode::Exec | Mode::Eval => compile_program_with_syntax_warning_handler( + &module, + source_file, + opts, + syntax_warning_handler, + ), + Mode::Single => compile_program_single_with_syntax_warning_handler( + &module, + source_file, + opts, + syntax_warning_handler, + ), + Mode::BlockExpr => compile_block_expression_with_syntax_warning_handler( + &module, + source_file, + opts, + syntax_warning_handler, + ), }, - ruff_python_ast::Mod::Expression(expr) => compile_expression(&expr, source_file, opts), + ruff_python_ast::Mod::Expression(expr) => compile_expression_with_syntax_warning_handler( + &expr, + source_file, + opts, + syntax_warning_handler, + ), } } @@ -280,9 +410,35 @@ pub fn compile_program( source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { - let symbol_table = SymbolTable::scan_program(ast, source_file.clone()) - .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; - let mut compiler = Compiler::new(opts, source_file, ""); + compile_program_with_syntax_warning_handler(ast, source_file, opts, None) +} + +fn compile_program_with_syntax_warning_handler<'a>( + ast: &ast::ModModule, + source_file: SourceFile, + opts: CompileOpts, + syntax_warning_handler: Option<&'a mut SyntaxWarningHandler<'a>>, +) -> CompileResult { + let symbol_table = SymbolTable::scan_program_with_options( + ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; + let mut compiler = Compiler::new_with_syntax_warning_handler( + opts, + source_file, + "", + syntax_warning_handler, + ); compiler.compile_program(ast, symbol_table)?; let code = compiler.exit_scope(); trace!("Compilation completed: {code:?}"); @@ -295,9 +451,35 @@ pub fn compile_program_single( source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { - let symbol_table = SymbolTable::scan_program(ast, source_file.clone()) - .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; - let mut compiler = Compiler::new(opts, source_file, ""); + compile_program_single_with_syntax_warning_handler(ast, source_file, opts, None) +} + +fn compile_program_single_with_syntax_warning_handler<'a>( + ast: &ast::ModModule, + source_file: SourceFile, + opts: CompileOpts, + syntax_warning_handler: Option<&'a mut SyntaxWarningHandler<'a>>, +) -> CompileResult { + let symbol_table = SymbolTable::scan_program_with_options( + ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; + let mut compiler = Compiler::new_with_syntax_warning_handler( + opts, + source_file, + "", + syntax_warning_handler, + ); compiler.compile_program_single(&ast.body, symbol_table)?; let code = compiler.exit_scope(); trace!("Compilation completed: {code:?}"); @@ -309,9 +491,35 @@ pub fn compile_block_expression( source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { - let symbol_table = SymbolTable::scan_program(ast, source_file.clone()) - .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; - let mut compiler = Compiler::new(opts, source_file, ""); + compile_block_expression_with_syntax_warning_handler(ast, source_file, opts, None) +} + +fn compile_block_expression_with_syntax_warning_handler<'a>( + ast: &ast::ModModule, + source_file: SourceFile, + opts: CompileOpts, + syntax_warning_handler: Option<&'a mut SyntaxWarningHandler<'a>>, +) -> CompileResult { + let symbol_table = SymbolTable::scan_program_with_options( + ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; + let mut compiler = Compiler::new_with_syntax_warning_handler( + opts, + source_file, + "", + syntax_warning_handler, + ); compiler.compile_block_expr(&ast.body, symbol_table)?; let code = compiler.exit_scope(); trace!("Compilation completed: {code:?}"); @@ -323,9 +531,35 @@ pub fn compile_expression( source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { - let symbol_table = SymbolTable::scan_expr(ast, source_file.clone()) - .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; - let mut compiler = Compiler::new(opts, source_file, ""); + compile_expression_with_syntax_warning_handler(ast, source_file, opts, None) +} + +fn compile_expression_with_syntax_warning_handler<'a>( + ast: &ast::ModExpression, + source_file: SourceFile, + opts: CompileOpts, + syntax_warning_handler: Option<&'a mut SyntaxWarningHandler<'a>>, +) -> CompileResult { + let symbol_table = SymbolTable::scan_expr_with_options( + ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned()))?; + let mut compiler = Compiler::new_with_syntax_warning_handler( + opts, + source_file, + "", + syntax_warning_handler, + ); compiler.compile_eval(ast, symbol_table)?; let code = compiler.exit_scope(); Ok(code) @@ -353,7 +587,7 @@ macro_rules! emit { }; } -fn eprint_location(zelf: &Compiler) { +fn eprint_location(zelf: &Compiler<'_>) { let start = zelf .source_file .to_source_code() @@ -374,7 +608,7 @@ fn eprint_location(zelf: &Compiler) { /// Better traceback for internal error #[track_caller] -fn unwrap_internal(zelf: &Compiler, r: InternalResult) -> T { +fn unwrap_internal(zelf: &Compiler<'_>, r: InternalResult) -> T { if let Err(ref r_err) = r { eprintln!("=== CODEGEN PANIC INFO ==="); eprintln!("This IS an internal error: {r_err}"); @@ -384,7 +618,7 @@ fn unwrap_internal(zelf: &Compiler, r: InternalResult) -> T { r.unwrap() } -fn compiler_unwrap_option(zelf: &Compiler, o: Option) -> T { +fn compiler_unwrap_option(zelf: &Compiler<'_>, o: Option) -> T { if o.is_none() { eprintln!("=== CODEGEN PANIC INFO ==="); eprintln!("This IS an internal error, an option was unwrapped during codegen"); @@ -454,10 +688,358 @@ enum CollectionType { Set, } +#[derive(Clone, Copy, Eq, PartialEq)] +enum InferredType { + Tuple, + List, + Dict, + Set, + FrozenSet, + Generator, + Function, + Template, + Str, + Bytes, + Int, + Float, + Complex, + Bool, + NoneType, + Ellipsis, + Slice, +} + +impl InferredType { + const fn name(self) -> &'static str { + match self { + Self::Tuple => "tuple", + Self::List => "list", + Self::Dict => "dict", + Self::Set => "set", + Self::FrozenSet => "frozenset", + Self::Generator => "generator", + Self::Function => "function", + Self::Template => "string.templatelib.Template", + Self::Str => "str", + Self::Bytes => "bytes", + Self::Int => "int", + Self::Float => "float", + Self::Complex => "complex", + Self::Bool => "bool", + Self::NoneType => "NoneType", + Self::Ellipsis => "ellipsis", + Self::Slice => "slice", + } + } + + const fn is_long_subclass(self) -> bool { + matches!(self, Self::Int | Self::Bool) + } +} + const STACK_USE_GUIDELINE: u32 = 30; -impl Compiler { - fn new(opts: CompileOpts, source_file: SourceFile, code_name: &str) -> Self { +impl<'warnings> Compiler<'warnings> { + fn constant_truthiness(constant: &ConstantData) -> bool { + match constant { + ConstantData::Tuple { elements } | ConstantData::Frozenset { elements } => { + !elements.is_empty() + } + ConstantData::Integer { value } => !value.is_zero(), + ConstantData::Float { value } => *value != 0.0, + ConstantData::Complex { value } => value.re != 0.0 || value.im != 0.0, + ConstantData::Boolean { value } => *value, + ConstantData::Str { value } => !value.is_empty(), + ConstantData::Bytes { value } => !value.is_empty(), + ConstantData::Code { .. } | ConstantData::Slice { .. } | ConstantData::Ellipsis => true, + ConstantData::None => false, + } + } + + fn infer_type_constant(constant: &ConstantData) -> Option { + match constant { + ConstantData::Tuple { .. } => Some(InferredType::Tuple), + ConstantData::Frozenset { .. } => Some(InferredType::FrozenSet), + ConstantData::Integer { .. } => Some(InferredType::Int), + ConstantData::Float { .. } => Some(InferredType::Float), + ConstantData::Complex { .. } => Some(InferredType::Complex), + ConstantData::Boolean { .. } => Some(InferredType::Bool), + ConstantData::Str { .. } => Some(InferredType::Str), + ConstantData::Bytes { .. } => Some(InferredType::Bytes), + ConstantData::None => Some(InferredType::NoneType), + ConstantData::Ellipsis => Some(InferredType::Ellipsis), + ConstantData::Slice { .. } => Some(InferredType::Slice), + ConstantData::Code { .. } => None, + } + } + + fn infer_type(&self, expr: &ast::Expr) -> Option { + if let Some(constant) = self.public_ast_constant_override(expr) { + return Self::infer_type_constant(&constant); + } + match expr { + ast::Expr::Tuple(_) => Some(InferredType::Tuple), + ast::Expr::List(_) | ast::Expr::ListComp(_) => Some(InferredType::List), + ast::Expr::Dict(_) | ast::Expr::DictComp(_) => Some(InferredType::Dict), + ast::Expr::Set(_) | ast::Expr::SetComp(_) => Some(InferredType::Set), + ast::Expr::Generator(_) => Some(InferredType::Generator), + ast::Expr::Lambda(_) => Some(InferredType::Function), + ast::Expr::TString(_) => Some(InferredType::Template), + ast::Expr::FString(_) | ast::Expr::StringLiteral(_) => Some(InferredType::Str), + ast::Expr::BytesLiteral(_) => Some(InferredType::Bytes), + ast::Expr::NumberLiteral(number) => match number.value { + ast::Number::Int(_) => Some(InferredType::Int), + ast::Number::Float(_) => Some(InferredType::Float), + ast::Number::Complex { .. } => Some(InferredType::Complex), + }, + ast::Expr::BooleanLiteral(_) => Some(InferredType::Bool), + ast::Expr::NoneLiteral(_) => Some(InferredType::NoneType), + ast::Expr::EllipsisLiteral(_) => Some(InferredType::Ellipsis), + ast::Expr::Slice(_) => Some(InferredType::Slice), + _ => None, + } + } + + fn is_constant_expr(&self, expr: &ast::Expr) -> bool { + if self.public_ast_constant_override(expr).is_some() { + return true; + } + matches!( + expr, + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + ) + } + + fn is_constant_slice(&self, slice: &ast::Expr) -> bool { + match slice { + ast::Expr::Slice(s) => { + let lower_const = s.lower.is_none() + || s.lower.as_deref().is_some_and(|e| self.is_constant_expr(e)); + let upper_const = s.upper.is_none() + || s.upper.as_deref().is_some_and(|e| self.is_constant_expr(e)); + let step_const = + s.step.is_none() || s.step.as_deref().is_some_and(|e| self.is_constant_expr(e)); + lower_const && upper_const && step_const + } + _ => false, + } + } + + fn should_apply_two_element_slice_optimization(&self, slice: &ast::Expr) -> bool { + !self.is_constant_slice(slice) && matches!(slice, ast::Expr::Slice(s) if s.step.is_none()) + } + + fn check_is_arg(&self, expr: &ast::Expr) -> bool { + if let Some(constant) = self.public_ast_constant_override(expr) { + return matches!( + constant, + ConstantData::None | ConstantData::Boolean { .. } | ConstantData::Ellipsis + ); + } + if let ast::Expr::Tuple(tuple) = expr { + return !tuple.elts.iter().all(|expr| self.is_constant_expr(expr)); + } + if !self.is_constant_expr(expr) { + return true; + } + matches!( + expr, + ast::Expr::NoneLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::EllipsisLiteral(_) + ) + } + + fn warn_syntax(&mut self, range: TextRange, message: String) -> CompileResult<()> { + if self.disable_warning > 0 { + return Ok(()); + } + let Some(handler) = self.syntax_warning_handler.as_deref_mut() else { + return Ok(()); + }; + let location = self + .source_file + .to_source_code() + .source_location(range.start(), PositionEncoding::Utf8); + handler(location, message) + } + + fn check_caller(&mut self, func: &ast::Expr) -> CompileResult<()> { + let warns = self.public_ast_constant_override(func).is_some() + || matches!( + func, + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + | ast::Expr::Tuple(_) + | ast::Expr::List(_) + | ast::Expr::ListComp(_) + | ast::Expr::Dict(_) + | ast::Expr::DictComp(_) + | ast::Expr::Set(_) + | ast::Expr::SetComp(_) + | ast::Expr::Generator(_) + | ast::Expr::FString(_) + | ast::Expr::TString(_) + ); + if warns && let Some(inferred) = self.infer_type(func) { + self.warn_syntax( + func.range(), + format!( + "'{}' object is not callable; perhaps you missed a comma?", + inferred.name() + ), + )?; + } + Ok(()) + } + + fn check_compare( + &mut self, + range: TextRange, + left: &ast::Expr, + ops: &[ast::CmpOp], + comparators: &[ast::Expr], + ) -> CompileResult<()> { + let mut left_is_arg = self.check_is_arg(left); + let mut left_expr = left; + for (op, right_expr) in ops.iter().zip(comparators.iter()) { + let right_is_arg = self.check_is_arg(right_expr); + if matches!(op, ast::CmpOp::Is | ast::CmpOp::IsNot) && (!right_is_arg || !left_is_arg) { + let literal = if !left_is_arg { left_expr } else { right_expr }; + if let Some(inferred) = self.infer_type(literal) { + let is_op = matches!(op, ast::CmpOp::Is); + let op = if is_op { "\"is\"" } else { "\"is not\"" }; + let replacement = if is_op { "==" } else { "!=" }; + self.warn_syntax( + range, + format!( + "{op} with '{}' literal. Did you mean \"{replacement}\"?", + inferred.name() + ), + )?; + return Ok(()); + } + } + left_is_arg = right_is_arg; + left_expr = right_expr; + } + Ok(()) + } + + fn constant_warns_as_subscripter(constant: &ConstantData) -> bool { + matches!( + constant, + ConstantData::None + | ConstantData::Ellipsis + | ConstantData::Integer { .. } + | ConstantData::Float { .. } + | ConstantData::Complex { .. } + | ConstantData::Boolean { .. } + | ConstantData::Frozenset { .. } + ) + } + + fn check_subscripter(&mut self, value: &ast::Expr) -> CompileResult<()> { + let warns = self + .public_ast_constant_override(value) + .is_some_and(|constant| Self::constant_warns_as_subscripter(&constant)) + || matches!( + value, + ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::Set(_) + | ast::Expr::SetComp(_) + | ast::Expr::Generator(_) + | ast::Expr::TString(_) + | ast::Expr::Lambda(_) + ); + if warns && let Some(inferred) = self.infer_type(value) { + self.warn_syntax( + value.range(), + format!( + "'{}' object is not subscriptable; perhaps you missed a comma?", + inferred.name() + ), + )?; + } + Ok(()) + } + + fn check_index(&mut self, value: &ast::Expr, slice: &ast::Expr) -> CompileResult<()> { + let Some(index_type) = self.infer_type(slice) else { + return Ok(()); + }; + if index_type.is_long_subclass() || index_type == InferredType::Slice { + return Ok(()); + } + + let constant_warns = self + .public_ast_constant_override(value) + .is_some_and(|constant| { + matches!( + constant, + ConstantData::Str { .. } + | ConstantData::Bytes { .. } + | ConstantData::Tuple { .. } + ) + }); + let warns = constant_warns + || matches!( + value, + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::Tuple(_) + | ast::Expr::List(_) + | ast::Expr::ListComp(_) + | ast::Expr::FString(_) + ); + if warns && let Some(value_type) = self.infer_type(value) { + self.warn_syntax( + value.range(), + format!( + "{} indices must be integers or slices, not {}; perhaps you missed a comma?", + value_type.name(), + index_type.name() + ), + )?; + } + Ok(()) + } + + fn check_assert(&mut self, assert_stmt: &ast::StmtAssert) -> CompileResult<()> { + let warns = match &*assert_stmt.test { + ast::Expr::Tuple(tuple) => !tuple.elts.is_empty(), + _ => matches!( + self.public_ast_constant_override(&assert_stmt.test), + Some(ConstantData::Tuple { ref elements }) if !elements.is_empty() + ), + }; + if warns { + self.warn_syntax( + assert_stmt.range, + "assertion is always true, perhaps remove parentheses?".to_owned(), + )?; + } + Ok(()) + } + + fn new_with_syntax_warning_handler( + opts: CompileOpts, + source_file: SourceFile, + code_name: &str, + syntax_warning_handler: Option<&'warnings mut SyntaxWarningHandler<'warnings>>, + ) -> Self { let module_code = ir::CodeInfo { // CPython convention: top-level module / interactive / // expression code does not carry CO_NEWLOCALS or CO_OPTIMIZED. @@ -467,7 +1049,7 @@ impl Compiler { // empty flags. frame.rs:725-731 then binds locals to globals // for module/REPL frames whose `scope.locals` is None - the // correct semantics for `exec(code, globals)` and module init. - flags: CodeFlags::empty(), + flags: bytecode::CodeFlags::empty(), source_path: source_file.name().to_owned(), private: None, blocks: Blocks::from([Block::default()]), @@ -492,7 +1074,7 @@ impl Compiler { }, static_attributes: None, in_inlined_comp: false, - fblock: Vec::with_capacity(MAXBLOCKS), + fblock: Vec::with_capacity(CO_MAXBLOCKS), symbol_table_index: 0, // Module is always the first symbol table nparams: 0, in_conditional_block: 0, @@ -501,10 +1083,12 @@ impl Compiler { Self { code_stack: vec![module_code], symbol_table_stack: Vec::new(), + annotation_symbol_sources: Vec::new(), source_file, // current_source_location: SourceLocation::default(), current_source_range: TextRange::default(), done_with_future_stmts: DoneWithFuture::No, + future_features: opts.future_features, future_annotations: false, ctx: CompileContext { in_class: false, @@ -515,7 +1099,9 @@ impl Compiler { in_annotation: false, interactive: false, do_not_emit_bytecode: 0, + disable_warning: 0, disable_const_collection_folding: false, + syntax_warning_handler, } } @@ -569,8 +1155,12 @@ impl Compiler { mem::replace(&mut code.instr_sequence, saved_instr_sequence); code.current_block = saved_current_block; code.instr_sequence_label_map = saved_instr_sequence_label_map; - code.annotations_instr_sequence = Some(annotations_instr_sequence); debug_assert!(saved_annotations_instr_sequence.is_none()); + if matches!(result, Ok(true)) { + code.annotations_instr_sequence = Some(annotations_instr_sequence); + } else { + code.annotations_instr_sequence = saved_annotations_instr_sequence; + } }; result.map(|_| ()) @@ -607,13 +1197,17 @@ impl Compiler { ) -> CompileResult<()> { // Save full subscript expression range (set by compile_expression before this call) let subscript_range = self.current_source_range; + if matches!(ctx, ast::ExprContext::Load) { + self.check_subscripter(value)?; + self.check_index(value, slice)?; + } // VISIT(c, expr, e->v.Subscript.value) self.compile_expression(value)?; // Handle two-element non-constant slice with BINARY_SLICE/STORE_SLICE let use_slice_opt = matches!(ctx, ast::ExprContext::Load | ast::ExprContext::Store) - && slice.should_use_slice_optimization(); + && self.should_apply_two_element_slice_optimization(slice); if use_slice_opt { match slice { ast::Expr::Slice(s) => self.compile_slice_two_parts(s)?, @@ -659,9 +1253,10 @@ impl Compiler { /// - collection_type: What type of collection to build (tuple, list, set) /// // = starunpack_helper in compile.c - fn starunpack_helper( + fn starunpack_helper_impl( &mut self, elts: &[ast::Expr], + injected_arg: Option<&str>, pushed: u32, collection_type: CollectionType, ) -> CompileResult<()> { @@ -669,7 +1264,8 @@ impl Compiler { let n = elts.len().to_u32(); let seen_star = elts.iter().any(|e| matches!(e, ast::Expr::Starred(_))); - let big = n + pushed > STACK_USE_GUIDELINE; + let injected_count = u32::from(injected_arg.is_some()); + let big = n + pushed + injected_count > STACK_USE_GUIDELINE; // Match CPython's constant ordering by letting the late flowgraph-style // folding passes introduce tuple-backed constants after their operands @@ -708,7 +1304,11 @@ impl Compiler { for elt in elts { self.compile_expression(elt)?; } - let total_size = n + pushed; + if let Some(injected_arg) = injected_arg { + self.set_source_range(collection_range); + self.load_name(injected_arg)?; + } + let total_size = n + injected_count + pushed; self.set_source_range(collection_range); match collection_type { CollectionType::List => { @@ -729,6 +1329,7 @@ impl Compiler { let mut i = 0u32; if big { + self.set_source_range(collection_range); match collection_type { CollectionType::List => { emit!(self, Instruction::BuildList { count: pushed }); @@ -803,22 +1404,22 @@ impl Compiler { } } - // If we never built sequence (all non-starred), build it now - if !sequence_built { + debug_assert!(sequence_built); + if let Some(injected_arg) = injected_arg { + self.set_source_range(collection_range); + self.load_name(injected_arg)?; self.set_source_range(collection_range); match collection_type { - CollectionType::List => { - emit!(self, Instruction::BuildList { count: i + pushed }); + CollectionType::List | CollectionType::Tuple => { + emit!(self, Instruction::ListAppend { i: 1 }); } CollectionType::Set => { - emit!(self, Instruction::BuildSet { count: i + pushed }); - } - CollectionType::Tuple => { - emit!(self, Instruction::BuildTuple { count: i + pushed }); + emit!(self, Instruction::SetAdd { i: 1 }); } } - } else if collection_type == CollectionType::Tuple { - // For tuples, convert the list to tuple + } + + if collection_type == CollectionType::Tuple { self.set_source_range(collection_range); emit!( self, @@ -831,6 +1432,15 @@ impl Compiler { Ok(()) } + fn starunpack_helper( + &mut self, + elts: &[ast::Expr], + pushed: u32, + collection_type: CollectionType, + ) -> CompileResult<()> { + self.starunpack_helper_impl(elts, None, pushed, collection_type) + } + fn error(&mut self, error: CodegenErrorType) -> CodegenError { self.error_ranged(error, self.current_source_range) } @@ -847,6 +1457,21 @@ impl Compiler { } } + fn error_optional_range( + &mut self, + error: CodegenErrorType, + range: Option, + ) -> CodegenError { + match range { + Some(range) => self.error_ranged(error, range), + None => CodegenError { + error, + location: None, + source_path: self.source_file.name().to_owned(), + }, + } + } + /// Get the SymbolTable for the current scope. fn current_symbol_table(&self) -> &SymbolTable { self.symbol_table_stack @@ -929,6 +1554,20 @@ impl Compiler { )))); } + while current_table.next_sub_table < current_table.sub_tables.len() + && current_table.sub_tables[current_table.next_sub_table].typ + == CompilerScope::Annotation + { + current_table.next_sub_table += 1; + } + if current_table.next_sub_table >= current_table.sub_tables.len() { + let name = current_table.name.clone(); + let typ = current_table.typ; + return Err(self.error(CodegenErrorType::SyntaxError(format!( + "no symbol table available in {name} (type: {typ:?})" + )))); + } + let idx = current_table.next_sub_table; current_table.next_sub_table += 1; let table = current_table.sub_tables[idx].clone(); @@ -938,31 +1577,101 @@ impl Compiler { Ok(self.current_symbol_table()) } - /// Push the annotation symbol table from the next sub_table's annotation_block - /// The annotation_block is stored in the function's scope, which is the next sub_table - /// Returns true if annotation_block exists, false otherwise - fn push_annotation_symbol_table(&mut self) -> bool { + fn push_symbol_table_matching( + &mut self, + typ: CompilerScope, + table_name: &str, + ) -> CompileResult<&SymbolTable> { let current_table = self .symbol_table_stack .last_mut() .expect("no current symbol table"); - // The annotation_block is in the next sub_table (function scope) - let next_idx = current_table.next_sub_table; - if next_idx >= current_table.sub_tables.len() { - return false; + while current_table.next_sub_table < current_table.sub_tables.len() + && current_table.sub_tables[current_table.next_sub_table].typ + == CompilerScope::Annotation + { + current_table.next_sub_table += 1; } - let next_table = &mut current_table.sub_tables[next_idx]; - if let Some(annotation_block) = next_table.annotation_block.take() { - self.symbol_table_stack.push(*annotation_block); - true - } else { - false - } - } + let start = current_table.next_sub_table; + let Some(idx) = current_table.sub_tables[start..] + .iter() + .position(|table| table.typ == typ && table.name == table_name) + .map(|idx| start + idx) + else { + let name = current_table.name.clone(); + let current_typ = current_table.typ; + return Err(self.error(CodegenErrorType::SyntaxError(format!( + "no matching symbol table {table_name} ({typ:?}) available in {name} (type: {current_typ:?})" + )))); + }; - /// Push the annotation symbol table for module/class level annotations + let table = current_table.sub_tables[idx].clone(); + current_table.next_sub_table = idx + 1; + self.symbol_table_stack.push(table); + Ok(self.current_symbol_table()) + } + + /// Push the function annotation symbol table. + /// CPython stores signature annotation blocks in st_blocks keyed by the + /// arguments AST node. Without future annotations they are also children; + /// with future annotations they are hidden from children and consumed here. + fn push_annotation_symbol_table(&mut self) -> bool { + let Some((annotation_table, source)) = ({ + let current_table = self + .symbol_table_stack + .last_mut() + .expect("no current symbol table"); + + let next_idx = current_table.next_sub_table; + if next_idx < current_table.sub_tables.len() + && current_table.sub_tables[next_idx].typ == CompilerScope::Annotation + { + let next_table = current_table.sub_tables[next_idx].clone(); + current_table.next_sub_table += 1; + Some((next_table, AnnotationSymbolSource::Sibling)) + } else if current_table.next_hidden_annotation_block + < current_table.hidden_annotation_blocks.len() + { + let idx = current_table.next_hidden_annotation_block; + current_table.next_hidden_annotation_block += 1; + Some(( + current_table.hidden_annotation_blocks[idx].clone(), + AnnotationSymbolSource::Hidden, + )) + } else { + None + } + }) else { + return false; + }; + + self.symbol_table_stack.push(annotation_table); + self.annotation_symbol_sources.push(source); + true + } + + fn next_function_annotation_symbol_table_uses_annotations(&self) -> bool { + let current_table = self + .symbol_table_stack + .last() + .expect("no current symbol table"); + let next_idx = current_table.next_sub_table; + if next_idx < current_table.sub_tables.len() + && current_table.sub_tables[next_idx].typ == CompilerScope::Annotation + { + return current_table.sub_tables[next_idx].annotations_used; + } + + let hidden_idx = current_table.next_hidden_annotation_block; + current_table + .hidden_annotation_blocks + .get(hidden_idx) + .is_some_and(|table| table.annotations_used) + } + + /// Push the annotation symbol table for module/class level annotations /// This takes annotation_block from the current symbol table (not sub_tables) fn push_current_annotation_symbol_table(&mut self) -> bool { let current_table = self @@ -979,18 +1688,15 @@ impl Compiler { } } - /// Pop the annotation symbol table and restore it to the function scope's annotation_block + /// Pop the annotation symbol table. fn pop_annotation_symbol_table(&mut self) { - let annotation_table = self.symbol_table_stack.pop().expect("compiler bug"); - let current_table = self - .symbol_table_stack - .last_mut() - .expect("no current symbol table"); - - // Restore to the next sub_table (function scope) where it came from - let next_idx = current_table.next_sub_table; - if next_idx < current_table.sub_tables.len() { - current_table.sub_tables[next_idx].annotation_block = Some(Box::new(annotation_table)); + self.symbol_table_stack.pop().expect("compiler bug"); + let source = self + .annotation_symbol_sources + .pop() + .expect("missing annotation symbol source"); + match source { + AnnotationSymbolSource::Sibling | AnnotationSymbolSource::Hidden => {} } } @@ -1032,28 +1738,22 @@ impl Compiler { return None; } - // 5. Must be inside a function (not at module level or class body) - if !self.ctx.in_func() { - return None; - } - - // 6. "super" must be GlobalImplicit (not redefined locally or at module level) + // 5. "super" must be GlobalImplicit in the current scope. let table = self.current_symbol_table(); if let Some(symbol) = table.lookup("super") && symbol.scope != SymbolScope::GlobalImplicit { return None; } - // Also check top-level scope to detect module-level shadowing. - // Only block if super is actually *bound* at module level (not just used). + // CPython then checks the top-level scope and rejects any statically + // visible symbol for "super", not just local bindings. if let Some(top_table) = self.symbol_table_stack.first() - && let Some(sym) = top_table.lookup("super") - && sym.scope != SymbolScope::GlobalImplicit + && top_table.lookup("super").is_some() { return None; } - // 7. Check argument pattern + // 6. Check argument pattern let args = &arguments.args; // No starred expressions allowed @@ -1180,10 +1880,13 @@ impl Compiler { let source_path = self.source_file.name().to_owned(); // Lookup symbol table entry using key (_PySymtable_Lookup) - let Some(ste) = self.symbol_table_stack.get(key) else { - return Err(self.error(CodegenErrorType::SyntaxError( - "unknown symbol table entry".into(), - ))); + let ste = match self.symbol_table_stack.get(key) { + Some(v) => v, + None => { + return Err(self.error(CodegenErrorType::SyntaxError( + "unknown symbol table entry".to_owned(), + ))); + } }; // Use varnames from symbol table (already collected in definition order) @@ -1192,32 +1895,16 @@ impl Compiler { // Build cellvars using dictbytype (CELL scope or COMP_CELL flag, sorted) let mut cellvar_cache = IndexSet::default(); - // CPython ordering: parameter cells first (in parameter order), - // then non-parameter cells (alphabetically sorted) - let cell_symbols: Vec<_> = ste + let mut cell_names: Vec<_> = ste .symbols .iter() .filter(|(_, s)| { s.scope == SymbolScope::Cell || s.flags.contains(SymbolFlags::COMP_CELL) }) - .map(|(name, sym)| (name.clone(), sym.flags)) + .map(|(name, _)| name.clone()) .collect(); - let mut param_cells = Vec::new(); - let mut nonparam_cells = Vec::new(); - for (name, flags) in cell_symbols { - if flags.contains(SymbolFlags::PARAMETER) { - param_cells.push(name); - } else { - nonparam_cells.push(name); - } - } - // param_cells are already in parameter order (from varname_cache insertion order) - param_cells.sort_by_key(|n| varname_cache.get_index_of(n.as_str()).unwrap_or(usize::MAX)); - nonparam_cells.sort(); - for name in param_cells { - cellvar_cache.insert(name); - } - for name in nonparam_cells { + cell_names.sort(); + for name in cell_names { cellvar_cache.insert(name); } @@ -1254,21 +1941,16 @@ impl Compiler { .collect() }) .unwrap_or_default(); - - let mut free_names = ste + let mut free_names: Vec<_> = ste .symbols .iter() .filter(|(_, s)| { - if s.scope == SymbolScope::Free { - return true; - } - - let has_free_class = s.flags.contains(SymbolFlags::FREE_CLASS); - if scope_type == CompilerScope::Class { - has_free_class && self.has_enclosing_non_module_code_scope() - } else { - has_free_class - } + s.scope == SymbolScope::Free + || (scope_type != CompilerScope::Class + && s.flags.contains(SymbolFlags::FREE_CLASS)) + || (scope_type == CompilerScope::Class + && s.flags.contains(SymbolFlags::FREE_CLASS) + && self.has_enclosing_non_module_code_scope()) }) .filter(|(name, symbol)| { if !matches!( @@ -1280,8 +1962,7 @@ impl Compiler { !(annotation_free_names.contains(*name) && symbol.flags.is_empty()) }) .map(|(name, _)| name.clone()) - .collect::>(); - + .collect(); free_names.sort(); for name in free_names { freevar_cache.insert(name); @@ -1289,31 +1970,42 @@ impl Compiler { // Initialize u_metadata fields let (mut flags, posonlyarg_count, arg_count, kwonlyarg_count) = match scope_type { - CompilerScope::Module => (CodeFlags::empty(), 0, 0, 0), - CompilerScope::Class => (CodeFlags::empty(), 0, 0, 0), + CompilerScope::Module => (bytecode::CodeFlags::empty(), 0, 0, 0), + CompilerScope::Class => (bytecode::CodeFlags::empty(), 0, 0, 0), CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Lambda => ( - CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED, + bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, 0, // Will be set later in enter_function 0, // Will be set later in enter_function 0, // Will be set later in enter_function ), CompilerScope::Comprehension => ( - CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED, + bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, 0, 1, // comprehensions take one argument (.0) 0, ), - CompilerScope::TypeParams => (CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED, 0, 0, 0), + CompilerScope::TypeParams => ( + bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, + 0, + 0, + 0, + ), CompilerScope::Annotation => ( - CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED, + bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, + 1, // format is positional-only + 0, + 0, + ), + CompilerScope::TypeAlias | CompilerScope::TypeVariable => ( + bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, 1, // format is positional-only - 1, // annotation scope takes one argument (format) + 0, 0, ), }; if ste.is_method { - flags |= CodeFlags::METHOD; + flags |= bytecode::CodeFlags::METHOD; } // CPython sets CO_NESTED from symtable's ste_nested, not merely @@ -1327,15 +2019,15 @@ impl Compiler { | CompilerScope::Lambda | CompilerScope::Comprehension | CompilerScope::Annotation + | CompilerScope::TypeAlias + | CompilerScope::TypeVariable | CompilerScope::TypeParams ) { - flags | CodeFlags::NESTED + flags | bytecode::CodeFlags::NESTED } else { flags }; - if self.future_annotations { - flags |= CodeFlags::FUTURE_ANNOTATIONS; - } + flags |= self.future_features; // Get private name from parent scope let private = if !self.code_stack.is_empty() { @@ -1375,7 +2067,7 @@ impl Compiler { None }, in_inlined_comp: false, - fblock: Vec::with_capacity(MAXBLOCKS), + fblock: Vec::with_capacity(CO_MAXBLOCKS), symbol_table_index: key, nparams, in_conditional_block: 0, @@ -1418,7 +2110,10 @@ impl Compiler { let except_handler = None; self.cpython_cfg_builder_addop(ir::InstructionInfo { - instr: Opcode::Resume.into(), + instr: Instruction::Resume { + context: OpArgMarker::marker(), + } + .into(), arg: OpArg::new(oparg::ResumeLocation::AtFuncStart.into()), target: BlockIdx::NULL, location, @@ -1456,7 +2151,15 @@ impl Compiler { // Preserve flags computed from the symbol-table context. info.flags = flags | (info.flags - & (CodeFlags::NESTED | CodeFlags::METHOD | CodeFlags::FUTURE_ANNOTATIONS)); + & (bytecode::CodeFlags::NESTED + | bytecode::CodeFlags::METHOD + | bytecode::CodeFlags::FUTURE_DIVISION + | bytecode::CodeFlags::FUTURE_ABSOLUTE_IMPORT + | bytecode::CodeFlags::FUTURE_WITH_STATEMENT + | bytecode::CodeFlags::FUTURE_PRINT_FUNCTION + | bytecode::CodeFlags::FUTURE_UNICODE_LITERALS + | bytecode::CodeFlags::FUTURE_GENERATOR_STOP + | bytecode::CodeFlags::FUTURE_ANNOTATIONS)); info.metadata.argcount = arg_count; info.metadata.posonlyargcount = posonlyarg_count; info.metadata.kwonlyargcount = kwonlyarg_count; @@ -1479,17 +2182,25 @@ impl Compiler { unwrap_internal(self, stack_top.finalize_code(&self.opts)) } - /// Exit annotation scope - similar to exit_scope but restores annotation_block to parent + fn expose_annotation_format_parameter(code: &mut CodeObject) { + if let Some(first) = code.varnames.first_mut() { + *first = "format".to_owned(); + } + } + + /// Exit a function signature annotation scope. fn exit_annotation_scope(&mut self, saved_ctx: CompileContext) -> CodeObject { self.pop_annotation_symbol_table(); self.ctx = saved_ctx; let pop = self.code_stack.pop(); let stack_top = compiler_unwrap_option(self, pop); - unwrap_internal(self, stack_top.finalize_code(&self.opts)) + let mut code = unwrap_internal(self, stack_top.finalize_code(&self.opts)); + Self::expose_annotation_format_parameter(&mut code); + code } - /// Enter annotation scope using the symbol table's annotation_block. - /// Returns None if no annotation_block exists. + /// Enter a function signature annotation scope. + /// Returns None if no matching annotation symbol table exists. /// On success, returns the saved CompileContext to pass to exit_annotation_scope. fn enter_annotation_scope( &mut self, @@ -1518,12 +2229,12 @@ impl Compiler { lineno.to_u32(), )?; - // Override arg_count since enter_scope sets it to 1 but we need the varnames - // setup to be correct too + // Keep CPython's internal ".format" name; exit_annotation_scope() + // renames it to "format" on the final code object. self.current_code_info() .metadata .varnames - .insert("format".to_owned()); + .insert(".format".to_owned()); // Emit format validation: if format > VALUE_WITH_FAKE_GLOBALS: raise NotImplementedError // VALUE_WITH_FAKE_GLOBALS = 2 (from annotationlib.Format) @@ -1586,12 +2297,15 @@ impl Compiler { fb_datum: FBlockDatum, ) -> CompileResult<()> { let fb_range = self.current_source_range; - let code = self.current_code_info(); - if code.fblock.len() >= MAXBLOCKS { + if self.current_code_info().fblock.len() >= CO_MAXBLOCKS { return Err(self.error(CodegenErrorType::SyntaxError( "too many statically nested blocks".to_owned(), ))); } + if matches!(fb_type, FBlockType::FinallyEnd) { + self.disable_warning += 1; + } + let code = self.current_code_info(); code.fblock.push(FBlockInfo { fb_type, fb_block, @@ -1608,16 +2322,49 @@ impl Compiler { expected_type: FBlockType, expected_block: ir::InstructionSequenceLabel, ) -> FBlockInfo { - let code = self.current_code_info(); - let fblock = code.fblock.pop().expect("fblock stack underflow"); + let fblock = { + let code = self.current_code_info(); + code.fblock.pop().expect("fblock stack underflow") + }; debug_assert_eq!(fblock.fb_type, expected_type); debug_assert_eq!( fblock.fb_block, expected_block, "CPython _PyCompile_PopFBlock asserts the popped fb_block label" ); + if matches!(expected_type, FBlockType::FinallyEnd) { + self.disable_warning -= 1; + } fblock } + /// CPython `_PyCompile_PushFBlock()` call used by + /// `codegen_unwind_fblock_stack()` to restore the copied fblock after + /// recursive unwinding. + fn restore_fblock_info(&mut self, fblock: FBlockInfo) -> CompileResult<()> { + let FBlockInfo { + fb_type, + fb_block, + fb_exit, + fb_range, + fb_datum, + } = fblock; + let code = self.current_code_info(); + if code.fblock.len() >= CO_MAXBLOCKS { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError("too many statically nested blocks".to_owned()), + fb_range, + )); + } + code.fblock.push(FBlockInfo { + fb_type, + fb_block, + fb_exit, + fb_range, + fb_datum, + }); + Ok(()) + } + fn set_unwind_source_range(&mut self, loc: Option) { if let Some(range) = loc { self.set_source_range(range); @@ -1667,10 +2414,32 @@ impl Compiler { } FBlockType::FinallyTry => { - // FinallyTry is now handled specially in unwind_fblock_stack - // to avoid infinite recursion when the finally body contains return/break/continue. - // This branch should not be reached. - unreachable!("FinallyTry should be handled by unwind_fblock_stack"); + // codegen_unwind_fblock(FINALLY_TRY) + self.set_unwind_source_range(*loc); + emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); + + if preserve_tos { + self.push_fblock_labels( + FBlockType::PopValue, + ir::InstructionSequenceLabel::NO_LABEL, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; + } + + if let FBlockDatum::FinallyBody(ref body) = info.fb_datum { + self.compile_statements(body)?; + } + + if preserve_tos { + self.pop_fblock_label( + FBlockType::PopValue, + ir::InstructionSequenceLabel::NO_LABEL, + ); + } + + *loc = None; } FBlockType::FinallyEnd => { @@ -1802,96 +2571,40 @@ impl Compiler { preserve_tos: bool, stop_at_loop: bool, ) -> CompileResult<(Option, Option)> { - // Collect the info we need, with indices for FinallyTry blocks - #[derive(Clone)] - enum UnwindInfo { - Normal(FBlockInfo), - FinallyTry { - body: Vec, - fblock_idx: usize, - }, - } - let mut unwind_infos = Vec::new(); - let mut loop_fblock = None; - - { - let code = self.current_code_info(); - for i in (0..code.fblock.len()).rev() { - // Check for exception group handler (forbidden) - if matches!(code.fblock[i].fb_type, FBlockType::ExceptionGroupHandler) { - return Err(self.error(CodegenErrorType::BreakContinueReturnInExceptStar)); - } - - // Stop at loop if requested - if stop_at_loop - && matches!( - code.fblock[i].fb_type, - FBlockType::WhileLoop | FBlockType::ForLoop - ) - { - loop_fblock = Some(code.fblock[i].clone()); - break; - } - - if matches!(code.fblock[i].fb_type, FBlockType::FinallyTry) { - if let FBlockDatum::FinallyBody(ref body) = code.fblock[i].fb_datum { - unwind_infos.push(UnwindInfo::FinallyTry { - body: body.clone(), - fblock_idx: i, - }); - } - } else { - unwind_infos.push(UnwindInfo::Normal(code.fblock[i].clone())); - } - } - } - - // Process each fblock let mut unwind_loc = Some(self.current_source_range); - for info in unwind_infos { - match info { - UnwindInfo::Normal(fblock_info) => { - self.unwind_fblock(&fblock_info, preserve_tos, &mut unwind_loc)?; - } - UnwindInfo::FinallyTry { body, fblock_idx } => { - // codegen_unwind_fblock(FINALLY_TRY) - self.set_unwind_source_range(unwind_loc); - emit!(self, PseudoInstruction::PopBlock); - self.mark_unwind_no_location(unwind_loc); - - // Temporarily remove the FinallyTry fblock so nested return/break/continue - // in the finally body won't see it again - let code = self.current_code_info(); - let saved_fblock = code.fblock.remove(fblock_idx); - - // Push PopValue fblock if preserving tos - if preserve_tos { - self.push_fblock_labels( - FBlockType::PopValue, - ir::InstructionSequenceLabel::NO_LABEL, - ir::InstructionSequenceLabel::NO_LABEL, - FBlockDatum::None, - )?; - } - - self.compile_statements(&body)?; - unwind_loc = None; - - if preserve_tos { - self.pop_fblock_label( - FBlockType::PopValue, - ir::InstructionSequenceLabel::NO_LABEL, - ); - } + let loop_fblock = + self.unwind_fblock_stack_inner(preserve_tos, stop_at_loop, &mut unwind_loc)?; + Ok((unwind_loc, loop_fblock)) + } - // Restore the fblock - let code = self.current_code_info(); - code.fblock.insert(fblock_idx, saved_fblock); - } - } + fn unwind_fblock_stack_inner( + &mut self, + preserve_tos: bool, + stop_at_loop: bool, + unwind_loc: &mut Option, + ) -> CompileResult> { + let Some(top) = self.current_code_info().fblock.last().cloned() else { + return Ok(None); + }; + if matches!(top.fb_type, FBlockType::ExceptionGroupHandler) { + return Err(self.error_optional_range( + CodegenErrorType::BreakContinueReturnInExceptStar, + *unwind_loc, + )); + } + if stop_at_loop && matches!(top.fb_type, FBlockType::WhileLoop | FBlockType::ForLoop) { + return Ok(Some(top)); } - Ok((unwind_loc, loop_fblock)) + let copy = self + .current_code_info() + .fblock + .pop() + .expect("fblock stack underflow"); + self.unwind_fblock(©, preserve_tos, unwind_loc)?; + let loop_fblock = self.unwind_fblock_stack_inner(preserve_tos, stop_at_loop, unwind_loc)?; + self.restore_fblock_info(copy)?; + Ok(loop_fblock) } // could take impl Into>, but everything is borrowed from ast structs; we never @@ -1949,7 +2662,12 @@ impl Compiler { // when building qualnames for the contained function/class code object. if matches!( parent_scope, - Some(CompilerScope::TypeParams | CompilerScope::Annotation) + Some( + CompilerScope::TypeParams + | CompilerScope::Annotation + | CompilerScope::TypeAlias + | CompilerScope::TypeVariable, + ) ) || parent.metadata.name.starts_with(" CompileResult<()> { + let future_features = self.future_features; + self.current_code_info().flags |= future_features; + if symbol_table.is_coroutine { + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::COROUTINE); + } self.symbol_table_stack.push(symbol_table); self.emit_resume_for_scope(CompilerScope::Module, 1); @@ -2264,6 +2970,13 @@ impl Compiler { expression: &ast::ModExpression, symbol_table: SymbolTable, ) -> CompileResult<()> { + let future_features = self.future_features; + self.current_code_info().flags |= future_features; + if symbol_table.is_coroutine { + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::COROUTINE); + } self.symbol_table_stack.push(symbol_table); self.emit_resume_for_scope(CompilerScope::Module, 1); @@ -2365,7 +3078,6 @@ impl Compiler { fn emit_no_location_exception_name_cleanup(&mut self, name: &str) -> CompileResult<()> { // CPython codegen_try_except() emits `name = None; del name` // with NO_LOCATION for `except ... as name` cleanup. - self.set_no_location(); self.emit_load_const(ConstantData::None); self.set_no_location(); self.store_name(name)?; @@ -2423,7 +3135,10 @@ impl Compiler { let current_idx = self.symbol_table_stack.len() - 1; let current_table = &self.symbol_table_stack[current_idx]; let is_typeparams = current_table.typ == CompilerScope::TypeParams; - let is_annotation = current_table.typ == CompilerScope::Annotation; + let is_annotation = matches!( + current_table.typ, + CompilerScope::Annotation | CompilerScope::TypeAlias | CompilerScope::TypeVariable + ); let can_see_class = current_table.can_see_class_scope; // First try to find in current table @@ -2462,13 +3177,11 @@ impl Compiler { let current_table = self.current_symbol_table(); if current_table.typ == CompilerScope::Class && !self.current_code_info().in_inlined_comp - && matches!( - (usage, name.as_ref()), - ( - NameUsage::Load, - "__class__" | "__classdict__" | "__conditional_annotations__" - ) | (NameUsage::Store, "__conditional_annotations__") - ) + && ((usage == NameUsage::Load + && (name == "__class__" + || name == "__classdict__" + || name == "__conditional_annotations__")) + || (name == "__conditional_annotations__" && usage == NameUsage::Store)) { Some(SymbolScope::Cell) } else { @@ -2485,7 +3198,10 @@ impl Compiler { let current_table = self.current_symbol_table(); if matches!( current_table.typ, - CompilerScope::Annotation | CompilerScope::TypeParams + CompilerScope::Annotation + | CompilerScope::TypeAlias + | CompilerScope::TypeVariable + | CompilerScope::TypeParams ) { SymbolScope::GlobalImplicit } else if matches!( @@ -2539,7 +3255,7 @@ impl Compiler { // to check classdict first before globals if class_declared_global { NameOp::Global - } else if can_see_class_scope { + } else if can_see_class_scope && usage == NameUsage::Load { NameOp::DictOrGlobals } else if is_function_like { NameOp::Global @@ -2551,7 +3267,7 @@ impl Compiler { // A global declared in the owning class body must bypass the // classdict, but an explicit global inherited from an outer // function still participates in DictOrGlobals lookup. - if can_see_class_scope && !class_declared_global { + if can_see_class_scope && !class_declared_global && usage == NameUsage::Load { NameOp::DictOrGlobals } else { NameOp::Global @@ -2635,21 +3351,10 @@ impl Compiler { NameOp::DictOrGlobals => { // PEP 649: First check classdict (from __classdict__ freevar), then globals let idx = self.get_global_name_index(&name); - match usage { - NameUsage::Load => { - // Load __classdict__ first (it's a free variable in annotation scope) - let classdict_idx = self.get_free_var_index("__classdict__"); - emit!(self, Instruction::LoadDeref { i: classdict_idx }); - emit!(self, Instruction::LoadFromDictOrGlobals { i: idx }); - } - // Store/Delete in annotation scope should use Name ops - NameUsage::Store => { - emit!(self, Instruction::StoreName { namei: idx }); - } - NameUsage::Delete => { - emit!(self, Instruction::DeleteName { namei: idx }); - } - } + debug_assert!(usage == NameUsage::Load); + let classdict_idx = self.get_free_var_index("__classdict__"); + emit!(self, Instruction::LoadDeref { i: classdict_idx }); + emit!(self, Instruction::LoadFromDictOrGlobals { i: idx }); } } @@ -2664,9 +3369,12 @@ impl Compiler { match &statement { // we do this here because `from __future__` still executes that `from` statement at runtime, // we still need to compile the ImportFrom down below - ast::Stmt::ImportFrom(ast::StmtImportFrom { module, names, .. }) - if module.as_ref().map(|id| id.as_str()) == Some("__future__") => - { + ast::Stmt::ImportFrom(ast::StmtImportFrom { + module, + names, + level, + .. + }) if *level == 0 && module.as_ref().map(|id| id.as_str()) == Some("__future__") => { self.compile_future_features(names)? } // ignore module-level doc comments @@ -2716,24 +3424,14 @@ impl Compiler { names, .. }) => { - let import_star = names.iter().any(|n| &n.name == "*"); + let import_star = names.first().is_some_and(|n| &n.name == "*"); - let from_list = if import_star { - if self.ctx.in_func() { - return Err(self.error_ranged( - CodegenErrorType::FunctionImportStar, - statement.range(), - )); - } - vec![ConstantData::Str { value: "*".into() }] - } else { - names - .iter() - .map(|n| ConstantData::Str { - value: n.name.as_str().into(), - }) - .collect() - }; + let from_list = names + .iter() + .map(|n| ConstantData::Str { + value: n.name.as_str().into(), + }) + .collect(); // from .... import (*fromlist) self.emit_load_const(ConstantData::Integer { @@ -2811,17 +3509,21 @@ impl Compiler { .. }) => { self.enter_conditional_block(); - self.compile_if(test, body, elif_else_clauses, test.range())?; + self.compile_if(test, body, elif_else_clauses, statement.range())?; self.leave_conditional_block(); self.set_source_range(statement.range()); } ast::Stmt::While(ast::StmtWhile { - test, body, orelse, .. - }) => self.compile_while(test, body, orelse)?, - ast::Stmt::With(ast::StmtWith { - items, + test, body, - is_async, + orelse, + range, + .. + }) => self.compile_while(test, body, orelse, *range)?, + ast::Stmt::With(ast::StmtWith { + items, + body, + is_async, .. }) => self.compile_with(items, body, *is_async)?, ast::Stmt::For(ast::StmtFor { @@ -2909,13 +3611,15 @@ impl Compiler { arguments.as_deref(), false, )?, - ast::Stmt::Assert(ast::StmtAssert { - test, msg, range, .. - }) => { + ast::Stmt::Assert(assert_stmt) => { + let ast::StmtAssert { + test, msg, range, .. + } = assert_stmt; + self.check_assert(assert_stmt)?; // if some flag, ignore all assert statements! if self.opts.optimize == 0 { let after_block = self.new_block(); - self.compile_jump_if(test, true, after_block)?; + self.compile_jump_if_inner(test, true, after_block, Some(*range))?; self.set_source_range(*range); emit!( self, @@ -2973,7 +3677,13 @@ impl Compiler { statement.range(), )); } - let folded_constant = if v.is_constant() { + let debug_constant = matches!( + &**v, + ast::Expr::Name(ast::ExprName { id, ctx, .. }) + if matches!(ctx, ast::ExprContext::Load) + && id.as_str() == "__debug__" + ); + let folded_constant = if self.is_constant_expr(v) || debug_constant { self.try_fold_constant_expr(v)? } else { None @@ -3000,18 +3710,17 @@ impl Compiler { let unwind_loc = self.unwind_fblock_stack(preserve_tos, false)?; if let Some(loc) = unwind_loc { self.set_source_range(loc); - } - match folded_constant { - Some(constant) if unwind_loc.is_none() => { - self.emit_return_const_no_location(constant); - } - Some(constant) => { - self.emit_load_const(constant); - self.emit_return_value(); + match folded_constant { + Some(constant) => self.emit_return_const(constant), + None => { + self.emit_return_value(); + } } - None => { - self.emit_return_value(); - if unwind_loc.is_none() { + } else { + match folded_constant { + Some(constant) => self.emit_return_const_no_location(constant), + None => { + self.emit_return_value(); self.set_no_location(); } } @@ -3099,6 +3808,7 @@ impl Compiler { let name_string = name.id.to_string(); if let Some(type_params) = type_params { + self.set_source_range(*range); self.push_symbol_table()?; let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get().to_u32(); @@ -3113,11 +3823,13 @@ impl Compiler { in_async_scope: false, }; + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name_string.clone().into(), }); self.compile_type_params(type_params)?; self.compile_typealias_value_closure(&name_string, value, *range)?; + self.set_source_range(*range); emit!(self, Instruction::BuildTuple { count: 3 }); emit!( self, @@ -3129,15 +3841,19 @@ impl Compiler { let code = self.exit_scope(); self.ctx = prev_ctx; - self.make_closure(code, MakeFunctionFlags::new())?; + self.set_source_range(*range); + self.make_closure(code, bytecode::MakeFunctionFlags::new())?; + self.set_source_range(*range); emit!(self, Instruction::PushNull); emit!(self, Instruction::Call { argc: 0 }); } else { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name_string.clone().into(), }); self.emit_load_const(ConstantData::None); self.compile_typealias_value_closure(&name_string, value, *range)?; + self.set_source_range(*range); emit!(self, Instruction::BuildTuple { count: 3 }); emit!( self, @@ -3147,9 +3863,15 @@ impl Compiler { ); } + self.set_source_range(*range); self.store_name(&name_string)?; } - ast::Stmt::IpyEscapeCommand(_) => todo!(), + ast::Stmt::IpyEscapeCommand(stmt) => { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError("invalid syntax".to_owned()), + stmt.range, + )); + } } Ok(()) } @@ -3195,21 +3917,10 @@ impl Compiler { } fn enter_function(&mut self, name: &str, parameters: &ast::Parameters) -> CompileResult<()> { - // TODO: partition_in_place - let mut kw_without_defaults = vec![]; - let mut kw_with_defaults = vec![]; - for kwonlyarg in ¶meters.kwonlyargs { - if let Some(default) = &kwonlyarg.default { - kw_with_defaults.push((&kwonlyarg.parameter, default)); - } else { - kw_without_defaults.push(&kwonlyarg.parameter); - } - } - self.push_output( - CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED, + bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, parameters.posonlyargs.len().to_u32(), - (parameters.posonlyargs.len() + parameters.args.len()).to_u32(), + parameters.args.len().to_u32(), parameters.kwonlyargs.len().to_u32(), name, )?; @@ -3218,18 +3929,17 @@ impl Compiler { .chain(¶meters.posonlyargs) .chain(¶meters.args) .map(|arg| &arg.parameter) - .chain(kw_without_defaults) - .chain(kw_with_defaults.into_iter().map(|(arg, _)| arg)); + .chain(parameters.kwonlyargs.iter().map(|arg| &arg.parameter)); for name in args_iter { self.varname(name.name.as_str()); } if let Some(name) = parameters.vararg.as_deref() { - self.current_code_info().flags |= CodeFlags::VARARGS; + self.current_code_info().flags |= bytecode::CodeFlags::VARARGS; self.varname(name.name.as_str()); } if let Some(name) = parameters.kwarg.as_deref() { - self.current_code_info().flags |= CodeFlags::VARKEYWORDS; + self.current_code_info().flags |= bytecode::CodeFlags::VARKEYWORDS; self.varname(name.name.as_str()); } @@ -3275,7 +3985,7 @@ impl Compiler { let lineno = self.get_source_line_number().get().to_u32(); // Enter scope with the type parameter name - self.enter_scope(name, CompilerScope::Annotation, key, lineno)?; + self.enter_scope(name, CompilerScope::TypeVariable, key, lineno)?; self.current_code_info() .metadata @@ -3306,13 +4016,17 @@ impl Compiler { // Return value self.set_source_range(expr_range); emit!(self, Instruction::ReturnValue); + self.emit_return_const_no_location(ConstantData::None); // Exit scope and create closure let code = self.exit_scope(); self.ctx = prev_ctx; self.set_source_range(expr_range); - self.make_closure(code, MakeFunctionFlags::from([MakeFunctionFlag::Defaults]))?; + self.make_closure( + code, + bytecode::MakeFunctionFlags::from([bytecode::MakeFunctionFlag::Defaults]), + )?; Ok(()) } @@ -3331,7 +4045,7 @@ impl Compiler { self.push_symbol_table()?; let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get().to_u32(); - self.enter_scope(alias_name, CompilerScope::Annotation, key, lineno)?; + self.enter_scope(alias_name, CompilerScope::TypeAlias, key, lineno)?; self.current_code_info() .metadata .varnames @@ -3352,7 +4066,10 @@ impl Compiler { let code = self.exit_scope(); self.ctx = prev_ctx; self.set_source_range(alias_range); - self.make_closure(code, MakeFunctionFlags::from([MakeFunctionFlag::Defaults]))?; + self.make_closure( + code, + bytecode::MakeFunctionFlags::from([bytecode::MakeFunctionFlag::Defaults]), + )?; Ok(()) } @@ -3360,6 +4077,7 @@ impl Compiler { /// Store each type parameter so it is accessible to the current scope, and leave a tuple of /// all the type parameters on the stack. Handles default values per PEP 695. fn compile_type_params(&mut self, type_params: &ast::TypeParams) -> CompileResult<()> { + let mut seen_default = false; // First, compile each type parameter and store it for type_param in &type_params.type_params { match type_param { @@ -3395,6 +4113,7 @@ impl Compiler { } if let Some(default_expr) = default { + seen_default = true; self.compile_type_param_bound_or_default( default_expr, name.as_str(), @@ -3407,6 +4126,13 @@ impl Compiler { func: bytecode::IntrinsicFunction2::SetTypeparamDefault } ); + } else if seen_default { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "non-default type parameter '{name}' follows default type parameter" + )), + *range, + )); } self.set_source_range(*range); @@ -3431,6 +4157,7 @@ impl Compiler { ); if let Some(default_expr) = default { + seen_default = true; self.compile_type_param_bound_or_default( default_expr, name.as_str(), @@ -3443,6 +4170,13 @@ impl Compiler { func: bytecode::IntrinsicFunction2::SetTypeparamDefault } ); + } else if seen_default { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "non-default type parameter '{name}' follows default type parameter" + )), + *range, + )); } self.set_source_range(*range); @@ -3480,6 +4214,14 @@ impl Compiler { func: bytecode::IntrinsicFunction2::SetTypeparamDefault } ); + seen_default = true; + } else if seen_default { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "non-default type parameter '{name}' follows default type parameter" + )), + *range, + )); } self.set_source_range(*range); @@ -3534,7 +4276,6 @@ impl Compiler { if handlers.is_empty() { self.compile_statements(body)?; - self.compile_statements(orelse)?; } else { self.compile_try_except_no_finally(body, handlers, orelse)?; } @@ -3543,7 +4284,7 @@ impl Compiler { self.set_no_location(); self.pop_fblock_label(FBlockType::FinallyTry, body_label); - let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); + let symbol_table_cursors = self.current_symbol_table_cursors(); self.compile_statements(finalbody)?; emit!( @@ -3552,11 +4293,7 @@ impl Compiler { ); self.set_no_location(); - if let Some(cursor) = sub_table_cursor - && let Some(current_table) = self.symbol_table_stack.last_mut() - { - current_table.next_sub_table = cursor; - } + self.set_symbol_table_cursors(symbol_table_cursors); self.use_cpython_label_block(finally_except_block); emit!( @@ -3620,7 +4357,16 @@ impl Compiler { self.pop_fblock_label(FBlockType::TryExcept, body_label); emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); + + // CPython's symtable stores child scopes in AST visit order + // (body, handlers, orelse), while codegen_try_except() emits orelse + // before the exception handlers. Keep the symbol table in CPython order + // and only move the codegen cursor while compiling orelse. + let handler_symbol_table_cursors = self.current_symbol_table_cursors(); + self.consume_skipped_nested_scopes_in_except_handlers(handlers)?; self.compile_statements(orelse)?; + let after_orelse_symbol_table_cursors = self.current_symbol_table_cursors(); + self.set_symbol_table_cursors(handler_symbol_table_cursors); emit!( self, PseudoInstruction::JumpNoInterrupt { delta: end_block } @@ -3739,13 +4485,14 @@ impl Compiler { self.use_cpython_label_block(next_handler); } + self.set_symbol_table_cursors(after_orelse_symbol_table_cursors); - emit!(self, Instruction::Reraise { depth: 0 }); - self.set_no_location(); self.pop_fblock_label( FBlockType::ExceptionHandler, ir::InstructionSequenceLabel::NO_LABEL, ); + emit!(self, Instruction::Reraise { depth: 0 }); + self.set_no_location(); self.use_cpython_label_block(cleanup_block); emit!(self, Instruction::Copy { i: 3 }); @@ -3801,7 +4548,7 @@ impl Compiler { self.set_no_location(); self.pop_fblock_label(FBlockType::FinallyTry, body_label); - let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); + let symbol_table_cursors = self.current_symbol_table_cursors(); self.compile_statements(finalbody)?; emit!( @@ -3810,11 +4557,7 @@ impl Compiler { ); self.set_no_location(); - if let Some(cursor) = sub_table_cursor - && let Some(current_table) = self.symbol_table_stack.last_mut() - { - current_table.next_sub_table = cursor; - } + self.set_symbol_table_cursors(symbol_table_cursors); self.use_cpython_label_block(finally_except_block); emit!( @@ -3881,9 +4624,9 @@ impl Compiler { FBlockDatum::None, )?; self.compile_statements(body)?; + self.pop_fblock_label(FBlockType::TryExcept, body_label); emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - self.pop_fblock_label(FBlockType::TryExcept, body_label); emit!( self, PseudoInstruction::JumpNoInterrupt { delta: else_block } @@ -3938,30 +4681,29 @@ impl Compiler { emit!(self, Instruction::Copy { i: 2 }); } - // Compile exception type + // Compile exception type. CPython's public-AST validator allows a + // NULL type here, so codegen only emits CHECK_EG_MATCH when present. if let Some(exc_type) = type_ { self.compile_expression(exc_type)?; self.set_source_range(*handler_range); - } else { - return Err(self.error(CodegenErrorType::SyntaxError( - "except* must specify an exception type".to_owned(), - ))); } - // Stack: [prev_exc, orig, list, rest, type] - // ADDOP(c, loc, CHECK_EG_MATCH); - emit!(self, Instruction::CheckEgMatch); - // Stack: [prev_exc, orig, list, new_rest, match] + if type_.is_some() { + // Stack: [prev_exc, orig, list, rest, type] + // ADDOP(c, loc, CHECK_EG_MATCH); + emit!(self, Instruction::CheckEgMatch); + // Stack: [prev_exc, orig, list, new_rest, match] - // ADDOP_I(c, loc, COPY, 1); - // ADDOP_JUMP(c, loc, POP_JUMP_IF_NONE, no_match); - emit!(self, Instruction::Copy { i: 1 }); - emit!( - self, - Instruction::PopJumpIfNone { - delta: no_match_block - } - ); + // ADDOP_I(c, loc, COPY, 1); + // ADDOP_JUMP(c, loc, POP_JUMP_IF_NONE, no_match); + emit!(self, Instruction::Copy { i: 1 }); + emit!( + self, + Instruction::PopJumpIfNone { + delta: no_match_block + } + ); + } // Handler matched // Stack: [prev_exc, orig, list, new_rest, match] @@ -4001,9 +4743,9 @@ impl Compiler { self.compile_statements(body)?; // Handler body completed normally + self.pop_fblock_label(FBlockType::HandlerCleanup, cleanup_body_label); emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - self.pop_fblock_label(FBlockType::HandlerCleanup, cleanup_body_label); // Cleanup name binding if let Some(alias) = name { @@ -4179,7 +4921,7 @@ impl Compiler { parameters: &ast::Parameters, loc: TextRange, ) -> CompileResult { - let mut funcflags = MakeFunctionFlags::new(); + let mut funcflags = bytecode::MakeFunctionFlags::new(); // Handle positional defaults let defaults: Vec<_> = core::iter::empty() @@ -4200,7 +4942,7 @@ impl Compiler { count: defaults.len().to_u32() } ); - funcflags.insert(MakeFunctionFlag::Defaults); + funcflags.insert(bytecode::MakeFunctionFlag::Defaults); } // Handle keyword-only defaults @@ -4227,7 +4969,7 @@ impl Compiler { count: kw_with_defaults.len().to_u32(), } ); - funcflags.insert(MakeFunctionFlag::KwOnlyDefaults); + funcflags.insert(bytecode::MakeFunctionFlag::KwOnlyDefaults); } Ok(funcflags) @@ -4248,7 +4990,7 @@ impl Compiler { self.enter_function(name, parameters)?; self.current_code_info() .flags - .set(CodeFlags::COROUTINE, is_async); + .set(bytecode::CodeFlags::COROUTINE, is_async); // Set up context let prev_ctx = self.ctx; @@ -4267,7 +5009,7 @@ impl Compiler { self.set_qualname(); // Handle docstring - store in co_consts[0] if present - let (doc_info, body) = split_doc_with_range(body, self.opts); + let (doc_info, body) = split_doc_with_range(body, &self.opts); let doc_str = doc_info.as_ref().map(|(doc, _)| doc); if let Some(doc) = &doc_str { // Docstring present: store in co_consts[0] and set HAS_DOCSTRING flag @@ -4277,7 +5019,7 @@ impl Compiler { .insert_full(ConstantData::Str { value: (*doc).to_string().into(), }); - self.current_code_info().flags |= CodeFlags::HAS_DOCSTRING; + self.current_code_info().flags |= bytecode::CodeFlags::HAS_DOCSTRING; } let start_label = self.use_cpython_function_start_label(); @@ -4300,19 +5042,8 @@ impl Compiler { // Compile body statements self.compile_statements(body)?; - // Emit implicit `return None` if the body doesn't end with return. - // Also ensure None is in co_consts even when not emitting return - // (matching CPython: functions without explicit constants always - // have None in co_consts). - match body.last() { - Some(ast::Stmt::Return(_)) => {} - _ => { - self.emit_return_const_no_location(ConstantData::None); - } - } - // Functions with no other constants should still have None in co_consts - if self.current_code_info().metadata.consts.is_empty() { - self.arg_constant(ConstantData::None); + if stop_iteration_block.is_some() { + self.emit_return_const_no_location(ConstantData::None); } // Close StopIteration handler and emit handler code @@ -4329,6 +5060,7 @@ impl Compiler { emit!(self, Instruction::Reraise { depth: 1u32 }); self.set_no_location(); } + self.emit_return_const_no_location(ConstantData::None); // Exit scope and create function object let code = self.exit_scope(); @@ -4347,7 +5079,7 @@ impl Compiler { /// Compile function annotations as a closure (PEP 649) /// Returns true if an __annotate__ closure was created - /// Uses symbol table's annotation_block for proper scoping. + /// Uses the matching annotation symbol table for proper scoping. fn compile_annotations_closure( &mut self, func_name: &str, @@ -4355,21 +5087,11 @@ impl Compiler { returns: Option<&ast::Expr>, func_range: TextRange, ) -> CompileResult { - let has_signature_annotations = parameters - .args - .iter() - .map(|x| &x.parameter) - .chain(parameters.posonlyargs.iter().map(|x| &x.parameter)) - .chain(parameters.vararg.as_deref()) - .chain(parameters.kwonlyargs.iter().map(|x| &x.parameter)) - .chain(parameters.kwarg.as_deref()) - .any(|param| param.annotation.is_some()) - || returns.is_some(); - if !has_signature_annotations { + if !self.next_function_annotation_symbol_table_uses_annotations() { return Ok(false); } - // Try to enter annotation scope - returns None if no annotation_block exists + // Try to enter annotation scope - returns None if no matching symbol table exists. let Some(saved_ctx) = self.enter_annotation_scope(func_name, func_range)? else { return Ok(false); }; @@ -4427,13 +5149,14 @@ impl Compiler { } ); emit!(self, Instruction::ReturnValue); + self.emit_return_const_no_location(ConstantData::None); // Exit the annotation scope and get the code object let annotate_code = self.exit_annotation_scope(saved_ctx); // Make a closure from the code object self.set_source_range(func_range); - self.make_closure(annotate_code, MakeFunctionFlags::new())?; + self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; Ok(true) } @@ -4442,29 +5165,73 @@ impl Compiler { /// (including nested conditional blocks). This preserves the same walk /// order as symbol-table construction so the annotation scope's /// `sub_tables` cursor stays aligned. - fn collect_annotations(body: &[ast::Stmt]) -> Vec<&ast::StmtAnnAssign> { - use ast::visitor::Visitor; - - #[derive(Default)] - struct AnnotationsVisitor<'a> { - annotations: Vec<&'a ast::StmtAnnAssign>, - } - - impl<'a> Visitor<'a> for AnnotationsVisitor<'a> { - fn visit_stmt(&mut self, stmt: &'a ast::Stmt) { + fn collect_annotations( + body: &[ast::Stmt], + parent_scope_type: CompilerScope, + ) -> Vec<(&ast::StmtAnnAssign, bool)> { + fn walk<'a>( + stmts: &'a [ast::Stmt], + out: &mut Vec<(&'a ast::StmtAnnAssign, bool)>, + in_conditional_block: bool, + module_scope: bool, + ) { + for stmt in stmts { match stmt { - ast::Stmt::AnnAssign(ann_assign) => self.annotations.push(ann_assign), - ast::Stmt::ClassDef(_) | ast::Stmt::FunctionDef(_) => {} - _ => ast::visitor::walk_stmt(self, stmt), + ast::Stmt::AnnAssign(stmt) => { + out.push((stmt, module_scope || in_conditional_block)); + } + ast::Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + walk(body, out, true, module_scope); + for clause in elif_else_clauses { + walk(&clause.body, out, true, module_scope); + } + } + ast::Stmt::For(ast::StmtFor { body, orelse, .. }) + | ast::Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + walk(body, out, true, module_scope); + walk(orelse, out, true, module_scope); + } + ast::Stmt::With(ast::StmtWith { body, .. }) => { + walk(body, out, true, module_scope); + } + ast::Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + walk(body, out, true, module_scope); + for handler in handlers { + let ast::ExceptHandler::ExceptHandler( + ast::ExceptHandlerExceptHandler { body, .. }, + ) = handler; + walk(body, out, true, module_scope); + } + walk(orelse, out, true, module_scope); + walk(finalbody, out, true, module_scope); + } + ast::Stmt::Match(ast::StmtMatch { cases, .. }) => { + for case in cases { + walk(&case.body, out, true, module_scope); + } + } + _ => {} } } } - - let mut visitor = AnnotationsVisitor::default(); - for stmt in body { - visitor.visit_stmt(stmt); - } - visitor.annotations + let mut annotations = Vec::new(); + walk( + body, + &mut annotations, + false, + parent_scope_type == CompilerScope::Module, + ); + annotations } fn compile_annotation_for_symbol_cursor_only( @@ -4482,20 +5249,21 @@ impl Compiler { loc: Option, ) -> CompileResult { let loc = loc.unwrap_or(self.current_source_range); - let annotations = Self::collect_annotations(body); - let has_simple_annotation = annotations + // Get parent scope type BEFORE pushing annotation symbol table. + let parent_scope_type = self.current_symbol_table().typ; + let annotations = Self::collect_annotations(body, parent_scope_type); + let simple_annotation_count = annotations .iter() - .any(|stmt| stmt.simple && matches!(stmt.target.as_ref(), ast::Expr::Name(_))); + .filter(|(stmt, _)| stmt.simple && matches!(stmt.target.as_ref(), ast::Expr::Name(_))) + .count(); - if !has_simple_annotation { + if simple_annotation_count == 0 { return Ok(false); } // Check if we have conditional annotations let has_conditional = self.current_symbol_table().has_conditional_annotations; - // Get parent scope type BEFORE pushing annotation symbol table - let parent_scope_type = self.current_symbol_table().typ; // Try to push annotation symbol table from current scope if !self.push_current_annotation_symbol_table() { return Ok(false); @@ -4520,11 +5288,12 @@ impl Compiler { lineno.to_u32(), )?; - // Add 'format' parameter to varnames + // Keep CPython's internal ".format" name; the final code object + // exposes this parameter as "format". self.current_code_info() .metadata .varnames - .insert("format".to_owned()); + .insert(".format".to_owned()); // Emit format validation: if format > VALUE_WITH_FAKE_GLOBALS: raise NotImplementedError self.emit_format_validation(); @@ -4532,8 +5301,8 @@ impl Compiler { self.set_source_range(loc); emit!(self, Instruction::BuildMap { count: 0 }); - let mut simple_idx = 0usize; - for stmt in annotations { + let mut conditional_idx = 0usize; + for (stmt, is_conditional) in annotations { let ast::StmtAnnAssign { target, annotation, @@ -4557,16 +5326,17 @@ impl Compiler { continue; } - let not_set_block = has_conditional.then(|| self.new_block()); - let not_set_label = - (!has_conditional).then(|| self.current_code_info().new_instr_sequence_label()); + let not_set_block = (has_conditional && is_conditional).then(|| self.new_block()); + let not_set_label = (!has_conditional || !is_conditional) + .then(|| self.current_code_info().new_instr_sequence_label()); let name = simple_name.expect("missing simple annotation name"); - if has_conditional { + if let Some(not_set_block) = not_set_block { self.set_source_range(*range); self.emit_load_const(ConstantData::Integer { - value: simple_idx.into(), + value: conditional_idx.into(), }); + conditional_idx += 1; if parent_scope_type == CompilerScope::Class { let idx = self.get_free_var_index("__conditional_annotations__"); emit!(self, Instruction::LoadDeref { i: idx }); @@ -4583,7 +5353,7 @@ impl Compiler { emit!( self, Instruction::PopJumpIfFalse { - delta: not_set_block.expect("missing not_set block") + delta: not_set_block } ); } @@ -4596,7 +5366,6 @@ impl Compiler { }); self.set_source_range(loc); emit!(self, Instruction::StoreSubscr); - simple_idx += 1; if let Some(not_set_block) = not_set_block { self.use_cpython_label_block(not_set_block); @@ -4610,6 +5379,7 @@ impl Compiler { self.set_source_range(loc); emit!(self, Instruction::ReturnValue); + self.emit_return_const_no_location(ConstantData::None); // Exit annotation scope - pop symbol table, restore to parent's annotation_block, and get code let annotation_table = self.pop_symbol_table(); @@ -4622,14 +5392,15 @@ impl Compiler { self.ctx = saved_ctx; // Exit code scope let pop = self.code_stack.pop(); - let annotate_code = unwrap_internal( + let mut annotate_code = unwrap_internal( self, compiler_unwrap_option(self, pop).finalize_code(&self.opts), ); + Self::expose_annotation_format_parameter(&mut annotate_code); // Make a closure from the code object self.set_source_range(loc); - self.make_closure(annotate_code, MakeFunctionFlags::new())?; + self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; // Store as __annotate_func__ for classes, __annotate__ for modules let name = if parent_scope_type == CompilerScope::Class { @@ -4665,14 +5436,25 @@ impl Compiler { if is_async { "async def " } else { "def " }, ); + // CPython's symtable visits defaults before decorators, but + // codegen_function() emits decorators first. Keep the symbol table in + // CPython order and only move the codegen cursor while compiling + // decorators. + let defaults_symbol_table_cursors = self.current_symbol_table_cursors(); + self.consume_skipped_nested_scopes_in_parameter_defaults(parameters)?; self.prepare_decorators(decorator_list)?; + let after_decorators_symbol_table_cursors = self.current_symbol_table_cursors(); + self.set_symbol_table_cursors(defaults_symbol_table_cursors); + + // CPython uses the first decorator line for code objects created by + // this definition, but LOC(s) for the surrounding instructions. + let firstlineno_range = decorator_list + .first() + .map_or(stmt_source_range, |decorator| decorator.expression.range()); // compile defaults and return funcflags let funcflags = self.compile_default_arguments(parameters, def_source_range)?; - - // Restore the `def` line range so that enter_function → push_output → get_source_line_number() - // records the `def` keyword's line as co_firstlineno, not the last default-argument line. - self.set_source_range(def_source_range); + self.set_symbol_table_cursors(after_decorators_symbol_table_cursors); let is_generic = type_params.is_some(); let mut num_typeparam_args = 0u32; @@ -4682,20 +5464,22 @@ impl Compiler { if is_generic { // Count args to pass to type params scope - if funcflags.contains(&MakeFunctionFlag::Defaults) { + if funcflags.contains(&bytecode::MakeFunctionFlag::Defaults) { num_typeparam_args += 1; } - if funcflags.contains(&MakeFunctionFlag::KwOnlyDefaults) { + if funcflags.contains(&bytecode::MakeFunctionFlag::KwOnlyDefaults) { num_typeparam_args += 1; } if num_typeparam_args == 2 { + self.set_source_range(def_source_range); emit!(self, Instruction::Swap { i: 2 }); } // Enter type params scope let type_params_name = format!(""); + self.set_source_range(firstlineno_range); self.push_output( - CodeFlags::OPTIMIZED | CodeFlags::NEWLOCALS, + bytecode::CodeFlags::OPTIMIZED | bytecode::CodeFlags::NEWLOCALS, 0, num_typeparam_args, 0, @@ -4712,13 +5496,13 @@ impl Compiler { // Add parameter names to varnames for the type params scope // These will be passed as arguments when the closure is called let current_info = self.current_code_info(); - if funcflags.contains(&MakeFunctionFlag::Defaults) { + if funcflags.contains(&bytecode::MakeFunctionFlag::Defaults) { current_info .metadata .varnames .insert(".defaults".to_owned()); } - if funcflags.contains(&MakeFunctionFlag::KwOnlyDefaults) { + if funcflags.contains(&bytecode::MakeFunctionFlag::KwOnlyDefaults) { current_info .metadata .varnames @@ -4729,6 +5513,7 @@ impl Compiler { self.compile_type_params(type_params.unwrap())?; // Load defaults/kwdefaults with LOAD_FAST + self.set_source_range(def_source_range); for i in 0..num_typeparam_args { let var_num = oparg::VarNum::from(i); emit!(self, Instruction::LoadFast { var_num }); @@ -4736,13 +5521,14 @@ impl Compiler { } // Compile annotations as closure (PEP 649) - let mut annotations_flag = MakeFunctionFlags::new(); + let mut annotations_flag = bytecode::MakeFunctionFlags::new(); if self.compile_annotations_closure(name, parameters, returns, def_source_range)? { - annotations_flag.insert(MakeFunctionFlag::Annotate); + annotations_flag.insert(bytecode::MakeFunctionFlag::Annotate); } - // Compile function body - self.set_source_range(stmt_source_range); + // Compile function body. CPython's codegen_function() uses the first + // decorator line for co_firstlineno, but LOC(s) for MAKE_FUNCTION. + self.set_source_range(firstlineno_range); let final_funcflags = funcflags | annotations_flag; self.compile_function_body( name, @@ -4757,9 +5543,11 @@ impl Compiler { if is_generic { // SWAP to get function on top // Stack: [type_params_tuple, function] -> [function, type_params_tuple] + self.set_source_range(def_source_range); emit!(self, Instruction::Swap { i: 2 }); // Call INTRINSIC_SET_FUNCTION_TYPE_PARAMS + self.set_source_range(def_source_range); emit!( self, Instruction::CallIntrinsic2 { @@ -4769,6 +5557,7 @@ impl Compiler { // Return the function object from type params scope emit!(self, Instruction::ReturnValue); + self.set_no_location(); // Set argcount for type params scope self.current_code_info().metadata.argcount = num_typeparam_args; @@ -4779,15 +5568,18 @@ impl Compiler { self.ctx = saved_ctx; // Make closure for type params code - self.make_closure(type_params_code, MakeFunctionFlags::new())?; + self.set_source_range(def_source_range); + self.make_closure(type_params_code, bytecode::MakeFunctionFlags::new())?; if num_typeparam_args > 0 { + self.set_source_range(def_source_range); emit!( self, Instruction::Swap { i: num_typeparam_args + 1 } ); + self.set_source_range(def_source_range); emit!( self, Instruction::Call { @@ -4796,8 +5588,10 @@ impl Compiler { ); } else { // Stack: [closure] + self.set_source_range(def_source_range); emit!(self, Instruction::PushNull); // Stack: [closure, NULL] + self.set_source_range(def_source_range); emit!(self, Instruction::Call { argc: 0 }); } } @@ -4818,36 +5612,31 @@ impl Compiler { /// Determines if a variable should be CELL or FREE type // = get_ref_type fn get_ref_type(&self, name: &str) -> Result { - let table = self.current_symbol_table(); + let table = self.symbol_table_stack.last().unwrap(); // Special handling for __class__, __classdict__, and __conditional_annotations__ in class scope // This should only apply when we're actually IN a class body, // not when we're in a method nested inside a class. if table.typ == CompilerScope::Class - && matches!( - name, - "__class__" | "__classdict__" | "__conditional_annotations__" - ) + && (name == "__class__" + || name == "__classdict__" + || name == "__conditional_annotations__") { return Ok(SymbolScope::Cell); } - - let Some(symbol) = table.lookup(name) else { - return Err(CodegenErrorType::SyntaxError(format!( - "get_ref_type: cannot find symbol '{name}'" - ))); - }; - - Ok(match symbol.scope { - SymbolScope::Cell => SymbolScope::Cell, - SymbolScope::Free => SymbolScope::Free, - _ if symbol.flags.contains(SymbolFlags::FREE_CLASS) => SymbolScope::Free, - _ => { - return Err(CodegenErrorType::SyntaxError(format!( + match table.lookup(name) { + Some(symbol) => match symbol.scope { + SymbolScope::Cell => Ok(SymbolScope::Cell), + SymbolScope::Free => Ok(SymbolScope::Free), + _ if symbol.flags.contains(SymbolFlags::FREE_CLASS) => Ok(SymbolScope::Free), + _ => Err(CodegenErrorType::SyntaxError(format!( "get_ref_type: invalid scope for '{name}'" - ))); - } - }) + ))), + }, + None => Err(CodegenErrorType::SyntaxError(format!( + "get_ref_type: cannot find symbol '{name}'" + ))), + } } /// Loads closure variables if needed and creates a function object @@ -4940,57 +5729,47 @@ impl Compiler { emit!( self, Instruction::SetFunctionAttribute { - flag: MakeFunctionFlag::Closure + flag: bytecode::MakeFunctionFlag::Closure } ); } // Set annotations if present - if flags.contains(&MakeFunctionFlag::Annotations) { + if flags.contains(&bytecode::MakeFunctionFlag::Annotations) { emit!( self, Instruction::SetFunctionAttribute { - flag: MakeFunctionFlag::Annotations + flag: bytecode::MakeFunctionFlag::Annotations } ); } // Set __annotate__ closure if present (PEP 649) - if flags.contains(&MakeFunctionFlag::Annotate) { + if flags.contains(&bytecode::MakeFunctionFlag::Annotate) { emit!( self, Instruction::SetFunctionAttribute { - flag: MakeFunctionFlag::Annotate + flag: bytecode::MakeFunctionFlag::Annotate } ); } // Set kwdefaults if present - if flags.contains(&MakeFunctionFlag::KwOnlyDefaults) { + if flags.contains(&bytecode::MakeFunctionFlag::KwOnlyDefaults) { emit!( self, Instruction::SetFunctionAttribute { - flag: MakeFunctionFlag::KwOnlyDefaults + flag: bytecode::MakeFunctionFlag::KwOnlyDefaults } ); } // Set defaults if present - if flags.contains(&MakeFunctionFlag::Defaults) { - emit!( - self, - Instruction::SetFunctionAttribute { - flag: MakeFunctionFlag::Defaults - } - ); - } - - // Set type_params if present - if flags.contains(&MakeFunctionFlag::TypeParams) { + if flags.contains(&bytecode::MakeFunctionFlag::Defaults) { emit!( self, Instruction::SetFunctionAttribute { - flag: MakeFunctionFlag::TypeParams + flag: bytecode::MakeFunctionFlag::Defaults } ); } @@ -5017,55 +5796,6 @@ impl Compiler { } } - // Python/compile.c find_ann - fn find_ann(body: &[ast::Stmt]) -> bool { - for statement in body { - let res = match &statement { - ast::Stmt::AnnAssign(_) => true, - ast::Stmt::For(ast::StmtFor { body, orelse, .. }) => { - Self::find_ann(body) || Self::find_ann(orelse) - } - ast::Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - Self::find_ann(body) - || elif_else_clauses.iter().any(|x| Self::find_ann(&x.body)) - } - ast::Stmt::While(ast::StmtWhile { body, orelse, .. }) => { - Self::find_ann(body) || Self::find_ann(orelse) - } - ast::Stmt::With(ast::StmtWith { body, .. }) => Self::find_ann(body), - ast::Stmt::Match(ast::StmtMatch { cases, .. }) => { - cases.iter().any(|case| Self::find_ann(&case.body)) - } - ast::Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - Self::find_ann(body) - || handlers.iter().any(|h| { - let ast::ExceptHandler::ExceptHandler( - ast::ExceptHandlerExceptHandler { body, .. }, - ) = h; - Self::find_ann(body) - }) - || Self::find_ann(orelse) - || Self::find_ann(finalbody) - } - _ => false, - }; - if res { - return true; - } - } - false - } - /// Compile the class body into a code object // = compiler_class_body fn compile_class_body( @@ -5077,7 +5807,7 @@ impl Compiler { ) -> CompileResult { // 1. Enter class scope let key = self.symbol_table_stack.len(); - self.push_symbol_table()?; + self.push_symbol_table_matching(CompilerScope::Class, name)?; self.enter_scope(name, CompilerScope::Class, key, firstlineno)?; // Set qualname using the new method @@ -5087,7 +5817,7 @@ impl Compiler { self.code_stack.last_mut().unwrap().private = Some(name.to_owned()); // 2. Set up class namespace - let (doc_str, body) = split_doc_with_range(body, self.opts); + let (doc_str, body) = split_doc_with_range(body, &self.opts); let class_body_prefix_range = self.source_line_start_range(firstlineno); self.set_source_range(class_body_prefix_range); @@ -5123,15 +5853,14 @@ impl Compiler { } // Handle class annotation bookkeeping in CPython order. - if Self::find_ann(body) { - if Self::scope_needs_conditional_annotations_cell(self.current_symbol_table()) { - emit!(self, Instruction::BuildSet { count: 0 }); - self.store_name("__conditional_annotations__")?; - } + let annotations_used = self.current_symbol_table().annotations_used; + if Self::scope_needs_conditional_annotations_cell(self.current_symbol_table()) { + emit!(self, Instruction::BuildSet { count: 0 }); + self.store_name("__conditional_annotations__")?; + } - if self.future_annotations { - emit!(self, Instruction::SetupAnnotations); - } + if self.future_annotations && annotations_used { + emit!(self, Instruction::SetupAnnotations); } // Store __doc__ only if there's an explicit docstring. @@ -5140,13 +5869,14 @@ impl Compiler { self.set_source_range(range); self.emit_load_const(ConstantData::Str { value: doc.into() }); self.store_name("__doc__")?; + self.set_no_location(); self.set_source_range(saved_range); } // 3. Compile the class body self.compile_statements(body)?; - if Self::find_ann(body) && !self.future_annotations { + if annotations_used && !self.future_annotations { self.compile_module_annotate(body, Some(class_body_prefix_range))?; } @@ -5211,6 +5941,7 @@ impl Compiler { // Return the class namespace self.emit_return_value(); self.set_no_location(); + self.emit_return_const_no_location(ConstantData::None); // Exit scope and return the code object Ok(self.exit_scope()) @@ -5233,6 +5964,9 @@ impl Compiler { self.prepare_decorators(decorator_list)?; let is_generic = type_params.is_some(); + let firstlineno_range = decorator_list + .first() + .map_or(stmt_source_range, |decorator| decorator.expression.range()); #[expect(clippy::map_unwrap_or, reason = "Changing this will not compile")] let firstlineno = decorator_list .first() @@ -5251,8 +5985,9 @@ impl Compiler { // Step 1: If generic, enter type params scope and compile type params if is_generic { let type_params_name = format!(""); + self.set_source_range(firstlineno_range); self.push_output( - CodeFlags::OPTIMIZED | CodeFlags::NEWLOCALS, + bytecode::CodeFlags::OPTIMIZED | bytecode::CodeFlags::NEWLOCALS, 0, 0, 0, @@ -5283,7 +6018,10 @@ impl Compiler { in_class: true, in_async_scope: false, }; + let pre_class_body_symbol_table_cursors = self.current_symbol_table_cursors(); let class_code = self.compile_class_body(name, body, type_params, firstlineno)?; + let post_class_body_symbol_table_cursors = self.current_symbol_table_cursors(); + self.set_symbol_table_cursors(pre_class_body_symbol_table_cursors); self.ctx = prev_ctx; self.set_source_range(class_source_range); @@ -5295,7 +6033,7 @@ impl Compiler { // Create the class body function with the .type_params closure // captured through the class code object's freevars. - self.make_closure(class_code, MakeFunctionFlags::new())?; + self.make_closure(class_code, bytecode::MakeFunctionFlags::new())?; self.emit_load_const(ConstantData::Str { value: name.into() }); // Create .generic_base after the class function and name are on the @@ -5311,134 +6049,21 @@ impl Compiler { self.set_source_range(class_source_range); self.store_name(".generic_base")?; - // Compile bases and call __build_class__ - // Check for starred bases or **kwargs - let has_starred = arguments.is_some_and(|args| { - args.args - .iter() - .any(|arg| matches!(arg, ast::Expr::Starred(_))) + let (bases, keywords) = arguments.map_or((&[][..], &[][..]), |args| { + (&args.args[..], &args.keywords[..]) }); - let has_double_star = - arguments.is_some_and(|args| args.keywords.iter().any(|kw| kw.arg.is_none())); - - if has_starred { - // Use CallFunctionEx for *bases or **kwargs - // Stack has: [__build_class__, NULL, class_func, name] - // Need to build: args tuple = (class_func, name, *bases, .generic_base) - - // Build a list starting with class_func and name (2 elements already on stack) - emit!(self, Instruction::BuildList { count: 2 }); - - // Add bases to the list - if let Some(arguments) = arguments { - for arg in &arguments.args { - if let ast::Expr::Starred(ast::ExprStarred { value, .. }) = arg { - // Starred: compile and extend - self.compile_expression(value)?; - emit!(self, Instruction::ListExtend { i: 1 }); - } else { - // Non-starred: compile and append - self.compile_expression(arg)?; - emit!(self, Instruction::ListAppend { i: 1 }); - } - } - } - - // Add .generic_base as final element - self.set_source_range(class_source_range); - self.load_name(".generic_base")?; - self.set_source_range(class_source_range); - emit!(self, Instruction::ListAppend { i: 1 }); - - // Convert list to tuple - self.set_source_range(class_source_range); - emit!( - self, - Instruction::CallIntrinsic1 { - func: IntrinsicFunction1::ListToTuple - } - ); - - self.compile_call_function_ex_keywords( - arguments.map_or(&[][..], |args| &args.keywords[..]), - class_source_range, - )?; - emit!(self, Instruction::CallFunctionEx); - } else if has_double_star { - if let Some(arguments) = arguments { - for arg in &arguments.args { - self.compile_expression(arg)?; - } - } - self.set_source_range(class_source_range); - self.load_name(".generic_base")?; - self.set_source_range(class_source_range); - emit!( - self, - Instruction::BuildTuple { - count: 3 + arguments - .map_or(0, |args| u32::try_from(args.args.len()).unwrap()) - } - ); - self.compile_call_function_ex_keywords( - &arguments.unwrap().keywords[..], - class_source_range, - )?; - emit!(self, Instruction::CallFunctionEx); - } else { - // Simple case: no starred bases, no **kwargs - // Compile bases normally - let base_count = if let Some(arguments) = arguments { - for arg in &arguments.args { - self.compile_expression(arg)?; - } - arguments.args.len() - } else { - 0 - }; - - // Load .generic_base as the last base - self.set_source_range(class_source_range); - self.load_name(".generic_base")?; - - let nargs = 2 + u32::try_from(base_count).expect("too many base classes") + 1; - - // Handle keyword arguments (no **kwargs here) - if let Some(arguments) = arguments - && !arguments.keywords.is_empty() - { - let mut kwarg_names = vec![]; - for keyword in &arguments.keywords { - let name = keyword.arg.as_ref().expect( - "keyword argument name must be set (no **kwargs in this branch)", - ); - kwarg_names.push(ConstantData::Str { - value: name.as_str().into(), - }); - self.compile_expression(&keyword.value)?; - } - self.set_source_range(class_source_range); - self.emit_load_const(ConstantData::Tuple { - elements: kwarg_names, - }); - self.set_source_range(class_source_range); - emit!( - self, - Instruction::CallKw { - argc: nargs - + u32::try_from(arguments.keywords.len()) - .expect("too many keyword arguments") - } - ); - } else { - self.set_source_range(class_source_range); - emit!(self, Instruction::Call { argc: nargs }); - } - } + self.codegen_call_helper_impl( + 2, + bases, + keywords, + class_source_range, + None, + Some(".generic_base"), + )?; // Return the created class - self.set_source_range(class_source_range); self.emit_return_value(); + self.set_no_location(); // Exit type params scope and wrap in function let type_params_code = self.exit_scope(); @@ -5446,7 +6071,7 @@ impl Compiler { // Execute the type params function self.set_source_range(class_source_range); - self.make_closure(type_params_code, MakeFunctionFlags::new())?; + self.make_closure(type_params_code, bytecode::MakeFunctionFlags::new())?; self.set_source_range(class_source_range); emit!(self, Instruction::PushNull); self.set_source_range(class_source_range); @@ -5457,7 +6082,7 @@ impl Compiler { emit!(self, Instruction::PushNull); // Create class function with closure - self.make_closure(class_code, MakeFunctionFlags::new())?; + self.make_closure(class_code, bytecode::MakeFunctionFlags::new())?; self.emit_load_const(ConstantData::Str { value: name.into() }); if let Some(arguments) = arguments { @@ -5466,6 +6091,7 @@ impl Compiler { self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: 2 }); } + self.set_symbol_table_cursors(post_class_body_symbol_table_cursors); } // Step 4: Apply decorators and store (common to both paths) @@ -5484,7 +6110,7 @@ impl Compiler { test: &ast::Expr, body: &[ast::Stmt], elif_else_clauses: &[ast::ElifElseClause], - _stmt_range: TextRange, + stmt_range: TextRange, ) -> CompileResult<()> { let end_block = self.new_block(); let next_block = if elif_else_clauses.is_empty() { @@ -5493,7 +6119,7 @@ impl Compiler { self.new_block() }; - self.compile_jump_if(test, false, next_block)?; + self.compile_jump_if_inner(test, false, next_block, Some(stmt_range))?; self.compile_statements(body)?; let Some((clause, rest)) = elif_else_clauses.split_first() else { @@ -5509,7 +6135,7 @@ impl Compiler { self.use_cpython_label_block(next_block); if let Some(test) = &clause.test { - self.compile_if(test, &clause.body, rest, test.range())?; + self.compile_if(test, &clause.body, rest, clause.range)?; } else { debug_assert!(rest.is_empty()); self.compile_statements(&clause.body)?; @@ -5523,6 +6149,7 @@ impl Compiler { test: &ast::Expr, body: &[ast::Stmt], orelse: &[ast::Stmt], + while_range: TextRange, ) -> CompileResult<()> { self.enter_conditional_block(); @@ -5538,7 +6165,7 @@ impl Compiler { end_label, FBlockDatum::None, )?; - self.compile_jump_if(test, false, anchor_block)?; + self.compile_jump_if_inner(test, false, anchor_block, Some(while_range))?; self.compile_loop_body_statements(body)?; emit!(self, PseudoInstruction::Jump { delta: loop_block }); @@ -5560,7 +6187,17 @@ impl Compiler { is_async: bool, ) -> CompileResult<()> { self.enter_conditional_block(); + let result = self.compile_with_inner(items, body, is_async); + self.leave_conditional_block(); + result + } + fn compile_with_inner( + &mut self, + items: &[ast::WithItem], + body: &[ast::Stmt], + is_async: bool, + ) -> CompileResult<()> { // Python 3.12+ style with statement: // // BEFORE_WITH # TOS: ctx_mgr -> [__exit__, __enter__ result] @@ -5606,7 +6243,9 @@ impl Compiler { emit!(self, Instruction::Copy { i: 1 }); // [cm, cm] if is_async { - if self.ctx.func != FunctionContext::AsyncFunction { + if self.ctx.func != FunctionContext::AsyncFunction + && !self.allows_top_level_await_in_current_context() + { return Err(self.error(CodegenErrorType::InvalidAsyncWith)); } // Load __aexit__ and __aenter__, then call __aenter__ @@ -5691,7 +6330,7 @@ impl Compiler { self.compile_with_body_statements(body)?; } else { self.set_source_range(items[0].context_expr.range()); - self.compile_with(items, body, is_async)?; + self.compile_with_inner(items, body, is_async)?; } // CPython pops the async-with fblock before emitting POP_BLOCK, but @@ -5715,7 +6354,6 @@ impl Compiler { } emit!(self, Instruction::PopTop); // Pop __exit__ result emit!(self, PseudoInstruction::Jump { delta: after_block }); - self.set_no_location(); // ===== Exception handler path ===== // Stack at entry: [..., exit_func, self_exit, lasti, exc] @@ -5745,7 +6383,6 @@ impl Compiler { self.use_cpython_label_block(after_block); - self.leave_conditional_block(); Ok(()) } @@ -5786,10 +6423,12 @@ impl Compiler { } // The thing iterated: - self.compile_for_iterable_expression(iter, is_async)?; + self.compile_expression(iter)?; if is_async { - if self.ctx.func != FunctionContext::AsyncFunction { + if self.ctx.func != FunctionContext::AsyncFunction + && !self.allows_top_level_await_in_current_context() + { return Err(self.error(CodegenErrorType::InvalidAsyncFor)); } self.set_source_range(iter.range()); @@ -5821,6 +6460,7 @@ impl Compiler { self.compile_store(target)?; } else { // Retrieve Iterator + self.set_source_range(iter.range()); emit!(self, Instruction::GetIter); self.use_cpython_label_block(for_block); @@ -5842,12 +6482,19 @@ impl Compiler { emit!(self, PseudoInstruction::Jump { delta: for_block }); self.set_no_location(); + if is_async { + // CPython codegen_async_for() pops the loop fblock before the + // END_ASYNC_FOR exception block. Sync codegen_for() keeps the + // fblock through END_FOR/POP_ITER and pops below. + self.pop_fblock_label(FBlockType::ForLoop, for_label); + } + self.use_cpython_label_block(else_block); // Except block for __anext__ / end of sync for if is_async { // codegen_async_for emits END_ASYNC_FOR at the iterator location, - // then pops the for-loop fblock before the else block. + // after the for-loop fblock has already been popped. let saved_range = self.current_source_range; self.set_source_range(iter.range()); self.emit_end_async_for(end_async_for_target); @@ -5859,9 +6506,8 @@ impl Compiler { self.set_no_location(); emit!(self, Instruction::PopIter); self.set_no_location(); + self.pop_fblock_label(FBlockType::ForLoop, for_label); } - // No PopBlock here - for async, POP_BLOCK is already in for_block - self.pop_fblock_label(FBlockType::ForLoop, for_label); self.compile_statements(orelse)?; self.use_cpython_label_block(after_block); @@ -5873,39 +6519,9 @@ impl Compiler { Ok(()) } - fn compile_for_iterable_expression( - &mut self, - iter: &ast::Expr, - is_async: bool, - ) -> CompileResult<()> { - // Match CPython's iterable lowering for `for`/comprehension fronts: - // a non-starred list literal used only for iteration is emitted as a tuple. - // Skip async-for/async comprehension iteration because GET_AITER expects - // the original object semantics. - if !is_async - && let ast::Expr::List(ast::ExprList { elts, .. }) = iter - && elts.len() <= usize::try_from(STACK_USE_GUIDELINE).unwrap() - && !elts.iter().any(|e| matches!(e, ast::Expr::Starred(_))) - { - for elt in elts { - self.compile_expression(elt)?; - } - self.set_source_range(iter.range()); - emit!( - self, - Instruction::BuildList { - count: u32::try_from(elts.len()).expect("too many elements"), - } - ); - return Ok(()); - } - - self.compile_expression(iter) - } - fn compile_comprehension_iter(&mut self, generator: &ast::Comprehension) -> CompileResult<()> { let saved_range = self.current_source_range; - self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; + self.compile_expression(&generator.iter)?; self.set_source_range(generator.iter.range()); if generator.is_async { emit!(self, Instruction::GetAiter); @@ -5928,24 +6544,6 @@ impl Compiler { } } - fn forbidden_name(&mut self, name: &str, ctx: NameUsage) -> CompileResult { - if ctx == NameUsage::Store && name == "__debug__" { - return Err(self.error(CodegenErrorType::Assign("__debug__"))); - // return Ok(true); - } - if ctx == NameUsage::Delete && name == "__debug__" { - return Err(self.error(CodegenErrorType::Delete("__debug__"))); - // return Ok(true); - } - Ok(false) - } - - fn compile_error_forbidden_name(&mut self, name: &str) -> CodegenError { - self.error(CodegenErrorType::SyntaxError(format!( - "cannot use forbidden name '{name}' in pattern" - ))) - } - /// Ensures that `pc.fail_pop` has at least `n + 1` entries. /// If not, new labels are generated and pushed until the required size is reached. fn ensure_fail_pop(&mut self, pc: &mut PatternContext, n: usize) { @@ -5988,7 +6586,7 @@ impl Compiler { /// Emits the necessary POP instructions for all failure targets in the pattern context, /// then resets the fail_pop vector. - fn emit_and_reset_fail_pop(&mut self, pc: &mut PatternContext) { + fn emit_and_reset_fail_pop(&mut self, pc: &mut PatternContext, loc: TextRange) { // If the fail_pop vector is empty, nothing needs to be done. if pc.fail_pop.is_empty() { debug_assert!(pc.fail_pop.is_empty()); @@ -5999,6 +6597,7 @@ impl Compiler { // CPython emit_and_reset_fail_pop() uses USE_LABEL here. self.use_cpython_label_block(label); // Emit the POP instruction. + self.set_source_range(loc); emit!(self, Instruction::PopTop); } // Finally, use the first label. @@ -6010,7 +6609,7 @@ impl Compiler { } /// Duplicate the effect of Python 3.10's ROT_* instructions using SWAPs. - fn pattern_helper_rotate(&mut self, mut count: usize) { + fn pattern_helper_rotate(&mut self, loc: TextRange, mut count: usize) { // Rotate TOS (top of stack) to position `count` down // This is done by a series of swaps // For count=1, no rotation needed (already at top) @@ -6018,6 +6617,7 @@ impl Compiler { // For count=3, swap TOS with item 2 positions down, then with item 1 position down while count > 1 { // Emit a SWAP instruction with the current count. + self.set_source_range(loc); emit!( self, Instruction::Swap { @@ -6036,32 +6636,30 @@ impl Compiler { /// to the list of captured names. fn pattern_helper_store_name( &mut self, + loc: TextRange, n: Option<&ast::Identifier>, pc: &mut PatternContext, ) -> CompileResult<()> { match n { // If no name is provided, simply pop the top of the stack. None => { + self.set_source_range(loc); emit!(self, Instruction::PopTop); Ok(()) } Some(name) => { - // Check if the name is forbidden for storing. - if self.forbidden_name(name.as_str(), NameUsage::Store)? { - return Err(self.compile_error_forbidden_name(name.as_str())); - } - // Ensure we don't store the same name twice. // TODO: maybe pc.stores should be a set? if pc.stores.contains(&name.to_string()) { - return Err( - self.error(CodegenErrorType::DuplicateStore(name.as_str().to_string())) - ); + return Err(self.error_ranged( + CodegenErrorType::DuplicateStore(name.as_str().to_string()), + loc, + )); } // Calculate how many items to rotate: let rotations = pc.on_top + pc.stores.len() + 1; - self.pattern_helper_rotate(rotations); + self.pattern_helper_rotate(loc, rotations); // Append the name to the captured stores. pc.stores.push(name.to_string()); @@ -6070,30 +6668,51 @@ impl Compiler { } } - fn pattern_unpack_helper(&mut self, elts: &[ast::Pattern]) -> CompileResult<()> { + fn pattern_wildcard_check(pattern: &ast::Pattern) -> bool { + matches!( + pattern, + ast::Pattern::MatchAs(ast::PatternMatchAs { name: None, .. }) + ) + } + + fn pattern_wildcard_star_check(pattern: &ast::Pattern) -> bool { + matches!( + pattern, + ast::Pattern::MatchStar(ast::PatternMatchStar { name: None, .. }) + ) + } + + fn pattern_unpack_helper( + &mut self, + loc: TextRange, + elts: &[ast::Pattern], + ) -> CompileResult<()> { let n = elts.len(); let mut seen_star = false; for (i, elt) in elts.iter().enumerate() { - if elt.is_match_star() { - if !seen_star { - if i >= (1 << 8) || (n - i - 1) >= ((i32::MAX as usize) >> 8) { - todo!(); - // return self.compiler_error(loc, "too many expressions in star-unpacking sequence pattern"); - } - let counts = UnpackExArgs { - before: u8::try_from(i).unwrap(), - after: u8::try_from(n - i - 1).unwrap(), - }; - emit!(self, Instruction::UnpackEx { counts }); - seen_star = true; - } else { - // TODO: Fix error msg - return Err(self.error(CodegenErrorType::MultipleStarArgs)); - // return self.compiler_error(loc, "multiple starred expressions in sequence pattern"); + if elt.is_match_star() && !seen_star { + if i >= (1 << 8) || (n - i - 1) >= ((i32::MAX as usize) >> 8) { + return Err(self.error_ranged( + CodegenErrorType::TooManyExpressionsInStarUnpackingSequencePattern, + loc, + )); } + let counts = UnpackExArgs { + before: u8::try_from(i).unwrap(), + after: u32::try_from(n - i - 1).unwrap(), + }; + self.set_source_range(loc); + emit!(self, Instruction::UnpackEx { counts }); + seen_star = true; + } else if elt.is_match_star() { + return Err(self.error_ranged( + CodegenErrorType::MultipleStarredExpressionsInSequencePattern, + loc, + )); } } if !seen_star { + self.set_source_range(loc); emit!( self, Instruction::UnpackSequence { @@ -6106,12 +6725,13 @@ impl Compiler { fn pattern_helper_sequence_unpack( &mut self, + loc: TextRange, patterns: &[ast::Pattern], _star: Option, pc: &mut PatternContext, ) -> CompileResult<()> { // Unpack the sequence into individual subjects. - self.pattern_unpack_helper(patterns)?; + self.pattern_unpack_helper(loc, patterns)?; let size = patterns.len(); // Increase the on_top counter for the newly unpacked subjects. pc.on_top += size; @@ -6126,6 +6746,7 @@ impl Compiler { fn pattern_helper_sequence_subscr( &mut self, + loc: TextRange, patterns: &[ast::Pattern], star: usize, pc: &mut PatternContext, @@ -6133,35 +6754,32 @@ impl Compiler { // Keep the subject around for extracting elements. pc.on_top += 1; for (i, pattern) in patterns.iter().enumerate() { - let is_true_wildcard = matches!( - pattern, - ast::Pattern::MatchAs(ast::PatternMatchAs { - pattern: None, - name: None, - .. - }) - ); - if is_true_wildcard { + if Self::pattern_wildcard_check(pattern) { continue; } if i == star { // This must be a starred wildcard. - // assert!(pattern.is_star_wildcard()); + debug_assert!(Self::pattern_wildcard_star_check(pattern)); continue; } // Duplicate the subject. + self.set_source_range(loc); emit!(self, Instruction::Copy { i: 1 }); if i < star { // For indices before the star, use a nonnegative index equal to i. + self.set_source_range(loc); self.emit_load_const(ConstantData::Integer { value: i.into() }); } else { // For indices after the star, compute a nonnegative index: // index = len(subject) - (size - i) + self.set_source_range(loc); emit!(self, Instruction::GetLen); + self.set_source_range(loc); self.emit_load_const(ConstantData::Integer { value: (patterns.len() - i).into(), }); // Subtract to compute the correct index. + self.set_source_range(loc); emit!( self, Instruction::BinaryOp { @@ -6170,6 +6788,7 @@ impl Compiler { ); } // Use BINARY_OP/NB_SUBSCR to extract the element. + self.set_source_range(loc); emit!( self, Instruction::BinaryOp { @@ -6181,6 +6800,7 @@ impl Compiler { } // Pop the subject off the stack. pc.on_top -= 1; + self.set_source_range(loc); emit!(self, Instruction::PopTop); Ok(()) } @@ -6209,31 +6829,31 @@ impl Compiler { // If there is no sub-pattern, then it's an irrefutable match. if p.pattern.is_none() { if !pc.allow_irrefutable { - if let Some(_name) = p.name.as_ref() { - // TODO: This error message does not match cpython exactly - // A name capture makes subsequent patterns unreachable. - return Err(self.error(CodegenErrorType::UnreachablePattern( - PatternUnreachableReason::NameCapture, - ))); + if let Some(name) = p.name.as_ref() { + return Err(self.error_ranged( + CodegenErrorType::UnreachableNameCapturePattern(name.to_string()), + p.range, + )); } // A wildcard makes remaining patterns unreachable. - return Err(self.error(CodegenErrorType::UnreachablePattern( - PatternUnreachableReason::Wildcard, - ))); + return Err( + self.error_ranged(CodegenErrorType::UnreachableWildcardPattern, p.range) + ); } // If irrefutable matches are allowed, store the name (if any). - return self.pattern_helper_store_name(p.name.as_ref(), pc); + return self.pattern_helper_store_name(p.range, p.name.as_ref(), pc); } // Otherwise, there is a sub-pattern. Duplicate the object on top of the stack. pc.on_top += 1; + self.set_source_range(p.range); emit!(self, Instruction::Copy { i: 1 }); // Compile the sub-pattern. self.compile_pattern(p.pattern.as_ref().unwrap(), pc)?; // After success, decrement the on_top counter. pc.on_top -= 1; // Store the captured name (if any). - self.pattern_helper_store_name(p.name.as_ref(), pc)?; + self.pattern_helper_store_name(p.range, p.name.as_ref(), pc)?; Ok(()) } @@ -6242,7 +6862,7 @@ impl Compiler { p: &ast::PatternMatchStar, pc: &mut PatternContext, ) -> CompileResult<()> { - self.pattern_helper_store_name(p.name.as_ref(), pc)?; + self.pattern_helper_store_name(p.range, p.name.as_ref(), pc)?; Ok(()) } @@ -6251,21 +6871,19 @@ impl Compiler { fn validate_kwd_attrs( &mut self, attrs: &[ast::Identifier], - _patterns: &[ast::Pattern], + patterns: &[ast::Pattern], ) -> CompileResult<()> { let n_attrs = attrs.len(); for i in 0..n_attrs { let attr = attrs[i].as_str(); - // Check if the attribute name is forbidden in a Store context. - if self.forbidden_name(attr, NameUsage::Store)? { - // Return an error if the name is forbidden. - return Err(self.compile_error_forbidden_name(attr)); - } // Check for duplicates: compare with every subsequent attribute. - for ident in attrs.iter().take(n_attrs).skip(i + 1) { + for (j, ident) in attrs.iter().enumerate().take(n_attrs).skip(i + 1) { let other = ident.as_str(); if attr == other { - return Err(self.error(CodegenErrorType::RepeatedAttributePattern)); + return Err(self.error_ranged( + CodegenErrorType::RepeatedAttributePattern(attr.to_owned()), + patterns[j].range(), + )); } } } @@ -6292,12 +6910,27 @@ impl Compiler { let nargs = patterns.len(); let n_attrs = kwd_attrs.len(); + let n_kwd_patterns = kwd_patterns.len(); + if n_attrs != n_kwd_patterns { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "kwd_attrs ({n_attrs}) / kwd_patterns ({n_kwd_patterns}) length mismatch in class pattern" + )), + p.range, + )); + } // Check for too many sub-patterns. - if nargs > u32::MAX as usize || (nargs + n_attrs).saturating_sub(1) > i32::MAX as usize { - return Err(self.error(CodegenErrorType::SyntaxError( - "too many sub-patterns in class pattern".to_owned(), - ))); + if nargs > i32::MAX as usize + || nargs.saturating_add(n_attrs).saturating_sub(1) > i32::MAX as usize + { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "too many sub-patterns in class pattern {}", + UnparseExpr::new(&match_class.cls, &self.source_file) + )), + p.range, + )); } // Validate keyword attributes if any. @@ -6344,6 +6977,7 @@ impl Compiler { // At this point the TOS is a tuple of (nargs + n_attrs) attributes (or None). pc.on_top += 1; + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); // Unpack the tuple into (nargs + n_attrs) items. @@ -6354,24 +6988,19 @@ impl Compiler { count: u32::try_from(total).unwrap() } ); - pc.on_top += total; - pc.on_top -= 1; + if total == 0 { + pc.on_top -= 1; + } else { + pc.on_top += total - 1; + } // Process each sub-pattern. for subpattern in patterns.iter().chain(kwd_patterns.iter()) { - // Check if this is a true wildcard (underscore pattern without name binding) - let is_true_wildcard = match subpattern { - ast::Pattern::MatchAs(match_as) => { - // Only consider it wildcard if both pattern and name are None (i.e., "_") - match_as.pattern.is_none() && match_as.name.is_none() - } - _ => subpattern.is_wildcard(), - }; - // Decrement the on_top counter for each sub-pattern pc.on_top -= 1; - if is_true_wildcard { + if Self::pattern_wildcard_check(subpattern) { + self.set_source_range(p.range); emit!(self, Instruction::PopTop); continue; // Don't compile wildcard patterns } @@ -6395,27 +7024,36 @@ impl Compiler { // Validate pattern count matches key count if keys.len() != patterns.len() { - return Err(self.error(CodegenErrorType::SyntaxError(format!( - "keys ({}) / patterns ({}) length mismatch in mapping pattern", - keys.len(), - patterns.len() - )))); + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "keys ({}) / patterns ({}) length mismatch in mapping pattern", + keys.len(), + patterns.len() + )), + p.range, + )); } - // Validate rest pattern: '_' cannot be used as a rest target + // CPython rejects `case {**_}:` before codegen. RustPython's parser + // currently lets it through, so keep the compiler boundary equivalent. if let Some(rest) = star_target && rest.as_str() == "_" { - return Err(self.error(CodegenErrorType::SyntaxError("invalid syntax".to_string()))); + return Err(self.error_ranged( + CodegenErrorType::SyntaxError("invalid syntax".to_string()), + rest.range, + )); } // Step 1: Check if subject is a mapping // Stack: [subject] pc.on_top += 1; + self.set_source_range(p.range); emit!(self, Instruction::MatchMapping); // Stack: [subject, is_mapping] + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); // Stack: [subject] @@ -6430,64 +7068,40 @@ impl Compiler { // Length check for patterns with keys if size > 0 { // Check if the mapping has at least 'size' keys + self.set_source_range(p.range); emit!(self, Instruction::GetLen); + self.set_source_range(p.range); self.emit_load_const(ConstantData::Integer { value: size.into() }); // Stack: [subject, len, size] + self.set_source_range(p.range); emit!( self, Instruction::CompareOp { opname: ComparisonOperator::GreaterOrEqual } ); + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); // Stack: [subject] } // Check for overflow (INT_MAX < size - 1) - let size = u32::try_from(size).map_err(|_| { - self.error(CodegenErrorType::SyntaxError( - "too many sub-patterns in mapping pattern".to_string(), - )) - })?; - - // Step 2: If we have keys to match - if size > 0 { - // Validate and compile keys - let mut seen = IndexSet::default(); - for key in keys { - let is_attribute = matches!(key, ast::Expr::Attribute(_)); - let is_literal = matches!( - key, - ast::Expr::NumberLiteral(_) - | ast::Expr::StringLiteral(_) - | ast::Expr::BytesLiteral(_) - | ast::Expr::BooleanLiteral(_) - | ast::Expr::NoneLiteral(_) - ); - let key_repr = if is_literal { - UnparseExpr::new(key, &self.source_file).to_string() - } else if is_attribute { - String::new() - } else { - return Err(self.error(CodegenErrorType::SyntaxError( - "mapping pattern keys may only match literals and attribute lookups" - .to_string(), - ))); - }; - - if !key_repr.is_empty() && seen.contains(&key_repr) { - return Err(self.error(CodegenErrorType::SyntaxError(format!( - "mapping pattern checks duplicate key ({key_repr})" - )))); - } - if !key_repr.is_empty() { - seen.insert(key_repr); - } + if size.saturating_sub(1) > i32::MAX as usize { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError( + "too many sub-patterns in mapping pattern".to_string(), + ), + p.range, + )); + } + let size = size.to_u32(); - self.compile_match_pattern_expr(key)?; - } - self.set_source_range(p.range); + // Step 2: Validate and compile all keys. + let mut seen = Vec::new(); + for key in keys { + self.compile_pattern_mapping_key(&mut seen, p.range, key)?; } + self.set_source_range(p.range); // Stack: [subject, key1, key2, ..., key_n] // Build tuple of keys (empty tuple if size==0) @@ -6500,11 +7114,14 @@ impl Compiler { pc.on_top += 2; // subject and keys_tuple are underneath // Check if match succeeded + self.set_source_range(p.range); emit!(self, Instruction::Copy { i: 1 }); // Stack: [subject, keys_tuple, values_tuple, values_tuple_copy] // Check if copy is None (consumes the copy like POP_JUMP_IF_NONE) + self.set_source_range(p.range); self.emit_load_const(ConstantData::None); + self.set_source_range(p.range); emit!( self, Instruction::IsOp { @@ -6513,14 +7130,18 @@ impl Compiler { ); // Stack: [subject, keys_tuple, values_tuple, bool] + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); // Stack: [subject, keys_tuple, values_tuple] // Unpack values (the original values_tuple) emit!(self, Instruction::UnpackSequence { count: size }); // Stack after unpack: [subject, keys_tuple, ...unpacked values...] - pc.on_top += size as usize; // Unpacked size values, tuple replaced by values - pc.on_top -= 1; + if size == 0 { + pc.on_top -= 1; + } else { + pc.on_top += size as usize - 1; + } // Step 3: Process matched values for i in 0..size { @@ -6539,15 +7160,19 @@ impl Compiler { // Stack: [subject, keys_tuple] // Build rest dict exactly + self.set_source_range(p.range); emit!(self, Instruction::BuildMap { count: 0 }); // Stack: [subject, keys_tuple, {}] + self.set_source_range(p.range); emit!(self, Instruction::Swap { i: 3 }); // Stack: [{}, keys_tuple, subject] + self.set_source_range(p.range); emit!(self, Instruction::DictUpdate { i: 2 }); // Stack after DICT_UPDATE: [rest_dict, keys_tuple] // DICT_UPDATE consumes source (subject) and leaves dict in place // Unpack keys and delete from rest_dict + self.set_source_range(p.range); emit!(self, Instruction::UnpackSequence { count: size }); // Stack: [rest_dict, k1, k2, ..., kn] (if size==0, nothing pushed) @@ -6556,10 +7181,13 @@ impl Compiler { let mut remaining = size; while remaining > 0 { // Copy rest_dict which is at position (1 + remaining) from TOS + self.set_source_range(p.range); emit!(self, Instruction::Copy { i: 1 + remaining }); // Stack: [rest_dict, k1, ..., kn, rest_dict] + self.set_source_range(p.range); emit!(self, Instruction::Swap { i: 2 }); // Stack: [rest_dict, k1, ..., kn-1, rest_dict, kn] + self.set_source_range(p.range); emit!(self, Instruction::DeleteSubscr); // Stack: [rest_dict, k1, ..., kn-1] (removed kn from rest_dict) remaining -= 1; @@ -6568,22 +7196,216 @@ impl Compiler { // pattern_helper_store_name will handle the rotation correctly // Store the rest dict - self.pattern_helper_store_name(Some(rest_name), pc)?; - - // After storing all values, pc.on_top should be 0 - // The values are rotated to the bottom for later storage - pc.on_top = 0; + self.pattern_helper_store_name(p.range, Some(rest_name), pc)?; } else { // Non-rest pattern: just clean up the stack // Pop them as we're not using them + self.set_source_range(p.range); emit!(self, Instruction::PopTop); // Pop keys_tuple + self.set_source_range(p.range); emit!(self, Instruction::PopTop); // Pop subject } Ok(()) } + fn compile_pattern_mapping_key( + &mut self, + seen: &mut Vec, + pattern_range: TextRange, + key: &ast::Expr, + ) -> CompileResult<()> { + let is_attribute = matches!(key, ast::Expr::Attribute(_)); + let constant = match self.try_compile_match_mapping_key_constant(key)? { + Some(constant) => Some(constant), + None if is_attribute => None, + None => { + if Self::is_unexpected_match_literal_constant(key) { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError( + "unexpected constant inside of a literal pattern".to_string(), + ), + pattern_range, + )); + } + return Err(self.error_ranged( + CodegenErrorType::SyntaxError( + "mapping pattern keys may only match literals and attribute lookups" + .to_string(), + ), + pattern_range, + )); + } + }; + + if let Some(constant) = constant { + if seen + .iter() + .any(|seen| Self::match_mapping_keys_equal(seen, &constant)) + { + let key_repr = Self::match_mapping_key_repr(&constant); + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!( + "mapping pattern checks duplicate key ({key_repr})" + )), + pattern_range, + )); + } + seen.push(constant); + } + + self.compile_match_pattern_expr(key) + } + + fn try_compile_match_mapping_key_constant( + &mut self, + key: &ast::Expr, + ) -> CompileResult> { + if let Some(constant) = self.try_fold_match_pattern_const_expr(key)? { + return Ok(Some(constant)); + } + self.try_compile_match_mapping_key_direct_constant(key) + } + + fn try_compile_match_value_constant( + &mut self, + value: &ast::Expr, + ) -> CompileResult> { + if let Some(constant) = self.try_fold_match_pattern_const_expr(value)? { + return Ok(Some(constant)); + } + self.try_compile_match_pattern_direct_literal(value) + } + + fn match_mapping_keys_equal(left: &ConstantData, right: &ConstantData) -> bool { + use ConstantData::{Bytes, Ellipsis, None, Str}; + + if Self::match_mapping_numeric_keys_equal(left, right).unwrap_or(false) { + return true; + } + + match (left, right) { + (Str { value: left }, Str { value: right }) => left == right, + (Bytes { value: left }, Bytes { value: right }) => left == right, + (None, None) | (Ellipsis, Ellipsis) => true, + _ => false, + } + } + + fn match_mapping_key_repr(key: &ConstantData) -> String { + match key { + ConstantData::Integer { value } => value.to_string(), + ConstantData::Float { value } => literal_float::to_string(*value), + ConstantData::Complex { value } => literal_complex::to_string(value.re, value.im), + ConstantData::Boolean { value } => { + if *value { + "True".to_owned() + } else { + "False".to_owned() + } + } + ConstantData::Str { value } => UnicodeEscape::new_repr(value.as_ref()) + .str_repr() + .to_string() + .unwrap_or_else(|| value.to_string()), + ConstantData::Bytes { value } => AsciiEscape::new_repr(value) + .bytes_repr() + .to_string() + .unwrap_or_else(|| format!(r#"b"{}""#, value.escape_ascii())), + ConstantData::None => "None".to_owned(), + ConstantData::Ellipsis => "...".to_owned(), + other => other.to_string(), + } + } + + fn match_mapping_numeric_keys_equal(left: &ConstantData, right: &ConstantData) -> Option { + use ConstantData::{Boolean, Complex, Float, Integer}; + + match (left, right) { + (Integer { value: left }, Integer { value: right }) => Some(left == right), + (Boolean { value: left }, Boolean { value: right }) => Some(left == right), + (Boolean { value }, Integer { value: int }) + | (Integer { value: int }, Boolean { value }) => { + Some(BigInt::from(u8::from(*value)) == *int) + } + (Float { value: left }, Float { value: right }) => Some(left == right), + (Integer { value: int }, Float { value: float }) + | (Float { value: float }, Integer { value: int }) => { + Some(Self::match_mapping_float_integer_equal(*float, int)) + } + (Boolean { value }, Float { value: float }) + | (Float { value: float }, Boolean { value }) => Some( + Self::match_mapping_float_integer_equal(*float, &BigInt::from(u8::from(*value))), + ), + (Complex { value: left }, Complex { value: right }) => { + Some(left.re == right.re && left.im == right.im) + } + (Complex { value: complex }, other) | (other, Complex { value: complex }) => Some( + complex.im == 0.0 + && Self::match_mapping_float_real_constant_equal(complex.re, other) + .unwrap_or(false), + ), + _ => Option::None, + } + } + + fn match_mapping_float_real_constant_equal(float: f64, other: &ConstantData) -> Option { + match other { + ConstantData::Integer { value } => { + Some(Self::match_mapping_float_integer_equal(float, value)) + } + ConstantData::Boolean { value } => Some(Self::match_mapping_float_integer_equal( + float, + &BigInt::from(u8::from(*value)), + )), + ConstantData::Float { value } => Some(float == *value), + _ => None, + } + } + + fn match_mapping_float_integer_equal(float: f64, int: &BigInt) -> bool { + Self::match_mapping_float_to_integer(float).is_some_and(|float_int| &float_int == int) + } + + fn match_mapping_float_to_integer(value: f64) -> Option { + if !value.is_finite() { + return None; + } + if value == 0.0 { + return Some(BigInt::from(0)); + } + + let bits = value.to_bits(); + let negative = (bits >> 63) != 0; + let exponent_bits = i32::try_from((bits >> 52) & 0x7ff).ok()?; + let fraction = bits & ((1_u64 << 52) - 1); + let (mantissa, exponent) = if exponent_bits == 0 { + (fraction, -1074) + } else { + ((1_u64 << 52) | fraction, exponent_bits - 1023 - 52) + }; + + let mut integer = if exponent >= 0 { + BigInt::from(mantissa) << u32::try_from(exponent).ok()? + } else { + let shift = u32::try_from(-exponent).ok()?; + if shift >= u64::BITS { + return None; + } + let mask = (1_u64 << shift) - 1; + if mantissa & mask != 0 { + return None; + } + BigInt::from(mantissa >> shift) + }; + + if negative { + integer = -integer; + } + Some(integer) + } + fn compile_pattern_or( &mut self, p: &ast::PatternMatchOr, @@ -6625,7 +7447,9 @@ impl Compiler { } else { let control_vec = control.as_ref().unwrap(); if n_stores != control_vec.len() { - return Err(self.error(CodegenErrorType::ConflictingNameBindPattern)); + return Err( + self.error_ranged(CodegenErrorType::ConflictingNameBindPattern, p.range()) + ); } else if n_stores > 0 { // Check that the names occur in the same order. for i_control in (0..n_stores).rev() { @@ -6633,7 +7457,10 @@ impl Compiler { // Find the index of `name` in the current stores. let i_stores = pc.stores.iter().position(|n| n == name).ok_or_else(|| { - self.error(CodegenErrorType::ConflictingNameBindPattern) + self.error_ranged( + CodegenErrorType::ConflictingNameBindPattern, + p.range(), + ) })?; if i_control != i_stores { // The orders differ; we must reorder. @@ -6653,7 +7480,7 @@ impl Compiler { // Also perform the same rotation on the evaluation stack. self.set_source_range(alt.range()); for _ in 0..=i_stores { - self.pattern_helper_rotate(i_control + 1); + self.pattern_helper_rotate(alt.range(), i_control + 1); } } } @@ -6663,7 +7490,7 @@ impl Compiler { self.set_source_range(alt.range()); emit!(self, PseudoInstruction::Jump { delta: end }); self.set_source_range(alt.range()); - self.emit_and_reset_fail_pop(pc); + self.emit_and_reset_fail_pop(pc, alt.range()); } // Restore the original pattern context. @@ -6688,11 +7515,14 @@ impl Compiler { for i in 0..n_stores { // Rotate the capture to its proper place. self.set_source_range(p.range()); - self.pattern_helper_rotate(n_rots); + self.pattern_helper_rotate(p.range(), n_rots); let name = &control.as_ref().unwrap()[i]; // Check for duplicate binding. if pc.stores.contains(name) { - return Err(self.error(CodegenErrorType::DuplicateStore(name.to_string()))); + return Err(self.error_ranged( + CodegenErrorType::DuplicateStore(name.to_string()), + p.range(), + )); } pc.stores.push(name.clone()); } @@ -6720,47 +7550,59 @@ impl Compiler { for (i, pattern) in patterns.iter().enumerate() { if pattern.is_match_star() { if star.is_some() { - // TODO: Fix error msg - return Err(self.error(CodegenErrorType::MultipleStarArgs)); + return Err(self.error_ranged( + CodegenErrorType::MultipleStarredNamesInSequencePattern, + p.range, + )); } // star wildcard check - star_wildcard = pattern.as_match_star().is_some_and(|m| m.name.is_none()); + star_wildcard = Self::pattern_wildcard_star_check(pattern); only_wildcard &= star_wildcard; star = Some(i); continue; } // wildcard check - only_wildcard &= pattern.as_match_as().is_some_and(|m| m.name.is_none()); + only_wildcard &= Self::pattern_wildcard_check(pattern); } // Keep the subject on top during the sequence and length checks. pc.on_top += 1; + self.set_source_range(p.range); emit!(self, Instruction::MatchSequence); + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); if star.is_none() { // No star: len(subject) == size + self.set_source_range(p.range); emit!(self, Instruction::GetLen); + self.set_source_range(p.range); self.emit_load_const(ConstantData::Integer { value: size.into() }); + self.set_source_range(p.range); emit!( self, Instruction::CompareOp { opname: ComparisonOperator::Equal } ); + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); } else if size > 1 { // Star exists: len(subject) >= size - 1 + self.set_source_range(p.range); emit!(self, Instruction::GetLen); + self.set_source_range(p.range); self.emit_load_const(ConstantData::Integer { value: (size - 1).into(), }); + self.set_source_range(p.range); emit!( self, Instruction::CompareOp { opname: ComparisonOperator::GreaterOrEqual } ); + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); } @@ -6768,11 +7610,12 @@ impl Compiler { pc.on_top -= 1; if only_wildcard { // ast::Patterns like: [] / [_] / [_, _] / [*_] / [_, *_] / [_, _, *_] / etc. + self.set_source_range(p.range); emit!(self, Instruction::PopTop); } else if star_wildcard { - self.pattern_helper_sequence_subscr(patterns, star.unwrap(), pc)?; + self.pattern_helper_sequence_subscr(p.range, patterns, star.unwrap(), pc)?; } else { - self.pattern_helper_sequence_unpack(patterns, star, pc)?; + self.pattern_helper_sequence_unpack(p.range, patterns, star, pc)?; } Ok(()) } @@ -6785,14 +7628,37 @@ impl Compiler { // Match CPython codegen_pattern_value(): compare, then normalize to bool // before the fail jump. Late IR folding will collapse COMPARE_OP+TO_BOOL // into COMPARE_OP bool(...) when applicable. - self.compile_match_pattern_expr(&p.value)?; + if let Some(constant) = self.try_compile_match_value_constant(&p.value)? { + self.set_source_range(p.value.range()); + self.emit_load_const(constant); + } else if matches!(*p.value, ast::Expr::Attribute(_)) { + self.compile_expression(&p.value)?; + } else { + if Self::is_unexpected_match_literal_constant(&p.value) { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError( + "unexpected constant inside of a literal pattern".to_string(), + ), + p.range, + )); + } + return Err(self.error_ranged( + CodegenErrorType::SyntaxError( + "patterns may only match literals and attribute lookups".to_string(), + ), + p.range, + )); + } + self.set_source_range(p.range); emit!( self, Instruction::CompareOp { opname: bytecode::ComparisonOperator::Equal } ); + self.set_source_range(p.range); emit!(self, Instruction::ToBool); + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); Ok(()) } @@ -6803,14 +7669,17 @@ impl Compiler { pc: &mut PatternContext, ) { // Load the singleton constant value. + self.set_source_range(p.range); self.emit_load_const(match p.value { ast::Singleton::None => ConstantData::None, ast::Singleton::False => ConstantData::Boolean { value: false }, ast::Singleton::True => ConstantData::Boolean { value: true }, }); // Compare using the "Is" operator. + self.set_source_range(p.range); emit!(self, Instruction::IsOp { invert: Invert::No }); // Jump to the failure label if the comparison is false. + self.set_source_range(p.range); self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse); } @@ -6858,22 +7727,13 @@ impl Compiler { cases: &[ast::MatchCase], pattern_context: &mut PatternContext, ) -> CompileResult<()> { - fn is_trailing_wildcard_default(pattern: &ast::Pattern) -> bool { - match pattern { - ast::Pattern::MatchAs(match_as) => { - match_as.pattern.is_none() && match_as.name.is_none() - } - _ => false, - } - } - self.compile_expression(subject)?; let end = self.new_block(); let num_cases = cases.len(); assert!(num_cases > 0); let has_default = - num_cases > 1 && is_trailing_wildcard_default(&cases.last().unwrap().pattern); + num_cases > 1 && Self::pattern_wildcard_check(&cases.last().unwrap().pattern); let case_count = num_cases - usize::from(has_default); for (i, m) in cases.iter().enumerate().take(case_count) { @@ -6891,8 +7751,8 @@ impl Compiler { self.compile_pattern(&m.pattern, pattern_context)?; assert_eq!(pattern_context.on_top, 0); - self.set_source_range(m.pattern.range()); for name in &pattern_context.stores { + self.set_source_range(m.pattern.range()); self.compile_name(name, NameUsage::Store)?; } @@ -6920,7 +7780,7 @@ impl Compiler { emit!(self, PseudoInstruction::Jump { delta: end }); self.set_no_location(); self.set_source_range(m.pattern.range()); - self.emit_and_reset_fail_pop(pattern_context); + self.emit_and_reset_fail_pop(pattern_context, m.pattern.range()); } if has_default { @@ -6932,7 +7792,7 @@ impl Compiler { emit!(self, Instruction::Nop); } if let Some(ref guard) = m.guard { - self.compile_jump_if(guard, false, end)?; + self.compile_jump_if_inner(guard, false, end, Some(m.pattern.range()))?; } self.compile_statements(&m.body)?; } @@ -7037,6 +7897,7 @@ impl Compiler { ) -> CompileResult<()> { // Save the full Compare expression range for COMPARE_OP positions let compare_range = self.current_source_range; + self.check_compare(compare_range, left, ops, comparators)?; let (last_op, mid_ops) = ops.split_last().unwrap(); let (last_comparator, mid_comparators) = comparators.split_last().unwrap(); @@ -7097,6 +7958,7 @@ impl Compiler { target_block: BlockIdx, ) -> CompileResult<()> { let compare_range = self.current_source_range; + self.check_compare(compare_range, left, ops, comparators)?; let (last_op, mid_ops) = ops.split_last().unwrap(); let (last_comparator, mid_comparators) = comparators.split_last().unwrap(); @@ -7134,13 +7996,13 @@ impl Compiler { self.use_cpython_label_block(cleanup); emit!(self, Instruction::PopTop); if !condition { - self.set_no_location(); emit!( self, PseudoInstruction::JumpNoInterrupt { delta: target_block } ); + self.set_no_location(); } self.use_cpython_label_block(end); @@ -7182,7 +8044,9 @@ impl Compiler { ast::Expr::Starred(ast::ExprStarred { value, .. }) => { // *args: *Ts (where Ts is a TypeVarTuple). // Do [annotation_value] = [*Ts]. + let saved_range = self.current_source_range; self.compile_expression(value)?; + self.set_source_range(saved_range); emit!(self, Instruction::UnpackSequence { count: 1 }); Ok(()) } @@ -7197,6 +8061,7 @@ impl Compiler { fn compile_check_annotation_expression(&mut self, expression: &ast::Expr) -> CompileResult<()> { self.compile_expression(expression)?; + self.set_source_range(expression.range()); emit!(self, Instruction::PopTop); Ok(()) } @@ -7269,23 +8134,24 @@ impl Compiler { } else { // PEP 649: Handle conditional annotations if self.current_symbol_table().has_conditional_annotations { - // Allocate an index for every annotation when has_conditional_annotations - // This keeps indices aligned with compile_module_annotate's enumeration - let code_info = self.current_code_info(); - let annotation_index = code_info.next_conditional_annotation_index; - code_info.next_conditional_annotation_index += 1; - - // Determine if this annotation is conditional - // Module and Class scopes both need all annotations tracked let scope_type = self.current_symbol_table().typ; let in_conditional_block = self.current_code_info().in_conditional_block > 0; let is_conditional = - matches!(scope_type, CompilerScope::Module | CompilerScope::Class) - || in_conditional_block; + matches!(scope_type, CompilerScope::Module) || in_conditional_block; - // Only add to __conditional_annotations__ set if actually conditional if is_conditional { - self.load_name("__conditional_annotations__")?; + let code_info = self.current_code_info(); + let annotation_index = code_info.next_conditional_annotation_index; + code_info.next_conditional_annotation_index += 1; + + self.set_source_range(loc); + if matches!(scope_type, CompilerScope::Class) { + let i = self.get_cell_var_index("__conditional_annotations__"); + emit!(self, Instruction::LoadDeref { i }); + } else { + let namei = self.name("__conditional_annotations__"); + emit!(self, Instruction::LoadName { namei }); + } self.emit_load_const(ConstantData::Integer { value: annotation_index.into(), }); @@ -7340,23 +8206,27 @@ impl Compiler { // Scan for star args: for (i, element) in elts.iter().enumerate() { - if let ast::Expr::Starred(_) = &element { - if seen_star { - return Err(self.error(CodegenErrorType::MultipleStarArgs)); - } - - seen_star = true; + if matches!(element, ast::Expr::Starred(_)) && !seen_star { let before = i; let after = elts.len() - i - 1; - let (before, after) = (|| Some((before.to_u8()?, after.to_u8()?)))() - .ok_or_else(|| { - self.error_ranged( - CodegenErrorType::TooManyStarUnpack, - target.range(), - ) - })?; + if before >= (1 << 8) || after >= ((i32::MAX as usize) >> 8) { + return Err(self.error_ranged( + CodegenErrorType::TooManyStarUnpack, + target.range(), + )); + } + let before = before.to_u8().ok_or_else(|| { + self.error_ranged( + CodegenErrorType::TooManyStarUnpack, + target.range(), + ) + })?; + let after = after.to_u32(); let counts = bytecode::UnpackExArgs { before, after }; emit!(self, Instruction::UnpackEx { counts }); + seen_star = true; + } else if matches!(element, ast::Expr::Starred(_)) { + return Err(self.error(CodegenErrorType::MultipleStarArgs)); } } @@ -7427,7 +8297,7 @@ impl Compiler { ctx: _, .. }) => { - let use_slice_opt = slice.should_use_slice_optimization(); + let use_slice_opt = self.should_apply_two_element_slice_optimization(slice); self.compile_expression(value)?; self.set_source_range(target_range); if use_slice_opt { @@ -7461,7 +8331,7 @@ impl Compiler { self.compile_expression(value)?; let attr_range = self.update_start_location_to_match_attr(target_range, target_range, attr); - self.set_source_range(attr_range); + self.set_source_range(target_range); emit!(self, Instruction::Copy { i: 1 }); let idx = self.name(attr); self.set_source_range(attr_range); @@ -7596,6 +8466,7 @@ impl Compiler { comparators, .. }) if ops.len() > 1 => { + self.set_source_range(expression.range()); self.compile_jump_if_compare(left, ops, comparators, condition, target_block) } _ => { @@ -7668,130 +8539,50 @@ impl Compiler { } } - fn compile_dict(&mut self, items: &[ast::DictItem], range: TextRange) -> CompileResult<()> { - let has_unpacking = items.iter().any(|item| item.key.is_none()); - - if !has_unpacking { - // Match CPython's compiler_subdict chunking strategy: - // - n≤15: BUILD_MAP n (all pairs on stack) - // - n>15: BUILD_MAP 0 + MAP_ADD chunks of 17, last chunk uses - // BUILD_MAP n (if ≤15) or BUILD_MAP 0 + MAP_ADD - const STACK_LIMIT: usize = 15; - const BIG_MAP_CHUNK: usize = 17; - - if items.len() <= STACK_LIMIT { - for item in items { - self.compile_expression(item.key.as_ref().unwrap())?; - self.compile_expression(&item.value)?; - } - self.set_source_range(range); - emit!( - self, - Instruction::BuildMap { - count: u32::try_from(items.len()).expect("too many dict items"), - } - ); - } else { - // Split: leading full chunks of BIG_MAP_CHUNK via MAP_ADD, - // remainder via BUILD_MAP n or MAP_ADD depending on size - let n = items.len(); - let remainder = n % BIG_MAP_CHUNK; - let n_big_chunks = n / BIG_MAP_CHUNK; - // If remainder fits on stack (≤15), use BUILD_MAP n for it. - // Otherwise it becomes another MAP_ADD chunk. - let (big_count, tail_count) = if remainder > 0 && remainder <= STACK_LIMIT { - (n_big_chunks, remainder) - } else { - // remainder is 0 or >15: all chunks are MAP_ADD chunks - let total_map_add = if remainder == 0 { - n_big_chunks - } else { - n_big_chunks + 1 - }; - (total_map_add, 0usize) - }; - + fn compile_subdict( + &mut self, + items: &[ast::DictItem], + begin: usize, + end: usize, + range: TextRange, + ) -> CompileResult<()> { + let n = end - begin; + let big = n * 2 > STACK_USE_GUIDELINE as usize; + if big { + self.set_source_range(range); + emit!(self, Instruction::BuildMap { count: 0 }); + } + for item in &items[begin..end] { + self.compile_expression(item.key.as_ref().unwrap())?; + self.compile_expression(&item.value)?; + if big { self.set_source_range(range); - emit!(self, Instruction::BuildMap { count: 0 }); - - let mut idx = 0; - for chunk_i in 0..big_count { - if chunk_i > 0 { - self.set_source_range(range); - emit!(self, Instruction::BuildMap { count: 0 }); - } - let chunk_size = if idx + BIG_MAP_CHUNK <= n - tail_count { - BIG_MAP_CHUNK - } else { - n - tail_count - idx - }; - for item in &items[idx..idx + chunk_size] { - self.compile_expression(item.key.as_ref().unwrap())?; - self.compile_expression(&item.value)?; - self.set_source_range(range); - emit!(self, Instruction::MapAdd { i: 1 }); - } - if chunk_i > 0 { - self.set_source_range(range); - emit!(self, Instruction::DictUpdate { i: 1 }); - } - idx += chunk_size; - } - - // Tail: remaining pairs via BUILD_MAP n + DICT_UPDATE - if tail_count > 0 { - for item in &items[idx..idx + tail_count] { - self.compile_expression(item.key.as_ref().unwrap())?; - self.compile_expression(&item.value)?; - } - self.set_source_range(range); - emit!( - self, - Instruction::BuildMap { - count: tail_count.to_u32(), - } - ); - self.set_source_range(range); - emit!(self, Instruction::DictUpdate { i: 1 }); - } + emit!(self, Instruction::MapAdd { i: 1 }); } - return Ok(()); } + if !big { + self.set_source_range(range); + emit!(self, Instruction::BuildMap { count: n.to_u32() }); + } + Ok(()) + } - // Complex case with ** unpacking: preserve insertion order. - // Collect runs of regular k:v pairs and emit BUILD_MAP + DICT_UPDATE - // for each run, and DICT_UPDATE for each ** entry. + fn compile_dict(&mut self, items: &[ast::DictItem], range: TextRange) -> CompileResult<()> { + let n = items.len(); let mut have_dict = false; - let mut elements: u32 = 0; - - // Flush pending regular pairs as a BUILD_MAP, merging into the - // accumulator dict via DICT_UPDATE when one already exists. - macro_rules! flush_pending { - () => { - #[allow(unused_assignments)] - if elements > 0 { - self.set_source_range(range); - emit!(self, Instruction::BuildMap { count: elements }); + let mut elements = 0usize; + + for (i, item) in items.iter().enumerate() { + if item.key.is_none() { + if elements != 0 { + self.compile_subdict(items, i - elements, i, range)?; if have_dict { self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); - } else { - have_dict = true; } + have_dict = true; elements = 0; } - }; - } - - for item in items { - if let Some(key) = &item.key { - // Regular key: value pair - self.compile_expression(key)?; - self.compile_expression(&item.value)?; - elements += 1; - } else { - // ** unpacking entry - flush_pending!(); if !have_dict { self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); @@ -7800,10 +8591,27 @@ impl Compiler { self.compile_expression(&item.value)?; self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); + } else if elements * 2 > STACK_USE_GUIDELINE as usize { + self.compile_subdict(items, i - elements, i + 1, range)?; + if have_dict { + self.set_source_range(range); + emit!(self, Instruction::DictUpdate { i: 1 }); + } + have_dict = true; + elements = 0; + } else { + elements += 1; } } - flush_pending!(); + if elements != 0 { + self.compile_subdict(items, n - elements, n, range)?; + if have_dict { + self.set_source_range(range); + emit!(self, Instruction::DictUpdate { i: 1 }); + } + have_dict = true; + } if !have_dict { self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); @@ -7882,14 +8690,88 @@ impl Compiler { send_block } + fn public_ast_constant_override(&self, expr: &ast::Expr) -> Option { + let index = ast::HasNodeIndex::node_index(expr).load(); + if index == ast::NodeIndex::NONE { + return None; + } + self.opts + .ast_constant_overrides + .as_ref()? + .get(&index) + .cloned() + } + + fn public_ast_interpolation_override( + &self, + expr_tstring: &ast::ExprTString, + ) -> Option { + let index = expr_tstring.node_index.load(); + self.public_ast_interpolation_override_by_index(index) + } + + fn public_ast_interpolation_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.opts + .ast_interpolation_overrides + .as_ref()? + .get(&index) + .cloned() + } + + fn public_ast_formatted_value_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.opts + .ast_formatted_value_overrides + .as_ref()? + .get(&index) + .cloned() + } + + fn public_ast_joined_str_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.opts + .ast_joined_str_overrides + .as_ref()? + .get(&index) + .cloned() + } + + fn public_ast_template_str_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.opts + .ast_template_str_overrides + .as_ref()? + .get(&index) + .cloned() + } + fn compile_expression(&mut self, expression: &ast::Expr) -> CompileResult<()> { trace!("Compiling {expression:?}"); let range = expression.range(); self.set_source_range(range); - if matches!(expression, ast::Expr::BinOp(_)) - && let Some(constant) = self.try_fold_constant_expr(expression)? - { + if let Some(constant) = self.public_ast_constant_override(expression) { self.emit_load_const(constant); return Ok(()); } @@ -7917,22 +8799,7 @@ impl Compiler { self.compile_subscript(value, slice, *ctx)?; } ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { - if let ( - ast::UnaryOp::Not, - ast::Expr::Compare(ast::ExprCompare { - left, - ops, - comparators, - .. - }), - ) = (op, operand.as_ref()) - && ops.len() == 1 - { - self.set_source_range(range); - self.compile_compare(left, ops, comparators)?; - } else { - self.compile_expression(operand)?; - } + self.compile_expression(operand)?; // Restore full expression range before emitting the operation self.set_source_range(range); @@ -7963,6 +8830,8 @@ impl Compiler { unreachable!("can_optimize_super_call only accepts calls"); }; self.load_args_for_super(&super_type, super_func.range(), value.range())?; + let attr_access_range = + self.update_start_location_to_match_attr(range, range, attr.as_str()); self.set_source_range(range); let idx = self.name(attr.as_str()); match super_type { @@ -7973,6 +8842,8 @@ impl Compiler { self.emit_load_zero_super_attr(idx); } } + self.set_source_range(attr_access_range); + emit!(self, Instruction::Nop); } else { // Normal attribute access self.compile_expression(value)?; @@ -8077,7 +8948,9 @@ impl Compiler { ); } ast::Expr::Await(ast::ExprAwait { value, .. }) => { - if self.ctx.func != FunctionContext::AsyncFunction { + if self.ctx.func != FunctionContext::AsyncFunction + && !self.allows_top_level_await_in_current_context() + { return Err(self.error(CodegenErrorType::InvalidAwait)); } self.compile_expression(value)?; @@ -8162,12 +9035,12 @@ impl Compiler { } self.enter_function(&name, params)?; - let mut func_flags = MakeFunctionFlags::new(); + let mut func_flags = bytecode::MakeFunctionFlags::new(); if have_defaults { - func_flags.insert(MakeFunctionFlag::Defaults); + func_flags.insert(bytecode::MakeFunctionFlag::Defaults); } if have_kwdefaults { - func_flags.insert(MakeFunctionFlag::KwOnlyDefaults); + func_flags.insert(bytecode::MakeFunctionFlag::KwOnlyDefaults); } // Set qualname for lambda @@ -8181,15 +9054,20 @@ impl Compiler { }; self.compile_expression(body)?; - self.set_source_range(body.range()); - self.emit_return_value(); - // _PyCodegen_AddReturnAtEnd() appends a no-location - // return-None epilogue even after lambda's explicit - // RETURN_VALUE. It is later removed as unreachable, but - // remove_unused_consts() keeps None when it was the first - // constant in an otherwise constant-free lambda. - if self.current_code_info().metadata.consts.is_empty() { - self.arg_constant(ConstantData::None); + let is_generator = self + .current_code_info() + .flags + .contains(bytecode::CodeFlags::GENERATOR); + if is_generator { + // CPython codegen_lambda() calls OptimizeAndAssemble with + // addNone=0, so AddReturnAtEnd appends RETURN_VALUE without + // adding None to co_consts. + emit!(self, Instruction::ReturnValue); + self.set_no_location(); + } else { + self.set_source_range(body.range()); + self.emit_return_value(); + self.emit_return_const_no_location(ConstantData::None); } let code = self.exit_scope(); @@ -8207,7 +9085,12 @@ impl Compiler { }) => { self.compile_comprehension( "", - Some(Opcode::BuildList.into()), + Some( + Instruction::BuildList { + count: OpArgMarker::marker(), + } + .into(), + ), generators, &|compiler, collection_add_i| { compiler.compile_comprehension_element(elt)?; @@ -8235,7 +9118,12 @@ impl Compiler { }) => { self.compile_comprehension( "", - Some(Opcode::BuildSet.into()), + Some( + Instruction::BuildSet { + count: OpArgMarker::marker(), + } + .into(), + ), generators, &|compiler, collection_add_i| { compiler.compile_comprehension_element(elt)?; @@ -8262,9 +9150,17 @@ impl Compiler { range, .. }) => { + let Some(key) = key.as_deref() else { + return Err(self.error(CodegenErrorType::InvalidStarExpr)); + }; self.compile_comprehension( "", - Some(Opcode::BuildMap.into()), + Some( + Instruction::BuildMap { + count: OpArgMarker::marker(), + } + .into(), + ), generators, &|compiler, collection_add_i| { // changed evaluation order for Py38 named expression PEP 572 @@ -8356,9 +9252,24 @@ impl Compiler { self.set_source_range(target.range()); } ast::Expr::FString(fstring) => { + if let Some(joined_str) = + self.public_ast_joined_str_override_by_index(fstring.node_index.load()) + { + return self.compile_public_ast_joined_str(fstring, joined_str); + } self.compile_expr_fstring(fstring)?; } ast::Expr::TString(tstring) => { + if let Some(template_str) = + self.public_ast_template_str_override_by_index(tstring.node_index.load()) + { + return self.compile_public_ast_template_str(tstring, template_str); + } + if let Some(interpolation) = self.public_ast_interpolation_override(tstring) + && self.compile_public_ast_interpolation(tstring, interpolation)? + { + return Ok(()); + } self.compile_expr_tstring(tstring)?; } ast::Expr::StringLiteral(string) => { @@ -8393,8 +9304,11 @@ impl Compiler { ast::Expr::EllipsisLiteral(_) => { self.emit_load_const(ConstantData::Ellipsis); } - ast::Expr::IpyEscapeCommand(_) => { - panic!("unexpected ipy escape command"); + ast::Expr::IpyEscapeCommand(expr) => { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError("invalid syntax".to_owned()), + expr.range, + )); } } Ok(()) @@ -8486,17 +9400,13 @@ impl Compiler { emit!(self, Instruction::BuildList { count: 0 }); } - let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); + let symbol_table_cursors = self.current_symbol_table_cursors(); if let Some(range) = self.cpython_implicit_call_generator_range(generator_expr) { self.compile_expression_with_generator_range(generator_expr, range)?; } else { self.compile_expression(generator_expr)?; } - if let Some(cursor) = sub_table_cursor - && let Some(current_table) = self.symbol_table_stack.last_mut() - { - current_table.next_sub_table = cursor; - } + self.set_symbol_table_cursors(symbol_table_cursors); let loop_block = self.new_block(); let cleanup = self.new_block(); @@ -8515,9 +9425,9 @@ impl Compiler { self.set_source_range(loc); emit!(self, Instruction::ToBool); emit!(self, Instruction::PopJumpIfTrue { delta: loop_block }); - self.set_source_range(loc); emit!(self, Instruction::PopIter); self.set_no_location(); + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: false }); self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); @@ -8526,9 +9436,9 @@ impl Compiler { self.set_source_range(loc); emit!(self, Instruction::ToBool); emit!(self, Instruction::PopJumpIfFalse { delta: loop_block }); - self.set_source_range(loc); emit!(self, Instruction::PopIter); self.set_no_location(); + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: true }); self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); @@ -8536,10 +9446,8 @@ impl Compiler { } self.use_cpython_label_block(cleanup); - self.set_source_range(loc); emit!(self, Instruction::EndFor); self.set_no_location(); - self.set_source_range(loc); emit!(self, Instruction::PopIter); self.set_no_location(); match kind { @@ -8568,20 +9476,100 @@ impl Compiler { Ok(()) } + fn can_use_cpython_method_call(&self, value: &ast::Expr, args: &ast::Arguments) -> bool { + let is_import = matches!(value, ast::Expr::Name(ast::ExprName { id, .. }) + if self.is_name_imported(id.as_str())); + if is_import { + return false; + } + + if args.args.len() + args.keywords.len() + usize::from(!args.keywords.is_empty()) + >= STACK_USE_GUIDELINE as usize + { + return false; + } + + !args + .args + .iter() + .any(|arg| matches!(arg, ast::Expr::Starred(_))) + && args.keywords.iter().all(|kw| kw.arg.is_some()) + } + + fn compile_method_call_arguments( + &mut self, + args: &ast::Arguments, + call_range: TextRange, + kw_names_range: TextRange, + ) -> CompileResult<()> { + let implicit_generator_range = if args.args.len() == 1 && args.keywords.is_empty() { + self.cpython_implicit_call_generator_range(&args.args[0]) + } else { + None + }; + for arg in &args.args { + if let Some(range) = implicit_generator_range { + self.compile_expression_with_generator_range(arg, range)?; + } else { + self.compile_expression(arg)?; + } + } + + if args.keywords.is_empty() { + self.set_source_range(call_range); + emit!( + self, + Instruction::Call { + argc: args.args.len().to_u32() + } + ); + return Ok(()); + } + + let mut kwarg_names = Vec::with_capacity(args.keywords.len()); + for keyword in &args.keywords { + kwarg_names.push(ConstantData::Str { + value: keyword.arg.as_ref().unwrap().as_str().into(), + }); + self.compile_expression(&keyword.value)?; + } + self.set_source_range(kw_names_range); + self.emit_load_const(ConstantData::Tuple { + elements: kwarg_names, + }); + self.set_source_range(call_range); + emit!( + self, + Instruction::CallKw { + argc: (args.args.len() + args.keywords.len()).to_u32() + } + ); + Ok(()) + } + fn compile_call(&mut self, func: &ast::Expr, args: &ast::Arguments) -> CompileResult<()> { // Save the call expression's source range so CALL instructions use the // call start line, not the last argument's line. let call_range = self.current_source_range; + self.validate_keywords(&args.keywords)?; let uses_ex_call = self.call_uses_ex_call(args); // Method call: obj → LOAD_ATTR_METHOD → [method, self_or_null] → args → CALL // Regular call: func → PUSH_NULL → args → CALL if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = &func { + if !self.can_use_cpython_method_call(value, args) { + self.check_caller(func)?; + self.compile_expression(func)?; + self.set_source_range(func.range()); + emit!(self, Instruction::PushNull); + self.codegen_call_helper(0, args, call_range, None)?; + return Ok(()); + } + // Check for super() method call optimization if let Some(super_type) = self.can_optimize_super_call(value, attr.as_str()) { // super().method() or super(cls, self).method() optimization // CALL path: [global_super, class, self] → LOAD_SUPER_METHOD → [method, self] - // CALL_FUNCTION_EX path: [global_super, class, self] → LOAD_SUPER_ATTR → [attr] let ast::Expr::Call(ast::ExprCall { func: super_func, .. }) = value.as_ref() @@ -8599,41 +9587,21 @@ impl Compiler { func.range(), attr.as_str(), ); - self.set_source_range(attr_access_range); + self.set_source_range(func.range()); let idx = self.name(attr.as_str()); - if uses_ex_call { - self.set_source_range(func.range()); - match super_type { - SuperCallType::TwoArg { .. } => { - self.emit_load_super_attr(idx); - } - SuperCallType::ZeroArg => { - self.emit_load_zero_super_attr(idx); - } + match super_type { + SuperCallType::TwoArg { .. } => { + self.emit_load_super_method(idx); } - // CPython's Attribute_kind super path emits an attr-line - // NOP after LOAD_SUPER_ATTR, even when the call later uses - // CALL_FUNCTION_EX for starred arguments. - self.set_source_range(attr_access_range); - emit!(self, Instruction::Nop); - self.set_source_range(func.range()); - emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range, None)?; - } else { - match super_type { - SuperCallType::TwoArg { .. } => { - self.emit_load_super_method(idx); - } - SuperCallType::ZeroArg => { - self.emit_load_zero_super_method(idx); - } + SuperCallType::ZeroArg => { + self.emit_load_zero_super_method(idx); } - // NOP for line tracking at .method( line - self.set_source_range(attr_access_range); - emit!(self, Instruction::Nop); - // CALL at .method( line (not the full expression line) - self.codegen_call_helper(0, args, method_call_range, Some(attr_access_range))?; } + // NOP for line tracking at .method( line + self.set_source_range(attr_access_range); + emit!(self, Instruction::Nop); + // CALL at .method( line (not the full expression line) + self.compile_method_call_arguments(args, method_call_range, attr_access_range)?; } else { self.compile_expression(value)?; let idx = self.name(attr.as_str()); @@ -8648,28 +9616,15 @@ impl Compiler { attr.as_str(), ); self.set_source_range(attr_access_range); - // Imported names and CALL_FUNCTION_EX-style calls use plain - // LOAD_ATTR + PUSH_NULL; other names use method-call mode. - // Check current scope and enclosing scopes for IMPORTED flag. - let is_import = matches!(value.as_ref(), ast::Expr::Name(ast::ExprName { id, .. }) - if self.is_name_imported(id.as_str())); - if is_import || uses_ex_call { - self.emit_load_attr(idx); - emit!(self, Instruction::PushNull); - } else { - self.emit_load_attr_method(idx); - } - if is_import || uses_ex_call { - self.codegen_call_helper(0, args, call_range, None)?; - } else { - self.codegen_call_helper(0, args, method_call_range, Some(attr_access_range))?; - } + self.emit_load_attr_method(idx); + self.compile_method_call_arguments(args, method_call_range, attr_access_range)?; } } else if let Some(kind) = (!uses_ex_call) .then(|| self.detect_builtin_generator_call(func, args)) .flatten() { let skip_normal_call = self.new_block(); + self.check_caller(func)?; self.compile_expression(func)?; self.optimize_builtin_generator_call( kind, @@ -8692,6 +9647,7 @@ impl Compiler { .then(|| self.cpython_sync_genexpr_call_name(func, args)) .flatten() .is_some(); + self.check_caller(func)?; self.compile_expression(func)?; if sync_genexpr_call_name { // CPython `maybe_optimize_function_call()` creates and uses @@ -8703,6 +9659,7 @@ impl Compiler { .use_raw_instr_sequence_label(skip_optimization); unwrap_internal(self, result); } + self.set_source_range(func.range()); emit!(self, Instruction::PushNull); self.codegen_call_helper(0, args, call_range, None)?; let result = self @@ -8724,6 +9681,24 @@ impl Compiler { has_starred || has_double_star || too_big } + /// Compile subkwargs: emit key-value pairs for BUILD_MAP + fn validate_keywords(&mut self, keywords: &[ast::Keyword]) -> CompileResult<()> { + for (i, keyword) in keywords.iter().enumerate() { + let Some(arg) = &keyword.arg else { + continue; + }; + for other in &keywords[i + 1..] { + if other.arg.as_ref() == Some(arg) { + return Err(self.error_ranged( + CodegenErrorType::SyntaxError(format!("keyword argument repeated: {arg}")), + other.range, + )); + } + } + } + Ok(()) + } + /// Compile subkwargs: emit key-value pairs for BUILD_MAP fn codegen_subkwargs( &mut self, @@ -8739,8 +9714,8 @@ impl Compiler { let big = n * 2 > STACK_USE_GUIDELINE as usize; if big { - self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: 0 }); + self.set_no_location(); } for kw in &keywords[begin..end] { @@ -8752,8 +9727,8 @@ impl Compiler { self.compile_expression(&kw.value)?; if big { - self.set_source_range(call_range); emit!(self, Instruction::MapAdd { i: 1 }); + self.set_no_location(); } } @@ -8775,15 +9750,33 @@ impl Compiler { call_range: TextRange, kw_names_range: Option, ) -> CompileResult<()> { - let nelts = arguments.args.len(); - let nkwelts = arguments.keywords.len(); + self.codegen_call_helper_impl( + additional_positional, + &arguments.args, + &arguments.keywords, + call_range, + kw_names_range, + None, + ) + } + + fn codegen_call_helper_impl( + &mut self, + additional_positional: u32, + args: &[ast::Expr], + keywords: &[ast::Keyword], + call_range: TextRange, + kw_names_range: Option, + injected_arg: Option<&str>, + ) -> CompileResult<()> { + self.validate_keywords(keywords)?; + + let nelts = args.len(); + let nkwelts = keywords.len(); // Check if we have starred args or **kwargs - let has_starred = arguments - .args - .iter() - .any(|arg| matches!(arg, ast::Expr::Starred(_))); - let has_double_star = arguments.keywords.iter().any(|k| k.arg.is_none()); + let has_starred = args.iter().any(|arg| matches!(arg, ast::Expr::Starred(_))); + let has_double_star = keywords.iter().any(|k| k.arg.is_none()); // Check if exceeds CPython's stack-use guideline. // With CALL_KW, kwargs values go on stack but keys go in a const tuple, @@ -8794,22 +9787,29 @@ impl Compiler { // Simple call path: no * or ** args let implicit_generator_range = if additional_positional == 0 && nelts == 1 && nkwelts == 0 { - self.cpython_implicit_call_generator_range(&arguments.args[0]) + self.cpython_implicit_call_generator_range(&args[0]) } else { None }; - for arg in &arguments.args { + for arg in args { if let Some(range) = implicit_generator_range { self.compile_expression_with_generator_range(arg, range)?; } else { self.compile_expression(arg)?; } } + let injected_count = if let Some(injected_arg) = injected_arg { + self.set_source_range(call_range); + self.load_name(injected_arg)?; + 1 + } else { + 0 + }; if nkwelts > 0 { // Compile keyword values and build kwnames tuple let mut kwarg_names = Vec::with_capacity(nkwelts); - for keyword in &arguments.keywords { + for keyword in keywords { kwarg_names.push(ConstantData::Str { value: keyword.arg.as_ref().unwrap().as_str().into(), }); @@ -8823,24 +9823,23 @@ impl Compiler { }); self.set_source_range(call_range); - let argc = additional_positional + nelts.to_u32() + nkwelts.to_u32(); + let argc = + additional_positional + nelts.to_u32() + injected_count + nkwelts.to_u32(); emit!(self, Instruction::CallKw { argc }); } else { self.set_source_range(call_range); - let argc = additional_positional + nelts.to_u32(); + let argc = additional_positional + nelts.to_u32() + injected_count; emit!(self, Instruction::Call { argc }); } } else { // ex_call path: has * or ** args // Compile positional arguments - if additional_positional == 0 - && nelts == 1 - && matches!(arguments.args[0], ast::Expr::Starred(_)) + if additional_positional == 0 && nelts == 1 && matches!(args[0], ast::Expr::Starred(_)) { // Single starred arg: pass value directly to CallFunctionEx. // Runtime will convert to tuple and validate with function name. - if let ast::Expr::Starred(ast::ExprStarred { value, .. }) = &arguments.args[0] { + if let ast::Expr::Starred(ast::ExprStarred { value, .. }) = &args[0] { self.compile_expression(value)?; } } else { @@ -8850,14 +9849,15 @@ impl Compiler { // LIST_EXTEND, tuple=1)`, even when the only reason for the // ex-call path is too many non-starred positional arguments. self.set_source_range(call_range); - self.starunpack_helper( - &arguments.args, + self.starunpack_helper_impl( + args, + injected_arg, additional_positional, CollectionType::Tuple, )?; } - self.compile_call_function_ex_keywords(&arguments.keywords, call_range)?; + self.compile_call_function_ex_keywords(keywords, call_range)?; self.set_source_range(call_range); emit!(self, Instruction::CallFunctionEx); @@ -9106,26 +10106,83 @@ impl Compiler { Ok(()) } - fn consume_skipped_nested_scopes_in_expr( - &mut self, - expression: &ast::Expr, - ) -> CompileResult<()> { - use ast::visitor::Visitor; + fn consume_function_annotation_symbol_table_if_used(&mut self) -> CompileResult<()> { + if !self.next_function_annotation_symbol_table_uses_annotations() { + return Ok(()); + } + if !self.push_annotation_symbol_table() { + let current_table = self.current_symbol_table(); + return Err(self.error(CodegenErrorType::SyntaxError(format!( + "no annotation symbol table available in {} (type: {:?})", + current_table.name, current_table.typ + )))); + } + self.pop_annotation_symbol_table(); + Ok(()) + } + + fn consume_skipped_nested_scopes_in_expr( + &mut self, + expression: &ast::Expr, + ) -> CompileResult<()> { + use ast::visitor::Visitor; - struct SkippedScopeVisitor<'a> { - compiler: &'a mut Compiler, + struct SkippedScopeVisitor<'a, 'warnings> { + compiler: &'a mut Compiler<'warnings>, error: Option, } - impl SkippedScopeVisitor<'_> { + impl SkippedScopeVisitor<'_, '_> { fn consume_scope(&mut self) { if self.error.is_none() { self.error = self.compiler.consume_next_sub_table().err(); } } + + fn consume_inlined_comprehension_scope(&mut self) -> bool { + if self.error.is_some() { + return false; + } + let Some(current_table) = self.compiler.symbol_table_stack.last_mut() else { + return false; + }; + if current_table.next_inlined_comprehension_block + < current_table.inlined_comprehension_blocks.len() + { + current_table.next_inlined_comprehension_block += 1; + true + } else { + false + } + } + + fn visit_comprehension_tail( + &mut self, + elt1: &ast::Expr, + elt2: Option<&ast::Expr>, + generators: &[ast::Comprehension], + ) { + if let Some(outermost) = generators.first() { + self.visit_expr(&outermost.target); + for if_expr in &outermost.ifs { + self.visit_expr(if_expr); + } + } + for generator in generators.iter().skip(1) { + self.visit_expr(&generator.target); + self.visit_expr(&generator.iter); + for if_expr in &generator.ifs { + self.visit_expr(if_expr); + } + } + if let Some(elt2) = elt2 { + self.visit_expr(elt2); + } + self.visit_expr(elt1); + } } - impl ast::visitor::Visitor<'_> for SkippedScopeVisitor<'_> { + impl ast::visitor::Visitor<'_> for SkippedScopeVisitor<'_, '_> { fn visit_expr(&mut self, expr: &ast::Expr) { if self.error.is_some() { return; @@ -9149,19 +10206,45 @@ impl Compiler { } self.consume_scope(); } - ast::Expr::ListComp(ast::ExprListComp { generators, .. }) - | ast::Expr::SetComp(ast::ExprSetComp { generators, .. }) - | ast::Expr::Generator(ast::ExprGenerator { generators, .. }) => { + ast::Expr::Generator(ast::ExprGenerator { generators, .. }) => { if let Some(first) = generators.first() { self.visit_expr(&first.iter); } self.consume_scope(); } - ast::Expr::DictComp(ast::ExprDictComp { generators, .. }) => { + ast::Expr::ListComp(ast::ExprListComp { + elt, generators, .. + }) + | ast::Expr::SetComp(ast::ExprSetComp { + elt, generators, .. + }) => { if let Some(first) = generators.first() { self.visit_expr(&first.iter); } - self.consume_scope(); + if self.consume_inlined_comprehension_scope() { + self.visit_comprehension_tail(elt, None, generators); + } else { + self.consume_scope(); + } + } + ast::Expr::DictComp(ast::ExprDictComp { + key, + value, + generators, + .. + }) => { + if let Some(first) = generators.first() { + self.visit_expr(&first.iter); + } + if self.consume_inlined_comprehension_scope() { + if let Some(key) = key.as_deref() { + self.visit_comprehension_tail(key, Some(value), generators); + } else { + self.visit_comprehension_tail(value, None, generators); + } + } else { + self.consume_scope(); + } } _ => ast::visitor::walk_expr(self, expr), } @@ -9180,23 +10263,313 @@ impl Compiler { } } - fn peek_next_sub_table_after_skipped_nested_scopes_in_expr( + fn consume_skipped_nested_scopes_in_parameter_defaults( + &mut self, + parameters: &ast::Parameters, + ) -> CompileResult<()> { + for default in parameters + .posonlyargs + .iter() + .chain(¶meters.args) + .chain(¶meters.kwonlyargs) + .filter_map(|arg| arg.default.as_deref()) + { + self.consume_skipped_nested_scopes_in_expr(default)?; + } + Ok(()) + } + + fn consume_skipped_nested_scopes_in_statements( + &mut self, + statements: &[ast::Stmt], + ) -> CompileResult<()> { + use ast::visitor::Visitor; + + struct SkippedStatementScopeVisitor<'a, 'warnings> { + compiler: &'a mut Compiler<'warnings>, + error: Option, + } + + impl SkippedStatementScopeVisitor<'_, '_> { + fn consume_scope(&mut self) { + if self.error.is_none() { + self.error = self.compiler.consume_next_sub_table().err(); + } + } + + fn consume_function_annotation_scope_if_used(&mut self) { + if self.error.is_none() { + self.error = self + .compiler + .consume_function_annotation_symbol_table_if_used() + .err(); + } + } + + fn visit_parameter_defaults(&mut self, parameters: &ast::Parameters) { + for default in parameters + .posonlyargs + .iter() + .chain(¶meters.args) + .chain(¶meters.kwonlyargs) + .filter_map(|arg| arg.default.as_deref()) + { + self.visit_expr(default); + } + } + + fn visit_decorators(&mut self, decorators: &[ast::Decorator]) { + for decorator in decorators { + self.visit_expr(&decorator.expression); + } + } + + fn visit_arguments(&mut self, arguments: &ast::Arguments) { + for arg in &arguments.args { + self.visit_expr(arg); + } + for keyword in &arguments.keywords { + self.visit_expr(&keyword.value); + } + } + } + + impl ast::visitor::Visitor<'_> for SkippedStatementScopeVisitor<'_, '_> { + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + if self.error.is_some() { + return; + } + + match stmt { + ast::Stmt::FunctionDef(ast::StmtFunctionDef { + parameters, + decorator_list, + type_params, + .. + }) => { + self.visit_parameter_defaults(parameters); + self.visit_decorators(decorator_list); + if type_params.is_some() { + self.consume_scope(); + } else { + self.consume_function_annotation_scope_if_used(); + self.consume_scope(); + } + } + ast::Stmt::ClassDef(ast::StmtClassDef { + arguments, + decorator_list, + type_params, + .. + }) => { + self.visit_decorators(decorator_list); + if type_params.is_some() { + self.consume_scope(); + } + if let Some(arguments) = arguments { + self.visit_arguments(arguments); + } + self.consume_scope(); + } + ast::Stmt::TypeAlias(ast::StmtTypeAlias { type_params, .. }) => { + if type_params.is_some() { + self.consume_scope(); + } + self.consume_scope(); + } + ast::Stmt::AnnAssign(ast::StmtAnnAssign { target, value, .. }) => { + self.visit_expr(target); + if let Some(value) = value { + self.visit_expr(value); + } + } + ast::Stmt::If(ast::StmtIf { + test, + body, + elif_else_clauses, + .. + }) => { + self.visit_expr(test); + for stmt in body { + self.visit_stmt(stmt); + } + for clause in elif_else_clauses { + if let Some(test) = &clause.test { + self.visit_expr(test); + } + for stmt in &clause.body { + self.visit_stmt(stmt); + } + } + } + ast::Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + for stmt in body { + self.visit_stmt(stmt); + } + for handler in handlers { + self.visit_except_handler(handler); + } + for stmt in orelse { + self.visit_stmt(stmt); + } + for stmt in finalbody { + self.visit_stmt(stmt); + } + } + _ => ast::visitor::walk_stmt(self, stmt), + } + } + + fn visit_expr(&mut self, expr: &ast::Expr) { + if self.error.is_some() { + return; + } + self.error = self + .compiler + .consume_skipped_nested_scopes_in_expr(expr) + .err(); + } + + fn visit_except_handler(&mut self, handler: &ast::ExceptHandler) { + if self.error.is_some() { + return; + } + let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + type_, + body, + .. + }) = handler; + if let Some(type_) = type_ { + self.visit_expr(type_); + } + for stmt in body { + self.visit_stmt(stmt); + } + } + } + + let mut visitor = SkippedStatementScopeVisitor { + compiler: self, + error: None, + }; + for statement in statements { + visitor.visit_stmt(statement); + } + if let Some(err) = visitor.error { + Err(err) + } else { + Ok(()) + } + } + + fn consume_skipped_nested_scopes_in_except_handlers( + &mut self, + handlers: &[ast::ExceptHandler], + ) -> CompileResult<()> { + use ast::visitor::Visitor; + + struct SkippedHandlerScopeVisitor<'a, 'warnings> { + compiler: &'a mut Compiler<'warnings>, + error: Option, + } + + impl ast::visitor::Visitor<'_> for SkippedHandlerScopeVisitor<'_, '_> { + fn visit_expr(&mut self, expr: &ast::Expr) { + if self.error.is_some() { + return; + } + self.error = self + .compiler + .consume_skipped_nested_scopes_in_expr(expr) + .err(); + } + + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + if self.error.is_some() { + return; + } + self.error = self + .compiler + .consume_skipped_nested_scopes_in_statements(slice::from_ref(stmt)) + .err(); + } + } + + let mut visitor = SkippedHandlerScopeVisitor { + compiler: self, + error: None, + }; + for handler in handlers { + visitor.visit_except_handler(handler); + if visitor.error.is_some() { + break; + } + } + if let Some(err) = visitor.error { + Err(err) + } else { + Ok(()) + } + } + + fn current_symbol_table_cursors(&self) -> SymbolTableCursors { + let table = self + .symbol_table_stack + .last() + .expect("no current symbol table"); + SymbolTableCursors { + sub_table: table.next_sub_table, + hidden_annotation_block: table.next_hidden_annotation_block, + inlined_comprehension_block: table.next_inlined_comprehension_block, + } + } + + fn set_symbol_table_cursors(&mut self, cursors: SymbolTableCursors) { + let table = self + .symbol_table_stack + .last_mut() + .expect("no current symbol table"); + table.next_sub_table = cursors.sub_table; + table.next_hidden_annotation_block = cursors.hidden_annotation_block; + table.next_inlined_comprehension_block = cursors.inlined_comprehension_block; + } + + fn lookup_comprehension_symbol_table_after_skipped_nested_scopes_in_expr( &mut self, expression: &ast::Expr, - ) -> CompileResult { + comprehension_type: ComprehensionType, + ) -> CompileResult<(SymbolTable, ComprehensionSymbolSource)> { let saved_cursor = self .symbol_table_stack .last() .expect("no current symbol table") .next_sub_table; + let saved_inlined_cursor = self + .symbol_table_stack + .last() + .expect("no current symbol table") + .next_inlined_comprehension_block; let result = (|| { self.consume_skipped_nested_scopes_in_expr(expression)?; let current_table = self .symbol_table_stack .last() .expect("no current symbol table"); + if comprehension_type != ComprehensionType::Generator + && let Some(table) = current_table + .inlined_comprehension_blocks + .get(current_table.next_inlined_comprehension_block) + { + return Ok((table.clone(), ComprehensionSymbolSource::Inlined)); + } if let Some(table) = current_table.sub_tables.get(current_table.next_sub_table) { - Ok(table.clone()) + Ok((table.clone(), ComprehensionSymbolSource::Child)) } else { let name = current_table.name.clone(); let typ = current_table.typ; @@ -9209,6 +10582,10 @@ impl Compiler { .last_mut() .expect("no current symbol table") .next_sub_table = saved_cursor; + self.symbol_table_stack + .last_mut() + .expect("no current symbol table") + .next_inlined_comprehension_block = saved_inlined_cursor; result } @@ -9231,7 +10608,15 @@ impl Compiler { if let Some(info) = self.code_stack.last_mut() { info.flags = flags | (info.flags - & (CodeFlags::NESTED | CodeFlags::METHOD | CodeFlags::FUTURE_ANNOTATIONS)); + & (bytecode::CodeFlags::NESTED + | bytecode::CodeFlags::METHOD + | bytecode::CodeFlags::FUTURE_DIVISION + | bytecode::CodeFlags::FUTURE_ABSOLUTE_IMPORT + | bytecode::CodeFlags::FUTURE_WITH_STATEMENT + | bytecode::CodeFlags::FUTURE_PRINT_FUNCTION + | bytecode::CodeFlags::FUTURE_UNICODE_LITERALS + | bytecode::CodeFlags::FUTURE_GENERATOR_STOP + | bytecode::CodeFlags::FUTURE_ANNOTATIONS)); info.metadata.argcount = arg_count; info.metadata.posonlyargcount = posonlyarg_count; info.metadata.kwonlyargcount = kwonlyarg_count; @@ -9254,12 +10639,16 @@ impl Compiler { ) -> CompileResult<()> { let prev_ctx = self.ctx; let has_an_async_gen = generators.iter().any(|g| g.is_async); + let is_top_level_await_context = self.opts.allow_top_level_await + && prev_ctx.func == FunctionContext::NoFunction + && !prev_ctx.in_class; // Check for async comprehension outside async function (list/set/dict only, not generator expressions) // Use in_async_scope to allow nested async comprehensions inside an async function if comprehension_type != ComprehensionType::Generator && (has_an_async_gen || element_contains_await) && !prev_ctx.in_async_scope + && !is_top_level_await_context { return Err(self.error(CodegenErrorType::InvalidAsyncComprehension)); } @@ -9270,7 +10659,7 @@ impl Compiler { let is_async_list_set_dict_comprehension = comprehension_type != ComprehensionType::Generator && (has_an_async_gen || element_contains_await) - && prev_ctx.in_async_scope; + && (prev_ctx.in_async_scope || is_top_level_await_context); let is_async_generator_comprehension = comprehension_type == ComprehensionType::Generator && (has_an_async_gen || element_contains_await); @@ -9282,8 +10671,11 @@ impl Compiler { // We must have at least one generator: assert!(!generators.is_empty()); let outermost = &generators[0]; - let comp_table = - self.peek_next_sub_table_after_skipped_nested_scopes_in_expr(&outermost.iter)?; + let (comp_table, comp_source) = self + .lookup_comprehension_symbol_table_after_skipped_nested_scopes_in_expr( + &outermost.iter, + comprehension_type, + )?; let is_inlined = self.is_inlined_comprehension_context(comprehension_type, &comp_table); @@ -9299,6 +10691,7 @@ impl Compiler { generators, compile_element, (comprehension_range, element_range, outer_backedge_range), + comp_source, ); } @@ -9315,9 +10708,9 @@ impl Compiler { in_async_scope: prev_ctx.in_async_scope || is_async, }; - let flags = CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED; + let flags = bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED; let flags = if is_async { - flags | CodeFlags::COROUTINE + flags | bytecode::CodeFlags::COROUTINE } else { flags }; @@ -9459,9 +10852,6 @@ impl Compiler { is_async, end_async_for_target, } => { - self.set_source_range(backedge_range); - emit!(self, PseudoInstruction::Jump { delta: loop_block }); - self.use_cpython_label_block(if_cleanup_block); self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); @@ -9490,6 +10880,7 @@ impl Compiler { if return_none { self.emit_return_const_no_location(ConstantData::None); } else { + self.set_source_range(comprehension_range); self.emit_return_value(); } @@ -9506,6 +10897,7 @@ impl Compiler { emit!(self, Instruction::Reraise { depth: 1u32 }); self.set_no_location(); } + self.emit_return_const_no_location(ConstantData::None); let code = self.exit_scope(); @@ -9513,7 +10905,7 @@ impl Compiler { // Create comprehension function with closure self.set_source_range(comprehension_range); - self.make_closure(code, MakeFunctionFlags::new())?; + self.make_closure(code, bytecode::MakeFunctionFlags::new())?; // Evaluate iterated item and get its iterator. self.compile_comprehension_iter(outermost)?; @@ -9543,37 +10935,32 @@ impl Compiler { generators: &[ast::Comprehension], compile_element: &dyn Fn(&mut Self, usize) -> CompileResult<()>, ranges: (TextRange, TextRange, TextRange), + comp_source: ComprehensionSymbolSource, ) -> CompileResult<()> { let (comprehension_range, element_range, outer_backedge_range) = ranges; - fn collect_bound_names(target: &ast::Expr, out: &mut Vec) { - match target { - ast::Expr::Name(ast::ExprName { id, .. }) => out.push(id.to_string()), - ast::Expr::Tuple(ast::ExprTuple { elts, .. }) - | ast::Expr::List(ast::ExprList { elts, .. }) => { - for elt in elts { - collect_bound_names(elt, out); - } - } - ast::Expr::Starred(ast::ExprStarred { value, .. }) => { - collect_bound_names(value, out); - } - _ => {} - } - } - // Compile the outermost iterator first. Its expression may reference // nested scopes (e.g. lambdas) whose sub_tables sit at the current // position in the parent's list. Those must be consumed before we // splice in the comprehension's own children. self.compile_comprehension_iter(&generators[0])?; - self.symbol_table_stack - .last_mut() - .expect("no current symbol table") - .next_sub_table += 1; + match comp_source { + ComprehensionSymbolSource::Child => { + self.symbol_table_stack + .last_mut() + .expect("no current symbol table") + .next_sub_table += 1; + } + ComprehensionSymbolSource::Inlined => { + self.symbol_table_stack + .last_mut() + .expect("no current symbol table") + .next_inlined_comprehension_block += 1; + } + } let was_in_inlined_comp = self.current_code_info().in_inlined_comp; let saved_source_range = self.current_source_range; - let in_class_block = { + let tweak_in_class_block = { let ct = self.current_symbol_table(); ct.typ == CompilerScope::Class && !was_in_inlined_comp }; @@ -9583,9 +10970,13 @@ impl Compiler { let mut changed_fast_hidden = Vec::new(); let result = (|| { - // Splice the comprehension's children (e.g. nested inlined - // comprehensions) into the parent so the compiler can find them. - if !comp_table.sub_tables.is_empty() { + // If the symbol table still carries the inlined comprehension as + // a child, splice its children here. CPython's symtable normally + // performs this splice before codegen, and the Inlined source path + // has already done so. + if matches!(comp_source, ComprehensionSymbolSource::Child) + && !comp_table.sub_tables.is_empty() + { let current_table = self .symbol_table_stack .last_mut() @@ -9595,30 +10986,21 @@ impl Compiler { current_table.sub_tables.insert(insert_pos + i, st.clone()); } } - let mut source_order_bound_names = Vec::new(); - for generator in generators { - collect_bound_names(&generator.target, &mut source_order_bound_names); - } - let mut pushed_locals: Vec = Vec::new(); - for name in source_order_bound_names - .into_iter() - .chain(comp_table.symbols.keys().cloned()) - { - if pushed_locals.iter().any(|existing| existing == &name) { - continue; + let mut fast_hidden_locals: Vec = Vec::new(); + for (name, sym) in &comp_table.symbols { + if sym.flags.contains(SymbolFlags::PARAMETER) { + continue; // skip .0 } - if let Some(sym) = comp_table.symbols.get(&name) { - if sym.flags.contains(SymbolFlags::PARAMETER) { - continue; // skip .0 - } - let is_local = sym - .flags - .intersects(SymbolFlags::ASSIGNED | SymbolFlags::ITER) - && !sym.flags.contains(SymbolFlags::NONLOCAL); - if is_local { - pushed_locals.push(name); - } + let is_local = sym + .flags + .intersects(SymbolFlags::ASSIGNED | SymbolFlags::ITER) + && !sym.flags.contains(SymbolFlags::NONLOCAL); + if is_local { + pushed_locals.push(name.clone()); + } + if is_local || tweak_in_class_block { + fast_hidden_locals.push(name.clone()); } } @@ -9638,7 +11020,7 @@ impl Compiler { if (comp_scope != outer_scope && comp_scope != SymbolScope::Free && !(comp_scope == SymbolScope::Cell && outer_scope == SymbolScope::Free)) - || in_class_block + || tweak_in_class_block { temp_symbols.insert(name.clone(), outer_sym.clone()); let current_table = @@ -9648,7 +11030,7 @@ impl Compiler { } } if !self.ctx.in_func() { - for name in &pushed_locals { + for name in &fast_hidden_locals { if self .current_code_info() .metadata @@ -9953,12 +11335,28 @@ impl Compiler { // Python 3 features; we've already implemented them by default "nested_scopes" | "generators" | "division" | "absolute_import" | "with_statement" | "print_function" | "unicode_literals" | "generator_stop" => {} - "annotations" => self.future_annotations = true, - other => { + // Accept the CPython future feature name, but do not implement + // Barry-as-BDFL parser mode. + "barry_as_FLUFL" => {} + "annotations" => { + self.future_annotations = true; + self.future_features + .insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + } + "braces" => { return Err( - self.error(CodegenErrorType::InvalidFutureFeature(other.to_owned())) + self.error_ranged(CodegenErrorType::InvalidFutureBraces, feature.range) ); } + other => { + return Err(self.error_ranged( + CodegenErrorType::InvalidFutureFeature(other.to_owned()), + feature.range, + )); + } } } Ok(()) @@ -10012,22 +11410,18 @@ impl Compiler { } let instr = instr.into(); let opcode = AnyOpcode::from(instr); - debug_assert!( !instr.is_assembler(), "CPython codegen_addop_* must not emit assembler-only opcodes" ); - debug_assert!( opcode.has_arg() || instr.has_target() || u32::from(arg) == 0, "CPython _PyInstructionSequence_Addop requires either OPCODE_HAS_ARG, HAS_TARGET, or oparg == 0" ); - debug_assert!( target == BlockIdx::NULL || instr.has_target(), "CPython codegen_addop_j only accepts HAS_TARGET opcodes" ); - let range = self.current_source_range; let source = self.source_file.to_source_code(); let location = source.source_location(range.start(), PositionEncoding::Utf8); @@ -10111,7 +11505,6 @@ impl Compiler { .blocks .first_mut() .expect("code unit must have an entry block"); - debug_assert!( entry .used_instructions() @@ -10125,11 +11518,14 @@ impl Compiler { }), "scope entry must start with a function-start RESUME" ); - debug_assert!( !entry.used_instructions().iter().any(|info| matches!( - info.instr.real_opcode(), - Some(Opcode::ReturnGenerator | Opcode::MakeCell | Opcode::CopyFreeVars) + info.instr.real(), + Some( + Instruction::ReturnGenerator + | Instruction::MakeCell { .. } + | Instruction::CopyFreeVars { .. } + ) )), "CPython inserts StopIteration cleanup before CFG prefix instructions" ); @@ -10341,6 +11737,9 @@ impl Compiler { } fn try_fold_constant_expr(&mut self, expr: &ast::Expr) -> CompileResult> { + if let Some(constant) = self.public_ast_constant_override(expr) { + return Ok(Some(constant)); + } Ok(Some(match expr { ast::Expr::NumberLiteral(num) => match &num.value { ast::Number::Int(int) => ConstantData::Integer { @@ -10477,7 +11876,7 @@ impl Compiler { } (ast::UnaryOp::Not, ConstantData::Tuple { .. }) => return Ok(None), (ast::UnaryOp::Not, value) => ConstantData::Boolean { - value: !value.truthiness(), + value: !Self::constant_truthiness(&value), }, _ => return Ok(None), } @@ -10497,9 +11896,9 @@ impl Compiler { let mut selected = first; match op { ast::BoolOp::Or => { - if !selected.truthiness() { + if !Self::constant_truthiness(&selected) { for constant in iter { - let is_truthy = constant.truthiness(); + let is_truthy = Self::constant_truthiness(&constant); selected = constant; if is_truthy { break; @@ -10508,9 +11907,9 @@ impl Compiler { } } ast::BoolOp::And => { - if selected.truthiness() { + if Self::constant_truthiness(&selected) { for constant in iter { - let is_truthy = constant.truthiness(); + let is_truthy = Self::constant_truthiness(&constant); selected = constant; if !is_truthy { break; @@ -10552,24 +11951,45 @@ impl Compiler { })) } + fn try_compile_match_mapping_key_direct_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + if matches!( + expr, + ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) + ) { + return self.try_compile_ast_constant(expr); + } + self.try_compile_match_pattern_direct_literal(expr) + } + + fn is_unexpected_match_literal_constant(expr: &ast::Expr) -> bool { + matches!( + expr, + ast::Expr::BooleanLiteral(_) + | ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + ) + } + + fn try_compile_match_pattern_direct_literal( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + match expr { + ast::Expr::NumberLiteral(_) + | ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) => self.try_compile_ast_constant(expr), + _ => Ok(None), + } + } + fn try_negate_match_pattern_constant(constant: ConstantData) -> Option { match constant { ConstantData::Integer { value } => Some(ConstantData::Integer { value: -value }), ConstantData::Float { value } => Some(ConstantData::Float { value: -value }), ConstantData::Complex { value } => Some(ConstantData::Complex { value: -value }), - ConstantData::Boolean { value } => Some(ConstantData::Integer { - value: -BigInt::from(u8::from(value)), - }), - _ => None, - } - } - - fn constant_as_match_pattern_complex(constant: &ConstantData) -> Option> { - match constant { - ConstantData::Integer { value } => Some(Complex::new(value.to_f64()?, 0.0)), - ConstantData::Float { value } => Some(Complex::new(*value, 0.0)), - ConstantData::Complex { value } => Some(*value), - ConstantData::Boolean { value } => Some(Complex::new(f64::from(u8::from(*value)), 0.0)), _ => None, } } @@ -10579,51 +11999,20 @@ impl Compiler { left: &ConstantData, right: &ConstantData, ) -> Option { - if let (ConstantData::Integer { value: left }, ConstantData::Integer { value: right }) = - (left, right) - { - return match op { - ast::Operator::Add => Some(ConstantData::Integer { - value: left + right, - }), - ast::Operator::Sub => Some(ConstantData::Integer { - value: left - right, - }), - _ => None, - }; - } - - let left_is_complex = matches!(left, ConstantData::Complex { .. }); - let right_is_complex = matches!(right, ConstantData::Complex { .. }); - if left_is_complex || right_is_complex { - let left = Self::constant_as_match_pattern_complex(left)?; - let right = Self::constant_as_match_pattern_complex(right)?; - let value = match op { - ast::Operator::Add => Complex::new(left.re + right.re, left.im + right.im), - ast::Operator::Sub => { - let imag = if !left_is_complex && right_is_complex { - -right.im - } else { - left.im - right.im - }; - Complex::new(left.re - right.re, imag) - } - _ => return None, - }; - return Some(ConstantData::Complex { value }); - } - - let left = Self::constant_as_match_pattern_complex(left)?; - let right = Self::constant_as_match_pattern_complex(right)?; - match op { - ast::Operator::Add => Some(ConstantData::Float { - value: left.re + right.re, - }), - ast::Operator::Sub => Some(ConstantData::Float { - value: left.re - right.re, - }), - _ => None, - } + let left = match left { + ConstantData::Integer { value } => value.to_f64()?, + ConstantData::Float { value } => *value, + _ => return None, + }; + let ConstantData::Complex { value: right } = right else { + return None; + }; + let value = match op { + ast::Operator::Add => Complex::new(left + right.re, right.im), + ast::Operator::Sub => Complex::new(left - right.re, -right.im), + _ => return None, + }; + Some(ConstantData::Complex { value }) } fn try_fold_match_pattern_const_expr( @@ -10639,7 +12028,8 @@ impl Compiler { operand, .. }) => { - let Some(constant) = self.try_compile_ast_constant(operand)? else { + let Some(constant) = self.try_compile_match_pattern_number_constant(operand)? + else { return Ok(None); }; Self::try_negate_match_pattern_constant(constant) @@ -10647,13 +12037,10 @@ impl Compiler { ast::Expr::BinOp(ast::ExprBinOp { left, op, right, .. }) if matches!(op, ast::Operator::Add | ast::Operator::Sub) => { - let Some(left) = (match self.try_fold_match_pattern_const_expr(left)? { - Some(constant) => Some(constant), - None => self.try_compile_ast_constant(left)?, - }) else { + let Some(left) = self.try_compile_match_pattern_signed_real_constant(left)? else { return Ok(None); }; - let Some(right) = self.try_compile_ast_constant(right)? else { + let Some(right) = self.try_compile_match_pattern_imaginary_constant(right)? else { return Ok(None); }; Self::try_fold_match_pattern_binop(*op, &left, &right) @@ -10662,8 +12049,66 @@ impl Compiler { }) } + fn try_compile_match_pattern_signed_real_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + if let Some(constant) = self.try_compile_match_pattern_real_constant(expr)? { + return Ok(Some(constant)); + } + let ast::Expr::UnaryOp(ast::ExprUnaryOp { + op: ast::UnaryOp::USub, + operand, + .. + }) = expr + else { + return Ok(None); + }; + let Some(constant) = self.try_compile_match_pattern_real_constant(operand)? else { + return Ok(None); + }; + Ok(Self::try_negate_match_pattern_constant(constant)) + } + + fn try_compile_match_pattern_real_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + let Some(constant) = self.try_compile_match_pattern_number_constant(expr)? else { + return Ok(None); + }; + Ok(match constant { + ConstantData::Integer { .. } | ConstantData::Float { .. } => Some(constant), + _ => None, + }) + } + + fn try_compile_match_pattern_imaginary_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + let Some(constant) = self.try_compile_match_pattern_number_constant(expr)? else { + return Ok(None); + }; + Ok(match constant { + ConstantData::Complex { .. } => Some(constant), + _ => None, + }) + } + + fn try_compile_match_pattern_number_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + match expr { + ast::Expr::NumberLiteral(_) => self.try_compile_ast_constant(expr), + _ => Ok(None), + } + } + fn compile_match_pattern_expr(&mut self, expr: &ast::Expr) -> CompileResult<()> { if let Some(constant) = self.try_fold_match_pattern_const_expr(expr)? { + self.set_source_range(expr.range()); self.emit_load_const(constant); } else { self.compile_expression(expr)?; @@ -10685,7 +12130,7 @@ impl Compiler { if [lower, upper, step] .into_iter() .flatten() - .any(|expr| !expr.is_constant()) + .any(|expr| !self.is_constant_expr(expr)) { return Ok(None); } @@ -10792,6 +12237,12 @@ impl Compiler { emit!(self, Instruction::ReturnValue) } + fn allows_top_level_await_in_current_context(&self) -> bool { + self.opts.allow_top_level_await + && self.ctx.func == FunctionContext::NoFunction + && !self.ctx.in_class + } + fn current_code_info(&mut self) -> &mut ir::CodeInfo { self.code_stack.last_mut().expect("no code on stack") } @@ -10836,17 +12287,12 @@ impl Compiler { _ => {} } } - if !found_loop { - let err_type = if is_break { - CodegenErrorType::InvalidBreak - } else { - CodegenErrorType::InvalidContinue - }; - - return Err(self.error_ranged(err_type, range)); + if is_break { + return Err(self.error_ranged(CodegenErrorType::InvalidBreak, range)); + } + return Err(self.error_ranged(CodegenErrorType::InvalidContinue, range)); } - return Ok(()); } @@ -10878,13 +12324,19 @@ impl Compiler { debug_assert!(loop_fblock.fb_block.is_jump_target_label()); loop_fblock.fb_block }; - if let Some(loc) = unwind_loc { + let jump_is_artificial = if let Some(loc) = unwind_loc { self.set_source_range(loc); + false } else { - self.set_source_range(range); + true }; - self.emit_jump_label(PseudoOpcode::Jump, target_label); - if unwind_loc.is_none() { + self.emit_jump_label( + PseudoInstruction::Jump { + delta: OpArgMarker::marker(), + }, + target_label, + ); + if jump_is_artificial { self.set_no_location(); } self.set_source_range(prev_source_range); @@ -10900,7 +12352,7 @@ impl Compiler { let code = self.current_code_info(); let cur = code.current_block; if !code.blocks[cur.idx()] - .instructions + .used_instructions() .last() .is_some_and(|instr| instr.instr.is_terminator()) { @@ -11081,7 +12533,7 @@ impl Compiler { if source.line_index(loc_range.start()) == source.line_index(attr_range.end()) { return loc_range; } - let Ok(attr_len) = u32::try_from(attr.len()) else { + let Ok(attr_len) = u32::try_from(attr.chars().count()) else { return TextRange::new(loc_range.start(), loc_range.end()); }; let attr_len = TextSize::new(attr_len); @@ -11113,10 +12565,10 @@ impl Compiler { let is_async = self.ctx.func == FunctionContext::AsyncFunction; let flags = &mut self.current_code_info().flags; if is_async { - flags.remove(CodeFlags::COROUTINE); - flags.insert(CodeFlags::ASYNC_GENERATOR); + flags.remove(bytecode::CodeFlags::COROUTINE); + flags.insert(bytecode::CodeFlags::ASYNC_GENERATOR); } else { - flags.insert(CodeFlags::GENERATOR); + flags.insert(bytecode::CodeFlags::GENERATOR); } } @@ -11184,7 +12636,7 @@ impl Compiler { let fstring_range = fstring.range; let fstring = fstring.value.as_slice(); if self.count_fstring_parts(fstring) > STACK_USE_GUIDELINE { - return self.compile_fstring_parts_joined(fstring); + return self.compile_fstring_parts_joined(fstring, fstring_range); } let mut element_count = 0; @@ -11198,7 +12650,7 @@ impl Compiler { &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, - false, + None, )?; } self.finish_fstring( @@ -11211,7 +12663,55 @@ impl Compiler { Ok(()) } - fn compile_fstring_parts_joined(&mut self, fstring: &[ast::FStringPart]) -> CompileResult<()> { + fn compile_public_ast_joined_str( + &mut self, + fstring: &ast::ExprFString, + joined_str: PublicAstExprList, + ) -> CompileResult<()> { + let range = fstring.range; + let values = joined_str.values; + let value_count: u32 = values + .len() + .try_into() + .expect("JoinedStr value count overflowed"); + if value_count > STACK_USE_GUIDELINE { + self.set_source_range(range); + self.emit_load_const(ConstantData::Str { + value: Wtf8Buf::new(), + }); + let join_idx = self.get_global_name_index("join"); + self.emit_load_attr_method(join_idx); + emit!(self, Instruction::BuildList { count: 0 }); + for value in &values { + self.compile_expression(value)?; + self.set_source_range(range); + emit!(self, Instruction::ListAppend { i: 1 }); + } + self.set_source_range(range); + emit!(self, Instruction::Call { argc: 1 }); + } else { + for value in &values { + self.compile_expression(value)?; + } + if value_count > 1 { + self.set_source_range(range); + emit!(self, Instruction::BuildString { count: value_count }); + } else if value_count == 0 { + self.set_source_range(range); + self.emit_load_const(ConstantData::Str { + value: Wtf8Buf::new(), + }); + } + } + Ok(()) + } + + fn compile_fstring_parts_joined( + &mut self, + fstring: &[ast::FStringPart], + fstring_range: TextRange, + ) -> CompileResult<()> { + self.set_source_range(fstring_range); self.emit_load_const(ConstantData::Str { value: Wtf8Buf::new(), }); @@ -11230,7 +12730,7 @@ impl Compiler { &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, - true, + Some(fstring_range), )?; } self.finish_fstring_join( @@ -11238,6 +12738,7 @@ impl Compiler { pending_literal_range, pending_literal_no_location, element_count, + fstring_range, ); Ok(()) } @@ -11249,7 +12750,7 @@ impl Compiler { pending_literal_range: &mut Option, pending_literal_no_location: &mut bool, element_count: &mut u32, - append_to_join_list: bool, + join_append_range: Option, ) -> CompileResult<()> { match part { ast::FStringPart::Literal(string) => { @@ -11271,7 +12772,7 @@ impl Compiler { pending_literal, (pending_literal_range, pending_literal_no_location), element_count, - append_to_join_list, + join_append_range, ), } } @@ -11291,7 +12792,7 @@ impl Compiler { &mut pending_literal_no_location, &mut element_count, keep_empty, - false, + None, ); if element_count == 0 { @@ -11320,6 +12821,7 @@ impl Compiler { mut pending_literal_range: Option, mut pending_literal_no_location: bool, mut element_count: u32, + fstring_range: TextRange, ) { let keep_empty = element_count == 0; self.emit_pending_fstring_literal( @@ -11328,8 +12830,9 @@ impl Compiler { &mut pending_literal_no_location, &mut element_count, keep_empty, - true, + Some(fstring_range), ); + self.set_source_range(fstring_range); emit!(self, Instruction::Call { argc: 1 }); } @@ -11340,7 +12843,7 @@ impl Compiler { pending_literal_no_location: &mut bool, element_count: &mut u32, keep_empty: bool, - append_to_join_list: bool, + join_append_range: Option, ) { let Some(value) = pending_literal.take() else { return; @@ -11364,7 +12867,8 @@ impl Compiler { self.set_no_location(); } *element_count += 1; - if append_to_join_list { + if let Some(join_append_range) = join_append_range { + self.set_source_range(join_append_range); emit!(self, Instruction::ListAppend { i: 1 }); } } @@ -11439,7 +12943,8 @@ impl Compiler { fstring_range: Option, ) -> CompileResult<()> { if self.count_fstring_elements(flags, fstring_elements) > STACK_USE_GUIDELINE { - return self.compile_fstring_elements_joined(flags, fstring_elements); + let fstring_range = fstring_range.unwrap_or(self.current_source_range); + return self.compile_fstring_elements_joined(flags, fstring_elements, fstring_range); } let mut element_count = 0; @@ -11452,7 +12957,7 @@ impl Compiler { &mut pending_literal, (&mut pending_literal_range, &mut pending_literal_no_location), &mut element_count, - false, + None, )?; self.finish_fstring( pending_literal, @@ -11468,7 +12973,9 @@ impl Compiler { &mut self, flags: ast::FStringFlags, fstring_elements: &ast::InterpolatedStringElements, + fstring_range: TextRange, ) -> CompileResult<()> { + self.set_source_range(fstring_range); self.emit_load_const(ConstantData::Str { value: Wtf8Buf::new(), }); @@ -11486,13 +12993,14 @@ impl Compiler { &mut pending_literal, (&mut pending_literal_range, &mut pending_literal_no_location), &mut element_count, - true, + Some(fstring_range), )?; self.finish_fstring_join( pending_literal, pending_literal_range, pending_literal_no_location, element_count, + fstring_range, ); Ok(()) } @@ -11517,7 +13025,7 @@ impl Compiler { pending_literal: &mut Option, pending_literal_meta: (&mut Option, &mut bool), element_count: &mut u32, - append_to_join_list: bool, + join_append_range: Option, ) -> CompileResult<()> { let (pending_literal_range, pending_literal_no_location) = pending_literal_meta; for element in fstring_elements { @@ -11542,7 +13050,18 @@ impl Compiler { ast::ConversionFlag::Ascii => ConvertValueOparg::Ascii, }; - if let Some(ast::DebugText { leading, trailing }) = &fstring_expr.debug_text { + if let Some(debug_text) = &fstring_expr.debug_text { + let leading = debug_text.leading(); + let trailing = debug_text.trailing(); + self.emit_pending_fstring_literal( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + false, + join_append_range, + ); + let range = fstring_expr.expression.range(); let leading = strip_fstring_debug_comments(leading); let trailing = strip_fstring_debug_comments(trailing); @@ -11562,16 +13081,9 @@ impl Compiler { ); let text: Wtf8Buf = text.into(); - if pending_literal.is_none() { - *pending_literal_range = Some(debug_text_range); - *pending_literal_no_location = false; - *pending_literal = Some(Wtf8Buf::new()); - } else { - Self::extend_pending_literal_range( - pending_literal_range, - debug_text_range, - ); - } + *pending_literal_range = Some(debug_text_range); + *pending_literal_no_location = false; + *pending_literal = Some(Wtf8Buf::new()); pending_literal.as_mut().unwrap().push_wtf8(text.as_ref()); // If debug text is present, apply repr conversion when no `format_spec` specified. @@ -11590,7 +13102,7 @@ impl Compiler { pending_literal_no_location, element_count, false, - append_to_join_list, + join_append_range, ); self.compile_expression(&fstring_expr.expression)?; @@ -11606,27 +13118,40 @@ impl Compiler { } } - match &fstring_expr.format_spec { - Some(format_spec) => { - let format_spec_range = - self.cpython_format_spec_range(format_spec.range); - self.compile_fstring_elements( - flags, - &format_spec.elements, - Some(format_spec_range), - )?; + if let Some(formatted_value) = self + .public_ast_formatted_value_override_by_index( + fstring_expr.node_index.load(), + ) + && let Some(format_spec) = &formatted_value.format_spec + { + self.compile_expression(format_spec)?; - self.set_source_range(formatted_value_range); - emit!(self, Instruction::FormatWithSpec); - } - None => { - self.set_source_range(formatted_value_range); - emit!(self, Instruction::FormatSimple); + self.set_source_range(formatted_value_range); + emit!(self, Instruction::FormatWithSpec); + } else { + match &fstring_expr.format_spec { + Some(format_spec) => { + let format_spec_range = + self.cpython_format_spec_range(format_spec.range); + self.compile_fstring_elements( + flags, + &format_spec.elements, + Some(format_spec_range), + )?; + + self.set_source_range(formatted_value_range); + emit!(self, Instruction::FormatWithSpec); + } + None => { + self.set_source_range(formatted_value_range); + emit!(self, Instruction::FormatSimple); + } } } *element_count += 1; - if append_to_join_list { + if let Some(join_append_range) = join_append_range { + self.set_source_range(join_append_range); emit!(self, Instruction::ListAppend { i: 1 }); } } @@ -11672,7 +13197,10 @@ impl Compiler { } } ast::InterpolatedStringElement::Interpolation(fstring_expr) => { - if let Some(ast::DebugText { leading, trailing }) = &fstring_expr.debug_text { + if let Some(debug_text) = &fstring_expr.debug_text { + let leading = debug_text.leading(); + let trailing = debug_text.trailing(); + Self::count_pending_fstring_literal(pending_literal, element_count, false); let range = fstring_expr.expression.range(); let source = self.source_file.slice(range); let text = [ @@ -11683,9 +13211,9 @@ impl Compiler { .concat(); let text: Wtf8Buf = text.into(); - pending_literal - .get_or_insert_with(Wtf8Buf::new) - .push_wtf8(text.as_ref()); + let mut debug_text = Wtf8Buf::new(); + debug_text.push_wtf8(text.as_ref()); + *pending_literal = Some(debug_text); } Self::count_pending_fstring_literal(pending_literal, element_count, false); @@ -11701,8 +13229,9 @@ impl Compiler { // strings tuple first, then evaluating interpolations left-to-right. let tstring_value = &expr_tstring.value; - let mut all_strings: Vec = Vec::new(); + let mut all_strings: Vec<(Wtf8Buf, TextRange)> = Vec::new(); let mut current_string = Wtf8Buf::new(); + let mut current_string_range = None; let mut interp_count: u32 = 0; for tstring in tstring_value { @@ -11710,19 +13239,26 @@ impl Compiler { tstring, &mut all_strings, &mut current_string, + &mut current_string_range, &mut interp_count, + expr_tstring.range, ); } - all_strings.push(core::mem::take(&mut current_string)); + all_strings.push(( + core::mem::take(&mut current_string), + current_string_range.unwrap_or(expr_tstring.range), + )); let string_count: u32 = all_strings .len() .try_into() .expect("t-string string count overflowed"); - for s in &all_strings { + for (s, range) in &all_strings { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: s.clone() }); } + self.set_source_range(expr_tstring.range); emit!( self, Instruction::BuildTuple { @@ -11734,32 +13270,169 @@ impl Compiler { self.compile_tstring_interpolations(tstring)?; } + self.set_source_range(expr_tstring.range); emit!( self, Instruction::BuildTuple { count: interp_count } ); + self.set_source_range(expr_tstring.range); emit!(self, Instruction::BuildTemplate); Ok(()) } + fn compile_public_ast_template_str( + &mut self, + expr_tstring: &ast::ExprTString, + template_str: PublicAstExprList, + ) -> CompileResult<()> { + let values = template_str.values; + let mut last_was_interpolation = true; + let mut strings_len = 0; + for value in &values { + if self + .public_ast_template_value_interpolation(value) + .is_some() + { + if last_was_interpolation { + self.set_source_range(expr_tstring.range); + self.emit_load_const(ConstantData::Str { + value: Wtf8Buf::new(), + }); + strings_len += 1; + } + last_was_interpolation = true; + } else { + self.compile_expression(value)?; + strings_len += 1; + last_was_interpolation = false; + } + } + if last_was_interpolation { + self.set_source_range(expr_tstring.range); + self.emit_load_const(ConstantData::Str { + value: Wtf8Buf::new(), + }); + strings_len += 1; + } + self.set_source_range(expr_tstring.range); + emit!(self, Instruction::BuildTuple { count: strings_len }); + + let mut interpolations_len = 0; + for value in &values { + if let Some((tstring, interpolation)) = + self.public_ast_template_value_interpolation(value) + { + self.compile_public_ast_interpolation(tstring, interpolation)?; + interpolations_len += 1; + } + } + self.set_source_range(expr_tstring.range); + emit!( + self, + Instruction::BuildTuple { + count: interpolations_len + } + ); + self.set_source_range(expr_tstring.range); + emit!(self, Instruction::BuildTemplate); + Ok(()) + } + + fn public_ast_template_value_interpolation<'a>( + &self, + value: &'a ast::Expr, + ) -> Option<(&'a ast::ExprTString, PublicAstInterpolation)> { + let ast::Expr::TString(tstring) = value else { + return None; + }; + let interpolation = self.public_ast_interpolation_override(tstring)?; + Self::single_tstring_interpolation(tstring)?; + Some((tstring, interpolation)) + } + + fn compile_public_ast_interpolation( + &mut self, + expr_tstring: &ast::ExprTString, + interpolation: PublicAstInterpolation, + ) -> CompileResult { + let Some(interp) = Self::single_tstring_interpolation(expr_tstring) else { + return Ok(false); + }; + + self.compile_interpolation(interp, interpolation)?; + Ok(true) + } + + fn compile_interpolation( + &mut self, + interp: &ast::InterpolatedElement, + interpolation: PublicAstInterpolation, + ) -> CompileResult<()> { + self.compile_expression(&interp.expression)?; + self.set_source_range(interp.range); + self.emit_load_const(interpolation.str); + + let conversion = match interp.conversion { + ast::ConversionFlag::None => 0, + ast::ConversionFlag::Str => 1, + ast::ConversionFlag::Repr => 2, + ast::ConversionFlag::Ascii => 3, + }; + + let has_format_spec = interpolation.format_spec.is_some(); + if let Some(format_spec) = &interpolation.format_spec { + self.compile_expression(format_spec)?; + } + + let format = 2 | (conversion << 2) | u32::from(has_format_spec); + self.set_source_range(interp.range); + emit!(self, Instruction::BuildInterpolation { format }); + Ok(()) + } + + fn single_tstring_interpolation( + expr_tstring: &ast::ExprTString, + ) -> Option<&ast::InterpolatedElement> { + let [tstring] = expr_tstring.value.as_slice() else { + return None; + }; + let mut elements = tstring.elements.iter(); + let ast::InterpolatedStringElement::Interpolation(interp) = elements.next()? else { + return None; + }; + if elements.next().is_some() { + return None; + } + Some(interp) + } + fn collect_tstring_strings( &self, tstring: &ast::TString, - strings: &mut Vec, + strings: &mut Vec<(Wtf8Buf, TextRange)>, current_string: &mut Wtf8Buf, + current_string_range: &mut Option, interp_count: &mut u32, + template_range: TextRange, ) { for element in &tstring.elements { match element { ast::InterpolatedStringElement::Literal(lit) => { + if current_string_range.is_none() { + *current_string_range = Some(lit.range); + } else { + Self::extend_pending_literal_range(current_string_range, lit.range); + } current_string .push_wtf8(&self.compile_tstring_literal_value(lit, tstring.flags)); } ast::InterpolatedStringElement::Interpolation(interp) => { - if let Some(ast::DebugText { leading, trailing }) = &interp.debug_text { + if let Some(debug_text) = &interp.debug_text { + let leading = debug_text.leading(); + let trailing = debug_text.trailing(); let range = interp.expression.range(); let source = self.source_file.slice(range); let text = [ @@ -11768,9 +13441,37 @@ impl Compiler { strip_fstring_debug_comments(trailing).as_str(), ] .concat(); + let debug_text_range = TextRange::new( + range.start() + - TextSize::new( + u32::try_from(leading.len()) + .expect("debug t-string leading text too long"), + ), + range.end() + + TextSize::new( + u32::try_from(trailing.len()) + .expect("debug t-string trailing text too long"), + ), + ); + if current_string_range.is_none() { + *current_string_range = Some(debug_text_range); + } else { + Self::extend_pending_literal_range( + current_string_range, + debug_text_range, + ); + } current_string.push_str(&text); + strings.push(( + core::mem::take(current_string), + current_string_range.take().unwrap_or(template_range), + )); + } else { + strings.push(( + core::mem::take(current_string), + current_string_range.take().unwrap_or(template_range), + )); } - strings.push(core::mem::take(current_string)); *interp_count += 1; } } @@ -11783,6 +13484,13 @@ impl Compiler { continue; }; + if let Some(interpolation) = + self.public_ast_interpolation_override_by_index(interp.node_index.load()) + { + self.compile_interpolation(interp, interpolation)?; + continue; + } + self.compile_expression(&interp.expression)?; let expr_range = interp.expression.range(); @@ -11794,9 +13502,11 @@ impl Compiler { .slice(TextRange::new(after_brace, expr_range.end())) } else { self.source_file.slice(expr_range) - }; + } + .to_string(); + self.set_source_range(interp.range); self.emit_load_const(ConstantData::Str { - value: expr_source.to_string().into(), + value: expr_source.into(), }); let mut conversion: u32 = match interp.conversion { @@ -11812,16 +13522,18 @@ impl Compiler { let has_format_spec = interp.format_spec.is_some(); if let Some(format_spec) = &interp.format_spec { + let format_spec_range = self.cpython_format_spec_range(format_spec.range); self.compile_fstring_elements( ast::FStringFlags::empty(), &format_spec.elements, - Some(format_spec.range), + Some(format_spec_range), )?; } // CPython keeps bit 1 set in BUILD_INTERPOLATION's oparg and uses // bit 0 for the optional format spec. let format = 2 | (conversion << 2) | u32::from(has_format_spec); + self.set_source_range(interp.range); emit!(self, Instruction::BuildInterpolation { format }); } @@ -11922,10 +13634,10 @@ fn expandtabs(input: &str, tab_size: usize) -> String { expanded_str } -fn split_doc_with_range( - body: &[ast::Stmt], - opts: CompileOpts, -) -> (Option<(String, TextRange)>, &[ast::Stmt]) { +fn split_doc_with_range<'a>( + body: &'a [ast::Stmt], + opts: &CompileOpts, +) -> (Option<(String, TextRange)>, &'a [ast::Stmt]) { if let Some((ast::Stmt::Expr(expr), body_rest)) = body.split_first() { let doc_comment = match &*expr.value { ast::Expr::StringLiteral(value) => Some((&value.value, expr.value.range())), @@ -11945,7 +13657,7 @@ fn split_doc_with_range( } #[cfg(test)] -fn split_doc(body: &[ast::Stmt], opts: CompileOpts) -> (Option, &[ast::Stmt]) { +fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, &'a [ast::Stmt]) { let (doc, body) = split_doc_with_range(body, opts); (doc.map(|(doc, _)| doc), body) } @@ -12154,14 +13866,16 @@ mod tests { use super::*; use rustpython_compiler_core::{ SourceFileBuilder, - bytecode::{CodeUnit, OpArg}, + bytecode::{CO_FAST_ARG_KW, CO_FAST_ARG_POS, CodeUnit, OpArg}, }; fn assert_scope_exit_locations(code: &CodeObject) { for (instr, (location, _)) in code.instructions.iter().zip(code.locations.iter()) { if matches!( - instr.op.into(), - Opcode::ReturnValue | Opcode::RaiseVarargs | Opcode::Reraise + instr.op, + Instruction::ReturnValue + | Instruction::RaiseVarargs { .. } + | Instruction::Reraise { .. } ) { assert!( location.line.get() > 0, @@ -12201,7 +13915,7 @@ mod tests { compile_exec_with_options(source, opts) } - fn compile_exec_with_options(source: &str, opts: CompileOpts) -> CodeObject { + fn compile_exec_with_options(source: &str, mut opts: CompileOpts) -> CodeObject { let source_file = SourceFileBuilder::new("source_path", source).finish(); let parsed = ruff_python_parser::parse( source_file.source_text(), @@ -12209,28 +13923,844 @@ mod tests { ) .unwrap(); let mut ast = parsed.into_syntax(); - preprocess::preprocess_mod(&mut ast); + opts.future_features |= checked_future_features(&ast, &source_file).unwrap(); + let future_annotations = opts + .future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + preprocess::preprocess_mod(&mut ast, opts.optimize, future_annotations, false); let ast = match ast { ruff_python_ast::Mod::Module(stmts) => stmts, _ => unreachable!(), }; - let symbol_table = SymbolTable::scan_program(&ast, source_file.clone()) - .map_err(|e| e.into_codegen_error(source_file.name().to_owned())) - .unwrap(); - let mut compiler = Compiler::new(opts, source_file, ""); + let symbol_table = SymbolTable::scan_program_with_options( + &ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned())) + .unwrap(); + let mut compiler = + Compiler::new_with_syntax_warning_handler(opts, source_file, "", None); compiler.compile_program(&ast, symbol_table).unwrap(); compiler.exit_scope() } - #[test] - fn empty_module_implicit_return_inherits_resume_location_like_cpython() { - let code = compile_exec(""); - // CPython 3.14 codegen emits the implicit LOAD_CONST/RETURN_VALUE with - // NO_LOCATION, then flowgraph.c::propagate_line_numbers() propagates - // the module RESUME location, whose line is 0. + fn compile_module_instruction_infos(source: &str, mode: Mode) -> Vec { + let mut opts = CompileOpts::default(); + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let parsed = ruff_python_parser::parse( + source_file.source_text(), + ruff_python_parser::Mode::Module.into(), + ) + .unwrap(); + let mut ast = parsed.into_syntax(); + opts.future_features |= checked_future_features(&ast, &source_file).unwrap(); + let future_annotations = opts + .future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + if matches!(mode, Mode::Single) + && let ruff_python_ast::Mod::Module(module) = &mut ast + { + preprocess::preprocess_statements( + &mut module.body, + opts.optimize, + future_annotations, + false, + ); + } else { + preprocess::preprocess_mod(&mut ast, opts.optimize, future_annotations, false); + } + let ast = match ast { + ruff_python_ast::Mod::Module(stmts) => stmts, + _ => unreachable!(), + }; + let symbol_table = SymbolTable::scan_program_with_options( + &ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned())) + .unwrap(); + let mut compiler = + Compiler::new_with_syntax_warning_handler(opts, source_file, "", None); + match mode { + Mode::Single => compiler.compile_program_single(&ast.body, symbol_table), + _ => compiler.compile_program(&ast, symbol_table), + } + .unwrap(); + + compiler + .current_code_info() + .blocks + .iter() + .flat_map(|block| block.used_instructions().iter().copied()) + .collect() + } + + fn compile_eval_ast_with_options(expr: ast::Expr, opts: CompileOpts) -> CodeObject { + let source_file = SourceFileBuilder::new("source_path", "").finish(); + let parsed = ruff_python_ast::Mod::Expression(ast::ModExpression { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + body: Box::new(expr), + }); + compile_top(parsed, source_file, Mode::Eval, opts).unwrap() + } + + fn compile_public_constant_expr(expr: ast::Expr, constant: ConstantData) -> CodeObject { + let index = ast::NodeIndex::from(0); + ast::HasNodeIndex::node_index(&expr).set(index); + let mut ast_constant_overrides = PublicAstNodeMap::new(); + ast_constant_overrides.insert(index, constant); + compile_eval_ast_with_options( + expr, + CompileOpts { + ast_constant_overrides: Some(Arc::new(ast_constant_overrides)), + ..CompileOpts::default() + }, + ) + } + + fn first_public_constant_warning( + expr: ast::Expr, + index: ast::NodeIndex, + constant: ConstantData, + ) -> String { + let mut ast_constant_overrides = PublicAstNodeMap::new(); + ast_constant_overrides.insert(index, constant); + let opts = CompileOpts { + ast_constant_overrides: Some(Arc::new(ast_constant_overrides)), + ..CompileOpts::default() + }; + let source_file = SourceFileBuilder::new("source_path", "").finish(); + let parsed = ruff_python_ast::Mod::Expression(ast::ModExpression { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + body: Box::new(expr), + }); + let mut warning = None; + let mut handler = |location, message: String| { + warning = Some(message.clone()); + Err(CodegenError { + location: Some(location), + error: CodegenErrorType::SyntaxError(message), + source_path: "source_path".to_owned(), + }) + }; + let _ = compile_top_with_syntax_warning_handler( + parsed, + source_file, + Mode::Eval, + opts, + Some(&mut handler), + ) + .expect_err("expected SyntaxWarning handler to stop compilation"); + warning.expect("expected warning message") + } + + fn first_exec_warning(source: &str) -> String { + let opts = CompileOpts::default(); + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let parsed = ruff_python_parser::parse( + source_file.source_text(), + ruff_python_parser::Mode::Module.into(), + ) + .unwrap() + .into_syntax(); + let mut warning = None; + let mut handler = |location, message: String| { + warning = Some(message.clone()); + Err(CodegenError { + location: Some(location), + error: CodegenErrorType::SyntaxError(message), + source_path: "source_path".to_owned(), + }) + }; + let _ = compile_top_with_syntax_warning_handler( + parsed, + source_file, + Mode::Exec, + opts, + Some(&mut handler), + ) + .expect_err("expected SyntaxWarning handler to stop compilation"); + warning.expect("expected warning message") + } + + fn frozenset_call_expr() -> ast::Expr { + ast::Expr::Call(ast::ExprCall { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + func: Box::new(ast::Expr::Name(ast::ExprName { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + id: ast::name::Name::new_static("frozenset"), + ctx: ast::ExprContext::Load, + })), + arguments: ast::Arguments { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + args: Box::default(), + keywords: Default::default(), + }, + }) + } + + fn compile_exec_parsed_error( + source: &str, + parsed: ruff_python_parser::Parsed, + ) -> CodegenError { + let mut opts = CompileOpts::default(); + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let mut ast = parsed.into_syntax(); + opts.future_features |= match checked_future_features(&ast, &source_file) { + Ok(features) => features, + Err(err) => return err, + }; + let future_annotations = opts + .future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + preprocess::preprocess_mod(&mut ast, opts.optimize, future_annotations, false); + let ast = match ast { + ruff_python_ast::Mod::Module(stmts) => stmts, + _ => unreachable!(), + }; + let symbol_table = match SymbolTable::scan_program(&ast, source_file.clone()) { + Ok(symbol_table) => symbol_table, + Err(err) => return err.into_codegen_error(source_file.name().to_owned()), + }; + let mut compiler = + Compiler::new_with_syntax_warning_handler(opts, source_file, "", None); + compiler.compile_program(&ast, symbol_table).unwrap_err() + } + + fn compile_exec_error(source: &str) -> CodegenError { + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let parsed = ruff_python_parser::parse( + source_file.source_text(), + ruff_python_parser::Mode::Module.into(), + ) + .unwrap(); + compile_exec_parsed_error(source, parsed) + } + + fn compile_exec_error_message(source: &str) -> String { + compile_exec_error(source).error.to_string() + } + + fn compile_exec_unchecked_error_message(source: &str) -> String { + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let parsed = ruff_python_parser::parse_unchecked( + source_file.source_text(), + ruff_python_parser::Mode::Module.into(), + ); + compile_exec_parsed_error(source, parsed).error.to_string() + } + + #[test] + fn public_ast_frozenset_constant_compiles_as_load_const() { + let code = compile_public_constant_expr( + frozenset_call_expr(), + ConstantData::Frozenset { + elements: vec![ConstantData::Integer { + value: BigInt::from(1u8), + }], + }, + ); + let ops: Vec<_> = code + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.iter() + .any(|op| matches!(op, Instruction::LoadConst { .. })), + "public ast.Constant(frozenset(...)) must use CPython Constant_kind LOAD_CONST path, got {ops:?}" + ); + assert!( + !ops.iter().any(|op| matches!( + op, + Instruction::LoadName { .. } + | Instruction::Call { .. } + | Instruction::CallKw { .. } + )), + "public ast.Constant(frozenset(...)) must not compile as a frozenset() call, got {ops:?}" + ); + assert!( + code.constants.iter().any(|constant| matches!( + constant, + ConstantData::Frozenset { elements } + if matches!( + elements.as_slice(), + [ConstantData::Integer { value }] if *value == BigInt::from(1u8) + ) + )), + "missing frozenset constant in code constants" + ); + } + + #[test] + fn public_ast_constant_is_not_scanned_as_lowered_expression() { + let expr = frozenset_call_expr(); + let index = ast::NodeIndex::from(0); + ast::HasNodeIndex::node_index(&expr).set(index); + let mut ast_constant_overrides = PublicAstNodeMap::new(); + ast_constant_overrides.insert( + index, + ConstantData::Frozenset { + elements: Vec::new(), + }, + ); + let module = ast::ModExpression { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + body: Box::new(expr), + }; + let table = SymbolTable::scan_expr_with_options( + &module, + SourceFileBuilder::new("source_path", "").finish(), + false, + false, + Some(Arc::new(ast_constant_overrides)), + None, + None, + None, + None, + CompileOpts::default().recursion_limit, + ) + .unwrap(); + + assert!( + table.lookup("frozenset").is_none(), + "CPython symtable Constant_kind does not visit the lowered frozenset() expression" + ); + } + + #[test] + fn public_ast_frozenset_constant_call_warns_like_cpython_constant() { + let index = ast::NodeIndex::from(0); + let func = frozenset_call_expr(); + ast::HasNodeIndex::node_index(&func).set(index); + let message = first_public_constant_warning( + ast::Expr::Call(ast::ExprCall { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + func: Box::new(func), + arguments: ast::Arguments { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + args: Box::default(), + keywords: Default::default(), + }, + }), + index, + ConstantData::Frozenset { + elements: Vec::new(), + }, + ); + assert!( + message.contains("'frozenset' object is not callable"), + "expected public ast.Constant(frozenset()) callable warning, got {message:?}" + ); + } + + #[test] + fn public_ast_frozenset_constant_subscript_warns_like_cpython_constant() { + let index = ast::NodeIndex::from(0); + let value = frozenset_call_expr(); + ast::HasNodeIndex::node_index(&value).set(index); + let message = first_public_constant_warning( + ast::Expr::Subscript(ast::ExprSubscript { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + value: Box::new(value), + slice: Box::new(ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + value: ast::Number::Int(ast::Int::ZERO), + })), + ctx: ast::ExprContext::Load, + }), + index, + ConstantData::Frozenset { + elements: Vec::new(), + }, + ); + assert!( + message.contains("'frozenset' object is not subscriptable"), + "expected public ast.Constant(frozenset()) subscript warning, got {message:?}" + ); + } + + #[test] + fn public_ast_str_constant_bad_index_warns_like_cpython_constant() { + let index = ast::NodeIndex::from(0); + let value = frozenset_call_expr(); + ast::HasNodeIndex::node_index(&value).set(index); + let message = first_public_constant_warning( + ast::Expr::Subscript(ast::ExprSubscript { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + value: Box::new(value), + slice: Box::new(ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + value: ast::Number::Float(1.0), + })), + ctx: ast::ExprContext::Load, + }), + index, + ConstantData::Str { + value: "abc".into(), + }, + ); + assert!( + message.contains("str indices must be integers or slices, not float"), + "expected public ast.Constant(str) bad-index warning, got {message:?}" + ); + } + + #[test] + fn public_ast_frozenset_constant_is_warns_like_cpython_constant() { + let index = ast::NodeIndex::from(0); + let left = frozenset_call_expr(); + ast::HasNodeIndex::node_index(&left).set(index); + let message = first_public_constant_warning( + ast::Expr::Compare(ast::ExprCompare { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + left: Box::new(left), + ops: Box::new([ast::CmpOp::Is]), + comparators: Box::new([ast::Expr::NoneLiteral(ast::ExprNoneLiteral { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + })]), + }), + index, + ConstantData::Frozenset { + elements: Vec::new(), + }, + ); + assert!( + message.contains("\"is\" with 'frozenset' literal"), + "expected public ast.Constant(frozenset()) identity warning, got {message:?}" + ); + } + + #[test] + fn public_ast_tuple_constant_compiles_as_load_const() { + let expr = ast::Expr::Tuple(ast::ExprTuple { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + elts: Vec::new(), + ctx: ast::ExprContext::Load, + parenthesized: true, + }); + let code = compile_public_constant_expr( + expr, + ConstantData::Tuple { + elements: vec![ConstantData::Integer { + value: BigInt::from(1u8), + }], + }, + ); + let ops: Vec<_> = code + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.iter() + .any(|op| matches!(op, Instruction::LoadConst { .. })), + "public ast.Constant(tuple(...)) must use CPython Constant_kind LOAD_CONST path, got {ops:?}" + ); + assert!( + !ops.iter() + .any(|op| matches!(op, Instruction::BuildTuple { .. })), + "public ast.Constant(tuple(...)) must not compile as a tuple display, got {ops:?}" + ); + assert!( + code.constants.iter().any(|constant| matches!( + constant, + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ConstantData::Integer { value }] if *value == BigInt::from(1u8) + ) + )), + "missing tuple constant in code constants" + ); + } + + #[test] + fn public_ast_constant_slice_bound_uses_cpython_constant_slice_path() { + let index = ast::NodeIndex::from(0); + let lower = frozenset_call_expr(); + ast::HasNodeIndex::node_index(&lower).set(index); + let expr = ast::Expr::Subscript(ast::ExprSubscript { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + value: Box::new(ast::Expr::Name(ast::ExprName { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + id: ast::name::Name::new_static("obj"), + ctx: ast::ExprContext::Load, + })), + slice: Box::new(ast::Expr::Slice(ast::ExprSlice { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + lower: Some(Box::new(lower)), + upper: Some(Box::new(ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: ast::AtomicNodeIndex::NONE, + range: TextRange::default(), + value: ast::Number::Int(ast::Int::ZERO), + }))), + step: None, + })), + ctx: ast::ExprContext::Load, + }); + let mut ast_constant_overrides = PublicAstNodeMap::new(); + ast_constant_overrides.insert( + index, + ConstantData::Integer { + value: BigInt::from(1u8), + }, + ); + let code = compile_eval_ast_with_options( + expr, + CompileOpts { + ast_constant_overrides: Some(Arc::new(ast_constant_overrides)), + ..CompileOpts::default() + }, + ); + let ops: Vec<_> = code + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + assert!( + !ops.iter().any(|op| matches!( + op, + Instruction::BinarySlice | Instruction::BuildSlice { .. } + )), + "public ast.Constant slice bound must follow CPython Constant_kind folded slice path, got {ops:?}" + ); + assert!( + code.constants.iter().any(|constant| matches!( + constant, + ConstantData::Slice { elements } + if matches!( + elements.as_ref(), + [ + ConstantData::Integer { value }, + ConstantData::Integer { .. }, + ConstantData::None, + ] if *value == BigInt::from(1u8) + ) + )), + "missing folded slice constant for public ast.Constant bound" + ); + } + + #[test] + fn match_pattern_errors_use_cpython_sequence_messages() { + let many_names = (0..256) + .map(|i| format!("a{i}")) + .collect::>() + .join(", "); + let too_many = format!( + "\ +match x: + case [{many_names}, *rest]: + pass +" + ); + assert_eq!( + compile_exec_error_message(&too_many), + "too many expressions in star-unpacking sequence pattern" + ); + + assert_eq!( + compile_exec_error_message( + "\ +match x: + case [*a, *b]: + pass +" + ), + "multiple starred names in sequence pattern" + ); + + assert_eq!( + compile_exec_unchecked_error_message( + "\ +match x: + case {**_}: + pass +" + ), + "invalid syntax" + ); + } + + #[test] + fn match_mapping_duplicate_literal_keys_use_cpython_equality() { + for (source, expected) in [ + ( + "\ +match x: + case {1: a, True: b}: + pass +", + "mapping pattern checks duplicate key (True)", + ), + ( + "\ +match x: + case {1: a, 1.0: b}: + pass +", + "mapping pattern checks duplicate key (1.0)", + ), + ( + "\ +match x: + case {0.0: a, -0.0: b}: + pass +", + "mapping pattern checks duplicate key (-0.0)", + ), + ( + "\ +match x: + case {9007199254740992: a, 9007199254740992.0: b}: + pass +", + "mapping pattern checks duplicate key (9007199254740992.0)", + ), + ( + "\ +match x: + case {-9007199254740992: a, -9007199254740992.0: b}: + pass +", + "mapping pattern checks duplicate key (-9007199254740992.0)", + ), + ( + "\ +match x: + case {1 + 0j: a, 1: b}: + pass +", + "mapping pattern checks duplicate key (1)", + ), + ( + "\ +match x: + case {1: a, 1 + 0j: b}: + pass +", + "mapping pattern checks duplicate key ((1+0j))", + ), + ( + "\ +match x: + case {0j: a, -0.0: b}: + pass +", + "mapping pattern checks duplicate key (-0.0)", + ), + ] { + assert_eq!(compile_exec_error_message(source), expected); + } + } + + #[test] + fn match_mapping_accepts_folded_literal_keys_like_cpython() { + compile_exec( + "\ +def f(x): + match x: + case {-1: a, 1 + 0j: b}: + return a, b + case {9007199254740993: a, 9007199254740992.0: b}: + return a, b + case {-9007199254740993: a, -9007199254740992.0: b}: + return a, b + case {1 + 1j: a, 1: b}: + return a, b + case _: + return None +", + ); + } + + #[test] + fn match_literal_binop_folding_uses_cpython_complex_shape() { + assert!( + Compiler::try_fold_match_pattern_binop( + ast::Operator::Add, + &ConstantData::Integer { + value: BigInt::from(1) + }, + &ConstantData::Integer { + value: BigInt::from(2) + }, + ) + .is_none() + ); + assert!( + Compiler::try_fold_match_pattern_binop( + ast::Operator::Add, + &ConstantData::Float { value: 1.0 }, + &ConstantData::Float { value: 2.0 }, + ) + .is_none() + ); + assert!(matches!( + Compiler::try_fold_match_pattern_binop( + ast::Operator::Add, + &ConstantData::Integer { + value: BigInt::from(1) + }, + &ConstantData::Complex { + value: Complex::new(0.0, 2.0) + }, + ), + Some(ConstantData::Complex { value }) if value == Complex::new(1.0, 2.0) + )); + assert!(matches!( + Compiler::try_fold_match_pattern_binop( + ast::Operator::Sub, + &ConstantData::Float { value: 1.5 }, + &ConstantData::Complex { + value: Complex::new(0.0, 2.0) + }, + ), + Some(ConstantData::Complex { value }) if value == Complex::new(1.5, -2.0) + )); + } + + #[test] + fn match_literal_patterns_reject_unexpected_constants_like_cpython() { + assert!(Compiler::is_unexpected_match_literal_constant( + &ast::ExprEllipsisLiteral { + range: TextRange::default(), + node_index: ast::AtomicNodeIndex::NONE, + } + .into() + )); + } + + #[test] + fn unpack_ex_allows_large_after_count_like_cpython() { + let suffix = (0..256) + .map(|i| format!("a{i}")) + .collect::>() + .join(", "); + let code = compile_exec(&format!( + "\ +def assignment(values): + *rest, {suffix} = values + return a255 + +def pattern(values): + match values: + case [*rest, {suffix}]: + return a255 + case _: + return None +" + )); + + let assignment = find_code(&code, "assignment").expect("missing assignment code"); + assert_eq!( + full_opargs_for(assignment, |op| matches!(op, Instruction::UnpackEx { .. })), + vec![256 << 8] + ); + + let pattern = find_code(&code, "pattern").expect("missing pattern code"); + assert_eq!( + full_opargs_for(pattern, |op| matches!(op, Instruction::UnpackEx { .. })), + vec![256 << 8] + ); + } + + #[test] + fn match_irrefutable_pattern_errors_use_cpython_messages() { + assert_eq!( + compile_exec_error_message( + "\ +match x: + case y | 1: + pass +" + ), + "name capture 'y' makes remaining patterns unreachable" + ); + + assert_eq!( + compile_exec_error_message( + "\ +match x: + case _ | 1: + pass +" + ), + "wildcard makes remaining patterns unreachable" + ); + } + + #[test] + fn empty_module_implicit_return_inherits_resume_location_like_cpython() { + let code = compile_exec(""); + // CPython 3.14 codegen emits the implicit LOAD_CONST/RETURN_VALUE with + // NO_LOCATION, then flowgraph.c::propagate_line_numbers() propagates + // the module RESUME location, whose line is 0. assert_eq!(code.linetable.as_ref(), &[0xf2, 0x03, 0x01, 0x01, 0x01]); } + #[test] + fn module_docstring_load_uses_doc_location_like_cpython() { + let code = compile_exec( + "\ +\"doc\" +x = 1 +", + ); + + // CPython 3.14 codegen_body() emits the docstring LOAD_CONST at the + // string expression location, then emits STORE_NAME __doc__ with + // NO_LOCATION. + assert_eq!( + code.linetable.as_ref(), + &[ + 0xf0, 0x03, 0x01, 0x01, 0x01, 0xd9, 0x00, 0x05, 0xd8, 0x04, 0x05, 0x82, 0x01, + ], + ); + } + #[test] fn redundant_nop_location_copies_full_location_like_cpython() { let code = compile_exec( @@ -12305,7 +14835,8 @@ def f(x, y, z): let symbol_table = SymbolTable::scan_program(&ast, source_file.clone()) .map_err(|e| e.into_codegen_error(source_file.name().to_owned())) .unwrap(); - let mut compiler = Compiler::new(opts, source_file, ""); + let mut compiler = + Compiler::new_with_syntax_warning_handler(opts, source_file, "", None); compiler.compile_program(&ast, symbol_table).unwrap(); let _table = compiler.pop_symbol_table(); let stack_top = compiler.code_stack.pop().unwrap(); @@ -12351,7 +14882,8 @@ def f(x, y, z): let is_async = function.is_async; let range = function.range(); - let mut compiler = Compiler::new(opts, source_file, ""); + let mut compiler = + Compiler::new_with_syntax_warning_handler(opts, source_file, "", None); compiler.future_annotations = symbol_table.future_annotations; compiler.symbol_table_stack.push(symbol_table); compiler.set_source_range(range); @@ -12359,7 +14891,7 @@ def f(x, y, z): compiler .current_code_info() .flags - .set(CodeFlags::COROUTINE, is_async); + .set(bytecode::CodeFlags::COROUTINE, is_async); let prev_ctx = compiler.ctx; compiler.ctx = CompileContext { @@ -12372,7 +14904,7 @@ def f(x, y, z): in_async_scope: is_async, }; compiler.set_qualname(); - let (_doc_str, body) = split_doc(body, compiler.opts); + let (_doc_str, body) = split_doc(body, &compiler.opts); let start_label = compiler.use_cpython_function_start_label(); let is_gen = is_async || compiler.current_symbol_table().is_generator; let stop_iteration_block = if is_gen { @@ -14297,6 +16829,21 @@ def g(): ); } + #[test] + fn starred_arg_annotation_unpack_uses_function_location_like_cpython() { + let code = compile_exec("def f(*args: *Ts): pass\n"); + let annotate = find_code(&code, "__annotate__").expect("missing annotation code"); + + // CPython 3.14 codegen_argannotation() visits `Ts` at the annotation + // expression location, then emits UNPACK_SEQUENCE at LOC(function). + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd7, 0x00, 0x17, 0xd1, 0x00, 0x17, 0x8c, 0x62, 0xd3, 0x00, 0x17 + ], + ); + } + #[test] fn module_deferred_annotations_use_start_location_like_cpython() { let code = compile_exec( @@ -14360,33 +16907,216 @@ class C: } #[test] - fn lambda_return_uses_body_location_like_cpython() { + fn multiline_super_method_load_uses_expression_start_location_like_cpython() { + let code = compile_exec( + "\ +class C: + def f(self): + return super( + ).m() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let load_super_index = f + .instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadSuperAttr { namei } => namei + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .is_load_method(), + _ => false, + }) + .expect("missing LOAD_SUPER_METHOD"); + let (load_location, _) = f.locations[load_super_index]; + + assert_eq!( + load_location.line.get(), + 3, + "CPython maybe_optimize_method_call() emits LOAD_SUPER_METHOD at LOC(meth), before updating to the attribute start" + ); + } + + #[test] + fn multiline_non_ascii_attribute_uses_cpython_unicode_length() { + let code = compile_exec( + "\ +def f(obj): + return (obj + .é) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let load_attr_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::LoadAttr { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing LOAD_ATTR"); + + assert_eq!( + load_attr_position, + (3, 11, 3, 12), + "CPython update_start_location_to_match_attr() subtracts PyUnicode_GET_LENGTH(attr), not the UTF-8 byte length; Rust SourceLocation exposes the resulting columns as one-based" + ); + } + + #[test] + fn two_arg_super_attr_in_class_body_is_optimized_like_cpython() { + let code = compile_exec( + "\ +class C: + x = super(C, self).attr +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + + assert!( + class_code + .instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadSuperAttr { .. })), + "CPython can_optimize_super_call() does not require function scope for two-argument super()" + ); + } + + #[test] + fn module_super_symbol_blocks_zero_arg_super_optimization_like_cpython() { + let code = compile_exec( + "\ +super +class C: + def f(self): + return super().attr +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + assert!( + !f.instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadSuperAttr { .. })), + "CPython can_optimize_super_call() rejects any top-level symbol-table entry for super" + ); + } + + #[test] + fn lambda_return_uses_body_location_like_cpython() { + let code = compile_exec( + "\ +def outer(): + return lambda x: x if x else 1 +", + ); + let lambda = find_code(&code, "").expect("missing lambda code"); + let return_positions: Vec<_> = lambda + .instructions + .iter() + .zip(&lambda.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 22, 2, 35), (2, 22, 2, 35)], + "CPython codegen_lambda() emits RETURN_VALUE at LOC(lambda body)" + ); + } + + #[test] + fn explicit_return_value_locations_match_cpython_codegen_return() { + let code = compile_exec( + "\ +def dynamic(x): + return x + +def constant(): + return 1 + +def bare(): + return +", + ); + + let cases = [ + ("dynamic", vec![(2, 5, 2, 13)]), + ("constant", vec![(5, 12, 5, 13)]), + ("bare", vec![(8, 5, 8, 11)]), + ]; + for (name, expected) in cases { + let function = find_code(&code, name).expect("missing function code"); + let return_positions: Vec<_> = function + .instructions + .iter() + .zip(&function.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, expected, + "CPython codegen_return() emits explicit return at loc for {name}" + ); + } + } + + #[test] + fn continue_jump_keeps_statement_location_like_cpython() { let code = compile_exec( "\ -def outer(): - return lambda x: x if x else 1 +def continues(xs): + for x in xs: + if x: + continue + use(x) ", ); - let lambda = find_code(&code, "").expect("missing lambda code"); - let return_positions: Vec<_> = lambda - .instructions - .iter() - .zip(&lambda.locations) - .filter_map(|(unit, (location, end_location))| { - matches!(unit.op, Instruction::ReturnValue).then_some(( - location.line.get(), - location.character_offset.get(), - end_location.line.get(), - end_location.character_offset.get(), - )) - }) - .collect(); - assert_eq!( - return_positions, - vec![(2, 22, 2, 35), (2, 22, 2, 35)], - "CPython codegen_lambda() emits RETURN_VALUE at LOC(lambda body)" - ); + { + let (name, expected_position) = ("continues", (4, 13, 4, 21)); + let function = find_code(&code, name).expect("missing function code"); + let jump_positions: Vec<_> = function + .instructions + .iter() + .zip(&function.locations) + .filter_map(|(unit, (location, end_location))| { + matches!( + unit.op, + Instruction::JumpForward { .. } | Instruction::JumpBackward { .. } + ) + .then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert!( + jump_positions.contains(&expected_position), + "CPython codegen_continue() emits final jump at statement loc for {name}, got {jump_positions:?}" + ); + } } #[test] @@ -14466,6 +17196,73 @@ def f(c): ); } + #[test] + fn typealias_value_scope_has_single_return_like_cpython() { + let code = compile_exec("type Alias = int\n"); + let alias = find_direct_child_code(&code, "Alias").expect("missing alias code"); + let return_count = alias + .instructions + .iter() + .filter(|unit| matches!(unit.op, Instruction::ReturnValue)) + .count(); + assert_eq!( + return_count, 1, + "CPython codegen_typealias_body() emits one RETURN_VALUE and assembles with addNone=0, got instructions={:?}", + alias.instructions + ); + } + + #[test] + fn generic_typealias_wrapper_return_uses_alias_location_like_cpython() { + let code = compile_exec("type A[T] = T\n"); + let type_params = + find_code(&code, "").expect("missing type params code"); + + // CPython 3.14 codegen_typealias() assembles the generic-parameters + // wrapper with addNone=0 after codegen_typealias_body() leaves the type + // alias object on the stack. The final RETURN_VALUE keeps LOC(type alias). + assert_eq!( + type_params.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0x80, 0x0d, 0x84, 0x71, 0x87, 0x0d, 0x81, 0x0d + ], + ); + } + + #[test] + fn type_param_bound_scope_has_single_return_like_cpython() { + let code = compile_exec("type Alias[T: int] = T\n"); + let type_params = + find_code(&code, "").expect("missing type params code"); + let bound = find_direct_child_code(type_params, "T").expect("missing T bound code"); + let return_count = bound + .instructions + .iter() + .filter(|unit| matches!(unit.op, Instruction::ReturnValue)) + .count(); + assert_eq!( + return_count, 1, + "CPython codegen_type_param_bound_or_default() emits one explicit RETURN_VALUE before OptimizeAndAssemble(addNone=1), got instructions={:?}", + bound.instructions + ); + } + + #[test] + fn class_body_scope_has_single_return_like_cpython() { + let code = compile_exec("class C:\n pass\n"); + let class_code = find_code(&code, "C").expect("missing class code"); + let return_count = class_code + .instructions + .iter() + .filter(|unit| matches!(unit.op, Instruction::ReturnValue)) + .count(); + assert_eq!( + return_count, 1, + "CPython codegen_class_body() emits one explicit RETURN_VALUE before OptimizeAndAssemble(addNone=1), got instructions={:?}", + class_code.instructions + ); + } + #[test] fn generic_function_annotation_scope_uses_function_location_like_cpython() { let code = compile_exec("def f[T](x: int): ...\n"); @@ -14485,6 +17282,23 @@ def f(c): ); } + #[test] + fn decorated_generic_function_type_params_use_decorator_firstlineno_like_cpython() { + let code = compile_exec( + "\ +def deco(obj): return obj +@deco +def f[T](): pass +", + ); + let type_params = + find_code(&code, "").expect("missing type params code"); + + // CPython codegen_function() passes firstlineno, not LOC(s).lineno, to + // the generic-parameters scope. + assert_eq!(type_params.first_line_number.unwrap().get(), 2); + } + #[test] fn generic_class_type_params_store_uses_class_location_like_cpython() { let code = compile_exec( @@ -14576,6 +17390,160 @@ def f(): ); } + #[test] + fn try_except_else_finally_child_scopes_follow_cpython_symbol_order() { + let code = compile_exec( + "\ +def f(x): + try: + pass + except Exception: + y = 1 + def h(): + return y + else: + def e(): + return x + finally: + def z(): + return x +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let h = find_code(f, "h").expect("missing handler function code"); + let e = find_code(f, "e").expect("missing else function code"); + let z = find_code(f, "z").expect("missing finally function code"); + + assert_eq!( + h.freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["y"], + "handler child scope should consume the handler symbol table" + ); + assert_eq!( + e.freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["x"], + "else child scope should be consumed before handler scopes, matching CPython codegen_try_except()" + ); + assert_eq!( + z.freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["x"], + "finally child scope should remain after body/else/handler scopes" + ); + } + + #[test] + fn try_star_child_scopes_follow_codegen_order_like_cpython() { + let code = compile_exec( + "\ +def f(x): + try: + pass + except* Exception: + y = 1 + def h(): + return y + else: + def e(): + return x +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let h = find_code(f, "h").expect("missing except* handler function code"); + let e = find_code(f, "e").expect("missing else function code"); + + assert_eq!( + h.freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["y"], + "except* handler child scope should consume handler symbol table before else" + ); + assert_eq!( + e.freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["x"], + "except* else child scope should be consumed after handler scopes, matching CPython codegen_try_star_except()" + ); + } + + #[test] + fn function_default_and_decorator_child_scopes_follow_cpython_symbol_order() { + fn direct_child_codes<'a>(code: &'a CodeObject, name: &str) -> Vec<&'a CodeObject> { + code.constants + .iter() + .filter_map(|constant| { + if let ConstantData::Code { code } = constant + && code.obj_name == name + { + Some(code.as_ref()) + } else { + None + } + }) + .collect() + } + + let code = compile_exec( + "\ +def outer(x, deco): + @(lambda f: deco(f)) + def inner(a=(lambda: x)()): + return a +", + ); + let outer = find_code(&code, "outer").expect("missing outer function code"); + let lambdas = direct_child_codes(outer, ""); + + assert_eq!(lambdas.len(), 2); + assert_eq!( + lambdas[0] + .freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["deco"], + "decorator lambda is emitted first by codegen_function()" + ); + assert_eq!( + lambdas[1] + .freevars + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["x"], + "default lambda should still consume the default symbol table" + ); + } + + #[test] + fn decorated_generic_class_type_params_use_decorator_firstlineno_like_cpython() { + let code = compile_exec( + "\ +def deco(obj): return obj +@deco +class C[T]: pass +", + ); + let type_params = + find_code(&code, "").expect("missing type params code"); + + // CPython codegen_class() also enters the generic-parameters scope with + // firstlineno, which is the first decorator line when decorators exist. + assert_eq!(type_params.first_line_number.unwrap().get(), 2); + } + #[test] fn class_deferred_annotations_use_class_body_location_like_cpython() { let code = compile_exec( @@ -14649,6 +17617,45 @@ g = lambda i: {**i} ); } + #[test] + fn dict_unpacking_large_regular_run_uses_subdict_chunks_like_cpython() { + let pairs = (0..17) + .map(|i| format!("{i}: {i}")) + .collect::>() + .join(", "); + let source = format!("def f(x):\n return {{{pairs}, **x}}\n"); + let code = compile_exec(&source); + let f = find_code(&code, "f").expect("missing f code"); + let first_dict_update = f + .instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::DictUpdate { .. })) + .expect("missing DICT_UPDATE"); + let prefix = &f.instructions[..first_dict_update]; + let build_map_args: Vec<_> = prefix + .iter() + .filter_map(|unit| { + matches!(unit.op, Instruction::BuildMap { .. }).then_some(u8::from(unit.arg)) + }) + .collect(); + let map_adds = prefix + .iter() + .filter(|unit| matches!(unit.op, Instruction::MapAdd { .. })) + .count(); + + assert_eq!( + build_map_args, + vec![0], + "CPython codegen_dict() routes a 17-pair run before ** through codegen_subdict(), got instructions={:?}", + f.instructions + ); + assert_eq!( + map_adds, 17, + "CPython codegen_subdict() uses MAP_ADD for all 17 pairs before **, got instructions={:?}", + f.instructions + ); + } + #[test] fn class_function_like_scopes_set_method_flag_like_cpython() { let code = compile_exec_with_options( @@ -14677,13 +17684,13 @@ def f(): for code in [method, async_method, lambda, genexpr] { assert!( - code.flags.contains(CodeFlags::METHOD), + code.flags.contains(bytecode::CodeFlags::METHOD), "class-scope function-like code should carry CO_METHOD like CPython 3.14, got {:?}", code.flags ); } assert!( - !module_function.flags.contains(CodeFlags::METHOD), + !module_function.flags.contains(bytecode::CodeFlags::METHOD), "module-scope function must not carry CO_METHOD" ); } @@ -14702,15 +17709,47 @@ class C: let class_code = find_code(&code, "C").expect("missing class code"); let lambda = find_code(class_code, "").expect("missing lambda code"); assert!( - lambda.flags.contains(CodeFlags::NESTED), + lambda.flags.contains(bytecode::CodeFlags::NESTED), "lambda under inlined class comprehension should stay nested" ); assert!( - !lambda.flags.contains(CodeFlags::METHOD), + !lambda.flags.contains(bytecode::CodeFlags::METHOD), "CPython creates this lambda while the current symtable block is the comprehension, not the class" ); } + #[test] + fn class_inlined_comprehension_pushes_only_bound_locals_like_cpython() { + let code = compile_exec( + "\ +class C: + x = 1 + items = [x for i in range(3)] +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let cleared_names = class_code + .instructions + .iter() + .filter_map(|unit| match unit.op { + Instruction::LoadFastAndClear { var_num } => { + let idx = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + Some(class_code.varnames[usize::from(idx)].as_str()) + } + _ => None, + }) + .collect::>(); + + assert!( + cleared_names.contains(&"i"), + "the comprehension iteration variable should be isolated, got {cleared_names:?}" + ); + assert!( + !cleared_names.contains(&"x"), + "CPython applies the class-block special case while tweaking scopes, but codegen_push_inlined_comprehension_locals() runs after u_in_inlined_comp is set and only clears DEF_LOCAL names; got {cleared_names:?}" + ); + } + #[test] fn genexpr_implicit_iterator_is_not_posonly_like_cpython() { let code = compile_exec("x = (i for i in ())"); @@ -14723,6 +17762,29 @@ class C: ); } + #[test] + fn posonly_function_argcount_metadata_matches_cpython_assemble_split() { + let code = compile_exec("def f(a, /, b):\n pass\n"); + let func = find_code(&code, "f").expect("missing function code"); + + assert_eq!( + func.arg_count, 2, + "CPython assemble.c exposes co_argcount as u_posonlyargcount + u_argcount" + ); + assert_eq!(func.posonlyarg_count, 1); + assert_eq!(func.varnames.as_ref(), &["a".to_owned(), "b".to_owned()]); + assert_eq!( + func.localspluskinds[0] & (CO_FAST_ARG_POS | CO_FAST_ARG_KW), + CO_FAST_ARG_POS, + "CPython compute_localsplus_info marks only u_posonlyargcount slots as positional-only" + ); + assert_eq!( + func.localspluskinds[1] & (CO_FAST_ARG_POS | CO_FAST_ARG_KW), + CO_FAST_ARG_POS | CO_FAST_ARG_KW, + "CPython compute_localsplus_info marks u_argcount slots after posonly as positional-or-keyword" + ); + } + #[test] fn async_generator_uses_cpython_async_generator_flag() { let code = compile_exec_with_options( @@ -14742,17 +17804,37 @@ async def ag(): let coroutine = find_code(&code, "c").expect("missing coroutine code"); let async_generator = find_code(&code, "ag").expect("missing async generator code"); - assert!(generator.flags.contains(CodeFlags::GENERATOR)); - assert!(!generator.flags.contains(CodeFlags::COROUTINE)); - assert!(!generator.flags.contains(CodeFlags::ASYNC_GENERATOR)); + assert!(generator.flags.contains(bytecode::CodeFlags::GENERATOR)); + assert!(!generator.flags.contains(bytecode::CodeFlags::COROUTINE)); + assert!( + !generator + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); - assert!(coroutine.flags.contains(CodeFlags::COROUTINE)); - assert!(!coroutine.flags.contains(CodeFlags::GENERATOR)); - assert!(!coroutine.flags.contains(CodeFlags::ASYNC_GENERATOR)); + assert!(coroutine.flags.contains(bytecode::CodeFlags::COROUTINE)); + assert!(!coroutine.flags.contains(bytecode::CodeFlags::GENERATOR)); + assert!( + !coroutine + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); - assert!(async_generator.flags.contains(CodeFlags::ASYNC_GENERATOR)); - assert!(!async_generator.flags.contains(CodeFlags::GENERATOR)); - assert!(!async_generator.flags.contains(CodeFlags::COROUTINE)); + assert!( + async_generator + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + assert!( + !async_generator + .flags + .contains(bytecode::CodeFlags::GENERATOR) + ); + assert!( + !async_generator + .flags + .contains(bytecode::CodeFlags::COROUTINE) + ); } #[test] @@ -15089,6 +18171,27 @@ def f(a, b, c): .filter(|unit| !matches!(unit.op, Instruction::Cache)) } + fn full_opargs_for( + code: &CodeObject, + mut predicate: impl FnMut(Instruction) -> bool, + ) -> Vec { + let mut extended = 0u32; + let mut args = Vec::new(); + for unit in non_cache_instructions(code) { + let byte = u32::from(u8::from(unit.arg)); + if matches!(unit.op, Instruction::ExtendedArg) { + extended = (extended << 8) | byte; + continue; + } + let oparg = (extended << 8) | byte; + extended = 0; + if predicate(unit.op) { + args.push(oparg); + } + } + args + } + fn varname_index(code: &CodeObject, name: &str) -> usize { code.varnames .iter() @@ -15288,6 +18391,48 @@ def f(): ); } + #[test] + fn match_or_conflicting_bind_error_uses_or_pattern_location_like_cpython() { + let error = compile_exec_error( + "\ +def f(x): + match x: + case ( + a + | b + ): + pass +", + ); + let location = error.location.expect("missing error location"); + assert_eq!( + location.line.get(), + 4, + "CPython codegen_pattern_or() reports alternative binding mismatches at LOC(p), not LOC(alt)" + ); + } + + #[test] + fn match_or_duplicate_store_error_uses_or_pattern_location_like_cpython() { + let error = compile_exec_error( + "\ +def f(value): + match value: + case [ + x, + (x | x), + ]: + pass +", + ); + let location = error.location.expect("missing error location"); + assert_eq!( + location.line.get(), + 5, + "CPython codegen_pattern_or() reports merge-time duplicate stores at LOC(p)" + ); + } + #[test] fn match_success_jump_uses_no_location_like_cpython() { let code = compile_exec( @@ -15313,6 +18458,55 @@ def f(self): ); } + #[test] + fn match_default_simple_guard_jump_uses_guard_location_like_cpython() { + let code = compile_exec( + "\ +def f(x, y): + match x: + case 0: + return 1 + case _ if y: + return 2 + return 3 +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let guard_jump_location = f + .instructions + .iter() + .enumerate() + .find_map(|(idx, unit)| { + let (Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num }) = + unit.op + else { + return None; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + if f.varnames[usize::from(var_num.get(arg))] != "y" { + return None; + } + f.instructions + .iter() + .zip(&f.locations) + .skip(idx + 1) + .take(8) + .find_map(|(unit, (location, _))| { + matches!(unit.op, Instruction::PopJumpIfFalse { .. }).then_some(*location) + }) + }) + .expect("missing default guard jump"); + + assert_eq!( + ( + guard_jump_location.line.get(), + guard_jump_location.character_offset.get() + ), + (5, 19), + "CPython codegen_jump_if() receives LOC(pattern), but the simple guard fallback emits TO_BOOL/jump at LOC(guard)" + ); + } + #[test] fn match_mapping_keys_scaffolding_uses_mapping_location_like_cpython() { let code = compile_exec( @@ -15339,6 +18533,57 @@ def f(self): ); } + #[test] + fn match_mapping_rest_cleanup_uses_mapping_location_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case { + 0: _, + **rest, + }: + return rest +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let rest_cleanup_start = f + .instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::BuildMap { .. })) + .expect("missing BUILD_MAP"); + for expected in [ + "BUILD_MAP", + "DICT_UPDATE", + "DELETE_SUBSCR", + "rest cleanup COPY", + "rest cleanup SWAP", + ] { + let location = f + .instructions + .iter() + .zip(&f.locations) + .skip(rest_cleanup_start) + .find_map(|(unit, (location, _))| { + let found = matches!( + (expected, unit.op), + ("BUILD_MAP", Instruction::BuildMap { .. }) + | ("DICT_UPDATE", Instruction::DictUpdate { .. }) + | ("DELETE_SUBSCR", Instruction::DeleteSubscr) + | ("rest cleanup COPY", Instruction::Copy { .. }) + | ("rest cleanup SWAP", Instruction::Swap { .. }) + ); + found.then_some(*location) + }) + .unwrap_or_else(|| panic!("missing {expected}")); + assert_eq!( + location.line.get(), + 3, + "CPython codegen_pattern_mapping() emits {expected} with LOC(p)" + ); + } + } + #[test] fn match_class_scaffolding_uses_class_pattern_location_like_cpython() { let code = compile_exec( @@ -15362,6 +18607,40 @@ def f(x): ); } + #[test] + fn match_class_wildcard_pop_uses_class_pattern_location_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case bool( + _ + ): + return True +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let unpack_index = f + .instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::UnpackSequence { .. })) + .expect("missing class pattern UNPACK_SEQUENCE"); + let wildcard_pop_location = f + .instructions + .iter() + .zip(&f.locations) + .skip(unpack_index + 1) + .find_map(|(unit, (location, _))| { + matches!(unit.op, Instruction::PopTop).then_some(*location) + }) + .expect("missing wildcard POP_TOP"); + assert_eq!( + wildcard_pop_location.line.get(), + 3, + "CPython codegen_pattern_class() emits wildcard POP_TOP with LOC(p)" + ); + } + #[test] fn while_try_body_layout_keeps_false_jump_to_anchor() { let code = compile_exec( @@ -16161,6 +19440,132 @@ def f(self, node): ); } + #[test] + fn return_debug_in_finally_uses_cpython_preprocessed_constant_order() { + let code = compile_exec( + "\ +def f(close): + try: + return __debug__ + finally: + close() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let call_pos = f + .instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::Call { .. })) + .expect("missing finally-body call"); + let debug_load_pos = f + .instructions + .iter() + .position(|unit| { + let Instruction::LoadConst { consti } = unit.op else { + return false; + }; + let constant = &f.constants[consti.get(OpArg::new(u32::from(u8::from(unit.arg))))]; + matches!(constant, ConstantData::Boolean { value: true }) + }) + .expect("missing __debug__ constant load"); + + assert!( + call_pos < debug_load_pos, + "CPython ast_preprocess.c folds __debug__ to Constant before codegen_return(), so the return constant is loaded after finally cleanup; ops={:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn debug_statement_is_preprocessed_constant_like_cpython() { + for code in [ + compile_exec("__debug__\n"), + compile_exec_optimized("__debug__\n"), + ] { + let ops = non_cache_instructions(&code) + .map(|unit| unit.op) + .collect::>(); + assert!( + !ops.iter().any(|op| matches!(op, Instruction::PopTop)), + "CPython ast_preprocess.c folds __debug__ to Constant before codegen_stmt_expr(), so it must not compile as LOAD_CONST/POP_TOP; ops={ops:?}" + ); + } + } + + #[test] + fn statement_expr_pop_top_uses_no_location_like_cpython() { + let infos = compile_module_instruction_infos("x + 1\n", Mode::Exec); + let pop = infos + .iter() + .find(|info| matches!(info.instr.real(), Some(Instruction::PopTop))) + .expect("missing expression-statement POP_TOP"); + + assert_eq!( + pop.lineno_override, + Some(ir::NO_LOCATION_OVERRIDE), + "CPython codegen_stmt_expr() emits artificial expression-statement POP_TOP at NO_LOCATION" + ); + } + + #[test] + fn interactive_statement_expr_pop_top_uses_no_location_like_cpython() { + let infos = compile_module_instruction_infos("x + 1\n", Mode::Single); + let print = infos + .iter() + .position(|info| { + matches!( + info.instr.real(), + Some(Instruction::CallIntrinsic1 { func }) + if func.get(info.arg) == bytecode::IntrinsicFunction1::Print + ) + }) + .expect("missing interactive PRINT intrinsic"); + let pop = infos + .get(print + 1) + .expect("missing POP_TOP after interactive PRINT"); + + assert!( + matches!(pop.instr.real(), Some(Instruction::PopTop)), + "CPython codegen_stmt_expr() emits POP_TOP immediately after INTRINSIC_PRINT; got {pop:?}" + ); + assert_eq!( + pop.lineno_override, + Some(ir::NO_LOCATION_OVERRIDE), + "CPython codegen_stmt_expr() emits interactive PRINT cleanup POP_TOP at NO_LOCATION" + ); + } + + #[test] + fn import_star_pop_top_uses_no_location_like_cpython() { + let infos = compile_module_instruction_infos("from m import *\n", Mode::Exec); + let import_star = infos + .iter() + .position(|info| { + matches!( + info.instr.real(), + Some(Instruction::CallIntrinsic1 { func }) + if func.get(info.arg) == bytecode::IntrinsicFunction1::ImportStar + ) + }) + .expect("missing IMPORT_STAR intrinsic"); + let pop = infos + .get(import_star + 1) + .expect("missing POP_TOP after IMPORT_STAR"); + + assert!( + matches!(pop.instr.real(), Some(Instruction::PopTop)), + "CPython codegen_from_import() emits POP_TOP immediately after INTRINSIC_IMPORT_STAR; got {pop:?}" + ); + assert_eq!( + pop.lineno_override, + Some(ir::NO_LOCATION_OVERRIDE), + "CPython codegen_from_import() emits import-star cleanup POP_TOP at NO_LOCATION" + ); + } + #[test] fn adjacent_no_location_entries_merge_like_cpython() { let code = compile_exec( @@ -16272,9 +19677,9 @@ def prefixed(x): "CPython represents f'{{x=}}' debug text as a literal at the expression/debug-text location" ); assert_eq!( - string_load_position(prefixed, "a x="), - (5, 14, 5, 19), - "CPython extends a pending f-string literal through the debug text range" + string_load_position(prefixed, "x="), + (5, 17, 5, 19), + "CPython keeps debug text as a separate JoinedStr Constant instead of merging it with the preceding literal" ); } @@ -17015,6 +20420,98 @@ def f(cls, args, kwargs): } } + #[test] + fn method_call_at_stack_guideline_uses_plain_load_attr_like_cpython() { + let params = (0..STACK_USE_GUIDELINE) + .map(|i| format!("a{i}")) + .collect::>() + .join(", "); + let code = compile_exec(&format!( + "def f(obj, {params}):\n return obj.m({params})\n" + )); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + let plain_load_attr = f.instructions.iter().any(|unit| { + if let Instruction::LoadAttr { namei } = unit.op { + !namei + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .is_method() + } else { + false + } + }); + let direct_call_30 = f.instructions.iter().any(|unit| match unit.op { + Instruction::Call { argc } => { + argc.get(OpArg::new(u32::from(u8::from(unit.arg)))) == STACK_USE_GUIDELINE + } + _ => false, + }); + + assert!( + plain_load_attr && direct_call_30, + "CPython maybe_optimize_method_call rejects arg count at the guideline, got ops={ops:?}" + ); + assert!( + !ops.iter() + .any(|op| matches!(op, Instruction::CallFunctionEx)), + "exactly guideline-sized method call should stay direct after LOAD_ATTR fallback, got ops={ops:?}" + ); + } + + #[test] + fn method_call_many_keywords_stays_load_method_call_kw_like_cpython() { + let params = (0..16) + .map(|i| format!("a{i}")) + .collect::>() + .join(", "); + let keywords = (0..16) + .map(|i| format!("k{i}=a{i}")) + .collect::>() + .join(", "); + let code = compile_exec(&format!( + "def f(obj, {params}):\n return obj.m({keywords})\n" + )); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + let method_load_attr = f.instructions.iter().any(|unit| { + if let Instruction::LoadAttr { namei } = unit.op { + namei + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .is_method() + } else { + false + } + }); + let call_kw_16 = f.instructions.iter().any(|unit| match unit.op { + Instruction::CallKw { argc } => { + argc.get(OpArg::new(u32::from(u8::from(unit.arg)))) == 16 + } + _ => false, + }); + + assert!( + method_load_attr && call_kw_16, + "CPython maybe_optimize_method_call emits LOAD_METHOD/CALL_KW under its own stack threshold, got ops={ops:?}" + ); + assert!( + !ops.iter() + .any(|op| matches!(op, Instruction::CallFunctionEx)), + "method-call keyword path should not reuse codegen_call_helper_impl's lower kw threshold, got ops={ops:?}" + ); + } + #[test] fn large_plain_call_uses_direct_call_until_stack_guideline() { let code = compile_exec( @@ -17288,7 +20785,7 @@ def set_f(xs): ); assert!( !has_common_constant(list_f, bytecode::CommonConstant::BuiltinList), - "CPython 3.14.2 does not optimize list(genexpr)" + "CPython 3.14.5 does not optimize list(genexpr)" ); let set_f = find_code(&code, "set_f").expect("missing set_f code"); @@ -17301,7 +20798,7 @@ def set_f(xs): ); assert!( !has_common_constant(set_f, bytecode::CommonConstant::BuiltinSet), - "CPython 3.14.2 does not optimize set(genexpr)" + "CPython 3.14.5 does not optimize set(genexpr)" ); } @@ -17703,16 +21200,50 @@ def aug_const(x, y): x[1:2] += y ", ); - let aug_const = find_code(&code, "aug_const").expect("missing aug_const code"); + let aug_const = find_code(&code, "aug_const").expect("missing aug_const code"); + + // CPython 3.14 codegen_augassign() visits a constant slice, then emits + // COPY/COPY/BINARY_OP NB_SUBSCR at LOC(target), not at LOC(slice). + assert_eq!( + aug_const.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x04, 0x05, 0x80, 0x63, 0x87, 0x46, 0x88, 0x61, 0x85, 0x4b, 0x85, + 0x46, + ] + ); + } + + #[test] + fn augassign_attribute_copy_uses_target_location_like_cpython() { + let code = compile_exec( + "\ +def f(obj, value): + obj.attr += value +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let copy_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + let Instruction::Copy { i } = unit.op else { + return None; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + (i.get(arg) == 1).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing augmented attribute COPY"); - // CPython 3.14 codegen_augassign() visits a constant slice, then emits - // COPY/COPY/BINARY_OP NB_SUBSCR at LOC(target), not at LOC(slice). assert_eq!( - aug_const.linetable.as_ref(), - &[ - 0x80, 0x00, 0xd8, 0x04, 0x05, 0x80, 0x63, 0x87, 0x46, 0x88, 0x61, 0x85, 0x4b, 0x85, - 0x46, - ] + copy_position, + (2, 5, 2, 13), + "CPython codegen_augassign() emits COPY 1 at LOC(target) before updating to attr location" ); } @@ -18610,6 +22141,82 @@ t = t\"Value: {value=}\" ); } + #[test] + fn tstring_ops_restore_template_and_interpolation_locations_like_cpython() { + let code = compile_exec( + "\ +def f(x): + return t\"{x}\" +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let mut build_interpolation = None; + let mut build_tuple = None; + let mut build_template = None; + for (unit, (location, end_location)) in f.instructions.iter().zip(&f.locations) { + let range = ( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + ); + match unit.op { + Instruction::BuildInterpolation { .. } => build_interpolation = Some(range), + Instruction::BuildTuple { .. } => build_tuple = Some(range), + Instruction::BuildTemplate => build_template = Some(range), + _ => {} + } + } + + assert_eq!( + build_interpolation, + Some((2, 14, 2, 17)), + "CPython codegen_interpolation() restores LOC(Interpolation) after visiting the value; this direct codegen path uses the parser's Interpolation range" + ); + assert_eq!( + build_tuple, + Some((2, 12, 2, 18)), + "CPython codegen_template_str() emits the interpolations tuple at LOC(TemplateStr); this direct codegen path uses the parser's TemplateStr range" + ); + assert_eq!( + build_template, + Some((2, 12, 2, 18)), + "CPython codegen_template_str() emits BUILD_TEMPLATE at LOC(TemplateStr); this direct codegen path uses the parser's TemplateStr range" + ); + } + + #[test] + fn regular_call_push_null_uses_callee_location_like_cpython() { + let code = compile_exec( + "\ +def f(g, x): + return ( + g + )(x) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let push_null = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::PushNull).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing PUSH_NULL"); + + assert_eq!( + push_null, + (3, 9, 3, 10), + "CPython codegen_call() resets loc to LOC(func) before emitting PUSH_NULL; this direct codegen path uses the parser's callee range" + ); + } + #[test] fn tstring_literal_preserves_surrogate_wtf8() { let code = compile_exec("t = t\"\\ud800\""); @@ -20898,6 +24505,63 @@ def f(lines, close): ); } + #[test] + fn try_finally_return_inside_with_pops_unwound_fblocks_for_finalbody() { + let code = compile_exec( + "\ +def f(cm): + try: + with cm: + return 1 + finally: + return 2 +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert!( + f.instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::ReturnValue)), + "return inside with/try-finally should compile without leaving an invalid CFG" + ); + } + + #[test] + fn except_star_return_after_with_unwind_uses_no_location_like_cpython() { + let err = compile_exec_error( + "\ +def f(cm): + try: + pass + except* Exception: + with cm: + return 1 +", + ); + assert!(matches!( + err.error, + CodegenErrorType::BreakContinueReturnInExceptStar + )); + assert!( + err.location.is_none(), + "CPython codegen_unwind_fblock(WITH) sets *ploc = NO_LOCATION before the except* error" + ); + } + + #[test] + fn async_generator_return_value_error_message_matches_cpython() { + assert_eq!( + compile_exec_error_message( + "\ +async def f(): + yield 1 + return 2 +" + ), + "'return' with value in async generator" + ); + } + #[test] fn try_except_finally_handler_normal_exit_keeps_nointerrupt_jump() { let code = compile_exec( @@ -23981,6 +27645,66 @@ class C: ); } + #[test] + fn optimize_two_strips_docstrings_during_preprocess() { + let code = compile_exec_with_options( + "\ +\"module doc\" + +def f(): + \"function doc\" + return 1 + +class C: + \"class doc\" + x = 1 +", + CompileOpts { + optimize: 2, + ..CompileOpts::default() + }, + ); + + assert!( + !code.instructions.iter().any(|unit| { + matches!( + unit.op, + Instruction::StoreName { namei } + if code.names + [namei.get(OpArg::new(u32::from(u8::from(unit.arg)))) as usize] + .as_str() + == "__doc__" + ) + }), + "module docstring should be stripped before codegen, got instructions={:?}", + code.instructions + ); + + let function_code = find_code(&code, "f").expect("missing function code"); + assert!( + !function_code + .flags + .contains(bytecode::CodeFlags::HAS_DOCSTRING), + "function docstring should not set HAS_DOCSTRING when optimize=2" + ); + + let class_code = find_code(&code, "C").expect("missing class code"); + assert!( + !class_code.instructions.iter().any(|unit| { + matches!( + unit.op, + Instruction::StoreName { namei } + if class_code.names + [namei.get(OpArg::new(u32::from(u8::from(unit.arg)))) as usize] + .as_str() + == "__doc__" + ) + }), + "class docstring should be stripped before codegen, got instructions={:?}", + class_code.instructions + ); + } + #[test] fn future_annotations_flag_is_inherited_like_cpython() { let code = compile_exec( @@ -23993,11 +27717,117 @@ def f(): return C ", ); - assert!(code.flags.contains(CodeFlags::FUTURE_ANNOTATIONS)); + assert!(code.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); let f = find_code(&code, "f").expect("missing f code"); - assert!(f.flags.contains(CodeFlags::FUTURE_ANNOTATIONS)); + assert!(f.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); let class_code = find_code(f, "C").expect("missing C code"); - assert!(class_code.flags.contains(CodeFlags::FUTURE_ANNOTATIONS)); + assert!( + class_code + .flags + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS) + ); + } + + #[test] + fn future_flags_from_compile_options_are_merged_like_cpython() { + let opts = CompileOpts { + future_features: bytecode::CodeFlags::FUTURE_ANNOTATIONS + | bytecode::CodeFlags::FUTURE_DIVISION, + ..CompileOpts::default() + }; + let code = compile_exec_with_options( + "\ +x: int +def f(): + pass +", + opts, + ); + assert!(code.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + assert!(code.flags.contains(bytecode::CodeFlags::FUTURE_DIVISION)); + assert!( + code.instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::SetupAnnotations)) + ); + let f = find_code(&code, "f").expect("missing f code"); + assert!(f.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + assert!(f.flags.contains(bytecode::CodeFlags::FUTURE_DIVISION)); + } + + #[test] + fn future_barry_as_flufl_is_accepted_but_ignored() { + let code = compile_exec( + "\ +from __future__ import barry_as_FLUFL + +def f(): + pass +", + ); + let future_flags = bytecode::CodeFlags::FUTURE_DIVISION + | bytecode::CodeFlags::FUTURE_ABSOLUTE_IMPORT + | bytecode::CodeFlags::FUTURE_WITH_STATEMENT + | bytecode::CodeFlags::FUTURE_PRINT_FUNCTION + | bytecode::CodeFlags::FUTURE_UNICODE_LITERALS + | bytecode::CodeFlags::FUTURE_GENERATOR_STOP + | bytecode::CodeFlags::FUTURE_ANNOTATIONS; + assert!((code.flags & future_flags).is_empty()); + let f = find_code(&code, "f").expect("missing f code"); + assert!((f.flags & future_flags).is_empty()); + } + + #[test] + fn relative_future_import_does_not_enable_annotations_like_cpython() { + let code = compile_exec( + "\ +from .__future__ import annotations +x: int +", + ); + assert!(!code.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + } + + #[test] + fn future_braces_uses_cpython_special_error() { + assert_eq!( + compile_exec_error_message("from __future__ import braces\n"), + "not a chance" + ); + } + + #[test] + fn invalid_future_feature_is_checked_before_ast_preprocess_like_cpython() { + assert_eq!( + compile_exec_error_message("from __future__ import spam, annotations\nx: (y := int)\n"), + "future feature spam is not defined" + ); + } + + #[test] + fn allow_top_level_await_marks_module_coroutine_like_cpython() { + let opts = CompileOpts { + allow_top_level_await: true, + ..CompileOpts::default() + }; + let code = compile_exec_with_options("await f()\n", opts); + assert!(code.flags.contains(bytecode::CodeFlags::COROUTINE)); + } + + #[test] + fn allow_top_level_await_accepts_module_async_for_like_cpython() { + let opts = CompileOpts { + allow_top_level_await: true, + ..CompileOpts::default() + }; + let code = compile_exec_with_options( + "\ +async for x in y: + pass +", + opts, + ); + assert!(code.flags.contains(bytecode::CodeFlags::COROUTINE)); } #[test] @@ -24016,7 +27846,7 @@ def outer(): let class_annotate = find_code(class_code, "__annotate__").expect("missing class annotation code"); assert!( - !class_annotate.flags.contains(CodeFlags::NESTED), + !class_annotate.flags.contains(bytecode::CodeFlags::NESTED), "module-level class annotation scope should not be nested" ); @@ -24025,7 +27855,7 @@ def outer(): let nested_annotate = find_code(nested_class, "__annotate__").expect("missing nested annotation code"); assert!( - nested_annotate.flags.contains(CodeFlags::NESTED), + nested_annotate.flags.contains(bytecode::CodeFlags::NESTED), "annotation scope under a nested class should be nested" ); } @@ -24040,25 +27870,25 @@ type A[T] = T ); let outer_lambda = find_code(&code, "").expect("missing outer lambda code"); assert!( - !outer_lambda.flags.contains(CodeFlags::NESTED), + !outer_lambda.flags.contains(bytecode::CodeFlags::NESTED), "module-level lambda should not be nested" ); let inner_lambda = find_direct_child_code(outer_lambda, "").expect("missing inner lambda code"); assert!( - inner_lambda.flags.contains(CodeFlags::NESTED), + inner_lambda.flags.contains(bytecode::CodeFlags::NESTED), "lambda inside lambda should be nested" ); let type_params = find_code(&code, "").expect("missing type params code"); assert!( - !type_params.flags.contains(CodeFlags::NESTED), + !type_params.flags.contains(bytecode::CodeFlags::NESTED), "module-level type-parameter scope should not be nested" ); let type_alias = find_direct_child_code(type_params, "A").expect("missing type alias code"); assert!( - type_alias.flags.contains(CodeFlags::NESTED), + type_alias.flags.contains(bytecode::CodeFlags::NESTED), "type alias body inside type-parameter scope should be nested" ); } @@ -26699,21 +30529,95 @@ def f(cm, func, args, kwds): let return_positions: Vec<_> = f .instructions .iter() - .zip(&f.locations) - .filter_map(|(unit, (location, end_location))| { - matches!(unit.op, Instruction::ReturnValue).then_some(( - location.line.get(), - location.character_offset.get(), - end_location.line.get(), - end_location.character_offset.get(), + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 10, 2, 12), (2, 10, 2, 12)], + "CPython codegen_unwind_fblock(WITH) leaves RETURN_VALUE inheriting the context expression location" + ); + } + + #[test] + fn with_normal_cleanup_jump_uses_context_expr_location_like_cpython() { + let source = "\ +with cm: + pass +x = 1 +"; + let mut opts = CompileOpts::default(); + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let parsed = ruff_python_parser::parse( + source_file.source_text(), + ruff_python_parser::Mode::Module.into(), + ) + .unwrap(); + let mut ast = parsed.into_syntax(); + opts.future_features |= preprocess::future_features(&ast); + let future_annotations = opts + .future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + preprocess::preprocess_mod(&mut ast, opts.optimize, future_annotations, false); + let ast = match ast { + ruff_python_ast::Mod::Module(stmts) => stmts, + _ => unreachable!(), + }; + let symbol_table = SymbolTable::scan_program_with_options( + &ast, + source_file.clone(), + opts.allow_top_level_await, + opts.future_features + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS), + opts.ast_constant_overrides.clone(), + opts.ast_interpolation_overrides.clone(), + opts.ast_formatted_value_overrides.clone(), + opts.ast_joined_str_overrides.clone(), + opts.ast_template_str_overrides.clone(), + opts.recursion_limit, + ) + .map_err(|e| e.into_codegen_error(source_file.name().to_owned())) + .unwrap(); + let mut compiler = + Compiler::new_with_syntax_warning_handler(opts, source_file, "", None); + compiler.compile_program(&ast, symbol_table).unwrap(); + + let jump_positions = compiler + .current_code_info() + .blocks + .iter() + .flat_map(|block| block.used_instructions()) + .filter_map(|info| { + matches!( + info.instr, + AnyInstruction::Pseudo(PseudoInstruction::Jump { .. }) + ) + .then_some(( + ( + info.location.line.get(), + info.location.character_offset.get(), + info.end_location.line.get(), + info.end_location.character_offset.get(), + ), + info.lineno_override, )) }) - .collect(); + .collect::>(); - assert_eq!( - return_positions, - vec![(2, 10, 2, 12), (2, 10, 2, 12)], - "CPython codegen_unwind_fblock(WITH) leaves RETURN_VALUE inheriting the context expression location" + assert!( + jump_positions + .iter() + .any(|(position, override_)| *position == (1, 6, 1, 8) + && *override_ != Some(ir::NO_LOCATION_OVERRIDE)), + "CPython codegen_with_inner() emits the normal-exit JUMP at LOC(context_expr), not NO_LOCATION; got {jump_positions:?}" ); } @@ -27029,6 +30933,44 @@ def f(x): assert_eq!(join_attr_count, 1); } + #[test] + fn large_fstring_join_scaffolding_uses_joinedstr_location_like_cpython() { + let mut source = String::from("def f(x):\n return f\""); + for _ in 0..=STACK_USE_GUIDELINE { + source.push_str("{x}"); + } + source.push_str("\"\n"); + + let code = compile_exec(&source); + let f = find_code(&code, "f").expect("missing function code"); + let fstring_end = " return ".len() + + 3 + + 3 * usize::try_from(STACK_USE_GUIDELINE + 1).expect("guideline overflowed") + + 1; + let expected = (2, 12, 2, fstring_end); + + for (unit, (location, end_location)) in f.instructions.iter().zip(&f.locations) { + if matches!( + unit.op, + Instruction::BuildList { .. } + | Instruction::ListAppend { .. } + | Instruction::Call { .. } + ) { + assert_eq!( + ( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + ), + expected, + "CPython codegen_joined_str() emits join scaffolding at LOC(JoinedStr); this direct codegen path uses the parser's FString range, op={:?}", + unit.op + ); + } + } + } + #[test] fn large_power_is_not_constant_folded() { let code = compile_exec("x = 2**100\n"); @@ -27529,6 +31471,65 @@ class C: assert_eq!(varnames, vec!["format"]); } + #[test] + fn future_function_signature_annotation_uses_hidden_block_like_cpython() { + let code = compile_exec( + "\ +from __future__ import annotations +def f(x: T): pass +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + let varnames = annotate + .varnames + .iter() + .map(|name| name.as_str()) + .collect::>(); + assert_eq!(varnames, vec!["format"]); + assert!( + find_code(&code, "f").is_some(), + "function body symbol-table cursor must skip the hidden AnnotationBlock" + ); + } + + #[test] + fn deferred_annotation_format_name_does_not_capture_helper_parameter() { + let code = compile_exec( + "\ +format = object() +x: format +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + let varnames = annotate + .varnames + .iter() + .map(|name| name.as_str()) + .collect::>(); + assert_eq!(varnames, vec!["format"]); + assert!( + annotate.names.iter().any(|name| name.as_str() == "format"), + "CPython keeps the helper parameter as internal .format during symbol analysis, so annotation expression `format` must remain a separate name; got names={:?}", + annotate.names + ); + + let helper_param_loads = annotate + .instructions + .iter() + .filter(|unit| match unit.op { + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + annotate.varnames[usize::from(var_num.get(arg))].as_str() == "format" + } + _ => false, + }) + .count(); + assert_eq!( + helper_param_loads, 1, + "only the CPython format-validation prologue should load the helper parameter; annotation expression `format` must not compile as LOAD_FAST" + ); + } + #[test] fn non_simple_class_annotation_is_not_deferred_like_cpython() { let code = compile_exec( @@ -27576,6 +31577,106 @@ class C: ); } + #[test] + fn class_deferred_annotations_guard_only_conditional_entries_like_cpython() { + let code = compile_exec( + "\ +class C: + x: int + if flag: + y: str + z: float +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let class_ops: Vec<_> = class_code + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let class_set_adds = class_ops + .iter() + .filter(|op| matches!(op, Instruction::SetAdd { .. })) + .count(); + assert_eq!( + class_set_adds, 1, + "CPython _PyCompile_AddDeferredAnnotation() adds class annotations to __conditional_annotations__ only inside conditional blocks, got ops={class_ops:?}" + ); + assert!( + class_code.instructions.iter().any(|unit| match unit.op { + Instruction::LoadDeref { i } => { + let idx = i.get(OpArg::new(u32::from(u8::from(unit.arg)))).as_usize(); + localsplus_name(class_code, idx) == Some("__conditional_annotations__") + } + _ => false, + }), + "CPython codegen_annassign() emits LOAD_DEREF for class __conditional_annotations__, got ops={class_ops:?}" + ); + assert!( + !class_code + .instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadFromDictOrDeref { .. })), + "CPython codegen_annassign() bypasses codegen_nameop for class __conditional_annotations__, got ops={class_ops:?}" + ); + + let annotate = find_code(class_code, "__annotate__").expect("missing __annotate__ code"); + let annotate_ops: Vec<_> = annotate + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let annotation_body = annotate_ops + .iter() + .position(|op| matches!(op, Instruction::BuildMap { .. })) + .map(|idx| &annotate_ops[idx..]) + .expect("missing annotation map build"); + let guarded_entries = annotation_body + .iter() + .filter(|op| matches!(op, Instruction::PopJumpIfFalse { .. })) + .count(); + assert_eq!( + guarded_entries, 1, + "CPython codegen_deferred_annotations_body() guards only conditional class annotations, got ops={annotate_ops:?}" + ); + } + + #[test] + fn future_annotations_non_simple_target_checks_target_but_not_annotation_like_cpython() { + let code = compile_exec( + "\ +from __future__ import annotations +class C: + target[item]: missing +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let loaded_names: Vec<_> = class_code + .instructions + .iter() + .filter_map(|unit| match unit.op { + Instruction::LoadName { namei } => { + let idx = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + Some(class_code.names[usize::try_from(idx).unwrap()].as_str()) + } + _ => None, + }) + .collect(); + + assert!( + ["target", "item"] + .iter() + .all(|name| loaded_names.contains(name)), + "CPython codegen_annassign() still checks bare complex annotation targets under future annotations, got loaded_names={loaded_names:?}" + ); + assert!( + !loaded_names.contains(&"missing"), + "CPython codegen_check_annotation() skips the annotation expression under future annotations, got loaded_names={loaded_names:?}" + ); + } + #[test] fn type_param_evaluator_uses_dot_format_varname() { let code = compile_exec( @@ -27680,7 +31781,7 @@ def func[T](a: T = 'a', *, b: T = 'b'): } #[test] - fn generic_function_type_params_varnames_include_defaults_like_cpython() { + fn generic_function_type_params_omit_defaults_without_defaults_like_cpython() { let code = compile_exec( "\ def func[T](): @@ -27696,8 +31797,29 @@ def func[T](): .iter() .map(String::as_str) .collect::>(), - vec![".defaults", "T"] + vec!["T"] + ); + } + + #[test] + fn generic_function_type_params_split_defaults_like_cpython() { + let code = compile_exec( + "\ +def with_pos[T](a: T = 1): + pass +def with_kw[U](*, a: U = 1): + pass +", ); + let with_pos = + find_code(&code, "").expect("missing type params code"); + let with_kw = + find_code(&code, "").expect("missing type params code"); + + assert!(with_pos.varnames.iter().any(|name| name == ".defaults")); + assert!(!with_pos.varnames.iter().any(|name| name == ".kwdefaults")); + assert!(!with_kw.varnames.iter().any(|name| name == ".defaults")); + assert!(with_kw.varnames.iter().any(|name| name == ".kwdefaults")); } #[test] @@ -27925,6 +32047,39 @@ class C[T]: } } + #[test] + fn non_inlined_listcomp_return_uses_comprehension_location_like_cpython() { + let code = compile_exec( + "\ +class C[T]: + class Inner[U]( + make_base([T for _ in (1,)]) + ): + pass +", + ); + let listcomp = find_code(&code, "").expect("missing listcomp code"); + let return_positions: Vec<_> = listcomp + .instructions + .iter() + .zip(&listcomp.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(3, 19, 3, 36)], + "CPython codegen_comprehension() emits non-gen RETURN_VALUE at LOC(e)" + ); + } + #[test] fn class_annotation_global_resolution_matches_cpython() { let class_global = compile_exec( @@ -28280,6 +32435,21 @@ def tuple_or_tuple(): ); } + #[test] + fn chained_compare_jump_if_runs_cpython_check_compare_warning() { + let message = first_exec_warning( + "\ +def f(x): + if 1 is 1 < x: + return x +", + ); + assert!( + message.contains("\"is\" with 'int' literal"), + "CPython codegen_jump_if() checks chained comparisons before conditional lowering, got {message:?}" + ); + } + #[test] fn lambda_without_body_constants_keeps_none_like_cpython() { let code = compile_exec("f = lambda x: x"); @@ -28293,6 +32463,17 @@ def tuple_or_tuple(): ); } + #[test] + fn generator_lambda_without_body_constants_omits_none_like_cpython() { + let code = compile_exec("f = lambda x: (yield x)"); + let lambda = find_code(&code, "").expect("missing lambda code"); + + assert!( + lambda.constants.is_empty(), + "CPython codegen_lambda() assembles generator lambdas with addNone=0" + ); + } + #[test] fn call_function_ex_empty_args_tuple_is_folded_late_like_cpython() { let code = compile_exec( @@ -28425,6 +32606,39 @@ f = lambda x: x in {0} ))); } + #[test] + fn frozenset_membership_consts_deduplicate_like_cpython_constant_key() { + let code = compile_exec( + "\ +def f(x): + return x in {1, 2}, x in {2, 1}, x in {1, 1} +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let frozensets: Vec<_> = f + .constants + .iter() + .filter_map(|constant| match constant { + ConstantData::Frozenset { elements } => Some(elements.as_slice()), + _ => None, + }) + .collect(); + + assert_eq!( + frozensets.len(), + 2, + "CPython folds equal frozensets to the same const key and removes duplicate set items" + ); + assert!( + frozensets.iter().any(|elements| elements.len() == 2), + "missing shared frozenset constant for {{1, 2}} and {{2, 1}}" + ); + assert!( + frozensets.iter().any(|elements| elements.len() == 1), + "missing duplicate-collapsed frozenset constant for {{1, 1}}" + ); + } + #[test] fn nonconstant_list_membership_uses_tuple() { let code = compile_exec( @@ -29210,7 +33424,7 @@ deoptmap = { } let comp = symbol_table - .sub_tables + .inlined_comprehension_blocks .first() .expect("missing comprehension symbol table"); assert!(comp.comp_inlined, "expected comprehension to be inlined"); @@ -29609,6 +33823,34 @@ values = ( ); } + #[test] + fn single_mode_returns_none_after_print_like_cpython() { + let code = compile_single("1\n"); + let ops = code + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Resume { .. })) + .collect::>(); + + assert!( + !ops.iter() + .any(|unit| matches!(unit.op, Instruction::Copy { .. })), + "CPython codegen_stmt_expr() prints and pops interactive expressions; it does not preserve the final expression as the code object's return value, got ops={ops:?}" + ); + let Some(load_none) = ops.iter().rev().nth(1) else { + panic!("missing final LOAD_CONST None before RETURN_VALUE, got ops={ops:?}"); + }; + let Instruction::LoadConst { consti } = load_none.op else { + panic!("missing final LOAD_CONST None before RETURN_VALUE, got ops={ops:?}"); + }; + let constant = &code.constants[consti.get(OpArg::new(u32::from(u8::from(load_none.arg))))]; + assert!(matches!(constant, ConstantData::None)); + assert!(matches!( + ops.last().map(|unit| unit.op), + Some(Instruction::ReturnValue) + )); + } + #[test] fn folded_multiline_bytes_binop_does_not_leave_operand_nops() { let code = compile_exec( diff --git a/crates/codegen/src/error.rs b/crates/codegen/src/error.rs index fb848354e86..9f11eba946f 100644 --- a/crates/codegen/src/error.rs +++ b/crates/codegen/src/error.rs @@ -3,21 +3,6 @@ use core::fmt::Display; use rustpython_compiler_core::SourceLocation; use thiserror::Error; -#[derive(Clone, Copy, Debug)] -pub enum PatternUnreachableReason { - NameCapture, - Wildcard, -} - -impl Display for PatternUnreachableReason { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::NameCapture => write!(f, "name capture"), - Self::Wildcard => write!(f, "wildcard"), - } - } -} - // pub type CodegenError = rustpython_parser_core::source_code::LocatedError; #[derive(Error, Debug)] @@ -70,6 +55,8 @@ pub enum CodegenErrorType { SyntaxError(String), /// Multiple `*` detected MultipleStarArgs, + MultipleStarredExpressionsInSequencePattern, + MultipleStarredNamesInSequencePattern, /// Misplaced `*` expression InvalidStarExpr, /// Break statement outside of loop. @@ -87,14 +74,17 @@ pub enum CodegenErrorType { AsyncReturnValue, InvalidFuturePlacement, InvalidFutureFeature(String), - FunctionImportStar, + InvalidFutureBraces, + RecursionError, TooManyStarUnpack, + TooManyExpressionsInStarUnpackingSequencePattern, EmptyWithItems, EmptyWithBody, ForbiddenName, DuplicateStore(String), - UnreachablePattern(PatternUnreachableReason), - RepeatedAttributePattern, + UnreachableWildcardPattern, + UnreachableNameCapturePattern(String), + RepeatedAttributePattern(String), ConflictingNameBindPattern, /// break/continue/return inside except* block BreakContinueReturnInExceptStar, @@ -112,6 +102,12 @@ impl fmt::Display for CodegenErrorType { Self::MultipleStarArgs => { write!(f, "multiple starred expressions in assignment") } + Self::MultipleStarredExpressionsInSequencePattern => { + write!(f, "multiple starred expressions in sequence pattern") + } + Self::MultipleStarredNamesInSequencePattern => { + write!(f, "multiple starred names in sequence pattern") + } Self::InvalidStarExpr => write!(f, "can't use starred expression here"), Self::InvalidBreak => write!(f, "'break' outside loop"), Self::InvalidContinue => write!(f, "'continue' not properly in loop"), @@ -128,9 +124,7 @@ impl fmt::Display for CodegenErrorType { ) } Self::AsyncYieldFrom => write!(f, "'yield from' inside async function"), - Self::AsyncReturnValue => { - write!(f, "'return' with value inside async generator") - } + Self::AsyncReturnValue => write!(f, "'return' with value in async generator"), Self::InvalidFuturePlacement => write!( f, "from __future__ imports must occur at the beginning of the file" @@ -138,12 +132,16 @@ impl fmt::Display for CodegenErrorType { Self::InvalidFutureFeature(feat) => { write!(f, "future feature {feat} is not defined") } - Self::FunctionImportStar => { - write!(f, "import * only allowed at module level") + Self::InvalidFutureBraces => write!(f, "not a chance"), + Self::RecursionError => { + write!(f, "maximum recursion depth exceeded during compilation") } Self::TooManyStarUnpack => { write!(f, "too many expressions in star-unpacking assignment") } + Self::TooManyExpressionsInStarUnpackingSequencePattern => { + write!(f, "too many expressions in star-unpacking sequence pattern") + } Self::EmptyWithItems => { write!(f, "empty items on With") } @@ -153,14 +151,18 @@ impl fmt::Display for CodegenErrorType { Self::ForbiddenName => { write!(f, "forbidden attribute name") } - Self::DuplicateStore(s) => { - write!(f, "duplicate store {s}") + Self::DuplicateStore(s) => write!(f, "multiple assignments to name '{s}' in pattern"), + Self::UnreachableWildcardPattern => { + write!(f, "wildcard makes remaining patterns unreachable") } - Self::UnreachablePattern(reason) => { - write!(f, "{reason} makes remaining patterns unreachable") + Self::UnreachableNameCapturePattern(name) => { + write!( + f, + "name capture '{name}' makes remaining patterns unreachable" + ) } - Self::RepeatedAttributePattern => { - write!(f, "attribute name repeated in class pattern") + Self::RepeatedAttributePattern(name) => { + write!(f, "attribute name repeated in class pattern: {name}") } Self::ConflictingNameBindPattern => { write!(f, "alternative patterns bind different names") diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index e7b50659e8e..e60bfa9fff4 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -78,14 +78,131 @@ impl ConstantPool { } } + fn frozenset_key_contains(elements: &[ConstantData], needle: &ConstantData) -> bool { + if Self::constant_contains_nan(needle) { + return false; + } + elements.iter().any(|element| { + !Self::constant_contains_nan(element) && Self::constant_key_eq(element, needle) + }) + } + + fn frozenset_key_eq(left: &[ConstantData], right: &[ConstantData]) -> bool { + left.iter() + .all(|element| Self::frozenset_key_contains(right, element)) + && right + .iter() + .all(|element| Self::frozenset_key_contains(left, element)) + } + + fn constant_key_eq(left: &ConstantData, right: &ConstantData) -> bool { + match (left, right) { + (ConstantData::Tuple { elements: left }, ConstantData::Tuple { elements: right }) => { + left.len() == right.len() + && left + .iter() + .zip(right.iter()) + .all(|(left, right)| Self::constant_key_eq(left, right)) + } + ( + ConstantData::Frozenset { elements: left }, + ConstantData::Frozenset { elements: right }, + ) => Self::frozenset_key_eq(left, right), + (ConstantData::Slice { elements: left }, ConstantData::Slice { elements: right }) => { + left.iter() + .zip(right.iter()) + .all(|(left, right)| Self::constant_key_eq(left, right)) + } + _ => left == right, + } + } + + fn canonicalize_constant_key(constant: ConstantData) -> crate::InternalResult { + match constant { + ConstantData::Tuple { elements } => { + let mut canonical = Vec::new(); + canonical + .try_reserve_exact(elements.len()) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + for element in elements { + canonical.push(Self::canonicalize_constant_key(element)?); + } + Ok(ConstantData::Tuple { + elements: canonical, + }) + } + ConstantData::Slice { elements } => { + let [start, stop, step] = *elements; + Ok(ConstantData::Slice { + elements: Box::new([ + Self::canonicalize_constant_key(start)?, + Self::canonicalize_constant_key(stop)?, + Self::canonicalize_constant_key(step)?, + ]), + }) + } + ConstantData::Frozenset { elements } => { + let mut canonical = Vec::new(); + canonical + .try_reserve_exact(elements.len()) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + for element in elements { + let element = Self::canonicalize_constant_key(element)?; + if !Self::frozenset_key_contains(&canonical, &element) { + canonical.push(element); + } + } + Ok(ConstantData::Frozenset { + elements: canonical, + }) + } + other => Ok(other), + } + } + + fn canonicalize_constant_key_infallible(constant: ConstantData) -> ConstantData { + match constant { + ConstantData::Tuple { elements } => ConstantData::Tuple { + elements: elements + .into_iter() + .map(Self::canonicalize_constant_key_infallible) + .collect(), + }, + ConstantData::Slice { elements } => { + let [start, stop, step] = *elements; + ConstantData::Slice { + elements: Box::new([ + Self::canonicalize_constant_key_infallible(start), + Self::canonicalize_constant_key_infallible(stop), + Self::canonicalize_constant_key_infallible(step), + ]), + } + } + ConstantData::Frozenset { elements } => { + let mut canonical = Vec::with_capacity(elements.len()); + for element in elements { + let element = Self::canonicalize_constant_key_infallible(element); + if !Self::frozenset_key_contains(&canonical, &element) { + canonical.push(element); + } + } + ConstantData::Frozenset { + elements: canonical, + } + } + other => other, + } + } + pub fn insert_full(&mut self, constant: ConstantData) -> (usize, bool) { + let constant = Self::canonicalize_constant_key_infallible(constant); // CPython's _PyCode_ConstantKey() keeps NaN-bearing constants distinct // because Python-level NaN keys do not compare equal. if !Self::constant_contains_nan(&constant) && let Some(idx) = self .constants .iter() - .position(|existing| existing == &constant) + .position(|existing| Self::constant_key_eq(existing, &constant)) { return (idx, false); } @@ -95,13 +212,14 @@ impl ConstantPool { } fn try_insert_full(&mut self, constant: ConstantData) -> crate::InternalResult<(usize, bool)> { + let constant = Self::canonicalize_constant_key(constant)?; // CPython's _PyCode_ConstantKey() keeps NaN-bearing constants distinct // because Python-level NaN keys do not compare equal. if !Self::constant_contains_nan(&constant) && let Some(idx) = self .constants .iter() - .position(|existing| existing == &constant) + .position(|existing| Self::constant_key_eq(existing, &constant)) { return Ok((idx, false)); } @@ -1399,7 +1517,7 @@ impl IndexMut for Blocks { } pub(crate) const START_DEPTH_UNSET: i32 = i32::MIN; -const CO_MAXBLOCKS: usize = 20; +const CO_MAXBLOCKS: usize = 21; /// flowgraph.c struct _PyCfgExceptStack #[derive(Clone, Debug)] @@ -1975,6 +2093,9 @@ impl CodeInfo { kwonlyargcount: kwonlyarg_count, firstlineno: first_line_number, } = metadata; + let code_arg_count = posonlyarg_count + .checked_add(arg_count) + .ok_or(InternalError::MalformedControlFlowGraph)?; resolve_unconditional_jumps(&mut instr_sequence)?; resolve_jump_offsets(&mut instr_sequence)?; @@ -1992,7 +2113,7 @@ impl CodeInfo { Ok(CodeObject { flags, posonlyarg_count, - arg_count, + arg_count: code_arg_count, kwonlyarg_count, source_path, first_line_number: Some(first_line_number), @@ -2502,8 +2623,13 @@ fn const_folding_safe_multiply(left: &ConstantData, right: &ConstantData) -> Opt const_folding_safe_multiply(right, left) } (ConstantData::Tuple { elements }, ConstantData::Integer { value: n }) => { + if elements.is_empty() { + return Some(ConstantData::Tuple { + elements: Vec::new(), + }); + } let n = n.to_usize()?; - if n != 0 && !elements.is_empty() { + if n != 0 { if n > MAX_COLLECTION_SIZE / elements.len() { return None; } @@ -3282,8 +3408,7 @@ fn optimize_lists_and_sets( if !contains_or_iter { debug_assert!(i >= 2); - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; + let folded_loc = instr_location(&block.instructions[i]); nop_out(block, &operand_indices); @@ -3294,9 +3419,7 @@ fn optimize_lists_and_sets( } .into(); instr_set_op1(&mut block.instructions[i - 2], build_instr, OpArg::new(0)); - block.instructions[i - 2].location = folded_loc; - block.instructions[i - 2].end_location = end_loc; - block.instructions[i - 2].lineno_override = None; + instr_set_location(&mut block.instructions[i - 2], folded_loc); instr_set_op1( &mut block.instructions[i - 1], @@ -3545,7 +3668,7 @@ fn basicblock_optimize_load_const( block: &mut Block, ) -> crate::InternalResult<()> { let mut i = 0; - let mut effective_opcode = None; + let mut effective_opcode = Instruction::Nop.into(); let mut effective_oparg = OpArg::new(0); while i < block.instruction_used { if matches!( @@ -3559,21 +3682,19 @@ fn basicblock_optimize_load_const( let curr = block.instructions[i]; let curr_arg = curr.arg; - // Only combine if the source is a real instruction. - let Some(curr_instr) = curr.instr.real() else { - i += 1; - continue; - }; - let is_copy_of_load_const = matches!( - (effective_opcode, curr_instr), - (Some(Instruction::LoadConst { .. }), Instruction::Copy { i }) if i.get(curr_arg) == 1 + (effective_opcode, curr.instr.real()), + (AnyInstruction::Real(Instruction::LoadConst { .. }), Some(Instruction::Copy { i })) + if i.get(curr_arg) == 1 ); if !is_copy_of_load_const { - effective_opcode = Some(curr_instr); + effective_opcode = curr.instr; effective_oparg = curr_arg; } - let Some(const_instr) = effective_opcode else { + debug_assert!(!effective_opcode.is_assembler()); + let Some(const_instr @ (Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. })) = + effective_opcode.real() + else { i += 1; continue; }; @@ -3666,7 +3787,7 @@ fn basicblock_optimize_load_const( Opcode::PopJumpIfNone } .into(); - i = jump_idx; + i += 1; continue; } } @@ -3795,6 +3916,7 @@ fn optimize_basic_block( | PseudoInstruction::JumpIfTrue { .. }), ) => { let opcode = pseudo.into(); + let opcode_is_false = matches!(pseudo, PseudoInstruction::JumpIfFalse { .. }); match target.instr.pseudo().map(Into::into) { Some(PseudoOpcode::Jump) if jump_thread(blocks, block_idx, i, &target, opcode)? => @@ -3802,22 +3924,25 @@ fn optimize_basic_block( continue; } Some(PseudoOpcode::JumpIfFalse) - if matches!( - opcode, - AnyInstruction::Pseudo(PseudoInstruction::JumpIfFalse { .. }) - ) && jump_thread(blocks, block_idx, i, &target, opcode)? => + if opcode_is_false + && jump_thread(blocks, block_idx, i, &target, opcode)? => { continue; } Some(PseudoOpcode::JumpIfTrue) - if matches!( - opcode, - AnyInstruction::Pseudo(PseudoInstruction::JumpIfTrue { .. }) - ) && jump_thread(blocks, block_idx, i, &target, opcode)? => + if !opcode_is_false + && jump_thread(blocks, block_idx, i, &target, opcode)? => { continue; } - Some(PseudoOpcode::JumpIfFalse | PseudoOpcode::JumpIfTrue) => { + Some(PseudoOpcode::JumpIfTrue) if opcode_is_false => { + let next = blocks[inst.target.idx()].next; + debug_assert!(next != BlockIdx::NULL); + debug_assert!(next != inst.target); + blocks[bi].instructions[i].target = next; + continue; + } + Some(PseudoOpcode::JumpIfFalse) if !opcode_is_false => { let next = blocks[inst.target.idx()].next; debug_assert!(next != BlockIdx::NULL); debug_assert!(next != inst.target); @@ -6825,6 +6950,58 @@ pub(crate) fn fix_cell_offsets( #[cfg(test)] mod tests { use super::*; + use rustpython_compiler_core::bytecode::Arg; + + fn int_const(value: i32) -> ConstantData { + ConstantData::Integer { + value: BigInt::from(value), + } + } + + fn nan_const() -> ConstantData { + ConstantData::Float { value: f64::NAN } + } + + #[test] + fn constant_pool_frozenset_key_ignores_order_and_duplicates_like_cpython() { + let mut pool = ConstantPool::default(); + let (first, inserted) = pool.insert_full(ConstantData::Frozenset { + elements: vec![int_const(1), int_const(2)], + }); + assert_eq!(first, 0); + assert!(inserted); + + let (second, inserted) = pool.insert_full(ConstantData::Frozenset { + elements: vec![int_const(2), int_const(1), int_const(1)], + }); + assert_eq!( + second, first, + "CPython _PyCode_ConstantKey uses frozenset item keys, not insertion order" + ); + assert!(!inserted); + assert!(matches!( + &pool.constants[first], + ConstantData::Frozenset { elements } if elements.len() == 2 + )); + } + + #[test] + fn constant_pool_frozenset_key_preserves_nan_duplicates_like_cpython() { + let mut pool = ConstantPool::default(); + let (idx, inserted) = pool.insert_full(ConstantData::Frozenset { + elements: vec![nan_const(), nan_const()], + }); + + assert_eq!(idx, 0); + assert!(inserted); + assert!(matches!( + &pool.constants[idx], + ConstantData::Frozenset { elements } + if elements.iter().filter(|constant| { + matches!(constant, ConstantData::Float { value } if value.is_nan()) + }).count() == 2 + )); + } fn test_location(line: u32) -> SourceLocation { SourceLocation { @@ -6859,6 +7036,13 @@ mod tests { instr } + fn test_true_cond_jump(target: BlockIdx, line: u32) -> InstructionInfo { + let mut instr = test_instr(Instruction::Nop, line); + instr.instr = PseudoOpcode::JumpIfTrue.into(); + instr.target = target; + instr + } + fn test_block_push(block: &mut Block, info: InstructionInfo) { let off = basicblock_next_instr(block).expect("test block instruction slot"); block.instructions[off] = info; @@ -7370,6 +7554,52 @@ mod tests { ); } + #[test] + fn optimize_load_const_pseudo_opcode_breaks_effective_load_const() { + let mut block = Block::default(); + test_block_push( + &mut block, + test_instr( + Instruction::LoadConst { + consti: Arg::marker(), + }, + 90, + ), + ); + test_block_push(&mut block, test_true_cond_jump(BlockIdx::new(0), 90)); + let mut copy = test_instr(Instruction::Copy { i: Arg::marker() }, 90); + copy.arg = OpArg::new(1); + test_block_push(&mut block, copy); + test_block_push(&mut block, test_instr(Instruction::ToBool, 90)); + + let mut code = test_code_info(block); + let (const_idx, _) = code.metadata.consts.insert_full(ConstantData::Tuple { + elements: vec![ConstantData::Integer { + value: BigInt::from(1), + }], + }); + code.blocks[0].instructions[0].arg = OpArg::new(const_idx as u32); + + optimize_load_const(&mut code.metadata, &mut code.blocks) + .expect("optimize_load_const succeeds"); + + // CPython `basicblock_optimize_load_const()` assigns the current + // pseudo opcode to its effective opcode slot, so the following COPY 1 + // is not treated as a copy of the earlier LOAD_CONST. + assert!(matches!( + code.blocks[0].instructions[1].instr.pseudo(), + Some(PseudoInstruction::Jump { .. }) + )); + assert!(matches!( + code.blocks[0].instructions[2].instr.real(), + Some(Instruction::Copy { .. }) + )); + assert!(matches!( + code.blocks[0].instructions[3].instr.real(), + Some(Instruction::ToBool) + )); + } + #[test] fn optimize_load_fast_records_no_input_opcode_ref_at_cpython_produced_index() { let mut block = Block::default(); @@ -7457,6 +7687,24 @@ mod tests { )); } + #[test] + fn empty_tuple_repeat_folds_negative_count_like_cpython() { + let folded = const_folding_safe_multiply( + &ConstantData::Tuple { + elements: Vec::new(), + }, + &ConstantData::Integer { + value: BigInt::from(-1), + }, + ) + .expect("CPython skips repeat-count checks for empty tuples"); + + assert!(matches!( + folded, + ConstantData::Tuple { elements } if elements.is_empty() + )); + } + #[test] fn resolve_line_numbers_duplicates_exit_blocks_like_cpython() { let exit = BlockIdx::new(2); @@ -7581,4 +7829,49 @@ mod tests { assert_eq!(threaded.target, BlockIdx::new(3)); assert_eq!(u32::from(threaded.arg), 3); } + + #[test] + fn same_direction_pseudo_conditional_jump_thread_false_keeps_target() { + let mut blocks = Blocks::from([Block::default(), Block::default(), Block::default()]); + for (i, block) in blocks.iter_mut().enumerate() { + block.cpython_label = InstructionSequenceLabel::from_index(i as i32); + } + blocks[0].next = BlockIdx::new(1); + blocks[1].next = BlockIdx::new(2); + test_block_push(&mut blocks[0], test_cond_jump(BlockIdx::new(1), 10)); + test_block_push(&mut blocks[1], test_cond_jump(BlockIdx::new(1), 20)); + test_block_push(&mut blocks[2], test_instr(Instruction::ReturnValue, 30)); + + let mut metadata = test_code_info(Block::default()).metadata; + optimize_basic_block(&mut blocks, &mut metadata, BlockIdx::new(0)) + .expect("valid conditional jump chain"); + + // CPython only rewrites JUMP_IF_FALSE -> JUMP_IF_TRUE through + // target->b_next. For same-direction jumps, a failed jump_thread() + // leaves the original target unchanged. + assert_eq!(blocks[0].instructions[0].target, BlockIdx::new(1)); + assert!(matches!( + blocks[0].instructions[0].instr.pseudo(), + Some(PseudoInstruction::JumpIfFalse { .. }) + )); + } + + #[test] + fn opposite_direction_pseudo_conditional_uses_target_fallthrough() { + let mut blocks = Blocks::from([Block::default(), Block::default(), Block::default()]); + for (i, block) in blocks.iter_mut().enumerate() { + block.cpython_label = InstructionSequenceLabel::from_index(i as i32); + } + blocks[0].next = BlockIdx::new(1); + blocks[1].next = BlockIdx::new(2); + test_block_push(&mut blocks[0], test_cond_jump(BlockIdx::new(1), 10)); + test_block_push(&mut blocks[1], test_true_cond_jump(BlockIdx::new(2), 20)); + test_block_push(&mut blocks[2], test_instr(Instruction::ReturnValue, 30)); + + let mut metadata = test_code_info(Block::default()).metadata; + optimize_basic_block(&mut blocks, &mut metadata, BlockIdx::new(0)) + .expect("valid conditional jump chain"); + + assert_eq!(blocks[0].instructions[0].target, BlockIdx::new(2)); + } } diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index b598ab7e933..2866e645ff3 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -8,13 +8,15 @@ extern crate log; extern crate alloc; +use rustpython_compiler_core::bytecode::ConstantData; + type IndexMap = indexmap::IndexMap; type IndexSet = indexmap::IndexSet; pub mod compile; pub mod error; pub mod ir; -mod preprocess; +pub mod preprocess; mod string_parser; pub mod symboltable; mod unparse; @@ -24,6 +26,73 @@ use ruff_python_ast as ast; pub(crate) use compile::InternalResult; +#[derive(Clone, Debug)] +pub struct PublicAstInterpolation { + pub str: ConstantData, + pub format_spec: Option>, +} + +#[derive(Clone, Debug)] +pub struct PublicAstFormattedValue { + pub format_spec: Option>, +} + +#[derive(Clone, Debug)] +pub struct PublicAstExprList { + pub values: Vec, +} + +/// Dense side table keyed by public-AST `NodeIndex`. +/// +/// Public `_ast` constructors allocate synthetic node indexes from zero, so a +/// `Vec>` gives O(1) lookup without hashing or insertion-order state. +#[derive(Clone, Debug, Default)] +pub struct PublicAstNodeMap { + values: Vec>, +} + +impl PublicAstNodeMap { + #[must_use] + pub fn new() -> Self { + Self { values: Vec::new() } + } + + pub fn insert(&mut self, index: ast::NodeIndex, value: T) -> Option { + let index = index + .as_u32() + .expect("public AST side table cannot store NodeIndex::NONE") + as usize; + if self.values.len() <= index { + self.values.resize_with(index + 1, || None); + } + self.values[index].replace(value) + } + + #[must_use] + pub fn get(&self, index: &ast::NodeIndex) -> Option<&T> { + let index = index.as_u32()? as usize; + self.values.get(index)?.as_ref() + } + + pub fn get_mut(&mut self, index: &ast::NodeIndex) -> Option<&mut T> { + let index = index.as_u32()? as usize; + self.values.get_mut(index)?.as_mut() + } + + #[must_use] + pub fn contains_key(&self, index: &ast::NodeIndex) -> bool { + self.get(index).is_some() + } + + pub fn values(&self) -> impl Iterator { + self.values.iter().filter_map(Option::as_ref) + } + + pub fn is_empty(&self) -> bool { + self.values.iter().all(Option::is_none) + } +} + pub trait ToPythonName { /// Returns a short name for the node suitable for use in error messages. fn python_name(&self) -> &'static str; @@ -65,7 +134,7 @@ impl ToPythonName for ast::Expr { Self::Lambda { .. } => "lambda", Self::If { .. } => "conditional expression", Self::Named { .. } => "named expression", - Self::IpyEscapeCommand(_) => todo!(), + Self::IpyEscapeCommand(_) => "expression", } } } diff --git a/crates/codegen/src/preprocess.rs b/crates/codegen/src/preprocess.rs index ae2e65bf3fe..ff04330de85 100644 --- a/crates/codegen/src/preprocess.rs +++ b/crates/codegen/src/preprocess.rs @@ -8,29 +8,450 @@ use ruff_python_ast::{ visitor::transformer::{self, Transformer}, }; use ruff_text_size::{Ranged, TextRange}; +use rustpython_compiler_core::bytecode; const MAXDIGITS: usize = 3; const F_LJUST: u8 = 1; -pub(crate) fn preprocess_mod(module: &mut ast::Mod) { - let preprocessor = AstPreprocessor; +/// ast_preprocess.c ControlFlowInFinallyContext +#[derive(Clone, Copy)] +struct ControlFlowInFinallyContext { + in_finally: bool, + in_funcdef: bool, + in_loop: bool, +} + +/// ast_preprocess.c before_return +fn before_return( + contexts: &[ControlFlowInFinallyContext], + range: TextRange, + warn: &mut impl FnMut(TextRange, String) -> Result<(), E>, +) -> Result<(), E> { + if let Some(ctx) = contexts.last() + && ctx.in_finally + && !ctx.in_funcdef + { + warn(range, "'return' in a 'finally' block".to_owned())?; + } + Ok(()) +} + +/// ast_preprocess.c before_loop_exit +fn before_loop_exit( + contexts: &[ControlFlowInFinallyContext], + range: TextRange, + kw: &str, + warn: &mut impl FnMut(TextRange, String) -> Result<(), E>, +) -> Result<(), E> { + if let Some(ctx) = contexts.last() + && ctx.in_finally + && !ctx.in_loop + { + warn(range, format!("'{kw}' in a 'finally' block"))?; + } + Ok(()) +} + +fn visit_body_with_control_flow_context( + body: &[ast::Stmt], + contexts: &mut Vec, + warn: &mut impl FnMut(TextRange, String) -> Result<(), E>, + in_finally: bool, + in_funcdef: bool, + in_loop: bool, +) -> Result<(), E> { + contexts.push(ControlFlowInFinallyContext { + in_finally, + in_funcdef, + in_loop, + }); + visit_body_for_control_flow_in_finally(body, contexts, warn)?; + contexts.pop(); + Ok(()) +} + +fn visit_body_for_control_flow_in_finally( + body: &[ast::Stmt], + contexts: &mut Vec, + warn: &mut impl FnMut(TextRange, String) -> Result<(), E>, +) -> Result<(), E> { + for stmt in body { + visit_stmt_for_control_flow_in_finally(stmt, contexts, warn)?; + } + Ok(()) +} + +/// ast_preprocess.c astfold_stmt control-flow warning traversal. +fn visit_stmt_for_control_flow_in_finally( + stmt: &ast::Stmt, + contexts: &mut Vec, + warn: &mut impl FnMut(TextRange, String) -> Result<(), E>, +) -> Result<(), E> { + match stmt { + ast::Stmt::FunctionDef(function) => { + visit_body_with_control_flow_context( + &function.body, + contexts, + warn, + false, + true, + false, + )?; + } + ast::Stmt::ClassDef(class) => { + visit_body_for_control_flow_in_finally(&class.body, contexts, warn)?; + } + ast::Stmt::Return(return_stmt) => { + before_return(contexts, return_stmt.range, warn)?; + } + ast::Stmt::For(for_stmt) => { + visit_body_with_control_flow_context( + &for_stmt.body, + contexts, + warn, + false, + false, + true, + )?; + visit_body_for_control_flow_in_finally(&for_stmt.orelse, contexts, warn)?; + } + ast::Stmt::While(while_stmt) => { + visit_body_with_control_flow_context( + &while_stmt.body, + contexts, + warn, + false, + false, + true, + )?; + visit_body_for_control_flow_in_finally(&while_stmt.orelse, contexts, warn)?; + } + ast::Stmt::If(if_stmt) => { + visit_body_for_control_flow_in_finally(&if_stmt.body, contexts, warn)?; + for clause in &if_stmt.elif_else_clauses { + visit_body_for_control_flow_in_finally(&clause.body, contexts, warn)?; + } + } + ast::Stmt::Try(try_stmt) => { + visit_body_for_control_flow_in_finally(&try_stmt.body, contexts, warn)?; + for handler in &try_stmt.handlers { + match handler { + ast::ExceptHandler::ExceptHandler(handler) => { + visit_body_for_control_flow_in_finally(&handler.body, contexts, warn)?; + } + } + } + visit_body_for_control_flow_in_finally(&try_stmt.orelse, contexts, warn)?; + visit_body_with_control_flow_context( + &try_stmt.finalbody, + contexts, + warn, + true, + false, + false, + )?; + } + ast::Stmt::With(with_stmt) => { + visit_body_for_control_flow_in_finally(&with_stmt.body, contexts, warn)?; + } + ast::Stmt::Match(match_stmt) => { + for case in &match_stmt.cases { + visit_body_for_control_flow_in_finally(&case.body, contexts, warn)?; + } + } + ast::Stmt::Break(break_stmt) => { + before_loop_exit(contexts, break_stmt.range, "break", warn)?; + } + ast::Stmt::Continue(continue_stmt) => { + before_loop_exit(contexts, continue_stmt.range, "continue", warn)?; + } + _ => {} + } + Ok(()) +} + +/// ast_preprocess.c control_flow_in_finally_warning +pub fn warn_control_flow_in_finally( + module: &ast::Mod, + mut warn: impl FnMut(TextRange, String) -> Result<(), E>, +) -> Result<(), E> { + let mut contexts = Vec::new(); + match module { + ast::Mod::Module(module) => { + visit_body_for_control_flow_in_finally(&module.body, &mut contexts, &mut warn)?; + } + ast::Mod::Expression(_) => {} + } + Ok(()) +} + +pub fn has_future_annotations(module: &ast::Mod) -> bool { + future_features(module).contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS) +} + +pub fn future_features(module: &ast::Mod) -> bytecode::CodeFlags { + checked_future_features(module).unwrap_or_else(|err| err.features) +} + +pub struct FutureFeatureError { + pub features: bytecode::CodeFlags, + pub range: TextRange, + pub kind: FutureFeatureErrorKind, +} + +pub enum FutureFeatureErrorKind { + InvalidFeature(String), + InvalidBraces, +} + +pub fn checked_future_features( + module: &ast::Mod, +) -> Result { + let ast::Mod::Module(module) = module else { + return Ok(bytecode::CodeFlags::empty()); + }; + checked_future_features_in_body(&module.body) +} + +pub fn has_future_annotations_in_body(body: &[ast::Stmt]) -> bool { + future_features_in_body(body).contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS) +} + +pub fn future_features_in_body(body: &[ast::Stmt]) -> bytecode::CodeFlags { + checked_future_features_in_body(body).unwrap_or_else(|err| err.features) +} + +pub fn checked_future_features_in_body( + body: &[ast::Stmt], +) -> Result { + let mut future_features = bytecode::CodeFlags::empty(); + let mut statements = body.iter(); + if let Some(ast::Stmt::Expr(ast::StmtExpr { value, .. })) = statements.clone().next() + && matches!(&**value, ast::Expr::StringLiteral(_)) + { + statements.next(); + } + for statement in statements { + match statement { + ast::Stmt::ImportFrom(ast::StmtImportFrom { + module, + names, + level, + .. + }) if *level == 0 && module.as_ref().map(|id| id.as_str()) == Some("__future__") => { + for alias in names { + match alias.name.as_str() { + "nested_scopes" | "generators" | "division" | "absolute_import" + | "with_statement" | "print_function" | "unicode_literals" + | "generator_stop" => {} + "annotations" => { + future_features.insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + } + // Accept the CPython future feature name, but leave it + // as a RustPython no-op. + "barry_as_FLUFL" => {} + "braces" => { + return Err(FutureFeatureError { + features: future_features, + range: alias.range, + kind: FutureFeatureErrorKind::InvalidBraces, + }); + } + other => { + return Err(FutureFeatureError { + features: future_features, + range: alias.range, + kind: FutureFeatureErrorKind::InvalidFeature(other.to_owned()), + }); + } + } + } + } + _ => return Ok(future_features), + } + } + Ok(future_features) +} + +pub fn preprocess_statements( + body: &mut [ast::Stmt], + optimize: u8, + future_annotations: bool, + syntax_check_only: bool, +) { + let preprocessor = AstPreprocessor { + optimize, + future_annotations, + constant_folding: !syntax_check_only, + }; + for stmt in body { + preprocessor.visit_stmt(stmt); + } +} + +pub fn preprocess_mod( + module: &mut ast::Mod, + optimize: u8, + future_annotations: bool, + syntax_check_only: bool, +) { + let preprocessor = AstPreprocessor { + optimize, + future_annotations, + constant_folding: !syntax_check_only, + }; match module { - ast::Mod::Module(module) => preprocessor.visit_body(&mut module.body), + ast::Mod::Module(module) => preprocessor.visit_astfold_body(&mut module.body), ast::Mod::Expression(expr) => preprocessor.visit_expr(&mut expr.body), } } -struct AstPreprocessor; +struct AstPreprocessor { + optimize: u8, + future_annotations: bool, + constant_folding: bool, +} + +impl AstPreprocessor { + fn visit_astfold_body(&self, body: &mut ast::Suite) { + let mut docstring = body_starts_with_docstring(body); + if docstring && self.optimize >= 2 { + remove_docstring_from_body(body); + docstring = false; + } + + for stmt in body.iter_mut() { + self.visit_stmt(stmt); + } + + if !docstring && body_starts_with_docstring(body) { + wrap_first_docstring_as_fstring(body); + } + } +} impl Transformer for AstPreprocessor { + fn visit_stmt(&self, stmt: &mut ast::Stmt) { + match stmt { + ast::Stmt::FunctionDef(function) => { + if let Some(type_params) = &mut function.type_params { + self.visit_type_params(type_params); + } + self.visit_parameters(&mut function.parameters); + self.visit_astfold_body(&mut function.body); + for decorator in &mut function.decorator_list { + self.visit_decorator(decorator); + } + if let Some(returns) = &mut function.returns { + self.visit_annotation(returns); + } + } + ast::Stmt::ClassDef(class) => { + if let Some(type_params) = &mut class.type_params { + self.visit_type_params(type_params); + } + if let Some(arguments) = &mut class.arguments { + self.visit_arguments(arguments); + } + self.visit_astfold_body(&mut class.body); + for decorator in &mut class.decorator_list { + self.visit_decorator(decorator); + } + } + _ => transformer::walk_stmt(self, stmt), + } + } + + fn visit_annotation(&self, expr: &mut Expr) { + if !self.future_annotations { + transformer::walk_annotation(self, expr); + } + } + + fn visit_parameters(&self, parameters: &mut ast::Parameters) { + for arg in &mut parameters.posonlyargs { + self.visit_parameter(&mut arg.parameter); + } + for arg in &mut parameters.args { + self.visit_parameter(&mut arg.parameter); + } + if let Some(arg) = &mut parameters.vararg { + self.visit_parameter(arg); + } + for arg in &mut parameters.kwonlyargs { + self.visit_parameter(&mut arg.parameter); + } + for arg in &mut parameters.kwonlyargs { + if let Some(default) = &mut arg.default { + self.visit_expr(default); + } + } + if let Some(arg) = &mut parameters.kwarg { + self.visit_parameter(arg); + } + for arg in &mut parameters.posonlyargs { + if let Some(default) = &mut arg.default { + self.visit_expr(default); + } + } + for arg in &mut parameters.args { + if let Some(default) = &mut arg.default { + self.visit_expr(default); + } + } + } + + fn visit_comprehension(&self, comprehension: &mut ast::Comprehension) { + self.visit_expr(&mut comprehension.target); + self.visit_expr(&mut comprehension.iter); + for if_expr in &mut comprehension.ifs { + self.visit_expr(if_expr); + } + } + + fn visit_pattern(&self, pattern: &mut ast::Pattern) { + transformer::walk_pattern(self, pattern); + if !self.constant_folding { + return; + } + match pattern { + ast::Pattern::MatchValue(value) => fold_match_value_constant_expr(&mut value.value), + ast::Pattern::MatchMapping(mapping) => { + for key in &mut mapping.keys { + fold_match_value_constant_expr(key); + } + } + _ => {} + } + } + fn visit_expr(&self, expr: &mut Expr) { transformer::walk_expr(self, expr); - if let Some(optimized) = optimize_format(expr) { - *expr = optimized; + if self.constant_folding { + if let Some(optimized) = optimize_format(expr) { + *expr = optimized; + } else if let Some(optimized) = fold_debug_constant(expr, self.optimize) { + *expr = optimized; + } } } } +fn fold_debug_constant(expr: &Expr, optimize: u8) -> Option { + let Expr::Name(name) = expr else { + return None; + }; + if !matches!(name.ctx, ast::ExprContext::Load) || name.id.as_str() != "__debug__" { + return None; + } + + Some(Expr::BooleanLiteral(ast::ExprBooleanLiteral { + node_index: name.node_index.clone(), + range: name.range, + value: optimize == 0, + })) +} + fn optimize_format(expr: &Expr) -> Option { let Expr::BinOp(binop) = expr else { return None; @@ -223,3 +644,385 @@ fn generated_literal(value: String) -> InterpolatedStringLiteralElement { value: value.into_boxed_str(), } } + +fn remove_docstring_from_body(body: &mut ast::Suite) { + if let Some(range) = take_docstring(body) { + if !body.is_empty() { + return; + } + let start = range.start(); + let pass_range = TextRange::new(start, start + ruff_text_size::TextSize::from(4)); + body.push(ast::Stmt::Pass(ast::StmtPass { + node_index: Default::default(), + range: pass_range, + })); + } +} + +fn take_docstring(body: &mut ast::Suite) -> Option { + let ast::Stmt::Expr(expr_stmt) = body.first()? else { + return None; + }; + if matches!(expr_stmt.value.as_ref(), ast::Expr::StringLiteral(_)) { + let range = expr_stmt.range; + body.remove(0); + return Some(range); + } + None +} + +fn body_starts_with_docstring(body: &[ast::Stmt]) -> bool { + let Some(ast::Stmt::Expr(expr_stmt)) = body.first() else { + return false; + }; + matches!(expr_stmt.value.as_ref(), ast::Expr::StringLiteral(_)) +} + +fn wrap_first_docstring_as_fstring(body: &mut [ast::Stmt]) { + let Some(ast::Stmt::Expr(expr_stmt)) = body.first_mut() else { + return; + }; + let ast::Expr::StringLiteral(string) = expr_stmt.value.as_ref() else { + return; + }; + let range = expr_stmt.value.range(); + let value = string.value.to_str().to_string(); + *expr_stmt.value = ast::Expr::FString(ast::ExprFString { + node_index: AtomicNodeIndex::NONE, + range, + value: FStringValue::single(FString { + range, + node_index: AtomicNodeIndex::NONE, + elements: InterpolatedStringElements::from(vec![InterpolatedStringElement::Literal( + InterpolatedStringLiteralElement { + range, + node_index: AtomicNodeIndex::NONE, + value: value.into_boxed_str(), + }, + )]), + flags: FStringFlags::empty(), + }), + }); +} + +fn fold_match_value_constant_expr(expr: &mut ast::Expr) { + match expr { + ast::Expr::UnaryOp(unary) + if matches!(unary.op, ast::UnaryOp::USub) + && matches!(unary.operand.as_ref(), ast::Expr::NumberLiteral(_)) => + { + if let Some(number) = negate_match_number(&unary.operand) { + *expr = ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: unary.node_index.clone(), + range: unary.range, + value: number, + }); + } + } + ast::Expr::BinOp(binop) if matches!(binop.op, ast::Operator::Add | ast::Operator::Sub) => { + fold_match_value_constant_expr(&mut binop.left); + if let Some(number) = fold_match_number_binop(&binop.left, binop.op, &binop.right) { + *expr = ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: binop.node_index.clone(), + range: binop.range, + value: number, + }); + } + } + _ => {} + } +} + +fn negate_match_number(expr: &ast::Expr) -> Option { + let ast::Expr::NumberLiteral(number) = expr else { + return None; + }; + Some(match &number.value { + ast::Number::Int(value) => { + if *value == ast::Int::ZERO { + ast::Number::Int(ast::Int::ZERO) + } else { + return None; + } + } + ast::Number::Float(value) => ast::Number::Float(-value), + ast::Number::Complex { real, imag } => ast::Number::Complex { + real: -real, + imag: -imag, + }, + }) +} + +fn fold_match_number_binop( + left: &ast::Expr, + op: ast::Operator, + right: &ast::Expr, +) -> Option { + let ast::Expr::NumberLiteral(left) = left else { + return None; + }; + let ast::Expr::NumberLiteral(right) = right else { + return None; + }; + let right = match right.value { + ast::Number::Complex { real, imag } => (real, imag), + _ => return None, + }; + enum MatchNumberLeft { + Real(f64), + Complex { real: f64, imag: f64 }, + } + let left = match &left.value { + ast::Number::Int(value) => MatchNumberLeft::Real(value.as_i64()? as f64), + ast::Number::Float(value) => MatchNumberLeft::Real(*value), + ast::Number::Complex { real, imag } => MatchNumberLeft::Complex { + real: *real, + imag: *imag, + }, + }; + let (real, imag) = match (left, op) { + (MatchNumberLeft::Real(left), ast::Operator::Add) => (left + right.0, right.1), + (MatchNumberLeft::Real(left), ast::Operator::Sub) => (left - right.0, -right.1), + (MatchNumberLeft::Complex { real, imag }, ast::Operator::Add) => { + (real + right.0, imag + right.1) + } + (MatchNumberLeft::Complex { real, imag }, ast::Operator::Sub) => { + (real - right.0, imag - right.1) + } + _ => return None, + }; + Some(ast::Number::Complex { real, imag }) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn first_match_value(source: &str) -> ast::Expr { + let parsed = ruff_python_parser::parse(source, ruff_python_parser::Mode::Module.into()) + .unwrap() + .into_syntax(); + let mut module = parsed; + let future_annotations = has_future_annotations(&module); + preprocess_mod(&mut module, 0, future_annotations, false); + let ast::Mod::Module(module) = module else { + panic!("expected module"); + }; + let [ast::Stmt::Match(match_stmt)] = &module.body[..] else { + panic!("expected a single match statement"); + }; + let ast::Pattern::MatchValue(value) = &match_stmt.cases[0].pattern else { + panic!("expected a value pattern"); + }; + *value.value.clone() + } + + fn preprocess_source(source: &str) -> ast::Mod { + let mut module = ruff_python_parser::parse(source, ruff_python_parser::Mode::Module.into()) + .unwrap() + .into_syntax(); + let future_annotations = has_future_annotations(&module); + preprocess_mod(&mut module, 0, future_annotations, false); + module + } + + fn preprocess_source_with_optimize(source: &str, optimize: u8) -> ast::Mod { + let mut module = ruff_python_parser::parse(source, ruff_python_parser::Mode::Module.into()) + .unwrap() + .into_syntax(); + let future_annotations = has_future_annotations(&module); + preprocess_mod(&mut module, optimize, future_annotations, false); + module + } + + fn preprocess_source_syntax_check_only(source: &str, optimize: u8) -> ast::Mod { + let mut module = ruff_python_parser::parse(source, ruff_python_parser::Mode::Module.into()) + .unwrap() + .into_syntax(); + let future_annotations = has_future_annotations(&module); + preprocess_mod(&mut module, optimize, future_annotations, true); + module + } + + #[test] + fn folds_match_value_negative_float_in_preprocess() { + let value = first_match_value( + "\ +match value: + case -1.5: + pass +", + ); + let ast::Expr::NumberLiteral(number) = value else { + panic!("expected folded number literal, got {value:?}"); + }; + assert!(matches!(number.value, ast::Number::Float(value) if value == -1.5)); + } + + #[test] + fn folds_match_value_complex_binop_in_preprocess() { + let value = first_match_value( + "\ +match value: + case 1 + 2j: + pass +", + ); + let ast::Expr::NumberLiteral(number) = value else { + panic!("expected folded number literal, got {value:?}"); + }; + assert!( + matches!(number.value, ast::Number::Complex { real, imag } if real == 1.0 && imag == 2.0) + ); + } + + #[test] + fn folds_match_value_complex_complex_binop_in_preprocess() { + let left = ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: AtomicNodeIndex::NONE, + range: TextRange::default(), + value: ast::Number::Complex { + real: 0.0, + imag: 1.0, + }, + }); + let right = ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + node_index: AtomicNodeIndex::NONE, + range: TextRange::default(), + value: ast::Number::Complex { + real: 0.0, + imag: 2.0, + }, + }); + let number = fold_match_number_binop(&left, ast::Operator::Add, &right) + .expect("CPython fold_const_match_patterns() uses PyNumber_Add"); + assert!( + matches!(number, ast::Number::Complex { real, imag } if real == 0.0 && imag == 3.0) + ); + } + + #[test] + fn folds_match_value_real_minus_zero_complex_preserves_negative_zero_in_preprocess() { + let value = first_match_value( + "\ +match value: + case 0 - 0j: + pass +", + ); + let ast::Expr::NumberLiteral(number) = value else { + panic!("expected folded number literal, got {value:?}"); + }; + assert!(matches!(number.value, ast::Number::Complex { real, imag } + if real == 0.0 && imag == 0.0 && imag.is_sign_negative())); + } + + #[test] + fn future_annotations_skip_annotation_preprocess_like_cpython() { + let module = preprocess_source( + "\ +from __future__ import annotations +def f(x: __debug__) -> __debug__: + pass +y: __debug__ +z = __debug__ +", + ); + let ast::Mod::Module(module) = module else { + panic!("expected module"); + }; + let ast::Stmt::FunctionDef(function) = &module.body[1] else { + panic!("expected function"); + }; + let annotation = function.parameters.args[0] + .parameter + .annotation + .as_deref() + .expect("missing parameter annotation"); + assert!( + matches!(annotation, ast::Expr::Name(name) if name.id.as_str() == "__debug__"), + "future annotations should skip parameter annotation folding, got {annotation:?}" + ); + let returns = function + .returns + .as_deref() + .expect("missing return annotation"); + assert!( + matches!(returns, ast::Expr::Name(name) if name.id.as_str() == "__debug__"), + "future annotations should skip return annotation folding, got {returns:?}" + ); + let ast::Stmt::AnnAssign(ann_assign) = &module.body[2] else { + panic!("expected annotated assignment"); + }; + assert!( + matches!(ann_assign.annotation.as_ref(), ast::Expr::Name(name) if name.id.as_str() == "__debug__"), + "future annotations should skip annotated assignment annotation folding, got {:?}", + ann_assign.annotation + ); + let ast::Stmt::Assign(assign) = &module.body[3] else { + panic!("expected assignment"); + }; + assert!( + matches!(assign.value.as_ref(), ast::Expr::BooleanLiteral(boolean) if boolean.value), + "non-annotation expression should still fold __debug__, got {:?}", + assign.value + ); + } + + #[test] + fn late_future_annotations_do_not_affect_preprocess_like_cpython() { + let module = preprocess_source( + "\ +x = 1 +from __future__ import annotations +y: __debug__ +", + ); + let ast::Mod::Module(module) = module else { + panic!("expected module"); + }; + let ast::Stmt::AnnAssign(ann_assign) = &module.body[2] else { + panic!("expected annotated assignment"); + }; + assert!( + matches!(ann_assign.annotation.as_ref(), ast::Expr::BooleanLiteral(boolean) if boolean.value), + "late future import should not disable annotation folding, got {:?}", + ann_assign.annotation + ); + } + + #[test] + fn optimize_two_wraps_new_docstring_after_removing_original() { + let module = preprocess_source_with_optimize("\"first\"\n\"second\"\n", 2); + let ast::Mod::Module(module) = module else { + panic!("expected module"); + }; + let [ast::Stmt::Expr(expr)] = &module.body[..] else { + panic!("expected only the second statement to remain"); + }; + assert!( + matches!(expr.value.as_ref(), ast::Expr::FString(_)), + "CPython wraps the new leading string as JoinedStr so it is not a docstring" + ); + } + + #[test] + fn syntax_check_only_disables_constant_folding_but_keeps_docstring_strip() { + let module = preprocess_source_syntax_check_only("\"doc\"\nvalue = __debug__\n", 2); + let ast::Mod::Module(module) = module else { + panic!("expected module"); + }; + assert!( + matches!(module.body[0], ast::Stmt::Assign(_)), + "optimize=2 should still strip docstrings in syntax_check_only mode" + ); + let ast::Stmt::Assign(assign) = &module.body[0] else { + panic!("expected assignment"); + }; + assert!( + matches!(assign.value.as_ref(), ast::Expr::Name(name) if name.id.as_str() == "__debug__"), + "syntax_check_only should skip __debug__ folding, got {:?}", + assign.value + ); + } +} diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 27ba10ebadb..032d6afdfd7 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -8,14 +8,20 @@ Inspirational file: https://github.com/python/cpython/blob/main/Python/symtable. */ use crate::{ - IndexMap, IndexSet, + IndexMap, IndexSet, PublicAstExprList, PublicAstFormattedValue, PublicAstInterpolation, + PublicAstNodeMap, error::{CodegenError, CodegenErrorType}, }; -use alloc::{borrow::Cow, fmt}; +use alloc::{borrow::Cow, fmt, sync::Arc}; use bitflags::bitflags; use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange}; -use rustpython_compiler_core::{PositionEncoding, SourceFile, SourceLocation}; +use rustpython_compiler_core::{ + PositionEncoding, SourceFile, SourceLocation, bytecode::ConstantData, +}; + +const DEFAULT_RECURSION_LIMIT: usize = 1000; +const RECURSION_ERROR: &str = "maximum recursion depth exceeded during compilation"; /// Captures all symbols in the current scope, and has a list of sub-scopes in this scope. #[derive(Clone)] @@ -42,6 +48,20 @@ pub struct SymbolTable { /// AST nodes. pub sub_tables: Vec, + /// Annotation scopes that CPython registers in st_blocks but does not add + /// to ste_children, e.g. future-annotation function signatures. + pub hidden_annotation_blocks: Vec, + + /// Cursor pointing to the next hidden annotation block to consume. + pub next_hidden_annotation_block: usize, + + /// Inlined comprehension scopes that CPython removes from ste_children but + /// can still find through st_blocks keyed by the comprehension expression. + pub inlined_comprehension_blocks: Vec, + + /// Cursor pointing to the next inlined comprehension block to consume. + pub next_inlined_comprehension_block: usize, + /// Cursor pointing to the next sub-table to consume during compilation. pub next_sub_table: usize, @@ -63,6 +83,19 @@ pub struct SymbolTable { /// Whether this scope contains await or async comprehension machinery. pub is_coroutine: bool, + /// Whether this scope contains a return statement with a value. + pub returns_value: bool, + + /// Whether this block visited at least one annotation expression. + pub annotations_used: bool, + + /// Optional description of the current type-variable evaluator context. + pub scope_info: Option<&'static str>, + + /// Whether this annotation block is currently visiting an unevaluated + /// function-local annotation. + pub in_unevaluated_annotation: bool, + /// Whether this comprehension scope should be inlined (PEP 709) /// True for list/set/dict comprehensions in non-generator expressions pub comp_inlined: bool, @@ -99,6 +132,10 @@ impl SymbolTable { is_method: false, symbols: IndexMap::default(), sub_tables: vec![], + hidden_annotation_blocks: vec![], + next_hidden_annotation_block: 0, + inlined_comprehension_blocks: vec![], + next_inlined_comprehension_block: 0, next_sub_table: 0, varnames: Vec::new(), needs_class_closure: false, @@ -106,6 +143,10 @@ impl SymbolTable { can_see_class_scope: false, is_generator: false, is_coroutine: false, + returns_value: false, + annotations_used: false, + scope_info: None, + in_unevaluated_annotation: false, comp_inlined: false, annotation_block: None, skip_enclosing_function_scope: false, @@ -115,11 +156,64 @@ impl SymbolTable { } } + fn add_format_parameter(&mut self) { + let name = ".format"; + let symbol = self + .symbols + .entry(name.to_owned()) + .or_insert_with(|| Symbol::new(name)); + symbol + .flags + .insert(SymbolFlags::PARAMETER | SymbolFlags::REFERENCED); + if !self.varnames.iter().any(|varname| varname == name) { + self.varnames.push(name.to_owned()); + } + } + pub fn scan_program( program: &ast::ModModule, source_file: SourceFile, + ) -> SymbolTableResult { + Self::scan_program_with_options( + program, + source_file, + false, + false, + None, + None, + None, + None, + None, + DEFAULT_RECURSION_LIMIT, + ) + } + + #[expect( + clippy::too_many_arguments, + reason = "passes compile options and public AST override tables" + )] + pub fn scan_program_with_options( + program: &ast::ModModule, + source_file: SourceFile, + allow_top_level_await: bool, + future_annotations: bool, + ast_constant_overrides: Option>>, + ast_interpolation_overrides: Option>>, + ast_formatted_value_overrides: Option>>, + ast_joined_str_overrides: Option>>, + ast_template_str_overrides: Option>>, + recursion_limit: usize, ) -> SymbolTableResult { let mut builder = SymbolTableBuilder::new(source_file); + builder.allow_top_level_await = allow_top_level_await; + builder.ast_constant_overrides = ast_constant_overrides; + builder.ast_interpolation_overrides = ast_interpolation_overrides; + builder.ast_formatted_value_overrides = ast_formatted_value_overrides; + builder.ast_joined_str_overrides = ast_joined_str_overrides; + builder.ast_template_str_overrides = ast_template_str_overrides; + builder.recursion_limit = recursion_limit; + builder.future_annotations = future_annotations + || SymbolTableBuilder::future_annotations_from_module_body(program.body.as_ref()); builder.scan_statements(program.body.as_ref())?; builder.finish() } @@ -127,8 +221,46 @@ impl SymbolTable { pub fn scan_expr( expr: &ast::ModExpression, source_file: SourceFile, + ) -> SymbolTableResult { + Self::scan_expr_with_options( + expr, + source_file, + false, + false, + None, + None, + None, + None, + None, + DEFAULT_RECURSION_LIMIT, + ) + } + + #[expect( + clippy::too_many_arguments, + reason = "passes compile options and public AST override tables" + )] + pub fn scan_expr_with_options( + expr: &ast::ModExpression, + source_file: SourceFile, + allow_top_level_await: bool, + future_annotations: bool, + ast_constant_overrides: Option>>, + ast_interpolation_overrides: Option>>, + ast_formatted_value_overrides: Option>>, + ast_joined_str_overrides: Option>>, + ast_template_str_overrides: Option>>, + recursion_limit: usize, ) -> SymbolTableResult { let mut builder = SymbolTableBuilder::new(source_file); + builder.allow_top_level_await = allow_top_level_await; + builder.ast_constant_overrides = ast_constant_overrides; + builder.ast_interpolation_overrides = ast_interpolation_overrides; + builder.ast_formatted_value_overrides = ast_formatted_value_overrides; + builder.ast_joined_str_overrides = ast_joined_str_overrides; + builder.ast_template_str_overrides = ast_template_str_overrides; + builder.recursion_limit = recursion_limit; + builder.future_annotations = future_annotations; builder.scan_expression(expr.body.as_ref(), ExpressionContext::Load)?; builder.finish() } @@ -150,6 +282,8 @@ pub enum CompilerScope { TypeParams, /// PEP 649: Annotation scope for deferred evaluation Annotation, + TypeAlias, + TypeVariable, } impl fmt::Display for CompilerScope { @@ -163,11 +297,8 @@ impl fmt::Display for CompilerScope { Self::Comprehension => write!(f, "comprehension"), Self::TypeParams => write!(f, "type parameter"), Self::Annotation => write!(f, "annotation"), - // TODO missing types from the C implementation - // if self._table.type == _symtable.TYPE_TYPE_VAR_BOUND: - // return "TypeVar bound" - // if self._table.type == _symtable.TYPE_TYPE_ALIAS: - // return "type alias" + Self::TypeAlias => write!(f, "type alias"), + Self::TypeVariable => write!(f, "TypeVar bound"), } } } @@ -287,9 +418,14 @@ pub struct SymbolTableError { impl SymbolTableError { #[must_use] pub fn into_codegen_error(self, source_path: String) -> CodegenError { + let error = if self.error == RECURSION_ERROR { + CodegenErrorType::RecursionError + } else { + CodegenErrorType::SyntaxError(self.error) + }; CodegenError { location: self.location, - error: CodegenErrorType::SyntaxError(self.error), + error, source_path, } } @@ -339,20 +475,6 @@ fn drop_class_free(symbol_table: &mut SymbolTable, newfree: &mut IndexSet, parent_type: CompilerScope, ) -> IndexSet { - let mut removed_class_implicit = IndexSet::default(); + let mut removed_class_implicits = IndexSet::default(); for (name, sub_symbol) in &comp.symbols { // Skip the .0 parameter if sub_symbol.flags.contains(SymbolFlags::PARAMETER) { @@ -383,15 +505,24 @@ fn inline_comprehension( inlined_cells.insert(name.clone()); } - // Handle __class__ in ClassBlock + // __class__, __classdict__ and __conditional_annotations__ are never + // allowed to be free through a class scope. let scope = if sub_symbol.scope == SymbolScope::Free && parent_type == CompilerScope::Class && matches!( name.as_str(), "__class__" | "__classdict__" | "__conditional_annotations__" ) { - comp_free.swap_remove(name); - removed_class_implicit.insert(name.clone()); + let is_free_in_child = comp.sub_tables.iter().any(|child| { + child + .symbols + .get(name) + .is_some_and(|s| s.scope == SymbolScope::Free) + }); + if !is_free_in_child { + comp_free.swap_remove(name); + } + removed_class_implicits.insert(name.clone()); SymbolScope::GlobalImplicit } else { sub_symbol.scope @@ -420,7 +551,7 @@ fn inline_comprehension( parent_symbols.insert(name.clone(), symbol); } } - removed_class_implicit + removed_class_implicits } type SymbolMap = IndexMap; @@ -614,6 +745,26 @@ impl SymbolTableAnalyzer { newfree.extend(ann_free); } + let mut inlined_blocks = Vec::new(); + let mut idx = 0; + while idx < symbol_table.sub_tables.len() { + if symbol_table.sub_tables[idx].comp_inlined { + let comp = symbol_table.sub_tables.remove(idx); + let nested_inlined_blocks = comp.inlined_comprehension_blocks.clone(); + let children = comp.sub_tables.clone(); + let inserted = children.len(); + inlined_blocks.push(comp); + inlined_blocks.extend(nested_inlined_blocks); + symbol_table.sub_tables.splice(idx..idx, children); + idx += inserted; + } else { + idx += 1; + } + } + symbol_table + .inlined_comprehension_blocks + .extend(inlined_blocks); + let sub_tables = &*symbol_table.sub_tables; for symbol in symbol_table.symbols.values_mut() { @@ -671,20 +822,12 @@ impl SymbolTableAnalyzer { // (for example GLOBAL_IMPLICIT via __classdict__) while still making // the name available as a closure cell for nested children such as // generator expressions. - if symbol_table.typ == CompilerScope::Class { + if symbol_table.typ == CompilerScope::Class || symbol_table.can_see_class_scope { for name in &newfree { if let Some(symbol) = symbol_table.symbols.get_mut(name) { symbol.flags.insert(SymbolFlags::FREE_CLASS); } } - } else if symbol_table.can_see_class_scope { - for name in &newfree { - if let Some(symbol) = symbol_table.symbols.get_mut(name) - && !symbol.is_local() - { - symbol.flags.insert(SymbolFlags::FREE_CLASS); - } - } } Ok(newfree) @@ -1039,12 +1182,8 @@ impl SymbolTableAnalyzer { location: None, }); } - CompilerScope::Annotation => { - // Named expression is not allowed in annotation scope - return Err(SymbolTableError { - error: "named expression cannot be used within an annotation".to_string(), - location: None, - }); + CompilerScope::Annotation | CompilerScope::TypeAlias | CompilerScope::TypeVariable => { + self.analyze_symbol_comprehension(symbol, parent_offset + 1)?; } } Ok(()) @@ -1071,6 +1210,12 @@ struct SymbolTableBuilder { // Scope stack. tables: Vec, future_annotations: bool, + allow_top_level_await: bool, + ast_constant_overrides: Option>>, + ast_interpolation_overrides: Option>>, + ast_formatted_value_overrides: Option>>, + ast_joined_str_overrides: Option>>, + ast_template_str_overrides: Option>>, source_file: SourceFile, // Current scope's varnames being collected (temporary storage) current_varnames: Vec, @@ -1078,19 +1223,16 @@ struct SymbolTableBuilder { varnames_stack: Vec>, // Track if we're inside an iterable definition expression (for nested comprehensions) in_iter_def_exp: bool, - // Track if we're inside an annotation (yield/await/named expr not allowed) - in_annotation: bool, - // CPython's ste_in_unevaluated_annotation: function-local AnnAssign - // annotations are not executed and do not contribute name bindings. - in_unevaluated_annotation: bool, - // Track if we're inside a type alias (yield/await/named expr not allowed) - in_type_alias: bool, // Track if we're scanning an inner loop iteration target (not the first generator) in_comp_inner_loop_target: bool, - // Scope info for error messages (e.g., "a TypeVar bound") - scope_info: Option<&'static str>, + // CPython rejects yield/yield from inside comprehension scopes with a + // message that names the comprehension kind. + comprehension_yield_context: Option<&'static str>, // PEP 649: Track if we're inside a conditional block (if/for/while/etc.) in_conditional_block: bool, + // Mirrors CPython symtable ENTER_RECURSIVE guards during compilation. + recursion_depth: usize, + recursion_limit: usize, } /// Enum to indicate in what mode an expression @@ -1112,16 +1254,21 @@ impl SymbolTableBuilder { class_name: None, tables: vec![], future_annotations: false, + allow_top_level_await: false, + ast_constant_overrides: None, + ast_interpolation_overrides: None, + ast_formatted_value_overrides: None, + ast_joined_str_overrides: None, + ast_template_str_overrides: None, source_file, current_varnames: Vec::new(), varnames_stack: Vec::new(), in_iter_def_exp: false, - in_annotation: false, - in_unevaluated_annotation: false, - in_type_alias: false, in_comp_inner_loop_target: false, - scope_info: None, + comprehension_yield_context: None, in_conditional_block: false, + recursion_depth: 0, + recursion_limit: DEFAULT_RECURSION_LIMIT, }; this.enter_scope("top", CompilerScope::Module, 0); this @@ -1135,10 +1282,91 @@ impl SymbolTableBuilder { | CompilerScope::Lambda | CompilerScope::Comprehension | CompilerScope::Annotation + | CompilerScope::TypeAlias + | CompilerScope::TypeVariable | CompilerScope::TypeParams ) } + fn future_annotations_from_module_body(body: &[ast::Stmt]) -> bool { + let mut statements = body.iter(); + if let Some(ast::Stmt::Expr(ast::StmtExpr { value, .. })) = statements.clone().next() + && matches!(&**value, ast::Expr::StringLiteral(_)) + { + statements.next(); + } + for statement in statements { + match statement { + ast::Stmt::ImportFrom(ast::StmtImportFrom { + module, + names, + level, + .. + }) if *level == 0 + && module.as_ref().map(|id| id.as_str()) == Some("__future__") => + { + if names + .iter() + .any(|future| future.name.as_str() == "annotations") + { + return true; + } + } + _ => return false, + } + } + false + } + + fn public_ast_interpolation_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.ast_interpolation_overrides + .as_ref()? + .get(&index) + .cloned() + } + + fn public_ast_formatted_value_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.ast_formatted_value_overrides + .as_ref()? + .get(&index) + .cloned() + } + + fn public_ast_joined_str_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.ast_joined_str_overrides.as_ref()?.get(&index).cloned() + } + + fn public_ast_template_str_override_by_index( + &self, + index: ast::NodeIndex, + ) -> Option { + if index == ast::NodeIndex::NONE { + return None; + } + self.ast_template_str_overrides + .as_ref()? + .get(&index) + .cloned() + } + fn finish(mut self) -> Result { assert_eq!(self.tables.len(), 1); let mut symbol_table = self.tables.pop().unwrap(); @@ -1183,7 +1411,7 @@ impl SymbolTableBuilder { fn enter_type_param_block( &mut self, name: &str, - line_number: u32, + range: TextRange, for_class: bool, has_defaults: bool, has_kwdefaults: bool, @@ -1194,7 +1422,11 @@ impl SymbolTableBuilder { .last() .is_some_and(|t| t.typ == CompilerScope::Class); - self.enter_scope(name, CompilerScope::TypeParams, line_number); + self.enter_scope( + name, + CompilerScope::TypeParams, + self.line_index_start(range), + ); // Set properties on the newly created type param scope if let Some(table) = self.tables.last_mut() { @@ -1208,19 +1440,22 @@ impl SymbolTableBuilder { // Add __classdict__ as a USE symbol in type param scope if in class if in_class { - self.register_name("__classdict__", SymbolUsage::Used, TextRange::default())?; + self.register_name("__classdict__", SymbolUsage::Used, range)?; } - // Register .type_params as a SET symbol (it will be converted to cell variable later) - self.register_name(".type_params", SymbolUsage::Assigned, TextRange::default())?; if for_class { - self.register_name(".generic_base", SymbolUsage::Assigned, TextRange::default())?; + // It gets set when we create the type params tuple and used when + // we build up the bases. + self.register_name(".type_params", SymbolUsage::Assigned, range)?; + self.register_name(".type_params", SymbolUsage::Used, range)?; + self.register_name(".generic_base", SymbolUsage::Assigned, range)?; + self.register_name(".generic_base", SymbolUsage::Used, range)?; } if has_defaults { - self.register_name(".defaults", SymbolUsage::Parameter, TextRange::default())?; + self.register_name(".defaults", SymbolUsage::Parameter, range)?; } if has_kwdefaults { - self.register_name(".kwdefaults", SymbolUsage::Parameter, TextRange::default())?; + self.register_name(".kwdefaults", SymbolUsage::Parameter, range)?; } Ok(()) @@ -1236,9 +1471,22 @@ impl SymbolTableBuilder { self.current_varnames = self.varnames_stack.pop().unwrap_or_default(); } + /// Pop symbol table without adding it to the parent children list. + fn discard_scope(&mut self) -> SymbolTable { + let mut table = self.tables.pop().unwrap(); + table.varnames = core::mem::take(&mut self.current_varnames); + self.current_varnames = self.varnames_stack.pop().unwrap_or_default(); + table + } + /// Enter annotation scope (PEP 649) /// Creates or reuses the annotation block for the current scope - fn enter_annotation_scope(&mut self, line_number: u32) { + fn enter_annotation_scope( + &mut self, + line_number: u32, + include_classdict_with_future: bool, + include_conditional_annotations: bool, + ) { let current = self.tables.last_mut().unwrap(); let can_see_class_scope = current.typ == CompilerScope::Class || current.can_see_class_scope; @@ -1256,8 +1504,7 @@ impl SymbolTableBuilder { // Annotation scope in class can see class scope annotation_table.can_see_class_scope = can_see_class_scope; annotation_table.skip_enclosing_function_scope = true; - // Add 'format' parameter - annotation_table.varnames.push("format".to_owned()); + annotation_table.add_format_parameter(); current.annotation_block = Some(Box::new(annotation_table)); } @@ -1269,10 +1516,10 @@ impl SymbolTableBuilder { .push(core::mem::take(&mut self.current_varnames)); self.current_varnames = self.tables.last().unwrap().varnames.clone(); - if can_see_class_scope && !self.future_annotations { + if can_see_class_scope && (include_classdict_with_future || !self.future_annotations) { self.add_classdict_freevar(); // Also add __conditional_annotations__ as free var if parent has conditional annotations - if has_conditional { + if include_conditional_annotations && has_conditional { self.add_conditional_annotations_freevar(); } } @@ -1321,11 +1568,6 @@ impl SymbolTableBuilder { /// Annotation and TypeParams scopes act as async barriers (always non-async). /// Comprehension scopes are transparent (inherit parent's async context). fn is_in_async_context(&self) -> bool { - // Annotations are evaluated in a non-async scope even when - // the enclosing function is async. - if self.in_annotation { - return false; - } for table in self.tables.iter().rev() { match table.typ { CompilerScope::AsyncFunction => return true, @@ -1334,6 +1576,8 @@ impl SymbolTableBuilder { | CompilerScope::Class | CompilerScope::Module | CompilerScope::Annotation + | CompilerScope::TypeAlias + | CompilerScope::TypeVariable | CompilerScope::TypeParams => return false, // Comprehension inherits parent's async context CompilerScope::Comprehension => continue, @@ -1342,6 +1586,14 @@ impl SymbolTableBuilder { false } + fn allows_top_level_await(&self) -> bool { + self.allow_top_level_await + && self + .tables + .last() + .is_some_and(|table| table.typ == CompilerScope::Module) + } + fn line_index_start(&self, range: TextRange) -> u32 { self.source_file .to_source_code() @@ -1364,12 +1616,6 @@ impl SymbolTableBuilder { } fn scan_parameter(&mut self, parameter: &ast::Parameter) -> SymbolTableResult { - self.check_name( - parameter.name.as_str(), - ExpressionContext::Store, - parameter.name.range, - )?; - let usage = if parameter.annotation.is_some() { SymbolUsage::AnnotationParameter } else { @@ -1395,15 +1641,85 @@ impl SymbolTableBuilder { self.register_ident(¶meter.name, usage) } - fn scan_annotation(&mut self, annotation: &ast::Expr) -> SymbolTableResult { - self.scan_annotation_inner(annotation, false) - } - /// Scan an annotation from an AnnAssign statement (can be conditional) fn scan_ann_assign_annotation(&mut self, annotation: &ast::Expr) -> SymbolTableResult { self.scan_annotation_inner(annotation, true) } + fn scan_function_annotations( + &mut self, + parameters: &ast::Parameters, + returns: Option<&ast::Expr>, + line_number: u32, + ) -> SymbolTableResult { + let current = self.tables.last().unwrap(); + let can_see_class_scope = + current.typ == CompilerScope::Class || current.can_see_class_scope; + self.enter_scope("__annotate__", CompilerScope::Annotation, line_number); + self.tables.last_mut().unwrap().can_see_class_scope = can_see_class_scope; + self.tables.last_mut().unwrap().add_format_parameter(); + if can_see_class_scope { + self.register_name("__classdict__", SymbolUsage::Used, TextRange::default())?; + } + + let was_in_unevaluated_annotation = self.tables.last().unwrap().in_unevaluated_annotation; + self.tables.last_mut().unwrap().in_unevaluated_annotation = false; + + let result = (|| { + for annotation in parameters + .posonlyargs + .iter() + .chain(parameters.args.iter()) + .filter_map(|arg| arg.parameter.annotation.as_ref()) + { + self.tables.last_mut().unwrap().annotations_used = true; + self.scan_expression(annotation, ExpressionContext::Load)?; + } + if let Some(annotation) = parameters + .vararg + .as_ref() + .and_then(|arg| arg.annotation.as_ref()) + { + self.tables.last_mut().unwrap().annotations_used = true; + self.scan_expression(annotation, ExpressionContext::Load)?; + } + if let Some(annotation) = parameters + .kwarg + .as_ref() + .and_then(|arg| arg.annotation.as_ref()) + { + self.tables.last_mut().unwrap().annotations_used = true; + self.scan_expression(annotation, ExpressionContext::Load)?; + } + for annotation in parameters + .kwonlyargs + .iter() + .filter_map(|arg| arg.parameter.annotation.as_ref()) + { + self.tables.last_mut().unwrap().annotations_used = true; + self.scan_expression(annotation, ExpressionContext::Load)?; + } + if let Some(annotation) = returns { + self.tables.last_mut().unwrap().annotations_used = true; + self.scan_expression(annotation, ExpressionContext::Load)?; + } + Ok(()) + })(); + + self.tables.last_mut().unwrap().in_unevaluated_annotation = was_in_unevaluated_annotation; + if self.future_annotations { + let annotation_block = self.discard_scope(); + self.tables + .last_mut() + .unwrap() + .hidden_annotation_blocks + .push(annotation_block); + } else { + self.leave_scope(); + } + result + } + fn scan_annotation_inner( &mut self, annotation: &ast::Expr, @@ -1431,11 +1747,6 @@ impl SymbolTableBuilder { } if should_register_conditional_annotations { - self.register_name( - "__conditional_annotations__", - SymbolUsage::Assigned, - annotation.range(), - )?; self.register_name( "__conditional_annotations__", SymbolUsage::Used, @@ -1445,25 +1756,14 @@ impl SymbolTableBuilder { // Create annotation scope for deferred evaluation let line_number = self.line_index_start(annotation.range()); - self.enter_annotation_scope(line_number); - - if self.future_annotations { - // PEP 563: annotations are stringified at compile time - // Don't scan expression - symbols would fail to resolve - // Just create the annotation_block structure - self.leave_annotation_scope(); - return Ok(()); - } + self.enter_annotation_scope(line_number, false, true); // PEP 649: scan expression for symbol references // Class annotations are evaluated in class locals (not module globals) - let was_in_annotation = self.in_annotation; - let was_in_unevaluated_annotation = self.in_unevaluated_annotation; - self.in_annotation = true; - self.in_unevaluated_annotation = is_unevaluated; + let was_in_unevaluated_annotation = self.tables.last().unwrap().in_unevaluated_annotation; + self.tables.last_mut().unwrap().in_unevaluated_annotation = is_unevaluated; let result = self.scan_expression(annotation, ExpressionContext::Load); - self.in_annotation = was_in_annotation; - self.in_unevaluated_annotation = was_in_unevaluated_annotation; + self.tables.last_mut().unwrap().in_unevaluated_annotation = was_in_unevaluated_annotation; self.leave_annotation_scope(); @@ -1471,468 +1771,515 @@ impl SymbolTableBuilder { } fn scan_statement(&mut self, statement: &ast::Stmt) -> SymbolTableResult { - use ast::*; - if let Stmt::ImportFrom(StmtImportFrom { module, names, .. }) = &statement - && module.as_ref().map(|id| id.as_str()) == Some("__future__") - { - self.future_annotations = - self.future_annotations || names.iter().any(|future| &future.name == "annotations"); + if self.recursion_depth >= self.recursion_limit { + return Err(SymbolTableError { + error: RECURSION_ERROR.to_owned(), + location: None, + }); } - - match &statement { - Stmt::Global(StmtGlobal { names, .. }) => { - for name in names { - self.register_ident(name, SymbolUsage::Global)?; + self.recursion_depth += 1; + let result = (|| { + use ast::*; + match &statement { + Stmt::Global(StmtGlobal { names, .. }) => { + for name in names { + self.register_name(name.as_str(), SymbolUsage::Global, statement.range())?; + } } - } - Stmt::Nonlocal(StmtNonlocal { names, .. }) => { - for name in names { - self.register_ident(name, SymbolUsage::Nonlocal)?; + Stmt::Nonlocal(StmtNonlocal { names, .. }) => { + for name in names { + self.register_name( + name.as_str(), + SymbolUsage::Nonlocal, + statement.range(), + )?; + } } - } - Stmt::FunctionDef(StmtFunctionDef { - name, - body, - parameters, - decorator_list, - type_params, - returns, - range, - is_async, - .. - }) => { - self.scan_decorators(decorator_list, ExpressionContext::Load)?; - self.register_ident(name, SymbolUsage::Assigned)?; - - // Save the parent's annotation_block before scanning function annotations, - // so function annotations don't interfere with parent scope annotations. - // This applies to both class scope (methods) and module scope (top-level functions). - let parent_scope_typ = self.tables.last().map(|t| t.typ); - let should_save_annotation_block = matches!( - parent_scope_typ, - Some( - CompilerScope::Class - | CompilerScope::Module - | CompilerScope::Function - | CompilerScope::AsyncFunction - ) - ); - let saved_annotation_block = if should_save_annotation_block { - self.tables.last_mut().unwrap().annotation_block.take() - } else { - None - }; + Stmt::FunctionDef(StmtFunctionDef { + name, + body, + parameters, + decorator_list, + type_params, + returns, + range, + is_async, + .. + }) => { + self.register_name(name.as_str(), SymbolUsage::Assigned, *range)?; - // For generic functions, scan defaults before entering type_param_block - // (defaults are evaluated in the enclosing scope, not the type param scope) - let has_type_params = type_params.is_some(); - if has_type_params { self.scan_parameter_defaults(parameters)?; - } + self.scan_decorators(decorator_list, ExpressionContext::Load)?; - // For generic functions, enter type_param block FIRST so that - // annotation scopes are nested inside and can see type parameters. - if let Some(type_params) = type_params { - self.enter_type_param_block( + // For generic functions, enter type_param block FIRST so that + // annotation scopes are nested inside and can see type parameters. + if let Some(type_params) = type_params { + self.enter_type_param_block( + name.as_str(), + *range, + false, + Self::has_positional_defaults(parameters), + Self::has_kwonlydefaults(parameters), + )?; + self.scan_type_params(type_params)?; + } + self.enter_scope_with_parameters( name.as_str(), - self.line_index_start(type_params.range), + parameters, + self.line_index_start(*range), + returns.as_deref(), + if *is_async { + CompilerScope::AsyncFunction + } else { + CompilerScope::Function + }, + true, // skip_defaults: already scanned above false, - true, - Self::has_kwonlydefaults(parameters), )?; - self.scan_type_params(type_params)?; - } - let has_return_annotation = if let Some(expression) = returns { - self.scan_annotation(expression)?; - true - } else { - false - }; - self.enter_scope_with_parameters( - name.as_str(), - parameters, - self.line_index_start(*range), - has_return_annotation, if *is_async { - CompilerScope::AsyncFunction - } else { - CompilerScope::Function - }, - has_type_params, // skip_defaults: already scanned above - )?; - if *is_async { - self.tables.last_mut().unwrap().is_coroutine = true; - } - self.scan_statements(body)?; - self.leave_scope(); - if type_params.is_some() { + self.tables.last_mut().unwrap().is_coroutine = true; + } + self.scan_statements(body)?; self.leave_scope(); + if type_params.is_some() { + self.leave_scope(); + } } + Stmt::ClassDef(StmtClassDef { + name, + body, + arguments, + decorator_list, + type_params, + range, + node_index: _, + }) => { + let prev_class = self.class_name.clone(); + self.register_name(name.as_str(), SymbolUsage::Assigned, *range)?; + self.scan_decorators(decorator_list, ExpressionContext::Load)?; - // Restore parent's annotation_block after processing the function - if let Some(block) = saved_annotation_block { - self.tables.last_mut().unwrap().annotation_block = Some(block); - } - } - Stmt::ClassDef(StmtClassDef { - name, - body, - arguments, - decorator_list, - type_params, - range, - node_index: _, - }) => { - // Save class_name for the entire ClassDef processing - let prev_class = self.class_name.take(); - if let Some(type_params) = type_params { - self.enter_type_param_block( + if let Some(type_params) = type_params { + self.enter_type_param_block( + name.as_str(), + *range, + true, // for_class: enable selective mangling + false, + false, + )?; + // Set class_name for mangling in type param scope + self.class_name = Some(name.to_string()); + self.scan_type_params(type_params)?; + } + + if type_params.is_none() { + self.class_name = prev_class.clone(); + } + if let Some(arguments) = arguments { + self.scan_expressions(&arguments.args, ExpressionContext::Load)?; + for keyword in &arguments.keywords { + if let Some(arg) = &keyword.arg { + self.check_name( + arg.as_str(), + ExpressionContext::Store, + keyword.range, + )?; + } + } + for keyword in &arguments.keywords { + self.scan_expression(&keyword.value, ExpressionContext::Load)?; + } + } + + self.enter_scope( name.as_str(), - self.line_index_start(type_params.range), - true, // for_class: enable selective mangling - false, - false, - )?; - // Set class_name for mangling in type param scope + CompilerScope::Class, + self.line_index_start(*range), + ); + // Reset in_conditional_block for new class scope + let saved_in_conditional = self.in_conditional_block; + self.in_conditional_block = false; self.class_name = Some(name.to_string()); - self.scan_type_params(type_params)?; - } - self.enter_scope( - name.as_str(), - CompilerScope::Class, - self.line_index_start(*range), - ); - // Reset in_conditional_block for new class scope - let saved_in_conditional = self.in_conditional_block; - self.in_conditional_block = false; - self.class_name = Some(name.to_string()); - self.register_name("__module__", SymbolUsage::Assigned, *range)?; - self.register_name("__qualname__", SymbolUsage::Assigned, *range)?; - self.register_name("__doc__", SymbolUsage::Assigned, *range)?; - self.register_name("__class__", SymbolUsage::Assigned, *range)?; - if type_params.is_some() { - self.register_name(".type_params", SymbolUsage::Used, *range)?; - self.register_name("__type_params__", SymbolUsage::Assigned, *range)?; - } - self.scan_statements(body)?; - self.leave_scope(); - self.in_conditional_block = saved_in_conditional; - // For non-generic classes, restore class_name before base scanning. - // Bases are evaluated in the enclosing scope, not the class scope. - // For generic classes, bases are scanned within the type_param scope - // where class_name is already correctly set. - if type_params.is_none() { - self.class_name = prev_class.clone(); - } - if let Some(arguments) = arguments { - self.scan_expressions(&arguments.args, ExpressionContext::Load)?; - for keyword in &arguments.keywords { - self.scan_expression(&keyword.value, ExpressionContext::Load)?; + if type_params.is_some() { + self.register_name(".type_params", SymbolUsage::Used, *range)?; + self.register_name("__type_params__", SymbolUsage::Assigned, *range)?; } - } - if type_params.is_some() { + self.scan_statements(body)?; self.leave_scope(); + self.in_conditional_block = saved_in_conditional; + if type_params.is_some() { + self.leave_scope(); + } + // Restore class_name after all ClassDef processing + self.class_name = prev_class; } - // Restore class_name after all ClassDef processing - self.class_name = prev_class; - self.scan_decorators(decorator_list, ExpressionContext::Load)?; - self.register_ident(name, SymbolUsage::Assigned)?; - } - Stmt::Expr(StmtExpr { value, .. }) => { - self.scan_expression(value, ExpressionContext::Load)? - } - Stmt::If(StmtIf { - test, - body, - elif_else_clauses, - .. - }) => { - self.scan_expression(test, ExpressionContext::Load)?; - // PEP 649: Track conditional block for annotations - let saved_in_conditional_block = self.in_conditional_block; - self.in_conditional_block = true; - self.scan_statements(body)?; - for elif in elif_else_clauses { - if let Some(test) = &elif.test { - self.scan_expression(test, ExpressionContext::Load)?; - } - self.scan_statements(&elif.body)?; - } - self.in_conditional_block = saved_in_conditional_block; - } - Stmt::For(StmtFor { - target, - iter, - body, - orelse, - .. - }) => { - self.scan_expression(target, ExpressionContext::Store)?; - self.scan_expression(iter, ExpressionContext::Load)?; - // PEP 649: Track conditional block for annotations - let saved_in_conditional_block = self.in_conditional_block; - self.in_conditional_block = true; - self.scan_statements(body)?; - self.scan_statements(orelse)?; - self.in_conditional_block = saved_in_conditional_block; - } - Stmt::While(StmtWhile { - test, body, orelse, .. - }) => { - self.scan_expression(test, ExpressionContext::Load)?; - // PEP 649: Track conditional block for annotations - let saved_in_conditional_block = self.in_conditional_block; - self.in_conditional_block = true; - self.scan_statements(body)?; - self.scan_statements(orelse)?; - self.in_conditional_block = saved_in_conditional_block; - } - Stmt::Break(_) | Stmt::Continue(_) | Stmt::Pass(_) => { - // No symbols here. - } - Stmt::Import(StmtImport { names, .. }) - | Stmt::ImportFrom(StmtImportFrom { names, .. }) => { - for name in names { - if let Some(alias) = &name.asname { - // `import my_module as my_alias` - self.check_name(alias.as_str(), ExpressionContext::Store, alias.range)?; - self.register_ident(alias, SymbolUsage::Imported)?; - } else if name.name.as_str() == "*" { - // Star imports are only allowed at module level - if self.tables.last().unwrap().typ != CompilerScope::Module { - return Err(SymbolTableError { - error: "'import *' only allowed at module level".to_string(), - location: Some(self.source_file.to_source_code().source_location( - name.name.range.start(), - PositionEncoding::Utf8, - )), - }); - } - // Don't register star imports as symbols - } else { - // `import module` or `from x import name` - let imported_name = name.name.split('.').next().unwrap(); - self.check_name(imported_name, ExpressionContext::Store, name.name.range)?; - self.register_name(imported_name, SymbolUsage::Imported, name.name.range)?; - } - } - } - Stmt::Return(StmtReturn { value, .. }) => { - if let Some(expression) = value { - self.scan_expression(expression, ExpressionContext::Load)?; - } - } - Stmt::Assert(StmtAssert { test, msg, .. }) => { - self.scan_expression(test, ExpressionContext::Load)?; - if let Some(expression) = msg { - self.scan_expression(expression, ExpressionContext::Load)?; - } - } - Stmt::Delete(StmtDelete { targets, .. }) => { - self.scan_expressions(targets, ExpressionContext::Delete)?; - } - Stmt::Assign(StmtAssign { targets, value, .. }) => { - self.scan_expressions(targets, ExpressionContext::Store)?; - self.scan_expression(value, ExpressionContext::Load)?; - } - Stmt::AugAssign(StmtAugAssign { target, value, .. }) => { - self.scan_expression(target, ExpressionContext::Store)?; - self.scan_expression(value, ExpressionContext::Load)?; - } - Stmt::AnnAssign(StmtAnnAssign { - target, - annotation, - value, - simple, - range, - node_index: _, - }) => { - // https://github.com/python/cpython/blob/main/Python/symtable.c#L1233 - match &**target { - Expr::Name(ast::ExprName { id, .. }) => { - let id_str = id.as_str(); - - if *simple { - self.check_name(id_str, ExpressionContext::Store, *range)?; - - self.register_name(id_str, SymbolUsage::AnnotationAssigned, *range)?; - // PEP 649: Register annotate function in module/class scope - let current_scope = self.tables.last().map(|t| t.typ); - match current_scope { - Some(CompilerScope::Module) => { - self.register_name( - "__annotate__", - SymbolUsage::Assigned, - *range, - )?; - } - Some(CompilerScope::Class) => { - self.register_name( - "__annotate_func__", - SymbolUsage::Assigned, - *range, - )?; - } - _ => {} - } - } else if value.is_some() { - self.check_name(id_str, ExpressionContext::Store, *range)?; - self.register_name(id_str, SymbolUsage::Assigned, *range)?; + Stmt::Expr(StmtExpr { value, .. }) => { + self.scan_expression(value, ExpressionContext::Load)? + } + Stmt::If(StmtIf { + test, + body, + elif_else_clauses, + .. + }) => { + self.scan_expression(test, ExpressionContext::Load)?; + // PEP 649: Track conditional block for annotations + let saved_in_conditional_block = self.in_conditional_block; + self.in_conditional_block = true; + self.scan_statements(body)?; + for elif in elif_else_clauses { + if let Some(test) = &elif.test { + self.scan_expression(test, ExpressionContext::Load)?; } + self.scan_statements(&elif.body)?; } - _ => { - self.scan_expression(target, ExpressionContext::Store)?; - } - } - self.scan_ann_assign_annotation(annotation)?; - if let Some(value) = value { - self.scan_expression(value, ExpressionContext::Load)?; + self.in_conditional_block = saved_in_conditional_block; } - } - Stmt::With(StmtWith { items, body, .. }) => { - for item in items { - self.scan_expression(&item.context_expr, ExpressionContext::Load)?; - if let Some(expression) = &item.optional_vars { - self.scan_expression(expression, ExpressionContext::Store)?; - } - } - // PEP 649: Track conditional block for annotations - let saved_in_conditional_block = self.in_conditional_block; - self.in_conditional_block = true; - self.scan_statements(body)?; - self.in_conditional_block = saved_in_conditional_block; - } - Stmt::Try(StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - // PEP 649: Track conditional block for annotations - let saved_in_conditional_block = self.in_conditional_block; - self.in_conditional_block = true; - self.scan_statements(body)?; - // Preserve source-order symbol analysis so `global`/`nonlocal` - // semantics match CPython, but reorder child scope storage to - // match the codegen order for plain try/except/else. - let body_subtables_len = self.tables.last().unwrap().sub_tables.len(); - for handler in handlers { - let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - type_, - name, - body, - .. - }) = &handler; - if let Some(expression) = type_ { - self.scan_expression(expression, ExpressionContext::Load)?; + Stmt::For(StmtFor { + target, + iter, + body, + orelse, + is_async, + .. + }) => { + if *is_async && self.allows_top_level_await() { + self.tables.last_mut().unwrap().is_coroutine = true; } - if let Some(name) = name { - self.register_ident(name, SymbolUsage::Assigned)?; + if *is_async && !self.tables.last().unwrap().is_coroutine { + return Err(SymbolTableError { + error: "'async for' outside async function".to_owned(), + location: Some(self.source_file.to_source_code().source_location( + statement.range().start(), + PositionEncoding::Utf8, + )), + }); } + self.scan_expression(target, ExpressionContext::Store)?; + self.scan_expression(iter, ExpressionContext::Load)?; + // PEP 649: Track conditional block for annotations + let saved_in_conditional_block = self.in_conditional_block; + self.in_conditional_block = true; self.scan_statements(body)?; - } - if finalbody.is_empty() { - let handler_subtables = self - .tables - .last_mut() - .unwrap() - .sub_tables - .split_off(body_subtables_len); self.scan_statements(orelse)?; - self.tables - .last_mut() - .unwrap() - .sub_tables - .extend(handler_subtables); - } else { + self.in_conditional_block = saved_in_conditional_block; + } + Stmt::While(StmtWhile { + test, body, orelse, .. + }) => { + self.scan_expression(test, ExpressionContext::Load)?; + // PEP 649: Track conditional block for annotations + let saved_in_conditional_block = self.in_conditional_block; + self.in_conditional_block = true; + self.scan_statements(body)?; self.scan_statements(orelse)?; + self.in_conditional_block = saved_in_conditional_block; } - self.scan_statements(finalbody)?; - self.in_conditional_block = saved_in_conditional_block; - } - Stmt::Match(StmtMatch { subject, cases, .. }) => { - self.scan_expression(subject, ExpressionContext::Load)?; - // PEP 649: Track conditional block for annotations - let saved_in_conditional_block = self.in_conditional_block; - self.in_conditional_block = true; - for case in cases { - self.scan_pattern(&case.pattern)?; - if let Some(guard) = &case.guard { - self.scan_expression(guard, ExpressionContext::Load)?; + Stmt::Break(_) | Stmt::Continue(_) | Stmt::Pass(_) => { + // No symbols here. + } + Stmt::Import(StmtImport { names, .. }) + | Stmt::ImportFrom(StmtImportFrom { names, .. }) => { + for name in names { + if let Some(alias) = &name.asname { + // `import my_module as my_alias` + self.register_name( + alias.as_str(), + SymbolUsage::Imported, + name.name.range, + )?; + } else if name.name.as_str() == "*" { + // Star imports are only allowed at module level + if self.tables.last().unwrap().typ != CompilerScope::Module { + return Err(SymbolTableError { + error: "import * only allowed at module level".to_string(), + location: Some( + self.source_file.to_source_code().source_location( + name.name.range.start(), + PositionEncoding::Utf8, + ), + ), + }); + } + // Don't register star imports as symbols + } else { + // `import module` or `from x import name` + let imported_name = name.name.split('.').next().unwrap(); + self.check_name( + imported_name, + ExpressionContext::Store, + name.name.range, + )?; + self.register_name( + imported_name, + SymbolUsage::Imported, + name.name.range, + )?; + } } - self.scan_statements(&case.body)?; } - self.in_conditional_block = saved_in_conditional_block; - } - Stmt::Raise(StmtRaise { exc, cause, .. }) => { - if let Some(expression) = exc { - self.scan_expression(expression, ExpressionContext::Load)?; + Stmt::Return(StmtReturn { value, .. }) => { + if let Some(expression) = value { + self.scan_expression(expression, ExpressionContext::Load)?; + self.tables.last_mut().unwrap().returns_value = true; + } } - if let Some(expression) = cause { - self.scan_expression(expression, ExpressionContext::Load)?; + Stmt::Assert(StmtAssert { test, msg, .. }) => { + self.scan_expression(test, ExpressionContext::Load)?; + if let Some(expression) = msg { + self.scan_expression(expression, ExpressionContext::Load)?; + } } - } - Stmt::TypeAlias(StmtTypeAlias { - name, - value, - type_params, - .. - }) => { - let Some(name_expr) = name.as_name_expr() else { + Stmt::Delete(StmtDelete { targets, .. }) => { + self.scan_expressions(targets, ExpressionContext::Delete)?; + } + Stmt::Assign(StmtAssign { targets, value, .. }) => { + self.scan_expressions(targets, ExpressionContext::Store)?; + self.scan_expression(value, ExpressionContext::Load)?; + } + Stmt::AugAssign(StmtAugAssign { target, value, .. }) => { + self.scan_expression(target, ExpressionContext::Store)?; + self.scan_expression(value, ExpressionContext::Load)?; + } + Stmt::AnnAssign(StmtAnnAssign { + target, + annotation, + value, + simple, + range, + node_index: _, + }) => { + self.tables.last_mut().unwrap().annotations_used = true; + // https://github.com/python/cpython/blob/main/Python/symtable.c#L1233 + match &**target { + Expr::Name(ast::ExprName { + id, + range: target_range, + .. + }) => { + let id_str = id.as_str(); + + if *simple { + let existing_flags = self.tables.last().and_then(|table| { + let name = maybe_mangle_name( + self.class_name.as_deref(), + table.mangled_names.as_ref(), + id_str, + ); + table.symbols.get(name.as_ref()).map(|symbol| symbol.flags) + }); + if self + .tables + .last() + .is_some_and(|table| table.typ != CompilerScope::Module) + && let Some(flags) = existing_flags + && flags.intersects(SymbolFlags::GLOBAL | SymbolFlags::NONLOCAL) + { + let usage = if flags.contains(SymbolFlags::GLOBAL) { + "global" + } else { + "nonlocal" + }; + return Err(SymbolTableError { + error: format!( + "annotated name '{id_str}' can't be {usage}" + ), + location: Some( + self.source_file.to_source_code().source_location( + range.start(), + PositionEncoding::Utf8, + ), + ), + }); + } + + self.register_name( + id_str, + SymbolUsage::AnnotationAssigned, + *target_range, + )?; + // PEP 649: Register annotate function in module/class scope + let current_scope = self.tables.last().map(|t| t.typ); + match current_scope { + Some(CompilerScope::Module) => { + self.register_name( + "__annotate__", + SymbolUsage::Assigned, + *range, + )?; + } + Some(CompilerScope::Class) => { + self.register_name( + "__annotate_func__", + SymbolUsage::Assigned, + *range, + )?; + } + _ => {} + } + } else if value.is_some() { + self.register_name(id_str, SymbolUsage::Assigned, *target_range)?; + } + } + _ => { + self.scan_expression(target, ExpressionContext::Store)?; + } + } + self.scan_ann_assign_annotation(annotation)?; + if let Some(value) = value { + self.scan_expression(value, ExpressionContext::Load)?; + } + } + Stmt::With(StmtWith { + items, + body, + is_async, + .. + }) => { + if *is_async && self.allows_top_level_await() { + self.tables.last_mut().unwrap().is_coroutine = true; + } + if *is_async && !self.tables.last().unwrap().is_coroutine { + return Err(SymbolTableError { + error: "'async with' outside async function".to_owned(), + location: Some(self.source_file.to_source_code().source_location( + statement.range().start(), + PositionEncoding::Utf8, + )), + }); + } + // PEP 649: Track conditional block for annotations + let saved_in_conditional_block = self.in_conditional_block; + self.in_conditional_block = true; + for item in items { + self.scan_expression(&item.context_expr, ExpressionContext::Load)?; + if let Some(expression) = &item.optional_vars { + self.scan_expression(expression, ExpressionContext::Store)?; + } + } + self.scan_statements(body)?; + self.in_conditional_block = saved_in_conditional_block; + } + Stmt::Try(StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + // PEP 649: Track conditional block for annotations + let saved_in_conditional_block = self.in_conditional_block; + self.in_conditional_block = true; + self.scan_statements(body)?; + for handler in handlers { + let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + type_, + name, + body, + .. + }) = &handler; + if let Some(expression) = type_ { + self.scan_expression(expression, ExpressionContext::Load)?; + } + if let Some(name) = name { + self.register_name( + name.as_str(), + SymbolUsage::Assigned, + handler.range(), + )?; + } + self.scan_statements(body)?; + } + self.scan_statements(orelse)?; + self.scan_statements(finalbody)?; + self.in_conditional_block = saved_in_conditional_block; + } + Stmt::Match(StmtMatch { subject, cases, .. }) => { + self.scan_expression(subject, ExpressionContext::Load)?; + // PEP 649: Track conditional block for annotations + let saved_in_conditional_block = self.in_conditional_block; + self.in_conditional_block = true; + for case in cases { + self.scan_pattern(&case.pattern)?; + if let Some(guard) = &case.guard { + self.scan_expression(guard, ExpressionContext::Load)?; + } + self.scan_statements(&case.body)?; + } + self.in_conditional_block = saved_in_conditional_block; + } + Stmt::Raise(StmtRaise { exc, cause, .. }) => { + if let Some(expression) = exc { + self.scan_expression(expression, ExpressionContext::Load)?; + if let Some(expression) = cause { + self.scan_expression(expression, ExpressionContext::Load)?; + } + } + } + Stmt::TypeAlias(StmtTypeAlias { + name, + value, + type_params, + range, + .. + }) => { + let Some(name_expr) = name.as_name_expr() else { + return Err(SymbolTableError { + error: "type alias expects name".to_owned(), + location: Some( + self.source_file + .to_source_code() + .source_location(name.range().start(), PositionEncoding::Utf8), + ), + }); + }; + let alias_name = name_expr.id.to_string(); + self.scan_expression(name, ExpressionContext::Store)?; + // Check before entering any sub-scopes + let in_class = self + .tables + .last() + .is_some_and(|t| t.typ == CompilerScope::Class); + let is_generic = type_params.is_some(); + if let Some(type_params) = type_params { + self.enter_type_param_block(&alias_name, *range, false, false, false)?; + self.scan_type_params(type_params)?; + } + // Value scope for lazy evaluation + self.enter_scope( + &alias_name, + CompilerScope::TypeAlias, + self.line_index_start(*range), + ); + // Evaluator takes a format parameter + self.register_name(".format", SymbolUsage::Parameter, *range)?; + self.register_name(".format", SymbolUsage::Used, *range)?; + if in_class { + if let Some(table) = self.tables.last_mut() { + table.can_see_class_scope = true; + } + self.register_name("__classdict__", SymbolUsage::Used, value.range())?; + } + self.scan_expression(value, ExpressionContext::Load)?; + self.leave_scope(); + if is_generic { + self.leave_scope(); + } + } + Stmt::IpyEscapeCommand(stmt) => { return Err(SymbolTableError { - error: "type alias expects name".to_owned(), + error: "invalid syntax".to_owned(), location: Some( self.source_file .to_source_code() - .source_location(name.range().start(), PositionEncoding::Utf8), + .source_location(stmt.range.start(), PositionEncoding::Utf8), ), }); - }; - let alias_name = name_expr.id.to_string(); - let was_in_type_alias = self.in_type_alias; - self.in_type_alias = true; - // Check before entering any sub-scopes - let in_class = self - .tables - .last() - .is_some_and(|t| t.typ == CompilerScope::Class); - let is_generic = type_params.is_some(); - if let Some(type_params) = type_params { - self.enter_type_param_block( - &alias_name, - self.line_index_start(type_params.range), - false, - false, - false, - )?; - self.scan_type_params(type_params)?; - } - // Value scope for lazy evaluation - self.enter_scope( - &alias_name, - CompilerScope::Annotation, - self.line_index_start(value.range()), - ); - // Evaluator takes a format parameter - self.register_name(".format", SymbolUsage::Parameter, TextRange::default())?; - if in_class { - if let Some(table) = self.tables.last_mut() { - table.can_see_class_scope = true; - } - self.register_name("__classdict__", SymbolUsage::Used, TextRange::default())?; - } - self.scan_expression(value, ExpressionContext::Load)?; - self.leave_scope(); - if is_generic { - self.leave_scope(); } - self.in_type_alias = was_in_type_alias; - self.scan_expression(name, ExpressionContext::Store)?; } - Stmt::IpyEscapeCommand(_) => todo!(), - } - Ok(()) + Ok(()) + })(); + self.recursion_depth -= 1; + result } fn scan_decorators( @@ -1962,389 +2309,541 @@ impl SymbolTableBuilder { expression: &ast::Expr, context: ExpressionContext, ) -> SymbolTableResult { - use ast::*; - - // Check for expressions not allowed in certain contexts - // (type parameters, annotations, type aliases, TypeVar bounds/defaults) - if let Some(keyword) = match expression { - Expr::Yield(_) | Expr::YieldFrom(_) => Some("yield"), - Expr::Await(_) => Some("await"), - Expr::Named(_) => Some("named"), - _ => None, - } { - // Determine the context name for the error message - // scope_info takes precedence (e.g., "a TypeVar bound") - let context_name = if let Some(scope_info) = self.scope_info { - Some(scope_info) - } else if let Some(table) = self.tables.last() - && table.typ == CompilerScope::TypeParams - { - Some("a type parameter") - } else if self.in_annotation { - Some("an annotation") - } else if self.in_type_alias { - Some("a type alias") - } else { - None - }; + if self.recursion_depth >= self.recursion_limit { + return Err(SymbolTableError { + error: RECURSION_ERROR.to_owned(), + location: None, + }); + } + self.recursion_depth += 1; + let result = (|| { + use ast::*; - if let Some(context_name) = context_name { - return Err(SymbolTableError { - error: format!("{keyword} expression cannot be used within {context_name}"), - location: Some( - self.source_file - .to_source_code() - .source_location(expression.range().start(), PositionEncoding::Utf8), - ), - }); + if self + .ast_constant_overrides + .as_ref() + .is_some_and(|overrides| { + let index = ast::HasNodeIndex::node_index(expression).load(); + index != ast::NodeIndex::NONE && overrides.contains_key(&index) + }) + { + return Ok(()); } - } - match expression { - Expr::BinOp(ExprBinOp { - left, - right, - range: _, - .. - }) => { - self.scan_expression(left, context)?; - self.scan_expression(right, context)?; - } - Expr::BoolOp(ExprBoolOp { - values, range: _, .. - }) => { - self.scan_expressions(values, context)?; - } - Expr::Compare(ExprCompare { - left, - comparators, - range: _, - .. - }) => { - self.scan_expression(left, context)?; - self.scan_expressions(comparators, context)?; - } - Expr::Subscript(ExprSubscript { - value, - slice, - range: _, - .. - }) => { - self.scan_expression(value, ExpressionContext::Load)?; - self.scan_expression(slice, ExpressionContext::Load)?; - } - Expr::Attribute(ExprAttribute { - value, attr, range, .. - }) => { - self.check_name(attr.as_str(), context, *range)?; - self.scan_expression(value, ExpressionContext::Load)?; - } - Expr::Dict(ExprDict { - items, - node_index: _, - range: _, - }) => { - for item in items { - if let Some(key) = &item.key { - self.scan_expression(key, context)?; - } - self.scan_expression(&item.value, context)?; - } - } - Expr::Await(ExprAwait { - value, - node_index: _, - range: _, - }) => { - self.scan_expression(value, context)?; - self.tables.last_mut().unwrap().is_coroutine = true; + // Check for expressions not allowed in certain contexts + // (type parameters, annotations, type aliases, TypeVar bounds/defaults) + if let Some(keyword) = match expression { + Expr::Yield(_) | Expr::YieldFrom(_) => Some("yield"), + Expr::Await(_) => Some("await"), + Expr::Named(_) => Some("named"), + _ => None, + } { + // Determine the context name for the error message from the + // current symbol table entry, matching CPython's ste_type checks. + let current_is_comprehension = self + .tables + .last() + .is_some_and(|table| table.typ == CompilerScope::Comprehension); + let context_name = if keyword == "named" && current_is_comprehension { + None + } else if let Some(table) = self.tables.last() { + match table.typ { + CompilerScope::Annotation => Some("an annotation"), + CompilerScope::TypeVariable => table.scope_info, + CompilerScope::TypeAlias => Some("a type alias"), + CompilerScope::TypeParams => Some("the definition of a generic"), + _ => None, + } + } else { + None + }; + + if let Some(context_name) = context_name { + return Err(SymbolTableError { + error: format!("{keyword} expression cannot be used within {context_name}"), + location: Some( + self.source_file.to_source_code().source_location( + expression.range().start(), + PositionEncoding::Utf8, + ), + ), + }); + } } - Expr::Yield(ExprYield { - value, - node_index: _, - range: _, - }) => { - self.tables.last_mut().unwrap().is_generator = true; - if let Some(expression) = value { - self.scan_expression(expression, context)?; - } - } - Expr::YieldFrom(ExprYieldFrom { - value, - node_index: _, - range: _, - }) => { - self.tables.last_mut().unwrap().is_generator = true; - self.scan_expression(value, context)?; - } - Expr::UnaryOp(ExprUnaryOp { - operand, range: _, .. - }) => { - self.scan_expression(operand, context)?; - } - Expr::Starred(ExprStarred { - value, range: _, .. - }) => { - self.scan_expression(value, context)?; - } - Expr::Tuple(ExprTuple { elts, range: _, .. }) - | Expr::Set(ExprSet { elts, range: _, .. }) - | Expr::List(ExprList { elts, range: _, .. }) => { - self.scan_expressions(elts, context)?; - } - Expr::Slice(ExprSlice { - lower, - upper, - step, - node_index: _, - range: _, - }) => { - if let Some(lower) = lower { - self.scan_expression(lower, context)?; - } - if let Some(upper) = upper { - self.scan_expression(upper, context)?; - } - if let Some(step) = step { - self.scan_expression(step, context)?; - } - } - Expr::Generator(ExprGenerator { - elt, - generators, - range, - .. - }) => { - let was_in_iter_def_exp = self.in_iter_def_exp; - if context == ExpressionContext::IterDefinitionExp { - self.in_iter_def_exp = true; - } - // Generator expression - is_generator = true - self.scan_comprehension("", elt, None, generators, *range, true)?; - self.in_iter_def_exp = was_in_iter_def_exp; - } - Expr::ListComp(ExprListComp { - elt, - generators, - range, - node_index: _, - }) => { - let was_in_iter_def_exp = self.in_iter_def_exp; - if context == ExpressionContext::IterDefinitionExp { - self.in_iter_def_exp = true; - } - // List comprehension - is_generator = false (can be inlined) - self.scan_comprehension("", elt, None, generators, *range, false)?; - self.in_iter_def_exp = was_in_iter_def_exp; - } - Expr::SetComp(ExprSetComp { - elt, - generators, - range, - node_index: _, - }) => { - let was_in_iter_def_exp = self.in_iter_def_exp; - if context == ExpressionContext::IterDefinitionExp { - self.in_iter_def_exp = true; - } - // Set comprehension - is_generator = false (can be inlined) - self.scan_comprehension("", elt, None, generators, *range, false)?; - self.in_iter_def_exp = was_in_iter_def_exp; - } - Expr::DictComp(ExprDictComp { - key, - value, - generators, - range, - node_index: _, - }) => { - let was_in_iter_def_exp = self.in_iter_def_exp; - if context == ExpressionContext::IterDefinitionExp { - self.in_iter_def_exp = true; - } - // Dict comprehension - is_generator = false (can be inlined) - self.scan_comprehension("", key, Some(value), generators, *range, false)?; - self.in_iter_def_exp = was_in_iter_def_exp; - } - Expr::Call(ExprCall { - func, - arguments, - node_index: _, - range: _, - }) => { - match context { - ExpressionContext::IterDefinitionExp => { - self.scan_expression(func, ExpressionContext::IterDefinitionExp)?; + + match expression { + Expr::BinOp(ExprBinOp { + left, + right, + range: _, + .. + }) => { + self.scan_expression(left, context)?; + self.scan_expression(right, context)?; + } + Expr::BoolOp(ExprBoolOp { + values, range: _, .. + }) => { + self.scan_expressions(values, context)?; + } + Expr::Compare(ExprCompare { + left, + comparators, + range: _, + .. + }) => { + self.scan_expression(left, context)?; + self.scan_expressions(comparators, context)?; + } + Expr::Subscript(ExprSubscript { + value, + slice, + range: _, + .. + }) => { + self.scan_expression(value, ExpressionContext::Load)?; + self.scan_expression(slice, ExpressionContext::Load)?; + } + Expr::Attribute(ExprAttribute { + value, attr, range, .. + }) => { + self.check_name(attr.as_str(), context, *range)?; + self.scan_expression(value, ExpressionContext::Load)?; + } + Expr::Dict(ExprDict { + items, + node_index: _, + range: _, + }) => { + for item in items { + if let Some(key) = &item.key { + self.scan_expression(key, context)?; + } } - _ => { - self.scan_expression(func, ExpressionContext::Load)?; + for item in items { + self.scan_expression(&item.value, context)?; } } - - self.scan_expressions(&arguments.args, ExpressionContext::Load)?; - for keyword in &arguments.keywords { - if let Some(arg) = &keyword.arg { - self.check_name(arg.as_str(), ExpressionContext::Store, keyword.range)?; + Expr::Await(ExprAwait { + value, + node_index: _, + range: _, + }) => { + let current_scope = self.tables.last().unwrap().typ; + if !self.allows_top_level_await() + && !Self::is_function_like_scope(current_scope) + { + return Err(SymbolTableError { + error: "'await' outside function".to_owned(), + location: Some(self.source_file.to_source_code().source_location( + expression.range().start(), + PositionEncoding::Utf8, + )), + }); + } + if current_scope != CompilerScope::AsyncFunction + && current_scope != CompilerScope::Comprehension + && !self.allows_top_level_await() + { + return Err(SymbolTableError { + error: "'await' outside async function".to_owned(), + location: Some(self.source_file.to_source_code().source_location( + expression.range().start(), + PositionEncoding::Utf8, + )), + }); } - self.scan_expression(&keyword.value, ExpressionContext::Load)?; + self.scan_expression(value, context)?; + self.tables.last_mut().unwrap().is_coroutine = true; } - } - Expr::Name(ExprName { id, range, .. }) => { - let id = id.as_str(); - - self.check_name(id, context, *range)?; - - if !self.in_unevaluated_annotation { - // Determine the contextual usage of this symbol: + Expr::Yield(ExprYield { + value, + node_index: _, + range: _, + }) => { + if let Some(expression) = value { + self.scan_expression(expression, context)?; + } + self.tables.last_mut().unwrap().is_generator = true; + if let Some(context_name) = self.comprehension_yield_context + && self + .tables + .last() + .is_some_and(|table| table.typ == CompilerScope::Comprehension) + { + return Err(SymbolTableError { + error: format!("'yield' inside {context_name}"), + location: Some(self.source_file.to_source_code().source_location( + expression.range().start(), + PositionEncoding::Utf8, + )), + }); + } + } + Expr::YieldFrom(ExprYieldFrom { + value, + node_index: _, + range: _, + }) => { + self.scan_expression(value, context)?; + self.tables.last_mut().unwrap().is_generator = true; + if let Some(context_name) = self.comprehension_yield_context + && self + .tables + .last() + .is_some_and(|table| table.typ == CompilerScope::Comprehension) + { + return Err(SymbolTableError { + error: format!("'yield' inside {context_name}"), + location: Some(self.source_file.to_source_code().source_location( + expression.range().start(), + PositionEncoding::Utf8, + )), + }); + } + } + Expr::UnaryOp(ExprUnaryOp { + operand, range: _, .. + }) => { + self.scan_expression(operand, context)?; + } + Expr::Starred(ExprStarred { + value, range: _, .. + }) => { + self.scan_expression(value, context)?; + } + Expr::Tuple(ExprTuple { elts, range: _, .. }) + | Expr::Set(ExprSet { elts, range: _, .. }) + | Expr::List(ExprList { elts, range: _, .. }) => { + self.scan_expressions(elts, context)?; + } + Expr::Slice(ExprSlice { + lower, + upper, + step, + node_index: _, + range: _, + }) => { + if let Some(lower) = lower { + self.scan_expression(lower, context)?; + } + if let Some(upper) = upper { + self.scan_expression(upper, context)?; + } + if let Some(step) = step { + self.scan_expression(step, context)?; + } + } + Expr::Generator(ExprGenerator { + elt, + generators, + range, + .. + }) => { + let was_in_iter_def_exp = self.in_iter_def_exp; + if context == ExpressionContext::IterDefinitionExp { + self.in_iter_def_exp = true; + } + // Generator expression - is_generator = true + self.scan_comprehension("", elt, None, generators, *range, true)?; + self.in_iter_def_exp = was_in_iter_def_exp; + } + Expr::ListComp(ExprListComp { + elt, + generators, + range, + node_index: _, + }) => { + let was_in_iter_def_exp = self.in_iter_def_exp; + if context == ExpressionContext::IterDefinitionExp { + self.in_iter_def_exp = true; + } + // List comprehension - is_generator = false (can be inlined) + self.scan_comprehension("", elt, None, generators, *range, false)?; + self.in_iter_def_exp = was_in_iter_def_exp; + } + Expr::SetComp(ExprSetComp { + elt, + generators, + range, + node_index: _, + }) => { + let was_in_iter_def_exp = self.in_iter_def_exp; + if context == ExpressionContext::IterDefinitionExp { + self.in_iter_def_exp = true; + } + // Set comprehension - is_generator = false (can be inlined) + self.scan_comprehension("", elt, None, generators, *range, false)?; + self.in_iter_def_exp = was_in_iter_def_exp; + } + Expr::DictComp(ExprDictComp { + key, + value, + generators, + range, + node_index: _, + }) => { + let was_in_iter_def_exp = self.in_iter_def_exp; + if context == ExpressionContext::IterDefinitionExp { + self.in_iter_def_exp = true; + } + // Dict comprehension - is_generator = false (can be inlined) + let Some(key) = key.as_deref() else { + self.scan_expression(value, ExpressionContext::Load)?; + self.in_iter_def_exp = was_in_iter_def_exp; + return Ok(()); + }; + self.scan_comprehension( + "", + key, + Some(value), + generators, + *range, + false, + )?; + self.in_iter_def_exp = was_in_iter_def_exp; + } + Expr::Call(ExprCall { + func, + arguments, + node_index: _, + range: _, + }) => { match context { - ExpressionContext::Delete => { - self.register_name(id, SymbolUsage::Assigned, *range)?; - self.register_name(id, SymbolUsage::Used, *range)?; + ExpressionContext::IterDefinitionExp => { + self.scan_expression(func, ExpressionContext::IterDefinitionExp)?; } - ExpressionContext::Load | ExpressionContext::IterDefinitionExp => { - self.register_name(id, SymbolUsage::Used, *range)?; + _ => { + self.scan_expression(func, ExpressionContext::Load)?; } - ExpressionContext::Store => { - self.register_name(id, SymbolUsage::Assigned, *range)?; - } - ExpressionContext::Iter => { - self.register_name(id, SymbolUsage::Iter, *range)?; + } + + self.scan_expressions(&arguments.args, ExpressionContext::Load)?; + for keyword in &arguments.keywords { + if let Some(arg) = &keyword.arg { + self.check_name(arg.as_str(), ExpressionContext::Store, keyword.range)?; } } - // Interesting stuff about the __class__ variable: - // https://docs.python.org/3/reference/datamodel.html?highlight=__class__#creating-the-class-object - if context == ExpressionContext::Load - && matches!( - self.tables.last().unwrap().typ, - CompilerScope::Function | CompilerScope::AsyncFunction - ) - && id == "super" + for keyword in &arguments.keywords { + self.scan_expression(&keyword.value, ExpressionContext::Load)?; + } + } + Expr::Name(ExprName { id, range, .. }) => { + let id = id.as_str(); + + self.check_name(id, context, *range)?; + + if !self + .tables + .last() + .is_some_and(|table| table.in_unevaluated_annotation) { - self.register_name("__class__", SymbolUsage::Used, *range)?; + // Determine the contextual usage of this symbol: + match context { + ExpressionContext::Delete => { + self.register_name(id, SymbolUsage::Assigned, *range)?; + } + ExpressionContext::Load | ExpressionContext::IterDefinitionExp => { + self.register_name(id, SymbolUsage::Used, *range)?; + } + ExpressionContext::Store => { + self.register_name(id, SymbolUsage::Assigned, *range)?; + } + ExpressionContext::Iter => { + self.register_name(id, SymbolUsage::Iter, *range)?; + } + } + // Interesting stuff about the __class__ variable: + // https://docs.python.org/3/reference/datamodel.html?highlight=__class__#creating-the-class-object + if context == ExpressionContext::Load + && Self::is_function_like_scope(self.tables.last().unwrap().typ) + && id == "super" + { + self.register_name("__class__", SymbolUsage::Used, *range)?; + } } } - } - Expr::Lambda(ExprLambda { - body, - parameters, - node_index: _, - range: _, - }) => { - if let Some(parameters) = parameters { - self.enter_scope_with_parameters( - "lambda", - parameters, - self.line_index_start(expression.range()), - false, // lambdas have no return annotation - CompilerScope::Lambda, - false, // don't skip defaults - )?; - } else { - self.enter_scope( - "lambda", - CompilerScope::Lambda, - self.line_index_start(expression.range()), - ); + Expr::Lambda(ExprLambda { + body, + parameters, + node_index: _, + range: _, + }) => { + let was_in_iter_def_exp = self.in_iter_def_exp; + if let Some(parameters) = parameters { + if was_in_iter_def_exp { + self.scan_parameter_defaults(parameters)?; + } + self.enter_scope_with_parameters( + "lambda", + parameters, + self.line_index_start(expression.range()), + None, // lambdas have no return annotation + CompilerScope::Lambda, + was_in_iter_def_exp, + false, + )?; + } else { + self.enter_scope( + "lambda", + CompilerScope::Lambda, + self.line_index_start(expression.range()), + ); + } + self.scan_expression(body, ExpressionContext::Load)?; + self.in_iter_def_exp = was_in_iter_def_exp; + self.leave_scope(); } - match context { - ExpressionContext::IterDefinitionExp => { - self.scan_expression(body, ExpressionContext::IterDefinitionExp)?; + Expr::FString(ExprFString { + node_index, value, .. + }) => { + if let Some(joined_str) = + self.public_ast_joined_str_override_by_index(node_index.load()) + { + for expr in &joined_str.values { + self.scan_expression(expr, ExpressionContext::Load)?; + } + return Ok(()); } - _ => { - self.scan_expression(body, ExpressionContext::Load)?; + for expr in value.elements().filter_map(|x| x.as_interpolation()) { + self.scan_expression(&expr.expression, ExpressionContext::Load)?; + if let Some(formatted_value) = self + .public_ast_formatted_value_override_by_index(expr.node_index.load()) + && let Some(format_spec) = &formatted_value.format_spec + { + self.scan_expression(format_spec, ExpressionContext::Load)?; + } else if let Some(format_spec) = &expr.format_spec { + for element in format_spec.elements.interpolations() { + self.scan_expression(&element.expression, ExpressionContext::Load)? + } + } } } - self.leave_scope(); - } - Expr::FString(ExprFString { value, .. }) => { - for expr in value.elements().filter_map(|x| x.as_interpolation()) { - self.scan_expression(&expr.expression, ExpressionContext::Load)?; - if let Some(format_spec) = &expr.format_spec { - for element in format_spec.elements.interpolations() { - self.scan_expression(&element.expression, ExpressionContext::Load)? + Expr::TString(tstring) => { + if let Some(template_str) = + self.public_ast_template_str_override_by_index(tstring.node_index.load()) + { + for expr in &template_str.values { + self.scan_expression(expr, ExpressionContext::Load)?; } + return Ok(()); } - } - } - Expr::TString(tstring) => { - // Scan t-string interpolation expressions (similar to f-strings) - for expr in tstring - .value - .elements() - .filter_map(|x| x.as_interpolation()) - { - self.scan_expression(&expr.expression, ExpressionContext::Load)?; - if let Some(format_spec) = &expr.format_spec { - for element in format_spec.elements.interpolations() { - self.scan_expression(&element.expression, ExpressionContext::Load)? + // Scan t-string interpolation expressions (similar to f-strings) + for expr in tstring + .value + .elements() + .filter_map(|x| x.as_interpolation()) + { + self.scan_expression(&expr.expression, ExpressionContext::Load)?; + if let Some(interpolation) = + self.public_ast_interpolation_override_by_index(expr.node_index.load()) + { + if let Some(format_spec) = &interpolation.format_spec { + self.scan_expression(format_spec, ExpressionContext::Load)?; + } + } else if let Some(format_spec) = &expr.format_spec { + for element in format_spec.elements.interpolations() { + self.scan_expression(&element.expression, ExpressionContext::Load)? + } } } } - } - // Constants - Expr::StringLiteral(_) - | Expr::BytesLiteral(_) - | Expr::NumberLiteral(_) - | Expr::BooleanLiteral(_) - | Expr::NoneLiteral(_) - | Expr::EllipsisLiteral(_) => {} - Expr::IpyEscapeCommand(_) => todo!(), - Expr::If(ExprIf { - test, - body, - orelse, - node_index: _, - range: _, - }) => { - self.scan_expression(test, ExpressionContext::Load)?; - self.scan_expression(body, ExpressionContext::Load)?; - self.scan_expression(orelse, ExpressionContext::Load)?; - } - - Expr::Named(ExprNamed { - target, - value, - range, - node_index: _, - }) => { - // named expressions are not allowed in the definition of - // comprehension iterator definitions (including nested comprehensions) - if context == ExpressionContext::IterDefinitionExp || self.in_iter_def_exp { + // Constants + Expr::StringLiteral(_) + | Expr::BytesLiteral(_) + | Expr::NumberLiteral(_) + | Expr::BooleanLiteral(_) + | Expr::NoneLiteral(_) + | Expr::EllipsisLiteral(_) => {} + Expr::IpyEscapeCommand(expr) => { return Err(SymbolTableError { - error: "assignment expression cannot be used in a comprehension iterable expression".to_string(), - location: Some(self.source_file.to_source_code().source_location(target.range().start(), PositionEncoding::Utf8)), - }); + error: "invalid syntax".to_owned(), + location: Some( + self.source_file + .to_source_code() + .source_location(expr.range.start(), PositionEncoding::Utf8), + ), + }); + } + Expr::If(ExprIf { + test, + body, + orelse, + node_index: _, + range: _, + }) => { + self.scan_expression(test, ExpressionContext::Load)?; + self.scan_expression(body, ExpressionContext::Load)?; + self.scan_expression(orelse, ExpressionContext::Load)?; } - self.scan_expression(value, ExpressionContext::Load)?; + Expr::Named(ExprNamed { + target, + value, + range, + node_index: _, + }) => { + // named expressions are not allowed in the definition of + // comprehension iterator definitions (including nested comprehensions) + if context == ExpressionContext::IterDefinitionExp || self.in_iter_def_exp { + return Err(SymbolTableError { + error: + "assignment expression cannot be used in a comprehension iterable expression" + .to_string(), + location: Some( + self.source_file + .to_source_code() + .source_location(range.start(), PositionEncoding::Utf8), + ), + }); + } - // special handling for assigned identifier in named expressions - // that are used in comprehensions. This required to correctly - // propagate the scope of the named assigned named and not to - // propagate inner names. - if let Expr::Name(ExprName { id, .. }) = &**target { - let id = id.as_str(); - self.check_name(id, ExpressionContext::Store, *range)?; - let table = self.tables.last().unwrap(); - if table.typ == CompilerScope::Comprehension { - self.extend_namedexpr_scope(id, *range)?; - self.register_name( - id, - SymbolUsage::AssignedNamedExprInComprehension, - *range, - )?; + let named_target = if let Expr::Name(ExprName { + id, + range: target_range, + .. + }) = &**target + { + let id = id.as_str(); + self.check_name(id, ExpressionContext::Store, *target_range)?; + let table = self.tables.last().unwrap(); + if table.typ == CompilerScope::Comprehension { + self.extend_namedexpr_scope(id, *target_range)?; + } + Some((id, *target_range)) } else { - // omit one recursion. When the handling of an store changes for - // Identifiers this needs adapted - more forward safe would be - // calling scan_expression directly. - self.register_name(id, SymbolUsage::Assigned, *range)?; + None + }; + + self.scan_expression(value, ExpressionContext::Load)?; + + // special handling for assigned identifier in named expressions + // that are used in comprehensions. This required to correctly + // propagate the scope of the named assigned named and not to + // propagate inner names. + if let Some((id, target_range)) = named_target { + let table = self.tables.last().unwrap(); + if table.typ == CompilerScope::Comprehension { + self.register_name( + id, + SymbolUsage::AssignedNamedExprInComprehension, + target_range, + )?; + } else { + // omit one recursion. When the handling of an store changes for + // Identifiers this needs adapted - more forward safe would be + // calling scan_expression directly. + self.register_name(id, SymbolUsage::Assigned, target_range)?; + } + } else { + self.scan_expression(target, ExpressionContext::Store)?; } - } else { - self.scan_expression(target, ExpressionContext::Store)?; } } - } - Ok(()) + Ok(()) + })(); + self.recursion_depth -= 1; + result } fn scan_comprehension( @@ -2356,26 +2855,15 @@ impl SymbolTableBuilder { range: TextRange, is_generator: bool, ) -> SymbolTableResult { - // Check for async comprehension outside async function - // (list/set/dict comprehensions only, not generator expressions) - let has_async_gen = generators.iter().any(|g| g.is_async); - if has_async_gen && !is_generator && !self.is_in_async_context() { - return Err(SymbolTableError { - error: "asynchronous comprehension outside of an asynchronous function".to_owned(), - location: Some( - self.source_file - .to_source_code() - .source_location(range.start(), PositionEncoding::Utf8), - ), - }); - } - assert!(!generators.is_empty()); let outermost = &generators[0]; // CPython evaluates the outermost iterator in the enclosing scope // before entering the comprehension scope. + let was_in_iter_def_exp = self.in_iter_def_exp; + self.in_iter_def_exp = true; self.scan_expression(&outermost.iter, ExpressionContext::IterDefinitionExp)?; + self.in_iter_def_exp = was_in_iter_def_exp; // Comprehensions are compiled as functions, so create a scope for them: self.enter_scope( @@ -2383,9 +2871,7 @@ impl SymbolTableBuilder { CompilerScope::Comprehension, self.line_index_start(range), ); - // Generator expressions need the is_generator flag - self.tables.last_mut().unwrap().is_generator = is_generator; - if generators.iter().any(|generator| generator.is_async) { + if outermost.is_async { self.tables.last_mut().unwrap().is_coroutine = true; } @@ -2404,6 +2890,15 @@ impl SymbolTableBuilder { // Register the passed argument to the generator function as the name ".0" self.register_name(".0", SymbolUsage::Parameter, range)?; + let saved_comprehension_yield_context = self.comprehension_yield_context; + self.comprehension_yield_context = Some(match scope_name { + "" => "list comprehension", + "" => "set comprehension", + "" => "dict comprehension", + "" => "generator expression", + _ => "comprehension", + }); + self.scan_expression(&outermost.target, ExpressionContext::Iter)?; for if_expr in &outermost.ifs { self.scan_expression(if_expr, ExpressionContext::Load)?; @@ -2413,22 +2908,47 @@ impl SymbolTableBuilder { self.in_comp_inner_loop_target = true; self.scan_expression(&generator.target, ExpressionContext::Iter)?; self.in_comp_inner_loop_target = false; + let was_in_iter_def_exp = self.in_iter_def_exp; + self.in_iter_def_exp = true; self.scan_expression(&generator.iter, ExpressionContext::IterDefinitionExp)?; + self.in_iter_def_exp = was_in_iter_def_exp; for if_expr in &generator.ifs { self.scan_expression(if_expr, ExpressionContext::Load)?; } + if generator.is_async { + self.tables.last_mut().unwrap().is_coroutine = true; + } } if let Some(elt2) = elt2 { self.scan_expression(elt2, ExpressionContext::Load)?; } self.scan_expression(elt1, ExpressionContext::Load)?; + self.tables.last_mut().unwrap().is_generator = is_generator; + self.comprehension_yield_context = saved_comprehension_yield_context; // CPython symtable_handle_comprehension(): non-generator async // comprehensions propagate ste_coroutine to the enclosing scope after // the comprehension block is exited. let propagate_coroutine = self.tables.last().unwrap().is_coroutine && !is_generator; self.leave_scope(); + if propagate_coroutine + && self + .tables + .last() + .is_none_or(|table| table.typ != CompilerScope::Comprehension) + && !self.is_in_async_context() + && !self.allows_top_level_await() + { + return Err(SymbolTableError { + error: "asynchronous comprehension outside of an asynchronous function".to_owned(), + location: Some( + self.source_file + .to_source_code() + .source_location(range.start(), PositionEncoding::Utf8), + ), + }); + } if propagate_coroutine { self.tables.last_mut().unwrap().is_coroutine = true; } @@ -2447,139 +2967,145 @@ impl SymbolTableBuilder { // Bounds/defaults are compiled as annotation scopes in CPython. let in_class = self.tables.last().is_some_and(|t| t.can_see_class_scope); let line_number = self.line_index_start(expr.range()); - self.enter_scope(scope_name, CompilerScope::Annotation, line_number); + self.enter_scope(scope_name, CompilerScope::TypeVariable, line_number); // Evaluator takes a format parameter - self.register_name(".format", SymbolUsage::Parameter, TextRange::default())?; + self.register_name(".format", SymbolUsage::Parameter, expr.range())?; + self.register_name(".format", SymbolUsage::Used, expr.range())?; if in_class { if let Some(table) = self.tables.last_mut() { table.can_see_class_scope = true; } - self.register_name("__classdict__", SymbolUsage::Used, TextRange::default())?; + self.register_name("__classdict__", SymbolUsage::Used, expr.range())?; } - // Set scope_info for better error messages - let old_scope_info = self.scope_info; - self.scope_info = Some(scope_info); + self.tables.last_mut().unwrap().scope_info = Some(scope_info); // Scan the expression in this new scope let result = self.scan_expression(expr, ExpressionContext::Load); - // Restore scope_info and exit the scope - self.scope_info = old_scope_info; self.leave_scope(); result } fn scan_type_params(&mut self, type_params: &ast::TypeParams) -> SymbolTableResult { - // Check for duplicate type parameter names - let mut seen_names: IndexSet<&str> = IndexSet::default(); - // Check for non-default type parameter after default type parameter - let mut default_seen = false; + // CPython visits each type parameter as: register name, scan bound, scan default. for type_param in &type_params.type_params { - let (name, range, has_default) = match type_param { - ast::TypeParam::TypeVar(tv) => (tv.name.as_str(), tv.range, tv.default.is_some()), - ast::TypeParam::ParamSpec(ps) => (ps.name.as_str(), ps.range, ps.default.is_some()), - ast::TypeParam::TypeVarTuple(tvt) => { - (tvt.name.as_str(), tvt.range, tvt.default.is_some()) - } - }; - if !seen_names.insert(name) { - return Err(SymbolTableError { - error: format!("duplicate type parameter '{name}'"), - location: Some( - self.source_file - .to_source_code() - .source_location(range.start(), PositionEncoding::Utf8), - ), - }); - } - if has_default { - default_seen = true; - } else if default_seen { + if self.recursion_depth >= self.recursion_limit { return Err(SymbolTableError { - error: format!( - "non-default type parameter '{name}' follows default type parameter" - ), - location: Some( - self.source_file - .to_source_code() - .source_location(range.start(), PositionEncoding::Utf8), - ), + error: RECURSION_ERROR.to_owned(), + location: None, }); } - } - - // Register .type_params as a type parameter (automatically becomes cell variable) - self.register_name(".type_params", SymbolUsage::TypeParam, type_params.range)?; - - // First register all type parameters - for type_param in &type_params.type_params { - match type_param { - ast::TypeParam::TypeVar(ast::TypeParamTypeVar { - name, - bound, - range: type_var_range, - default, - node_index: _, - }) => { - self.register_name(name.as_str(), SymbolUsage::TypeParam, *type_var_range)?; + self.recursion_depth += 1; + let result = (|| { + match type_param { + ast::TypeParam::TypeVar(ast::TypeParamTypeVar { + name, + bound, + range: type_var_range, + default, + node_index: _, + }) => { + self.register_name(name.as_str(), SymbolUsage::TypeParam, *type_var_range)?; + if name.as_str() == "__classdict__" { + return Err(SymbolTableError { + error: format!( + "reserved name '{}' cannot be used for type parameter", + name.as_str() + ), + location: Some(self.source_file.to_source_code().source_location( + type_var_range.start(), + PositionEncoding::Utf8, + )), + }); + } - // Process bound in a separate scope - if let Some(binding) = bound { - let scope_info = if binding.is_tuple_expr() { - "a TypeVar constraint" - } else { - "a TypeVar bound" - }; - self.scan_type_param_bound_or_default(binding, name.as_str(), scope_info)?; - } + // Process bound in a separate scope + if let Some(binding) = bound { + let scope_info = if binding.is_tuple_expr() { + "a TypeVar constraint" + } else { + "a TypeVar bound" + }; + self.scan_type_param_bound_or_default( + binding, + name.as_str(), + scope_info, + )?; + } - // Process default in a separate scope - if let Some(default_value) = default { - self.scan_type_param_bound_or_default( - default_value, - name.as_str(), - "a TypeVar default", - )?; + // Process default in a separate scope + if let Some(default_value) = default { + self.scan_type_param_bound_or_default( + default_value, + name.as_str(), + "a TypeVar default", + )?; + } } - } - ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { - name, - range: param_spec_range, - default, - node_index: _, - }) => { - self.register_name(name, SymbolUsage::TypeParam, *param_spec_range)?; + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { + name, + range: param_spec_range, + default, + node_index: _, + }) => { + self.register_name(name, SymbolUsage::TypeParam, *param_spec_range)?; + if name == "__classdict__" { + return Err(SymbolTableError { + error: format!( + "reserved name '{name}' cannot be used for type parameter" + ), + location: Some(self.source_file.to_source_code().source_location( + param_spec_range.start(), + PositionEncoding::Utf8, + )), + }); + } - // Process default in a separate scope - if let Some(default_value) = default { - self.scan_type_param_bound_or_default( - default_value, - name, - "a ParamSpec default", - )?; + // Process default in a separate scope + if let Some(default_value) = default { + self.scan_type_param_bound_or_default( + default_value, + name, + "a ParamSpec default", + )?; + } } - } - ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { - name, - range: type_var_tuple_range, - default, - node_index: _, - }) => { - self.register_name(name, SymbolUsage::TypeParam, *type_var_tuple_range)?; + ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { + name, + range: type_var_tuple_range, + default, + node_index: _, + }) => { + self.register_name(name, SymbolUsage::TypeParam, *type_var_tuple_range)?; + if name == "__classdict__" { + return Err(SymbolTableError { + error: format!( + "reserved name '{name}' cannot be used for type parameter" + ), + location: Some(self.source_file.to_source_code().source_location( + type_var_tuple_range.start(), + PositionEncoding::Utf8, + )), + }); + } - // Process default in a separate scope - if let Some(default_value) = default { - self.scan_type_param_bound_or_default( - default_value, - name, - "a TypeVarTuple default", - )?; + // Process default in a separate scope + if let Some(default_value) = default { + self.scan_type_param_bound_or_default( + default_value, + name, + "a TypeVarTuple default", + )?; + } } } - } + Ok(()) + })(); + self.recursion_depth -= 1; + result?; } Ok(()) } @@ -2592,51 +3118,86 @@ impl SymbolTableBuilder { } fn scan_pattern(&mut self, pattern: &ast::Pattern) -> SymbolTableResult { - match pattern { - ast::Pattern::MatchValue(ast::PatternMatchValue { value, .. }) => { - self.scan_expression(value, ExpressionContext::Load)? - } - ast::Pattern::MatchSingleton(_) => {} - ast::Pattern::MatchSequence(ast::PatternMatchSequence { patterns, .. }) => { - self.scan_patterns(patterns)? - } - ast::Pattern::MatchMapping(ast::PatternMatchMapping { - keys, - patterns, - rest, - .. - }) => { - self.scan_expressions(keys, ExpressionContext::Load)?; - self.scan_patterns(patterns)?; - if let Some(rest) = rest { - self.register_ident(rest, SymbolUsage::Assigned)?; + if self.recursion_depth >= self.recursion_limit { + return Err(SymbolTableError { + error: RECURSION_ERROR.to_owned(), + location: None, + }); + } + self.recursion_depth += 1; + let result = (|| { + use ast::Pattern::{ + MatchAs, MatchClass, MatchMapping, MatchOr, MatchSequence, MatchSingleton, + MatchStar, MatchValue, + }; + match pattern { + MatchValue(ast::PatternMatchValue { value, .. }) => { + self.scan_expression(value, ExpressionContext::Load)? } - } - ast::Pattern::MatchClass(ast::PatternMatchClass { cls, arguments, .. }) => { - self.scan_expression(cls, ExpressionContext::Load)?; - self.scan_patterns(&arguments.patterns)?; - for kw in &arguments.keywords { - self.scan_pattern(&kw.pattern)?; + MatchSingleton(_) => {} + MatchSequence(ast::PatternMatchSequence { patterns, .. }) => { + self.scan_patterns(patterns)? } - } - ast::Pattern::MatchStar(ast::PatternMatchStar { name, .. }) => { - if let Some(name) = name { - self.register_ident(name, SymbolUsage::Assigned)?; + MatchMapping(ast::PatternMatchMapping { + keys, + patterns, + rest, + .. + }) => { + self.scan_expressions(keys, ExpressionContext::Load)?; + self.scan_patterns(patterns)?; + if let Some(rest) = rest { + if rest.as_str() == "_" { + return Err(SymbolTableError { + error: "invalid syntax".to_owned(), + location: Some( + self.source_file.to_source_code().source_location( + rest.range.start(), + PositionEncoding::Utf8, + ), + ), + }); + } + self.register_name(rest.as_str(), SymbolUsage::Assigned, pattern.range())?; + } } - } - ast::Pattern::MatchAs(ast::PatternMatchAs { pattern, name, .. }) => { - if let Some(pattern) = pattern { - self.scan_pattern(pattern)?; + MatchClass(ast::PatternMatchClass { cls, arguments, .. }) => { + self.scan_expression(cls, ExpressionContext::Load)?; + self.scan_patterns(&arguments.patterns)?; + for kw in &arguments.keywords { + self.check_name( + kw.attr.as_str(), + ExpressionContext::Store, + kw.pattern.range(), + )?; + } + for kw in &arguments.keywords { + self.scan_pattern(&kw.pattern)?; + } } - if let Some(name) = name { - self.register_ident(name, SymbolUsage::Assigned)?; + MatchStar(ast::PatternMatchStar { name, .. }) => { + if let Some(name) = name { + self.register_name(name.as_str(), SymbolUsage::Assigned, pattern.range())?; + } } + MatchAs(ast::PatternMatchAs { + pattern: as_pattern, + name, + .. + }) => { + if let Some(as_pattern) = as_pattern { + self.scan_pattern(as_pattern)?; + } + if let Some(name) = name { + self.register_name(name.as_str(), SymbolUsage::Assigned, pattern.range())?; + } + } + MatchOr(ast::PatternMatchOr { patterns, .. }) => self.scan_patterns(patterns)?, } - ast::Pattern::MatchOr(ast::PatternMatchOr { patterns, .. }) => { - self.scan_patterns(patterns)? - } - } - Ok(()) + Ok(()) + })(); + self.recursion_depth -= 1; + result } /// Scan default parameter values (evaluated in the enclosing scope) @@ -2660,79 +3221,43 @@ impl SymbolTableBuilder { .any(|arg| arg.default.is_some()) } + fn has_positional_defaults(parameters: &ast::Parameters) -> bool { + parameters + .posonlyargs + .iter() + .chain(parameters.args.iter()) + .any(|arg| arg.default.is_some()) + } + + #[expect( + clippy::too_many_arguments, + reason = "keeps parameter/default scanning options explicit at call sites" + )] fn enter_scope_with_parameters( &mut self, name: &str, parameters: &ast::Parameters, line_number: u32, - has_return_annotation: bool, + returns: Option<&ast::Expr>, scope_type: CompilerScope, skip_defaults: bool, + skip_annotations: bool, ) -> SymbolTableResult { // Evaluate eventual default parameters (unless already scanned before type_param_block): if !skip_defaults { self.scan_parameter_defaults(parameters)?; } - // Annotations are scanned in outer scope: - for annotation in parameters - .posonlyargs - .iter() - .chain(parameters.args.iter()) - .chain(parameters.kwonlyargs.iter()) - .filter_map(|arg| arg.parameter.annotation.as_ref()) - { - self.scan_annotation(annotation)?; - } - if let Some(annotation) = parameters - .vararg - .as_ref() - .and_then(|arg| arg.annotation.as_ref()) - { - self.scan_annotation(annotation)?; - } - if let Some(annotation) = parameters - .kwarg - .as_ref() - .and_then(|arg| arg.annotation.as_ref()) - { - self.scan_annotation(annotation)?; + let is_function_scope = matches!( + scope_type, + CompilerScope::Function | CompilerScope::AsyncFunction + ); + if is_function_scope && !skip_annotations { + self.scan_function_annotations(parameters, returns, line_number)?; } - // Check if this function has any annotations (parameter or return) - let has_param_annotations = parameters - .posonlyargs - .iter() - .chain(parameters.args.iter()) - .chain(parameters.kwonlyargs.iter()) - .any(|p| p.parameter.annotation.is_some()) - || parameters - .vararg - .as_ref() - .is_some_and(|p| p.annotation.is_some()) - || parameters - .kwarg - .as_ref() - .is_some_and(|p| p.annotation.is_some()); - - let has_any_annotations = has_param_annotations || has_return_annotation; - - // Take annotation_block if this function has any annotations. - // When in class scope, the class's annotation_block was saved before scanning - // function annotations, so the current annotation_block belongs to this function. - let annotation_block = if has_any_annotations { - self.tables.last_mut().unwrap().annotation_block.take() - } else { - None - }; - self.enter_scope(name, scope_type, line_number); - // Move annotation_block to function scope only if we have one - if let Some(block) = annotation_block { - self.tables.last_mut().unwrap().annotation_block = Some(block); - } - // Fill scope with parameter names: self.scan_parameters(¶meters.posonlyargs)?; self.scan_parameters(¶meters.args)?; @@ -2871,12 +3396,23 @@ impl SymbolTableBuilder { location, }); } - CompilerScope::Annotation => { + CompilerScope::TypeAlias => { + return Err(SymbolTableError { + error: + "assignment expression within a comprehension cannot be used in a type alias" + .to_string(), + location, + }); + } + CompilerScope::TypeVariable => { return Err(SymbolTableError { - error: "named expression cannot be used within an annotation".to_string(), + error: + "assignment expression within a comprehension cannot be used in a TypeVar bound" + .to_string(), location, }); } + CompilerScope::Annotation => {} CompilerScope::Comprehension => unreachable!(), } } @@ -2896,10 +3432,27 @@ impl SymbolTableBuilder { .source_location(range.start(), PositionEncoding::Utf8); let location = Some(location); - // Note: __debug__ checks are handled by check_name function, so no check needed here. + // CPython's symtable_add_def_ctx() runs check_name() for definition + // roles covered by DEF_PARAM | DEF_LOCAL | DEF_IMPORT before adding + // the symbol. Several Rust callers reach register_name() directly + // instead of going through scan_expression(Name), so keep the guard here. + if matches!( + role, + SymbolUsage::Assigned + | SymbolUsage::Imported + | SymbolUsage::AnnotationAssigned + | SymbolUsage::Parameter + | SymbolUsage::AnnotationParameter + | SymbolUsage::AssignedNamedExprInComprehension + | SymbolUsage::Iter + | SymbolUsage::TypeParam + ) { + self.check_name(name, ExpressionContext::Store, range)?; + } let scope_depth = self.tables.len(); let table = self.tables.last_mut().unwrap(); + let current_scope = table.typ; // Add type param names to mangled_names set for selective mangling if matches!(role, SymbolUsage::TypeParam) @@ -2908,6 +3461,7 @@ impl SymbolTableBuilder { set.insert(name.to_owned()); } + let original_name = name; let name = maybe_mangle_name( self.class_name.as_deref(), table.mangled_names.as_ref(), @@ -2932,7 +3486,24 @@ impl SymbolTableBuilder { }); } + if matches!( + role, + SymbolUsage::Parameter | SymbolUsage::AnnotationParameter + ) && flags.contains(SymbolFlags::PARAMETER) + { + return Err(SymbolTableError { + error: format!("duplicate argument '{original_name}' in function definition"), + location, + }); + } + // Role already set.. + if matches!(role, SymbolUsage::TypeParam) && flags.contains(SymbolFlags::TYPE_PARAM) { + return Err(SymbolTableError { + error: format!("duplicate type parameter '{name}'"), + location, + }); + } match role { SymbolUsage::Global if !symbol.is_global() => { if flags.contains(SymbolFlags::PARAMETER) { @@ -2990,6 +3561,20 @@ impl SymbolTableBuilder { }); } } + SymbolUsage::AnnotationAssigned + if current_scope != CompilerScope::Module + && flags.intersects(SymbolFlags::GLOBAL | SymbolFlags::NONLOCAL) => + { + let usage = if flags.contains(SymbolFlags::GLOBAL) { + "global" + } else { + "nonlocal" + }; + return Err(SymbolTableError { + error: format!("annotated name '{name}' can't be {usage}"), + location, + }); + } _ => { // Ok? } @@ -3045,24 +3630,9 @@ impl SymbolTableBuilder { } SymbolUsage::Assigned => { flags.insert(SymbolFlags::ASSIGNED); - // Local variables (assigned) are added to varnames if they are local scope - // and not already in varnames - if symbol.scope == SymbolScope::Local { - let name_str = symbol.name.clone(); - if !self.current_varnames.contains(&name_str) { - self.current_varnames.push(name_str); - } - } } SymbolUsage::AssignedNamedExprInComprehension => { flags.insert(SymbolFlags::ASSIGNED | SymbolFlags::ASSIGNED_IN_COMPREHENSION); - // Named expressions in comprehensions might also be locals - if symbol.scope == SymbolScope::Local { - let name_str = symbol.name.clone(); - if !self.current_varnames.contains(&name_str) { - self.current_varnames.push(name_str); - } - } } SymbolUsage::Global => { symbol.scope = SymbolScope::GlobalExplicit; @@ -3072,7 +3642,7 @@ impl SymbolTableBuilder { flags.insert(SymbolFlags::REFERENCED); } SymbolUsage::Iter => { - flags.insert(SymbolFlags::ITER); + flags.insert(SymbolFlags::ITER | SymbolFlags::COMP_ITER); } SymbolUsage::TypeParam => { flags.insert(SymbolFlags::ASSIGNED | SymbolFlags::TYPE_PARAM); @@ -3134,7 +3704,27 @@ pub(crate) fn maybe_mangle_name<'a>( #[cfg(test)] mod tests { - use super::mangle_name; + use super::{CompilerScope, SymbolFlags, SymbolTable, mangle_name}; + use rustpython_compiler_core::SourceFileBuilder; + + fn scan_source(source: &str) -> SymbolTable { + scan_source_result(source).unwrap() + } + + fn scan_source_result(source: &str) -> Result { + let source_file = SourceFileBuilder::new("source_path", source).finish(); + let parsed = ruff_python_parser::parse( + source_file.source_text(), + ruff_python_parser::Mode::Module.into(), + ) + .unwrap() + .into_syntax(); + let module = match parsed { + ruff_python_ast::Mod::Module(module) => module, + _ => unreachable!(), + }; + SymbolTable::scan_program(&module, source_file) + } #[test] fn mangle_name_leaves_private_name_in_underscore_only_class() { @@ -3148,4 +3738,470 @@ mod tests { assert_eq!(mangle_name(Some("_a"), "__a"), "_a__a"); assert_eq!(mangle_name(Some("__a"), "__a"), "_a__a"); } + + #[test] + fn duplicate_parameter_check_uses_mangled_name_like_cpython() { + let err = scan_source_result("class C:\n def f(__x, _C__x):\n pass\n") + .expect_err("expected duplicate argument after class-private mangling"); + + assert_eq!( + err.error, + "duplicate argument '_C__x' in function definition" + ); + } + + #[test] + fn super_name_marks_class_use_in_lambda_scope_like_cpython() { + let table = scan_source("def f():\n return lambda: super()\n"); + let function = table + .sub_tables + .iter() + .find(|table| table.name == "f") + .expect("missing function scope"); + let lambda = function + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::Lambda) + .expect("missing lambda scope"); + + assert!( + lambda.lookup("__class__").is_some(), + "CPython symtable Name_kind treats super as a __class__ use in any function-like scope" + ); + } + + #[test] + fn comprehension_iteration_target_sets_comp_iter_flag_like_cpython() { + let table = scan_source("result = [i for i in xs]\n"); + let comprehension = table + .inlined_comprehension_blocks + .iter() + .find(|table| table.typ == CompilerScope::Comprehension) + .expect("missing comprehension scope"); + let symbol = comprehension + .lookup("i") + .expect("missing comprehension iteration target"); + + assert!( + symbol.flags.contains(SymbolFlags::COMP_ITER), + "CPython symtable_add_def_helper sets DEF_COMP_ITER on comprehension iteration targets" + ); + } + + #[test] + fn inlined_comprehension_children_are_spliced_like_cpython() { + let table = scan_source("result = [(lambda: i) for i in xs]\n"); + + assert!( + !table + .sub_tables + .iter() + .any(|table| table.typ == CompilerScope::Comprehension), + "CPython removes inlined comprehension entries from ste_children" + ); + assert!( + table + .sub_tables + .iter() + .any(|table| table.typ == CompilerScope::Lambda), + "CPython splices children of inlined comprehensions into the parent children list" + ); + + let comprehension = table + .inlined_comprehension_blocks + .iter() + .find(|table| table.typ == CompilerScope::Comprehension) + .expect("missing inlined comprehension block"); + assert!( + comprehension.comp_inlined, + "CPython keeps the comprehension entry addressable through st_blocks with ste_comp_inlined set" + ); + } + + #[test] + fn future_annotations_annassign_still_scans_annotation_symbols_like_cpython() { + let table = scan_source("from __future__ import annotations\nx: T\n"); + let annotation_block = table + .annotation_block + .as_ref() + .expect("CPython still creates an AnnotationBlock for future annotations"); + + assert!( + annotation_block.lookup("T").is_some(), + "CPython symtable_visit_annotation still visits the annotation expression with future annotations" + ); + } + + #[test] + fn annotation_like_format_parameter_is_marked_used_like_cpython() { + let table = scan_source("def f(x: T): pass\n"); + let annotation_block = table + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::Annotation) + .expect("missing function annotation block"); + let format = annotation_block + .lookup(".format") + .expect("missing annotation .format parameter"); + assert!( + format + .flags + .contains(SymbolFlags::PARAMETER | SymbolFlags::REFERENCED), + "CPython symtable_enter_block() adds both DEF_PARAM and USE for annotation-like .format" + ); + + let table = scan_source("type A = T\n"); + let alias = table + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::TypeAlias) + .expect("missing type alias scope"); + let format = alias + .lookup(".format") + .expect("missing type alias .format parameter"); + assert!( + format + .flags + .contains(SymbolFlags::PARAMETER | SymbolFlags::REFERENCED), + "CPython TypeAliasBlock .format has DEF_PARAM | USE" + ); + + let table = scan_source("def f[T: B](): pass\n"); + let type_params = table + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::TypeParams) + .expect("missing type params scope"); + let type_variable = type_params + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::TypeVariable) + .expect("missing type variable scope"); + let format = type_variable + .lookup(".format") + .expect("missing type variable .format parameter"); + assert!( + format + .flags + .contains(SymbolFlags::PARAMETER | SymbolFlags::REFERENCED), + "CPython TypeVariableBlock .format has DEF_PARAM | USE" + ); + } + + #[test] + fn function_signature_annotation_block_is_sibling_like_cpython() { + let table = scan_source("def f(x: T): pass\n"); + assert_eq!(table.sub_tables[0].typ, CompilerScope::Annotation); + assert!(table.sub_tables[0].annotations_used); + assert_eq!(table.sub_tables[1].typ, CompilerScope::Function); + assert!( + table.sub_tables[1].annotation_block.is_none(), + "CPython stores the function signature AnnotationBlock as a child keyed by arguments, not on the function block" + ); + + let table = scan_source("def f(x): pass\n"); + assert_eq!(table.sub_tables[0].typ, CompilerScope::Annotation); + assert!(!table.sub_tables[0].annotations_used); + assert_eq!(table.sub_tables[1].typ, CompilerScope::Function); + } + + #[test] + fn future_function_signature_annotation_block_is_hidden_like_cpython() { + let table = scan_source("from __future__ import annotations\ndef f(x: T): pass\n"); + assert_eq!(table.sub_tables[0].typ, CompilerScope::Function); + assert_eq!( + table.hidden_annotation_blocks[0].typ, + CompilerScope::Annotation + ); + assert!(table.hidden_annotation_blocks[0].annotations_used); + assert!( + table.sub_tables[0].annotation_block.is_none(), + "CPython future AnnotationBlock stays in st_blocks and is not attached to the FunctionBlock" + ); + + let table = scan_source("from __future__ import annotations\ndef f(x): pass\n"); + assert_eq!(table.sub_tables[0].typ, CompilerScope::Function); + assert_eq!( + table.hidden_annotation_blocks[0].typ, + CompilerScope::Annotation + ); + assert!(!table.hidden_annotation_blocks[0].annotations_used); + } + + #[test] + fn annassign_marks_current_scope_annotations_used_like_cpython() { + let table = scan_source("x: int\n"); + assert!( + table.annotations_used, + "CPython AnnAssign_kind sets ste_annotations_used on the current scope" + ); + + let table = scan_source("class C:\n x: int\n"); + let class = table + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::Class) + .expect("missing class scope"); + assert!( + class.annotations_used, + "CPython AnnAssign_kind sets ste_annotations_used on class scopes" + ); + + let table = scan_source("def f():\n x: int\n"); + let function = table + .sub_tables + .iter() + .find(|table| table.typ == CompilerScope::Function) + .expect("missing function scope"); + assert!( + function.annotations_used, + "CPython AnnAssign_kind also marks function-local annotations" + ); + } + + #[test] + fn class_base_child_scope_precedes_class_scope_like_cpython() { + let table = scan_source("class C((lambda: Base)()):\n pass\n"); + assert_eq!(table.sub_tables[0].typ, CompilerScope::Lambda); + assert_eq!(table.sub_tables[1].typ, CompilerScope::Class); + } + + #[test] + fn try_handler_child_scope_precedes_else_scope_like_cpython() { + let table = scan_source( + "\ +def f(x): + try: + pass + except Exception: + y = 1 + def h(): + return y + else: + def e(): + return x +", + ); + let function = table + .sub_tables + .iter() + .find(|table| table.name == "f") + .expect("missing function scope"); + + let function_child_names = function + .sub_tables + .iter() + .filter(|table| table.typ == CompilerScope::Function) + .map(|table| table.name.as_str()) + .collect::>(); + assert_eq!(function_child_names, vec!["h", "e"]); + } + + #[test] + fn function_default_child_scope_precedes_decorator_scope_like_cpython() { + let table = scan_source( + "\ +@(lambda decorator_arg: decorator_arg) +def f(x=(lambda: 1)()): + pass +", + ); + let lambdas = table + .sub_tables + .iter() + .filter(|table| table.typ == CompilerScope::Lambda) + .collect::>(); + + assert_eq!(lambdas.len(), 2); + assert!( + lambdas[0].varnames.is_empty(), + "CPython symtable visits function defaults before decorators" + ); + assert_eq!(lambdas[1].varnames, vec!["decorator_arg"]); + } + + #[test] + fn future_annotations_still_rejects_named_expr_in_annotation_like_cpython() { + let err = + scan_source_result("from __future__ import annotations\nx: (y := int)\n").unwrap_err(); + + assert_eq!( + err.error, + "named expression cannot be used within an annotation" + ); + } + + #[test] + fn import_star_outside_module_uses_cpython_symtable_message() { + let err = scan_source_result("def f():\n from m import *\n").unwrap_err(); + + assert_eq!(err.error, "import * only allowed at module level"); + } + + #[test] + fn import_as_error_location_uses_alias_location_like_cpython() { + let source = "import module as __debug__\n"; + let err = scan_source_result(source).unwrap_err(); + + assert_eq!(err.error, "cannot assign to __debug__"); + let location = err.location.unwrap(); + assert_eq!(location.line.get(), 1); + assert_eq!( + location.character_offset.get(), + 8, + "CPython reports LOCATION(a) for import aliases, at the imported name" + ); + } + + #[test] + fn function_def_error_location_uses_statement_location_like_cpython() { + let source = "def __debug__():\n pass\n"; + let err = scan_source_result(source).unwrap_err(); + + assert_eq!(err.error, "cannot assign to __debug__"); + let location = err.location.unwrap(); + assert_eq!(location.line.get(), 1); + assert_eq!( + location.character_offset.get(), + 1, + "CPython reports LOCATION(s) for FunctionDef, at 'def'" + ); + } + + #[test] + fn global_after_assign_error_location_uses_statement_location_like_cpython() { + let source = "def f():\n x = 1\n global x\n"; + let err = scan_source_result(source).unwrap_err(); + + assert_eq!( + err.error, + "name 'x' is assigned to before global declaration" + ); + let location = err.location.unwrap(); + assert_eq!(location.line.get(), 3); + assert_eq!( + location.character_offset.get(), + 5, + "CPython reports LOCATION(s) for global directives, at 'global'" + ); + } + + #[test] + fn type_param_debug_name_is_checked_like_cpython_add_def_ctx() { + let source = "class C[__debug__]:\n pass\n"; + let err = scan_source_result(source).unwrap_err(); + + assert_eq!(err.error, "cannot assign to __debug__"); + let location = err.location.unwrap(); + assert_eq!(location.line.get(), 1); + assert_eq!( + location.character_offset.get(), + 9, + "CPython symtable_add_def_ctx checks DEF_TYPE_PARAM | DEF_LOCAL at LOCATION(tp)" + ); + } + + #[test] + fn except_handler_name_error_location_uses_handler_location_like_cpython() { + let source = "try:\n pass\nexcept Exception as __debug__:\n pass\n"; + let err = scan_source_result(source).unwrap_err(); + + assert_eq!(err.error, "cannot assign to __debug__"); + let location = err.location.unwrap(); + assert_eq!(location.line.get(), 3); + assert_eq!( + location.character_offset.get(), + 1, + "CPython reports LOCATION(eh) for except-handler names, at 'except'" + ); + } + + #[test] + fn match_star_capture_error_location_uses_pattern_location_like_cpython() { + let source = "match subject:\n case [*__debug__]:\n pass\n"; + let err = scan_source_result(source).unwrap_err(); + + assert_eq!(err.error, "cannot assign to __debug__"); + let location = err.location.unwrap(); + assert_eq!(location.line.get(), 2); + assert_eq!( + location.character_offset.get(), + 11, + "CPython reports LOCATION(p) for MatchStar, at the '*'" + ); + } + + #[test] + fn named_expr_in_lambda_inside_comprehension_iter_is_rejected_like_cpython() { + let err = scan_source_result("[x for x in (lambda: (y := 1))()]\n").unwrap_err(); + + assert_eq!( + err.error, + "assignment expression cannot be used in a comprehension iterable expression" + ); + } + + #[test] + fn yield_in_lambda_inside_comprehension_body_is_not_comprehension_yield_like_cpython() { + scan_source_result("[(lambda: (yield x)) for x in xs]\n").expect( + "CPython checks ste_comprehension on the current lambda block, not the enclosing comprehension", + ); + } + + #[test] + fn yield_in_comprehension_scans_value_before_comprehension_error_like_cpython() { + let err = scan_source_result("[(yield (x := 1)) for x in xs]\n").unwrap_err(); + + assert_eq!( + err.error, + "assignment expression cannot rebind comprehension iteration variable 'x'" + ); + } + + #[test] + fn named_expr_in_function_annotation_comprehension_is_allowed_like_cpython() { + scan_source_result("def f(x: [(y := int) for _ in xs]): pass\n").expect( + "CPython skips AnnotationBlock while extending namedexpr scope from a comprehension", + ); + } + + #[test] + fn named_expr_in_class_annotation_comprehension_uses_cpython_message() { + let err = scan_source_result("class C:\n x: [(y := int) for _ in xs]\n").unwrap_err(); + + assert_eq!( + err.error, + "assignment expression within a comprehension cannot be used in a class body" + ); + } + + #[test] + fn named_expr_in_type_alias_comprehension_uses_cpython_message() { + let err = scan_source_result("type A = [(y := int) for _ in xs]\n").unwrap_err(); + + assert_eq!( + err.error, + "assignment expression within a comprehension cannot be used in a type alias" + ); + } + + #[test] + fn named_expr_in_type_parameters_block_uses_cpython_message() { + let err = scan_source_result("class C[T]((base := object)): pass\n").unwrap_err(); + + assert_eq!( + err.error, + "named expression cannot be used within the definition of a generic" + ); + } + + #[test] + fn named_expr_in_typevar_bound_comprehension_uses_cpython_message() { + let err = scan_source_result("def f[T: [(y := int) for _ in xs]](): pass\n").unwrap_err(); + + assert_eq!( + err.error, + "assignment expression within a comprehension cannot be used in a TypeVar bound" + ); + } } diff --git a/crates/codegen/src/unparse.rs b/crates/codegen/src/unparse.rs index d7f754e2f9d..8f05d884e00 100644 --- a/crates/codegen/src/unparse.rs +++ b/crates/codegen/src/unparse.rs @@ -253,8 +253,12 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { range: _range, }) => { self.p("{")?; - self.unparse_expr(key, precedence::TEST)?; - self.p(": ")?; + if let Some(key) = key { + self.unparse_expr(key, precedence::TEST)?; + self.p(": ")?; + } else { + self.p("**")?; + } self.unparse_expr(value, precedence::TEST)?; self.unparse_comp(generators)?; self.p("}")?; @@ -554,7 +558,9 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { let buffered = fmt::from_fn(|f| Unparser::new(f, self.source).unparse_expr(val, precedence::TEST + 1)) .to_string(); - if let Some(ast::DebugText { leading, trailing }) = debug_text { + if let Some(debug_text) = debug_text { + let leading = debug_text.leading(); + let trailing = debug_text.trailing(); self.p(leading)?; self.p(self.source.slice(val.range()))?; self.p(trailing)?; diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index ee6b6e5d96c..c01f2f05739 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -468,6 +468,13 @@ bitflags! { const COROUTINE = 0x0080; const ITERABLE_COROUTINE = 0x0100; const ASYNC_GENERATOR = 0x0200; + const FUTURE_DIVISION = 0x20000; + const FUTURE_ABSOLUTE_IMPORT = 0x40000; + const FUTURE_WITH_STATEMENT = 0x80000; + const FUTURE_PRINT_FUNCTION = 0x100000; + const FUTURE_UNICODE_LITERALS = 0x200000; + const FUTURE_BARRY_AS_BDFL = 0x400000; + const FUTURE_GENERATOR_STOP = 0x800000; const FUTURE_ANNOTATIONS = 0x1000000; /// If a code object represents a function and has a docstring, /// this bit is set and the first item in co_consts is the docstring. diff --git a/crates/compiler-core/src/bytecode/instruction.rs b/crates/compiler-core/src/bytecode/instruction.rs index 69714a0fe66..86f602087ab 100644 --- a/crates/compiler-core/src/bytecode/instruction.rs +++ b/crates/compiler-core/src/bytecode/instruction.rs @@ -754,9 +754,8 @@ impl Opcode { /// Stack effect when the instruction takes its branch (jump=true). /// /// CPython equivalent: `stack_effect(opcode, oparg, jump=True)`. - /// For most instructions this equals the fallthrough effect. - /// Override for instructions where branch and fallthrough differ - /// (e.g. [`Self::ForIter`]: fallthrough = +1, branch = −1). + /// Current CPython opcode metadata has the same real-opcode stack effect + /// for jump and fallthrough stack-depth calculation. #[must_use] pub fn stack_effect_jump(&self, oparg: u32) -> i32 { self.stack_effect(oparg) @@ -1415,6 +1414,26 @@ mod tests { assert!(!AnyInstruction::from(PseudoOpcode::Jump).has_const()); } + #[test] + fn stack_effects_match_cpython_opcode_metadata() { + assert_eq!(Opcode::ForIter.stack_effect_info(0).popped(), 1); + assert_eq!(Opcode::ForIter.stack_effect_info(0).pushed(), 2); + assert_eq!(Opcode::ForIter.stack_effect(0), 1); + assert_eq!(Opcode::ForIter.stack_effect_jump(0), 1); + + assert_eq!(Opcode::EndAsyncFor.stack_effect_info(0).popped(), 2); + assert_eq!(Opcode::EndAsyncFor.stack_effect_info(0).pushed(), 0); + assert_eq!(Opcode::PopJumpIfFalse.stack_effect(0), -1); + assert_eq!(Opcode::PopJumpIfFalse.stack_effect_jump(0), -1); + + assert_eq!(PseudoOpcode::SetupFinally.stack_effect_info(0).pushed(), 1); + assert_eq!(PseudoOpcode::SetupFinally.stack_effect(0), 0); + assert_eq!(PseudoOpcode::SetupFinally.stack_effect_jump(0), 1); + assert_eq!(PseudoOpcode::SetupCleanup.stack_effect_info(0).pushed(), 2); + assert_eq!(PseudoOpcode::SetupCleanup.stack_effect(0), 0); + assert_eq!(PseudoOpcode::SetupCleanup.stack_effect_jump(0), 2); + } + #[test] fn no_fallthrough_flags_match_cpython_basicblock_nofallthrough() { assert!(Opcode::JumpForward.is_no_fallthrough()); diff --git a/crates/compiler-core/src/bytecode/oparg.rs b/crates/compiler-core/src/bytecode/oparg.rs index 03628604a3f..8f706003091 100644 --- a/crates/compiler-core/src/bytecode/oparg.rs +++ b/crates/compiler-core/src/bytecode/oparg.rs @@ -777,19 +777,20 @@ oparg_enum!( #[derive(Copy, Clone)] pub struct UnpackExArgs { pub before: u8, - pub after: u8, + pub after: u32, } impl From for UnpackExArgs { fn from(value: u32) -> Self { - let [before, after, ..] = value.to_le_bytes(); + let before = (value & 0xFF) as u8; + let after = value >> 8; Self { before, after } } } impl From for u32 { fn from(value: UnpackExArgs) -> Self { - Self::from_le_bytes([value.before, value.after, 0, 0]) + Self::from(value.before) | (value.after << 8) } } diff --git a/crates/compiler/src/lib.rs b/crates/compiler/src/lib.rs index 9c0884c7520..c4e80e86d8b 100644 --- a/crates/compiler/src/lib.rs +++ b/crates/compiler/src/lib.rs @@ -1,244 +1,5119 @@ pub use ruff_python_ast::token::TokenKind; -use ruff_python_parser::{LexicalErrorType, ParseErrorType}; +use ruff_python_parser::ParseErrorType; use ruff_source_file::{PositionEncoding, SourceFile, SourceFileBuilder, SourceLocation}; -use ruff_text_size::TextSlice; +use ruff_text_size::{Ranged, TextSlice}; +use rustpython_codegen::{compile, symboltable}; use thiserror::Error; -use rustpython_codegen::{compile, symboltable}; +pub use rustpython_codegen::compile::CompileOpts; +pub use rustpython_compiler_core::{Mode, bytecode::CodeObject}; + +// these modules are out of repository. re-exporting them here for convenience. +pub use ruff_python_ast as ast; +pub use ruff_python_parser as parser; +pub use rustpython_codegen as codegen; +pub use rustpython_compiler_core as core; + +#[derive(Error, Debug)] +pub enum CompileErrorType { + #[error(transparent)] + Codegen(#[from] codegen::error::CodegenErrorType), + #[error(transparent)] + Parse(#[from] ParseErrorType), +} + +#[derive(Error, Debug)] +pub struct ParseError { + #[source] + pub error: ParseErrorType, + pub raw_location: ruff_text_size::TextRange, + pub location: SourceLocation, + pub end_location: SourceLocation, + pub source_path: String, + /// Set when the error is an unclosed bracket (converted from EOF). + pub is_unclosed_bracket: bool, +} + +impl ::core::fmt::Display for ParseError { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + self.error.fmt(f) + } +} + +#[derive(Error, Debug)] +pub enum CompileError { + #[error(transparent)] + Codegen(#[from] codegen::error::CodegenError), + #[error(transparent)] + Parse(#[from] ParseError), +} + +impl CompileError { + #[must_use] + pub fn from_ruff_parse_error(error: parser::ParseError, source_file: &SourceFile) -> Self { + let source_code = source_file.to_source_code(); + let source_text = source_file.source_text(); + let invalid_number = invalid_number_literal_error(source_text); + let invalid_legacy_statement = invalid_legacy_statement_error(source_text); + let non_printable_character = non_printable_character_error(source_text); + let invalid_interpolated_string = invalid_interpolated_string_error(source_text); + let unterminated_string = unterminated_string_error(source_text); + let bracket_error = bracket_syntax_error(source_text); + let indented_block = expected_indented_block_error(&error, source_text); + let invalid_type_param = invalid_type_param_error(source_text); + let invalid_comprehension = invalid_comprehension_error(source_text); + let invalid_parameter_star_annotation = + invalid_parameter_star_annotation_error(source_text); + let invalid_parameter_list = invalid_parameter_list_error(source_text); + let invalid_call_argument = invalid_call_argument_error(source_text); + let invalid_dict = invalid_dict_error(source_text); + let invalid_collection_assignment = invalid_collection_assignment_error(source_text); + let invalid_group = invalid_group_error(source_text); + let invalid_def_type_params = invalid_def_type_params_error(source_text); + let invalid_expression = invalid_expression_error(source_text); + let invalid_named_expression = invalid_named_expression_error(source_text); + let invalid_plain_assignment = invalid_plain_assignment_error(source_text); + let expression_assignment = expression_assignment_error(source_text); + let invalid_annotation_target = invalid_annotation_target_error(source_text); + let invalid_assignment_target = invalid_assignment_target_error(source_text); + let invalid_augassign_target = invalid_augassign_target_error(source_text); + let invalid_for_target = invalid_for_target_error(source_text); + let invalid_with_target = invalid_with_target_error(source_text); + let invalid_delete_target = invalid_delete_target_error(source_text); + let invalid_standalone_except = invalid_standalone_except_error(source_text); + let invalid_import_statement = invalid_import_statement_error(source_text); + let invalid_import_target = invalid_import_target_error(source_text); + let invalid_except_as_target = invalid_except_as_target_error(source_text); + let invalid_match_mapping_rest_wildcard = + invalid_match_mapping_rest_wildcard_error(source_text); + let invalid_match_as_target = invalid_match_as_target_error(source_text); + let invalid_for_if_clause = invalid_for_if_clause_error(source_text); + let invalid_if_expression_statement = invalid_if_expression_statement_error(source_text); + let invalid_else_elif = invalid_else_elif_error(source_text); + let mixed_except_handlers = mixed_except_handlers_error(source_text); + let missing_comma_between_literals = matches!( + &error.error, + parser::ParseErrorType::ExpectedToken { expected, found } + if matches!((expected, found), (TokenKind::Comma, TokenKind::Int)) + ); + + // For EOF errors (unclosed brackets), find the unclosed bracket position + // and adjust both the error location and message + let mut is_unclosed_bracket = false; + let (error_type, location, end_location) = if let Some((message, offset)) = invalid_number { + let text_size = ruff_text_size::TextSize::new(offset as u32); + let loc = source_code.source_location(text_size, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, loc) + } else if let Some((message, start, end)) = invalid_legacy_statement { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = non_printable_character { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_interpolated_string { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end, unclosed)) = bracket_error { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + is_unclosed_bracket = unclosed; + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if matches!( + &error.error, + parser::ParseErrorType::Lexical(parser::LexicalErrorType::LineContinuationError) + ) { + let start = error.location.start() + ruff_text_size::TextSize::from(1); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + (error.error, loc, loc) + } else if let Some((message, start, end)) = unterminated_string { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = indented_block { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if matches!( + &error.error, + parser::ParseErrorType::Lexical(parser::LexicalErrorType::Eof) + ) { + if let Some((bracket_char, bracket_offset)) = find_unclosed_bracket(source_text) { + let bracket_text_size = ruff_text_size::TextSize::new(bracket_offset as u32); + let loc = source_code.source_location(bracket_text_size, PositionEncoding::Utf8); + let end_loc = SourceLocation { + line: loc.line, + character_offset: loc.character_offset.saturating_add(1), + }; + let msg = format!("'{bracket_char}' was never closed"); + is_unclosed_bracket = true; + (parser::ParseErrorType::OtherError(msg), loc, end_loc) + } else { + let loc = + source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let line_idx = loc.line.to_zero_indexed(); + let line = source_text.split('\n').nth(line_idx).unwrap_or(""); + let line_end_col = line.chars().count() + 1; // 1-indexed, past last char + let end_loc = SourceLocation { + line: loc.line, + character_offset: ruff_source_file::OneIndexed::new(line_end_col) + .unwrap_or(loc.character_offset), + }; + (error.error, end_loc, end_loc) + } + } else if let Some((message, start, end)) = invalid_type_param { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_comprehension { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_parameter_star_annotation { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_parameter_list { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_call_argument { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if missing_comma_between_literals { + let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + let msg = "invalid syntax. Perhaps you forgot a comma?".into(); + (parser::ParseErrorType::OtherError(msg), loc, end_loc) + } else if let Some((message, start, end)) = invalid_dict { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_collection_assignment { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_group { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_def_type_params { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_expression { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_named_expression { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_plain_assignment { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = expression_assignment { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_annotation_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_assignment_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_augassign_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_for_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_with_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_delete_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_standalone_except { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_import_statement { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_import_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_except_as_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_match_mapping_rest_wildcard { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_match_as_target { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_for_if_clause { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_if_expression_statement { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = invalid_else_elif { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if let Some((message, start, end)) = mixed_except_handlers { + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let loc = source_code.source_location(start, PositionEncoding::Utf8); + let end_loc = source_code.source_location(end, PositionEncoding::Utf8); + (parser::ParseErrorType::OtherError(message), loc, end_loc) + } else if matches!( + &error.error, + parser::ParseErrorType::Lexical(parser::LexicalErrorType::IndentationError) + ) { + // For IndentationError, point the offset to the end of the line content + // instead of the beginning + let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let line_idx = loc.line.to_zero_indexed(); + let line = source_text.split('\n').nth(line_idx).unwrap_or(""); + let line_end_col = line.chars().count() + 1; // 1-indexed, past last char + let end_loc = SourceLocation { + line: loc.line, + character_offset: ruff_source_file::OneIndexed::new(line_end_col) + .unwrap_or(loc.character_offset), + }; + (error.error, end_loc, end_loc) + } else if matches!( + &error.error, + parser::ParseErrorType::InvalidAssignmentTarget + ) { + let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + let expr_str = source_file.source_text().slice(error.location); + + let msg = parser::parse_expression(expr_str).map_or_else( + |_| match expr_str { + "yield" => "assignment to yield expression not possible".into(), + _ => format!("cannot assign to {expr_str}"), + }, + |parsed| match *parsed.syntax().body { + ast::Expr::Call(_) => "cannot assign to function call".into(), + ast::Expr::BinOp(_) => "cannot assign to expression".into(), + ast::Expr::If(_) => "cannot assign to conditional expression".into(), + ast::Expr::Generator(_) => "cannot assign to generator expression".into(), + ast::Expr::FString(_) => "invalid syntax".into(), + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::NumberLiteral(_) => { + "cannot assign to literal here. Maybe you meant '==' instead of '='?".into() + } + ast::Expr::EllipsisLiteral(_) => { + "cannot assign to ellipsis here. Maybe you meant '==' instead of '='?" + .into() + } + _ => format!("cannot assign to {expr_str}"), + }, + ); + + (parser::ParseErrorType::OtherError(msg), loc, end_loc) + } else if matches!( + &error.error, + parser::ParseErrorType::InvalidNamedAssignmentTarget + ) { + let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + let target = source_file.source_text().slice(error.location); + let msg = format!("cannot use assignment expressions with {target}"); + (parser::ParseErrorType::OtherError(msg), loc, end_loc) + } else { + let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + (error.error, loc, end_loc) + }; + + Self::Parse(ParseError { + error: error_type, + raw_location: error.location, + location, + end_location, + source_path: source_file.name().to_owned(), + is_unclosed_bracket, + }) + } + + fn from_source_error( + source_file: &SourceFile, + message: String, + start: usize, + end: usize, + ) -> Self { + let source_code = source_file.to_source_code(); + let start = ruff_text_size::TextSize::new(start as u32); + let end = ruff_text_size::TextSize::new(end as u32); + let location = source_code.source_location(start, PositionEncoding::Utf8); + let end_location = source_code.source_location(end, PositionEncoding::Utf8); + Self::Parse(ParseError { + error: parser::ParseErrorType::OtherError(message), + raw_location: ruff_text_size::TextRange::new(start, end), + location, + end_location, + source_path: source_file.name().to_owned(), + is_unclosed_bracket: false, + }) + } + + #[must_use] + pub const fn location(&self) -> Option { + match self { + Self::Codegen(codegen_error) => codegen_error.location, + Self::Parse(parse_error) => Some(parse_error.location), + } + } + + #[must_use] + pub const fn python_location(&self) -> (usize, usize) { + if let Some(location) = self.location() { + (location.line.get(), location.character_offset.get()) + } else { + (0, 0) + } + } + + #[must_use] + pub fn python_end_location(&self) -> Option<(usize, usize)> { + match self { + Self::Codegen(_) => None, + Self::Parse(parse_error) => Some(( + parse_error.end_location.line.get(), + parse_error.end_location.character_offset.get(), + )), + } + } + + #[must_use] + pub fn source_path(&self) -> &str { + match self { + Self::Codegen(codegen_error) => &codegen_error.source_path, + Self::Parse(parse_error) => &parse_error.source_path, + } + } +} + +fn is_ascii_identifier_char(byte: u8) -> bool { + byte == b'_' || byte.is_ascii_alphanumeric() +} + +fn numeric_keyword_suffix(rest: &[u8]) -> bool { + rest.starts_with(b"and") + || rest.starts_with(b"else") + || rest.starts_with(b"for") + || rest.starts_with(b"if") + || rest.starts_with(b"in") + || rest.starts_with(b"is") + || rest.starts_with(b"or") + || rest.starts_with(b"not") +} + +fn consume_decimal_digits(bytes: &[u8], mut index: usize) -> usize { + while index < bytes.len() { + match bytes[index] { + b'0'..=b'9' => index += 1, + b'_' if bytes + .get(index + 1) + .is_some_and(|byte| byte.is_ascii_digit()) => + { + index += 2; + } + _ => break, + } + } + index +} + +fn consume_radix_digits(bytes: &[u8], mut index: usize, is_digit: impl Fn(u8) -> bool) -> usize { + while index < bytes.len() { + if is_digit(bytes[index]) { + index += 1; + } else if bytes.get(index) == Some(&b'_') + && bytes.get(index + 1).is_some_and(|&byte| is_digit(byte)) + { + index += 2; + } else { + break; + } + } + index +} + +fn invalid_radix_literal_error( + bytes: &[u8], + start: usize, + kind: &'static str, + is_digit: impl Fn(u8) -> bool, +) -> Option<(String, usize)> { + let mut index = start + 2; + let mut has_digit = false; + loop { + let Some(&byte) = bytes.get(index) else { + return Some((format!("invalid {kind} literal"), start + 1)); + }; + if byte == b'_' { + let Some(&next) = bytes.get(index + 1) else { + return Some((format!("invalid {kind} literal"), index)); + }; + if is_digit(next) { + has_digit = true; + index += 2; + continue; + } + if next.is_ascii_digit() && matches!(kind, "binary" | "octal") { + return Some(( + format!("invalid digit '{}' in {kind} literal", next as char), + index + 1, + )); + } + return Some((format!("invalid {kind} literal"), index)); + } + if is_digit(byte) { + has_digit = true; + index += 1; + continue; + } + if byte.is_ascii_digit() && matches!(kind, "binary" | "octal") { + return Some(( + format!("invalid digit '{}' in {kind} literal", byte as char), + index, + )); + } + if has_digit { + return None; + } + return Some((format!("invalid {kind} literal"), start + 1)); + } +} + +fn decimal_tail_error(bytes: &[u8], mut index: usize) -> Option { + loop { + while bytes.get(index).is_some_and(|byte| byte.is_ascii_digit()) { + index += 1; + } + if bytes.get(index) != Some(&b'_') { + return None; + } + let underscore = index; + index += 1; + if !bytes.get(index).is_some_and(|byte| byte.is_ascii_digit()) { + return Some(underscore); + } + } +} + +fn decimal_tail_end(bytes: &[u8], mut index: usize) -> usize { + loop { + while bytes.get(index).is_some_and(|byte| byte.is_ascii_digit()) { + index += 1; + } + if bytes.get(index) == Some(&b'_') + && bytes + .get(index + 1) + .is_some_and(|byte| byte.is_ascii_digit()) + { + index += 2; + } else { + return index; + } + } +} + +fn invalid_decimal_literal_error(bytes: &[u8], start: usize) -> Option<(String, usize)> { + if bytes.get(start) == Some(&b'.') { + return None; + } + let message = "invalid decimal literal".to_owned(); + if let Some(offset) = decimal_tail_error(bytes, start) { + return Some((message, offset)); + } + + let mut index = decimal_tail_end(bytes, start); + if bytes.get(index) == Some(&b'.') { + if bytes.get(index + 1) == Some(&b'_') { + return Some((message, index)); + } + if let Some(offset) = decimal_tail_error(bytes, index + 1) { + return Some((message, offset)); + } + index = decimal_tail_end(bytes, index + 1); + } + if matches!(bytes.get(index), Some(b'e' | b'E')) { + let exponent = index; + index += 1; + let sign = if matches!(bytes.get(index), Some(b'+' | b'-')) { + let sign = index; + index += 1; + Some(sign) + } else { + None + }; + if !bytes.get(index).is_some_and(|byte| byte.is_ascii_digit()) { + return Some((message, sign.unwrap_or(exponent))); + } + if let Some(offset) = decimal_tail_error(bytes, index) { + return Some((message, offset)); + } + } + None +} + +fn leading_zero_decimal_literal_error(bytes: &[u8], start: usize) -> Option<(String, usize)> { + if bytes.get(start) != Some(&b'0') { + return None; + } + let mut index = start; + loop { + match bytes.get(index) { + Some(b'0') => index += 1, + Some(b'_') + if bytes + .get(index + 1) + .is_some_and(|byte| byte.is_ascii_digit()) => + { + index += 1; + } + _ => break, + } + } + if bytes.get(index).is_some_and(|byte| byte.is_ascii_digit()) { + let after_digits = decimal_tail_end(bytes, index); + if !matches!( + bytes.get(after_digits), + Some(b'.' | b'e' | b'E' | b'j' | b'J') + ) { + return Some(( + "leading zeros in decimal integer literals are not permitted; use an 0o prefix for octal integers".to_owned(), + start, + )); + } + } + None +} + +fn invalid_numeric_literal_error(bytes: &[u8], start: usize) -> Option<(String, usize)> { + if bytes.get(start) == Some(&b'0') { + match bytes.get(start + 1) { + Some(b'x' | b'X') => { + return invalid_radix_literal_error(bytes, start, "hexadecimal", |byte| { + byte.is_ascii_hexdigit() + }); + } + Some(b'o' | b'O') => { + return invalid_radix_literal_error(bytes, start, "octal", |byte| { + matches!(byte, b'0'..=b'7') + }); + } + Some(b'b' | b'B') => { + return invalid_radix_literal_error(bytes, start, "binary", |byte| { + matches!(byte, b'0' | b'1') + }); + } + _ => {} + } + if let Some(err) = leading_zero_decimal_literal_error(bytes, start) { + return Some(err); + } + } + invalid_decimal_literal_error(bytes, start) +} + +fn consume_exponent(bytes: &[u8], index: usize) -> usize { + if !matches!(bytes.get(index), Some(b'e' | b'E')) { + return index; + } + let mut cursor = index + 1; + if matches!(bytes.get(cursor), Some(b'+' | b'-')) { + cursor += 1; + } + if bytes.get(cursor).is_some_and(|byte| byte.is_ascii_digit()) { + consume_decimal_digits(bytes, cursor) + } else { + index + } +} + +fn number_literal_end(bytes: &[u8], start: usize) -> Option<(&'static str, usize)> { + if bytes.get(start) == Some(&b'.') { + if !bytes + .get(start + 1) + .is_some_and(|byte| byte.is_ascii_digit()) + { + return None; + } + let mut index = consume_decimal_digits(bytes, start + 1); + index = consume_exponent(bytes, index); + if matches!(bytes.get(index), Some(b'j' | b'J')) { + return Some(("imaginary", index + 1)); + } + return Some(("decimal", index)); + } + + if !bytes.get(start).is_some_and(|byte| byte.is_ascii_digit()) { + return None; + } + + if bytes.get(start) == Some(&b'0') { + match bytes.get(start + 1) { + Some(b'x' | b'X') => { + let end = consume_radix_digits(bytes, start + 2, |byte| byte.is_ascii_hexdigit()); + return Some(("hexadecimal", end)); + } + Some(b'o' | b'O') => { + let end = + consume_radix_digits(bytes, start + 2, |byte| matches!(byte, b'0'..=b'7')); + return Some(("octal", end)); + } + Some(b'b' | b'B') => { + let end = + consume_radix_digits(bytes, start + 2, |byte| matches!(byte, b'0' | b'1')); + return Some(("binary", end)); + } + _ => {} + } + } + + let mut index = consume_decimal_digits(bytes, start); + if bytes.get(index) == Some(&b'.') { + index = consume_decimal_digits(bytes, index + 1); + } + index = consume_exponent(bytes, index); + if matches!(bytes.get(index), Some(b'j' | b'J')) { + return Some(("imaginary", index + 1)); + } + Some(("decimal", index)) +} + +fn skip_quoted_string(bytes: &[u8], mut index: usize) -> usize { + let quote = bytes[index]; + let triple = bytes.get(index + 1) == Some("e) && bytes.get(index + 2) == Some("e); + let quote_len = if triple { 3 } else { 1 }; + index += quote_len; + while index < bytes.len() { + if bytes[index] == b'\\' { + index = (index + 2).min(bytes.len()); + } else if triple + && bytes.get(index) == Some("e) + && bytes.get(index + 1) == Some("e) + && bytes.get(index + 2) == Some("e) + { + return index + 3; + } else if !triple && bytes[index] == quote { + return index + 1; + } else { + index += 1; + } + } + index +} + +fn invalid_number_literal_error(source: &str) -> Option<(String, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + byte if byte >= 0x80 || byte == b'_' || byte.is_ascii_alphabetic() => { + index += 1; + while index < bytes.len() + && (bytes[index] >= 0x80 || is_ascii_identifier_char(bytes[index])) + { + index += 1; + } + } + b'.' | b'0'..=b'9' => { + if let Some(err) = invalid_numeric_literal_error(bytes, index) { + return Some(err); + } + let Some((kind, end)) = number_literal_end(bytes, index) else { + index += 1; + continue; + }; + if end > index { + if source[end..].starts_with('⁄') { + return Some(("invalid character '⁄' (U+2044)".to_owned(), end)); + } + if bytes + .get(end) + .is_some_and(|byte| *byte < 128 && is_ascii_identifier_char(*byte)) + && !numeric_keyword_suffix(&bytes[end..]) + { + return Some((format!("invalid {kind} literal"), end.saturating_sub(1))); + } + } + index = end.max(index + 1); + } + _ => index += 1, + } + } + None +} + +fn cpython_indented_block_clause(message: &str) -> Option<&'static str> { + let clause = message.strip_prefix("Expected an indented block after ")?; + Some(match clause { + "`if` statement" => "'if' statement", + "`elif` clause" => "'elif' statement", + "`else` clause" => "'else' statement", + "`for` statement" => "'for' statement", + "`with` statement" => "'with' statement", + "`while` statement" => "'while' statement", + "`try` statement" => "'try' statement", + "`except` clause" => "'except' statement", + "`finally` clause" => "'finally' statement", + "`match` statement" => "'match' statement", + "`case` block" => "'case' statement", + "`class` definition" => "class definition", + "function definition" => "function definition", + _ => return None, + }) +} + +fn previous_non_empty_line_number(source: &str, offset: usize) -> Option { + let bytes = source.as_bytes(); + let mut index = offset.min(bytes.len()); + while index > 0 { + let line_end = index; + while index > 0 && bytes[index - 1] != b'\n' { + index -= 1; + } + let line_start = index; + let content_start = skip_horizontal_whitespace(bytes, line_start); + let mut content_end = line_end; + while content_end > content_start + && matches!( + bytes.get(content_end - 1), + Some(b' ' | b'\t' | b'\r' | b'\x0c') + ) + { + content_end -= 1; + } + if content_start < content_end { + return Some( + source[..line_start] + .bytes() + .filter(|byte| *byte == b'\n') + .count() + + 1, + ); + } + index = line_start.saturating_sub(1); + } + None +} + +fn expected_indented_block_error( + error: &parser::ParseError, + source: &str, +) -> Option<(String, usize, usize)> { + let parser::ParseErrorType::OtherError(message) = &error.error else { + return None; + }; + let mut clause = cpython_indented_block_clause(message)?; + let start = error.location.start().to_usize(); + let end = error.location.end().to_usize(); + let line = previous_non_empty_line_number(source, start)?; + if clause == "'except' statement" + && let Some(previous_line) = previous_non_empty_line(source, start) + && matches!( + previous_line.trim_start(), + line if line.starts_with("except*") || line.starts_with("except *") + ) + { + clause = "'except*' statement"; + } + Some(( + format!("expected an indented block after {clause} on line {line}"), + start, + end, + )) +} + +fn previous_non_empty_line(source: &str, offset: usize) -> Option<&str> { + let bytes = source.as_bytes(); + let mut index = offset.min(bytes.len()); + while index > 0 { + let line_end = index; + while index > 0 && bytes[index - 1] != b'\n' { + index -= 1; + } + let line_start = index; + let mut content_start = line_start; + while content_start < line_end + && matches!(bytes[content_start], b' ' | b'\t' | b'\n' | b'\r' | b'\x0c') + { + content_start += 1; + } + let mut content_end = line_end; + while content_end > content_start + && matches!(bytes[content_end - 1], b' ' | b'\t' | b'\r' | b'\x0c') + { + content_end -= 1; + } + if content_start < content_end { + return source.get(line_start..line_end); + } + index = line_start.saturating_sub(1); + } + None +} + +fn starts_identifier(bytes: &[u8], index: usize, word: &[u8]) -> bool { + bytes.get(index..index + word.len()) == Some(word) + && index + .checked_sub(1) + .and_then(|before| bytes.get(before)) + .is_none_or(|byte| !is_ascii_identifier_char(*byte)) + && bytes + .get(index + word.len()) + .is_none_or(|byte| !is_ascii_identifier_char(*byte)) +} + +fn is_plain_assignment_operator(bytes: &[u8], index: usize) -> bool { + bytes.get(index) == Some(&b'=') + && bytes.get(index + 1) != Some(&b'=') + && !matches!( + index.checked_sub(1).and_then(|before| bytes.get(before)), + Some(b'=' | b'!' | b'<' | b'>' | b':') + ) +} + +fn is_simple_keyword_name(bytes: &[u8], mut start: usize, mut end: usize) -> bool { + while matches!( + bytes.get(start), + Some(b' ' | b'\t' | b'\n' | b'\r' | b'\x0c') + ) { + start += 1; + } + while end > start + && matches!( + bytes.get(end - 1), + Some(b' ' | b'\t' | b'\n' | b'\r' | b'\x0c') + ) + { + end -= 1; + } + let Some(&first) = bytes.get(start) else { + return false; + }; + if !(first == b'_' || first.is_ascii_alphabetic() || first >= 0x80) { + return false; + } + let mut index = start + 1; + while index < end { + if bytes[index] < 0x80 && !is_ascii_identifier_char(bytes[index]) { + return false; + } + index += 1; + } + true +} + +fn is_function_parameter_list(bytes: &[u8], paren: usize) -> bool { + let mut cursor = paren; + while cursor > 0 && matches!(bytes.get(cursor - 1), Some(b' ' | b'\t' | b'\x0c')) { + cursor -= 1; + } + if cursor > 0 && bytes.get(cursor - 1) == Some(&b']') { + let mut bracket = cursor; + let mut level = 0usize; + while bracket > 0 { + bracket -= 1; + match bytes[bracket] { + b']' => level += 1, + b'[' => { + level = level.saturating_sub(1); + if level == 0 { + cursor = bracket; + break; + } + } + _ => {} + } + } + while cursor > 0 && matches!(bytes.get(cursor - 1), Some(b' ' | b'\t' | b'\x0c')) { + cursor -= 1; + } + } + while cursor > 0 + && bytes + .get(cursor - 1) + .is_some_and(|byte| *byte >= 0x80 || is_ascii_identifier_char(*byte)) + { + cursor -= 1; + } + while cursor > 0 && matches!(bytes.get(cursor - 1), Some(b' ' | b'\t' | b'\x0c')) { + cursor -= 1; + } + cursor >= 3 + && starts_identifier(bytes, cursor - 3, b"def") + && cursor + .checked_sub(4) + .and_then(|before| bytes.get(before)) + .is_none_or(|byte| !is_ascii_identifier_char(*byte)) +} + +#[derive(Clone, Copy)] +enum ParameterListKind { + Function, + Lambda, +} + +fn matching_delimiter(bytes: &[u8], open: usize, close: u8) -> Option { + let mut index = open; + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + byte if byte == close => { + level = level.saturating_sub(1); + if level == 0 { + return Some(index); + } + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ => index += 1, + } + } + None +} + +fn find_lambda_parameter_end(bytes: &[u8], mut index: usize) -> Option { + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' if level == 0 => return None, + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b':' if level == 0 => return Some(index), + _ => index += 1, + } + } + None +} + +fn top_level_byte(bytes: &[u8], mut index: usize, end: usize, needle: u8) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + byte if level == 0 && byte == needle => return Some(index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ => index += 1, + } + } + None +} + +fn identifier_end(bytes: &[u8], mut index: usize, end: usize) -> usize { + if !bytes + .get(index) + .is_some_and(|byte| *byte >= 0x80 || *byte == b'_' || byte.is_ascii_alphabetic()) + { + return index; + } + index += 1; + while index < end + && bytes + .get(index) + .is_some_and(|byte| *byte >= 0x80 || is_ascii_identifier_char(*byte)) + { + index += 1; + } + index +} + +fn expression_slice_is_tuple(source: &str, start: usize, end: usize) -> bool { + let bytes = source.as_bytes(); + let (start, end) = trim_target_range(bytes, start, end); + if start >= end { + return false; + } + let Ok(parsed) = parser::parse(&source[start..end], parser::Mode::Expression.into()) else { + return false; + }; + matches!(parsed.into_syntax(), ast::Mod::Expression(expression) if matches!(*expression.body, ast::Expr::Tuple(_))) +} + +fn type_param_list_open(bytes: &[u8], open: usize) -> bool { + let mut cursor = open; + while cursor > 0 && matches!(bytes.get(cursor - 1), Some(b' ' | b'\t' | b'\x0c')) { + cursor -= 1; + } + while cursor > 0 + && bytes + .get(cursor - 1) + .is_some_and(|byte| *byte >= 0x80 || is_ascii_identifier_char(*byte)) + { + cursor -= 1; + } + while cursor > 0 && matches!(bytes.get(cursor - 1), Some(b' ' | b'\t' | b'\x0c')) { + cursor -= 1; + } + (cursor >= 3 && starts_identifier(bytes, cursor - 3, b"def")) + || (cursor >= 5 && starts_identifier(bytes, cursor - 5, b"class")) + || (cursor >= 4 && starts_identifier(bytes, cursor - 4, b"type")) +} + +fn invalid_type_param_item_error( + source: &str, + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (start, end) = trim_target_range(bytes, start, end); + if start >= end || bytes.get(start) != Some(&b'*') { + return None; + } + let is_param_spec = bytes.get(start + 1) == Some(&b'*'); + let name_start = start + if is_param_spec { 2 } else { 1 }; + let name_end = identifier_end(bytes, name_start, end); + if name_start == name_end { + return None; + } + let colon = next_non_horizontal_whitespace(bytes, name_end); + if colon >= end || bytes.get(colon) != Some(&b':') { + return None; + } + let has_constraints = expression_slice_is_tuple(source, colon + 1, end); + let message = match (is_param_spec, has_constraints) { + (false, false) => "cannot use bound with TypeVarTuple", + (false, true) => "cannot use constraints with TypeVarTuple", + (true, false) => "cannot use bound with ParamSpec", + (true, true) => "cannot use constraints with ParamSpec", + }; + Some((message.to_owned(), colon, colon + 1)) +} + +fn invalid_type_param_list_error( + source: &str, + open: usize, + close: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut item_start = open + 1; + let mut index = item_start; + let mut level = 0usize; + while index <= close { + if index == close || (level == 0 && bytes.get(index) == Some(&b',')) { + if let Some(error) = invalid_type_param_item_error(source, item_start, index) { + return Some(error); + } + item_start = index + 1; + index += 1; + continue; + } + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_type_param_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'[' if type_param_list_open(bytes, index) => { + let Some(close) = matching_delimiter(bytes, index, b']') else { + index += 1; + continue; + }; + if let Some(error) = invalid_type_param_list_error(source, index, close) { + return Some(error); + } + index = close + 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_comprehension_in_slice( + bytes: &[u8], + open: usize, + close: usize, +) -> Option<(String, usize, usize)> { + let for_index = find_keyword_at_level(bytes, open + 1, close, b"for")?; + let item_start = next_non_horizontal_whitespace(bytes, open + 1); + if item_start >= for_index { + return None; + } + if bytes.get(item_start..item_start + 2) == Some(b"**") && bytes.get(open) == Some(&b'{') { + return Some(( + "dict unpacking cannot be used in dict comprehension".to_owned(), + item_start, + item_start + 2, + )); + } + if bytes.get(item_start..item_start + 2) == Some(b"**") && bytes.get(open) == Some(&b'(') { + return Some(("invalid syntax".to_owned(), for_index, for_index + 3)); + } + if bytes.get(item_start) == Some(&b'*') { + return Some(( + "iterable unpacking cannot be used in comprehension".to_owned(), + item_start, + item_start + 1, + )); + } + if !matches!(bytes.get(open), Some(b'[' | b'{')) { + return None; + } + if top_level_colon(bytes, open + 1, for_index).is_none() + && let Some(comma) = top_level_byte(bytes, open + 1, for_index, b',') + { + let (start, _) = trim_target_range(bytes, open + 1, comma); + return Some(( + "did you forget parentheses around the comprehension target?".to_owned(), + start, + comma + 1, + )); + } + None +} + +fn invalid_comprehension_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + let close_byte = match bytes[index] { + b'(' => b')', + b'[' => b']', + _ => b'}', + }; + let Some(close) = matching_delimiter(bytes, index, close_byte) else { + index += 1; + continue; + }; + if let Some(error) = invalid_comprehension_in_slice(bytes, index, close) { + return Some(error); + } + index = close + 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_group_in_slice( + bytes: &[u8], + open: usize, + close: usize, +) -> Option<(String, usize, usize)> { + let (item_start, item_end) = trim_target_range(bytes, open + 1, close); + if item_start >= item_end + || top_level_byte(bytes, item_start, item_end, b',').is_some() + || top_level_colon(bytes, item_start, item_end).is_some() + || find_keyword_at_level(bytes, item_start, item_end, b"for").is_some() + { + return None; + } + if bytes.get(item_start..item_start + 2) == Some(b"**") { + return Some(( + "cannot use double starred expression here".to_owned(), + item_start, + item_start + 2, + )); + } + if bytes.get(item_start) == Some(&b'*') { + return Some(( + "cannot use starred expression here".to_owned(), + item_start, + item_start + 1, + )); + } + None +} + +fn invalid_group_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' => { + let Some(close) = matching_delimiter(bytes, index, b')') else { + index += 1; + continue; + }; + if let Some(error) = invalid_group_in_slice(bytes, index, close) { + return Some(error); + } + index = close + 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_parameter_star_annotation_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' => { + let Some(close) = matching_delimiter(bytes, index, b')') else { + index += 1; + continue; + }; + let mut param_start = index + 1; + while param_start < close { + let param_end = + find_byte_at_level(bytes, param_start, close, b',').unwrap_or(close); + if let Some(colon) = top_level_colon(bytes, param_start, param_end) { + let value_start = next_non_horizontal_whitespace(bytes, colon + 1); + if bytes.get(value_start) == Some(&b'*') { + return Some(( + "invalid syntax".to_owned(), + value_start, + value_start + 1, + )); + } + } + param_start = param_end.saturating_add(1); + } + index = close + 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_def_type_params_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + _ if starts_identifier(bytes, index, b"def") => { + let name_start = skip_horizontal_whitespace(bytes, index + 3); + let name_end = identifier_end(bytes, name_start, bytes.len()); + let bracket = skip_horizontal_whitespace(bytes, name_end); + if bytes.get(bracket) == Some(&b'[') { + let Some(close) = matching_delimiter(bytes, bracket, b']') else { + index = bracket + 1; + continue; + }; + let after_close = skip_horizontal_whitespace(bytes, close + 1); + if bytes.get(after_close) == Some(&b'(') + && type_param_list_is_malformed(bytes, bracket + 1, close) + { + return Some(("expected '('".to_owned(), bracket, bracket + 1)); + } + } + index = name_end.max(index + 3); + } + _ => index += 1, + } + } + None +} + +fn type_param_list_is_malformed(bytes: &[u8], start: usize, end: usize) -> bool { + let mut index = start; + let mut expect_item = true; + while index < end { + index = skip_horizontal_whitespace(bytes, index); + if index >= end { + break; + } + if bytes[index] == b',' { + if expect_item { + return true; + } + expect_item = true; + index += 1; + continue; + } + if !expect_item { + return true; + } + if bytes.get(index..index + 2) == Some(b"**") { + index += 2; + } else if bytes.get(index) == Some(&b'*') { + index += 1; + } + let item_start = skip_horizontal_whitespace(bytes, index); + let item_end = identifier_end(bytes, item_start, end); + if item_end == item_start { + return true; + } + index = item_end; + if bytes.get(skip_horizontal_whitespace(bytes, index)) == Some(&b':') { + index = skip_horizontal_whitespace(bytes, index) + 1; + while index < end && bytes[index] != b',' { + index = match bytes[index] { + b'\'' | b'"' => skip_quoted_string(bytes, index), + _ => index + 1, + }; + } + } + expect_item = false; + } + false +} + +fn invalid_parameter_list_slice_error( + source: &str, + start: usize, + end: usize, + kind: ParameterListKind, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = start; + let mut level = 0usize; + let mut default_seen = false; + let mut keyword_only = false; + let mut slash_seen = false; + let mut var_keyword_seen = false; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + _ if level == 0 + && var_keyword_seen + && bytes.get(index).is_some_and(|byte| { + *byte >= 0x80 || *byte == b'_' || byte.is_ascii_alphabetic() + }) => + { + let name_end = identifier_end(bytes, index, end); + return Some(( + "arguments cannot follow var-keyword argument".to_owned(), + index, + name_end, + )); + } + _ if level == 0 + && !keyword_only + && bytes.get(index).is_some_and(|byte| { + *byte >= 0x80 || *byte == b'_' || byte.is_ascii_alphabetic() + }) => + { + let param_end = find_byte_at_level(bytes, index, end, b',') + .or_else(|| top_level_byte(bytes, index, end, b')')) + .or_else(|| { + matches!(kind, ParameterListKind::Lambda) + .then(|| top_level_byte(bytes, index, end, b':')) + .flatten() + }) + .unwrap_or(end); + let name_end = identifier_end(bytes, index, param_end); + if top_level_byte(bytes, index, param_end, b'=').is_some() { + default_seen = true; + } else if default_seen { + return Some(( + "parameter without a default follows parameter with a default".to_owned(), + index, + name_end, + )); + } + index = name_end; + } + b'(' if level == 0 => { + let close = matching_delimiter(bytes, index, b')') + .filter(|close| *close <= end) + .unwrap_or(index + 1); + let message = match kind { + ParameterListKind::Function => "Function parameters cannot be parenthesized", + ParameterListKind::Lambda => { + "Lambda expression parameters cannot be parenthesized" + } + }; + return Some((message.to_owned(), index, close + 1)); + } + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b'/' if level == 0 => { + if var_keyword_seen { + return Some(( + "arguments cannot follow var-keyword argument".to_owned(), + index, + index + 1, + )); + } + if slash_seen { + return Some(("/ may appear only once".to_owned(), index, index + 1)); + } + slash_seen = true; + let next = next_non_horizontal_whitespace(bytes, index + 1); + if bytes.get(next) == Some(&b'*') { + return Some(("expected comma between / and *".to_owned(), next, next + 1)); + } + index += 1; + } + b'*' if level == 0 => { + if var_keyword_seen { + return Some(( + "arguments cannot follow var-keyword argument".to_owned(), + index, + index + 1, + )); + } + keyword_only = true; + let stars = usize::from(bytes.get(index + 1) == Some(&b'*')) + 1; + let name_start = next_non_horizontal_whitespace(bytes, index + stars); + for keyword in [b"True".as_slice(), b"False".as_slice(), b"None".as_slice()] { + if starts_identifier(bytes, name_start, keyword) { + return Some(( + "invalid syntax".to_owned(), + name_start, + name_start + keyword.len(), + )); + } + } + let param_end = find_byte_at_level(bytes, name_start, end, b',') + .or_else(|| top_level_byte(bytes, name_start, end, b')')) + .or_else(|| { + matches!(kind, ParameterListKind::Lambda) + .then(|| top_level_byte(bytes, name_start, end, b':')) + .flatten() + }) + .unwrap_or(end); + if stars == 1 && matches!(bytes.get(name_start), Some(b')' | b',' | b':')) { + return Some(( + "named arguments must follow bare *".to_owned(), + index, + index + 1, + )); + } + if stars == 1 && top_level_byte(bytes, name_start, param_end, b'=').is_some() { + return Some(( + "var-positional argument cannot have default value".to_owned(), + index, + index + 1, + )); + } + if stars == 2 && top_level_byte(bytes, name_start, param_end, b'=').is_some() { + return Some(( + "var-keyword argument cannot have default value".to_owned(), + index, + index + 2, + )); + } + if stars == 2 { + var_keyword_seen = true; + index = param_end; + continue; + } + index += stars; + } + b'=' if level == 0 => { + let value_start = next_non_horizontal_whitespace(bytes, index + 1); + if value_start >= end || matches!(bytes.get(value_start), Some(b',' | b')' | b':')) + { + if matches!(kind, ParameterListKind::Lambda) + && matches!(bytes.get(value_start), Some(b':')) + { + return Some(("invalid syntax".to_owned(), index, index + 1)); + } + return Some(( + "expected default value expression".to_owned(), + index, + index + 1, + )); + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_parameter_list_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + _ if starts_identifier(bytes, index, b"def") => { + let Some(paren) = top_level_byte(bytes, index + 3, bytes.len(), b'(') else { + index += 3; + continue; + }; + let Some(close) = matching_delimiter(bytes, paren, b')') else { + index = paren + 1; + continue; + }; + if let Some(error) = invalid_parameter_list_slice_error( + source, + paren + 1, + close, + ParameterListKind::Function, + ) { + return Some(error); + } + index = close + 1; + } + _ if starts_identifier(bytes, index, b"lambda") => { + let params_start = index + 6; + let Some(params_end) = find_lambda_parameter_end(bytes, params_start) else { + index = params_start; + continue; + }; + if let Some(error) = invalid_parameter_list_slice_error( + source, + params_start, + params_end, + ParameterListKind::Lambda, + ) { + return Some(error); + } + index = params_end + 1; + } + _ => index += 1, + } + } + None +} + +#[derive(Clone, Copy)] +struct CallArgFrame { + level: usize, + arg_start: Option, + in_call: bool, +} + +fn next_non_horizontal_whitespace(bytes: &[u8], mut index: usize) -> usize { + while matches!(bytes.get(index), Some(b' ' | b'\t' | b'\x0c')) { + index += 1; + } + index +} + +fn invalid_call_argument_assignment_error( + source: &str, + arg_start: usize, + equal: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let start = bytes[arg_start..equal] + .iter() + .rposition(|byte| *byte == b'\n') + .map_or(arg_start, |newline| arg_start + newline + 1); + let (target_start, target_end) = trim_target_range(bytes, start, equal); + if target_start >= target_end { + return None; + } + let value_start = next_non_horizontal_whitespace(bytes, equal + 1); + if matches!(bytes.get(value_start), None | Some(b',' | b')')) { + return Some(( + "expected argument value expression".to_owned(), + target_start, + equal + 1, + )); + } + if bytes.get(target_start..target_start + 2) == Some(b"**") { + return Some(( + "cannot assign to keyword argument unpacking".to_owned(), + target_start, + value_start, + )); + } + if bytes.get(target_start) == Some(&b'*') { + return Some(( + "cannot assign to iterable argument unpacking".to_owned(), + target_start, + value_start, + )); + } + for keyword in [b"True".as_slice(), b"False".as_slice(), b"None".as_slice()] { + if bytes.get(target_start..target_end) == Some(keyword) { + let keyword = ::core::str::from_utf8(keyword).ok()?; + return Some(( + format!("cannot assign to {keyword}"), + target_start, + target_end, + )); + } + } + if is_simple_keyword_name(bytes, target_start, target_end) { + return None; + } + Some(( + "expression cannot contain assignment, perhaps you meant \"==\"?".to_owned(), + target_start, + equal, + )) +} + +fn invalid_call_star_expression_error( + bytes: &[u8], + arg_start: usize, + index: usize, +) -> Option<(String, usize, usize)> { + let start = next_non_horizontal_whitespace(bytes, arg_start); + if start != index || bytes.get(index) != Some(&b'*') { + return None; + } + let after_star = next_non_horizontal_whitespace(bytes, index + 1); + if matches!(bytes.get(after_star), None | Some(b',' | b')' | b':')) { + return Some(( + "Invalid star expression".to_owned(), + index, + (index + 1).min(bytes.len()), + )); + } + None +} + +fn invalid_call_argument_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + let mut level = 0usize; + let mut frames: Vec = Vec::new(); + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + _ if starts_identifier(bytes, index, b"lambda") => { + let params_start = index + 6; + if let Some(params_end) = find_lambda_parameter_end(bytes, params_start) { + index = params_end + 1; + } else { + index = params_start; + } + } + b'(' => { + level += 1; + let in_call = opening_paren_is_call(bytes, index) + || frames.last().is_some_and(|frame| frame.in_call); + frames.push(CallArgFrame { + level, + arg_start: (in_call && !is_function_parameter_list(bytes, index)) + .then_some(index + 1), + in_call, + }); + index += 1; + } + b')' => { + if matches!(frames.last(), Some(frame) if frame.level == level) { + frames.pop(); + } + level = level.saturating_sub(1); + index += 1; + } + b'[' | b'{' => { + level += 1; + index += 1; + } + b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b',' => { + if let Some(frame) = frames.last_mut() + && frame.level == level + && frame.arg_start.is_some() + { + frame.arg_start = Some(index + 1); + } + index += 1; + } + b'*' => { + if let Some(CallArgFrame { + level: frame_level, + arg_start: Some(arg_start), + in_call: true, + }) = frames.last().copied() + && frame_level == level + && let Some(error) = invalid_call_star_expression_error(bytes, arg_start, index) + { + return Some(error); + } + index += 1; + } + b'=' if is_plain_assignment_operator(bytes, index) => { + if let Some(CallArgFrame { + level: frame_level, + arg_start: Some(arg_start), + in_call: true, + }) = frames.last().copied() + && frame_level == level + && let Some(error) = + invalid_call_argument_assignment_error(source, arg_start, index) + { + return Some(error); + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn top_level_colon(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b':' if level == 0 => return Some(index), + _ => index += 1, + } + } + None +} + +fn expression_slice_is_valid(source: &str, start: usize, end: usize) -> bool { + let bytes = source.as_bytes(); + let (start, end) = trim_target_range(bytes, start, end); + start < end + && parser::parse(&source[start..end], parser::Mode::Expression.into()) + .is_ok_and(|parsed| matches!(parsed.into_syntax(), ast::Mod::Expression(_))) +} + +fn invalid_dict_entry_error( + source: &str, + item_start: usize, + item_end: usize, + colon: Option, + saw_dict_item: bool, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (item_start, item_end) = trim_target_range(bytes, item_start, item_end); + if item_start >= item_end { + return None; + } + if let Some(colon) = colon { + let value_start = next_non_horizontal_whitespace(bytes, colon + 1); + if value_start >= item_end { + return Some(( + "expression expected after dictionary key and ':'".to_owned(), + colon, + colon + 1, + )); + } + if bytes.get(value_start) == Some(&b'*') { + return Some(( + "cannot use a starred expression in a dictionary value".to_owned(), + value_start, + value_start + 1, + )); + } + if !expression_slice_is_valid(source, value_start, item_end) { + return Some(("invalid syntax".to_owned(), value_start, value_start)); + } + } else if saw_dict_item { + return Some(( + "':' expected after dictionary key".to_owned(), + item_end.saturating_sub(1), + item_end, + )); + } + None +} + +fn invalid_dict_literal_error( + source: &str, + open: usize, + close: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut item_start = open + 1; + let mut index = item_start; + let mut level = 0usize; + let mut saw_dict_item = false; + let mut item_colon = None; + while index <= close { + if index == close || (level == 0 && bytes.get(index) == Some(&b',')) { + if let Some(error) = + invalid_dict_entry_error(source, item_start, index, item_colon, saw_dict_item) + { + return Some(error); + } + saw_dict_item |= item_colon.is_some(); + item_start = index + 1; + item_colon = None; + index += 1; + continue; + } + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b':' if level == 0 && item_colon.is_none() => { + item_colon = Some(index); + saw_dict_item = true; + index += 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_dict_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'{' => { + let Some(close) = matching_delimiter(bytes, index, b'}') else { + index += 1; + continue; + }; + if top_level_colon(bytes, index + 1, close).is_some() + && let Some(error) = invalid_dict_literal_error(source, index, close) + { + return Some(error); + } + index = close + 1; + } + _ => index += 1, + } + } + None +} + +fn collection_open_is_call(bytes: &[u8], open: usize) -> bool { + if bytes.get(open) != Some(&b'(') { + return false; + } + let mut cursor = open; + while cursor > 0 && matches!(bytes.get(cursor - 1), Some(b' ' | b'\t' | b'\x0c')) { + cursor -= 1; + } + matches!( + cursor.checked_sub(1).and_then(|before| bytes.get(before)), + Some(b')' | b']' | b'_' | b'a'..=b'z' | b'A'..=b'Z' | 0x80..=0xff) + ) +} + +fn invalid_collection_assignment_in_slice( + source: &str, + bytes: &[u8], + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + let mut item_start = start; + let mut index = start; + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b',' if level == 0 => { + item_start = index + 1; + index += 1; + } + b'=' if level == 0 && is_plain_assignment_operator(bytes, index) => { + if top_level_colon(bytes, item_start, index).is_none() { + let start = next_non_horizontal_whitespace(bytes, item_start); + let target_end = trim_end_horizontal_whitespace(bytes, start, index); + if start < target_end + && let Some((expr_name, expr_start, expr_end, _)) = + expression_name_and_range(&source[start..target_end]) + { + if matches!(expr_name, "list" | "tuple") { + return None; + } + if matches!(expr_name, "expression" | "attribute" | "subscript") { + return Some(( + format!( + "cannot assign to {expr_name} here. Maybe you meant '==' instead of '='?" + ), + start + expr_start, + start + expr_end, + )); + } + } + return Some(( + "invalid syntax. Maybe you meant '==' or ':=' instead of '='?".to_owned(), + start, + index + 1, + )); + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_collection_assignment_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + let close_byte = match bytes[index] { + b'(' => b')', + b'[' => b']', + _ => b'}', + }; + let Some(close) = matching_delimiter(bytes, index, close_byte) else { + index += 1; + continue; + }; + if !collection_open_is_call(bytes, index) + && let Some(error) = + invalid_collection_assignment_in_slice(source, bytes, index + 1, close) + { + return Some(error); + } + index = close + 1; + } + _ => index += 1, + } + } + None +} + +fn expression_assignment_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + let mut paren_arg_starts: Vec<(Option, bool)> = Vec::new(); + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + _ if starts_identifier(bytes, index, b"lambda") => { + let params_start = index + 6; + if let Some(params_end) = find_lambda_parameter_end(bytes, params_start) { + index = params_end + 1; + } else { + index = params_start; + } + } + b'(' => { + let in_call_context = opening_paren_is_call(bytes, index) + || paren_arg_starts.last().is_some_and(|(_, in_call)| *in_call); + paren_arg_starts.push(( + (!is_function_parameter_list(bytes, index)).then_some(index + 1), + in_call_context, + )); + index += 1; + } + b')' => { + paren_arg_starts.pop(); + index += 1; + } + b',' => { + if let Some((start, _)) = paren_arg_starts.last_mut() + && start.is_some() + { + *start = Some(index + 1); + } + index += 1; + } + b'=' if is_plain_assignment_operator(bytes, index) => { + if let Some((Some(start), true)) = paren_arg_starts.last().copied() + && !is_simple_keyword_name(bytes, start, index) + { + let mut expr_start = start; + while matches!(bytes.get(expr_start), Some(b' ' | b'\t' | b'\x0c')) { + expr_start += 1; + } + return Some(( + "expression cannot contain assignment, perhaps you meant \"==\"?" + .to_owned(), + expr_start, + index, + )); + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn invalid_named_expression_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index + 1 < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b':' if bytes.get(index + 1) == Some(&b'=') => { + let target_start = named_expression_target_start(bytes, index); + let target_end = trim_end_horizontal_whitespace(bytes, target_start, index); + if target_start < target_end + && let Some((expr_name, start, end, is_name)) = + expression_name_and_range(&source[target_start..target_end]) + && !is_name + { + return Some(( + format!("cannot use assignment expressions with {expr_name}"), + target_start + start, + target_start + end, + )); + } + index += 2; + } + _ => index += 1, + } + } + None +} + +#[derive(Clone, Copy)] +struct AssignmentContext { + start: usize, + call: bool, +} + +fn invalid_plain_assignment_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut stack: Vec = Vec::new(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + stack.push(AssignmentContext { + start: index + 1, + call: bytes[index] == b'(' && opening_paren_is_call(bytes, index), + }); + index += 1; + } + b')' | b']' | b'}' => { + stack.pop(); + index += 1; + } + b',' => { + if let Some(context) = stack.last_mut() + && !context.call + { + context.start = index + 1; + } + index += 1; + } + b'=' if is_plain_assignment_operator(bytes, index) => { + if let Some(context) = stack.last().copied() + && !context.call + { + let target_start = skip_horizontal_whitespace(bytes, context.start); + let target_end = trim_end_horizontal_whitespace(bytes, target_start, index); + if target_start < target_end + && let Some((expr_name, start, end, _)) = + expression_name_and_range(&source[target_start..target_end]) + && matches!(expr_name, "expression" | "attribute" | "subscript") + { + return Some(( + format!( + "cannot assign to {expr_name} here. Maybe you meant '==' instead of '='?" + ), + target_start + start, + target_start + end, + )); + } + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn opening_paren_is_call(bytes: &[u8], paren: usize) -> bool { + let mut cursor = paren; + while cursor > 0 && matches!(bytes[cursor - 1], b' ' | b'\t' | b'\x0c') { + cursor -= 1; + } + cursor > 0 + && (bytes[cursor - 1] >= 0x80 + || is_ascii_identifier_char(bytes[cursor - 1]) + || matches!(bytes[cursor - 1], b')' | b']')) +} + +fn named_expression_target_start(bytes: &[u8], walrus: usize) -> usize { + let mut index = walrus; + let mut level = 0usize; + while index > 0 { + index -= 1; + match bytes[index] { + b')' | b']' | b'}' => level += 1, + b'(' | b'[' | b'{' if level > 0 => level -= 1, + b'(' | b'[' | b'{' if level == 0 => return index + 1, + b',' | b'\n' | b';' if level == 0 => return index + 1, + _ => {} + } + } + 0 +} + +fn trim_end_horizontal_whitespace(bytes: &[u8], start: usize, mut end: usize) -> usize { + while end > start && matches!(bytes[end - 1], b' ' | b'\t' | b'\x0c') { + end -= 1; + } + end +} + +fn annotation_target_error_for_slice( + source: &str, + start: usize, + colon: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (target_start, target_end) = trim_target_range(bytes, start, colon); + if target_start >= target_end { + return None; + } + let target_text = &source[target_start..target_end]; + let Ok(parsed) = parser::parse(target_text, parser::Mode::Expression.into()) else { + return None; + }; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + match expression.body.as_ref() { + ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => None, + ast::Expr::List(_) => Some(( + "only single target (not list) can be annotated".to_owned(), + target_start, + target_end, + )), + ast::Expr::Tuple(_) => Some(( + "only single target (not tuple) can be annotated".to_owned(), + target_start, + target_end, + )), + _ => Some(( + "illegal target for annotation".to_owned(), + target_start, + target_end, + )), + } +} + +fn invalid_annotation_line_start(bytes: &[u8], line_start: usize) -> bool { + let column = skip_horizontal_whitespace(bytes, line_start); + for keyword in [ + b"async".as_slice(), + b"case", + b"class", + b"def", + b"elif", + b"else", + b"except", + b"finally", + b"for", + b"if", + b"match", + b"try", + b"while", + b"with", + ] { + if starts_identifier(bytes, column, keyword) { + return false; + } + } + true +} + +fn invalid_annotation_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + if invalid_annotation_line_start(bytes, line_start) + && let Some(colon) = find_byte_at_level(bytes, line_start, line_end, b':') + && bytes.get(colon + 1) != Some(&b'=') + && colon.checked_sub(1).and_then(|before| bytes.get(before)) != Some(&b':') + && let Some(error) = annotation_target_error_for_slice(source, line_start, colon) + { + return Some(error); + } + line_start = line_end; + } + None +} + +fn statement_target_end(bytes: &[u8], mut index: usize) -> usize { + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' if level == 0 => return index, + b'\n' | b';' if level == 0 => return index, + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ => index += 1, + } + } + index +} + +fn invalid_assignment_target(expression: &ast::Expr) -> Option<&ast::Expr> { + match expression { + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { + elts.iter().find_map(invalid_assignment_target) + } + ast::Expr::Starred(ast::ExprStarred { value, .. }) => invalid_assignment_target(value), + ast::Expr::Name(_) | ast::Expr::Subscript(_) | ast::Expr::Attribute(_) => None, + _ => Some(expression), + } +} + +fn invalid_for_target(expression: &ast::Expr) -> Option<&ast::Expr> { + match expression { + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => elts.iter().find_map(invalid_for_target), + ast::Expr::Starred(ast::ExprStarred { value, .. }) => invalid_for_target(value), + ast::Expr::Compare(ast::ExprCompare { left, ops, .. }) => { + if matches!(ops.first(), Some(ast::CmpOp::In)) { + invalid_for_target(left) + } else { + None + } + } + ast::Expr::Name(_) | ast::Expr::Subscript(_) | ast::Expr::Attribute(_) => None, + _ => Some(expression), + } +} + +fn invalid_delete_target(expression: &ast::Expr) -> Option<&ast::Expr> { + match expression { + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { + elts.iter().find_map(invalid_delete_target) + } + ast::Expr::Name(_) | ast::Expr::Subscript(_) | ast::Expr::Attribute(_) => None, + ast::Expr::Starred(_) => Some(expression), + ast::Expr::Compare(_) => Some(expression), + _ => Some(expression), + } +} + +fn delete_target_expr_name(expression: &ast::Expr) -> &'static str { + match expression { + ast::Expr::Attribute(_) => "attribute", + ast::Expr::Subscript(_) => "subscript", + ast::Expr::Starred(_) => "starred", + ast::Expr::Name(_) => "name", + ast::Expr::List(_) => "list", + ast::Expr::Tuple(_) => "tuple", + ast::Expr::Lambda(_) => "lambda", + ast::Expr::Call(_) => "function call", + ast::Expr::BoolOp(_) | ast::Expr::BinOp(_) | ast::Expr::UnaryOp(_) => "expression", + ast::Expr::Generator(_) => "generator expression", + ast::Expr::Yield(_) | ast::Expr::YieldFrom(_) => "yield expression", + ast::Expr::Await(_) => "await expression", + ast::Expr::ListComp(_) => "list comprehension", + ast::Expr::SetComp(_) => "set comprehension", + ast::Expr::DictComp(_) => "dict comprehension", + ast::Expr::Dict(_) => "dict literal", + ast::Expr::Set(_) => "set display", + ast::Expr::FString(_) => "f-string expression", + ast::Expr::TString(_) => "t-string expression", + ast::Expr::NumberLiteral(_) | ast::Expr::StringLiteral(_) | ast::Expr::BytesLiteral(_) => { + "literal" + } + ast::Expr::BooleanLiteral(boolean) => { + if boolean.value { + "True" + } else { + "False" + } + } + ast::Expr::NoneLiteral(_) => "None", + ast::Expr::EllipsisLiteral(_) => "ellipsis", + ast::Expr::Compare(_) => "comparison", + ast::Expr::If(_) => "conditional expression", + ast::Expr::Named(_) => "named expression", + ast::Expr::Slice(_) | ast::Expr::IpyEscapeCommand(_) => "expression", + } +} + +fn parenthesized_single_starred_delete_target(bytes: &[u8], start: usize, end: usize) -> bool { + let mut cursor = start; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\x0c')) { + cursor += 1; + } + if bytes.get(cursor) != Some(&b'(') { + return false; + } + cursor += 1; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\x0c')) { + cursor += 1; + } + if bytes.get(cursor) != Some(&b'*') { + return false; + } + let mut level = 1usize; + cursor += 1; + while cursor < end { + match bytes[cursor] { + b'\'' | b'"' => { + cursor = skip_quoted_string(bytes, cursor); + } + b'(' | b'[' | b'{' => { + level += 1; + cursor += 1; + } + b')' => { + level = level.saturating_sub(1); + if level == 0 { + cursor += 1; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\x0c')) { + cursor += 1; + } + return cursor == end; + } + cursor += 1; + } + b',' if level == 1 => return false, + b']' | b'}' => { + level = level.saturating_sub(1); + cursor += 1; + } + _ => cursor += 1, + } + } + false +} + +fn trim_target_range(bytes: &[u8], mut start: usize, mut end: usize) -> (usize, usize) { + while start < end + && matches!( + bytes.get(start), + Some(b' ' | b'\t' | b'\n' | b'\r' | b'\x0c') + ) + { + start += 1; + } + while end > start + && matches!( + bytes.get(end - 1), + Some(b' ' | b'\t' | b'\n' | b'\r' | b'\x0c') + ) + { + end -= 1; + } + (start, end) +} + +fn invalid_assignment_message(name: &'static str, top_level_bitwise: bool) -> String { + if top_level_bitwise { + format!("cannot assign to {name} here. Maybe you meant '==' instead of '='?") + } else { + format!("cannot assign to {name}") + } +} + +fn assignment_target_error_for_slice( + source: &str, + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (target_start, target_end) = trim_target_range(bytes, start, end); + if target_start >= target_end { + return None; + } + if starts_identifier(bytes, target_start, b"yield") { + return Some(( + "assignment to yield expression not possible".to_owned(), + target_start, + target_start + 5, + )); + } + let target_text = &source[target_start..target_end]; + let Ok(parsed) = parser::parse(target_text, parser::Mode::Expression.into()) else { + return None; + }; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + let invalid_target = invalid_assignment_target(&expression.body)?; + let invalid_start = target_start + invalid_target.range().start().to_usize(); + let invalid_end = target_start + invalid_target.range().end().to_usize(); + if matches!(invalid_target, ast::Expr::FString(_)) { + return Some(("invalid syntax".to_owned(), invalid_start, invalid_end)); + } + let name = delete_target_expr_name(invalid_target); + let top_level = invalid_target.range() == expression.body.range(); + let bitwise_like = matches!( + invalid_target, + ast::Expr::Call(_) + | ast::Expr::BoolOp(_) + | ast::Expr::BinOp(_) + | ast::Expr::UnaryOp(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::EllipsisLiteral(_) + ); + Some(( + invalid_assignment_message(name, top_level && bitwise_like), + invalid_start, + invalid_end, + )) +} + +fn star_target_error_for_slice( + source: &str, + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + invalid_target_error_for_slice(source, start, end, invalid_assignment_target) +} + +fn for_target_error_for_slice( + source: &str, + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + invalid_target_error_for_slice(source, start, end, invalid_for_target) +} + +fn invalid_target_error_for_slice( + source: &str, + start: usize, + end: usize, + invalid_target: for<'a> fn(&'a ast::Expr) -> Option<&'a ast::Expr>, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (target_start, target_end) = trim_target_range(bytes, start, end); + if target_start >= target_end { + return None; + } + let target_text = &source[target_start..target_end]; + let Ok(parsed) = parser::parse(target_text, parser::Mode::Expression.into()) else { + return None; + }; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + let invalid_target = invalid_target(&expression.body)?; + let name = delete_target_expr_name(invalid_target); + let invalid_start = target_start + invalid_target.range().start().to_usize(); + let invalid_end = target_start + invalid_target.range().end().to_usize(); + Some(( + format!("cannot assign to {name}"), + invalid_start, + invalid_end, + )) +} + +fn first_compare_operator_at_level(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b'<' | b'>' if level == 0 => return Some(index), + b'=' if level == 0 && bytes.get(index + 1) == Some(&b'=') => return Some(index), + b'!' if level == 0 && bytes.get(index + 1) == Some(&b'=') => return Some(index), + _ if level == 0 && starts_identifier(bytes, index, b"is") => return Some(index), + _ if level == 0 && starts_identifier(bytes, index, b"not") => return Some(index), + _ => index += 1, + } + } + None +} + +fn non_in_compare_for_target_error( + source: &str, + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (target_start, target_end) = trim_target_range(bytes, start, end); + if target_start >= target_end { + return None; + } + let target_text = &source[target_start..target_end]; + let Ok(parsed) = parser::parse(target_text, parser::Mode::Expression.into()) else { + return None; + }; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + let ast::Expr::Compare(ast::ExprCompare { ops, .. }) = expression.body.as_ref() else { + return None; + }; + if matches!(ops.first(), Some(ast::CmpOp::In)) { + return None; + } + let operator = first_compare_operator_at_level(bytes, target_start, target_end)?; + Some(( + "invalid syntax".to_owned(), + operator, + (operator + 1).min(target_end), + )) +} + +fn top_level_plain_assignment_offsets(bytes: &[u8]) -> Vec { + let mut offsets = Vec::new(); + let mut index = 0usize; + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' if level == 0 => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b'=' if level == 0 && is_plain_assignment_operator(bytes, index) => { + offsets.push(index); + index += 1; + } + _ => index += 1, + } + } + offsets +} + +fn invalid_assignment_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let offsets = top_level_plain_assignment_offsets(bytes); + if offsets.is_empty() { + return None; + } + let mut start = 0usize; + for offset in offsets { + if let Some(error) = assignment_target_error_for_slice(source, start, offset) { + return Some(error); + } + start = offset + 1; + } + None +} + +fn top_level_augassign_offset(bytes: &[u8]) -> Option<(usize, usize)> { + let mut index = 0usize; + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' if level == 0 => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b'+' | b'-' | b'*' | b'@' | b'/' | b'%' | b'&' | b'|' | b'^' + if level == 0 && bytes.get(index + 1) == Some(&b'=') => + { + return Some((index, 2)); + } + b'<' | b'>' + if level == 0 + && bytes.get(index + 1) == Some(&bytes[index]) + && bytes.get(index + 2) == Some(&b'=') => + { + return Some((index, 3)); + } + b'*' if level == 0 + && bytes.get(index + 1) == Some(&b'*') + && bytes.get(index + 2) == Some(&b'=') => + { + return Some((index, 3)); + } + b'/' if level == 0 + && bytes.get(index + 1) == Some(&b'/') + && bytes.get(index + 2) == Some(&b'=') => + { + return Some((index, 3)); + } + _ => index += 1, + } + } + None +} + +fn invalid_augassign_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (operator, _) = top_level_augassign_offset(bytes)?; + let (target_start, target_end) = trim_target_range(bytes, 0, operator); + if target_start >= target_end { + return None; + } + let target_text = &source[target_start..target_end]; + let Ok(parsed) = parser::parse(target_text, parser::Mode::Expression.into()) else { + return None; + }; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + let name = delete_target_expr_name(&expression.body); + Some(( + format!("'{name}' is an illegal expression for augmented assignment"), + target_start, + target_end, + )) +} + +fn find_for_target_delimiter(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'#' if level == 0 => return None, + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ if level == 0 && starts_identifier(bytes, index, b"in") => return Some(index), + b':' if level == 0 => return Some(index), + _ => index += 1, + } + } + None +} + +fn invalid_for_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + _ if starts_identifier(bytes, index, b"for") => { + let target_start = skip_horizontal_whitespace(bytes, index + 3); + let line_end = source[index..] + .find('\n') + .map_or(bytes.len(), |newline| index + newline); + if let Some(target_end) = find_for_target_delimiter(bytes, target_start, line_end) { + if let Some(error) = + for_target_error_for_slice(source, target_start, target_end) + { + return Some(error); + } + if let Some(error) = + non_in_compare_for_target_error(source, target_start, target_end) + { + return Some(error); + } + } + index = target_start.max(index + 3); + } + _ => index += 1, + } + } + None +} + +fn find_with_target_delimiter(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'#' if level == 0 => return None, + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' if level == 0 => return Some(index), + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b',' | b':' if level == 0 => return Some(index), + _ => index += 1, + } + } + None +} + +fn invalid_with_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let mut column = skip_horizontal_whitespace(bytes, line_start); + if starts_identifier(bytes, column, b"async") { + column = skip_horizontal_whitespace(bytes, column + 5); + } + if !starts_identifier(bytes, column, b"with") { + line_start = line_end; + continue; + } + let mut index = column + 4; + while let Some(as_index) = find_keyword_at_level(bytes, index, line_end, b"as") { + let target_start = skip_horizontal_whitespace(bytes, as_index + 2); + if let Some(target_end) = find_with_target_delimiter(bytes, target_start, line_end) { + if let Some(error) = star_target_error_for_slice(source, target_start, target_end) { + return Some(error); + } + index = target_end.saturating_add(1); + } else { + break; + } + } + line_start = line_end; + } + None +} + +fn find_missing_in_if_keyword(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'#' if level == 0 => return None, + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' if level == 0 => return None, + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ if level == 0 && starts_identifier(bytes, index, b"in") => return None, + _ if level == 0 && starts_identifier(bytes, index, b"if") => return Some(index), + _ => index += 1, + } + } + None +} + +fn invalid_for_if_clause_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0usize; + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' if level == 0 => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ if level > 0 && starts_identifier(bytes, index, b"for") => { + let target_start = skip_horizontal_whitespace(bytes, index + 3); + let line_end = source[index..] + .find('\n') + .map_or(bytes.len(), |newline| index + newline); + if let Some(if_index) = find_missing_in_if_keyword(bytes, target_start, line_end) { + return Some(( + "'in' expected after for-loop variables".to_owned(), + if_index, + (if_index + 2).min(line_end), + )); + } + index = target_start.max(index + 3); + } + _ => index += 1, + } + } + None +} + +fn invalid_delete_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'd' if starts_identifier(bytes, index, b"del") => { + let mut target_start = index + 3; + if !matches!(bytes.get(target_start), Some(b' ' | b'\t' | b'\x0c')) { + index += 3; + continue; + } + while matches!(bytes.get(target_start), Some(b' ' | b'\t' | b'\x0c')) { + target_start += 1; + } + let mut target_end = statement_target_end(bytes, target_start); + while target_end > target_start + && matches!(bytes.get(target_end - 1), Some(b' ' | b'\t' | b'\x0c')) + { + target_end -= 1; + } + if target_start >= target_end { + index = target_end.max(index + 3); + continue; + } + if parenthesized_single_starred_delete_target(bytes, target_start, target_end) { + return Some(( + "cannot use starred expression here".to_owned(), + target_start, + target_end, + )); + } + if bytes.get(target_start) == Some(&b'*') { + return Some(( + "cannot delete starred".to_owned(), + target_start, + (target_start + 1).min(target_end), + )); + } + let target_text = &source[target_start..target_end]; + let Ok(parsed) = parser::parse(target_text, parser::Mode::Expression.into()) else { + index = target_end; + continue; + }; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + index = target_end; + continue; + }; + let Some(invalid_target) = invalid_delete_target(&expression.body) else { + index = target_end; + continue; + }; + let start = target_start + invalid_target.range().start().to_usize(); + let end = target_start + invalid_target.range().end().to_usize(); + if matches!(invalid_target, ast::Expr::FString(_)) { + return Some(("invalid syntax".to_owned(), start, end)); + } + let name = delete_target_expr_name(invalid_target); + return Some((format!("cannot delete {name}"), start, end)); + } + _ => index += 1, + } + } + None +} + +fn skip_horizontal_whitespace(bytes: &[u8], mut index: usize) -> usize { + while matches!(bytes.get(index), Some(b' ' | b'\t' | b'\x0c')) { + index += 1; + } + index +} + +fn find_keyword_at_level( + bytes: &[u8], + mut index: usize, + end: usize, + keyword: &[u8], +) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'#' if level == 0 => return None, + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ if level == 0 && starts_identifier(bytes, index, keyword) => return Some(index), + _ => index += 1, + } + } + None +} + +fn find_byte_at_level(bytes: &[u8], mut index: usize, end: usize, needle: u8) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'#' if level == 0 => return None, + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + byte if level == 0 && byte == needle => return Some(index), + _ => index += 1, + } + } + None +} + +fn expression_name_and_range(source: &str) -> Option<(&'static str, usize, usize, bool)> { + let parsed = parser::parse(source, parser::Mode::Expression.into()).ok()?; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + let is_name = matches!(expression.body.as_ref(), ast::Expr::Name(_)); + Some(( + delete_target_expr_name(&expression.body), + expression.body.range().start().to_usize(), + expression.body.range().end().to_usize(), + is_name, + )) +} + +fn invalid_standalone_except_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + let mut seen_try = false; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let column = skip_horizontal_whitespace(bytes, line_start); + if column >= line_end { + line_start = line_end; + continue; + } + if starts_identifier(bytes, column, b"try") { + seen_try = true; + } else if (bytes.get(column..column + 7) == Some(b"except*") + || starts_identifier(bytes, column, b"except")) + && !seen_try + { + let end = if bytes.get(column..column + 7) == Some(b"except*") { + column + 7 + } else { + column + 6 + }; + return Some(("invalid syntax".to_owned(), column, end)); + } + line_start = line_end; + } + None +} + +fn invalid_import_statement_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let column = skip_horizontal_whitespace(bytes, line_start); + if column < line_end + && starts_identifier(bytes, column, b"import") + && find_keyword_at_level(bytes, column + 6, line_end, b"from").is_some() + { + return Some(( + "Did you mean to use 'from ... import ...' instead?".to_owned(), + column, + column + 6, + )); + } + line_start = line_end; + } + None +} + +fn import_as_target_end(bytes: &[u8], mut index: usize) -> usize { + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' if level == 0 => return index, + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' if level == 0 => return index, + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + b',' | b';' | b'\n' if level == 0 => return index, + _ => index += 1, + } + } + index +} + +fn valid_import_alias_name(bytes: &[u8], mut start: usize, end: usize) -> bool { + start = skip_horizontal_whitespace(bytes, start); + let Some(&first) = bytes.get(start) else { + return false; + }; + if !(first == b'_' || first.is_ascii_alphabetic() || first >= 0x80) { + return false; + } + let mut index = start + 1; + while index < end { + match bytes[index] { + b' ' | b'\t' | b'\x0c' => break, + byte if byte >= 0x80 || is_ascii_identifier_char(byte) => index += 1, + _ => return false, + } + } + let index = skip_horizontal_whitespace(bytes, index); + matches!( + bytes.get(index), + None | Some(b',' | b')' | b';' | b'\n' | b'\r') + ) +} + +fn import_target_error_for_slice( + source: &str, + start: usize, + end: usize, +) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let (target_start, target_end) = trim_target_range(bytes, start, end); + if target_start >= target_end || valid_import_alias_name(bytes, target_start, target_end) { + return None; + } + let parsed = parser::parse( + &source[target_start..target_end], + parser::Mode::Expression.into(), + ) + .ok()?; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + return None; + }; + let name = delete_target_expr_name(&expression.body); + let start = target_start + expression.body.range().start().to_usize(); + let end = target_start + expression.body.range().end().to_usize(); + Some((format!("cannot use {name} as import target"), start, end)) +} + +fn statement_starts_import(bytes: &[u8], line_start: usize, line_end: usize) -> bool { + let column = skip_horizontal_whitespace(bytes, line_start); + if starts_identifier(bytes, column, b"import") { + return true; + } + starts_identifier(bytes, column, b"from") + && find_keyword_at_level(bytes, column + 4, line_end, b"import").is_some() +} + +fn invalid_import_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + let mut in_parenthesized_from_import = false; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let starts_import = statement_starts_import(bytes, line_start, line_end); + if starts_import && bytes[line_start..line_end].contains(&b'(') { + in_parenthesized_from_import = true; + } + if starts_import || in_parenthesized_from_import { + let mut index = line_start; + while index < line_end { + if starts_identifier(bytes, index, b"as") { + let target_start = skip_horizontal_whitespace(bytes, index + 2); + let target_end = import_as_target_end(bytes, target_start); + if let Some(error) = + import_target_error_for_slice(source, target_start, target_end) + { + return Some(error); + } + index = target_end.max(index + 2); + } else { + index += 1; + } + } + } + if in_parenthesized_from_import && bytes[line_start..line_end].contains(&b')') { + in_parenthesized_from_import = false; + } + line_start = line_end; + } + None +} + +fn invalid_except_as_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + let mut seen_try = false; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let mut column = skip_horizontal_whitespace(bytes, line_start); + if column >= line_end { + line_start = line_end; + continue; + } + if starts_identifier(bytes, column, b"try") { + seen_try = true; + line_start = line_end; + continue; + } + let (keyword_len, starred) = if bytes.get(column..column + 7) == Some(b"except*") { + (7, true) + } else if starts_identifier(bytes, column, b"except") { + (6, false) + } else { + line_start = line_end; + continue; + }; + if !seen_try { + line_start = line_end; + continue; + } + column += keyword_len; + let Some(as_index) = find_keyword_at_level(bytes, column, line_end, b"as") else { + line_start = line_end; + continue; + }; + let target_start = skip_horizontal_whitespace(bytes, as_index + 2); + let Some(delimiter) = find_byte_at_level(bytes, target_start, line_end, b':') + .into_iter() + .chain(find_byte_at_level(bytes, target_start, line_end, b',')) + .min() + else { + line_start = line_end; + continue; + }; + let mut target_end = delimiter; + while target_end > target_start + && matches!(bytes.get(target_end - 1), Some(b' ' | b'\t' | b'\x0c')) + { + target_end -= 1; + } + let Some((expr_name, start, end, is_name)) = + expression_name_and_range(&source[target_start..target_end]) + else { + line_start = line_end; + continue; + }; + if !is_name { + let statement = if starred { "except*" } else { "except" }; + return Some(( + format!("cannot use {statement} statement with {expr_name}"), + target_start + start, + target_start + end, + )); + } + line_start = line_end; + } + None +} + +fn invalid_match_as_target_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let mut column = skip_horizontal_whitespace(bytes, line_start); + if column >= line_end || !starts_identifier(bytes, column, b"case") { + line_start = line_end; + continue; + } + column += 4; + let Some(as_index) = find_keyword_at_level(bytes, column, line_end, b"as") else { + line_start = line_end; + continue; + }; + let target_start = skip_horizontal_whitespace(bytes, as_index + 2); + let Some(delimiter) = find_byte_at_level(bytes, target_start, line_end, b':') + .into_iter() + .chain(find_byte_at_level(bytes, target_start, line_end, b',')) + .min() + else { + line_start = line_end; + continue; + }; + let mut target_end = delimiter; + while target_end > target_start + && matches!(bytes.get(target_end - 1), Some(b' ' | b'\t' | b'\x0c')) + { + target_end -= 1; + } + if source[target_start..target_end].trim() == "_" { + return Some(( + "cannot use '_' as a target".to_owned(), + target_start, + target_end, + )); + } + let Some((expr_name, start, end, is_name)) = + expression_name_and_range(&source[target_start..target_end]) + else { + line_start = line_end; + continue; + }; + if !is_name { + if matches!(expr_name, "expression" | "subscript") { + line_start = line_end; + continue; + } + return Some(( + format!("cannot use {expr_name} as pattern target"), + target_start + start, + target_start + end, + )); + } + line_start = line_end; + } + None +} + +fn invalid_match_mapping_rest_wildcard_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let next_line_end = |line_start: usize| { + line_start + + bytes[line_start..] + .iter() + .position(|byte| *byte == b'\n') + .unwrap_or(bytes.len() - line_start) + }; + let mut index = 0usize; + let mut line_start = 0usize; + let mut line_end = next_line_end(line_start); + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'\n' => { + index += 1; + line_start = index; + line_end = next_line_end(line_start); + } + _ => { + let column = skip_horizontal_whitespace(bytes, line_start); + if index != column + || column >= line_end + || !starts_identifier(bytes, column, b"case") + { + index += 1; + continue; + } + let mut cursor = column + 4; + while cursor < line_end { + match bytes[cursor] { + b'#' => break, + b'\'' | b'"' => cursor = skip_quoted_string(bytes, cursor), + b'{' => { + let rest = next_non_horizontal_whitespace(bytes, cursor + 1); + if bytes.get(rest..rest + 2) == Some(b"**") { + let name_start = next_non_horizontal_whitespace(bytes, rest + 2); + let name_end = identifier_end(bytes, name_start, line_end); + if source.get(name_start..name_end) == Some("_") { + return Some(( + "invalid syntax".to_owned(), + name_start, + name_end, + )); + } + } + cursor += 1; + } + _ => cursor += 1, + } + } + index = line_end; + } + } + } + None +} + +fn invalid_if_expression_statement_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + if let Some(if_index) = find_keyword_at_level(bytes, line_start, line_end, b"if") + && let Some((start, end)) = statement_before_if_expression(bytes, line_start, if_index) + && find_keyword_at_level(bytes, if_index + 2, line_end, b"else").is_some() + { + return Some(( + "expected expression before 'if', but statement is given".to_owned(), + start, + end, + )); + } + if let Some(else_index) = find_keyword_at_level(bytes, line_start, line_end, b"else") + && find_keyword_at_level(bytes, line_start, else_index, b"if").is_some() + && let Some((start, end)) = + statement_after_else_expression(bytes, else_index + 4, line_end) + { + return Some(( + "expected expression after 'else', but statement is given".to_owned(), + start, + end, + )); + } + line_start = line_end; + } + None +} + +fn statement_before_if_expression( + bytes: &[u8], + line_start: usize, + if_index: usize, +) -> Option<(usize, usize)> { + let mut start = if_index; + while start > line_start && matches!(bytes.get(start - 1), Some(b' ' | b'\t' | b'\x0c')) { + start -= 1; + } + while start > line_start + && !matches!( + bytes.get(start - 1), + Some(b'=' | b':' | b',' | b'(' | b'[' | b'{') + ) + { + start -= 1; + } + start = skip_horizontal_whitespace(bytes, start); + for keyword in [b"pass".as_slice(), b"break", b"continue"] { + if starts_identifier(bytes, start, keyword) { + return Some((start, start + keyword.len())); + } + } + None +} + +fn statement_after_else_expression( + bytes: &[u8], + else_end: usize, + line_end: usize, +) -> Option<(usize, usize)> { + let start = skip_horizontal_whitespace(bytes, else_end); + for keyword in [ + b"pass".as_slice(), + b"return", + b"raise", + b"del", + b"yield", + b"assert", + b"break", + b"continue", + b"import", + b"from", + ] { + if starts_identifier(bytes, start, keyword) { + let end = statement_target_end(bytes, start).min(line_end); + return Some((start, end.max(start + keyword.len()))); + } + } + None +} + +fn invalid_else_elif_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut line_start = 0usize; + let mut else_indents: Vec = Vec::new(); + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let column = skip_horizontal_whitespace(bytes, line_start); + let line_column = column.saturating_sub(line_start); + if column >= line_end { + line_start = line_end; + continue; + } + while else_indents + .last() + .is_some_and(|indent| line_column < *indent) + { + else_indents.pop(); + } + if starts_identifier(bytes, column, b"else") + && find_byte_at_level(bytes, column + 4, line_end, b':').is_some() + { + else_indents.push(line_column); + } else if starts_identifier(bytes, column, b"elif") && else_indents.contains(&line_column) { + return Some(( + "'elif' block follows an 'else' block".to_owned(), + column, + column + 4, + )); + } + line_start = line_end; + } + None +} + +fn mixed_except_handlers_error(source: &str) -> Option<(String, usize, usize)> { + let message = "cannot have both 'except' and 'except*' on the same 'try'".to_owned(); + let mut seen_except = false; + let mut seen_except_star = false; + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let bytes = line.as_bytes(); + let mut column = 0usize; + while matches!(bytes.get(column), Some(b' ' | b'\t' | b'\x0c')) { + column += 1; + } + let token_start = line_start + column; + if bytes.get(column..column + 7) == Some(b"except*") { + if seen_except { + return Some((message, token_start, token_start + 7)); + } + seen_except_star = true; + } else if starts_identifier(bytes, column, b"except") { + if seen_except_star { + return Some((message, token_start, token_start + 6)); + } + seen_except = true; + } + line_start += line.len(); + } + None +} + +fn non_printable_character_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + byte if byte.is_ascii_control() && !matches!(byte, b'\t' | b'\n' | b'\r' | b'\x0c') => { + return Some(( + format!("invalid non-printable character U+{byte:04X}"), + index, + index + 1, + )); + } + byte if byte >= 0x80 => { + let ch = source[index..].chars().next()?; + if ch.is_control() { + return Some(( + format!("invalid non-printable character U+{:04X}", ch as u32), + index, + index + ch.len_utf8(), + )); + } + index += ch.len_utf8(); + } + _ => index += 1, + } + } + None +} + +fn unterminated_string_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + let mut line = 1usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\n' => { + line += 1; + index += 1; + } + quote @ (b'\'' | b'"') => { + let start = index; + let start_line = line; + let quote_size = if bytes.get(index + 1) == Some("e) + && bytes.get(index + 2) == Some("e) + { + 3 + } else { + 1 + }; + index += quote_size; + let mut has_escaped_quote = false; + let mut closed = false; + while index < bytes.len() { + let c = bytes[index]; + if c == b'\n' { + if quote_size == 1 { + return Some(( + unterminated_string_message(line, false, has_escaped_quote), + start, + start + 1, + )); + } + line += 1; + index += 1; + } else if c == quote { + if quote_size == 3 { + if bytes.get(index + 1) == Some("e) + && bytes.get(index + 2) == Some("e) + { + index += 3; + closed = true; + break; + } + index += 1; + } else { + index += 1; + closed = true; + break; + } + } else if c == b'\\' { + if bytes.get(index + 1) == Some("e) { + has_escaped_quote = true; + } + index = (index + 2).min(bytes.len()); + } else { + index += 1; + } + } + if !closed { + let detected_line = if quote_size == 3 { line } else { start_line }; + return Some(( + unterminated_string_message( + detected_line, + quote_size == 3, + has_escaped_quote, + ), + start, + start + 1, + )); + } + } + _ => index += 1, + } + } + None +} + +fn invalid_interpolated_string_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + quote @ (b'\'' | b'"') => { + let Some(prefix) = interpolated_string_prefix(bytes, index) else { + index = skip_quoted_string(bytes, index); + continue; + }; + if let Some(error) = + single_quoted_format_spec_newline_error(bytes, index, quote, prefix) + { + return Some(error); + } + let Some((content_start, content_end)) = + quoted_string_content_range(bytes, index, quote) + else { + index = skip_quoted_string(bytes, index); + continue; + }; + if let Some(error) = + invalid_replacement_field_error(bytes, content_start, content_end, prefix) + { + return Some(error); + } + index = skip_quoted_string(bytes, index); + } + _ => index += 1, + } + } + None +} + +fn single_quoted_format_spec_newline_error( + bytes: &[u8], + quote_index: usize, + quote: u8, + prefix: &str, +) -> Option<(String, usize, usize)> { + if bytes.get(quote_index + 1) == Some("e) && bytes.get(quote_index + 2) == Some("e) { + return None; + } + + let (content_start, content_end) = quoted_string_content_range(bytes, quote_index, quote)?; + let mut index = content_start; + while index < content_end { + match bytes[index] { + b'{' if bytes.get(index + 1) == Some(&b'{') => index += 2, + b'}' if bytes.get(index + 1) == Some(&b'}') => index += 2, + b'{' => { + let expr_start = skip_ascii_whitespace(bytes, index + 1, content_end); + if let Some(separator) = replacement_field_separator(bytes, expr_start, content_end) + && bytes[separator] == b':' + { + let format_end = + replacement_field_closing_brace(bytes, separator + 1, content_end) + .unwrap_or(content_end); + if bytes[separator + 1..format_end].contains(&b'\n') { + return Some(( + format!( + "{prefix}: newlines are not allowed in format specifiers for single quoted {prefix}s" + ), + quote_index, + quote_index + 1, + )); + } + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn interpolated_string_prefix(bytes: &[u8], quote: usize) -> Option<&'static str> { + let prev = quote.checked_sub(1).and_then(|index| bytes.get(index))?; + let lower_prev = prev.to_ascii_lowercase(); + let (prefix_start, marker) = if matches!(lower_prev, b'f' | b't') { + if quote >= 2 && bytes[quote - 2].eq_ignore_ascii_case(&b'r') { + (quote - 2, lower_prev) + } else { + (quote - 1, lower_prev) + } + } else if lower_prev == b'r' + && quote >= 2 + && matches!(bytes[quote - 2].to_ascii_lowercase(), b'f' | b't') + { + (quote - 2, bytes[quote - 2].to_ascii_lowercase()) + } else { + return None; + }; + + if prefix_start > 0 && is_ascii_identifier_char(bytes[prefix_start - 1]) { + return None; + } + + Some(if marker == b'f' { + "f-string" + } else { + "t-string" + }) +} + +fn quoted_string_content_range( + bytes: &[u8], + quote_index: usize, + quote: u8, +) -> Option<(usize, usize)> { + let triple = + bytes.get(quote_index + 1) == Some("e) && bytes.get(quote_index + 2) == Some("e); + let quote_len = if triple { 3 } else { 1 }; + let content_start = quote_index + quote_len; + let mut index = content_start; + while index < bytes.len() { + if bytes[index] == b'\\' { + index = (index + 2).min(bytes.len()); + } else if (triple + && bytes.get(index) == Some("e) + && bytes.get(index + 1) == Some("e) + && bytes.get(index + 2) == Some("e)) + || (!triple && bytes[index] == quote) + { + return Some((content_start, index)); + } else { + index += 1; + } + } + None +} + +fn invalid_replacement_field_error( + bytes: &[u8], + start: usize, + end: usize, + prefix: &str, +) -> Option<(String, usize, usize)> { + let mut index = start; + while index < end { + match bytes[index] { + b'{' if bytes.get(index + 1) == Some(&b'{') => index += 2, + b'}' if bytes.get(index + 1) == Some(&b'}') => index += 2, + b'{' => { + if let Some(error) = replacement_field_error(bytes, index, end, prefix) { + return Some(error); + } + index += 1; + } + _ => index += 1, + } + } + None +} + +fn replacement_field_error( + bytes: &[u8], + open: usize, + end: usize, + prefix: &str, +) -> Option<(String, usize, usize)> { + let expr_start = skip_ascii_whitespace(bytes, open + 1, end); + if let Some(backslash) = replacement_field_line_continuation(bytes, expr_start, end) { + return Some(( + "unexpected character after line continuation character".to_owned(), + backslash + 1, + (backslash + 2).min(end), + )); + } + if let Some(quote) = unterminated_string_in_replacement_field(bytes, expr_start, end) { + return Some(( + unterminated_string_message(1, false, false), + quote, + quote + 1, + )); + } + match bytes.get(expr_start).copied() { + Some(marker @ (b'=' | b'!' | b':' | b'}')) => { + return Some(( + format!( + "{prefix}: valid expression required before '{}'", + marker as char + ), + expr_start, + expr_start + 1, + )); + } + Some(_) => {} + None => { + return Some(( + format!("{prefix}: expecting a valid expression after '{{'"), + open, + open + 1, + )); + } + } + + if starts_identifier(bytes, expr_start, b"lambda") { + return Some(( + format!("{prefix}: lambda expressions are not allowed without parentheses"), + expr_start, + expr_start + b"lambda".len(), + )); + } + + if invalid_replacement_expression_start(bytes, expr_start, end) { + return Some(( + format!("{prefix}: expecting a valid expression after '{{'"), + open, + open + 1, + )); + } + + let Some(separator) = replacement_field_separator(bytes, expr_start, end) else { + return Some((format!("{prefix}: expecting '}}'"), open, open + 1)); + }; + + if bytes[separator] == b':' + && replacement_expression_has_parse_error(bytes, expr_start, separator) + { + return Some(("invalid syntax".to_owned(), expr_start, separator)); + } + + match bytes[separator] { + b'=' => invalid_debug_expression_error(bytes, separator, end, prefix), + b'!' => invalid_conversion_error(bytes, separator, end, prefix), + b':' => invalid_format_spec_error(bytes, separator, end, prefix), + b'}' => None, + _ => unreachable!(), + } +} + +fn replacement_field_line_continuation( + bytes: &[u8], + mut index: usize, + end: usize, +) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'\\' => return Some(index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' if level > 0 => { + level -= 1; + index += 1; + } + b'=' | b'!' | b':' | b'}' if level == 0 => return None, + _ => index += 1, + } + } + None +} + +fn unterminated_string_in_replacement_field( + bytes: &[u8], + mut index: usize, + end: usize, +) -> Option { + while index < end { + match bytes[index] { + quote @ (b'\'' | b'"') => { + let string_end = skip_quoted_string(bytes, index); + if string_end >= end && !bytes[index + 1..end].contains("e) { + return Some(index); + } + index = string_end; + } + _ => index += 1, + } + } + None +} + +fn replacement_expression_has_parse_error(bytes: &[u8], start: usize, end: usize) -> bool { + let Ok(expression) = ::core::str::from_utf8(&bytes[start..end]) else { + return false; + }; + parser::parse_expression(expression).is_err() +} + +fn invalid_replacement_expression_start(bytes: &[u8], index: usize, end: usize) -> bool { + if index >= end { + return true; + } + + if matches!( + bytes[index], + b'.' | b',' | b'*' | b'/' | b'%' | b'&' | b'|' | b'^' | b'<' | b'>' | b'@' + ) { + return true; + } + + if matches!(bytes[index], b'+' | b'-' | b'~') { + let operand = skip_ascii_whitespace(bytes, index + 1, end); + return !bytes.get(operand).is_some_and(|byte| { + *byte >= 0x80 + || *byte == b'_' + || byte.is_ascii_alphabetic() + || byte.is_ascii_digit() + || matches!(*byte, b'\'' | b'"' | b'(' | b'[' | b'{') + }); + } + + [ + b"and".as_slice(), + b"as".as_slice(), + b"else".as_slice(), + b"for".as_slice(), + b"if".as_slice(), + b"in".as_slice(), + b"is".as_slice(), + b"or".as_slice(), + ] + .iter() + .any(|keyword| starts_identifier(bytes, index, keyword)) +} + +fn replacement_field_separator(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' if level > 0 => { + level -= 1; + index += 1; + } + b'=' | b'!' | b':' | b'}' if level == 0 => return Some(index), + _ => index += 1, + } + } + None +} + +fn invalid_debug_expression_error( + bytes: &[u8], + equals: usize, + end: usize, + prefix: &str, +) -> Option<(String, usize, usize)> { + let next = equals + 1; + if next >= end || matches!(bytes[next], b'!' | b':' | b'}') { + return None; + } + Some(( + format!("{prefix}: expecting '!', or ':', or '}}'"), + next, + next.saturating_add(1).min(end), + )) +} + +fn invalid_conversion_error( + bytes: &[u8], + bang: usize, + end: usize, + prefix: &str, +) -> Option<(String, usize, usize)> { + let next = bang + 1; + if next >= end { + return Some((format!("{prefix}: expecting '}}'"), bang, bang + 1)); + } + + if bytes[next].is_ascii_whitespace() { + let following = skip_ascii_whitespace(bytes, next, end); + let message = if bytes + .get(following) + .is_some_and(|byte| byte.is_ascii_alphabetic() || *byte == b'_') + { + "conversion type must come right after the exclamation mark" + } else { + "missing conversion character" + }; + return Some((format!("{prefix}: {message}"), next, next + 1)); + } + + if matches!(bytes[next], b':' | b'}') { + return Some(( + format!("{prefix}: missing conversion character"), + next, + next + 1, + )); + } + + if !bytes[next].is_ascii_alphabetic() && bytes[next] != b'_' { + return Some(( + format!("{prefix}: invalid conversion character"), + next, + next + 1, + )); + } + + let conversion_end = identifier_end(bytes, next, end); + let conversion = &bytes[next..conversion_end]; + if !matches!(conversion, b"s" | b"r" | b"a") { + let conversion = ::core::str::from_utf8(conversion).unwrap_or(""); + return Some(( + format!( + "{prefix}: invalid conversion character '{conversion}': expected 's', 'r', or 'a'" + ), + next, + conversion_end, + )); + } + + if conversion_end >= end || matches!(bytes[conversion_end], b':' | b'}') { + return None; + } + + Some(( + format!("{prefix}: expecting ':' or '}}'"), + conversion_end, + conversion_end + 1, + )) +} + +fn invalid_format_spec_error( + bytes: &[u8], + colon: usize, + end: usize, + prefix: &str, +) -> Option<(String, usize, usize)> { + if replacement_field_closing_brace(bytes, colon + 1, end).is_some() { + return None; + } + Some(( + format!("{prefix}: expecting '}}', or format specs"), + colon, + colon + 1, + )) +} + +fn replacement_field_closing_brace(bytes: &[u8], mut index: usize, end: usize) -> Option { + let mut level = 0usize; + while index < end { + match bytes[index] { + b'\'' | b'"' => index = skip_quoted_string(bytes, index), + b'{' => { + level += 1; + index += 1; + } + b'}' if level > 0 => { + level -= 1; + index += 1; + } + b'}' => return Some(index), + _ => index += 1, + } + } + None +} + +fn skip_ascii_whitespace(bytes: &[u8], mut index: usize, end: usize) -> usize { + while index < end && matches!(bytes[index], b' ' | b'\t' | b'\r' | b'\n' | 0x0c) { + index += 1; + } + index +} + +fn string_literal_end_at(bytes: &[u8], index: usize) -> Option { + match bytes.get(index).copied()? { + b'\'' | b'"' => Some(skip_quoted_string(bytes, index)), + first if first.is_ascii_alphabetic() => { + if matches!(bytes.get(index + 1), Some(b'\'' | b'"')) { + return string_literal_prefix(bytes, index, index + 1) + .then(|| skip_quoted_string(bytes, index + 1)); + } + if matches!(bytes.get(index + 2), Some(b'\'' | b'"')) { + return string_literal_prefix(bytes, index, index + 2) + .then(|| skip_quoted_string(bytes, index + 2)); + } + None + } + _ => None, + } +} + +fn string_literal_prefix(bytes: &[u8], start: usize, quote: usize) -> bool { + let prefix = &bytes[start..quote]; + let valid = matches!( + prefix, + b"b" | b"B" + | b"r" + | b"R" + | b"u" + | b"U" + | b"f" + | b"F" + | b"t" + | b"T" + | b"br" + | b"bR" + | b"Br" + | b"BR" + | b"rb" + | b"rB" + | b"Rb" + | b"RB" + | b"fr" + | b"fR" + | b"Fr" + | b"FR" + | b"rf" + | b"rF" + | b"Rf" + | b"RF" + | b"tr" + | b"tR" + | b"Tr" + | b"TR" + | b"rt" + | b"rT" + | b"Rt" + | b"RT" + ); + valid && (start == 0 || !is_ascii_identifier_char(bytes[start - 1])) +} + +fn invalid_expression_error(source: &str) -> Option<(String, usize, usize)> { + invalid_string_expression_error(source).or_else(|| missing_comma_expression_error(source)) +} + +fn invalid_string_expression_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + if let Some(first_string_end) = string_literal_end_at(bytes, index) { + let expr_start = skip_ascii_whitespace(bytes, first_string_end, bytes.len()); + if expression_atom_start(bytes, expr_start) + && let Some(expr_end) = adjacent_atom_end(bytes, expr_start) + { + let next = skip_ascii_whitespace(bytes, expr_end, bytes.len()); + if string_literal_end_at(bytes, next).is_some() { + return Some(( + "invalid syntax. Is this intended to be part of the string?".to_owned(), + expr_start, + expr_end, + )); + } + } + index = first_string_end; + } else { + index += 1; + } + } + None +} + +fn missing_comma_expression_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut stack: Vec = Vec::new(); + let mut index = 0; + while index < bytes.len() { + if bytes[index] == b'#' { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } else if let Some(string_end) = string_literal_end_at(bytes, index) { + index = string_end; + } else { + match bytes[index] { + b'(' | b'[' | b'{' => { + if bytes[index] == b'[' && opening_bracket_is_class_type_params(bytes, index) { + let Some(close) = matching_delimiter(bytes, index, b']') else { + index += 1; + continue; + }; + index = close + 1; + continue; + } + stack.push(bytes[index]); + index += 1; + } + b')' | b']' | b'}' => { + stack.pop(); + index += 1; + } + _ if !stack.is_empty() && expression_continuation_keyword(bytes, index) => { + index = identifier_end(bytes, index, bytes.len()); + } + byte if !stack.is_empty() && expression_atom_start_byte(byte) => { + let atom_end = adjacent_atom_end(bytes, index).unwrap_or(index + 1); + let next = skip_ascii_whitespace(bytes, atom_end, bytes.len()); + if next > atom_end + && expression_atom_start(bytes, next) + && !expression_continuation_keyword(bytes, next) + { + return Some(( + "invalid syntax. Perhaps you forgot a comma?".to_owned(), + index, + next + 1, + )); + } + index = atom_end; + } + _ => index += 1, + } + } + } + None +} + +fn opening_bracket_is_class_type_params(bytes: &[u8], bracket: usize) -> bool { + let mut cursor = bracket; + while cursor > 0 && matches!(bytes[cursor - 1], b' ' | b'\t' | b'\x0c') { + cursor -= 1; + } + while cursor > 0 + && bytes + .get(cursor - 1) + .is_some_and(|byte| *byte >= 0x80 || is_ascii_identifier_char(*byte)) + { + cursor -= 1; + } + while cursor > 0 && matches!(bytes[cursor - 1], b' ' | b'\t' | b'\x0c') { + cursor -= 1; + } + cursor >= 5 && starts_identifier(bytes, cursor - 5, b"class") +} + +fn expression_continuation_keyword(bytes: &[u8], index: usize) -> bool { + [ + b"and".as_slice(), + b"else".as_slice(), + b"for".as_slice(), + b"if".as_slice(), + b"in".as_slice(), + b"is".as_slice(), + b"not".as_slice(), + b"or".as_slice(), + ] + .iter() + .any(|keyword| starts_identifier(bytes, index, keyword)) +} + +fn expression_atom_start(bytes: &[u8], index: usize) -> bool { + bytes + .get(index) + .is_some_and(|byte| expression_atom_start_byte(*byte)) + || string_literal_end_at(bytes, index).is_some() +} + +fn expression_atom_start_byte(byte: u8) -> bool { + byte >= 0x80 + || byte == b'_' + || byte.is_ascii_alphabetic() + || byte.is_ascii_digit() + || matches!(byte, b'\'' | b'"' | b'(' | b'[' | b'{') +} + +fn adjacent_atom_end(bytes: &[u8], index: usize) -> Option { + if let Some(string_end) = string_literal_end_at(bytes, index) { + return Some(string_end); + } + match bytes.get(index).copied()? { + byte if byte >= 0x80 || byte == b'_' || byte.is_ascii_alphabetic() => { + Some(identifier_end(bytes, index, bytes.len())) + } + byte if byte.is_ascii_digit() => { + let mut end = index + 1; + while end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') { + end += 1; + } + Some(end) + } + b'(' | b'[' | b'{' => Some(index + 1), + _ => None, + } +} + +fn unterminated_string_message( + detected_line: usize, + triple: bool, + has_escaped_quote: bool, +) -> String { + if triple { + format!("unterminated triple-quoted string literal (detected at line {detected_line})") + } else if has_escaped_quote { + format!( + "unterminated string literal (detected at line {detected_line}); perhaps you escaped the end quote?" + ) + } else { + format!("unterminated string literal (detected at line {detected_line})") + } +} + +fn expected_opening_bracket(closing: char) -> char { + match closing { + ')' => '(', + ']' => '[', + '}' => '{', + _ => unreachable!(), + } +} + +fn bracket_syntax_error(source: &str) -> Option<(String, usize, usize, bool)> { + let mut stack: Vec<(char, usize, usize)> = Vec::new(); + let mut in_string = false; + let mut string_quote = '\0'; + let mut triple_quote = false; + let mut escape_next = false; + let mut is_raw_string = false; + let mut line = 1usize; + + let chars: Vec<(usize, char)> = source.char_indices().collect(); + let mut index = 0; + while index < chars.len() { + let (byte_offset, ch) = chars[index]; + + if ch == '\n' { + line += 1; + } + + if escape_next { + escape_next = false; + index += 1; + continue; + } + + if in_string { + if ch == '\\' && !is_raw_string { + escape_next = true; + } else if triple_quote { + if ch == string_quote + && index + 2 < chars.len() + && chars[index + 1].1 == string_quote + && chars[index + 2].1 == string_quote + { + in_string = false; + index += 3; + continue; + } + } else if ch == string_quote { + in_string = false; + } + index += 1; + continue; + } + + if ch == '#' { + while index < chars.len() && chars[index].1 != '\n' { + index += 1; + } + continue; + } + + if ch == '\'' || ch == '"' { + is_raw_string = false; + for look_back in 1..=2.min(index) { + let prev = chars[index - look_back].1; + if matches!(prev, 'r' | 'R') { + is_raw_string = true; + break; + } + if !matches!(prev, 'b' | 'B' | 'f' | 'F' | 'u' | 'U') { + break; + } + } + string_quote = ch; + if index + 2 < chars.len() && chars[index + 1].1 == ch && chars[index + 2].1 == ch { + triple_quote = true; + in_string = true; + index += 3; + continue; + } + triple_quote = false; + in_string = true; + index += 1; + continue; + } -pub use rustpython_codegen::compile::CompileOpts; -pub use rustpython_compiler_core::{Mode, bytecode::CodeObject}; + match ch { + '(' | '[' | '{' => stack.push((ch, byte_offset, line)), + ')' | ']' | '}' => { + let expected = expected_opening_bracket(ch); + let Some(&(opening, _, opening_line)) = stack.last() else { + return Some((format!("unmatched '{ch}'"), byte_offset, byte_offset, false)); + }; + if opening == expected { + stack.pop(); + } else { + let suffix = if opening_line != line { + format!(" on line {opening_line}") + } else { + String::new() + }; + return Some(( + format!( + "closing parenthesis '{ch}' does not match opening parenthesis '{opening}'{suffix}" + ), + byte_offset, + byte_offset, + false, + )); + } + } + _ => {} + } -// these modules are out of repository. re-exporting them here for convenience. -pub use ruff_python_ast as ast; -pub use ruff_python_parser as parser; -pub use rustpython_codegen as codegen; -pub use rustpython_compiler_core as core; + index += 1; + } -#[derive(Error, Debug)] -pub enum CompileErrorType { - #[error(transparent)] - Codegen(#[from] codegen::error::CodegenErrorType), - #[error(transparent)] - Parse(#[from] ParseErrorType), + stack.last().map(|(opening, byte_offset, _)| { + ( + format!("'{opening}' was never closed"), + *byte_offset, + *byte_offset, + true, + ) + }) } -#[derive(Error, Debug)] -pub struct ParseError { - #[source] - pub error: ParseErrorType, - pub raw_location: ruff_text_size::TextRange, - pub location: SourceLocation, - pub end_location: SourceLocation, - pub source_path: String, - /// Set when the error is an unclosed bracket (converted from EOF). - pub is_unclosed_bracket: bool, +fn is_legacy_statement_expression_start(byte: u8) -> bool { + byte >= 0x80 + || byte == b'_' + || byte.is_ascii_alphabetic() + || byte.is_ascii_digit() + || matches!(byte, b'\'' | b'"' | b'{' | b'[') } -impl ::core::fmt::Display for ParseError { - fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { - self.error.fmt(f) +fn legacy_statement_container_has_invalid_attribute(bytes: &[u8], start: usize) -> bool { + let Some(&opening) = bytes.get(start) else { + return false; + }; + if !matches!(opening, b'{' | b'[') { + return false; } -} -#[derive(Error, Debug)] -pub enum CompileError { - #[error(transparent)] - Codegen(#[from] codegen::error::CodegenError), - #[error(transparent)] - Parse(#[from] ParseError), + let mut index = start; + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\n' | b';' if level == 0 => return false, + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'(' | b'[' | b'{' => { + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + if level == 0 { + return false; + } + } + b'.' => { + let mut cursor = index + 1; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\x0c')) { + cursor += 1; + } + if matches!(bytes.get(cursor), Some(b')' | b']' | b'}')) { + return true; + } + index += 1; + } + _ => index += 1, + } + } + false } -impl CompileError { - #[must_use] - pub fn from_ruff_parse_error(error: parser::ParseError, source_file: &SourceFile) -> Self { - let source_code = source_file.to_source_code(); - let source_text = source_file.source_text(); - - // For EOF errors (unclosed brackets), find the unclosed bracket position - // and adjust both the error location and message - let mut is_unclosed_bracket = false; - let (error_type, location, end_location) = match &error.error { - ParseErrorType::Lexical(LexicalErrorType::Eof) => { - if let Some((bracket_char, bracket_offset)) = find_unclosed_bracket(source_text) { - let bracket_text_size = ruff_text_size::TextSize::new(bracket_offset as u32); - let loc = - source_code.source_location(bracket_text_size, PositionEncoding::Utf8); - let end_loc = SourceLocation { - line: loc.line, - character_offset: loc.character_offset.saturating_add(1), - }; - let msg = format!("'{bracket_char}' was never closed"); - is_unclosed_bracket = true; - (ParseErrorType::OtherError(msg), loc, end_loc) - } else { - let loc = - source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let end_loc = - source_code.source_location(error.location.end(), PositionEncoding::Utf8); - (error.error, loc, end_loc) +fn invalid_legacy_statement_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; } } - - ParseErrorType::Lexical(LexicalErrorType::IndentationError) => { - // For IndentationError, point the offset to the end of the line content - // instead of the beginning - let loc = - source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let line_idx = loc.line.to_zero_indexed(); - let line = source_text.split('\n').nth(line_idx).unwrap_or(""); - let line_end_col = line.chars().count() + 1; // 1-indexed, past last char - let end_loc = SourceLocation { - line: loc.line, - character_offset: ruff_source_file::OneIndexed::new(line_end_col) - .unwrap_or(loc.character_offset), + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'p' | b'e' => { + let keyword = if starts_identifier(bytes, index, b"print") { + Some("print") + } else if starts_identifier(bytes, index, b"exec") { + Some("exec") + } else { + None }; - (error.error, end_loc, end_loc) + let Some(keyword) = keyword else { + index += 1; + continue; + }; + let after_keyword = index + keyword.len(); + if !matches!(bytes.get(after_keyword), Some(b' ' | b'\t' | b'\x0c')) { + index = after_keyword; + continue; + } + let mut cursor = after_keyword; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\x0c')) { + cursor += 1; + } + if legacy_statement_container_has_invalid_attribute(bytes, cursor) { + index = after_keyword; + continue; + } + if bytes.get(cursor).is_some_and(|byte| { + *byte != b'(' && is_legacy_statement_expression_start(*byte) + }) { + return Some(( + format!( + "Missing parentheses in call to '{keyword}'. Did you mean {keyword}(...)?" + ), + index, + after_keyword, + )); + } + index = after_keyword; } - ParseErrorType::ExpectedToken { expected, found } - if matches!((expected, found), (TokenKind::Comma, TokenKind::Int)) => - { - let loc = - source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let mut end_loc = - source_code.source_location(error.location.end(), PositionEncoding::Utf8); + _ => index += 1, + } + } + None +} - // If the error range ends at the start of a new line (column 1), - // adjust it to the end of the previous line - if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { - let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); - end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); - end_loc.character_offset = end_loc.character_offset.saturating_add(1); +fn long_decimal_integer_literal_error( + source: &str, + max_str_digits: usize, +) -> Option<(String, usize, usize)> { + if max_str_digits == 0 { + return None; + } + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + byte if byte >= 0x80 || byte == b'_' || byte.is_ascii_alphabetic() => { + index += 1; + while index < bytes.len() + && (bytes[index] >= 0x80 || is_ascii_identifier_char(bytes[index])) + { + index += 1; + } + } + b'.' => { + if bytes + .get(index + 1) + .is_some_and(|byte| byte.is_ascii_digit()) + { + let (_, end) = number_literal_end(bytes, index)?; + index = end.max(index + 1); + } else { + index += 1; } - let msg = "invalid syntax. Perhaps you forgot a comma?".into(); - (ParseErrorType::OtherError(msg), loc, end_loc) } + b'0'..=b'9' => { + if bytes.get(index) == Some(&b'0') + && matches!( + bytes.get(index + 1), + Some(b'x' | b'X' | b'o' | b'O' | b'b' | b'B') + ) + { + let Some((_, end)) = number_literal_end(bytes, index) else { + index += 1; + continue; + }; + index = end.max(index + 1); + continue; + } - ParseErrorType::InvalidAssignmentTarget => { - let loc = - source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let mut end_loc = - source_code.source_location(error.location.end(), PositionEncoding::Utf8); - - // If the error range ends at the start of a new line (column 1), - // adjust it to the end of the previous line - if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { - let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); - end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); - end_loc.character_offset = end_loc.character_offset.saturating_add(1); - } - - let expr_str = source_file.source_text().slice(error.location); - - let msg = parser::parse_expression(expr_str).map_or_else( - |_| match expr_str { - "yield" => "assignment to yield expression not possible".into(), - _ => format!("cannot assign to {expr_str}"), - }, - |parsed| match *parsed.syntax().body { - ast::Expr::Call(_) => "cannot assign to function call".into(), - ast::Expr::BinOp(_) => "cannot assign to expression".into(), - ast::Expr::If(_) => "cannot assign to conditional expression".into(), - ast::Expr::Generator(_) => "cannot assign to generator expression".into(), - ast::Expr::StringLiteral(_) - | ast::Expr::BytesLiteral(_) - | ast::Expr::NumberLiteral(_) => { - "cannot assign to literal here. Maybe you meant '==' instead of '='?" - .into() + let start = index; + let mut digits = 0usize; + while index < bytes.len() { + match bytes[index] { + b'0'..=b'9' => { + digits += 1; + index += 1; } - ast::Expr::EllipsisLiteral(_) => { - "cannot assign to ellipsis here. Maybe you meant '==' instead of '='?" - .into() + b'_' if bytes + .get(index + 1) + .is_some_and(|byte| byte.is_ascii_digit()) => + { + index += 1; } - _ => format!("cannot assign to {expr_str}"), - }, - ); - - (ParseErrorType::OtherError(msg), loc, end_loc) + _ => break, + } + } + if matches!(bytes.get(index), Some(b'.' | b'e' | b'E' | b'j' | b'J')) { + let Some((_, end)) = number_literal_end(bytes, start) else { + continue; + }; + index = end.max(index + 1); + continue; + } + if digits > max_str_digits { + return Some(( + format!( + "Exceeds the limit ({max_str_digits} digits) for integer string conversion: value has {digits} digits; use sys.set_int_max_str_digits() to increase the limit - Consider hexadecimal for huge integer literals to avoid decimal conversion limits." + ), + start, + start, + )); + } } + _ => index += 1, + } + } + None +} - ParseErrorType::InvalidNamedAssignmentTarget => { - let loc = - source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let mut end_loc = - source_code.source_location(error.location.end(), PositionEncoding::Utf8); - - // If the error range ends at the start of a new line (column 1), - // adjust it to the end of the previous line - if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { - let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); - end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); - end_loc.character_offset = end_loc.character_offset.saturating_add(1); +fn invalid_parenthesized_import_star_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; } - - let target = source_file.source_text().slice(error.location); - let msg = format!("cannot use assignment expressions with {target}"); - (ParseErrorType::OtherError(msg), loc, end_loc) } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'f' if starts_identifier(bytes, index, b"from") => { + let mut cursor = index + 4; + while cursor < bytes.len() && !matches!(bytes[cursor], b'\n' | b';') { + if starts_identifier(bytes, cursor, b"import") { + cursor += 6; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\r')) { + cursor += 1; + } + if bytes.get(cursor) == Some(&b'(') { + cursor += 1; + while cursor < bytes.len() + && !matches!(bytes[cursor], b')' | b'\n' | b';') + { + if bytes[cursor] == b'*' { + return Some(("invalid syntax".to_owned(), cursor, cursor + 1)); + } + cursor += 1; + } + } + break; + } + cursor += 1; + } + index = cursor; + } + _ => index += 1, + } + } + None +} - _ => { - let loc = - source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let mut end_loc = - source_code.source_location(error.location.end(), PositionEncoding::Utf8); +fn too_many_nested_parentheses_error(source: &str) -> Option<(String, usize, usize)> { + const MAXLEVEL: usize = 200; - // If the error range ends at the start of a new line (column 1), - // adjust it to the end of the previous line - if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { - let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); - end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); - end_loc.character_offset = end_loc.character_offset.saturating_add(1); + let bytes = source.as_bytes(); + let mut index = 0; + let mut level = 0usize; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b'(' | b'[' | b'{' => { + if level >= MAXLEVEL { + return Some(("too many nested parentheses".to_owned(), index, index + 1)); } + level += 1; + index += 1; + } + b')' | b']' | b'}' => { + level = level.saturating_sub(1); + index += 1; + } + _ => index += 1, + } + } + None +} - (error.error, loc, end_loc) +fn invalid_unparenthesized_yield_after_comma_error(source: &str) -> Option<(String, usize, usize)> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } } - }; + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + b',' => { + let mut cursor = index + 1; + while matches!(bytes.get(cursor), Some(b' ' | b'\t' | b'\x0c')) { + cursor += 1; + } + if starts_identifier(bytes, cursor, b"yield") { + return Some(("invalid syntax".to_owned(), cursor, cursor + 5)); + } + index += 1; + } + _ => index += 1, + } + } + None +} - Self::Parse(ParseError { - error: error_type, - raw_location: error.location, - location, - end_location, - source_path: source_file.name().to_owned(), - is_unclosed_bracket, +fn post_parse_source_error(source_file: &SourceFile, opts: &CompileOpts) -> Option { + too_many_nested_parentheses_error(source_file.source_text()) + .or_else(|| { + long_decimal_integer_literal_error(source_file.source_text(), opts.int_max_str_digits) }) - } + .or_else(|| invalid_call_argument_error(source_file.source_text())) + .or_else(|| invalid_match_mapping_rest_wildcard_error(source_file.source_text())) + .or_else(|| invalid_unparenthesized_yield_after_comma_error(source_file.source_text())) + .or_else(|| invalid_parenthesized_import_star_error(source_file.source_text())) + .map(|(message, start, end)| { + CompileError::from_source_error(source_file, message, start, end) + }) +} - #[must_use] - pub const fn location(&self) -> Option { - match self { - Self::Codegen(codegen_error) => codegen_error.location, - Self::Parse(parse_error) => Some(parse_error.location), - } +fn is_compound_stmt(stmt: &ast::Stmt) -> bool { + matches!( + stmt, + ast::Stmt::FunctionDef(_) + | ast::Stmt::ClassDef(_) + | ast::Stmt::If(_) + | ast::Stmt::For(_) + | ast::Stmt::While(_) + | ast::Stmt::With(_) + | ast::Stmt::Try(_) + | ast::Stmt::Match(_) + ) +} + +fn single_mode_body_error(body: &[ast::Stmt], source_file: &SourceFile) -> Option { + let first = body.first()?; + let source_code = source_file.to_source_code(); + let first_start = source_code.source_location(first.range().start(), PositionEncoding::Utf8); + let first_end = source_code.source_location(first.range().end(), PositionEncoding::Utf8); + + if body.iter().skip(1).any(|stmt| { + source_code + .source_location(stmt.range().start(), PositionEncoding::Utf8) + .line + > first_start.line + }) { + return Some(CompileError::from_source_error( + source_file, + "multiple statements found while compiling a single statement".to_owned(), + first.range().end().to_usize(), + first.range().end().to_usize(), + )); } - #[must_use] - pub const fn python_location(&self) -> (usize, usize) { - if let Some(location) = self.location() { - (location.line.get(), location.character_offset.get()) - } else { - (0, 0) - } + if is_compound_stmt(first) + && first_start.line == first_end.line + && !ends_with_line_break(source_file.source_text()) + { + return Some(CompileError::from_source_error( + source_file, + "invalid syntax".to_owned(), + first.range().start().to_usize(), + first.range().start().to_usize(), + )); } + None +} - #[must_use] - pub fn python_end_location(&self) -> Option<(usize, usize)> { - match self { - Self::Codegen(_) => None, - Self::Parse(parse_error) => Some(( - parse_error.end_location.line.get(), - parse_error.end_location.character_offset.get(), - )), +fn single_mode_source_error(ast: &ast::Mod, source_file: &SourceFile) -> Option { + let ast::Mod::Module(module) = ast else { + return None; + }; + single_mode_body_error(&module.body, source_file) +} + +fn ends_with_line_break(source: &str) -> bool { + source.ends_with('\n') || source.ends_with('\r') +} + +fn ends_with_implied_dedent(source: &str) -> bool { + let mut lexer = parser::lexer::lex(source, parser::Mode::Module); + let mut last_kind = TokenKind::EndOfFile; + loop { + let kind = lexer.next_token(); + if kind.is_eof() { + break; } + last_kind = kind; } + matches!(last_kind, TokenKind::Dedent) +} - #[must_use] - pub fn source_path(&self) -> &str { - match self { - Self::Codegen(codegen_error) => &codegen_error.source_path, - Self::Parse(parse_error) => &parse_error.source_path, - } +/// Detect input that only parses because Ruff's lexer closes indentation at EOF. +/// +/// CPython's `PyCF_DONT_IMPLY_DEDENT` is used by `codeop` and interactive compile +/// paths to keep an indented block incomplete until a terminating newline is seen. +#[must_use] +pub fn dont_imply_dedent_source_error(source_file: &SourceFile) -> Option { + let source = source_file.source_text(); + if ends_with_line_break(source) || !ends_with_implied_dedent(source) { + return None; } + let eof = source.len(); + Some(CompileError::from_source_error( + source_file, + "incomplete input".to_owned(), + eof, + eof, + )) } /// Find the last unclosed opening bracket in source code. @@ -378,6 +5253,15 @@ fn _compile( source_file: SourceFile, mode: Mode, opts: CompileOpts, +) -> Result { + _compile_with_syntax_warning_handler(source_file, mode, opts, None) +} + +fn _compile_with_syntax_warning_handler<'a>( + source_file: SourceFile, + mode: Mode, + opts: CompileOpts, + syntax_warning_handler: Option<&'a mut compile::SyntaxWarningHandler<'a>>, ) -> Result { let parser_mode = match mode { Mode::Exec => parser::Mode::Module, @@ -386,10 +5270,49 @@ fn _compile( // since these are only different in terms of compilation Mode::Single | Mode::BlockExpr => parser::Mode::Module, }; - let parsed = parser::parse(source_file.source_text(), parser_mode.into()) + let parser_options = parser::ParseOptions::from(parser_mode); + let parsed = parser::parse(source_file.source_text(), parser_options) .map_err(|err| CompileError::from_ruff_parse_error(err, &source_file))?; + if opts.dont_imply_dedent + && matches!(mode, Mode::Single) + && let Some(error) = dont_imply_dedent_source_error(&source_file) + { + return Err(error); + } + if let Some(error) = post_parse_source_error(&source_file, &opts) { + return Err(error); + } let ast = parsed.into_syntax(); - compile::compile_top(ast, source_file, mode, opts).map_err(|e| e.into()) + let single_mode_error = matches!(mode, Mode::Single) + .then(|| single_mode_source_error(&ast, &source_file)) + .flatten(); + let code = compile::compile_top_with_syntax_warning_handler( + ast, + source_file, + mode, + opts, + syntax_warning_handler, + ) + .map_err(CompileError::from)?; + if let Some(error) = single_mode_error { + return Err(error); + } + Ok(code) +} + +pub fn compile_with_syntax_warning_handler<'a>( + source: &str, + mode: Mode, + source_path: &str, + opts: CompileOpts, + syntax_warning_handler: &'a mut compile::SyntaxWarningHandler<'a>, +) -> Result { + let source = source.replace("\r\n", "\n"); + #[cfg(windows)] + let source = source.as_str(); + + let source_file = SourceFileBuilder::new(source_path, source).finish(); + _compile_with_syntax_warning_handler(source_file, mode, opts, Some(syntax_warning_handler)) } pub fn compile_symtable( @@ -409,7 +5332,16 @@ pub fn _compile_symtable( Mode::Exec | Mode::Single | Mode::BlockExpr => { let ast = ruff_python_parser::parse_module(source_file.source_text()) .map_err(|e| CompileError::from_ruff_parse_error(e, &source_file))?; - symboltable::SymbolTable::scan_program(&ast.into_syntax(), source_file.clone()) + if let Some(error) = post_parse_source_error(&source_file, &CompileOpts::default()) { + return Err(error); + } + let ast = ast.into_syntax(); + if matches!(mode, Mode::Single) + && let Some(error) = single_mode_body_error(&ast.body, &source_file) + { + return Err(error); + } + symboltable::SymbolTable::scan_program(&ast, source_file.clone()) } Mode::Eval => { let ast = ruff_python_parser::parse( @@ -417,6 +5349,9 @@ pub fn _compile_symtable( parser::Mode::Expression.into(), ) .map_err(|e| CompileError::from_ruff_parse_error(e, &source_file))?; + if let Some(error) = post_parse_source_error(&source_file, &CompileOpts::default()) { + return Err(error); + } symboltable::SymbolTable::scan_expr( &ast.into_syntax().expect_expression(), source_file.clone(), @@ -437,6 +5372,21 @@ mod tests { dbg!(compiled.expect("compile error")); } + #[test] + fn dont_imply_dedent_requires_terminating_newline() { + let code = "if True:\n pass"; + + let opts = CompileOpts { + dont_imply_dedent: true, + ..CompileOpts::default() + }; + let err = compile(code, Mode::Single, "<>", opts.clone()).expect_err("compile succeeded"); + assert_eq!(err.to_string(), "incomplete input"); + + compile("if True:\n pass\n", Mode::Single, "<>", opts).expect("compile error"); + compile(code, Mode::Single, "<>", CompileOpts::default()).expect("compile error"); + } + #[test] fn compile_phello() { let code = r#" @@ -501,6 +5451,20 @@ def f(): dbg!(compiled.expect("compile error")); } + #[test] + fn compile_call_arg_lambda_default() { + let code = "signature((lambda a=10: a))"; + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } + + #[test] + fn compile_generic_function_parameter_default() { + let code = "def __repr__[T: str](self, default: T = '') -> str: pass"; + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } + #[test] fn compile_int() { let code = r#" diff --git a/crates/stdlib/src/_opcode.rs b/crates/stdlib/src/_opcode.rs index 2b2a70b5572..e6fc3276c31 100644 --- a/crates/stdlib/src/_opcode.rs +++ b/crates/stdlib/src/_opcode.rs @@ -205,7 +205,7 @@ mod tests { let scope = vm.new_scope_with_builtins(); let code_obj = vm .compile(source.trim(), Mode::Exec, FNAME) - .map_err(|err| vm.new_syntax_error(&err, Some(source))) + .map_err(|err| err.into_pyexception(vm, Some(source))) .unwrap(); scope.globals.set_item("code", code_obj.into(), vm).unwrap(); @@ -228,7 +228,7 @@ output = re.sub(r'(0xdeadbeef', tmp let py_code_obj = vm .compile(py_source, Mode::Exec, FNAME) - .map_err(|err| vm.new_syntax_error(&err, Some(py_source))) + .map_err(|err| err.into_pyexception(vm, Some(py_source))) .unwrap(); vm.run_code_obj(py_code_obj, scope.clone()).unwrap(); diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__bare_function_annotations_check_attribute_and_subscript_expressions.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__bare_function_annotations_check_attribute_and_subscript_expressions.snap index 3274352b920..4d78128b5e6 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__bare_function_annotations_check_attribute_and_subscript_expressions.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__bare_function_annotations_check_attribute_and_subscript_expressions.snap @@ -1,5 +1,6 @@ --- source: crates/stdlib/src/_opcode.rs +assertion_line: 318 expression: "dis(r#\"\ndef f(one: int):\n int.new_attr: int\n [list][0].new_attr: [int, str]\n my_lst = [1]\n my_lst[one]: int\n return my_lst\n\"#)" --- 0 RESUME 0 @@ -15,7 +16,7 @@ expression: "dis(r#\"\ndef f(one: int):\n int.new_attr: int\n [list][0].ne Disassembly of ", line 1>: 1 RESUME 0 - LOAD_FAST_BORROW 0 (format) + LOAD_FAST_CHECK 0 (format) LOAD_SMALL_INT 2 COMPARE_OP 132 (>) POP_JUMP_IF_FALSE 3 (to L1) diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__const_no_op.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__const_no_op.snap index 347e58767ae..dc97f6b79c1 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__const_no_op.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__const_no_op.snap @@ -1,6 +1,7 @@ --- source: crates/stdlib/src/_opcode.rs -expression: x = not True +assertion_line: 281 +expression: "dis(r#\"\nx = not True\n\"#)" --- 0 RESUME 0 diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__constant_true_if_pass_keeps_line_anchor_nop.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__constant_true_if_pass_keeps_line_anchor_nop.snap index 02e2473501d..3de37ce2009 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__constant_true_if_pass_keeps_line_anchor_nop.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__constant_true_if_pass_keeps_line_anchor_nop.snap @@ -1,6 +1,7 @@ --- source: crates/stdlib/src/_opcode.rs -expression: "if 1:\n pass" +assertion_line: 290 +expression: "dis(r#\"\nif 1:\n pass\n\"#)" --- 0 RESUME 0 diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ands.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ands.snap index b5957dda5e5..5c58a2b6b85 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ands.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ands.snap @@ -1,6 +1,7 @@ --- source: crates/stdlib/src/_opcode.rs -expression: "if True and False and False:\n pass" +assertion_line: 252 +expression: "dis(r#\"\nif True and False and False:\n pass\n\"#)" --- 0 RESUME 0 diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_mixed.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_mixed.snap index f8976b8c6e5..6bef04ee143 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_mixed.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_mixed.snap @@ -1,6 +1,7 @@ --- source: crates/stdlib/src/_opcode.rs -expression: "if (True and False) or (False and True):\n pass" +assertion_line: 262 +expression: "dis(r#\"\nif (True and False) or (False and True):\n pass\n\"#)" --- 0 RESUME 0 diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ors.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ors.snap index f8cc3a1f28f..065d893732e 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ors.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__if_ors.snap @@ -1,6 +1,7 @@ --- source: crates/stdlib/src/_opcode.rs -expression: "if True or False or False:\n pass" +assertion_line: 242 +expression: "dis(r#\"\nif True or False or False:\n pass\n\"#)" --- 0 RESUME 0 diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_bool_op.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_bool_op.snap index c0e3659487b..00eeb277455 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_bool_op.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_bool_op.snap @@ -1,6 +1,7 @@ --- source: crates/stdlib/src/_opcode.rs -expression: x = Test() and False or False +assertion_line: 272 +expression: "dis(r#\"\nx = Test() and False or False\n\"#)" --- 0 RESUME 0 diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 83e41fa1f5f..b3479e017a1 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -70,6 +70,7 @@ static_assertions = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } thiserror = { workspace = true } +thin-vec = { workspace = true } memchr = { workspace = true } flamer = { workspace = true, optional = true } diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 89321189b09..19fca5cf473 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -34,7 +34,7 @@ use core::{ ops::Deref, pin::Pin, ptr::NonNull, - sync::atomic::{AtomicBool, AtomicPtr, AtomicU32, Ordering}, + sync::atomic::{AtomicBool, AtomicPtr, AtomicU32, AtomicU64, Ordering}, }; use indexmap::{IndexMap, map::Entry}; use itertools::Itertools; @@ -53,6 +53,7 @@ pub struct PyType { pub heaptype_ext: Option>>, /// Type version tag for inline caching. 0 means unassigned/invalidated. pub tp_version_tag: AtomicU32, + pub abc_tpflags: AtomicU64, } /// Monotonic counter for type version tags. Once it reaches `u32::MAX`, @@ -590,7 +591,9 @@ impl PyType { // Check each base in order and inherit the first collection flag found for base in bases { - let base_flags = base.slots.flags & COLLECTION_FLAGS; + let base_flags = (base.slots.flags + | PyTypeFlags::from_bits_truncate(base.abc_tpflags.load(Ordering::Acquire))) + & COLLECTION_FLAGS; if !base_flags.is_empty() { slots.flags |= base_flags; return; @@ -598,6 +601,58 @@ impl PyType { } } + fn inherited_abc_tpflags(bases: &[PyRef]) -> u64 { + const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate( + PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(), + ); + for base in bases { + let base_flags = + PyTypeFlags::from_bits_truncate(base.abc_tpflags.load(Ordering::Acquire)) + & COLLECTION_FLAGS; + if !base_flags.is_empty() { + return base_flags.bits(); + } + } + 0 + } + + pub fn has_patma_collection_flag(&self, flag: PyTypeFlags) -> bool { + debug_assert!(matches!(flag, PyTypeFlags::SEQUENCE | PyTypeFlags::MAPPING)); + const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate( + PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(), + ); + let slot_flags = self.slots.flags & COLLECTION_FLAGS; + if !slot_flags.is_empty() { + return slot_flags.contains(flag); + } + PyTypeFlags::from_bits_truncate(self.abc_tpflags.load(Ordering::Acquire)).contains(flag) + } + + pub fn set_abc_collection_flags_recursive(&self, flags: PyTypeFlags) { + const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate( + PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(), + ); + let flags = flags & COLLECTION_FLAGS; + if flags.is_empty() { + return; + } + let collection_bits = COLLECTION_FLAGS.bits(); + let flags_bits = flags.bits(); + let _ = self + .abc_tpflags + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |old| { + Some((old & !collection_bits) | flags_bits) + }); + self.modified(); + for weak_ref in self.subclasses.read().iter() { + if let Some(subclass) = weak_ref.upgrade() + && let Some(subclass) = subclass.downcast_ref::() + { + subclass.set_abc_collection_flags_recursive(flags); + } + } + } + /// Check for __abc_tpflags__ and set the appropriate flags /// This checks in attrs and all base classes for __abc_tpflags__ fn check_abc_tpflags( @@ -626,21 +681,19 @@ impl PyType { .to_owned(), ); } - // Don't override flags already inherited from a base class. - if !slots.flags.intersects(COLLECTION_FLAGS) { - slots.flags |= masked; - } + slots.flags.remove(COLLECTION_FLAGS); + slots.flags |= masked; return Ok(()); } - // No __abc_tpflags__ on this class — inheritance already happened - // in inherit_patma_flags, so nothing more to do if those bits are set. + // No __abc_tpflags__ on this class. Inheritance already happened in + // inherit_patma_flags, using base order and including ABC markers. if slots.flags.intersects(COLLECTION_FLAGS) { return Ok(()); } - // Then check in base classes (legacy path for cases that bypass - // inherit_patma_flags). + // Then check in base classes for legacy paths that bypassed + // inherit_patma_flags. for base in bases { if let Some(abc_tpflags_obj) = base.find_name_in_mro(abc_tpflags_name) && let Some(int_obj) = abc_tpflags_obj.downcast_ref::() @@ -654,6 +707,7 @@ impl PyType { .to_owned(), ); } + slots.flags.remove(COLLECTION_FLAGS); slots.flags |= masked; return Ok(()); } @@ -716,6 +770,7 @@ impl PyType { )); } + let inherited_abc_tpflags = Self::inherited_abc_tpflags(&bases); let new_type = PyRef::new_ref( Self { base: Some(base), @@ -726,6 +781,7 @@ impl PyType { slots, heaptype_ext: Some(Pin::new(Box::new(heaptype_ext))), tp_version_tag: AtomicU32::new(0), + abc_tpflags: AtomicU64::new(inherited_abc_tpflags), }, metaclass, None, @@ -775,6 +831,7 @@ impl PyType { slots.flags |= PyTypeFlags::MANAGED_WEAKREF; } + let inherited_abc_tpflags = Self::inherited_abc_tpflags(core::slice::from_ref(&base)); let bases = PyRwLock::new(vec![base.clone()]); let mro = base.mro_map_collect(|x| x.to_owned()); @@ -788,6 +845,7 @@ impl PyType { slots, heaptype_ext: None, tp_version_tag: AtomicU32::new(0), + abc_tpflags: AtomicU64::new(inherited_abc_tpflags), }, metaclass, None, diff --git a/crates/vm/src/eval.rs b/crates/vm/src/eval.rs index 5f52799d0b9..5a3a804688f 100644 --- a/crates/vm/src/eval.rs +++ b/crates/vm/src/eval.rs @@ -6,7 +6,7 @@ pub fn eval(vm: &VirtualMachine, source: &str, scope: Scope, source_path: &str) debug!("Code object: {bytecode:?}"); vm.run_code_obj(bytecode, scope) } - Err(err) => Err(vm.new_syntax_error(&err, Some(source))), + Err(err) => Err(err.into_pyexception(vm, Some(source))), } } diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index df22f4d822d..1ad86a35aed 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -2480,12 +2480,18 @@ pub(super) mod types { let maybe_lineno = zelf .as_object() .get_attr("lineno", vm) - .and_then(|obj| obj.str_utf8(vm)) - .ok(); - let maybe_filename = zelf.as_object().get_attr("filename", vm).ok().map(|obj| { - obj.str(vm) - .unwrap_or_else(|_| vm.ctx.new_str("")) - }); + .ok() + .filter(|obj| !vm.is_none(obj)) + .and_then(|obj| obj.str_utf8(vm).ok()); + let maybe_filename = zelf + .as_object() + .get_attr("filename", vm) + .ok() + .filter(|obj| !vm.is_none(obj)) + .map(|obj| { + obj.str(vm) + .unwrap_or_else(|_| vm.ctx.new_str("")) + }); let msg = match zelf.as_object().get_attr("msg", vm) { Ok(obj) => obj diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 7050d851743..a05bf47c0c8 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -3050,9 +3050,15 @@ impl ExecutingFrame<'_> { let subject = self.pop_value(); let nargs_val = nargs.get(arg) as usize; + let Some(cls_type) = cls.downcast_ref::() else { + return Err(vm.new_type_error("called match pattern must be a class")); + }; + let type_name = cls_type.name().to_string(); + // Check if subject is an instance of cls if subject.is_instance(cls.as_ref(), vm)? { let mut extracted = vec![]; + let seen_attrs = PySet::default().into_ref(&vm.ctx); // Get __match_args__ for positional arguments if nargs > 0 if nargs_val > 0 { @@ -3081,9 +3087,11 @@ impl ExecutingFrame<'_> { // Check if we have enough match args if match_args.len() < nargs_val { + let plural = if match_args.len() == 1 { "" } else { "s" }; return Err(vm.new_type_error(format!( - "class pattern accepts at most {} positional sub-patterns ({} given)", + "{type_name}() accepts {} positional sub-pattern{} ({} given)", match_args.len(), + plural, nargs_val ))); } @@ -3094,11 +3102,19 @@ impl ExecutingFrame<'_> { let attr_name_str = match attr_name.downcast_ref::() { Some(s) => s, None => { - return Err(vm.new_type_error( - "__match_args__ elements must be strings", - )); + let attr_type_name = attr_name.class().name(); + return Err(vm.new_type_error(format!( + "__match_args__ elements must be strings (got {attr_type_name})" + ))); } }; + if seen_attrs.__contains__(attr_name.as_object(), vm)? { + let attr_repr = attr_name.as_object().repr(vm)?; + return Err(vm.new_type_error(format!( + "{type_name}() got multiple sub-patterns for attribute {attr_repr}" + ))); + } + seen_attrs.add(attr_name.clone(), vm)?; match subject.get_attr(attr_name_str, vm) { Ok(value) => extracted.push(value), Err(e) @@ -3115,9 +3131,8 @@ impl ExecutingFrame<'_> { // No __match_args__, check if this is a type with MATCH_SELF behavior // For built-in types like bool, int, str, list, tuple, dict, etc. // they match the subject itself as the single positional argument - let is_match_self_type = cls - .downcast::() - .is_ok_and(|t| t.slots.flags.contains(PyTypeFlags::_MATCH_SELF)); + let is_match_self_type = + cls_type.slots.flags.contains(PyTypeFlags::_MATCH_SELF); if is_match_self_type { if nargs_val == 1 { @@ -3125,16 +3140,16 @@ impl ExecutingFrame<'_> { extracted.push(subject.clone()); } else if nargs_val > 1 { // Too many positional arguments for MATCH_SELF - return Err(vm.new_type_error( - "class pattern accepts at most 1 positional sub-pattern for MATCH_SELF types", - )); + return Err(vm.new_type_error(format!( + "{type_name}() accepts 1 positional sub-pattern ({nargs_val} given)" + ))); } } else { // No __match_args__ and not a MATCH_SELF type if nargs_val > 0 { - return Err(vm.new_type_error( - "class pattern defines no positional sub-patterns (__match_args__ missing)", - )); + return Err(vm.new_type_error(format!( + "{type_name}() accepts 0 positional sub-patterns ({nargs_val} given)" + ))); } } } @@ -3143,6 +3158,13 @@ impl ExecutingFrame<'_> { // Extract keyword attributes for name in kwd_attrs { let name_str = name.downcast_ref::().unwrap(); + if seen_attrs.__contains__(name_str.as_object(), vm)? { + let attr_repr = name.as_object().repr(vm)?; + return Err(vm.new_type_error(format!( + "{type_name}() got multiple sub-patterns for attribute {attr_repr}" + ))); + } + seen_attrs.add(name.clone(), vm)?; match subject.get_attr(name_str, vm) { Ok(value) => extracted.push(value), Err(e) if e.fast_isinstance(vm.ctx.exceptions.attribute_error) => { @@ -3166,10 +3188,14 @@ impl ExecutingFrame<'_> { let subject = self.nth_value(1); // stack[-2] // Check if subject is a mapping and extract values for keys - if subject.class().slots.flags.contains(PyTypeFlags::MAPPING) { + if subject + .class() + .has_patma_collection_flag(PyTypeFlags::MAPPING) + { let keys = keys_tuple.downcast_ref::().unwrap(); let mut values = Vec::new(); let mut all_match = true; + let seen_keys = PySet::default().into_ref(&vm.ctx); // We use the two argument form of map.get(key, default) for two reasons: // - Atomically check for a key and get its value without error handling. @@ -3186,6 +3212,13 @@ impl ExecutingFrame<'_> { .new_base_object(vm.ctx.types.object_type.to_owned(), None); for key in keys { + if seen_keys.__contains__(key.as_object(), vm)? { + return Err(vm.new_value_error(format!( + "mapping pattern checks duplicate key ({})", + key.as_object().repr(vm)? + ))); + } + seen_keys.add(key.as_object().to_owned(), vm)?; // value = map.get(key, dummy) match get_method.call((key.as_object(), dummy.clone()), vm) { Ok(value) => { @@ -3202,6 +3235,13 @@ impl ExecutingFrame<'_> { } else { // Fallback if .get() method is not available (shouldn't happen for mappings) for key in keys { + if seen_keys.__contains__(key.as_object(), vm)? { + return Err(vm.new_value_error(format!( + "mapping pattern checks duplicate key ({})", + key.as_object().repr(vm)? + ))); + } + seen_keys.add(key.as_object().to_owned(), vm)?; match subject.get_item(key.as_object(), vm) { Ok(value) => values.push(value), Err(e) if e.fast_isinstance(vm.ctx.exceptions.key_error) => { @@ -3231,7 +3271,9 @@ impl ExecutingFrame<'_> { let subject = self.pop_value(); // Check if the type has the MAPPING flag - let is_mapping = subject.class().slots.flags.contains(PyTypeFlags::MAPPING); + let is_mapping = subject + .class() + .has_patma_collection_flag(PyTypeFlags::MAPPING); self.push_value(subject); self.push_value(vm.ctx.new_bool(is_mapping).into()); @@ -3242,7 +3284,9 @@ impl ExecutingFrame<'_> { let subject = self.pop_value(); // Check if the type has the SEQUENCE flag - let is_sequence = subject.class().slots.flags.contains(PyTypeFlags::SEQUENCE); + let is_sequence = subject + .class() + .has_patma_collection_flag(PyTypeFlags::SEQUENCE); self.push_value(subject); self.push_value(vm.ctx.new_bool(is_sequence).into()); @@ -6845,7 +6889,7 @@ impl ExecutingFrame<'_> { } } - fn execute_unpack_ex(&mut self, vm: &VirtualMachine, before: u8, after: u8) -> FrameResult { + fn execute_unpack_ex(&mut self, vm: &VirtualMachine, before: u8, after: u32) -> FrameResult { let (before, after) = (before as usize, after as usize); let value = self.pop_value(); let not_iterable = value.class().slots.iter.load().is_none() diff --git a/crates/vm/src/import.rs b/crates/vm/src/import.rs index 5c418b35d67..f7cc03d991e 100644 --- a/crates/vm/src/import.rs +++ b/crates/vm/src/import.rs @@ -141,7 +141,7 @@ pub fn import_file( file_path, vm.compile_opts(), ) - .map_err(|err| vm.new_syntax_error(&err, Some(content)))?; + .map_err(|err| err.into_pyexception(vm, Some(content)))?; import_code_obj(vm, module_name, code, true) } @@ -154,7 +154,7 @@ pub fn import_source(vm: &VirtualMachine, module_name: &str, content: &str) -> P "", vm.compile_opts(), ) - .map_err(|err| vm.new_syntax_error(&err, Some(content)))?; + .map_err(|err| err.into_pyexception(vm, Some(content)))?; import_code_obj(vm, module_name, code, false) } diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index 36bc0df0c74..88ca646a4f1 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -2479,6 +2479,7 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { slots: PyType::make_slots(), heaptype_ext: None, tp_version_tag: core::sync::atomic::AtomicU32::new(0), + abc_tpflags: core::sync::atomic::AtomicU64::new(0), }; let object_payload = PyType { base: None, @@ -2489,6 +2490,7 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { slots: object::PyBaseObject::make_slots(), heaptype_ext: None, tp_version_tag: core::sync::atomic::AtomicU32::new(0), + abc_tpflags: core::sync::atomic::AtomicU64::new(0), }; // Both type_type and object_type are instances of `type`, which has // HAS_DICT and HAS_WEAKREF, so they need both ObjExt and WeakRefList prefixes. @@ -2585,6 +2587,7 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { slots: PyWeak::make_slots(), heaptype_ext: None, tp_version_tag: core::sync::atomic::AtomicU32::new(0), + abc_tpflags: core::sync::atomic::AtomicU64::new(0), }; let weakref_type = PyRef::new_ref(weakref_type, type_type.clone(), None); // Static type: untrack from GC (was tracked by new_ref because PyType has HAS_TRAVERSE) diff --git a/crates/vm/src/protocol/callable.rs b/crates/vm/src/protocol/callable.rs index c9afbd5afb0..70e0a54dcec 100644 --- a/crates/vm/src/protocol/callable.rs +++ b/crates/vm/src/protocol/callable.rs @@ -207,7 +207,7 @@ impl VirtualMachine { event: TraceEvent, arg: Option, ) -> PyResult> { - if self.use_tracing.get() { + if self.use_tracing.get() && !self.tracing_is_suppressed() { self._trace_event_inner(event, arg) } else { Ok(None) @@ -247,7 +247,9 @@ impl VirtualMachine { // tracing function itself. if is_trace_event && !self.is_none(&trace_func) { self.use_tracing.set(false); + self.enter_tracing(); let res = trace_func.call(args.clone(), self); + self.leave_tracing(); self.use_tracing.set(true); match res { Ok(result) => { @@ -268,7 +270,9 @@ impl VirtualMachine { if is_profile_event && !self.is_none(&profile_func) { self.use_tracing.set(false); + self.enter_tracing(); let res = profile_func.call(args, self); + self.leave_tracing(); self.use_tracing.set(true); if res.is_err() { *self.profile_func.borrow_mut() = self.ctx.none(); diff --git a/crates/vm/src/stdlib/_abc.rs b/crates/vm/src/stdlib/_abc.rs index 6cdef861253..58646035521 100644 --- a/crates/vm/src/stdlib/_abc.rs +++ b/crates/vm/src/stdlib/_abc.rs @@ -9,11 +9,11 @@ pub(crate) use _abc::module_def; mod _abc { use crate::{ AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyFrozenSet, PyList, PySet, PyStr, PyTupleRef, PyTypeRef, PyWeak}, + builtins::{PyFrozenSet, PyList, PySet, PyStr, PyTupleRef, PyType, PyTypeRef, PyWeak}, common::lock::PyRwLock, convert::ToPyObject, protocol::PyIterReturn, - types::Constructor, + types::{Constructor, PyTypeFlags}, }; use core::sync::atomic::{AtomicU64, Ordering}; @@ -238,6 +238,23 @@ mod _abc { // Invalidate negative cache increment_invalidation_counter(); + if let Some(cls_type) = cls.downcast_ref::() + && let Some(subclass_type) = subclass.downcast_ref::() + { + // CPython _abc_register propagates Py_TPFLAGS_SEQUENCE/MAPPING + // recursively so MATCH_SEQUENCE/MATCH_MAPPING see ABC registration. + let collection_mask = PyTypeFlags::SEQUENCE | PyTypeFlags::MAPPING; + let collection_flags = (cls_type.slots.flags + | PyTypeFlags::from_bits_truncate(cls_type.abc_tpflags.load(Ordering::Acquire))) + & collection_mask; + if !subclass_type.is(vm.ctx.types.str_type) + && !subclass_type.is(vm.ctx.types.bytes_type) + && !subclass_type.is(vm.ctx.types.bytearray_type) + { + subclass_type.set_abc_collection_flags_recursive(collection_flags); + } + } + Ok(subclass) } diff --git a/crates/vm/src/stdlib/_ast.rs b/crates/vm/src/stdlib/_ast.rs index 6bbbbefe504..5f22e87e3d0 100644 --- a/crates/vm/src/stdlib/_ast.rs +++ b/crates/vm/src/stdlib/_ast.rs @@ -4,18 +4,19 @@ //! This module makes use of the parser logic, and translates all ast nodes //! into python ast.AST objects. +use alloc::sync::Arc; + pub(crate) use python::_ast::module_def; mod pyast; use crate::builtins::{PyInt, PyStr}; -use crate::stdlib::_ast::module::{Mod, ModFunctionType, ModInteractive}; +use crate::stdlib::_ast::module::{Mod, ModFunctionType, ModInteractive, ModModule}; use crate::stdlib::_ast::node::BoxedSlice; use crate::{ - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, - TryFromObject, VirtualMachine, - builtins::PyIntRef, - builtins::{PyDict, PyModule, PyType, PyUtf8StrRef}, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, + builtins::{PyDict, PyList, PyModule, PyTuple, PyType, PyUtf8StrRef}, class::{PyClassImpl, StaticType}, compiler::{CompileError, ParseError}, convert::ToPyObject, @@ -68,36 +69,65 @@ fn singleton_node_to_object(vm: &VirtualMachine, node_type: &'static Py) .into() } +fn is_node_instance( + vm: &VirtualMachine, + object: &PyObjectRef, + node_type: &'static Py, +) -> PyResult { + object.is_instance(node_type.as_object(), vm) +} + +fn is_ast_instance(vm: &VirtualMachine, object: &PyObjectRef) -> PyResult { + let ast_type = NodeAst::make_static_type(); + object.is_instance(ast_type.as_object(), vm) +} + fn get_node_field(vm: &VirtualMachine, obj: &PyObject, field: &'static str, typ: &str) -> PyResult { vm.get_attribute_opt(obj.to_owned(), field)? .ok_or_else(|| vm.new_type_error(format!(r#"required field "{field}" missing from {typ}"#))) } -/// Read a required scalar field, rejecting both attribute absence and `None` value -/// with CPython-compatible error messages. Pairs with `get_node_field_opt` (which -/// returns `Option::None` for the same conditions): both filter `None`, but diverge -/// on whether to raise or return `None`. -/// -/// Errors: -/// - Attribute absent: `TypeError("required field \"X\" missing from Y")` (via `get_node_field`). -/// - Attribute present but `None`: `ValueError("field 'X' is required for Y")`, -/// matching CPython's `Python/ast.c` validator output. -/// -/// Use for required scalar fields where `None` is invalid (e.g. `comprehension.target`, -/// `keyword.value`, `match_case.pattern`). Do NOT use for fields where `None` is -/// legitimate (e.g. `Constant.value` representing the `None` literal — use plain -/// `get_node_field`); or for optional fields (use `get_node_field_opt`). +/// Read a required scalar field. CPython's generated `obj2ast_*` converters only +/// reject a missing required attribute here; if the field exists but is `None`, +/// the nested converter handles it. fn get_node_field_required( vm: &VirtualMachine, obj: &PyObject, field: &'static str, typ: &str, ) -> PyResult { - let value = get_node_field(vm, obj, field, typ)?; + get_node_field(vm, obj, field, typ) +} + +fn get_required_identifier_field( + vm: &VirtualMachine, + source_file: &SourceFile, + obj: &PyObject, + field: &'static str, + typ: &str, +) -> PyResult { + let value = get_node_field_required(vm, obj, field, typ)?; if vm.is_none(&value) { return Err(vm.new_value_error(format!("field '{field}' is required for {typ}"))); } - Ok(value) + Node::ast_from_object(vm, source_file, value) +} + +fn get_required_node_field( + vm: &VirtualMachine, + source_file: &SourceFile, + obj: &PyObject, + field: &'static str, + typ: &str, +) -> PyResult { + let value = get_node_field_required(vm, obj, field, typ)?; + if vm.is_none(&value) { + return Err(vm.new_value_error(format!("field '{field}' is required for {typ}"))); + } + let recursion_context = format!(" while traversing '{typ}' node"); + vm.with_recursion(&recursion_context, || { + Node::ast_from_object(vm, source_file, value) + }) } fn get_node_field_opt( @@ -110,15 +140,220 @@ fn get_node_field_opt( .filter(|obj| !vm.is_none(obj))) } +fn get_node_list_field( + vm: &VirtualMachine, + source_file: &SourceFile, + obj: &PyObject, + field: &'static str, + typ: &str, +) -> PyResult> { + let value = get_node_list_field_object(vm, obj, field, typ)?; + let list = value.downcast_ref::().unwrap(); + convert_node_list_field(vm, source_file, list, field, typ) +} + +fn get_node_list_field_object( + vm: &VirtualMachine, + obj: &PyObject, + field: &'static str, + typ: &str, +) -> PyResult { + let Some(value) = vm.get_attribute_opt(obj.to_owned(), field)? else { + return Ok(vm.ctx.new_list(Vec::new()).into()); + }; + value.downcast_ref::().ok_or_else(|| { + vm.new_type_error(format!( + r#"{typ} field "{field}" must be a list, not a {}"#, + value.class().name() + )) + })?; + Ok(value) +} + +fn convert_node_list_field( + vm: &VirtualMachine, + source_file: &SourceFile, + list: &PyList, + field: &'static str, + typ: &str, +) -> PyResult> { + let len = list.borrow_vec().len(); + let mut result = Vec::with_capacity(len); + let recursion_context = format!(" while traversing '{typ}' node"); + for i in 0..len { + let item = { + let items = list.borrow_vec(); + if items.len() != len { + return Err(vm.new_runtime_error(format!( + r#"{typ} field "{field}" changed size during iteration"# + ))); + } + items[i].clone() + }; + result.push(vm.with_recursion(&recursion_context, || { + Node::ast_from_object(vm, source_file, item) + })?); + if list.borrow_vec().len() != len { + return Err(vm.new_runtime_error(format!( + r#"{typ} field "{field}" changed size during iteration"# + ))); + } + } + Ok(result) +} + +fn get_node_boxed_slice_field( + vm: &VirtualMachine, + source_file: &SourceFile, + obj: &PyObject, + field: &'static str, + typ: &str, +) -> PyResult> { + Ok(get_node_list_field(vm, source_file, obj, field, typ)?.into_boxed_slice()) +} + +fn public_expr_list_from_values( + field: constant::PublicAstExprListField, + values: Vec>, +) -> (ast::AtomicNodeIndex, Vec) { + let node_index = public_expr_lists_node_index([(field, &values)]); + (node_index, lower_public_expr_list(values)) +} + +fn public_expr_boxed_slice_from_values( + field: constant::PublicAstExprListField, + values: Vec>, +) -> (ast::AtomicNodeIndex, Box<[ast::Expr]>) { + let (node_index, values) = public_expr_list_from_values(field, values); + (node_index, values.into_boxed_slice()) +} + +fn public_expr_lists_node_index<'a>( + fields: impl IntoIterator>)>, +) -> ast::AtomicNodeIndex { + public_node_list_overrides_node_index(Vec::new(), fields.into_iter().collect(), None, None) +} + +fn public_stmt_list_from_values( + field: constant::PublicAstStmtListField, + values: Vec>, +) -> (ast::AtomicNodeIndex, ast::Suite) { + let node_index = public_stmt_lists_node_index([(field, &values)]); + (node_index, lower_public_stmt_list(values)) +} + +fn public_stmt_lists_node_index<'a>( + fields: impl IntoIterator>)>, +) -> ast::AtomicNodeIndex { + public_node_list_overrides_node_index(fields.into_iter().collect(), Vec::new(), None, None) +} + +fn public_node_list_overrides_node_index<'a>( + stmt_fields: Vec<(constant::PublicAstStmtListField, &'a Vec>)>, + expr_fields: Vec<(constant::PublicAstExprListField, &'a Vec>)>, + except_handler_values: Option>>, + comprehension_is_async: Option, +) -> ast::AtomicNodeIndex { + let stmt_values: Vec<_> = stmt_fields + .into_iter() + .filter(|(_, values)| values.iter().any(Option::is_none)) + .map(|(field, values)| (field, values.clone())) + .collect(); + let expr_values: Vec<_> = expr_fields + .into_iter() + .filter(|(_, values)| values.iter().any(Option::is_none)) + .map(|(field, values)| (field, values.clone())) + .collect(); + let node_index = ast::AtomicNodeIndex::NONE; + if !stmt_values.is_empty() + || !expr_values.is_empty() + || except_handler_values.is_some() + || comprehension_is_async.is_some() + { + node_index.set(constant::register_public_ast_node_list_overrides( + stmt_values, + expr_values, + except_handler_values, + comprehension_is_async, + )); + } + node_index +} + +fn lower_public_stmt_list(values: Vec>) -> ast::Suite { + values + .into_iter() + .map(|value| value.unwrap_or_else(public_null_stmt_placeholder)) + .collect() +} + +fn lower_public_expr_list(values: Vec>) -> Vec { + values + .into_iter() + .map(|value| value.unwrap_or_else(public_null_expr_placeholder)) + .collect() +} + +fn public_null_stmt_placeholder() -> ast::Stmt { + ast::Stmt::Pass(ast::StmtPass { + range: Default::default(), + node_index: Default::default(), + }) +} + +fn public_null_expr_placeholder() -> ast::Expr { + ast::Expr::NoneLiteral(ast::ExprNoneLiteral { + range: Default::default(), + node_index: Default::default(), + }) +} + fn get_int_field( vm: &VirtualMachine, obj: &PyObject, field: &'static str, typ: &str, -) -> PyResult> { - get_node_field(vm, obj, field, typ)? - .downcast_exact(vm) - .map_err(|_| vm.new_type_error(format!(r#"field "{field}" must have integer type"#))) +) -> PyResult { + node_object_to_i32(vm, get_node_field(vm, obj, field, typ)?) +} + +pub(super) fn node_object_to_i32(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + if obj.is(&vm.ctx.true_value) { + return Ok(1); + } + if obj.is(&vm.ctx.false_value) { + return Ok(0); + } + let int: PyRef = match obj.clone().try_into_value(vm) { + Ok(int) => int, + Err(_) => { + return Err(vm.new_value_error(format!("invalid integer value: {}", obj.repr(vm)?))); + } + }; + i32::try_from(int.as_bigint()) + .map_err(|_| vm.new_overflow_error("Python int too large to convert to C int")) +} + +pub(super) fn node_object_to_ast_string( + vm: &VirtualMachine, + obj: PyObjectRef, +) -> PyResult { + let cls = obj.class(); + if cls.is(vm.ctx.types.str_type) || cls.is(vm.ctx.types.bytes_type) { + Ok(obj) + } else { + Err(vm.new_type_error("AST string must be of type str or bytes")) + } +} + +fn get_ast_string_field_opt( + vm: &VirtualMachine, + obj: &PyObject, + field: &'static str, +) -> PyResult> { + get_node_field_opt(vm, obj, field)? + .map(|obj| node_object_to_ast_string(vm, obj)) + .transpose() } struct PySourceRange { @@ -188,7 +423,17 @@ fn text_range_to_source_range(source_file: &SourceFile, text_range: TextRange) - let start_row = index.line_index(text_range.start()); let end_row = index.line_index(text_range.end()); let start_col = text_range.start() - index.line_start(start_row, source); - let end_col = text_range.end() - index.line_start(end_row, source); + let (end_row, end_col) = { + let end_col = text_range.end() - index.line_start(end_row, source); + if end_col == TextSize::new(0) && end_row > start_row { + let prev_line_end = text_range.end() - TextSize::new(1); + let row = index.line_index(prev_line_end); + let col = prev_line_end - index.line_start(row, source) + TextSize::new(1); + (row, col) + } else { + (end_row, end_col) + } + }; PySourceRange { start: PySourceLocation { @@ -204,138 +449,1361 @@ fn text_range_to_source_range(source_file: &SourceFile, text_range: TextRange) - fn get_opt_int_field( vm: &VirtualMachine, - obj: &PyObject, + obj: &PyObject, + field: &'static str, +) -> PyResult> { + match get_node_field_opt(vm, obj, field)? { + Some(val) => node_object_to_i32(vm, val).map(Some), + None => Ok(None), + } +} + +fn get_attribute_from_field( + vm: &VirtualMachine, + obj: &PyObjectRef, + field: PyObjectRef, +) -> PyResult> { + let field = field + .downcast::() + .map_err(|_| vm.new_type_error("attribute name must be string"))?; + vm.get_attribute_opt(obj.clone(), &field) +} + +#[derive(Default)] +struct AstSourceExtent { + max_line: usize, + max_col: usize, +} + +impl AstSourceExtent { + fn update_location(&mut self, vm: &VirtualMachine, obj: &PyObject) -> PyResult<()> { + if let Some(lineno) = get_opt_int_field(vm, obj, "lineno")? + && lineno > 0 + { + self.max_line = self.max_line.max(lineno as usize); + } + if let Some(end_lineno) = get_opt_int_field(vm, obj, "end_lineno")? + && end_lineno > 0 + { + self.max_line = self.max_line.max(end_lineno as usize); + } + if let Some(col_offset) = get_opt_int_field(vm, obj, "col_offset")? + && col_offset > 0 + { + self.max_col = self.max_col.max(col_offset as usize); + } + if let Some(end_col_offset) = get_opt_int_field(vm, obj, "end_col_offset")? + && end_col_offset > 0 + { + self.max_col = self.max_col.max(end_col_offset as usize); + } + Ok(()) + } +} + +fn scan_ast_source_extent( + vm: &VirtualMachine, + object: &PyObjectRef, + extent: &mut AstSourceExtent, +) -> PyResult<()> { + if is_ast_instance(vm, object)? { + extent.update_location(vm, object)?; + if let Some(fields) = object.class().get_attr(vm.ctx.intern_str("_fields")) { + let fields = fields.sequence_unchecked(); + let len = fields.length(vm)?; + for i in 0..len { + let field = fields.get_item(i as isize, vm)?; + if let Some(value) = get_attribute_from_field(vm, object, field)? { + vm.with_recursion(" while scanning AST node", || { + scan_ast_source_extent(vm, &value, extent) + })?; + } + } + } + } else if let Some(list) = object.downcast_ref::() { + let items = list.borrow_vec().to_vec(); + for item in items { + vm.with_recursion(" while scanning AST node", || { + scan_ast_source_extent(vm, &item, extent) + })?; + } + } else if let Some(tuple) = object.downcast_ref::() { + for item in tuple.as_slice() { + vm.with_recursion(" while scanning AST node", || { + scan_ast_source_extent(vm, item, extent) + })?; + } + } + Ok(()) +} + +fn copy_public_ast_passthrough_fields( + vm: &VirtualMachine, + source: &PyObjectRef, + target: &PyObjectRef, +) -> PyResult<()> { + if !is_ast_instance(vm, source)? + || !is_ast_instance(vm, target)? + || !source.is_instance(target.class().as_object(), vm)? + { + return Ok(()); + } + + let fields: &[&str] = + if is_node_instance(vm, target, pyast::NodeStmtFunctionDef::static_type())? + || is_node_instance(vm, target, pyast::NodeStmtAsyncFunctionDef::static_type())? + || is_node_instance(vm, target, pyast::NodeStmtAssign::static_type())? + || is_node_instance(vm, target, pyast::NodeStmtFor::static_type())? + || is_node_instance(vm, target, pyast::NodeStmtAsyncFor::static_type())? + || is_node_instance(vm, target, pyast::NodeStmtWith::static_type())? + || is_node_instance(vm, target, pyast::NodeStmtAsyncWith::static_type())? + || is_node_instance(vm, target, pyast::NodeArg::static_type())? + { + &["type_comment"] + } else if is_node_instance(vm, target, pyast::NodeComprehension::static_type())? { + &["is_async"] + } else if is_node_instance(vm, target, pyast::NodeExprConstant::static_type())? { + &["kind"] + } else if is_node_instance(vm, target, pyast::NodeExprInterpolation::static_type())? { + &["str"] + } else { + &[] + }; + + for field in fields { + if let Some(value) = vm.get_attribute_opt(source.clone(), *field)? { + target.set_attr(*field, value, vm)?; + } + } + + let Some(source_fields) = source.class().get_attr(vm.ctx.intern_str("_fields")) else { + return Ok(()); + }; + let Some(target_fields) = target.class().get_attr(vm.ctx.intern_str("_fields")) else { + return Ok(()); + }; + let source_fields = source_fields.sequence_unchecked(); + let target_fields = target_fields.sequence_unchecked(); + let len = source_fields.length(vm)?; + if len != target_fields.length(vm)? { + return Ok(()); + } + + for i in 0..len { + let source_field = source_fields.get_item(i as isize, vm)?; + let target_field = target_fields.get_item(i as isize, vm)?; + if !vm.bool_eq(&source_field, &target_field)? { + return Ok(()); + } + let Some(source_value) = get_attribute_from_field(vm, source, source_field)? else { + continue; + }; + let Some(target_value) = get_attribute_from_field(vm, target, target_field)? else { + continue; + }; + copy_public_ast_passthrough_children(vm, &source_value, &target_value)?; + } + + Ok(()) +} + +fn get_ast_location_field( + vm: &VirtualMachine, + object: &PyObjectRef, + field: &'static str, +) -> PyResult> { + Ok(vm + .get_attribute_opt(object.clone(), field)? + .filter(|value| !vm.is_none(value))) +} + +fn ast_start_location_matches( + vm: &VirtualMachine, + source: &PyObjectRef, + target: &PyObjectRef, +) -> PyResult { + for field in ["lineno", "col_offset"] { + let Some(source_value) = get_ast_location_field(vm, source, field)? else { + return Ok(false); + }; + let Some(target_value) = get_ast_location_field(vm, target, field)? else { + return Ok(false); + }; + if !vm.bool_eq(&source_value, &target_value)? { + return Ok(false); + } + } + + for field in ["end_lineno", "end_col_offset"] { + let Some(source_value) = get_ast_location_field(vm, source, field)? else { + continue; + }; + let Some(target_value) = get_ast_location_field(vm, target, field)? else { + continue; + }; + if !vm.bool_eq(&source_value, &target_value)? { + return Ok(false); + } + } + + Ok(true) +} + +fn ast_passthrough_location_candidate_matches( + vm: &VirtualMachine, + source: &PyObjectRef, + target: &PyObjectRef, +) -> PyResult { + Ok(is_ast_instance(vm, source)? + && is_ast_instance(vm, target)? + && source.is_instance(target.class().as_object(), vm)? + && ast_start_location_matches(vm, source, target)?) +} + +fn copy_public_ast_passthrough_list_items_by_location( + vm: &VirtualMachine, + source_items: &[PyObjectRef], + target_items: &[PyObjectRef], +) -> PyResult<()> { + let mut used_source_items = vec![false; source_items.len()]; + for target_item in target_items { + for (index, source_item) in source_items.iter().enumerate() { + if used_source_items[index] { + continue; + } + if ast_passthrough_location_candidate_matches(vm, source_item, target_item)? { + used_source_items[index] = true; + copy_public_ast_passthrough_fields(vm, source_item, target_item)?; + break; + } + } + } + Ok(()) +} + +fn copy_public_ast_passthrough_children( + vm: &VirtualMachine, + source: &PyObjectRef, + target: &PyObjectRef, +) -> PyResult<()> { + if is_ast_instance(vm, source)? && is_ast_instance(vm, target)? { + return copy_public_ast_passthrough_fields(vm, source, target); + } + + if let (Some(source_list), Some(target_list)) = ( + source.downcast_ref::(), + target.downcast_ref::(), + ) { + let source_items = source_list.borrow_vec().to_vec(); + let target_items = target_list.borrow_vec().to_vec(); + if source_items.len() == target_items.len() { + for (source_item, target_item) in source_items.iter().zip(target_items.iter()) { + copy_public_ast_passthrough_children(vm, source_item, target_item)?; + } + } else { + copy_public_ast_passthrough_list_items_by_location(vm, &source_items, &target_items)?; + } + } else if let (Some(source_tuple), Some(target_tuple)) = ( + source.downcast_ref::(), + target.downcast_ref::(), + ) && source_tuple.as_slice().len() == target_tuple.as_slice().len() + { + for (source_item, target_item) in source_tuple + .as_slice() + .iter() + .zip(target_tuple.as_slice().iter()) + { + copy_public_ast_passthrough_children(vm, source_item, target_item)?; + } + } + + Ok(()) +} + +fn synthetic_source_from_ast_object(vm: &VirtualMachine, object: &PyObjectRef) -> PyResult { + let mut extent = AstSourceExtent::default(); + scan_ast_source_extent(vm, object, &mut extent)?; + if extent.max_line == 0 { + return Ok(String::new()); + } + + let line_len = extent.max_col.saturating_add(1); + let line_width = line_len + .checked_add(1) + .ok_or_else(|| vm.new_memory_error("source location is too large"))?; + let capacity = line_width + .checked_mul(extent.max_line) + .ok_or_else(|| vm.new_memory_error("source location is too large"))?; + let mut source = String::new(); + source + .try_reserve(capacity) + .map_err(|_| vm.new_memory_error("source location is too large"))?; + + for _ in 0..extent.max_line { + source.extend(core::iter::repeat_n(' ', line_len)); + source.push('\n'); + } + Ok(source) +} + +fn range_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + name: &str, +) -> PyResult { + range_from_object_impl(vm, source_file, object, name, false) +} + +fn type_param_range_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + range_from_object_impl(vm, source_file, object, "type_param", true) +} + +fn expr_range_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + range_from_object_impl(vm, source_file, object, "expr", false) +} + +fn stmt_range_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + range_from_object_impl(vm, source_file, object, "stmt", false) +} + +fn pattern_range_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + range_from_object_impl(vm, source_file, object, "pattern", true) +} + +fn excepthandler_range_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + range_from_object_impl(vm, source_file, object, "excepthandler", false) +} + +fn excepthandler_range_from_object_unvalidated( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + let start_row = get_int_field(vm, &object, "lineno", "excepthandler")?; + let start_column = get_int_field(vm, &object, "col_offset", "excepthandler")?; + let end_row = get_opt_int_field(vm, &object, "end_lineno")?.unwrap_or(start_row); + let end_column = get_opt_int_field(vm, &object, "end_col_offset")?.unwrap_or(start_column); + + let location = PySourceRange { + start: PySourceLocation { + row: Row(if start_row > 0 { + OneIndexed::new(start_row as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }), + column: Column(TextSize::new(start_column.max(0) as u32)), + }, + end: PySourceLocation { + row: Row(if end_row > 0 { + OneIndexed::new(end_row as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }), + column: Column(TextSize::new(end_column.max(0) as u32)), + }, + }; + + Ok(source_range_to_text_range_unvalidated( + source_file, + location, + )) +} + +fn range_from_object_impl( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + name: &str, + end_required: bool, +) -> PyResult { + let start_row = get_int_field(vm, &object, "lineno", name)?; + let start_column = get_int_field(vm, &object, "col_offset", name)?; + let end_row = if end_required { + get_int_field(vm, &object, "end_lineno", name)? + } else { + get_opt_int_field(vm, &object, "end_lineno")?.unwrap_or(start_row) + }; + let end_column = if end_required { + get_int_field(vm, &object, "end_col_offset", name)? + } else { + get_opt_int_field(vm, &object, "end_col_offset")?.unwrap_or(start_column) + }; + + // lineno=0 or negative values as a special case (no location info). + // Use default values (line 1, col 0) when lineno <= 0. + let start_row_val = start_row; + let end_row_val = end_row; + let start_col_val = start_column; + let end_col_val = end_column; + + if start_row_val > end_row_val { + return Err(vm.new_value_error(format!( + "AST node line range ({start_row_val}, {end_row_val}) is not valid" + ))); + } + if (start_row_val < 0 && end_row_val != start_row_val) + || (start_col_val < 0 && end_col_val != start_col_val) + { + return Err(vm.new_value_error(format!( + "AST node column range ({start_col_val}, {end_col_val}) for line range ({start_row_val}, {end_row_val}) is not valid" + ))); + } + if start_row_val == end_row_val && start_col_val > end_col_val { + return Err(vm.new_value_error(format!( + "line {start_row_val}, column {start_col_val}-{end_col_val} is not a valid range" + ))); + } + + let location = PySourceRange { + start: PySourceLocation { + row: Row(if start_row_val > 0 { + OneIndexed::new(start_row_val as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }), + column: Column(TextSize::new(start_col_val.max(0) as u32)), + }, + end: PySourceLocation { + row: Row(if end_row_val > 0 { + OneIndexed::new(end_row_val as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }), + column: Column(TextSize::new(end_col_val.max(0) as u32)), + }, + }; + + Ok(source_range_to_text_range(source_file, location)) +} + +fn source_range_to_text_range(source_file: &SourceFile, location: PySourceRange) -> TextRange { + let index = LineIndex::from_source_text(source_file.clone().source_text()); + let source = &source_file.source_text(); + + if source.is_empty() { + return TextRange::new(TextSize::new(0), TextSize::new(0)); + } + + let start = index.offset( + location.start.to_source_location(), + source, + PositionEncoding::Utf8, + ); + let end = index.offset( + location.end.to_source_location(), + source, + PositionEncoding::Utf8, + ); + + TextRange::new(start, end) +} + +fn source_range_to_text_range_unvalidated( + source_file: &SourceFile, + location: PySourceRange, +) -> TextRange { + let index = LineIndex::from_source_text(source_file.clone().source_text()); + let source = &source_file.source_text(); + + if source.is_empty() { + return TextRange::new(TextSize::new(0), TextSize::new(0)); + } + + let start = index.offset( + location.start.to_source_location(), + source, + PositionEncoding::Utf8, + ); + let end = index.offset( + location.end.to_source_location(), + source, + PositionEncoding::Utf8, + ); + + if start <= end { + TextRange::new(start, end) + } else { + TextRange::empty(start) + } +} + +fn node_add_location( + dict: &Py, + range: TextRange, + vm: &VirtualMachine, + source_file: &SourceFile, +) { + let range = text_range_to_source_range(source_file, range); + dict.set_item("lineno", vm.ctx.new_int(range.start.row.get()).into(), vm) + .unwrap(); + dict.set_item( + "col_offset", + vm.ctx.new_int(range.start.column.get()).into(), + vm, + ) + .unwrap(); + dict.set_item("end_lineno", vm.ctx.new_int(range.end.row.get()).into(), vm) + .unwrap(); + dict.set_item( + "end_col_offset", + vm.ctx.new_int(range.end.column.get()).into(), + vm, + ) + .unwrap(); +} + +/// Return the expected public AST root type class for a compile() mode string. +/// +/// CPython's builtin compile() accepts func_type only with PyCF_ONLY_AST. +/// Source-string func_type parsing is handled separately, but public AST +/// FunctionType still uses the mode check before obj-to-AST conversion. +pub(crate) fn mode_type_and_name(mode: &str) -> Option<(PyRef, &'static str)> { + match mode { + "exec" => Some((pyast::NodeModModule::make_static_type(), "Module")), + "eval" => Some((pyast::NodeModExpression::make_static_type(), "Expression")), + "single" => Some((pyast::NodeModInteractive::make_static_type(), "Interactive")), + "func_type" => Some(( + pyast::NodeModFunctionType::make_static_type(), + "FunctionType", + )), + _ => None, + } +} + +struct TypeCommentLine<'a> { + text: &'a str, + comment_start: Option, +} + +struct TypeCommentSource<'a> { + lines: Vec>, +} + +impl<'a> TypeCommentSource<'a> { + fn new(source: &'a str, tokens: &ast::token::Tokens) -> Self { + let mut comment_offsets = Vec::new(); + for token in tokens { + if matches!(token.kind(), ast::token::TokenKind::Comment) { + comment_offsets.push(token.start().to_usize()); + } + } + + let mut comment_offsets = comment_offsets.into_iter().peekable(); + let mut line_start = 0usize; + let mut lines = Vec::new(); + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + let comment_start = comment_offsets.next_if(|offset| *offset < line_end); + lines.push(TypeCommentLine { + text: line, + comment_start: comment_start.map(|offset| offset - line_start), + }); + line_start = line_end; + } + + Self { lines } + } +} + +fn type_comment_position(line: &TypeCommentLine<'_>) -> Option { + let comment = line.comment_start?; + line.text[comment + 1..] + .trim_start() + .starts_with("type:") + .then_some(comment) +} + +fn type_comment_text<'a>(line: &'a TypeCommentLine<'a>) -> Option<&'a str> { + let comment = line.comment_start?; + let text = line.text.trim_end_matches(['\n', '\r']); + let mut rest = text[comment + 1..].trim_start_matches([' ', '\t']); + rest = rest.strip_prefix("type:")?; + Some(rest.trim_start_matches([' ', '\t'])) +} + +fn type_ignore_tag(comment: &str) -> Option<&str> { + let rest = comment.strip_prefix("ignore")?; + if let Some(next) = rest.as_bytes().first() + && (next.is_ascii_alphanumeric() || !next.is_ascii()) + { + return None; + } + Some(rest) +} + +fn regular_type_comment_text<'a>(line: &'a TypeCommentLine<'a>) -> Option<&'a str> { + let comment = type_comment_text(line)?; + type_ignore_tag(comment).is_none().then_some(comment) +} + +fn type_comment_parse_error( + source_file: &SourceFile, + message: &str, + start: usize, + end: usize, +) -> CompileError { + let range = TextRange::new(TextSize::new(start as u32), TextSize::new(end as u32)); + let source_range = text_range_to_source_range(source_file, range); + ParseError { + error: parser::ParseErrorType::OtherError(message.to_owned()), + raw_location: range, + location: source_range.start.to_source_location(), + end_location: source_range.end.to_source_location(), + source_path: "".to_string(), + is_unclosed_bracket: false, + } + .into() +} + +#[cfg(feature = "codegen")] +fn future_feature_compile_error( + source_file: &SourceFile, + error: codegen::preprocess::FutureFeatureError, +) -> CompileError { + let location = source_file + .to_source_code() + .source_location(error.range.start(), PositionEncoding::Utf8); + let error = match error.kind { + codegen::preprocess::FutureFeatureErrorKind::InvalidFeature(feature) => { + codegen::error::CodegenErrorType::InvalidFutureFeature(feature) + } + codegen::preprocess::FutureFeatureErrorKind::InvalidBraces => { + codegen::error::CodegenErrorType::InvalidFutureBraces + } + }; + codegen::error::CodegenError { + location: Some(location), + error, + source_path: source_file.name().to_owned(), + } + .into() +} + +fn trimmed_line_end(line: &str) -> usize { + line.trim_end_matches(['\n', '\r']).len() +} + +fn line_end_error( + source_file: &SourceFile, + message: &str, + line_start: usize, + line: &str, +) -> CompileError { + let start = line_start + trimmed_line_end(line); + type_comment_parse_error(source_file, message, start, start + 1) +} + +fn point_error_end(source: &str, start: usize) -> usize { + match source.as_bytes().get(start) { + None => start, + Some(_) => start + 1, + } +} + +fn line_after_colon_error( + source_file: &SourceFile, + message: &str, + line_start: usize, + line: &str, +) -> Option { + let code = &line[..trimmed_line_end(line)]; + let colon = code.rfind(':')?; + (!code[colon + 1..].trim().is_empty()) + .then(|| line_end_error(source_file, message, line_start, line)) +} + +fn find_line_containing_offset(source: &str, offset: usize) -> Option<(usize, &str)> { + let mut line_start = 0usize; + for line in source.split_inclusive('\n') { + let line_end = line_start + line.len(); + if offset < line_end { + return Some((line_start, line)); + } + line_start = line_end; + } + (offset == source.len()).then_some((line_start, "")) +} + +fn find_next_nonempty_line_end(source: &str, offset: usize) -> Option { + let (mut line_start, line) = find_line_containing_offset(source, offset)?; + line_start += line.len(); + for line in source[line_start..].split_inclusive('\n') { + if !line.trim().is_empty() { + return Some(line_start + trimmed_line_end(line)); + } + line_start += line.len(); + } + None +} + +fn find_numeric_literal_containing_underscore(code: &str) -> Option<(usize, usize)> { + let bytes = code.as_bytes(); + for idx in 1..bytes.len().saturating_sub(1) { + if bytes[idx] == b'_' && bytes[idx - 1].is_ascii_digit() && bytes[idx + 1].is_ascii_digit() + { + let mut start = idx - 1; + while start > 0 + && (bytes[start - 1].is_ascii_alphanumeric() || bytes[start - 1] == b'_') + { + start -= 1; + } + let mut end = idx + 2; + while end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') { + end += 1; + } + return Some((start, end)); + } + } + None +} + +fn bracket_delta(code: &str) -> i32 { + code.chars().fold(0, |depth, ch| match ch { + '(' | '[' | '{' => depth + 1, + ')' | ']' | '}' => depth - 1, + _ => depth, + }) +} + +fn def_header_complete(code: &str, depth: i32) -> bool { + depth <= 0 && code.trim_end().ends_with(':') +} + +fn is_assignment_stmt_line(code: &str) -> bool { + let bytes = code.as_bytes(); + for (idx, byte) in bytes.iter().enumerate() { + if *byte != b'=' { + continue; + } + let prev = idx.checked_sub(1).and_then(|idx| bytes.get(idx)).copied(); + let next = bytes.get(idx + 1).copied(); + if matches!( + prev, + Some( + b'=' | b'!' + | b'<' + | b'>' + | b':' + | b'+' + | b'-' + | b'*' + | b'/' + | b'%' + | b'&' + | b'|' + | b'^' + ) + ) || matches!(next, Some(b'=')) + { + continue; + } + return true; + } + false +} + +fn line_allows_stmt_type_comment(code: &str) -> bool { + let stripped = code.trim_start(); + (stripped.starts_with("for ") || stripped.starts_with("async for ")) && stripped.ends_with(':') + || (stripped.starts_with("with ") || stripped.starts_with("async with ")) + && stripped.ends_with(':') + || is_assignment_stmt_line(code) +} + +fn invalid_type_comment_syntax_error( + source_file: &SourceFile, + type_comment_source: &TypeCommentSource<'_>, +) -> Option { + let mut line_start = 0usize; + let mut in_def_header = false; + let mut def_depth = 0i32; + let mut pending_func_type_comment = false; + let mut previous_def_had_type_comment = false; + for line in &type_comment_source.lines { + let line_end = line_start + line.text.len(); + let stripped = line.text.trim_start(); + let code_end = type_comment_position(line).unwrap_or(line.text.len()); + let code = line.text[..code_end].trim(); + let has_regular_type_comment = regular_type_comment_text(line).is_some(); + + if let Some(comment) = type_comment_position(line) { + if code == "*" || code == "*," || code.ends_with("*,") { + return Some(type_comment_parse_error( + source_file, + "bare * has associated type comment", + line_start + comment, + line_start + line.text.len(), + )); + } + if previous_def_had_type_comment && code.is_empty() { + return Some(type_comment_parse_error( + source_file, + "Cannot have two type comments on def", + line_start + comment, + line_start + line.text.len(), + )); + } + let allowed = !has_regular_type_comment + || in_def_header + || line_allows_stmt_type_comment(code) + || stripped.starts_with("def ") + || stripped.starts_with("async def ") + || (pending_func_type_comment && code.is_empty()); + if !allowed { + return Some(type_comment_parse_error( + source_file, + "invalid syntax", + line_start + comment, + line_start + line.text.len(), + )); + } + } + + let starts_def = stripped.starts_with("def ") || stripped.starts_with("async def "); + if starts_def && !in_def_header { + def_depth = bracket_delta(code); + let complete = def_header_complete(code, def_depth); + in_def_header = !complete; + previous_def_had_type_comment = complete && has_regular_type_comment; + pending_func_type_comment = complete && !has_regular_type_comment; + } else if in_def_header { + def_depth += bracket_delta(code); + let complete = def_header_complete(code, def_depth); + if complete { + in_def_header = false; + previous_def_had_type_comment = has_regular_type_comment; + pending_func_type_comment = !has_regular_type_comment; + } + } else if (pending_func_type_comment && code.is_empty() && has_regular_type_comment) + || (!stripped.trim().is_empty() && !starts_def && !code.is_empty()) + { + pending_func_type_comment = false; + previous_def_had_type_comment = false; + } + + line_start = line_end; + } + None +} + +fn feature_version_syntax_error( + source: &str, + source_file: &SourceFile, + target_version: ast::PythonVersion, +) -> Option { + let mut line_start = 0usize; + let mut async_def_error = None; + let mut pending_async_def = false; + let mut pending_block_error = None; + for line in source.split_inclusive('\n') { + let code_end = line.find('#').unwrap_or(line.len()); + let code = &line[..code_end]; + let stripped = code.trim_start(); + if pending_async_def && !stripped.trim().is_empty() { + if async_def_error.is_none() { + async_def_error = Some(line_end_error( + source_file, + "Async functions are only supported in Python 3.5 and greater", + line_start, + line, + )); + } + pending_async_def = false; + } + if let Some(message) = pending_block_error.take() { + if !stripped.trim().is_empty() { + return Some(line_end_error(source_file, message, line_start, line)); + } + pending_block_error = Some(message); + } + + if target_version.minor < 5 { + if stripped.starts_with("async def ") && async_def_error.is_none() { + let message = "Async functions are only supported in Python 3.5 and greater"; + if let Some(error) = line_after_colon_error(source_file, message, line_start, line) + { + async_def_error = Some(error); + } else { + pending_async_def = true; + } + } + if stripped.starts_with("async for ") { + let message = "Async for loops are only supported in Python 3.5 and greater"; + if let Some(error) = line_after_colon_error(source_file, message, line_start, line) + { + return Some(error); + } + pending_block_error = Some(message); + } + if stripped.starts_with("async with ") { + let message = "Async with statements are only supported in Python 3.5 and greater"; + if let Some(error) = line_after_colon_error(source_file, message, line_start, line) + { + return Some(error); + } + pending_block_error = Some(message); + } + if stripped.starts_with("await ") { + return Some(line_end_error( + source_file, + "Await expressions are only supported in Python 3.5 and greater", + line_start, + line, + )); + } + if let Some(pos) = code.find('@') + && !stripped.starts_with('@') + { + let is_augassign = code.as_bytes().get(pos + 1) == Some(&b'='); + let (start, end) = if is_augassign { + (line_start + pos, line_start + pos + 2) + } else { + let start = line_start + trimmed_line_end(line); + (start, start + 1) + }; + return Some(type_comment_parse_error( + source_file, + "The '@' operator is only supported in Python 3.5 and greater", + start, + end, + )); + } + } + + if target_version.minor < 6 { + if !stripped.starts_with("async for ") && code.contains(" async for ") { + let start = line_start + trimmed_line_end(line).saturating_sub(1); + return Some(type_comment_parse_error( + source_file, + "Async comprehensions are only supported in Python 3.6 and greater", + start, + point_error_end(source_file.source_text(), start), + )); + } + if let Some((start, end)) = find_numeric_literal_containing_underscore(code) { + return Some(type_comment_parse_error( + source_file, + "Underscores in numeric literals are only supported in Python 3.6 and greater", + line_start + start, + line_start + end, + )); + } + } + + line_start += line.len(); + } + async_def_error +} + +fn ann_assign_feature_error(stmts: &[ast::Stmt], source_file: &SourceFile) -> Option { + for stmt in stmts { + match stmt { + ast::Stmt::AnnAssign(ann) => { + let start = ann.range().end().to_usize(); + return Some(type_comment_parse_error( + source_file, + "Variable annotation syntax is only supported in Python 3.6 and greater", + start, + point_error_end(source_file.source_text(), start), + )); + } + ast::Stmt::FunctionDef(def) => { + if let Some(error) = ann_assign_feature_error(&def.body, source_file) { + return Some(error); + } + } + ast::Stmt::ClassDef(class_def) => { + if let Some(error) = ann_assign_feature_error(&class_def.body, source_file) { + return Some(error); + } + } + ast::Stmt::For(for_stmt) => { + if let Some(error) = ann_assign_feature_error(&for_stmt.body, source_file) + .or_else(|| ann_assign_feature_error(&for_stmt.orelse, source_file)) + { + return Some(error); + } + } + ast::Stmt::While(while_stmt) => { + if let Some(error) = ann_assign_feature_error(&while_stmt.body, source_file) + .or_else(|| ann_assign_feature_error(&while_stmt.orelse, source_file)) + { + return Some(error); + } + } + ast::Stmt::If(if_stmt) => { + if let Some(error) = ann_assign_feature_error(&if_stmt.body, source_file) { + return Some(error); + } + for clause in &if_stmt.elif_else_clauses { + if let Some(error) = ann_assign_feature_error(&clause.body, source_file) { + return Some(error); + } + } + } + ast::Stmt::With(with_stmt) => { + if let Some(error) = ann_assign_feature_error(&with_stmt.body, source_file) { + return Some(error); + } + } + ast::Stmt::Match(match_stmt) => { + for case in &match_stmt.cases { + if let Some(error) = ann_assign_feature_error(&case.body, source_file) { + return Some(error); + } + } + } + ast::Stmt::Try(try_stmt) => { + if let Some(error) = ann_assign_feature_error(&try_stmt.body, source_file) + .or_else(|| ann_assign_feature_error(&try_stmt.orelse, source_file)) + .or_else(|| ann_assign_feature_error(&try_stmt.finalbody, source_file)) + { + return Some(error); + } + for handler in &try_stmt.handlers { + let ast::ExceptHandler::ExceptHandler(handler) = handler; + if let Some(error) = ann_assign_feature_error(&handler.body, source_file) { + return Some(error); + } + } + } + _ => {} + } + } + None +} + +fn feature_version_ast_syntax_error( + top: &ast::Mod, + source_file: &SourceFile, + target_version: ast::PythonVersion, +) -> Option { + if target_version.minor >= 6 { + return None; + } + match top { + ast::Mod::Module(module) => ann_assign_feature_error(&module.body, source_file), + ast::Mod::Expression(_) => None, + } +} + +fn cpython_unsupported_syntax_message( + error: &parser::UnsupportedSyntaxError, +) -> Option<&'static str> { + match error.kind { + parser::UnsupportedSyntaxErrorKind::Match => { + Some("Pattern matching is only supported in Python 3.10 and greater") + } + parser::UnsupportedSyntaxErrorKind::Walrus => { + Some("Assignment expressions are only supported in Python 3.8 and greater") + } + parser::UnsupportedSyntaxErrorKind::ExceptStar => { + Some("Exception groups are only supported in Python 3.11 and greater") + } + parser::UnsupportedSyntaxErrorKind::PositionalOnlyParameter => { + Some("Positional-only parameters are only supported in Python 3.8 and greater") + } + parser::UnsupportedSyntaxErrorKind::TypeParameterList => { + Some("Type parameter lists are only supported in Python 3.12 and greater") + } + parser::UnsupportedSyntaxErrorKind::TypeAliasStatement => { + Some("Type statement is only supported in Python 3.12 and greater") + } + parser::UnsupportedSyntaxErrorKind::TypeParamDefault => { + Some("Type parameter defaults are only supported in Python 3.13 and greater") + } + parser::UnsupportedSyntaxErrorKind::TemplateStrings => { + Some("t-strings are only supported in Python 3.14 and greater") + } + parser::UnsupportedSyntaxErrorKind::UnparenthesizedExceptionTypes => Some( + "except expressions without parentheses are only supported in Python 3.14 and greater", + ), + _ => None, + } +} + +fn cpython_unsupported_syntax_error( + error: &parser::UnsupportedSyntaxError, + source: &str, + source_file: &SourceFile, +) -> Option { + let message = cpython_unsupported_syntax_message(error)?; + let start = match error.kind { + parser::UnsupportedSyntaxErrorKind::Match + | parser::UnsupportedSyntaxErrorKind::ExceptStar + | parser::UnsupportedSyntaxErrorKind::UnparenthesizedExceptionTypes => { + find_next_nonempty_line_end(source, error.range.start().to_usize()) + .unwrap_or_else(|| error.range.end().to_usize()) + } + parser::UnsupportedSyntaxErrorKind::Walrus + | parser::UnsupportedSyntaxErrorKind::PositionalOnlyParameter + | parser::UnsupportedSyntaxErrorKind::TypeParamDefault => error.range.end().to_usize(), + parser::UnsupportedSyntaxErrorKind::TypeAliasStatement => { + let (line_start, line) = + find_line_containing_offset(source, error.range.start().to_usize())?; + line_start + trimmed_line_end(line) + } + parser::UnsupportedSyntaxErrorKind::TypeParameterList => { + let (line_start, line) = + find_line_containing_offset(source, error.range.start().to_usize())?; + let code = &line[..trimmed_line_end(line)]; + line_start + + code + .as_bytes() + .iter() + .rposition(|byte| *byte == b']') + .unwrap_or_else(|| error.range.end().to_usize() - line_start) + } + parser::UnsupportedSyntaxErrorKind::TemplateStrings => { + let (line_start, line) = + find_line_containing_offset(source, error.range.start().to_usize())?; + line_start + trimmed_line_end(line).saturating_sub(1) + } + _ => error.range.start().to_usize(), + }; + Some(type_comment_parse_error( + source_file, + message, + start, + point_error_end(source, start), + )) +} + +fn should_report_unsupported_syntax_error(error: &parser::UnsupportedSyntaxError) -> bool { + cpython_unsupported_syntax_message(error).is_some() + || matches!( + error.kind, + parser::UnsupportedSyntaxErrorKind::LazyImportStatement + | parser::UnsupportedSyntaxErrorKind::UnpackingInComprehension(_) + | parser::UnsupportedSyntaxErrorKind::ParenthesizedKeywordArgumentName + ) +} + +fn node_list_field( + vm: &VirtualMachine, + object: &PyObjectRef, field: &'static str, -) -> PyResult>> { - match get_node_field_opt(vm, obj, field)? { - Some(val) => val - .downcast_exact(vm) - .map(Some) - .map_err(|_| vm.new_type_error(format!(r#"field "{field}" must have integer type"#))), - None => Ok(None), - } +) -> Vec { + vm.get_attribute_opt(object.clone(), field) + .ok() + .flatten() + .and_then(|value| { + value + .downcast_ref::() + .map(|list| list.borrow_vec().to_vec()) + }) + .unwrap_or_default() } -fn range_from_object( +fn node_optional_field( vm: &VirtualMachine, - source_file: &SourceFile, - object: PyObjectRef, - name: &str, -) -> PyResult { - let start_row = get_int_field(vm, &object, "lineno", name)?; - let start_column = get_int_field(vm, &object, "col_offset", name)?; - // end_lineno and end_col_offset are optional, default to start values - let end_row = - get_opt_int_field(vm, &object, "end_lineno")?.unwrap_or_else(|| start_row.clone()); - let end_column = - get_opt_int_field(vm, &object, "end_col_offset")?.unwrap_or_else(|| start_column.clone()); + object: &PyObjectRef, + field: &'static str, +) -> Option { + vm.get_attribute_opt(object.clone(), field) + .ok() + .flatten() + .filter(|value| !vm.is_none(value)) +} - // lineno=0 or negative values as a special case (no location info). - // Use default values (line 1, col 0) when lineno <= 0. - let start_row_val: i32 = start_row.try_to_primitive(vm)?; - let end_row_val: i32 = end_row.try_to_primitive(vm)?; - let start_col_val: i32 = start_column.try_to_primitive(vm)?; - let end_col_val: i32 = end_column.try_to_primitive(vm)?; +fn node_lineno(vm: &VirtualMachine, object: &PyObjectRef) -> Option { + node_optional_field(vm, object, "lineno")? + .try_into_value(vm) + .ok() +} - if start_row_val > end_row_val { - return Err(vm.new_value_error(format!( - "AST node line range ({start_row_val}, {end_row_val}) is not valid" - ))); - } - if (start_row_val < 0 && end_row_val != start_row_val) - || (start_col_val < 0 && end_col_val != start_col_val) - { - return Err(vm.new_value_error(format!( - "AST node column range ({start_col_val}, {end_col_val}) for line range ({start_row_val}, {end_row_val}) is not valid" - ))); - } - if start_row_val == end_row_val && start_col_val > end_col_val { - return Err(vm.new_value_error(format!( - "line {start_row_val}, column {start_col_val}-{end_col_val} is not a valid range" - ))); - } +fn source_line<'a>( + lines: &'a TypeCommentSource<'a>, + lineno: usize, +) -> Option<&'a TypeCommentLine<'a>> { + lineno.checked_sub(1).and_then(|idx| lines.lines.get(idx)) +} - let location = PySourceRange { - start: PySourceLocation { - row: Row(if start_row_val > 0 { - OneIndexed::new(start_row_val as usize).unwrap_or(OneIndexed::MIN) - } else { - OneIndexed::MIN - }), - column: Column(TextSize::new(start_col_val.max(0) as u32)), - }, - end: PySourceLocation { - row: Row(if end_row_val > 0 { - OneIndexed::new(end_row_val as usize).unwrap_or(OneIndexed::MIN) - } else { - OneIndexed::MIN - }), - column: Column(TextSize::new(end_col_val.max(0) as u32)), - }, - }; +fn set_type_comment(vm: &VirtualMachine, object: &PyObjectRef, comment: Option<&str>) { + let value = comment.map_or_else(|| vm.ctx.none(), |comment| vm.ctx.new_str(comment).into()); + object + .as_object() + .dict() + .unwrap() + .set_item("type_comment", value, vm) + .unwrap(); +} - Ok(source_range_to_text_range(source_file, location)) +fn same_line_type_comment<'a>( + vm: &VirtualMachine, + lines: &'a TypeCommentSource<'a>, + object: &PyObjectRef, +) -> Option<&'a str> { + let lineno = node_lineno(vm, object)?; + regular_type_comment_text(source_line(lines, lineno)?) } -fn source_range_to_text_range(source_file: &SourceFile, location: PySourceRange) -> TextRange { - let index = LineIndex::from_source_text(source_file.clone().source_text()); - let source = &source_file.source_text(); +fn function_type_comment<'a>( + vm: &VirtualMachine, + lines: &'a TypeCommentSource<'a>, + object: &PyObjectRef, +) -> Option<&'a str> { + let lineno = node_lineno(vm, object)?; + if let Some(comment) = regular_type_comment_text(source_line(lines, lineno)?) { + return Some(comment); + } - if source.is_empty() { - return TextRange::new(TextSize::new(0), TextSize::new(0)); + let next_line = source_line(lines, lineno + 1)?; + let comment_pos = type_comment_position(next_line)?; + next_line.text[..comment_pos] + .trim() + .is_empty() + .then(|| regular_type_comment_text(next_line)) + .flatten() +} + +fn apply_type_comments_to_arguments( + vm: &VirtualMachine, + lines: &TypeCommentSource<'_>, + arguments: &PyObjectRef, +) { + for field in ["posonlyargs", "args", "kwonlyargs"] { + for arg in node_list_field(vm, arguments, field) { + set_type_comment(vm, &arg, same_line_type_comment(vm, lines, &arg)); + } + } + for field in ["vararg", "kwarg"] { + if let Some(arg) = node_optional_field(vm, arguments, field) { + set_type_comment(vm, &arg, same_line_type_comment(vm, lines, &arg)); + } } +} - let start = index.offset( - location.start.to_source_location(), - source, - PositionEncoding::Utf8, - ); - let end = index.offset( - location.end.to_source_location(), - source, - PositionEncoding::Utf8, - ); +fn apply_type_comments_to_node( + vm: &VirtualMachine, + lines: &TypeCommentSource<'_>, + object: &PyObjectRef, +) { + let cls = object.class(); + if cls.is(pyast::NodeStmtFunctionDef::static_type()) + || cls.is(pyast::NodeStmtAsyncFunctionDef::static_type()) + { + set_type_comment(vm, object, function_type_comment(vm, lines, object)); + if let Some(arguments) = node_optional_field(vm, object, "args") { + apply_type_comments_to_arguments(vm, lines, &arguments); + } + } else if cls.is(pyast::NodeStmtAssign::static_type()) + || cls.is(pyast::NodeStmtFor::static_type()) + || cls.is(pyast::NodeStmtAsyncFor::static_type()) + || cls.is(pyast::NodeStmtWith::static_type()) + || cls.is(pyast::NodeStmtAsyncWith::static_type()) + { + set_type_comment(vm, object, same_line_type_comment(vm, lines, object)); + } - TextRange::new(start, end) + for field in ["body", "orelse", "finalbody"] { + for child in node_list_field(vm, object, field) { + apply_type_comments_to_node(vm, lines, &child); + } + } + for field in ["handlers", "cases"] { + for child in node_list_field(vm, object, field) { + apply_type_comments_to_node(vm, lines, &child); + } + } } -fn node_add_location( - dict: &Py, - range: TextRange, +fn apply_type_comments_to_module( vm: &VirtualMachine, - source_file: &SourceFile, + lines: &TypeCommentSource<'_>, + module: &PyObjectRef, ) { - let range = text_range_to_source_range(source_file, range); - dict.set_item("lineno", vm.ctx.new_int(range.start.row.get()).into(), vm) - .unwrap(); - dict.set_item( - "col_offset", - vm.ctx.new_int(range.start.column.get()).into(), - vm, - ) - .unwrap(); - dict.set_item("end_lineno", vm.ctx.new_int(range.end.row.get()).into(), vm) - .unwrap(); - dict.set_item( - "end_col_offset", - vm.ctx.new_int(range.end.column.get()).into(), - vm, - ) - .unwrap(); + for statement in node_list_field(vm, module, "body") { + apply_type_comments_to_node(vm, lines, &statement); + } } -/// Return the expected AST mod type class for a compile() mode string. -pub(crate) fn mode_type_and_name(mode: &str) -> Option<(PyRef, &'static str)> { - match mode { - "exec" => Some((pyast::NodeModModule::make_static_type(), "Module")), - "eval" => Some((pyast::NodeModExpression::make_static_type(), "Expression")), - "single" => Some((pyast::NodeModInteractive::make_static_type(), "Interactive")), - "func_type" => Some(( - pyast::NodeModFunctionType::make_static_type(), - "FunctionType", - )), - _ => None, +#[cfg(feature = "parser")] +fn ipython_escape_command_syntax_error( + top: &ast::Mod, + source_file: &SourceFile, +) -> Option { + use ast::visitor::{Visitor, walk_expr, walk_stmt}; + + #[derive(Default)] + struct IpyEscapeCommandVisitor { + range: Option, } + + impl Visitor<'_> for IpyEscapeCommandVisitor { + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + if self.range.is_some() { + return; + } + match stmt { + ast::Stmt::IpyEscapeCommand(stmt) => { + self.range = Some(stmt.range); + } + _ => walk_stmt(self, stmt), + } + } + + fn visit_expr(&mut self, expr: &ast::Expr) { + if self.range.is_some() { + return; + } + match expr { + ast::Expr::IpyEscapeCommand(expr) => { + self.range = Some(expr.range); + } + _ => walk_expr(self, expr), + } + } + } + + let mut visitor = IpyEscapeCommandVisitor::default(); + match top { + ast::Mod::Module(module) => { + for statement in &module.body { + visitor.visit_stmt(statement); + if visitor.range.is_some() { + break; + } + } + } + ast::Mod::Expression(expression) => { + visitor.visit_expr(&expression.body); + } + } + let range = visitor.range?; + let source_range = text_range_to_source_range(source_file, range); + Some( + ParseError { + error: parser::ParseErrorType::OtherError("invalid syntax".to_owned()), + raw_location: range, + location: source_range.start.to_source_location(), + end_location: source_range.end.to_source_location(), + source_path: "".to_owned(), + is_unclosed_bracket: false, + } + .into(), + ) } /// Create an empty `arguments` AST node (no parameters). @@ -361,6 +1829,7 @@ fn empty_arguments_object(vm: &VirtualMachine) -> PyObjectRef { } #[cfg(feature = "parser")] +#[allow(clippy::too_many_arguments)] pub(crate) fn parse( vm: &VirtualMachine, source: &str, @@ -368,14 +1837,30 @@ pub(crate) fn parse( optimize: u8, target_version: Option, type_comments: bool, + optimized_ast: bool, + interactive: bool, + explicit_future_annotations: bool, + dont_imply_dedent: bool, ) -> Result { let source_file = SourceFileBuilder::new("".to_owned(), source.to_owned()).finish(); let mut options = parser::ParseOptions::from(mode); let target_version = target_version.unwrap_or(ast::PythonVersion::PY314); + if let Some(error) = feature_version_syntax_error(source, &source_file, target_version) { + return Err(error); + } options = options.with_target_version(target_version); - let parsed = parser::parse(source, options).map_err(|parse_error| { + let parsed = parser::parse_unchecked(source, options); + let type_comment_source = + type_comments.then(|| TypeCommentSource::new(source, parsed.tokens())); + if let Some(lines) = &type_comment_source + && let Some(error) = invalid_type_comment_syntax_error(&source_file, lines) + { + return Err(error); + } + if let Err(errors) = parsed.as_result() { + let parse_error = errors[0].clone(); let range = text_range_to_source_range(&source_file, parse_error.location); - ParseError { + return Err(ParseError { error: parse_error.error, raw_location: parse_error.location, location: range.start.to_source_location(), @@ -383,9 +1868,23 @@ pub(crate) fn parse( source_path: "".to_string(), is_unclosed_bracket: false, } - })?; + .into()); + } + if dont_imply_dedent + && interactive + && let Some(error) = rustpython_compiler::dont_imply_dedent_source_error(&source_file) + { + return Err(error); + } - if let Some(error) = parsed.unsupported_syntax_errors().first() { + if let Some(error) = parsed + .unsupported_syntax_errors() + .iter() + .find(|error| should_report_unsupported_syntax_error(error)) + { + if let Some(error) = cpython_unsupported_syntax_error(error, source, &source_file) { + return Err(error); + } let range = text_range_to_source_range(&source_file, error.range()); return Err(ParseError { error: parser::ParseErrorType::OtherError(error.to_string()), @@ -399,19 +1898,56 @@ pub(crate) fn parse( } let mut top = parsed.into_syntax(); - if optimize > 0 { - fold_match_value_constants(&mut top); + if let Some(error) = ipython_escape_command_syntax_error(&top, &source_file) { + return Err(error); + } + if let Some(error) = feature_version_ast_syntax_error(&top, &source_file, target_version) { + return Err(error); + } + #[cfg(feature = "codegen")] + { + let future_features = codegen::preprocess::checked_future_features(&top) + .map_err(|err| future_feature_compile_error(&source_file, err))?; + let future_annotations = explicit_future_annotations + || future_features.contains(crate::bytecode::CodeFlags::FUTURE_ANNOTATIONS); + if interactive && let ast::Mod::Module(module) = &mut top { + codegen::preprocess::preprocess_statements( + &mut module.body, + optimize, + future_annotations, + !optimized_ast, + ); + } else { + codegen::preprocess::preprocess_mod( + &mut top, + optimize, + future_annotations, + !optimized_ast, + ); + } } - if optimize >= 2 { - strip_docstrings(&mut top); + #[cfg(not(feature = "codegen"))] + { + if optimized_ast && optimize > 0 { + fold_match_value_constants(&mut top); + } + if optimize >= 2 { + strip_docstrings(&mut top); + } } let top = match top { - ast::Mod::Module(m) => Mod::Module(m), + ast::Mod::Module(m) => Mod::Module(ModModule { + module: m, + type_ignores: Vec::new(), + }), ast::Mod::Expression(e) => Mod::Expression(e), }; let obj = top.ast_to_object(vm, &source_file); - if type_comments && obj.class().is(pyast::NodeModModule::static_type()) { - let type_ignores = type_ignores_from_source(vm, source); + if let Some(lines) = &type_comment_source + && obj.class().is(pyast::NodeModModule::static_type()) + { + apply_type_comments_to_module(vm, lines, &obj); + let type_ignores = type_ignores_from_source(vm, lines); let dict = obj.as_object().dict().unwrap(); dict.set_item("type_ignores", vm.ctx.new_list(type_ignores).into(), vm) .unwrap(); @@ -441,8 +1977,18 @@ pub(crate) fn parse_func_type( target_version: Option, ) -> Result { let _ = optimize; - let _ = target_version; let source = source.trim(); + let invalid_func_type = || -> CompileError { + ParseError { + error: parser::ParseErrorType::OtherError("invalid syntax".to_owned()), + raw_location: TextRange::default(), + location: SourceLocation::default(), + end_location: SourceLocation::default(), + source_path: "".to_owned(), + is_unclosed_bracket: false, + } + .into() + }; let mut depth = 0i32; let mut split_at = None; let mut chars = source.chars().peekable(); @@ -477,7 +2023,9 @@ pub(crate) fn parse_func_type( let parse_expr = |expr_src: &str| -> Result { let source_file = SourceFileBuilder::new("".to_owned(), expr_src.to_owned()).finish(); - let parsed = parser::parse_expression(expr_src).map_err(|parse_error| { + let options = parser::ParseOptions::from(parser::Mode::Expression) + .with_target_version(target_version.unwrap_or(ast::PythonVersion::PY314)); + let parsed = parser::parse(expr_src, options).map_err(|parse_error| { let range = text_range_to_source_range(&source_file, parse_error.location); ParseError { error: parse_error.error, @@ -488,19 +2036,73 @@ pub(crate) fn parse_func_type( is_unclosed_bracket: false, } })?; - Ok(*parsed.into_syntax().body) + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + unreachable!(); + }; + Ok(*expression.body) }; - let arg_expr = parse_expr(left)?; - let returns = parse_expr(right)?; - - let argtypes: Vec = match arg_expr { - ast::Expr::Tuple(tup) => tup.elts, - ast::Expr::Name(_) | ast::Expr::Subscript(_) | ast::Expr::Attribute(_) => vec![arg_expr], - other => vec![other], + if !left.starts_with('(') || !left.ends_with(')') { + return Err(invalid_func_type()); + } + let inner = left[1..left.len() - 1].trim(); + let argtypes = if inner.is_empty() { + Vec::new() + } else { + if inner.ends_with(',') { + return Err(invalid_func_type()); + } + let call_source = format!("__rustpython_func_type__({inner})"); + let source_file = SourceFileBuilder::new("".to_owned(), call_source.clone()).finish(); + let options = parser::ParseOptions::from(parser::Mode::Expression) + .with_target_version(target_version.unwrap_or(ast::PythonVersion::PY314)); + let parsed = parser::parse(&call_source, options).map_err(|parse_error| { + let range = text_range_to_source_range(&source_file, parse_error.location); + ParseError { + error: parse_error.error, + raw_location: parse_error.location, + location: range.start.to_source_location(), + end_location: range.end.to_source_location(), + source_path: "".to_string(), + is_unclosed_bracket: false, + } + })?; + let ast::Mod::Expression(expression) = parsed.into_syntax() else { + unreachable!(); + }; + let ast::Expr::Call(call) = *expression.body else { + return Err(invalid_func_type()); + }; + let mut args = Vec::new(); + let positional_len = call.arguments.args.len(); + let mut seen_star = false; + for (index, arg) in call.arguments.args.into_iter().enumerate() { + match arg { + ast::Expr::Starred(starred) => { + if seen_star || index + 1 != positional_len { + return Err(invalid_func_type()); + } + seen_star = true; + args.push(*starred.value); + } + expr => args.push(expr), + } + } + let mut seen_kw_star = false; + for keyword in call.arguments.keywords { + if keyword.arg.is_some() || seen_kw_star { + return Err(invalid_func_type()); + } + seen_kw_star = true; + args.push(keyword.value); + } + args }; + let returns = parse_expr(right)?; + let func_type = ModFunctionType { + node_index: ast::AtomicNodeIndex::NONE, argtypes: argtypes.into_boxed_slice(), returns, range: TextRange::default(), @@ -509,22 +2111,18 @@ pub(crate) fn parse_func_type( Ok(func_type.ast_to_object(vm, &source_file)) } -fn type_ignores_from_source(vm: &VirtualMachine, source: &str) -> Vec { +fn type_ignores_from_source( + vm: &VirtualMachine, + lines: &TypeCommentSource<'_>, +) -> Vec { let mut ignores = Vec::new(); - for (idx, line) in source.lines().enumerate() { - let Some(pos) = line.find('#') else { + for (idx, line) in lines.lines.iter().enumerate() { + let Some(comment) = type_comment_text(line) else { continue; }; - - let comment = &line[pos + 1..]; - let comment = comment.trim_start(); - - let Some(rest) = comment.strip_prefix("type: ignore") else { + let Some(tag) = type_ignore_tag(comment) else { continue; }; - - let tag = rest.trim_start(); - let tag = if tag.is_empty() { "" } else { tag }; let node = NodeAst .into_ref_with_type( vm, @@ -542,7 +2140,7 @@ fn type_ignores_from_source(vm: &VirtualMachine, source: &str) -> Vec fold_stmts(&mut module.body), @@ -550,7 +2148,7 @@ fn fold_match_value_constants(top: &mut ast::Mod) { } } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn strip_docstrings(top: &mut ast::Mod) { match top { ast::Mod::Module(module) => strip_docstring_in_body(&mut module.body), @@ -558,8 +2156,8 @@ fn strip_docstrings(top: &mut ast::Mod) { } } -#[cfg(feature = "parser")] -fn strip_docstring_in_body(body: &mut Vec) { +#[cfg(all(feature = "parser", not(feature = "codegen")))] +fn strip_docstring_in_body(body: &mut ast::Suite) { if let Some(range) = take_docstring(body) && body.is_empty() { @@ -580,8 +2178,8 @@ fn strip_docstring_in_body(body: &mut Vec) { } } -#[cfg(feature = "parser")] -fn take_docstring(body: &mut Vec) -> Option { +#[cfg(all(feature = "parser", not(feature = "codegen")))] +fn take_docstring(body: &mut ast::Suite) -> Option { let ast::Stmt::Expr(expr_stmt) = body.first()? else { return None; }; @@ -593,14 +2191,14 @@ fn take_docstring(body: &mut Vec) -> Option { None } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn fold_stmts(stmts: &mut [ast::Stmt]) { for stmt in stmts { fold_stmt(stmt); } } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn fold_stmt(stmt: &mut ast::Stmt) { use ast::Stmt; match stmt { @@ -641,7 +2239,7 @@ fn fold_stmt(stmt: &mut ast::Stmt) { } } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn fold_pattern(pattern: &mut ast::Pattern) { use ast::Pattern; match pattern { @@ -681,7 +2279,7 @@ fn fold_pattern(pattern: &mut ast::Pattern) { } } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn fold_expr(expr: &mut ast::Expr) { use ast::Expr; if let Expr::UnaryOp(unary) = expr { @@ -735,7 +2333,7 @@ fn fold_expr(expr: &mut ast::Expr) { } } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn fold_number_binop( left: &ast::Number, op: ast::Operator, @@ -761,7 +2359,7 @@ fn fold_number_binop( } } -#[cfg(feature = "parser")] +#[cfg(all(feature = "parser", not(feature = "codegen")))] fn number_to_complex(number: &ast::Number) -> Option<(f64, f64, bool)> { match number { ast::Number::Complex { real, imag } => Some((*real, *imag, true)), @@ -770,94 +2368,307 @@ fn number_to_complex(number: &ast::Number) -> Option<(f64, f64, bool)> { } } +#[cfg(feature = "codegen")] +pub(crate) fn preprocess_ast_object( + vm: &VirtualMachine, + object: PyObjectRef, + filename: &str, + optimize: u8, + optimized_ast: bool, + explicit_future_annotations: bool, +) -> PyResult { + let original_object = object.clone(); + let text = synthetic_source_from_ast_object(vm, &object)?; + let source_file = SourceFileBuilder::new(filename.to_owned(), text).finish(); + let ( + ast, + ast_constant_overrides, + ast_interpolation_overrides, + ast_formatted_value_overrides, + ast_import_from_level_overrides, + ast_invalid_constant_overrides, + ast_joined_str_overrides, + ast_template_str_overrides, + ast_comprehension_is_async_overrides, + ast_pattern_list_overrides, + ast_expr_option_list_overrides, + ast_expr_list_overrides, + ast_stmt_list_overrides, + ast_except_handler_list_overrides, + ast_type_param_list_overrides, + ast_match_class_overrides, + ast_ann_assign_simple_overrides, + ast_arg_type_comment_overrides, + ast_stmt_type_comment_overrides, + ): (Mod, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _) = + constant::collect_public_ast_overrides(|| Node::ast_from_object(vm, &source_file, object))?; + validate::validate_mod( + vm, + &ast, + Some(&ast_constant_overrides), + Some(&ast_interpolation_overrides), + Some(&ast_formatted_value_overrides), + Some(&ast_import_from_level_overrides), + Some(&ast_invalid_constant_overrides), + Some(&ast_joined_str_overrides), + Some(&ast_template_str_overrides), + Some(&ast_pattern_list_overrides), + Some(&ast_expr_option_list_overrides), + Some(&ast_expr_list_overrides), + Some(&ast_stmt_list_overrides), + Some(&ast_except_handler_list_overrides), + Some(&ast_type_param_list_overrides), + Some(&ast_match_class_overrides), + )?; + let syntax_check_only = !optimized_ast; + + let ast = match ast { + Mod::Module(mut module) => { + let mut ast = ast::Mod::Module(module.module); + let future_features = + codegen::preprocess::checked_future_features(&ast).map_err(|err| { + vm.new_syntax_error(&future_feature_compile_error(&source_file, err), None) + })?; + let future_annotations = explicit_future_annotations + || future_features.contains(crate::bytecode::CodeFlags::FUTURE_ANNOTATIONS); + codegen::preprocess::preprocess_mod( + &mut ast, + optimize, + future_annotations, + syntax_check_only, + ); + let ast::Mod::Module(processed_module) = ast else { + unreachable!(); + }; + module.module = processed_module; + Mod::Module(module) + } + Mod::Interactive(mut interactive) => { + let future_features = codegen::preprocess::checked_future_features_in_body( + &interactive.body, + ) + .map_err(|err| { + vm.new_syntax_error(&future_feature_compile_error(&source_file, err), None) + })?; + let future_annotations = explicit_future_annotations + || future_features.contains(crate::bytecode::CodeFlags::FUTURE_ANNOTATIONS); + codegen::preprocess::preprocess_statements( + &mut interactive.body, + optimize, + future_annotations, + syntax_check_only, + ); + Mod::Interactive(interactive) + } + Mod::Expression(expression) => { + let mut ast = ast::Mod::Expression(expression); + codegen::preprocess::preprocess_mod( + &mut ast, + optimize, + explicit_future_annotations, + syntax_check_only, + ); + let ast::Mod::Expression(expression) = ast else { + unreachable!(); + }; + Mod::Expression(expression) + } + Mod::FunctionType(function_type) => Mod::FunctionType(function_type), + }; + let result = constant::with_public_ast_interpolation_objects( + &ast_constant_overrides, + &ast_interpolation_overrides, + &ast_formatted_value_overrides, + &ast_joined_str_overrides, + &ast_template_str_overrides, + &ast_comprehension_is_async_overrides, + &ast_pattern_list_overrides, + &ast_expr_option_list_overrides, + &ast_expr_list_overrides, + &ast_stmt_list_overrides, + &ast_except_handler_list_overrides, + &ast_type_param_list_overrides, + &ast_match_class_overrides, + &ast_ann_assign_simple_overrides, + &ast_arg_type_comment_overrides, + &ast_stmt_type_comment_overrides, + || ast.ast_to_object(vm, &source_file), + ); + copy_public_ast_passthrough_fields(vm, &original_object, &result)?; + Ok(result) +} + #[cfg(feature = "codegen")] pub(crate) fn compile( vm: &VirtualMachine, object: PyObjectRef, filename: &str, mode: crate::compiler::Mode, - optimize: Option, + mut opts: codegen::CompileOpts, ) -> PyResult { - let mut opts = vm.compile_opts(); - if let Some(optimize) = optimize { - opts.optimize = optimize; + let text = synthetic_source_from_ast_object(vm, &object)?; + let source_file = SourceFileBuilder::new(filename.to_owned(), text.clone()).finish(); + let ( + ast, + ast_constant_overrides, + ast_interpolation_overrides, + ast_formatted_value_overrides, + ast_import_from_level_overrides, + ast_invalid_constant_overrides, + ast_joined_str_overrides, + ast_template_str_overrides, + _ast_comprehension_is_async_overrides, + ast_pattern_list_overrides, + ast_expr_option_list_overrides, + ast_expr_list_overrides, + ast_stmt_list_overrides, + ast_except_handler_list_overrides, + ast_type_param_list_overrides, + ast_match_class_overrides, + _ast_ann_assign_simple_overrides, + _ast_arg_type_comment_overrides, + _ast_stmt_type_comment_overrides, + ): (Mod, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _) = + constant::collect_public_ast_overrides(|| Node::ast_from_object(vm, &source_file, object))?; + validate::validate_mod( + vm, + &ast, + Some(&ast_constant_overrides), + Some(&ast_interpolation_overrides), + Some(&ast_formatted_value_overrides), + Some(&ast_import_from_level_overrides), + Some(&ast_invalid_constant_overrides), + Some(&ast_joined_str_overrides), + Some(&ast_template_str_overrides), + Some(&ast_pattern_list_overrides), + Some(&ast_expr_option_list_overrides), + Some(&ast_expr_list_overrides), + Some(&ast_stmt_list_overrides), + Some(&ast_except_handler_list_overrides), + Some(&ast_type_param_list_overrides), + Some(&ast_match_class_overrides), + )?; + if !ast_constant_overrides.is_empty() { + opts.ast_constant_overrides = Some(Arc::new(ast_constant_overrides)); + } + if !ast_interpolation_overrides.is_empty() { + opts.ast_interpolation_overrides = Some(Arc::new(ast_interpolation_overrides)); + } + if !ast_formatted_value_overrides.is_empty() { + opts.ast_formatted_value_overrides = Some(Arc::new(ast_formatted_value_overrides)); + } + if !ast_joined_str_overrides.is_empty() { + opts.ast_joined_str_overrides = Some(Arc::new(ast_joined_str_overrides)); + } + if !ast_template_str_overrides.is_empty() { + opts.ast_template_str_overrides = Some(Arc::new(ast_template_str_overrides)); } - - let source_file = SourceFileBuilder::new(filename.to_owned(), "".to_owned()).finish(); - let ast: Mod = Node::ast_from_object(vm, &source_file, object)?; - validate::validate_mod(vm, &ast)?; let ast = match ast { - Mod::Module(m) => ast::Mod::Module(m), - Mod::Interactive(ModInteractive { range, body }) => ast::Mod::Module(ast::ModModule { + Mod::Module(m) => ast::Mod::Module(m.module), + Mod::Interactive(ModInteractive { range, body, .. }) => ast::Mod::Module(ast::ModModule { node_index: Default::default(), range, body, }), Mod::Expression(e) => ast::Mod::Expression(e), - Mod::FunctionType(_) => todo!(), + Mod::FunctionType(_) => { + return Err(vm.new_runtime_error("this compiler does not handle FunctionTypes")); + } }; - // TODO: create a textual representation of the ast - let text = ""; + opts.future_features |= codegen::preprocess::future_features(&ast); let source_file = SourceFileBuilder::new(filename, text).finish(); - let code = codegen::compile::compile_top(ast, source_file, mode, opts) - .map_err(|err| vm.new_syntax_error(&err.into(), None))?; // FIXME source + #[cfg(feature = "parser")] + let code = { + let source_path = filename.to_owned(); + let mut syntax_warning_handler = |location: SourceLocation, message: String| { + let fname = vm.ctx.new_str(source_path.as_str()); + let message = vm.ctx.new_str(message); + crate::warn::warn_explicit( + Some(vm.ctx.exceptions.syntax_warning.to_owned()), + message.into(), + fname, + location.line.get(), + None, + vm.ctx.none(), + None, + None, + vm, + ) + .map_err(|exception| { + let message = exception.as_object().str(vm).map_or_else( + |_| "compiler warning raised as an exception".to_owned(), + |message| message.as_wtf8().to_string(), + ); + codegen::error::CodegenError { + location: Some(location), + error: codegen::error::CodegenErrorType::SyntaxError(message), + source_path: source_path.clone(), + } + }) + }; + codegen::compile::compile_top_with_syntax_warning_handler( + ast, + source_file, + mode, + opts, + Some(&mut syntax_warning_handler), + ) + }; + #[cfg(not(feature = "parser"))] + let code = codegen::compile::compile_top(ast, source_file, mode, opts); + let code = code.map_err(|err| vm.new_syntax_error(&err.into(), None))?; // FIXME source Ok(crate::builtins::PyCode::new_ref_from_bytecode(vm, code).into()) } -#[cfg(feature = "codegen")] +#[cfg(not(feature = "rustpython-codegen"))] pub(crate) fn validate_ast_object(vm: &VirtualMachine, object: PyObjectRef) -> PyResult<()> { let source_file = SourceFileBuilder::new("".to_owned(), "".to_owned()).finish(); - let ast: Mod = Node::ast_from_object(vm, &source_file, object)?; - validate::validate_mod(vm, &ast)?; + let ( + ast, + ast_constant_overrides, + ast_interpolation_overrides, + ast_formatted_value_overrides, + ast_import_from_level_overrides, + ast_invalid_constant_overrides, + _ast_joined_str_overrides, + _ast_template_str_overrides, + _ast_comprehension_is_async_overrides, + ast_pattern_list_overrides, + ast_expr_option_list_overrides, + ast_expr_list_overrides, + ast_stmt_list_overrides, + ast_except_handler_list_overrides, + ast_type_param_list_overrides, + ast_match_class_overrides, + _ast_ann_assign_simple_overrides, + _ast_arg_type_comment_overrides, + _ast_stmt_type_comment_overrides, + ): (Mod, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _) = + constant::collect_public_ast_overrides(|| Node::ast_from_object(vm, &source_file, object))?; + validate::validate_mod( + vm, + &ast, + Some(&ast_constant_overrides), + Some(&ast_interpolation_overrides), + Some(&ast_formatted_value_overrides), + Some(&ast_import_from_level_overrides), + Some(&ast_invalid_constant_overrides), + Some(&_ast_joined_str_overrides), + Some(&_ast_template_str_overrides), + Some(&ast_pattern_list_overrides), + Some(&ast_expr_option_list_overrides), + Some(&ast_expr_list_overrides), + Some(&ast_stmt_list_overrides), + Some(&ast_except_handler_list_overrides), + Some(&ast_type_param_list_overrides), + Some(&ast_match_class_overrides), + )?; Ok(()) } -// Used by builtins::compile() -pub(crate) const PY_CF_ONLY_AST: i32 = 0x0400; - // The following flags match the values from Include/cpython/compile.h -// Caveat emptor: These flags are undocumented on purpose and depending -// on their effect outside the standard library is **unsupported**. -pub(crate) const PY_CF_SOURCE_IS_UTF8: i32 = 0x0100; -pub(crate) const PY_CF_DONT_IMPLY_DEDENT: i32 = 0x200; -pub(crate) const PY_CF_IGNORE_COOKIE: i32 = 0x0800; -pub(crate) const PY_CF_ALLOW_INCOMPLETE_INPUT: i32 = 0x4000; -pub(crate) const PY_CF_OPTIMIZED_AST: i32 = 0x8000 | PY_CF_ONLY_AST; -pub(crate) const PY_CF_TYPE_COMMENTS: i32 = 0x1000; -pub(crate) const PY_CF_ALLOW_TOP_LEVEL_AWAIT: i32 = 0x2000; - -// __future__ flags - sync with Lib/__future__.py -// TODO: These flags aren't being used in rust code -// CO_FUTURE_ANNOTATIONS does make a difference in the codegen, -// so it should be used in compile(). -// see compiler/codegen/src/compile.rs -const CO_NESTED: i32 = 0x0010; -const CO_GENERATOR_ALLOWED: i32 = 0; -const CO_FUTURE_DIVISION: i32 = 0x20000; -const CO_FUTURE_ABSOLUTE_IMPORT: i32 = 0x40000; -const CO_FUTURE_WITH_STATEMENT: i32 = 0x80000; -const CO_FUTURE_PRINT_FUNCTION: i32 = 0x100000; -const CO_FUTURE_UNICODE_LITERALS: i32 = 0x200000; -const CO_FUTURE_BARRY_AS_BDFL: i32 = 0x400000; -const CO_FUTURE_GENERATOR_STOP: i32 = 0x800000; -const CO_FUTURE_ANNOTATIONS: i32 = 0x1000000; - -// Used by builtins::compile() - the summary of all flags -pub(crate) const PY_COMPILE_FLAGS_MASK: i32 = PY_CF_ONLY_AST - | PY_CF_SOURCE_IS_UTF8 - | PY_CF_DONT_IMPLY_DEDENT - | PY_CF_IGNORE_COOKIE - | PY_CF_ALLOW_TOP_LEVEL_AWAIT - | PY_CF_ALLOW_INCOMPLETE_INPUT - | PY_CF_OPTIMIZED_AST - | PY_CF_TYPE_COMMENTS - | CO_NESTED - | CO_GENERATOR_ALLOWED - | CO_FUTURE_DIVISION - | CO_FUTURE_ABSOLUTE_IMPORT - | CO_FUTURE_WITH_STATEMENT - | CO_FUTURE_PRINT_FUNCTION - | CO_FUTURE_UNICODE_LITERALS - | CO_FUTURE_BARRY_AS_BDFL - | CO_FUTURE_GENERATOR_STOP - | CO_FUTURE_ANNOTATIONS; +pub(crate) use crate::vm::compile_mode::{ + PY_CF_ALLOW_INCOMPLETE_INPUT, PY_CF_ALLOW_TOP_LEVEL_AWAIT, PY_CF_DONT_IMPLY_DEDENT, + PY_CF_IGNORE_COOKIE, PY_CF_ONLY_AST, PY_CF_OPTIMIZED_AST, PY_CF_SOURCE_IS_UTF8, + PY_CF_TYPE_COMMENTS, +}; diff --git a/crates/vm/src/stdlib/_ast/argument.rs b/crates/vm/src/stdlib/_ast/argument.rs index 626024f5bd6..75bd34c86bf 100644 --- a/crates/vm/src/stdlib/_ast/argument.rs +++ b/crates/vm/src/stdlib/_ast/argument.rs @@ -2,14 +2,48 @@ use super::*; use rustpython_compiler_core::SourceFile; pub(super) struct PositionalArguments { + pub node_index: ast::AtomicNodeIndex, + pub field: super::constant::PublicAstExprListField, pub range: TextRange, pub args: Box<[ast::Expr]>, } +impl PositionalArguments { + pub(super) fn ast_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + let args: Vec> = + get_node_list_field(vm, source_file, object, field, typ)?; + let public_field = match field { + "bases" => super::constant::PublicAstExprListField::Bases, + _ => super::constant::PublicAstExprListField::Args, + }; + let (node_index, args) = public_expr_boxed_slice_from_values(public_field, args); + Ok(Self { + node_index, + field: public_field, + args, + range: TextRange::default(), + }) + } +} + impl Node for PositionalArguments { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - let Self { args, range: _ } = self; - BoxedSlice(args).ast_to_object(vm, source_file) + let Self { + node_index, + field, + args, + range: _, + } = self; + super::constant::public_ast_expr_list_object(node_index.load(), field).map_or_else( + || BoxedSlice(args).ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ) } fn ast_from_object( @@ -19,6 +53,8 @@ impl Node for PositionalArguments { ) -> PyResult { let args: BoxedSlice<_> = Node::ast_from_object(vm, source_file, object)?; Ok(Self { + node_index: Default::default(), + field: super::constant::PublicAstExprListField::Args, args: args.0, range: TextRange::default(), // TODO }) @@ -30,6 +66,21 @@ pub(super) struct KeywordArguments { pub keywords: Box<[ast::Keyword]>, } +impl KeywordArguments { + pub(super) fn ast_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + Ok(Self { + keywords: get_node_boxed_slice_field(vm, source_file, object, field, typ)?, + range: TextRange::default(), + }) + } +} + impl Node for KeywordArguments { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { keywords, range: _ } = self; @@ -57,10 +108,10 @@ pub(super) fn merge_function_call_arguments( let range = pos_args.range.cover(key_args.range); ast::Arguments { - node_index: Default::default(), + node_index: pos_args.node_index, range, args: pos_args.args, - keywords: key_args.keywords, + keywords: key_args.keywords.into(), } } @@ -68,7 +119,7 @@ pub(super) fn split_function_call_arguments( args: ast::Arguments, ) -> (PositionalArguments, KeywordArguments) { let ast::Arguments { - node_index: _, + node_index, range: _, args, keywords, @@ -81,6 +132,8 @@ pub(super) fn split_function_call_arguments( .unwrap_or_default(); // debug_assert!(range.contains_range(positional_arguments_range)); let positional_arguments = PositionalArguments { + node_index, + field: super::constant::PublicAstExprListField::Args, range: positional_arguments_range, args, }; @@ -93,7 +146,7 @@ pub(super) fn split_function_call_arguments( // debug_assert!(range.contains_range(keyword_arguments_range)); let keyword_arguments = KeywordArguments { range: keyword_arguments_range, - keywords, + keywords: keywords.into(), }; (positional_arguments, keyword_arguments) @@ -107,7 +160,7 @@ pub(super) fn split_class_def_args( Some(args) => *args, }; let ast::Arguments { - node_index: _, + node_index, range: _, args, keywords, @@ -120,6 +173,8 @@ pub(super) fn split_class_def_args( .unwrap_or_default(); // debug_assert!(range.contains_range(positional_arguments_range)); let positional_arguments = PositionalArguments { + node_index, + field: super::constant::PublicAstExprListField::Bases, range: positional_arguments_range, args, }; @@ -132,7 +187,7 @@ pub(super) fn split_class_def_args( // debug_assert!(range.contains_range(keyword_arguments_range)); let keyword_arguments = KeywordArguments { range: keyword_arguments_range, - keywords, + keywords: keywords.into(), }; (Some(positional_arguments), Some(keyword_arguments)) @@ -146,10 +201,10 @@ pub(super) fn merge_class_def_args( return None; } - let args = if let Some(positional_arguments) = positional_arguments { - positional_arguments.args + let (node_index, args) = if let Some(positional_arguments) = positional_arguments { + (positional_arguments.node_index, positional_arguments.args) } else { - vec![].into_boxed_slice() + (Default::default(), vec![].into_boxed_slice()) }; let keywords = if let Some(keyword_arguments) = keyword_arguments { keyword_arguments.keywords @@ -158,9 +213,9 @@ pub(super) fn merge_class_def_args( }; Some(Box::new(ast::Arguments { - node_index: Default::default(), + node_index, range: Default::default(), // TODO args, - keywords, + keywords: keywords.into(), })) } diff --git a/crates/vm/src/stdlib/_ast/basic.rs b/crates/vm/src/stdlib/_ast/basic.rs index 28e4a6803ee..b60f8ab75a8 100644 --- a/crates/vm/src/stdlib/_ast/basic.rs +++ b/crates/vm/src/stdlib/_ast/basic.rs @@ -1,4 +1,5 @@ use super::*; +use crate::builtins::PyIntRef; use rustpython_codegen::compile::ruff_int_to_bigint; use rustpython_compiler_core::SourceFile; @@ -13,7 +14,11 @@ impl Node for ast::Identifier { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let py_str = PyUtf8StrRef::try_from_object(vm, object)?; + if !object.class().is(vm.ctx.types.str_type) { + return Err(vm.new_type_error("AST identifier must be of type str")); + } + let py_str = PyUtf8StrRef::try_from_object(vm, object) + .map_err(|_| vm.new_type_error("AST identifier must be of type str"))?; Ok(Self::new(py_str.as_str(), TextRange::default())) } } @@ -45,6 +50,6 @@ impl Node for bool { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - i32::try_from_object(vm, object).map(|i| i != 0) + node_object_to_i32(vm, object).map(|i| i != 0) } } diff --git a/crates/vm/src/stdlib/_ast/constant.rs b/crates/vm/src/stdlib/_ast/constant.rs index b1a8a015689..598997fe163 100644 --- a/crates/vm/src/stdlib/_ast/constant.rs +++ b/crates/vm/src/stdlib/_ast/constant.rs @@ -1,12 +1,136 @@ use super::*; use crate::builtins::{PyComplex, PyFrozenSet, PyTuple}; use ast::str_prefix::StringLiteralPrefix; -use rustpython_compiler_core::SourceFile; +use core::cell::RefCell; +use rustpython_codegen::{ + PublicAstExprList, PublicAstFormattedValue, PublicAstInterpolation, PublicAstNodeMap, + compile::ruff_int_to_bigint, +}; +use rustpython_compiler_core::{SourceFile, bytecode::ConstantData}; + +#[derive(Clone)] +pub(super) struct PublicAstPatternList { + pub(super) values: Vec>, +} + +#[derive(Clone)] +pub(super) struct PublicAstExprOptionList { + pub(super) values: Vec>, +} + +#[derive(Clone)] +pub(super) struct PublicAstStmtList { + pub(super) values: Vec>, +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub(super) enum PublicAstStmtListField { + Body, + Orelse, + FinalBody, +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub(super) enum PublicAstExprListField { + Args, + Bases, + DecoratorList, + Targets, + Values, + Elts, + Comparators, + Ifs, +} + +#[derive(Clone)] +pub(super) struct PublicAstExceptHandlerList { + pub(super) values: Vec>, +} + +#[derive(Clone)] +pub(super) struct PublicAstTypeParamList { + pub(super) values: Vec>, +} + +#[derive(Clone)] +pub(super) struct PublicAstMatchClass { + pub(super) patterns: Vec>, + pub(super) kwd_attrs: Vec, + pub(super) kwd_patterns: Vec>, +} + +#[derive(Clone, Default)] +pub(super) struct PublicAstExprListFields { + args: Option, + bases: Option, + decorator_list: Option, + targets: Option, + values: Option, + elts: Option, + comparators: Option, + ifs: Option, +} + +impl PublicAstExprListFields { + fn insert(&mut self, field: PublicAstExprListField, values: PublicAstExprOptionList) { + let slot = match field { + PublicAstExprListField::Args => &mut self.args, + PublicAstExprListField::Bases => &mut self.bases, + PublicAstExprListField::DecoratorList => &mut self.decorator_list, + PublicAstExprListField::Targets => &mut self.targets, + PublicAstExprListField::Values => &mut self.values, + PublicAstExprListField::Elts => &mut self.elts, + PublicAstExprListField::Comparators => &mut self.comparators, + PublicAstExprListField::Ifs => &mut self.ifs, + }; + *slot = Some(values); + } + + pub(super) fn get(&self, field: PublicAstExprListField) -> Option<&PublicAstExprOptionList> { + match field { + PublicAstExprListField::Args => self.args.as_ref(), + PublicAstExprListField::Bases => self.bases.as_ref(), + PublicAstExprListField::DecoratorList => self.decorator_list.as_ref(), + PublicAstExprListField::Targets => self.targets.as_ref(), + PublicAstExprListField::Values => self.values.as_ref(), + PublicAstExprListField::Elts => self.elts.as_ref(), + PublicAstExprListField::Comparators => self.comparators.as_ref(), + PublicAstExprListField::Ifs => self.ifs.as_ref(), + } + } +} + +#[derive(Clone, Default)] +pub(super) struct PublicAstStmtListFields { + body: Option, + orelse: Option, + finalbody: Option, +} + +impl PublicAstStmtListFields { + fn insert(&mut self, field: PublicAstStmtListField, values: PublicAstStmtList) { + let slot = match field { + PublicAstStmtListField::Body => &mut self.body, + PublicAstStmtListField::Orelse => &mut self.orelse, + PublicAstStmtListField::FinalBody => &mut self.finalbody, + }; + *slot = Some(values); + } + + pub(super) fn get(&self, field: PublicAstStmtListField) -> Option<&PublicAstStmtList> { + match field { + PublicAstStmtListField::Body => self.body.as_ref(), + PublicAstStmtListField::Orelse => self.orelse.as_ref(), + PublicAstStmtListField::FinalBody => self.finalbody.as_ref(), + } + } +} #[derive(Debug)] pub(super) struct Constant { pub(super) range: TextRange, pub(super) value: ConstantLiteral, + invalid_type: Option, } impl Constant { @@ -19,6 +143,7 @@ impl Constant { Self { range, value: ConstantLiteral::Str { value, prefix }, + invalid_type: None, } } @@ -26,6 +151,7 @@ impl Constant { Self { range, value: ConstantLiteral::Int(value), + invalid_type: None, } } @@ -33,6 +159,7 @@ impl Constant { Self { range, value: ConstantLiteral::Float(value), + invalid_type: None, } } @@ -40,6 +167,7 @@ impl Constant { Self { range, value: ConstantLiteral::Complex { real, imag }, + invalid_type: None, } } @@ -47,6 +175,7 @@ impl Constant { Self { range, value: ConstantLiteral::Bytes(value), + invalid_type: None, } } @@ -54,6 +183,7 @@ impl Constant { Self { range, value: ConstantLiteral::Bool(value), + invalid_type: None, } } @@ -61,6 +191,7 @@ impl Constant { Self { range, value: ConstantLiteral::None, + invalid_type: None, } } @@ -68,15 +199,27 @@ impl Constant { Self { range, value: ConstantLiteral::Ellipsis, + invalid_type: None, } } pub(crate) fn into_expr(self) -> ast::Expr { - constant_to_ruff_expr(self) + let invalid_type = self.invalid_type.clone(); + let constant = self + .invalid_type + .is_none() + .then(|| constant_literal_to_constant_data(&self.value)); + let expr = constant_to_ruff_expr(self); + if let Some(invalid_type) = invalid_type { + register_public_ast_invalid_constant(&expr, invalid_type); + } else if let Some(constant) = constant { + register_public_ast_constant(&expr, constant); + } + expr } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum ConstantLiteral { None, Bool(bool), @@ -96,10 +239,1005 @@ pub(crate) enum ConstantLiteral { Ellipsis, } +struct PublicAstConstantState { + next_index: u32, + // CPython AST has Constant_kind.value/kind fields; Ruff has separate + // literal expr variants. Dense node indexes make Vec lookup cheaper than + // hashing, and insertion order is never observed. + constants: PublicAstNodeMap, + // CPython Interpolation has raw str and expr? format_spec; Ruff t-string + // elements do not. Dense node lookup avoids hashing these synthetic nodes. + interpolations: PublicAstNodeMap, + // CPython FormattedValue has expr? format_spec; Ruff f-string specs are + // parsed as string elements. Dense node lookup is the direct hot path. + formatted_values: PublicAstNodeMap, + // CPython ImportFrom.level accepts a public signed int; Ruff only stores + // parser-valid unsigned levels. Dense lookup preserves only overrides. + import_from_levels: PublicAstNodeMap, + // CPython validates Constant.value after object conversion; Ruff has no + // invalid Constant node. Dense lookup stores only rejected public values. + invalid_constants: PublicAstNodeMap, + // CPython JoinedStr.values is expr*; Ruff stores f-string element trees. + // Dense lookup restores the public expr list without ordered-map overhead. + joined_strs: PublicAstNodeMap, + // CPython TemplateStr.values is expr*; Ruff stores t-string element trees. + // Dense lookup restores the public expr list without ordered-map overhead. + template_strs: PublicAstNodeMap, + // CPython comprehension has is_async; Ruff folds it into generator data. + // Dense lookup keeps the raw public flag on affected nodes only. + comprehension_is_async: PublicAstNodeMap, + // CPython permits nullable public pattern lists during validation; Ruff + // pattern lists are non-null. Dense lookup stores only nullable lists. + pattern_lists: PublicAstNodeMap, + // CPython has nullable expr?* slots such as defaults; Ruff omits null list + // entries. Dense lookup stores only public nullable-list nodes. + expr_option_lists: PublicAstNodeMap, + // CPython public expr* fields may contain None until validation; Ruff + // Vec cannot represent null entries. Per-node bundles avoid hashing. + expr_lists: PublicAstNodeMap, + // CPython public stmt* fields may contain None until validation; Ruff + // Vec cannot represent null entries. Per-node bundles avoid hashing. + stmt_lists: PublicAstNodeMap, + // CPython nullable excepthandler* lists cannot be represented in Ruff. + // Dense lookup stores only public nodes that need nullable validation. + except_handler_lists: PublicAstNodeMap, + // CPython nullable type_param* lists cannot be represented in Ruff. Dense + // lookup stores only public nodes that need nullable validation. + type_param_lists: PublicAstNodeMap, + // CPython MatchClass splits patterns/kwd_attrs/kwd_patterns; Ruff stores + // PatternArguments. Dense lookup restores the public split shape. + match_classes: PublicAstNodeMap, + // CPython AnnAssign.simple is a raw int; Ruff has no equivalent field. + // Dense lookup stores only public AnnAssign overrides. + ann_assign_simple: PublicAstNodeMap, + // CPython arg nodes have type_comment; Ruff parameters do not. Dense lookup + // stores only public arg comments. + arg_type_comments: PublicAstNodeMap, + // CPython selected stmt nodes have type_comment; Ruff omits them. Dense + // lookup stores only public stmt comments. + stmt_type_comments: PublicAstNodeMap, +} + +type PublicAstOverrideMap = PublicAstNodeMap; +type PublicAstInterpolationOverrideMap = PublicAstNodeMap; +type PublicAstFormattedValueOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstImportFromLevelOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstInvalidConstantOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstExprListOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstComprehensionIsAsyncOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstPatternListOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstExprOptionListOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstExprListFieldOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstStmtListOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstExceptHandlerListOverrideMap = + PublicAstNodeMap; +pub(super) type PublicAstTypeParamListOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstMatchClassOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstAnnAssignSimpleOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstArgTypeCommentOverrideMap = PublicAstNodeMap; +pub(super) type PublicAstStmtTypeCommentOverrideMap = PublicAstNodeMap; +type PublicAstOverrideCollection = ( + T, + PublicAstOverrideMap, + PublicAstInterpolationOverrideMap, + PublicAstFormattedValueOverrideMap, + PublicAstImportFromLevelOverrideMap, + PublicAstInvalidConstantOverrideMap, + PublicAstExprListOverrideMap, + PublicAstExprListOverrideMap, + PublicAstComprehensionIsAsyncOverrideMap, + PublicAstPatternListOverrideMap, + PublicAstExprOptionListOverrideMap, + PublicAstExprListFieldOverrideMap, + PublicAstStmtListOverrideMap, + PublicAstExceptHandlerListOverrideMap, + PublicAstTypeParamListOverrideMap, + PublicAstMatchClassOverrideMap, + PublicAstAnnAssignSimpleOverrideMap, + PublicAstArgTypeCommentOverrideMap, + PublicAstStmtTypeCommentOverrideMap, +); + +thread_local! { + static PUBLIC_AST_CONSTANTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_CONSTANT_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_INTERPOLATION_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_FORMATTED_VALUE_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_JOINED_STR_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_TEMPLATE_STR_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_COMPREHENSION_IS_ASYNC_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_PATTERN_LIST_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_EXPR_OPTION_LIST_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_EXPR_LIST_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_STMT_LIST_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_EXCEPT_HANDLER_LIST_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_TYPE_PARAM_LIST_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_MATCH_CLASS_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_ANN_ASSIGN_SIMPLE_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_ARG_TYPE_COMMENT_OBJECTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_STMT_TYPE_COMMENT_OBJECTS: RefCell> = const { RefCell::new(None) }; +} + +pub(super) fn collect_public_ast_overrides( + f: impl FnOnce() -> PyResult, +) -> PyResult> { + PUBLIC_AST_CONSTANTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(PublicAstConstantState { + next_index: 0, + constants: PublicAstNodeMap::new(), + interpolations: PublicAstNodeMap::new(), + formatted_values: PublicAstNodeMap::new(), + import_from_levels: PublicAstNodeMap::new(), + invalid_constants: PublicAstNodeMap::new(), + joined_strs: PublicAstNodeMap::new(), + template_strs: PublicAstNodeMap::new(), + comprehension_is_async: PublicAstNodeMap::new(), + pattern_lists: PublicAstNodeMap::new(), + expr_option_lists: PublicAstNodeMap::new(), + expr_lists: PublicAstNodeMap::new(), + stmt_lists: PublicAstNodeMap::new(), + except_handler_lists: PublicAstNodeMap::new(), + type_param_lists: PublicAstNodeMap::new(), + match_classes: PublicAstNodeMap::new(), + ann_assign_simple: PublicAstNodeMap::new(), + arg_type_comments: PublicAstNodeMap::new(), + stmt_type_comments: PublicAstNodeMap::new(), + }); + }); + + let result = f(); + let ( + constants, + interpolations, + formatted_values, + import_from_levels, + invalid_constants, + joined_strs, + template_strs, + comprehension_is_async, + pattern_lists, + expr_option_lists, + expr_lists, + stmt_lists, + except_handler_lists, + type_param_lists, + match_classes, + ann_assign_simple, + arg_type_comments, + stmt_type_comments, + ) = PUBLIC_AST_CONSTANTS.with(|cell| { + let state = cell + .borrow_mut() + .take() + .expect("public AST constant collection state missing"); + ( + state.constants, + state.interpolations, + state.formatted_values, + state.import_from_levels, + state.invalid_constants, + state.joined_strs, + state.template_strs, + state.comprehension_is_async, + state.pattern_lists, + state.expr_option_lists, + state.expr_lists, + state.stmt_lists, + state.except_handler_lists, + state.type_param_lists, + state.match_classes, + state.ann_assign_simple, + state.arg_type_comments, + state.stmt_type_comments, + ) + }); + result.map(|value| { + ( + value, + constants, + interpolations, + formatted_values, + import_from_levels, + invalid_constants, + joined_strs, + template_strs, + comprehension_is_async, + pattern_lists, + expr_option_lists, + expr_lists, + stmt_lists, + except_handler_lists, + type_param_lists, + match_classes, + ann_assign_simple, + arg_type_comments, + stmt_type_comments, + ) + }) +} + +fn register_public_ast_constant(expr: &ast::Expr, constant: ConstantData) { + let index = register_public_ast_override(|state, index| { + state.constants.insert(index, constant); + }); + ast::HasNodeIndex::node_index(expr).set(index); +} + +fn register_public_ast_invalid_constant(expr: &ast::Expr, invalid_type: String) { + let index = register_public_ast_override(|state, index| { + state.invalid_constants.insert(index, invalid_type); + }); + ast::HasNodeIndex::node_index(expr).set(index); +} + +pub(super) fn register_public_ast_interpolation( + str_constant: ConstantData, + format_spec: Option>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state.interpolations.insert( + index, + PublicAstInterpolation { + str: str_constant, + format_spec, + }, + ); + }) +} + +pub(super) fn register_public_ast_formatted_value( + format_spec: Option>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .formatted_values + .insert(index, PublicAstFormattedValue { format_spec }); + }) +} + +pub(super) fn register_public_ast_joined_str(values: Vec) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .joined_strs + .insert(index, PublicAstExprList { values }); + }) +} + +pub(super) fn register_public_ast_template_str(values: Vec) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .template_strs + .insert(index, PublicAstExprList { values }); + }) +} + +pub(super) fn register_public_ast_pattern_list( + values: Vec>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .pattern_lists + .insert(index, PublicAstPatternList { values }); + }) +} + +pub(super) fn register_public_ast_match_mapping( + keys: Vec>, + patterns: Vec>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .expr_option_lists + .insert(index, PublicAstExprOptionList { values: keys }); + state + .pattern_lists + .insert(index, PublicAstPatternList { values: patterns }); + }) +} + +pub(super) fn register_public_ast_expr_option_list( + values: Vec>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .expr_option_lists + .insert(index, PublicAstExprOptionList { values }); + }) +} + +pub(super) fn register_public_ast_stmt_list( + field: PublicAstStmtListField, + values: Vec>, +) -> ast::NodeIndex { + register_public_ast_stmt_lists([(field, values)]) +} + +pub(super) fn register_public_ast_stmt_lists( + values: impl IntoIterator>)>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + for (field, values) in values { + public_ast_stmt_fields_mut(&mut state.stmt_lists, index) + .insert(field, PublicAstStmtList { values }); + } + }) +} + +pub(super) fn register_public_ast_try_lists( + stmt_values: Vec<(PublicAstStmtListField, Vec>)>, + except_handler_values: Option>>, +) -> ast::NodeIndex { + register_public_ast_node_list_overrides(stmt_values, Vec::new(), except_handler_values, None) +} + +pub(super) fn register_public_ast_node_list_overrides( + stmt_values: Vec<(PublicAstStmtListField, Vec>)>, + expr_values: Vec<(PublicAstExprListField, Vec>)>, + except_handler_values: Option>>, + comprehension_is_async: Option, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + for (field, values) in stmt_values { + public_ast_stmt_fields_mut(&mut state.stmt_lists, index) + .insert(field, PublicAstStmtList { values }); + } + for (field, values) in expr_values { + public_ast_expr_fields_mut(&mut state.expr_lists, index) + .insert(field, PublicAstExprOptionList { values }); + } + if let Some(values) = except_handler_values { + state + .except_handler_lists + .insert(index, PublicAstExceptHandlerList { values }); + } + if let Some(value) = comprehension_is_async { + state.comprehension_is_async.insert(index, value); + } + }) +} + +fn public_ast_expr_fields_mut( + values: &mut PublicAstExprListFieldOverrideMap, + index: ast::NodeIndex, +) -> &mut PublicAstExprListFields { + if !values.contains_key(&index) { + values.insert(index, PublicAstExprListFields::default()); + } + values.get_mut(&index).unwrap() +} + +fn public_ast_stmt_fields_mut( + values: &mut PublicAstStmtListOverrideMap, + index: ast::NodeIndex, +) -> &mut PublicAstStmtListFields { + if !values.contains_key(&index) { + values.insert(index, PublicAstStmtListFields::default()); + } + values.get_mut(&index).unwrap() +} + +pub(super) fn register_public_ast_type_param_list( + values: Vec>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state + .type_param_lists + .insert(index, PublicAstTypeParamList { values }); + }) +} + +pub(super) fn register_public_ast_match_class( + patterns: Vec>, + kwd_attrs: Vec, + kwd_patterns: Vec>, +) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state.match_classes.insert( + index, + PublicAstMatchClass { + patterns, + kwd_attrs, + kwd_patterns, + }, + ); + }) +} + +pub(super) fn register_public_ast_import_from_level(level: i32) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state.import_from_levels.insert(index, level); + }) +} + +pub(super) fn register_public_ast_ann_assign_simple(simple: i32) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state.ann_assign_simple.insert(index, simple); + }) +} + +pub(super) fn register_public_ast_arg_type_comment(type_comment: PyObjectRef) -> ast::NodeIndex { + register_public_ast_override(|state, index| { + state.arg_type_comments.insert(index, type_comment); + }) +} + +pub(super) fn register_public_ast_stmt_type_comment( + node_index: &ast::AtomicNodeIndex, + type_comment: PyObjectRef, +) { + register_public_ast_node_override(node_index, |state, index| { + state.stmt_type_comments.insert(index, type_comment); + }); +} + +fn register_public_ast_override( + insert: impl FnOnce(&mut PublicAstConstantState, ast::NodeIndex), +) -> ast::NodeIndex { + PUBLIC_AST_CONSTANTS.with(|cell| { + let mut state = cell.borrow_mut(); + let Some(state) = state.as_mut() else { + return ast::NodeIndex::NONE; + }; + let index = ast::NodeIndex::from(state.next_index); + state.next_index = state + .next_index + .checked_add(1) + .expect("too many public AST constants"); + insert(state, index); + index + }) +} + +fn register_public_ast_node_override( + node_index: &ast::AtomicNodeIndex, + insert: impl FnOnce(&mut PublicAstConstantState, ast::NodeIndex), +) { + PUBLIC_AST_CONSTANTS.with(|cell| { + let mut state = cell.borrow_mut(); + let Some(state) = state.as_mut() else { + return; + }; + let mut index = node_index.load(); + if index == ast::NodeIndex::NONE { + index = ast::NodeIndex::from(state.next_index); + state.next_index = state + .next_index + .checked_add(1) + .expect("too many public AST constants"); + node_index.set(index); + } + insert(state, index); + }); +} + +#[expect( + clippy::too_many_arguments, + reason = "public AST conversion installs independent override tables" +)] +pub(super) fn with_public_ast_interpolation_objects( + constants: &PublicAstOverrideMap, + interpolations: &PublicAstInterpolationOverrideMap, + formatted_values: &PublicAstFormattedValueOverrideMap, + joined_strs: &PublicAstExprListOverrideMap, + template_strs: &PublicAstExprListOverrideMap, + comprehension_is_async: &PublicAstComprehensionIsAsyncOverrideMap, + pattern_lists: &PublicAstPatternListOverrideMap, + expr_option_lists: &PublicAstExprOptionListOverrideMap, + expr_lists: &PublicAstExprListFieldOverrideMap, + stmt_lists: &PublicAstStmtListOverrideMap, + except_handler_lists: &PublicAstExceptHandlerListOverrideMap, + type_param_lists: &PublicAstTypeParamListOverrideMap, + match_classes: &PublicAstMatchClassOverrideMap, + ann_assign_simple: &PublicAstAnnAssignSimpleOverrideMap, + arg_type_comments: &PublicAstArgTypeCommentOverrideMap, + stmt_type_comments: &PublicAstStmtTypeCommentOverrideMap, + f: impl FnOnce() -> T, +) -> T { + PUBLIC_AST_CONSTANT_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(constants.clone()); + }); + PUBLIC_AST_INTERPOLATION_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(interpolations.clone()); + }); + PUBLIC_AST_FORMATTED_VALUE_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(formatted_values.clone()); + }); + PUBLIC_AST_JOINED_STR_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(joined_strs.clone()); + }); + PUBLIC_AST_TEMPLATE_STR_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(template_strs.clone()); + }); + PUBLIC_AST_COMPREHENSION_IS_ASYNC_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(comprehension_is_async.clone()); + }); + PUBLIC_AST_PATTERN_LIST_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(pattern_lists.clone()); + }); + PUBLIC_AST_EXPR_OPTION_LIST_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(expr_option_lists.clone()); + }); + PUBLIC_AST_EXPR_LIST_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(expr_lists.clone()); + }); + PUBLIC_AST_STMT_LIST_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(stmt_lists.clone()); + }); + PUBLIC_AST_EXCEPT_HANDLER_LIST_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(except_handler_lists.clone()); + }); + PUBLIC_AST_TYPE_PARAM_LIST_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(type_param_lists.clone()); + }); + PUBLIC_AST_MATCH_CLASS_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(match_classes.clone()); + }); + PUBLIC_AST_ANN_ASSIGN_SIMPLE_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(ann_assign_simple.clone()); + }); + PUBLIC_AST_ARG_TYPE_COMMENT_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(arg_type_comments.clone()); + }); + PUBLIC_AST_STMT_TYPE_COMMENT_OBJECTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = Some(stmt_type_comments.clone()); + }); + let result = f(); + PUBLIC_AST_CONSTANT_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_INTERPOLATION_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_FORMATTED_VALUE_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_JOINED_STR_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_TEMPLATE_STR_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_COMPREHENSION_IS_ASYNC_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_PATTERN_LIST_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_EXPR_OPTION_LIST_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_EXPR_LIST_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_STMT_LIST_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_EXCEPT_HANDLER_LIST_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_TYPE_PARAM_LIST_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_MATCH_CLASS_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_ANN_ASSIGN_SIMPLE_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_ARG_TYPE_COMMENT_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_STMT_TYPE_COMMENT_OBJECTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + result +} + +pub(super) fn public_ast_constant_object( + vm: &VirtualMachine, + source_file: &SourceFile, + node_index: ast::NodeIndex, + range: TextRange, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + let constant = PUBLIC_AST_CONSTANT_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|constants| constants.get(&node_index).cloned()) + })?; + let node = NodeAst + .into_ref_with_type(vm, pyast::NodeExprConstant::static_type().to_owned()) + .unwrap(); + let dict = node.as_object().dict().unwrap(); + dict.set_item("value", constant_data_to_object(vm, constant), vm) + .unwrap(); + dict.set_item("kind", vm.ctx.none(), vm).unwrap(); + node_add_location(&dict, range, vm, source_file); + Some(node.into()) +} + +pub(super) fn public_ast_interpolation_object( + vm: &VirtualMachine, + node_index: ast::NodeIndex, +) -> Option<(PyObjectRef, Option>)> { + if node_index == ast::NodeIndex::NONE { + return None; + } + let interpolation = PUBLIC_AST_INTERPOLATION_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|interpolations| interpolations.get(&node_index).cloned()) + })?; + Some(( + constant_data_to_object(vm, interpolation.str), + interpolation.format_spec, + )) +} + +pub(super) fn public_ast_formatted_value_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_FORMATTED_VALUE_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|formatted_values| formatted_values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_joined_str_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_JOINED_STR_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|joined_strs| joined_strs.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_template_str_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_TEMPLATE_STR_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|template_strs| template_strs.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_comprehension_is_async_object(node_index: ast::NodeIndex) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_COMPREHENSION_IS_ASYNC_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).copied()) + }) +} + +pub(super) fn public_ast_pattern_list_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_PATTERN_LIST_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_expr_option_list_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_EXPR_OPTION_LIST_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_expr_list_object( + node_index: ast::NodeIndex, + field: PublicAstExprListField, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_EXPR_LIST_OBJECTS.with(|cell| { + cell.borrow().as_ref().and_then(|values| { + values + .get(&node_index) + .and_then(|values| values.get(field)) + .cloned() + }) + }) +} + +pub(super) fn public_ast_stmt_list_object( + node_index: ast::NodeIndex, + field: PublicAstStmtListField, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_STMT_LIST_OBJECTS.with(|cell| { + cell.borrow().as_ref().and_then(|values| { + values + .get(&node_index) + .and_then(|values| values.get(field)) + .cloned() + }) + }) +} + +pub(super) fn public_ast_except_handler_list_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_EXCEPT_HANDLER_LIST_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_type_param_list_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_TYPE_PARAM_LIST_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_match_class_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_MATCH_CLASS_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_ann_assign_simple_object(node_index: ast::NodeIndex) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_ANN_ASSIGN_SIMPLE_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).copied()) + }) +} + +pub(super) fn public_ast_arg_type_comment_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_ARG_TYPE_COMMENT_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +pub(super) fn public_ast_stmt_type_comment_object( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_STMT_TYPE_COMMENT_OBJECTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +fn constant_literal_to_constant_data(value: &ConstantLiteral) -> ConstantData { + match value { + ConstantLiteral::None => ConstantData::None, + ConstantLiteral::Bool(value) => ConstantData::Boolean { value: *value }, + ConstantLiteral::Str { value, .. } => ConstantData::Str { + value: value.as_ref().into(), + }, + ConstantLiteral::Bytes(value) => ConstantData::Bytes { + value: value.to_vec(), + }, + ConstantLiteral::Int(value) => ConstantData::Integer { + value: ruff_int_to_bigint(value).unwrap(), + }, + ConstantLiteral::Tuple(value) => ConstantData::Tuple { + elements: value + .iter() + .map(constant_literal_to_constant_data) + .collect(), + }, + ConstantLiteral::FrozenSet(value) => ConstantData::Frozenset { + elements: value + .iter() + .map(constant_literal_to_constant_data) + .collect(), + }, + ConstantLiteral::Float(value) => ConstantData::Float { value: *value }, + ConstantLiteral::Complex { real, imag } => ConstantData::Complex { + value: num_complex::Complex::new(*real, *imag), + }, + ConstantLiteral::Ellipsis => ConstantData::Ellipsis, + } +} + +pub(super) fn constant_object_to_constant_data( + vm: &VirtualMachine, + source_file: &SourceFile, + value_object: PyObjectRef, +) -> PyResult { + let value = ConstantLiteral::ast_from_object(vm, source_file, value_object)?; + Ok(constant_literal_to_constant_data(&value)) +} + +fn first_invalid_constant_type(vm: &VirtualMachine, value_object: PyObjectRef) -> PyResult { + let cls = value_object.class(); + let class_name = cls.name().to_owned(); + if cls.is(vm.ctx.types.tuple_type) { + vm.with_recursion(" during compilation", || { + let tuple = value_object.clone().downcast::().map_err(|obj| { + vm.new_type_error(format!( + "Expected type {}, not {}", + PyTuple::static_type().name(), + obj.class().name() + )) + })?; + for item in tuple.iter() { + if let Some(invalid_type) = first_invalid_constant_type_opt(vm, item.clone())? { + return Ok(invalid_type); + } + } + Ok(class_name) + }) + } else if cls.is(vm.ctx.types.frozenset_type) { + vm.with_recursion(" during compilation", || { + let set = value_object.clone().downcast::().unwrap(); + for item in set.elements() { + if let Some(invalid_type) = first_invalid_constant_type_opt(vm, item)? { + return Ok(invalid_type); + } + } + Ok(class_name) + }) + } else { + Ok(class_name) + } +} + +fn first_invalid_constant_type_opt( + vm: &VirtualMachine, + value_object: PyObjectRef, +) -> PyResult> { + let cls = value_object.class(); + if cls.is(vm.ctx.types.none_type) + || cls.is(vm.ctx.types.bool_type) + || cls.is(vm.ctx.types.str_type) + || cls.is(vm.ctx.types.bytes_type) + || cls.is(vm.ctx.types.int_type) + || cls.is(vm.ctx.types.float_type) + || cls.is(vm.ctx.types.complex_type) + || cls.is(vm.ctx.types.ellipsis_type) + { + return Ok(None); + } + if cls.is(vm.ctx.types.tuple_type) || cls.is(vm.ctx.types.frozenset_type) { + return first_invalid_constant_type(vm, value_object).map(Some); + } + Ok(Some(cls.name().to_owned())) +} + +fn constant_data_to_object(vm: &VirtualMachine, constant: ConstantData) -> PyObjectRef { + match constant { + ConstantData::None => vm.ctx.none(), + ConstantData::Boolean { value } => vm.ctx.new_bool(value).to_pyobject(vm), + ConstantData::Str { value } => vm.ctx.new_str(value.to_string()).to_pyobject(vm), + ConstantData::Bytes { value } => vm.ctx.new_bytes(value).to_pyobject(vm), + ConstantData::Integer { value } => vm.ctx.new_int(value).into(), + ConstantData::Tuple { elements } => { + let value = elements + .into_iter() + .map(|c| constant_data_to_object(vm, c)) + .collect(); + vm.ctx.new_tuple(value).to_pyobject(vm) + } + ConstantData::Frozenset { elements } => PyFrozenSet::from_iter( + vm, + elements.into_iter().map(|c| constant_data_to_object(vm, c)), + ) + .unwrap() + .into_pyobject(vm), + ConstantData::Float { value } => vm.ctx.new_float(value).into_pyobject(vm), + ConstantData::Complex { value } => vm.ctx.new_complex(value).into_pyobject(vm), + ConstantData::Ellipsis => vm.ctx.ellipsis.clone().into(), + ConstantData::Code { .. } | ConstantData::Slice { .. } => { + unreachable!("public AST constants cannot contain code objects or slices") + } + } +} + // constructor +pub(super) fn constant_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let value_object = get_node_field(vm, &object, "value", "Constant")?; + let (value, invalid_type) = + match ConstantLiteral::ast_from_object(vm, source_file, value_object.clone()) { + Ok(value) => (value, None), + Err(_) => ( + ConstantLiteral::None, + Some(first_invalid_constant_type(vm, value_object)?), + ), + }; + let _kind = get_ast_string_field_opt(vm, &object, "kind")?; + + Ok(Constant { + range, + value, + invalid_type, + }) +} + impl Node for Constant { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - let Self { range, value } = self; + let Self { + range, + value, + invalid_type: _, + } = self; let node = NodeAst .into_ref_with_type(vm, pyast::NodeExprConstant::static_type().to_owned()) .unwrap(); @@ -123,16 +1261,8 @@ impl Node for Constant { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let value_object = get_node_field(vm, &object, "value", "Constant")?; - let value = Node::ast_from_object(vm, source_file, value_object)?; - - Ok(Self { - value, - // kind: get_node_field_opt(_vm, &_object, "kind")? - // .map(|obj| Node::ast_from_object(_vm, obj)) - // .transpose()?, - range: range_from_object(vm, source_file, object, "Constant")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Constant")?; + constant_from_object_with_range(vm, source_file, object, range) } } @@ -201,8 +1331,12 @@ impl Node for ConstantLiteral { })?; let tuple = tuple .into_iter() - .cloned() - .map(|object| Node::ast_from_object(vm, source_file, object)) + .map(|object| { + let object = object.clone(); + vm.with_recursion("during compilation", || { + Node::ast_from_object(vm, source_file, object) + }) + }) .collect::>()?; Self::Tuple(tuple) } else if cls.is(vm.ctx.types.frozenset_type) { @@ -210,7 +1344,11 @@ impl Node for ConstantLiteral { let elements = set .elements() .into_iter() - .map(|object| Node::ast_from_object(vm, source_file, object)) + .map(|object| { + vm.with_recursion("during compilation", || { + Node::ast_from_object(vm, source_file, object) + }) + }) .collect::>()?; Self::FrozenSet(elements) } else if cls.is(vm.ctx.types.float_type) { @@ -245,7 +1383,11 @@ impl Node for ConstantLiteral { } fn constant_to_ruff_expr(value: Constant) -> ast::Expr { - let Constant { value, range } = value; + let Constant { + value, + range, + invalid_type: _, + } = value; match value { ConstantLiteral::None => ast::Expr::NoneLiteral(ast::ExprNoneLiteral { node_index: Default::default(), @@ -294,6 +1436,7 @@ fn constant_to_ruff_expr(value: Constant) -> ast::Expr { constant_to_ruff_expr(Constant { range: TextRange::default(), value, + invalid_type: None, }) }) .collect(), @@ -314,6 +1457,7 @@ fn constant_to_ruff_expr(value: Constant) -> ast::Expr { constant_to_ruff_expr(Constant { range: TextRange::default(), value, + invalid_type: None, }) }) .collect(), @@ -332,7 +1476,7 @@ fn constant_to_ruff_expr(value: Constant) -> ast::Expr { node_index: Default::default(), range, args: args.into(), - keywords: Box::default(), + keywords: Default::default(), }, }) } diff --git a/crates/vm/src/stdlib/_ast/elif_else_clause.rs b/crates/vm/src/stdlib/_ast/elif_else_clause.rs index 0afdbc02ac1..4e926195d29 100644 --- a/crates/vm/src/stdlib/_ast/elif_else_clause.rs +++ b/crates/vm/src/stdlib/_ast/elif_else_clause.rs @@ -8,7 +8,7 @@ pub(super) fn ast_to_object( source_file: &SourceFile, ) -> PyObjectRef { let ast::ElifElseClause { - node_index: _, + node_index, range, test, body, @@ -24,10 +24,22 @@ pub(super) fn ast_to_object( dict.set_item("test", test.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("body", body.ast_to_object(vm, source_file), vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("body", body, vm).unwrap(); - let orelse = if let Some(next) = rest.next() { + let orelse = if let Some(values) = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + ) { + values.values.ast_to_object(vm, source_file) + } else if let Some(next) = rest.next() { if next.test.is_some() { let next = ast::ElifElseClause { range: TextRange::new(next.range.start(), range.end()), @@ -48,25 +60,28 @@ pub(super) fn ast_to_object( node.into() } -pub(super) fn ast_from_object( +pub(super) fn ast_from_object_with_range( vm: &VirtualMachine, source_file: &SourceFile, object: PyObjectRef, + range: TextRange, ) -> PyResult { - let test = Node::ast_from_object(vm, source_file, get_node_field(vm, &object, "test", "If")?)?; - let body = Node::ast_from_object(vm, source_file, get_node_field(vm, &object, "body", "If")?)?; - let orelse: Vec = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "orelse", "If")?, - )?; - let range = range_from_object(vm, source_file, object, "If")?; + let test = get_required_node_field(vm, source_file, &object, "test", "If")?; + let body: Vec> = get_node_list_field(vm, source_file, &object, "body", "If")?; + let orelse: Vec> = + get_node_list_field(vm, source_file, &object, "orelse", "If")?; + let node_index = public_stmt_lists_node_index([ + (super::constant::PublicAstStmtListField::Body, &body), + (super::constant::PublicAstStmtListField::Orelse, &orelse), + ]); + let body = lower_public_stmt_list(body); + let orelse = lower_public_stmt_list(orelse); let elif_else_clauses = if orelse.is_empty() { vec![] } else if let [ast::Stmt::If(_)] = &*orelse { let Some(ast::Stmt::If(ast::StmtIf { - node_index: _, + node_index, range, test, body, @@ -78,7 +93,7 @@ pub(super) fn ast_from_object( elif_else_clauses.insert( 0, ast::ElifElseClause { - node_index: Default::default(), + node_index, range, test: Some(*test), body, @@ -95,7 +110,7 @@ pub(super) fn ast_from_object( }; Ok(ast::StmtIf { - node_index: Default::default(), + node_index, test, body, elif_else_clauses, diff --git a/crates/vm/src/stdlib/_ast/exception.rs b/crates/vm/src/stdlib/_ast/exception.rs index 2daabecc84c..e144c94f7e8 100644 --- a/crates/vm/src/stdlib/_ast/exception.rs +++ b/crates/vm/src/stdlib/_ast/exception.rs @@ -13,29 +13,87 @@ impl Node for ast::ExceptHandler { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok( - if cls.is(pyast::NodeExceptHandlerExceptHandler::static_type()) { - Self::ExceptHandler(ast::ExceptHandlerExceptHandler::ast_from_object( - vm, - source_file, - object, - )?) - } else { - return Err(vm.new_type_error(format!( - "expected some sort of excepthandler, but got {}", - object.repr(vm)? - ))); - }, - ) + if vm.is_none(&object) { + return Err(vm.new_type_error(format!( + "expected some sort of excepthandler, but got {}", + object.repr(vm)? + ))); + } + if !is_node_instance( + vm, + &object, + pyast::NodeExceptHandlerExceptHandler::static_type(), + )? { + return Err(vm.new_type_error(format!( + "expected some sort of excepthandler, but got {}", + object.repr(vm)? + ))); + } + let range = excepthandler_range_from_object(vm, source_file, object.clone())?; + Ok(Self::ExceptHandler(except_handler_from_object_with_range( + vm, + source_file, + object, + range, + )?)) } } // constructor +fn except_handler_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let body: Vec> = + get_node_list_field(vm, source_file, &object, "body", "ExceptHandler")?; + let (node_index, body) = + public_stmt_list_from_values(super::constant::PublicAstStmtListField::Body, body); + Ok(ast::ExceptHandlerExceptHandler { + node_index, + type_: get_node_field_opt(vm, &object, "type")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + name: get_node_field_opt(vm, &object, "name")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + body, + range, + }) +} + +pub(super) fn except_handler_from_object_unvalidated_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, +) -> PyResult { + if vm.is_none(&object) { + return Err(vm.new_type_error(format!( + "expected some sort of excepthandler, but got {}", + object.repr(vm)? + ))); + } + if !is_node_instance( + vm, + &object, + pyast::NodeExceptHandlerExceptHandler::static_type(), + )? { + return Err(vm.new_type_error(format!( + "expected some sort of excepthandler, but got {}", + object.repr(vm)? + ))); + } + let range = excepthandler_range_from_object_unvalidated(vm, source_file, object.clone())?; + Ok(ast::ExceptHandler::ExceptHandler( + except_handler_from_object_with_range(vm, source_file, object, range)?, + )) +} + impl Node for ast::ExceptHandlerExceptHandler { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, type_, name, body, @@ -52,8 +110,15 @@ impl Node for ast::ExceptHandlerExceptHandler { .unwrap(); dict.set_item("name", name.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("body", body.ast_to_object(vm, source_file), vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("body", body, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -63,20 +128,7 @@ impl Node for ast::ExceptHandlerExceptHandler { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - type_: get_node_field_opt(vm, &object, "type")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - name: get_node_field_opt(vm, &object, "name")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "ExceptHandler")?, - )?, - range: range_from_object(vm, source_file, object, "ExceptHandler")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "ExceptHandler")?; + except_handler_from_object_with_range(vm, source_file, object, range) } } diff --git a/crates/vm/src/stdlib/_ast/expression.rs b/crates/vm/src/stdlib/_ast/expression.rs index 2b32a33f34d..686cfd62018 100644 --- a/crates/vm/src/stdlib/_ast/expression.rs +++ b/crates/vm/src/stdlib/_ast/expression.rs @@ -1,14 +1,20 @@ use super::*; -use crate::stdlib::_ast::{ - argument::{merge_function_call_arguments, split_function_call_arguments}, - constant::Constant, - string::JoinedStr, +use crate::stdlib::_ast::argument::{ + KeywordArguments, PositionalArguments, merge_function_call_arguments, + split_function_call_arguments, }; use rustpython_compiler_core::SourceFile; // sum impl Node for ast::Expr { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { + let node_index = ast::HasNodeIndex::node_index(&self).load(); + let range = self.range(); + if let Some(object) = + constant::public_ast_constant_object(vm, source_file, node_index, range) + { + return object; + } match self { Self::BoolOp(cons) => cons.ast_to_object(vm, source_file), Self::Name(cons) => cons.ast_to_object(vm, source_file), @@ -47,7 +53,7 @@ impl Node for ast::Expr { } Self::Named(cons) => cons.ast_to_object(vm, source_file), Self::IpyEscapeCommand(_) => { - unimplemented!("IPython escape command is not allowed in Python AST") + unreachable!("IPython escape command is not part of Python AST") } } } @@ -57,98 +63,311 @@ impl Node for ast::Expr { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeExprBoolOp::static_type()) { - Self::BoolOp(ast::ExprBoolOp::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprNamedExpr::static_type()) { - Self::Named(ast::ExprNamed::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprBinOp::static_type()) { - Self::BinOp(ast::ExprBinOp::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprUnaryOp::static_type()) { - Self::UnaryOp(ast::ExprUnaryOp::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprLambda::static_type()) { - Self::Lambda(ast::ExprLambda::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprIfExp::static_type()) { - Self::If(ast::ExprIf::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprDict::static_type()) { - Self::Dict(ast::ExprDict::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprSet::static_type()) { - Self::Set(ast::ExprSet::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprListComp::static_type()) { - Self::ListComp(ast::ExprListComp::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprSetComp::static_type()) { - Self::SetComp(ast::ExprSetComp::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprDictComp::static_type()) { - Self::DictComp(ast::ExprDictComp::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprGeneratorExp::static_type()) { - Self::Generator(ast::ExprGenerator::ast_from_object( + if vm.is_none(&object) { + return Err(vm.new_type_error(format!( + "expected some sort of expr, but got {}", + object.repr(vm)? + ))); + } + enum ExprKind { + BoolOp, + Named, + BinOp, + UnaryOp, + Lambda, + If, + Dict, + Set, + ListComp, + SetComp, + DictComp, + Generator, + Await, + Yield, + YieldFrom, + Compare, + Call, + FormattedValue, + Interpolation, + JoinedStr, + TemplateStr, + Constant, + Attribute, + Subscript, + Starred, + Name, + List, + Tuple, + Slice, + } + let kind = if is_node_instance(vm, &object, pyast::NodeExprBoolOp::static_type())? { + ExprKind::BoolOp + } else if is_node_instance(vm, &object, pyast::NodeExprNamedExpr::static_type())? { + ExprKind::Named + } else if is_node_instance(vm, &object, pyast::NodeExprBinOp::static_type())? { + ExprKind::BinOp + } else if is_node_instance(vm, &object, pyast::NodeExprUnaryOp::static_type())? { + ExprKind::UnaryOp + } else if is_node_instance(vm, &object, pyast::NodeExprLambda::static_type())? { + ExprKind::Lambda + } else if is_node_instance(vm, &object, pyast::NodeExprIfExp::static_type())? { + ExprKind::If + } else if is_node_instance(vm, &object, pyast::NodeExprDict::static_type())? { + ExprKind::Dict + } else if is_node_instance(vm, &object, pyast::NodeExprSet::static_type())? { + ExprKind::Set + } else if is_node_instance(vm, &object, pyast::NodeExprListComp::static_type())? { + ExprKind::ListComp + } else if is_node_instance(vm, &object, pyast::NodeExprSetComp::static_type())? { + ExprKind::SetComp + } else if is_node_instance(vm, &object, pyast::NodeExprDictComp::static_type())? { + ExprKind::DictComp + } else if is_node_instance(vm, &object, pyast::NodeExprGeneratorExp::static_type())? { + ExprKind::Generator + } else if is_node_instance(vm, &object, pyast::NodeExprAwait::static_type())? { + ExprKind::Await + } else if is_node_instance(vm, &object, pyast::NodeExprYield::static_type())? { + ExprKind::Yield + } else if is_node_instance(vm, &object, pyast::NodeExprYieldFrom::static_type())? { + ExprKind::YieldFrom + } else if is_node_instance(vm, &object, pyast::NodeExprCompare::static_type())? { + ExprKind::Compare + } else if is_node_instance(vm, &object, pyast::NodeExprCall::static_type())? { + ExprKind::Call + } else if is_node_instance(vm, &object, pyast::NodeExprFormattedValue::static_type())? { + ExprKind::FormattedValue + } else if is_node_instance(vm, &object, pyast::NodeExprInterpolation::static_type())? { + ExprKind::Interpolation + } else if is_node_instance(vm, &object, pyast::NodeExprJoinedStr::static_type())? { + ExprKind::JoinedStr + } else if is_node_instance(vm, &object, pyast::NodeExprTemplateStr::static_type())? { + ExprKind::TemplateStr + } else if is_node_instance(vm, &object, pyast::NodeExprConstant::static_type())? { + ExprKind::Constant + } else if is_node_instance(vm, &object, pyast::NodeExprAttribute::static_type())? { + ExprKind::Attribute + } else if is_node_instance(vm, &object, pyast::NodeExprSubscript::static_type())? { + ExprKind::Subscript + } else if is_node_instance(vm, &object, pyast::NodeExprStarred::static_type())? { + ExprKind::Starred + } else if is_node_instance(vm, &object, pyast::NodeExprName::static_type())? { + ExprKind::Name + } else if is_node_instance(vm, &object, pyast::NodeExprList::static_type())? { + ExprKind::List + } else if is_node_instance(vm, &object, pyast::NodeExprTuple::static_type())? { + ExprKind::Tuple + } else if is_node_instance(vm, &object, pyast::NodeExprSlice::static_type())? { + ExprKind::Slice + } else { + return Err(vm.new_type_error(format!( + "expected some sort of expr, but got {}", + object.repr(vm)? + ))); + }; + let range = expr_range_from_object(vm, source_file, object.clone())?; + Ok(match kind { + ExprKind::BoolOp => Self::BoolOp(expr_bool_op_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeExprAwait::static_type()) { - Self::Await(ast::ExprAwait::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprYield::static_type()) { - Self::Yield(ast::ExprYield::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprYieldFrom::static_type()) { - Self::YieldFrom(ast::ExprYieldFrom::ast_from_object( + range, + )?), + ExprKind::Named => Self::Named(expr_named_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeExprCompare::static_type()) { - Self::Compare(ast::ExprCompare::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprCall::static_type()) { - Self::Call(ast::ExprCall::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprAttribute::static_type()) { - Self::Attribute(ast::ExprAttribute::ast_from_object( + range, + )?), + ExprKind::BinOp => Self::BinOp(expr_bin_op_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeExprSubscript::static_type()) { - Self::Subscript(ast::ExprSubscript::ast_from_object( + range, + )?), + ExprKind::UnaryOp => Self::UnaryOp(expr_unary_op_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeExprStarred::static_type()) { - Self::Starred(ast::ExprStarred::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprName::static_type()) { - Self::Name(ast::ExprName::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprList::static_type()) { - Self::List(ast::ExprList::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprTuple::static_type()) { - Self::Tuple(ast::ExprTuple::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprSlice::static_type()) { - Self::Slice(ast::ExprSlice::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeExprConstant::static_type()) { - Constant::ast_from_object(vm, source_file, object)?.into_expr() - } else if cls.is(pyast::NodeExprJoinedStr::static_type()) { - JoinedStr::ast_from_object(vm, source_file, object)?.into_expr() - } else if cls.is(pyast::NodeExprTemplateStr::static_type()) { - let template = string::TemplateStr::ast_from_object(vm, source_file, object)?; - return string::template_str_to_expr(vm, template); - } else if cls.is(pyast::NodeExprInterpolation::static_type()) { - let interpolation = - string::TStringInterpolation::ast_from_object(vm, source_file, object)?; - return string::interpolation_to_expr(vm, interpolation); - } else if vm.is_none(&object) { - return Err(vm.new_value_error("None disallowed in expression list")); - } else { - return Err(vm.new_type_error(format!( - "expected some sort of expr, but got {}", - object.repr(vm)? - ))); + range, + )?), + ExprKind::Lambda => Self::Lambda(expr_lambda_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::If => Self::If(expr_if_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Dict => Self::Dict(expr_dict_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Set => Self::Set(expr_set_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::ListComp => Self::ListComp(expr_list_comp_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::SetComp => Self::SetComp(expr_set_comp_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::DictComp => Self::DictComp(expr_dict_comp_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Generator => Self::Generator(expr_generator_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Await => Self::Await(expr_await_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Yield => Self::Yield(expr_yield_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::YieldFrom => Self::YieldFrom(expr_yield_from_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Compare => Self::Compare(expr_compare_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Call => Self::Call(expr_call_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::FormattedValue => { + let formatted = + string::formatted_value_from_object_with_range(vm, source_file, object, range)?; + string::formatted_value_to_expr(formatted) + } + ExprKind::Interpolation => { + let interpolation = string::tstring_interpolation_from_object_with_range( + vm, + source_file, + object, + range, + )?; + string::interpolation_to_expr(vm, source_file, interpolation)? + } + ExprKind::JoinedStr => { + string::joined_str_from_object_with_range(vm, source_file, object, range)? + .into_expr() + } + ExprKind::TemplateStr => { + let template = + string::template_str_from_object_with_range(vm, source_file, object, range)?; + string::template_str_to_expr(vm, source_file, template)? + } + ExprKind::Constant => { + constant::constant_from_object_with_range(vm, source_file, object, range)? + .into_expr() + } + ExprKind::Attribute => Self::Attribute(expr_attribute_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Subscript => Self::Subscript(expr_subscript_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Starred => Self::Starred(expr_starred_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Name => Self::Name(expr_name_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::List => Self::List(expr_list_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Tuple => Self::Tuple(expr_tuple_from_object_with_range( + vm, + source_file, + object, + range, + )?), + ExprKind::Slice => Self::Slice(expr_slice_from_object_with_range( + vm, + source_file, + object, + range, + )?), }) } } // constructor +fn expr_bool_op_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let values: Vec> = + get_node_list_field(vm, source_file, &object, "values", "BoolOp")?; + let (node_index, values) = + public_expr_list_from_values(super::constant::PublicAstExprListField::Values, values); + Ok(ast::ExprBoolOp { + node_index, + op: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "op", "BoolOp")?, + )?, + values, + range, + }) +} + impl Node for ast::ExprBoolOp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, op, values, range, @@ -159,8 +378,15 @@ impl Node for ast::ExprBoolOp { let dict = node.as_object().dict().unwrap(); dict.set_item("op", op.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("values", values.ast_to_object(vm, source_file), vm) - .unwrap(); + let values = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Values, + ) + .map_or_else( + || values.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("values", values, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -170,24 +396,26 @@ impl Node for ast::ExprBoolOp { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - op: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "op", "BoolOp")?, - )?, - values: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "values", "BoolOp")?, - )?, - range: range_from_object(vm, source_file, object, "BoolOp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "BoolOp")?; + expr_bool_op_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_named_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprNamed { + node_index: Default::default(), + target: get_required_node_field(vm, source_file, &object, "target", "NamedExpr")?, + value: get_required_node_field(vm, source_file, &object, "value", "NamedExpr")?, + range, + }) +} + impl Node for ast::ExprNamed { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -213,24 +441,31 @@ impl Node for ast::ExprNamed { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - target: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "target", "NamedExpr")?, - )?, - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "NamedExpr")?, - )?, - range: range_from_object(vm, source_file, object, "NamedExpr")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "NamedExpr")?; + expr_named_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_bin_op_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprBinOp { + node_index: Default::default(), + left: get_required_node_field(vm, source_file, &object, "left", "BinOp")?, + op: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "op", "BinOp")?, + )?, + right: get_required_node_field(vm, source_file, &object, "right", "BinOp")?, + range, + }) +} + impl Node for ast::ExprBinOp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -259,29 +494,30 @@ impl Node for ast::ExprBinOp { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - left: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "left", "BinOp")?, - )?, - op: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "op", "BinOp")?, - )?, - right: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "right", "BinOp")?, - )?, - range: range_from_object(vm, source_file, object, "BinOp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "BinOp")?; + expr_bin_op_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_unary_op_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprUnaryOp { + node_index: Default::default(), + op: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "op", "UnaryOp")?, + )?, + operand: get_required_node_field(vm, source_file, &object, "operand", "UnaryOp")?, + range, + }) +} + impl Node for ast::ExprUnaryOp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -306,24 +542,30 @@ impl Node for ast::ExprUnaryOp { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - op: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "op", "UnaryOp")?, - )?, - operand: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "operand", "UnaryOp")?, - )?, - range: range_from_object(vm, source_file, object, "UnaryOp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "UnaryOp")?; + expr_unary_op_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_lambda_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprLambda { + node_index: Default::default(), + parameters: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "args", "Lambda")?, + )?, + body: get_required_node_field(vm, source_file, &object, "body", "Lambda")?, + range, + }) +} + impl Node for ast::ExprLambda { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -353,24 +595,27 @@ impl Node for ast::ExprLambda { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - parameters: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "args", "Lambda")?, - )?, - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "Lambda")?, - )?, - range: range_from_object(vm, source_file, object, "Lambda")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Lambda")?; + expr_lambda_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_if_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprIf { + node_index: Default::default(), + test: get_required_node_field(vm, source_file, &object, "test", "IfExp")?, + body: get_required_node_field(vm, source_file, &object, "body", "IfExp")?, + orelse: get_required_node_field(vm, source_file, &object, "orelse", "IfExp")?, + range, + }) +} + impl Node for ast::ExprIf { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -399,33 +644,43 @@ impl Node for ast::ExprIf { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - test: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "test", "IfExp")?, - )?, - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "IfExp")?, - )?, - orelse: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "orelse", "IfExp")?, - )?, - range: range_from_object(vm, source_file, object, "IfExp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "IfExp")?; + expr_if_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_dict_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let keys: Vec> = + get_node_list_field(vm, source_file, &object, "keys", "Dict")?; + let values: Vec> = + get_node_list_field(vm, source_file, &object, "values", "Dict")?; + if keys.len() != values.len() { + return Err(vm.new_value_error("Dict doesn't have the same number of keys as values")); + } + let node_index = + public_expr_lists_node_index([(super::constant::PublicAstExprListField::Values, &values)]); + let items = keys + .into_iter() + .zip(lower_public_expr_list(values)) + .map(|(key, value)| ast::DictItem { key, value }) + .collect(); + Ok(ast::ExprDict { + node_index, + items, + range, + }) +} + impl Node for ast::ExprDict { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, items, range, } = self; @@ -443,8 +698,15 @@ impl Node for ast::ExprDict { let dict = node.as_object().dict().unwrap(); dict.set_item("keys", keys.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("values", values.ast_to_object(vm, source_file), vm) - .unwrap(); + let values = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Values, + ) + .map_or_else( + || values.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("values", values, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -454,37 +716,33 @@ impl Node for ast::ExprDict { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let keys: Vec> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "keys", "Dict")?, - )?; - let values: Vec<_> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "values", "Dict")?, - )?; - if keys.len() != values.len() { - return Err(vm.new_value_error("Dict doesn't have the same number of keys as values")); - } - let items = keys - .into_iter() - .zip(values) - .map(|(key, value)| ast::DictItem { key, value }) - .collect(); - Ok(Self { - node_index: Default::default(), - items, - range: range_from_object(vm, source_file, object, "Dict")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Dict")?; + expr_dict_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_set_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let elts: Vec> = + get_node_list_field(vm, source_file, &object, "elts", "Set")?; + let (node_index, elts) = + public_expr_list_from_values(super::constant::PublicAstExprListField::Elts, elts); + Ok(ast::ExprSet { + node_index, + elts, + range, + }) +} + impl Node for ast::ExprSet { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, elts, range, } = self; @@ -492,8 +750,15 @@ impl Node for ast::ExprSet { .into_ref_with_type(vm, pyast::NodeExprSet::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("elts", elts.ast_to_object(vm, source_file), vm) - .unwrap(); + let elts = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Elts, + ) + .map_or_else( + || elts.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("elts", elts, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -502,19 +767,26 @@ impl Node for ast::ExprSet { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - elts: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "elts", "Set")?, - )?, - range: range_from_object(vm, source_file, object, "Set")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Set")?; + expr_set_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_list_comp_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprListComp { + node_index: Default::default(), + elt: get_required_node_field(vm, source_file, &object, "elt", "ListComp")?, + generators: get_node_list_field(vm, source_file, &object, "generators", "ListComp")?, + range, + }) +} + impl Node for ast::ExprListComp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -540,24 +812,26 @@ impl Node for ast::ExprListComp { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - elt: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "elt", "ListComp")?, - )?, - generators: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "generators", "ListComp")?, - )?, - range: range_from_object(vm, source_file, object, "ListComp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "ListComp")?; + expr_list_comp_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_set_comp_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprSetComp { + node_index: Default::default(), + elt: get_required_node_field(vm, source_file, &object, "elt", "SetComp")?, + generators: get_node_list_field(vm, source_file, &object, "generators", "SetComp")?, + range, + }) +} + impl Node for ast::ExprSetComp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -583,24 +857,33 @@ impl Node for ast::ExprSetComp { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - elt: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "elt", "SetComp")?, - )?, - generators: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "generators", "SetComp")?, - )?, - range: range_from_object(vm, source_file, object, "SetComp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "SetComp")?; + expr_set_comp_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_dict_comp_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprDictComp { + node_index: Default::default(), + key: Some(get_required_node_field( + vm, + source_file, + &object, + "key", + "DictComp", + )?), + value: get_required_node_field(vm, source_file, &object, "value", "DictComp")?, + generators: get_node_list_field(vm, source_file, &object, "generators", "DictComp")?, + range, + }) +} + impl Node for ast::ExprDictComp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -614,8 +897,8 @@ impl Node for ast::ExprDictComp { .into_ref_with_type(vm, pyast::NodeExprDictComp::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("key", key.ast_to_object(vm, source_file), vm) - .unwrap(); + let key = key.map_or_else(|| vm.ctx.none(), |key| key.ast_to_object(vm, source_file)); + dict.set_item("key", key, vm).unwrap(); dict.set_item("value", value.ast_to_object(vm, source_file), vm) .unwrap(); dict.set_item("generators", generators.ast_to_object(vm, source_file), vm) @@ -629,29 +912,28 @@ impl Node for ast::ExprDictComp { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - key: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "key", "DictComp")?, - )?, - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "DictComp")?, - )?, - generators: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "generators", "DictComp")?, - )?, - range: range_from_object(vm, source_file, object, "DictComp")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "DictComp")?; + expr_dict_comp_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_generator_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprGenerator { + node_index: Default::default(), + elt: get_required_node_field(vm, source_file, &object, "elt", "GeneratorExp")?, + generators: get_node_list_field(vm, source_file, &object, "generators", "GeneratorExp")?, + range, + // TODO: Is this correct? + parenthesized: true, + }) +} + impl Node for ast::ExprGenerator { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -688,26 +970,25 @@ impl Node for ast::ExprGenerator { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - elt: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "elt", "GeneratorExp")?, - )?, - generators: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "generators", "GeneratorExp")?, - )?, - range: range_from_object(vm, source_file, object, "GeneratorExp")?, - // TODO: Is this correct? - parenthesized: true, - }) + let range = range_from_object(vm, source_file, object.clone(), "GeneratorExp")?; + expr_generator_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_await_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprAwait { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "Await")?, + range, + }) +} + impl Node for ast::ExprAwait { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -729,19 +1010,27 @@ impl Node for ast::ExprAwait { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "Await")?, - )?, - range: range_from_object(vm, source_file, object, "Await")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Await")?; + expr_await_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_yield_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprYield { + node_index: Default::default(), + value: get_node_field_opt(vm, &object, "value")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::ExprYield { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -764,17 +1053,25 @@ impl Node for ast::ExprYield { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: get_node_field_opt(vm, &object, "value")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "Yield")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Yield")?; + expr_yield_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_yield_from_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprYieldFrom { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "YieldFrom")?, + range, + }) +} + impl Node for ast::ExprYieldFrom { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -797,23 +1094,37 @@ impl Node for ast::ExprYieldFrom { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field_required(vm, &object, "value", "YieldFrom")?, - )?, - range: range_from_object(vm, source_file, object, "YieldFrom")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "YieldFrom")?; + expr_yield_from_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_compare_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let comparators: Vec> = + get_node_list_field(vm, source_file, &object, "comparators", "Compare")?; + let (node_index, comparators) = public_expr_boxed_slice_from_values( + super::constant::PublicAstExprListField::Comparators, + comparators, + ); + Ok(ast::ExprCompare { + node_index, + left: get_required_node_field(vm, source_file, &object, "left", "Compare")?, + ops: get_node_boxed_slice_field(vm, source_file, &object, "ops", "Compare")?, + comparators, + range, + }) +} + impl Node for ast::ExprCompare { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, left, ops, comparators, @@ -827,12 +1138,15 @@ impl Node for ast::ExprCompare { .unwrap(); dict.set_item("ops", BoxedSlice(ops).ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item( - "comparators", - BoxedSlice(comparators).ast_to_object(vm, source_file), - vm, + let comparators = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Comparators, ) - .unwrap(); + .map_or_else( + || BoxedSlice(comparators).ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("comparators", comparators, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -842,35 +1156,29 @@ impl Node for ast::ExprCompare { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - left: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "left", "Compare")?, - )?, - ops: { - let ops: BoxedSlice<_> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ops", "Compare")?, - )?; - ops.0 - }, - comparators: { - let comparators: BoxedSlice<_> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "comparators", "Compare")?, - )?; - comparators.0 - }, - range: range_from_object(vm, source_file, object, "Compare")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Compare")?; + expr_compare_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_call_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprCall { + node_index: Default::default(), + func: get_required_node_field(vm, source_file, &object, "func", "Call")?, + arguments: merge_function_call_arguments( + PositionalArguments::ast_from_field(vm, source_file, &object, "args", "Call")?, + KeywordArguments::ast_from_field(vm, source_file, &object, "keywords", "Call")?, + ), + range, + }) +} + impl Node for ast::ExprCall { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -907,31 +1215,31 @@ impl Node for ast::ExprCall { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - func: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "func", "Call")?, - )?, - arguments: merge_function_call_arguments( - Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "args", "Call")?, - )?, - Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "keywords", "Call")?, - )?, - ), - range: range_from_object(vm, source_file, object, "Call")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Call")?; + expr_call_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_attribute_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprAttribute { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "Attribute")?, + attr: get_required_identifier_field(vm, source_file, &object, "attr", "Attribute")?, + ctx: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "ctx", "Attribute")?, + )?, + range, + }) +} + impl Node for ast::ExprAttribute { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -960,29 +1268,31 @@ impl Node for ast::ExprAttribute { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "Attribute")?, - )?, - attr: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "attr", "Attribute")?, - )?, - ctx: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ctx", "Attribute")?, - )?, - range: range_from_object(vm, source_file, object, "Attribute")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Attribute")?; + expr_attribute_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_subscript_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprSubscript { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "Subscript")?, + slice: get_required_node_field(vm, source_file, &object, "slice", "Subscript")?, + ctx: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "ctx", "Subscript")?, + )?, + range, + }) +} + impl Node for ast::ExprSubscript { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1010,29 +1320,30 @@ impl Node for ast::ExprSubscript { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "Subscript")?, - )?, - slice: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "slice", "Subscript")?, - )?, - ctx: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ctx", "Subscript")?, - )?, - range: range_from_object(vm, source_file, object, "Subscript")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Subscript")?; + expr_subscript_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_starred_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprStarred { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "Starred")?, + ctx: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "ctx", "Starred")?, + )?, + range, + }) +} + impl Node for ast::ExprStarred { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1057,24 +1368,30 @@ impl Node for ast::ExprStarred { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "Starred")?, - )?, - ctx: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ctx", "Starred")?, - )?, - range: range_from_object(vm, source_file, object, "Starred")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Starred")?; + expr_starred_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_name_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprName { + node_index: Default::default(), + id: get_required_identifier_field(vm, source_file, &object, "id", "Name")?, + ctx: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "ctx", "Name")?, + )?, + range, + }) +} + impl Node for ast::ExprName { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1099,24 +1416,38 @@ impl Node for ast::ExprName { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - id: Node::ast_from_object(vm, source_file, get_node_field(vm, &object, "id", "Name")?)?, - ctx: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ctx", "Name")?, - )?, - range: range_from_object(vm, source_file, object, "Name")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Name")?; + expr_name_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_list_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let elts: Vec> = + get_node_list_field(vm, source_file, &object, "elts", "List")?; + let (node_index, elts) = + public_expr_list_from_values(super::constant::PublicAstExprListField::Elts, elts); + Ok(ast::ExprList { + node_index, + elts, + ctx: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "ctx", "List")?, + )?, + range, + }) +} + impl Node for ast::ExprList { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, elts, ctx, range, @@ -1125,8 +1456,15 @@ impl Node for ast::ExprList { .into_ref_with_type(vm, pyast::NodeExprList::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("elts", elts.ast_to_object(vm, source_file), vm) - .unwrap(); + let elts = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Elts, + ) + .map_or_else( + || elts.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("elts", elts, vm).unwrap(); dict.set_item("ctx", ctx.ast_to_object(vm, source_file), vm) .unwrap(); node_add_location(&dict, range, vm, source_file); @@ -1138,28 +1476,39 @@ impl Node for ast::ExprList { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - elts: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "elts", "List")?, - )?, - ctx: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ctx", "List")?, - )?, - range: range_from_object(vm, source_file, object, "List")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "List")?; + expr_list_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_tuple_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let elts: Vec> = + get_node_list_field(vm, source_file, &object, "elts", "Tuple")?; + let (node_index, elts) = + public_expr_list_from_values(super::constant::PublicAstExprListField::Elts, elts); + Ok(ast::ExprTuple { + node_index, + elts, + ctx: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "ctx", "Tuple")?, + )?, + range, + parenthesized: true, // TODO: is this correct? + }) +} + impl Node for ast::ExprTuple { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, elts, ctx, range: _range, @@ -1169,8 +1518,15 @@ impl Node for ast::ExprTuple { .into_ref_with_type(vm, pyast::NodeExprTuple::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("elts", elts.ast_to_object(vm, source_file), vm) - .unwrap(); + let elts = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Elts, + ) + .map_or_else( + || elts.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("elts", elts, vm).unwrap(); dict.set_item("ctx", ctx.ast_to_object(vm, source_file), vm) .unwrap(); node_add_location(&dict, _range, vm, source_file); @@ -1182,25 +1538,33 @@ impl Node for ast::ExprTuple { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - elts: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "elts", "Tuple")?, - )?, - ctx: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ctx", "Tuple")?, - )?, - range: range_from_object(vm, source_file, object, "Tuple")?, - parenthesized: true, // TODO: is this correct? - }) + let range = range_from_object(vm, source_file, object.clone(), "Tuple")?; + expr_tuple_from_object_with_range(vm, source_file, object, range) } } // constructor +fn expr_slice_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::ExprSlice { + node_index: Default::default(), + lower: get_node_field_opt(vm, &object, "lower")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + upper: get_node_field_opt(vm, &object, "upper")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + step: get_node_field_opt(vm, &object, "step")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::ExprSlice { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1229,19 +1593,8 @@ impl Node for ast::ExprSlice { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - lower: get_node_field_opt(vm, &object, "lower")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - upper: get_node_field_opt(vm, &object, "upper")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - step: get_node_field_opt(vm, &object, "step")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "Slice")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Slice")?; + expr_slice_from_object_with_range(vm, source_file, object, range) } } @@ -1253,7 +1606,7 @@ impl Node for ast::ExprContext { Self::Store => pyast::NodeExprContextStore::static_type(), Self::Del => pyast::NodeExprContextDel::static_type(), Self::Invalid => { - unimplemented!("Invalid expression context is not allowed in Python AST") + unreachable!("invalid expression context is not part of Python AST") } }; singleton_node_to_object(vm, node_type) @@ -1264,19 +1617,20 @@ impl Node for ast::ExprContext { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeExprContextLoad::static_type()) { - Self::Load - } else if cls.is(pyast::NodeExprContextStore::static_type()) { - Self::Store - } else if cls.is(pyast::NodeExprContextDel::static_type()) { - Self::Del - } else { - return Err(vm.new_type_error(format!( - "expected some sort of expr_context, but got {}", - object.repr(vm)? - ))); - }) + Ok( + if is_node_instance(vm, &object, pyast::NodeExprContextLoad::static_type())? { + Self::Load + } else if is_node_instance(vm, &object, pyast::NodeExprContextStore::static_type())? { + Self::Store + } else if is_node_instance(vm, &object, pyast::NodeExprContextDel::static_type())? { + Self::Del + } else { + return Err(vm.new_type_error(format!( + "expected some sort of expr_context, but got {}", + object.repr(vm)? + ))); + }, + ) } } @@ -1284,7 +1638,7 @@ impl Node for ast::ExprContext { impl Node for ast::Comprehension { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, target, iter, ifs, @@ -1299,10 +1653,21 @@ impl Node for ast::Comprehension { .unwrap(); dict.set_item("iter", iter.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("ifs", ifs.ast_to_object(vm, source_file), vm) - .unwrap(); - dict.set_item("is_async", is_async.ast_to_object(vm, source_file), vm) - .unwrap(); + let ifs = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Ifs, + ) + .map_or_else( + || ifs.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("ifs", ifs, vm).unwrap(); + let is_async = super::constant::public_ast_comprehension_is_async_object(node_index.load()) + .map_or_else( + || is_async.ast_to_object(vm, source_file), + |value| vm.ctx.new_int(value).into(), + ); + dict.set_item("is_async", is_async, vm).unwrap(); node.into() } @@ -1311,28 +1676,24 @@ impl Node for ast::Comprehension { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { + let ifs: Vec> = + get_node_list_field(vm, source_file, &object, "ifs", "comprehension")?; + let is_async = node_object_to_i32( + vm, + get_node_field(vm, &object, "is_async", "comprehension")?, + )?; + let node_index = public_node_list_overrides_node_index( + Vec::new(), + vec![(super::constant::PublicAstExprListField::Ifs, &ifs)], + None, + Some(is_async), + ); Ok(Self { - node_index: Default::default(), - target: Node::ast_from_object( - vm, - source_file, - get_node_field_required(vm, &object, "target", "comprehension")?, - )?, - iter: Node::ast_from_object( - vm, - source_file, - get_node_field_required(vm, &object, "iter", "comprehension")?, - )?, - ifs: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "ifs", "comprehension")?, - )?, - is_async: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "is_async", "comprehension")?, - )?, + node_index, + target: get_required_node_field(vm, source_file, &object, "target", "comprehension")?, + iter: get_required_node_field(vm, source_file, &object, "iter", "comprehension")?, + ifs: lower_public_expr_list(ifs), + is_async: is_async != 0, range: Default::default(), }) } diff --git a/crates/vm/src/stdlib/_ast/module.rs b/crates/vm/src/stdlib/_ast/module.rs index b4c2468d33b..e15ec55fda6 100644 --- a/crates/vm/src/stdlib/_ast/module.rs +++ b/crates/vm/src/stdlib/_ast/module.rs @@ -18,7 +18,7 @@ use rustpython_compiler_core::SourceFile; /// - `FunctionType`: A function signature with argument and return type /// annotations, representing the type hints of a function (e.g., `def add(x: int, y: int) -> int`). pub(super) enum Mod { - Module(ast::ModModule), + Module(ModModule), Interactive(ModInteractive), Expression(ast::ModExpression), FunctionType(ModFunctionType), @@ -40,46 +40,63 @@ impl Node for Mod { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeModModule::static_type()) { - Self::Module(ast::ModModule::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeModInteractive::static_type()) { - Self::Interactive(ModInteractive::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeModExpression::static_type()) { - Self::Expression(ast::ModExpression::ast_from_object( - vm, - source_file, - object, - )?) - } else if cls.is(pyast::NodeModFunctionType::static_type()) { - Self::FunctionType(ModFunctionType::ast_from_object(vm, source_file, object)?) - } else { - return Err(vm.new_type_error(format!( - "expected some sort of mod, but got {}", - object.repr(vm)? - ))); - }) + Ok( + if object.is_instance(pyast::NodeModModule::static_type().as_object(), vm)? { + Self::Module(ModModule::ast_from_object(vm, source_file, object)?) + } else if object + .is_instance(pyast::NodeModInteractive::static_type().as_object(), vm)? + { + Self::Interactive(ModInteractive::ast_from_object(vm, source_file, object)?) + } else if object.is_instance(pyast::NodeModExpression::static_type().as_object(), vm)? { + Self::Expression(ast::ModExpression::ast_from_object( + vm, + source_file, + object, + )?) + } else if object + .is_instance(pyast::NodeModFunctionType::static_type().as_object(), vm)? + { + Self::FunctionType(ModFunctionType::ast_from_object(vm, source_file, object)?) + } else { + return Err(vm.new_type_error(format!( + "expected some sort of mod, but got {}", + object.repr(vm)? + ))); + }, + ) } } +pub(super) struct ModModule { + pub(crate) module: ast::ModModule, + pub(crate) type_ignores: Vec, +} + // constructor -impl Node for ast::ModModule { +impl Node for ModModule { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + module, + type_ignores, + } = self; + let ast::ModModule { + node_index, body, - // type_ignores, range, - } = self; + } = module; let node = NodeAst .into_ref_with_type(vm, pyast::NodeModModule::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("body", body.ast_to_object(vm, source_file), vm) - .unwrap(); - // TODO: Improve ruff API - // ruff ignores type_ignore comments currently. - let type_ignores: Vec = vec![]; + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("body", body, vm).unwrap(); dict.set_item( "type_ignores", type_ignores.ast_to_object(vm, source_file), @@ -95,37 +112,48 @@ impl Node for ast::ModModule { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { + let body: Vec> = + get_node_list_field(vm, source_file, &object, "body", "Module")?; + let (node_index, body) = stmt_list_from_values(body); + let type_ignores = get_node_list_field(vm, source_file, &object, "type_ignores", "Module")?; Ok(Self { - node_index: Default::default(), - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "Module")?, - )?, - // type_ignores: Node::ast_from_object( - // _vm, - // get_node_field(_vm, &_object, "type_ignores", "Module")?, - // )?, - range: Default::default(), + module: ast::ModModule { + node_index, + body, + range: Default::default(), + }, + type_ignores, }) } } pub(super) struct ModInteractive { + pub(crate) node_index: ast::AtomicNodeIndex, pub(crate) range: TextRange, - pub(crate) body: Vec, + pub(crate) body: ast::Suite, } // constructor impl Node for ModInteractive { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - let Self { body, range } = self; + let Self { + node_index, + body, + range, + } = self; let node = NodeAst .into_ref_with_type(vm, pyast::NodeModInteractive::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("body", body.ast_to_object(vm, source_file), vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("body", body, vm).unwrap(); let _ = range; node.into() } @@ -135,12 +163,12 @@ impl Node for ModInteractive { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { + let body: Vec> = + get_node_list_field(vm, source_file, &object, "body", "Interactive")?; + let (node_index, body) = stmt_list_from_values(body); Ok(Self { - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "Interactive")?, - )?, + node_index, + body, range: Default::default(), }) } @@ -171,17 +199,14 @@ impl Node for ast::ModExpression { ) -> PyResult { Ok(Self { node_index: Default::default(), - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "Expression")?, - )?, + body: get_required_node_field(vm, source_file, &object, "body", "Expression")?, range: Default::default(), }) } } pub(super) struct ModFunctionType { + pub(crate) node_index: ast::AtomicNodeIndex, pub(crate) argtypes: Box<[ast::Expr]>, pub(crate) returns: ast::Expr, pub(crate) range: TextRange, @@ -191,6 +216,7 @@ pub(super) struct ModFunctionType { impl Node for ModFunctionType { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { + node_index, argtypes, returns, range, @@ -199,12 +225,12 @@ impl Node for ModFunctionType { .into_ref_with_type(vm, pyast::NodeModFunctionType::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item( - "argtypes", - BoxedSlice(argtypes).ast_to_object(vm, source_file), - vm, - ) - .unwrap(); + let argtypes = super::constant::public_ast_expr_option_list_object(node_index.load()) + .map_or_else( + || BoxedSlice(argtypes).ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("argtypes", argtypes, vm).unwrap(); dict.set_item("returns", returns.ast_to_object(vm, source_file), vm) .unwrap(); let _ = range; @@ -216,21 +242,69 @@ impl Node for ModFunctionType { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { + let argtypes: Vec> = + get_node_list_field(vm, source_file, &object, "argtypes", "FunctionType")?; + let (node_index, argtypes) = expr_list_from_values(argtypes); Ok(Self { - argtypes: { - let argtypes: BoxedSlice<_> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "argtypes", "FunctionType")?, - )?; - argtypes.0 - }, - returns: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "returns", "FunctionType")?, - )?, + node_index, + argtypes: argtypes.into_boxed_slice(), + returns: get_required_node_field(vm, source_file, &object, "returns", "FunctionType")?, range: Default::default(), }) } } + +fn stmt_list_from_values(values: Vec>) -> (ast::AtomicNodeIndex, ast::Suite) { + let index = if values.iter().any(Option::is_none) { + super::constant::register_public_ast_stmt_list( + super::constant::PublicAstStmtListField::Body, + values.clone(), + ) + } else { + ast::NodeIndex::NONE + }; + let node_index = ast::AtomicNodeIndex::NONE; + if index != ast::NodeIndex::NONE { + node_index.set(index); + } + ( + node_index, + values + .into_iter() + .map(|value| value.unwrap_or_else(null_stmt_placeholder)) + .collect(), + ) +} + +fn expr_list_from_values(values: Vec>) -> (ast::AtomicNodeIndex, Vec) { + let index = if values.iter().any(Option::is_none) { + super::constant::register_public_ast_expr_option_list(values.clone()) + } else { + ast::NodeIndex::NONE + }; + let node_index = ast::AtomicNodeIndex::NONE; + if index != ast::NodeIndex::NONE { + node_index.set(index); + } + ( + node_index, + values + .into_iter() + .map(|value| value.unwrap_or_else(null_expr_placeholder)) + .collect(), + ) +} + +fn null_stmt_placeholder() -> ast::Stmt { + ast::Stmt::Pass(ast::StmtPass { + range: Default::default(), + node_index: Default::default(), + }) +} + +fn null_expr_placeholder() -> ast::Expr { + ast::Expr::NoneLiteral(ast::ExprNoneLiteral { + range: Default::default(), + node_index: Default::default(), + }) +} diff --git a/crates/vm/src/stdlib/_ast/node.rs b/crates/vm/src/stdlib/_ast/node.rs index 4ee3893b665..d07732ef645 100644 --- a/crates/vm/src/stdlib/_ast/node.rs +++ b/crates/vm/src/stdlib/_ast/node.rs @@ -1,5 +1,6 @@ -use crate::{PyObjectRef, PyResult, VirtualMachine}; +use crate::{PyObjectRef, PyResult, VirtualMachine, builtins::PyList}; use rustpython_compiler_core::SourceFile; +use thin_vec::ThinVec; pub(crate) trait Node: Sized { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef; @@ -31,14 +32,52 @@ impl Node for Vec { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - // Recursion guard for each element: prevents stack overflow when a - // sequence element transitively references the sequence itself - // (e.g. `l = ast.List(...); l.elts = [l]`). See issue #4862. - vm.extract_elements_with(&object, |obj| { - vm.with_recursion("while traversing AST node", || { - Node::ast_from_object(vm, source_file, obj) - }) - }) + let list = object.downcast_ref::().ok_or_else(|| { + vm.new_type_error(format!( + "AST list field must be a list, not a {}", + object.class().name() + )) + })?; + let len = list.borrow_vec().len(); + let mut result = Self::with_capacity(len); + for i in 0..len { + let item = { + let items = list.borrow_vec(); + if items.len() != len { + return Err( + vm.new_runtime_error("AST list field changed size during iteration") + ); + } + items[i].clone() + }; + result.push(vm.with_recursion("while traversing AST node", || { + Node::ast_from_object(vm, source_file, item) + })?); + if list.borrow_vec().len() != len { + return Err(vm.new_runtime_error("AST list field changed size during iteration")); + } + } + Ok(result) + } +} + +impl Node for ThinVec { + fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { + vm.ctx + .new_list( + self.into_iter() + .map(|node| node.ast_to_object(vm, source_file)) + .collect(), + ) + .into() + } + + fn ast_from_object( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + ) -> PyResult { + Vec::::ast_from_object(vm, source_file, object).map(Into::into) } } diff --git a/crates/vm/src/stdlib/_ast/operator.rs b/crates/vm/src/stdlib/_ast/operator.rs index 09e63b5d6ce..e05e490bb84 100644 --- a/crates/vm/src/stdlib/_ast/operator.rs +++ b/crates/vm/src/stdlib/_ast/operator.rs @@ -16,17 +16,18 @@ impl Node for ast::BoolOp { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeBoolOpAnd::static_type()) { - Self::And - } else if cls.is(pyast::NodeBoolOpOr::static_type()) { - Self::Or - } else { - return Err(vm.new_type_error(format!( - "expected some sort of boolop, but got {}", - object.repr(vm)? - ))); - }) + Ok( + if is_node_instance(vm, &object, pyast::NodeBoolOpAnd::static_type())? { + Self::And + } else if is_node_instance(vm, &object, pyast::NodeBoolOpOr::static_type())? { + Self::Or + } else { + return Err(vm.new_type_error(format!( + "expected some sort of boolop, but got {}", + object.repr(vm)? + ))); + }, + ) } } @@ -56,39 +57,40 @@ impl Node for ast::Operator { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeOperatorAdd::static_type()) { - Self::Add - } else if cls.is(pyast::NodeOperatorSub::static_type()) { - Self::Sub - } else if cls.is(pyast::NodeOperatorMult::static_type()) { - Self::Mult - } else if cls.is(pyast::NodeOperatorMatMult::static_type()) { - Self::MatMult - } else if cls.is(pyast::NodeOperatorDiv::static_type()) { - Self::Div - } else if cls.is(pyast::NodeOperatorMod::static_type()) { - Self::Mod - } else if cls.is(pyast::NodeOperatorPow::static_type()) { - Self::Pow - } else if cls.is(pyast::NodeOperatorLShift::static_type()) { - Self::LShift - } else if cls.is(pyast::NodeOperatorRShift::static_type()) { - Self::RShift - } else if cls.is(pyast::NodeOperatorBitOr::static_type()) { - Self::BitOr - } else if cls.is(pyast::NodeOperatorBitXor::static_type()) { - Self::BitXor - } else if cls.is(pyast::NodeOperatorBitAnd::static_type()) { - Self::BitAnd - } else if cls.is(pyast::NodeOperatorFloorDiv::static_type()) { - Self::FloorDiv - } else { - return Err(vm.new_type_error(format!( - "expected some sort of operator, but got {}", - object.repr(vm)? - ))); - }) + Ok( + if is_node_instance(vm, &object, pyast::NodeOperatorAdd::static_type())? { + Self::Add + } else if is_node_instance(vm, &object, pyast::NodeOperatorSub::static_type())? { + Self::Sub + } else if is_node_instance(vm, &object, pyast::NodeOperatorMult::static_type())? { + Self::Mult + } else if is_node_instance(vm, &object, pyast::NodeOperatorMatMult::static_type())? { + Self::MatMult + } else if is_node_instance(vm, &object, pyast::NodeOperatorDiv::static_type())? { + Self::Div + } else if is_node_instance(vm, &object, pyast::NodeOperatorMod::static_type())? { + Self::Mod + } else if is_node_instance(vm, &object, pyast::NodeOperatorPow::static_type())? { + Self::Pow + } else if is_node_instance(vm, &object, pyast::NodeOperatorLShift::static_type())? { + Self::LShift + } else if is_node_instance(vm, &object, pyast::NodeOperatorRShift::static_type())? { + Self::RShift + } else if is_node_instance(vm, &object, pyast::NodeOperatorBitOr::static_type())? { + Self::BitOr + } else if is_node_instance(vm, &object, pyast::NodeOperatorBitXor::static_type())? { + Self::BitXor + } else if is_node_instance(vm, &object, pyast::NodeOperatorBitAnd::static_type())? { + Self::BitAnd + } else if is_node_instance(vm, &object, pyast::NodeOperatorFloorDiv::static_type())? { + Self::FloorDiv + } else { + return Err(vm.new_type_error(format!( + "expected some sort of operator, but got {}", + object.repr(vm)? + ))); + }, + ) } } @@ -109,21 +111,22 @@ impl Node for ast::UnaryOp { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeUnaryOpInvert::static_type()) { - Self::Invert - } else if cls.is(pyast::NodeUnaryOpNot::static_type()) { - Self::Not - } else if cls.is(pyast::NodeUnaryOpUAdd::static_type()) { - Self::UAdd - } else if cls.is(pyast::NodeUnaryOpUSub::static_type()) { - Self::USub - } else { - return Err(vm.new_type_error(format!( - "expected some sort of unaryop, but got {}", - object.repr(vm)? - ))); - }) + Ok( + if is_node_instance(vm, &object, pyast::NodeUnaryOpInvert::static_type())? { + Self::Invert + } else if is_node_instance(vm, &object, pyast::NodeUnaryOpNot::static_type())? { + Self::Not + } else if is_node_instance(vm, &object, pyast::NodeUnaryOpUAdd::static_type())? { + Self::UAdd + } else if is_node_instance(vm, &object, pyast::NodeUnaryOpUSub::static_type())? { + Self::USub + } else { + return Err(vm.new_type_error(format!( + "expected some sort of unaryop, but got {}", + object.repr(vm)? + ))); + }, + ) } } @@ -150,32 +153,33 @@ impl Node for ast::CmpOp { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeCmpOpEq::static_type()) { - Self::Eq - } else if cls.is(pyast::NodeCmpOpNotEq::static_type()) { - Self::NotEq - } else if cls.is(pyast::NodeCmpOpLt::static_type()) { - Self::Lt - } else if cls.is(pyast::NodeCmpOpLtE::static_type()) { - Self::LtE - } else if cls.is(pyast::NodeCmpOpGt::static_type()) { - Self::Gt - } else if cls.is(pyast::NodeCmpOpGtE::static_type()) { - Self::GtE - } else if cls.is(pyast::NodeCmpOpIs::static_type()) { - Self::Is - } else if cls.is(pyast::NodeCmpOpIsNot::static_type()) { - Self::IsNot - } else if cls.is(pyast::NodeCmpOpIn::static_type()) { - Self::In - } else if cls.is(pyast::NodeCmpOpNotIn::static_type()) { - Self::NotIn - } else { - return Err(vm.new_type_error(format!( - "expected some sort of cmpop, but got {}", - object.repr(vm)? - ))); - }) + Ok( + if is_node_instance(vm, &object, pyast::NodeCmpOpEq::static_type())? { + Self::Eq + } else if is_node_instance(vm, &object, pyast::NodeCmpOpNotEq::static_type())? { + Self::NotEq + } else if is_node_instance(vm, &object, pyast::NodeCmpOpLt::static_type())? { + Self::Lt + } else if is_node_instance(vm, &object, pyast::NodeCmpOpLtE::static_type())? { + Self::LtE + } else if is_node_instance(vm, &object, pyast::NodeCmpOpGt::static_type())? { + Self::Gt + } else if is_node_instance(vm, &object, pyast::NodeCmpOpGtE::static_type())? { + Self::GtE + } else if is_node_instance(vm, &object, pyast::NodeCmpOpIs::static_type())? { + Self::Is + } else if is_node_instance(vm, &object, pyast::NodeCmpOpIsNot::static_type())? { + Self::IsNot + } else if is_node_instance(vm, &object, pyast::NodeCmpOpIn::static_type())? { + Self::In + } else if is_node_instance(vm, &object, pyast::NodeCmpOpNotIn::static_type())? { + Self::NotIn + } else { + return Err(vm.new_type_error(format!( + "expected some sort of cmpop, but got {}", + object.repr(vm)? + ))); + }, + ) } } diff --git a/crates/vm/src/stdlib/_ast/other.rs b/crates/vm/src/stdlib/_ast/other.rs index 5009c588cfc..9d0256942ab 100644 --- a/crates/vm/src/stdlib/_ast/other.rs +++ b/crates/vm/src/stdlib/_ast/other.rs @@ -11,14 +11,13 @@ impl Node for ast::ConversionFlag { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - // Python's AST uses ASCII codes: 's', 'r', 'a', -1=None - // Note: 255 is -1i8 as u8 (ruff's ConversionFlag::None) - match i32::try_from_object(vm, object)? { - -1 | 255 => Ok(Self::None), + // Python's AST uses ASCII codes: 's', 'r', 'a', -1=None. + match node_object_to_i32(vm, object)? { + -1 => Ok(Self::None), x if x == b's' as i32 => Ok(Self::Str), x if x == b'r' as i32 => Ok(Self::Repr), x if x == b'a' as i32 => Ok(Self::Ascii), - _ => Err(vm.new_value_error("invalid conversion flag")), + x => Err(vm.new_system_error(format!("Unrecognized conversion character {x}"))), } } } @@ -34,10 +33,13 @@ impl Node for ast::name::Name { _source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - match object.downcast::() { - Ok(name) => Ok(Self::new(name)), - Err(_) => Err(vm.new_value_error("expected str for name")), + if !object.class().is(vm.ctx.types.str_type) { + return Err(vm.new_type_error("AST identifier must be of type str")); } + object + .downcast::() + .map(Self::new) + .map_err(|_| vm.new_type_error("AST identifier must be of type str")) } } @@ -89,11 +91,7 @@ impl Node for ast::Alias { ) -> PyResult { Ok(Self { node_index: Default::default(), - name: Node::ast_from_object( - vm, - source_file, - get_node_field_required(vm, &object, "name", "alias")?, - )?, + name: get_required_identifier_field(vm, source_file, &object, "name", "alias")?, asname: get_node_field_opt(vm, &object, "asname")? .map(|obj| Node::ast_from_object(vm, source_file, obj)) .transpose()?, @@ -137,10 +135,12 @@ impl Node for ast::WithItem { ) -> PyResult { Ok(Self { node_index: Default::default(), - context_expr: Node::ast_from_object( + context_expr: get_required_node_field( vm, source_file, - get_node_field_required(vm, &object, "context_expr", "withitem")?, + &object, + "context_expr", + "withitem", )?, optional_vars: get_node_field_opt(vm, &object, "optional_vars")? .map(|obj| Node::ast_from_object(vm, source_file, obj)) diff --git a/crates/vm/src/stdlib/_ast/parameter.rs b/crates/vm/src/stdlib/_ast/parameter.rs index b0c807a2922..67817475a03 100644 --- a/crates/vm/src/stdlib/_ast/parameter.rs +++ b/crates/vm/src/stdlib/_ast/parameter.rs @@ -5,7 +5,7 @@ use rustpython_compiler_core::SourceFile; impl Node for ast::Parameters { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, posonlyargs, args, vararg, @@ -40,8 +40,12 @@ impl Node for ast::Parameters { .unwrap(); dict.set_item("kwarg", kwarg.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("defaults", defaults.ast_to_object(vm, source_file), vm) - .unwrap(); + let defaults = super::constant::public_ast_expr_option_list_object(node_index.load()) + .map_or_else( + || defaults.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("defaults", defaults, vm).unwrap(); let _ = range; node.into() } @@ -51,47 +55,57 @@ impl Node for ast::Parameters { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let kwonlyargs = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "kwonlyargs", "arguments")?, - )?; - let kw_defaults = Node::ast_from_object( + let posonlyargs = PositionalParameters::ast_from_field( vm, source_file, - get_node_field(vm, &object, "kw_defaults", "arguments")?, - )?; - let kwonlyargs = merge_keyword_parameter_defaults(vm, kwonlyargs, kw_defaults)?; - - let posonlyargs = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "posonlyargs", "arguments")?, + &object, + "posonlyargs", + "arguments", )?; - let args = Node::ast_from_object( + let args = + PositionalParameters::ast_from_field(vm, source_file, &object, "args", "arguments")?; + let vararg = get_node_field_opt(vm, &object, "vararg")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?; + let kwonlyargs = + KeywordParameters::ast_from_field(vm, source_file, &object, "kwonlyargs", "arguments")?; + let kw_defaults = ParameterDefaults::ast_from_field( vm, source_file, - get_node_field(vm, &object, "args", "arguments")?, + &object, + "kw_defaults", + "arguments", )?; - let defaults = Node::ast_from_object( + let kwarg = get_node_field_opt(vm, &object, "kwarg")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?; + let defaults = ParameterDefaults::ast_from_field_preserve_none( vm, source_file, - get_node_field(vm, &object, "defaults", "arguments")?, + &object, + "defaults", + "arguments", )?; + + let kwonlyargs = merge_keyword_parameter_defaults(vm, kwonlyargs, kw_defaults)?; + let defaults_node_index = defaults.node_index; let (posonlyargs, args) = merge_positional_parameter_defaults(vm, posonlyargs, args, defaults)?; + let node_index = { + let node_index = ast::AtomicNodeIndex::NONE; + if defaults_node_index != ast::NodeIndex::NONE { + node_index.set(defaults_node_index); + } + node_index + }; Ok(Self { - node_index: Default::default(), + node_index, posonlyargs, args, - vararg: get_node_field_opt(vm, &object, "vararg")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, + vararg, kwonlyargs, - kwarg: get_node_field_opt(vm, &object, "kwarg")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, + kwarg, range: Default::default(), }) } @@ -105,7 +119,7 @@ impl Node for ast::Parameters { impl Node for ast::Parameter { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, name, annotation, // type_comment, @@ -127,8 +141,9 @@ impl Node for ast::Parameter { _vm, ) .unwrap(); - // Ruff AST doesn't track type_comment, so always set to None - dict.set_item("type_comment", _vm.ctx.none(), _vm).unwrap(); + let type_comment = super::constant::public_ast_arg_type_comment_object(node_index.load()) + .unwrap_or_else(|| _vm.ctx.none()); + dict.set_item("type_comment", type_comment, _vm).unwrap(); node_add_location(&dict, range, _vm, source_file); node.into() } @@ -138,20 +153,23 @@ impl Node for ast::Parameter { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { + let name = get_required_identifier_field(_vm, source_file, &_object, "arg", "arg")?; + let annotation = get_node_field_opt(_vm, &_object, "annotation")? + .map(|obj| Node::ast_from_object(_vm, source_file, obj)) + .transpose()?; + let type_comment = get_ast_string_field_opt(_vm, &_object, "type_comment")?; + let node_index = ast::AtomicNodeIndex::NONE; + if let Some(type_comment) = type_comment { + node_index.set(super::constant::register_public_ast_arg_type_comment( + type_comment, + )); + } + let range = range_from_object(_vm, source_file, _object, "arg")?; Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - _vm, - source_file, - get_node_field_required(_vm, &_object, "arg", "arg")?, - )?, - annotation: get_node_field_opt(_vm, &_object, "annotation")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - // type_comment: get_node_field_opt(_vm, &_object, "type_comment")? - // .map(|obj| Node::ast_from_object(_vm, obj)) - // .transpose()?, - range: range_from_object(_vm, source_file, _object, "arg")?, + node_index, + name, + annotation, + range, }) } } @@ -186,11 +204,7 @@ impl Node for ast::Keyword { arg: get_node_field_opt(_vm, &_object, "arg")? .map(|obj| Node::ast_from_object(_vm, source_file, obj)) .transpose()?, - value: Node::ast_from_object( - _vm, - source_file, - get_node_field_required(_vm, &_object, "value", "keyword")?, - )?, + value: get_required_node_field(_vm, source_file, &_object, "value", "keyword")?, range: range_from_object(_vm, source_file, _object, "keyword")?, }) } @@ -201,6 +215,21 @@ struct PositionalParameters { pub args: Box<[ast::Parameter]>, } +impl PositionalParameters { + fn ast_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + Ok(Self { + args: get_node_boxed_slice_field(vm, source_file, object, field, typ)?, + _range: TextRange::default(), + }) + } +} + impl Node for PositionalParameters { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { BoxedSlice(self.args).ast_to_object(vm, source_file) @@ -224,6 +253,21 @@ struct KeywordParameters { pub keywords: Box<[ast::Parameter]>, } +impl KeywordParameters { + fn ast_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + Ok(Self { + keywords: get_node_boxed_slice_field(vm, source_file, object, field, typ)?, + _range: TextRange::default(), + }) + } +} + impl Node for KeywordParameters { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { BoxedSlice(self.keywords).ast_to_object(vm, source_file) @@ -244,9 +288,52 @@ impl Node for KeywordParameters { struct ParameterDefaults { pub _range: TextRange, // TODO: Use this + node_index: ast::NodeIndex, defaults: Box<[Option>]>, } +impl ParameterDefaults { + fn ast_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + Ok(Self { + defaults: get_node_boxed_slice_field(vm, source_file, object, field, typ)?, + node_index: ast::NodeIndex::NONE, + _range: TextRange::default(), + }) + } + + fn ast_from_field_preserve_none( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + let defaults: Vec>> = + get_node_list_field(vm, source_file, object, field, typ)?; + let node_index = if defaults.iter().any(Option::is_none) { + super::constant::register_public_ast_expr_option_list( + defaults + .iter() + .map(|default| default.as_deref().cloned()) + .collect(), + ) + } else { + ast::NodeIndex::NONE + }; + Ok(Self { + defaults: defaults.into_boxed_slice(), + node_index, + _range: TextRange::default(), + }) + } +} + impl Node for ParameterDefaults { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { BoxedSlice(self.defaults).ast_to_object(vm, source_file) @@ -260,14 +347,15 @@ impl Node for ParameterDefaults { let defaults: BoxedSlice<_> = Node::ast_from_object(vm, source_file, object)?; Ok(Self { defaults: defaults.0, + node_index: ast::NodeIndex::NONE, _range: TextRange::default(), // TODO }) } } fn extract_positional_parameter_defaults( - pos_only_args: Vec, - args: Vec, + pos_only_args: ast::ParameterWithDefaults, + args: ast::ParameterWithDefaults, ) -> ( PositionalParameters, PositionalParameters, @@ -287,6 +375,7 @@ fn extract_positional_parameter_defaults( .map(|item| item.range()) .reduce(|acc, next| acc.cover(next)) .unwrap_or_default(), + node_index: ast::NodeIndex::NONE, defaults: defaults.into_boxed_slice(), }; @@ -326,10 +415,7 @@ fn merge_positional_parameter_defaults( posonlyargs: PositionalParameters, args: PositionalParameters, defaults: ParameterDefaults, -) -> PyResult<( - Vec, - Vec, -)> { +) -> PyResult<(ast::ParameterWithDefaults, ast::ParameterWithDefaults)> { let posonlyargs = posonlyargs.args; let args = args.args; let defaults = defaults.defaults; @@ -368,11 +454,11 @@ fn merge_positional_parameter_defaults( arg.default = default; } - Ok((posonlyargs, args)) + Ok((posonlyargs.into(), args.into())) } fn extract_keyword_parameter_defaults( - kw_only_args: Vec, + kw_only_args: ast::ParameterWithDefaults, ) -> (KeywordParameters, ParameterDefaults) { let mut defaults = vec![]; defaults.extend(kw_only_args.iter().map(|item| item.default.clone())); @@ -383,6 +469,7 @@ fn extract_keyword_parameter_defaults( .map(|item| item.range()) .reduce(|acc, next| acc.cover(next)) .unwrap_or_default(), + node_index: ast::NodeIndex::NONE, defaults: defaults.into_boxed_slice(), }; @@ -409,7 +496,7 @@ fn merge_keyword_parameter_defaults( vm: &VirtualMachine, kw_only_args: KeywordParameters, defaults: ParameterDefaults, -) -> PyResult> { +) -> PyResult { if kw_only_args.keywords.len() != defaults.defaults.len() { return Err( vm.new_value_error("length of kwonlyargs is not the same as kw_defaults on arguments") @@ -422,5 +509,6 @@ fn merge_keyword_parameter_defaults( default, range: Default::default(), }) - .collect()) + .collect::>() + .into()) } diff --git a/crates/vm/src/stdlib/_ast/pattern.rs b/crates/vm/src/stdlib/_ast/pattern.rs index a383c25a6b7..d8b851b5bab 100644 --- a/crates/vm/src/stdlib/_ast/pattern.rs +++ b/crates/vm/src/stdlib/_ast/pattern.rs @@ -5,7 +5,7 @@ use rustpython_compiler_core::SourceFile; impl Node for ast::MatchCase { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, pattern, guard, body, @@ -19,8 +19,15 @@ impl Node for ast::MatchCase { .unwrap(); dict.set_item("guard", guard.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("body", body.ast_to_object(vm, source_file), vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("body", body, vm).unwrap(); node.into() } @@ -29,21 +36,17 @@ impl Node for ast::MatchCase { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { + let body: Vec> = + get_node_list_field(vm, source_file, &object, "body", "match_case")?; + let (node_index, body) = + public_stmt_list_from_values(super::constant::PublicAstStmtListField::Body, body); Ok(Self { - node_index: Default::default(), - pattern: Node::ast_from_object( - vm, - source_file, - get_node_field_required(vm, &object, "pattern", "match_case")?, - )?, + node_index, + pattern: get_required_node_field(vm, source_file, &object, "pattern", "match_case")?, guard: get_node_field_opt(vm, &object, "guard")? .map(|obj| Node::ast_from_object(vm, source_file, obj)) .transpose()?, - body: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "body", "match_case")?, - )?, + body, range: Default::default(), }) } @@ -68,64 +71,161 @@ impl Node for ast::Pattern { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodePatternMatchValue::static_type()) { - Self::MatchValue(ast::PatternMatchValue::ast_from_object( - vm, - source_file, - object, - )?) - } else if cls.is(pyast::NodePatternMatchSingleton::static_type()) { - Self::MatchSingleton(ast::PatternMatchSingleton::ast_from_object( - vm, - source_file, - object, - )?) - } else if cls.is(pyast::NodePatternMatchSequence::static_type()) { - Self::MatchSequence(ast::PatternMatchSequence::ast_from_object( - vm, - source_file, - object, - )?) - } else if cls.is(pyast::NodePatternMatchMapping::static_type()) { - Self::MatchMapping(ast::PatternMatchMapping::ast_from_object( + if vm.is_none(&object) { + return Err(vm.new_type_error(format!( + "expected some sort of pattern, but got {}", + object.repr(vm)? + ))); + } + enum PatternKind { + Value, + Singleton, + Sequence, + Mapping, + Class, + Star, + As, + Or, + } + let kind = if is_node_instance(vm, &object, pyast::NodePatternMatchValue::static_type())? { + PatternKind::Value + } else if is_node_instance(vm, &object, pyast::NodePatternMatchSingleton::static_type())? { + PatternKind::Singleton + } else if is_node_instance(vm, &object, pyast::NodePatternMatchSequence::static_type())? { + PatternKind::Sequence + } else if is_node_instance(vm, &object, pyast::NodePatternMatchMapping::static_type())? { + PatternKind::Mapping + } else if is_node_instance(vm, &object, pyast::NodePatternMatchClass::static_type())? { + PatternKind::Class + } else if is_node_instance(vm, &object, pyast::NodePatternMatchStar::static_type())? { + PatternKind::Star + } else if is_node_instance(vm, &object, pyast::NodePatternMatchAs::static_type())? { + PatternKind::As + } else if is_node_instance(vm, &object, pyast::NodePatternMatchOr::static_type())? { + PatternKind::Or + } else { + return Err(vm.new_type_error(format!( + "expected some sort of pattern, but got {}", + object.repr(vm)? + ))); + }; + let range = pattern_range_from_object(vm, source_file, object.clone())?; + Ok(match kind { + PatternKind::Value => Self::MatchValue(pattern_match_value_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodePatternMatchClass::static_type()) { - Self::MatchClass(ast::PatternMatchClass::ast_from_object( + range, + )?), + PatternKind::Singleton => Self::MatchSingleton( + pattern_match_singleton_from_object_with_range(vm, source_file, object, range)?, + ), + PatternKind::Sequence => Self::MatchSequence( + pattern_match_sequence_from_object_with_range(vm, source_file, object, range)?, + ), + PatternKind::Mapping => Self::MatchMapping( + pattern_match_mapping_from_object_with_range(vm, source_file, object, range)?, + ), + PatternKind::Class => Self::MatchClass(pattern_match_class_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodePatternMatchStar::static_type()) { - Self::MatchStar(ast::PatternMatchStar::ast_from_object( + range, + )?), + PatternKind::Star => Self::MatchStar(pattern_match_star_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodePatternMatchAs::static_type()) { - Self::MatchAs(ast::PatternMatchAs::ast_from_object( + range, + )?), + PatternKind::As => Self::MatchAs(pattern_match_as_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodePatternMatchOr::static_type()) { - Self::MatchOr(ast::PatternMatchOr::ast_from_object( + range, + )?), + PatternKind::Or => Self::MatchOr(pattern_match_or_from_object_with_range( vm, source_file, object, - )?) - } else { - return Err(vm.new_type_error(format!( - "expected some sort of pattern, but got {}", - object.repr(vm)? - ))); + range, + )?), }) } } + +fn pattern_node_index(index: ast::NodeIndex) -> ast::AtomicNodeIndex { + let node_index = ast::AtomicNodeIndex::NONE; + node_index.set(index); + node_index +} + +fn null_pattern_placeholder(range: TextRange) -> ast::Pattern { + ast::Pattern::MatchAs(ast::PatternMatchAs { + node_index: Default::default(), + range, + pattern: None, + name: None, + }) +} + +fn lower_nullable_patterns(values: &[Option], range: TextRange) -> ast::Patterns { + values + .iter() + .cloned() + .map(|value| value.unwrap_or_else(|| null_pattern_placeholder(range))) + .collect() +} + +fn null_expr_placeholder(range: TextRange) -> ast::Expr { + ast::Expr::NoneLiteral(ast::ExprNoneLiteral { + node_index: Default::default(), + range, + }) +} + +fn lower_nullable_exprs(values: &[Option], range: TextRange) -> ast::PatternKeys { + values + .iter() + .cloned() + .map(|value| value.unwrap_or_else(|| null_expr_placeholder(range))) + .collect() +} + +fn pattern_list_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + range: TextRange, +) -> PyResult<(ast::AtomicNodeIndex, ast::Patterns)> { + let values: Vec> = + get_node_list_field(vm, source_file, object, field, typ)?; + let node_index = if values.iter().any(Option::is_none) { + pattern_node_index(super::constant::register_public_ast_pattern_list( + values.clone(), + )) + } else { + Default::default() + }; + Ok((node_index, lower_nullable_patterns(&values, range))) +} + // constructor +fn pattern_match_value_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::PatternMatchValue { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "MatchValue")?, + range, + }) +} + impl Node for ast::PatternMatchValue { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -148,19 +248,29 @@ impl Node for ast::PatternMatchValue { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "MatchValue")?, - )?, - range: range_from_object(vm, source_file, object, "MatchValue")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchValue")?; + pattern_match_value_from_object_with_range(vm, source_file, object, range) } } // constructor +fn pattern_match_singleton_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::PatternMatchSingleton { + node_index: Default::default(), + value: Node::ast_from_object( + vm, + source_file, + get_node_field(vm, &object, "value", "MatchSingleton")?, + )?, + range, + }) +} + impl Node for ast::PatternMatchSingleton { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -186,15 +296,8 @@ impl Node for ast::PatternMatchSingleton { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "MatchSingleton")?, - )?, - range: range_from_object(vm, source_file, object, "MatchSingleton")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchSingleton")?; + pattern_match_singleton_from_object_with_range(vm, source_file, object, range) } } @@ -219,19 +322,31 @@ impl Node for ast::Singleton { } else if object.is(&vm.ctx.false_value) { Ok(Self::False) } else { - Err(vm.new_value_error(format!( - "Expected None, True, or False, got {:?}", - object.class().name() - ))) + Err(vm.new_value_error("MatchSingleton can only contain True, False and None")) } } } // constructor +fn pattern_match_sequence_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let (node_index, patterns) = + pattern_list_from_field(vm, source_file, &object, "patterns", "MatchSequence", range)?; + Ok(ast::PatternMatchSequence { + node_index, + patterns: patterns.to_vec(), + range, + }) +} + impl Node for ast::PatternMatchSequence { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, patterns, range, } = self; @@ -242,8 +357,12 @@ impl Node for ast::PatternMatchSequence { ) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("patterns", patterns.ast_to_object(vm, source_file), vm) - .unwrap(); + let patterns = super::constant::public_ast_pattern_list_object(node_index.load()) + .map_or_else( + || patterns.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("patterns", patterns, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -253,23 +372,47 @@ impl Node for ast::PatternMatchSequence { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - patterns: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "patterns", "MatchSequence")?, - )?, - range: range_from_object(vm, source_file, object, "MatchSequence")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchSequence")?; + pattern_match_sequence_from_object_with_range(vm, source_file, object, range) } } // constructor +fn pattern_match_mapping_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let keys: Vec> = + get_node_list_field(vm, source_file, &object, "keys", "MatchMapping")?; + let patterns: Vec> = + get_node_list_field(vm, source_file, &object, "patterns", "MatchMapping")?; + let has_public_override = + keys.iter().any(Option::is_none) || patterns.iter().any(Option::is_none); + let node_index = if has_public_override { + pattern_node_index(super::constant::register_public_ast_match_mapping( + keys.clone(), + patterns.clone(), + )) + } else { + Default::default() + }; + Ok(ast::PatternMatchMapping { + node_index, + keys: lower_nullable_exprs(&keys, range), + patterns: lower_nullable_patterns(&patterns, range), + rest: get_node_field_opt(vm, &object, "rest")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::PatternMatchMapping { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, keys, patterns, rest, @@ -279,10 +422,18 @@ impl Node for ast::PatternMatchMapping { .into_ref_with_type(vm, pyast::NodePatternMatchMapping::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("keys", keys.ast_to_object(vm, source_file), vm) - .unwrap(); - dict.set_item("patterns", patterns.ast_to_object(vm, source_file), vm) - .unwrap(); + let keys = super::constant::public_ast_expr_option_list_object(node_index.load()) + .map_or_else( + || keys.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("keys", keys, vm).unwrap(); + let patterns = super::constant::public_ast_pattern_list_object(node_index.load()) + .map_or_else( + || patterns.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("patterns", patterns, vm).unwrap(); dict.set_item("rest", rest.ast_to_object(vm, source_file), vm) .unwrap(); node_add_location(&dict, range, vm, source_file); @@ -294,52 +445,93 @@ impl Node for ast::PatternMatchMapping { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - keys: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "keys", "MatchMapping")?, - )?, - patterns: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "patterns", "MatchMapping")?, - )?, - rest: get_node_field_opt(vm, &object, "rest")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "MatchMapping")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchMapping")?; + pattern_match_mapping_from_object_with_range(vm, source_file, object, range) } } // constructor +fn pattern_match_class_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let cls = get_required_node_field(vm, source_file, &object, "cls", "MatchClass")?; + let patterns: Vec> = + get_node_list_field(vm, source_file, &object, "patterns", "MatchClass")?; + let kwd_attrs = PatternMatchClassKeywordAttributes::ast_from_field( + vm, + source_file, + &object, + "kwd_attrs", + "MatchClass", + )?; + let kwd_patterns: Vec> = + get_node_list_field(vm, source_file, &object, "kwd_patterns", "MatchClass")?; + let has_public_override = kwd_attrs.0.len() != kwd_patterns.len() + || patterns.iter().any(Option::is_none) + || kwd_patterns.iter().any(Option::is_none); + let node_index = if has_public_override { + pattern_node_index(super::constant::register_public_ast_match_class( + patterns.clone(), + kwd_attrs.0.clone(), + kwd_patterns.clone(), + )) + } else { + Default::default() + }; + let patterns = PatternMatchClassPatterns(lower_nullable_patterns(&patterns, range)); + let kwd_patterns = + PatternMatchClassKeywordPatterns(lower_nullable_patterns(&kwd_patterns, range)); + let (patterns, keywords) = merge_pattern_match_class(patterns, kwd_attrs, kwd_patterns); + + Ok(ast::PatternMatchClass { + node_index, + cls, + range, + arguments: ast::PatternArguments { + node_index: Default::default(), + range: Default::default(), + patterns, + keywords, + }, + }) +} + impl Node for ast::PatternMatchClass { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, cls, arguments, range, } = self; - let (patterns, kwd_attrs, kwd_patterns) = split_pattern_match_class(arguments); let node = NodeAst .into_ref_with_type(vm, pyast::NodePatternMatchClass::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); dict.set_item("cls", cls.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("patterns", patterns.ast_to_object(vm, source_file), vm) - .unwrap(); - dict.set_item("kwd_attrs", kwd_attrs.ast_to_object(vm, source_file), vm) - .unwrap(); - dict.set_item( - "kwd_patterns", - kwd_patterns.ast_to_object(vm, source_file), - vm, - ) - .unwrap(); + let (patterns, kwd_attrs, kwd_patterns) = if let Some(values) = + super::constant::public_ast_match_class_object(node_index.load()) + { + ( + values.patterns.ast_to_object(vm, source_file), + values.kwd_attrs.ast_to_object(vm, source_file), + values.kwd_patterns.ast_to_object(vm, source_file), + ) + } else { + let (patterns, kwd_attrs, kwd_patterns) = split_pattern_match_class(arguments); + ( + patterns.ast_to_object(vm, source_file), + kwd_attrs.ast_to_object(vm, source_file), + kwd_patterns.ast_to_object(vm, source_file), + ) + }; + dict.set_item("patterns", patterns, vm).unwrap(); + dict.set_item("kwd_attrs", kwd_attrs, vm).unwrap(); + dict.set_item("kwd_patterns", kwd_patterns, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -349,45 +541,12 @@ impl Node for ast::PatternMatchClass { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let patterns = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "patterns", "MatchClass")?, - )?; - let kwd_attrs: PatternMatchClassKeywordAttributes = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "kwd_attrs", "MatchClass")?, - )?; - let kwd_patterns: PatternMatchClassKeywordPatterns = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "kwd_patterns", "MatchClass")?, - )?; - if kwd_attrs.0.len() != kwd_patterns.0.len() { - return Err(vm.new_value_error("MatchClass has mismatched kwd_attrs and kwd_patterns")); - } - let (patterns, keywords) = merge_pattern_match_class(patterns, kwd_attrs, kwd_patterns); - - Ok(Self { - node_index: Default::default(), - cls: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "cls", "MatchClass")?, - )?, - range: range_from_object(vm, source_file, object, "MatchClass")?, - arguments: ast::PatternArguments { - node_index: Default::default(), - range: Default::default(), - patterns, - keywords, - }, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchClass")?; + pattern_match_class_from_object_with_range(vm, source_file, object, range) } } -struct PatternMatchClassPatterns(Vec); +struct PatternMatchClassPatterns(ast::Patterns); impl Node for PatternMatchClassPatterns { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -405,6 +564,24 @@ impl Node for PatternMatchClassPatterns { struct PatternMatchClassKeywordAttributes(Vec); +impl PatternMatchClassKeywordAttributes { + fn ast_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, + ) -> PyResult { + Ok(Self(get_node_list_field( + vm, + source_file, + object, + field, + typ, + )?)) + } +} + impl Node for PatternMatchClassKeywordAttributes { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { self.0.ast_to_object(vm, source_file) @@ -419,7 +596,7 @@ impl Node for PatternMatchClassKeywordAttributes { } } -struct PatternMatchClassKeywordPatterns(Vec); +struct PatternMatchClassKeywordPatterns(ast::Patterns); impl Node for PatternMatchClassKeywordPatterns { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -435,6 +612,21 @@ impl Node for PatternMatchClassKeywordPatterns { } } // constructor +fn pattern_match_star_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::PatternMatchStar { + node_index: Default::default(), + name: get_node_field_opt(vm, &object, "name")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::PatternMatchStar { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -457,17 +649,30 @@ impl Node for ast::PatternMatchStar { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - name: get_node_field_opt(vm, &object, "name")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "MatchStar")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchStar")?; + pattern_match_star_from_object_with_range(vm, source_file, object, range) } } // constructor +fn pattern_match_as_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::PatternMatchAs { + node_index: Default::default(), + pattern: get_node_field_opt(vm, &object, "pattern")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + name: get_node_field_opt(vm, &object, "name")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::PatternMatchAs { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -493,24 +698,31 @@ impl Node for ast::PatternMatchAs { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - pattern: get_node_field_opt(vm, &object, "pattern")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - name: get_node_field_opt(vm, &object, "name")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "MatchAs")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchAs")?; + pattern_match_as_from_object_with_range(vm, source_file, object, range) } } // constructor +fn pattern_match_or_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let (node_index, patterns) = + pattern_list_from_field(vm, source_file, &object, "patterns", "MatchOr", range)?; + Ok(ast::PatternMatchOr { + node_index, + patterns: patterns.to_vec(), + range, + }) +} + impl Node for ast::PatternMatchOr { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, patterns, range, } = self; @@ -518,8 +730,12 @@ impl Node for ast::PatternMatchOr { .into_ref_with_type(vm, pyast::NodePatternMatchOr::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("patterns", patterns.ast_to_object(vm, source_file), vm) - .unwrap(); + let patterns = super::constant::public_ast_pattern_list_object(node_index.load()) + .map_or_else( + || patterns.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("patterns", patterns, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -528,15 +744,8 @@ impl Node for ast::PatternMatchOr { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - patterns: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "patterns", "MatchOr")?, - )?, - range: range_from_object(vm, source_file, object, "MatchOr")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "MatchOr")?; + pattern_match_or_from_object_with_range(vm, source_file, object, range) } } @@ -562,7 +771,7 @@ fn merge_pattern_match_class( patterns: PatternMatchClassPatterns, kwd_attrs: PatternMatchClassKeywordAttributes, kwd_patterns: PatternMatchClassKeywordPatterns, -) -> (Vec, Vec) { +) -> (ast::Patterns, Vec) { let keywords = kwd_attrs .0 .into_iter() diff --git a/crates/vm/src/stdlib/_ast/pyast.rs b/crates/vm/src/stdlib/_ast/pyast.rs index 7eae6f00986..eb97eec8024 100644 --- a/crates/vm/src/stdlib/_ast/pyast.rs +++ b/crates/vm/src/stdlib/_ast/pyast.rs @@ -3,7 +3,6 @@ use crate::builtins::{PyGenericAlias, PyTuple, PyTupleRef, PyTypeRef, make_union use crate::common::ascii; use crate::convert::ToPyObject; use crate::function::FuncArgs; -use crate::types::Initializer; macro_rules! impl_node { ( @@ -61,6 +60,12 @@ macro_rules! impl_node { macro_rules! impl_base_node { // Base node without fields/attributes (e.g. NodeMod, NodeExpr) ($name:ident) => { + impl_base_node!($name, attributes: []); + }; + ($name:ident, attributes: [$($attr:expr),* $(,)?]) => { + impl_base_node!($name, attributes: [$($attr),*], optional_end_location: false); + }; + ($name:ident, attributes: [$($attr:expr),* $(,)?], optional_end_location: $optional_end_location:expr) => { #[pyclass(flags(HAS_DICT, BASETYPE))] impl $name { #[pymethod] @@ -83,9 +88,24 @@ macro_rules! impl_base_node { (*flags).remove(crate::types::PyTypeFlags::IMMUTABLETYPE); } class.set_attr( - identifier!(ctx, _attributes), + identifier!(ctx, _fields), ctx.empty_tuple.clone().into(), ); + class.set_str_attr("__match_args__", ctx.empty_tuple.clone(), ctx); + class.set_attr( + identifier!(ctx, _attributes), + ctx.new_tuple(vec![ + $( + ctx.new_str(ascii!($attr)).into() + ),* + ]) + .into(), + ); + if $optional_end_location { + let none = ctx.none(); + class.set_str_attr("end_lineno", none.clone(), ctx); + class.set_str_attr("end_col_offset", none, ctx); + } } } }; @@ -179,7 +199,11 @@ impl_node!( #[repr(transparent)] pub(crate) struct NodeStmt(NodeAst); -impl_base_node!(NodeStmt); +impl_base_node!( + NodeStmt, + attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"], + optional_end_location: true +); impl_node!( #[pyclass(module = "_ast", name = "FunctionType", base = NodeMod)] @@ -378,7 +402,11 @@ impl_node!( #[repr(transparent)] pub(crate) struct NodeExpr(NodeAst); -impl_base_node!(NodeExpr); +impl_base_node!( + NodeExpr, + attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"], + optional_end_location: true +); impl_node!( #[pyclass(module = "_ast", name = "Continue", base = NodeStmt)] @@ -533,12 +561,11 @@ impl_node!( attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"], ); -// NodeExprConstant needs custom Initializer to default kind to None #[pyclass(module = "_ast", name = "Constant", base = NodeExpr)] #[repr(transparent)] pub(crate) struct NodeExprConstant(NodeExpr); -#[pyclass(flags(HAS_DICT, BASETYPE), with(Initializer))] +#[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeExprConstant { #[extend_class] fn extend_class_with_fields(ctx: &Context, class: &'static Py) { @@ -580,24 +607,6 @@ impl NodeExprConstant { } } -impl Initializer for NodeExprConstant { - type Args = FuncArgs; - - fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { - ::slot_init(zelf.clone(), args, vm)?; - // kind defaults to None if not provided - let dict = zelf.as_object().dict().unwrap(); - if !dict.contains_key("kind", vm) { - dict.set_item("kind", vm.ctx.none(), vm)?; - } - Ok(()) - } - - fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { - unreachable!("slot_init is defined") - } -} - impl_node!( #[pyclass(module = "_ast", name = "Attribute", base = NodeExpr)] pub(crate) struct NodeExprAttribute, @@ -841,7 +850,11 @@ impl_node!( #[repr(transparent)] pub(crate) struct NodeExceptHandler(NodeAst); -impl_base_node!(NodeExceptHandler); +impl_base_node!( + NodeExceptHandler, + attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"], + optional_end_location: true +); impl_node!( #[pyclass(module = "_ast", name = "comprehension", base = NodeAst)] @@ -893,7 +906,10 @@ impl_node!( #[repr(transparent)] pub(crate) struct NodePattern(NodeAst); -impl_base_node!(NodePattern); +impl_base_node!( + NodePattern, + attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"] +); impl_node!( #[pyclass(module = "_ast", name = "match_case", base = NodeAst)] @@ -967,7 +983,10 @@ impl_node!( #[repr(transparent)] pub(crate) struct NodeTypeParam(NodeAst); -impl_base_node!(NodeTypeParam); +impl_base_node!( + NodeTypeParam, + attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"] +); impl_node!( #[pyclass(module = "_ast", name = "TypeIgnore", base = NodeTypeIgnore)] @@ -1681,7 +1700,6 @@ fn populate_field_types(vm: &VirtualMachine, module: &Py) { let field_types_attr = vm.ctx.intern_str("_field_types"); let annotations_attr = vm.ctx.intern_str("__annotations__"); - let empty_dict: PyObjectRef = vm.ctx.new_dict().into(); for &(class_name, fields) in FIELD_TYPES { if fields.is_empty() { @@ -1752,36 +1770,6 @@ fn populate_field_types(vm: &VirtualMachine, module: &Py) { type_obj.set_attr(annotations_attr, field_types); } } - - // Base AST classes (e.g., expr, stmt) should still expose __annotations__. - const BASE_AST_TYPES: &[&str] = &[ - "mod", - "stmt", - "expr", - "expr_context", - "boolop", - "operator", - "unaryop", - "cmpop", - "excepthandler", - "pattern", - "type_ignore", - "type_param", - ]; - for &class_name in BASE_AST_TYPES { - let class = module - .get_attr(class_name, vm) - .unwrap_or_else(|_| panic!("AST class '{class_name}' not found in module")); - let Some(type_obj) = class.downcast_ref::() else { - continue; - }; - if type_obj.get_attr(field_types_attr).is_none() { - type_obj.set_attr(field_types_attr, empty_dict.clone()); - } - if type_obj.get_attr(annotations_attr).is_none() { - type_obj.set_attr(annotations_attr, empty_dict.clone()); - } - } } fn populate_singletons(vm: &VirtualMachine, module: &Py) { diff --git a/crates/vm/src/stdlib/_ast/python.rs b/crates/vm/src/stdlib/_ast/python.rs index ee5588acb87..e7697e4cf4b 100644 --- a/crates/vm/src/stdlib/_ast/python.rs +++ b/crates/vm/src/stdlib/_ast/python.rs @@ -7,13 +7,10 @@ use super::{ #[pymodule] pub(crate) mod _ast { use crate::{ - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{ - PyDictRef, PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef, PyUtf8Str, PyUtf8StrRef, - }, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyDictRef, PySet, PyStr, PyTupleRef, PyType, PyTypeRef, PyUtf8Str}, class::{PyClassImpl, StaticType}, - common::wtf8::Wtf8, - function::{FuncArgs, KwArgs, PyMethodDef, PyMethodFlags}, + function::{ArgIterable, FuncArgs, KwArgs, PyMethodDef, PyMethodFlags}, stdlib::_ast::repr, types::{Constructor, Initializer}, warn, @@ -117,10 +114,12 @@ pub(crate) mod _ast { let fields = cls.get_attr(vm.ctx.intern_str("_fields")); if let Some(fields) = fields { - let fields: Vec = fields.try_to_value(vm)?; + let fields = fields.sequence_unchecked(); + let numfields = fields.length(vm)?; let mut positional: Vec = Vec::new(); - for field in fields { - if dict.get_item_opt::(field.as_wtf8(), vm)?.is_some() { + for i in 0..numfields { + let field = fields.get_item(i as isize, vm)?; + if dict.get_item_opt(&*field, vm)?.is_some() { positional.push(vm.ctx.none()); } else { break; @@ -136,6 +135,85 @@ pub(crate) mod _ast { .new_tuple(vec![type_obj, vm.ctx.new_tuple(vec![]).into(), dict.into()])) } + fn ast_replace_update_payload( + payload: &PyDictRef, + keys: Option<&PyObjectRef>, + dict: &PyDictRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(keys) = keys else { + return Ok(()); + }; + let keys = keys.sequence_unchecked(); + let num_keys = keys.length(vm)?; + for i in 0..num_keys { + let key = keys.get_item(i as isize, vm)?; + if let Some(value) = dict.get_item_opt(&*key, vm)? { + payload.set_item(&*key, value, vm)?; + } + } + Ok(()) + } + + fn ast_replace_set_update( + expecting: &PyRef, + iterable: Option<&PyObjectRef>, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(iterable) = iterable else { + return Ok(()); + }; + let iterable = iterable.clone().try_into_value::(vm)?; + for item in iterable.iter(vm)? { + expecting.add(item?, vm)?; + } + Ok(()) + } + + fn ast_replace_set_discard( + expecting: &PyRef, + key: &PyObject, + vm: &VirtualMachine, + ) -> PyResult { + let contained = expecting + .as_object() + .sequence_unchecked() + .contains(key, vm)?; + if contained { + vm.call_method(expecting.as_object(), "discard", (key.to_owned(),))?; + } + Ok(contained) + } + + fn ast_replace_set_difference_update( + expecting: &PyRef, + iterable: Option<&PyObjectRef>, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(iterable) = iterable else { + return Ok(()); + }; + let iterable = iterable.clone().try_into_value::(vm)?; + for item in iterable.iter(vm)? { + let item = item?; + ast_replace_set_discard(expecting, &item, vm)?; + } + Ok(()) + } + + fn ast_set_attr( + obj: &PyObject, + name: &PyObject, + value: impl Into, + vm: &VirtualMachine, + ) -> PyResult<()> { + let name = name + .to_owned() + .downcast::() + .map_err(|_| vm.new_type_error("attribute name must be string"))?; + obj.set_attr(&name, value, vm) + } + pub(crate) fn ast_replace(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { if !args.args.is_empty() { return Err(vm.new_type_error("__replace__() takes no positional arguments")); @@ -146,22 +224,13 @@ pub(crate) mod _ast { let attributes = cls.get_attr(vm.ctx.intern_str("_attributes")); let dict = zelf.as_object().dict(); - let mut expecting: std::collections::HashSet = std::collections::HashSet::new(); - if let Some(fields) = fields.clone() { - let fields: Vec = fields.try_to_value(vm)?; - for field in fields { - expecting.insert(field.as_str().to_owned()); - } - } - if let Some(attributes) = attributes.clone() { - let attributes: Vec = attributes.try_to_value(vm)?; - for attr in attributes { - expecting.insert(attr.as_str().to_owned()); - } - } + let expecting = PySet::default().into_ref(&vm.ctx); + ast_replace_set_update(&expecting, fields.as_ref(), vm)?; + ast_replace_set_update(&expecting, attributes.as_ref(), vm)?; for (key, _value) in &args.kwargs { - if !expecting.remove(key) { + let key_obj: PyObjectRef = vm.ctx.new_str(key.as_str()).into(); + if !ast_replace_set_discard(&expecting, &key_obj, vm)? { return Err(vm.new_type_error(format!( "{}.__replace__ got an unexpected keyword argument '{}'.", cls.name(), @@ -172,16 +241,9 @@ pub(crate) mod _ast { if let Some(dict) = dict.as_ref() { for (key, _value) in dict.items_vec() { - if let Ok(key) = key.downcast::() { - expecting.remove(key.as_str()); - } - } - if let Some(attributes) = attributes.clone() { - let attributes: Vec = attributes.try_to_value(vm)?; - for attr in attributes { - expecting.remove(attr.as_str()); - } + ast_replace_set_discard(&expecting, &key, vm)?; } + ast_replace_set_difference_update(&expecting, attributes.as_ref(), vm)?; } // Discard optional fields (T | None). @@ -189,20 +251,18 @@ pub(crate) mod _ast { && let Ok(field_types) = field_types.downcast::() { for (key, value) in field_types.items_vec() { - let Ok(key) = key.downcast::() else { - continue; - }; if value.fast_isinstance(vm.ctx.types.union_type) { - expecting.remove(key.as_str()); + ast_replace_set_discard(&expecting, &key, vm)?; } } } - if !expecting.is_empty() { - let mut names: Vec = expecting - .into_iter() - .map(|name| format!("{name:?}")) - .collect(); + let remaining = expecting.elements(); + if !remaining.is_empty() { + let mut names = Vec::with_capacity(remaining.len()); + for name in &remaining { + names.push(name.repr(vm)?.to_string()); + } names.sort(); let missing = names.join(", "); let count = names.len(); @@ -217,22 +277,8 @@ pub(crate) mod _ast { let payload = vm.ctx.new_dict(); if let Some(dict) = dict { - if let Some(fields) = fields { - let fields: Vec = fields.try_to_value(vm)?; - for field in fields { - if let Some(value) = dict.get_item_opt::(field.as_wtf8(), vm)? { - payload.set_item(field.as_object(), value, vm)?; - } - } - } - if let Some(attributes) = attributes { - let attributes: Vec = attributes.try_to_value(vm)?; - for attr in attributes { - if let Some(value) = dict.get_item_opt::(attr.as_wtf8(), vm)? { - payload.set_item(attr.as_object(), value, vm)?; - } - } - } + ast_replace_update_payload(&payload, fields.as_ref(), &dict, vm)?; + ast_replace_update_payload(&payload, attributes.as_ref(), &dict, vm)?; } for (key, value) in args.kwargs { payload.set_item(vm.ctx.intern_str(key), value, vm)?; @@ -327,7 +373,7 @@ pub(crate) mod _ast { } fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { - unimplemented!("use slot_new") + unreachable!("NodeAst construction is handled by slot_new") } } @@ -350,52 +396,55 @@ pub(crate) mod _ast { zelf.class().name() )) })?; - let fields: Vec = fields.try_to_value(vm)?; + let fields_seq = fields.sequence_unchecked(); + let numfields = fields_seq.length(vm)?; + let remaining_fields = PySet::default().into_ref(&vm.ctx); + ast_replace_set_update(&remaining_fields, Some(&fields), vm)?; let n_args = args.args.len(); - if n_args > fields.len() { + if n_args > numfields { return Err(vm.new_type_error(format!( "{} constructor takes at most {} positional argument{}", zelf.class().name(), - fields.len(), - if fields.len() == 1 { "" } else { "s" }, + numfields, + if numfields == 1 { "" } else { "s" }, ))); } - // Track which fields were set - let mut set_fields = std::collections::HashSet::new(); - let mut attributes: Option> = None; + let mut attributes: Option = None; - for (name, arg) in fields.iter().zip(args.args) { - zelf.set_attr(name, arg, vm)?; - set_fields.insert(name.as_str().to_owned()); + for (i, arg) in args.args.into_iter().enumerate() { + let name = fields_seq.get_item(i as isize, vm)?; + ast_set_attr(&zelf, &name, arg, vm)?; + ast_replace_set_discard(&remaining_fields, &name, vm)?; } for (key, value) in args.kwargs { - if let Some(pos) = fields.iter().position(|f| f.as_bytes() == key.as_bytes()) - && pos < n_args - { - return Err(vm.new_type_error(format!( - "{} got multiple values for argument '{}'", - zelf.class().name(), - key - ))); - } - - if fields - .iter() - .all(|field| field.as_bytes() != key.as_bytes()) - { - let attrs = if let Some(attrs) = &attributes { - attrs + let key_obj: PyObjectRef = vm.ctx.new_str(key.as_str()).into(); + let contains = fields_seq.contains(&key_obj, vm)?; + if contains { + if !ast_replace_set_discard(&remaining_fields, &key_obj, vm)? { + return Err(vm.new_type_error(format!( + "{} got multiple values for argument '{}'", + zelf.class().name(), + key + ))); + } + } else { + let attrs = if let Some(attributes) = &attributes { + attributes } else { let attrs = zelf .class() .get_attr(vm.ctx.intern_str("_attributes")) - .and_then(|attr| attr.try_to_value::>(vm).ok()) - .unwrap_or_default(); + .ok_or_else(|| { + vm.new_attribute_error(format!( + "type object '{}' has no attribute '_attributes'", + zelf.class().name() + )) + })?; attributes = Some(attrs); attributes.as_ref().unwrap() }; - if attrs.iter().all(|attr| attr.as_bytes() != key.as_bytes()) { + if !attrs.sequence_unchecked().contains(&key_obj, vm)? { let message = vm.ctx.new_str(format!( "{}.__init__ got an unexpected keyword argument '{}'. \ Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15.", @@ -412,7 +461,6 @@ Support for arbitrary keyword arguments is deprecated and will be removed in Pyt } } - set_fields.insert(key.clone()); zelf.set_attr(vm.ctx.intern_str(key), value, vm)?; } @@ -425,17 +473,14 @@ Support for arbitrary keyword arguments is deprecated and will be removed in Pyt let expr_ctx_type: PyObjectRef = super::super::pyast::NodeExprContext::make_static_type().into(); - for field in &fields { - if set_fields.contains(field.as_str()) { - continue; - } - if let Some(ftype) = ft_dict.get_item_opt::(field.as_wtf8(), vm)? { + for field in remaining_fields.elements() { + if let Some(ftype) = ft_dict.get_item_opt(&*field, vm)? { if ftype.fast_isinstance(vm.ctx.types.union_type) { // Optional field (T | None) — no default } else if ftype.fast_isinstance(vm.ctx.types.generic_alias_type) { // List field (list[T]) — default to [] let empty_list: PyObjectRef = vm.ctx.new_list(vec![]).into(); - zelf.set_attr(vm.ctx.intern_str(field.as_wtf8()), empty_list, vm)?; + ast_set_attr(&zelf, &field, empty_list, vm)?; } else if ftype.is(&expr_ctx_type) { // expr_context — default to Load() let load_type = @@ -445,13 +490,15 @@ Support for arbitrary keyword arguments is deprecated and will be removed in Pyt .unwrap_or_else(|| { vm.ctx.new_base_object(load_type, Some(vm.ctx.new_dict())) }); - zelf.set_attr(vm.ctx.intern_str(field.as_wtf8()), load_instance, vm)?; + ast_set_attr(&zelf, &field, load_instance, vm)?; } else { // Required field missing: emit DeprecationWarning. + let field_repr = field.repr(vm)?; let message = vm.ctx.new_str(format!( - "{}.__init__ missing 1 required positional argument: '{}'", + "{}.__init__ missing 1 required positional argument: {}. \ +This will become an error in Python 3.15.", zelf.class().name(), - field.as_wtf8() + field_repr )); warn::warn( message.into(), @@ -461,6 +508,21 @@ Support for arbitrary keyword arguments is deprecated and will be removed in Pyt vm, )?; } + } else { + let field_repr = field.repr(vm)?; + let message = vm.ctx.new_str(format!( + "Field {} is missing from {}._field_types. \ +This will become an error in Python 3.15.", + field_repr, + zelf.class().name() + )); + warn::warn( + message.into(), + Some(vm.ctx.exceptions.deprecation_warning.to_owned()), + 1, + None, + vm, + )?; } } } @@ -510,9 +572,29 @@ Support for arbitrary keyword arguments is deprecated and will be removed in Pyt .map_err(|_| vm.new_type_error("AST is not a type"))?; let ctx = &vm.ctx; let empty_tuple = ctx.empty_tuple.clone(); + let set_empty_annotations = |typ: &Py| { + typ.set_str_attr("__annotations__", ctx.new_dict(), ctx); + }; + set_empty_annotations(&ast_type); ast_type.set_str_attr("_fields", empty_tuple.clone(), ctx); ast_type.set_str_attr("_attributes", empty_tuple.clone(), ctx); ast_type.set_str_attr("__match_args__", empty_tuple, ctx); + for typ in [ + super::super::pyast::NodeMod::static_type(), + super::super::pyast::NodeStmt::static_type(), + super::super::pyast::NodeExpr::static_type(), + super::super::pyast::NodeExprContext::static_type(), + super::super::pyast::NodeBoolOp::static_type(), + super::super::pyast::NodeOperator::static_type(), + super::super::pyast::NodeUnaryOp::static_type(), + super::super::pyast::NodeCmpOp::static_type(), + super::super::pyast::NodeExceptHandler::static_type(), + super::super::pyast::NodePattern::static_type(), + super::super::pyast::NodeTypeIgnore::static_type(), + super::super::pyast::NodeTypeParam::static_type(), + ] { + set_empty_annotations(typ); + } const AST_REDUCE: PyMethodDef = PyMethodDef::new_const( "__reduce__", diff --git a/crates/vm/src/stdlib/_ast/repr.rs b/crates/vm/src/stdlib/_ast/repr.rs index 2897447fbec..57f00c095e0 100644 --- a/crates/vm/src/stdlib/_ast/repr.rs +++ b/crates/vm/src/stdlib/_ast/repr.rs @@ -1,7 +1,8 @@ use crate::{ AsObject, PyObjectRef, PyResult, VirtualMachine, - builtins::{PyList, PyTuple}, + builtins::{PyList, PyStr, PyTuple}, class::PyClassImpl, + recursion::ReprGuard, stdlib::_ast::NodeAst, }; use rustpython_common::wtf8::Wtf8Buf; @@ -33,9 +34,7 @@ fn repr_ast_list(vm: &VirtualMachine, items: Vec, depth: usize) -> rendered.push_wtf8(&parts[0]); } if items.len() > 2 { - if !parts[0].is_empty() { - rendered.push_wtf8(", ...".as_ref()); - } + rendered.push_wtf8(", ...".as_ref()); if parts.len() > 1 { rendered.push_wtf8(", ".as_ref()); rendered.push_wtf8(&parts[1]); @@ -75,9 +74,7 @@ fn repr_ast_tuple(vm: &VirtualMachine, items: Vec, depth: usize) -> rendered.push_wtf8(&parts[0]); } if items.len() > 2 { - if !parts[0].is_empty() { - rendered.push_wtf8(", ...".as_ref()); - } + rendered.push_wtf8(", ...".as_ref()); if parts.len() > 1 { rendered.push_wtf8(", ".as_ref()); rendered.push_wtf8(&parts[1]); @@ -86,9 +83,6 @@ fn repr_ast_tuple(vm: &VirtualMachine, items: Vec, depth: usize) -> rendered.push_wtf8(", ".as_ref()); rendered.push_wtf8(&parts[1]); } - if items.len() == 1 { - rendered.push_wtf8(",".as_ref()); - } rendered.push_wtf8(")".as_ref()); Ok(rendered) } @@ -104,18 +98,24 @@ pub(crate) fn repr_ast_node( s.push_wtf8("(...)".as_ref()); return Ok(s); } + let Some(_guard) = ReprGuard::enter(vm, obj.as_object()) else { + let mut s = Wtf8Buf::from(&*cls.name()); + s.push_wtf8("(...)".as_ref()); + return Ok(s); + }; - let fields = cls.get_attr(vm.ctx.intern_str("_fields")); - let fields = match fields { - Some(fields) => fields.try_to_value::>(vm)?, + let fields = match cls.get_attr(vm.ctx.intern_str("_fields")) { + Some(fields) => fields, None => { let mut s = Wtf8Buf::from(&*cls.name()); s.push_wtf8("(...)".as_ref()); return Ok(s); } }; + let fields = fields.sequence_unchecked(); + let numfields = fields.length(vm)?; - if fields.is_empty() { + if numfields == 0 { let mut s = Wtf8Buf::from(&*cls.name()); s.push_wtf8("()".as_ref()); return Ok(s); @@ -124,8 +124,12 @@ pub(crate) fn repr_ast_node( let mut rendered = Wtf8Buf::from(&*cls.name()); rendered.push_wtf8("(".as_ref()); - for (idx, field) in fields.iter().enumerate() { - let value = obj.get_attr(field, vm)?; + for idx in 0..numfields { + let field = fields.get_item(idx as isize, vm)?; + let field = field + .downcast::() + .map_err(|_| vm.new_type_error("attribute name must be string"))?; + let value = obj.get_attr(&field, vm)?; let value_repr = if value.fast_isinstance(vm.ctx.types.list_type) { let list = value .downcast::() diff --git a/crates/vm/src/stdlib/_ast/statement.rs b/crates/vm/src/stdlib/_ast/statement.rs index 43d1162a402..7f94dce5da0 100644 --- a/crates/vm/src/stdlib/_ast/statement.rs +++ b/crates/vm/src/stdlib/_ast/statement.rs @@ -1,7 +1,66 @@ use super::*; -use crate::stdlib::_ast::argument::{merge_class_def_args, split_class_def_args}; +use crate::stdlib::_ast::argument::{ + KeywordArguments, PositionalArguments, merge_class_def_args, split_class_def_args, +}; +use crate::stdlib::_ast::exception::except_handler_from_object_unvalidated_range; +use crate::stdlib::_ast::type_parameters::type_params_from_field; use rustpython_compiler_core::SourceFile; +fn public_decorator_expr_list(values: &[Option]) -> Vec> { + values + .iter() + .map(|value| value.as_ref().map(|decorator| decorator.expression.clone())) + .collect() +} + +fn lower_public_decorator_list(values: Vec>) -> ast::DecoratorList { + values + .into_iter() + .map(|value| { + value.unwrap_or_else(|| ast::Decorator { + range: Default::default(), + node_index: Default::default(), + expression: public_null_expr_placeholder(), + }) + }) + .collect() +} + +fn definition_range_from_name( + source_file: &SourceFile, + name_start: TextSize, + end: TextSize, + keyword: &str, +) -> TextRange { + let source_code = source_file.to_source_code(); + let line = source_code.line_index(name_start); + let line_start = source_code.line_start(line); + let keyword_start = source_code + .slice(TextRange::new(line_start, name_start)) + .rfind(keyword) + .map_or(line_start, |offset| { + line_start + TextSize::new(offset as u32) + }); + TextRange::new(keyword_start, end) +} + +fn register_public_ast_stmt_type_comment( + node_index: &ast::AtomicNodeIndex, + type_comment: Option, +) { + if let Some(type_comment) = type_comment { + super::constant::register_public_ast_stmt_type_comment(node_index, type_comment); + } +} + +fn public_ast_stmt_type_comment_object( + vm: &VirtualMachine, + node_index: ast::NodeIndex, +) -> PyObjectRef { + super::constant::public_ast_stmt_type_comment_object(node_index) + .unwrap_or_else(|| vm.ctx.none()) +} + // sum impl Node for ast::Stmt { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -31,127 +90,314 @@ impl Node for ast::Stmt { Self::Break(cons) => cons.ast_to_object(vm, source_file), Self::Continue(cons) => cons.ast_to_object(vm, source_file), Self::IpyEscapeCommand(_) => { - unimplemented!("IPython escape command is not allowed in Python AST") + unreachable!("IPython escape command is not part of Python AST") } } } - #[expect(clippy::if_same_then_else, reason = "Looks better here")] fn ast_from_object( vm: &VirtualMachine, source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeStmtFunctionDef::static_type()) { - Self::FunctionDef(ast::StmtFunctionDef::ast_from_object( + if vm.is_none(&object) { + return Err(vm.new_value_error("None disallowed in statement list")); + } + enum StmtKind { + FunctionDef { is_async: bool }, + ClassDef, + Return, + Delete, + Assign, + TypeAlias, + AugAssign, + AnnAssign, + For { is_async: bool }, + While, + If, + With { is_async: bool }, + Match, + Raise, + Try { is_star: bool }, + Assert, + Import, + ImportFrom, + Global, + Nonlocal, + Expr, + Pass, + Break, + Continue, + } + let kind = if is_node_instance(vm, &object, pyast::NodeStmtFunctionDef::static_type())? { + StmtKind::FunctionDef { is_async: false } + } else if is_node_instance(vm, &object, pyast::NodeStmtAsyncFunctionDef::static_type())? { + StmtKind::FunctionDef { is_async: true } + } else if is_node_instance(vm, &object, pyast::NodeStmtClassDef::static_type())? { + StmtKind::ClassDef + } else if is_node_instance(vm, &object, pyast::NodeStmtReturn::static_type())? { + StmtKind::Return + } else if is_node_instance(vm, &object, pyast::NodeStmtDelete::static_type())? { + StmtKind::Delete + } else if is_node_instance(vm, &object, pyast::NodeStmtAssign::static_type())? { + StmtKind::Assign + } else if is_node_instance(vm, &object, pyast::NodeStmtTypeAlias::static_type())? { + StmtKind::TypeAlias + } else if is_node_instance(vm, &object, pyast::NodeStmtAugAssign::static_type())? { + StmtKind::AugAssign + } else if is_node_instance(vm, &object, pyast::NodeStmtAnnAssign::static_type())? { + StmtKind::AnnAssign + } else if is_node_instance(vm, &object, pyast::NodeStmtFor::static_type())? { + StmtKind::For { is_async: false } + } else if is_node_instance(vm, &object, pyast::NodeStmtAsyncFor::static_type())? { + StmtKind::For { is_async: true } + } else if is_node_instance(vm, &object, pyast::NodeStmtWhile::static_type())? { + StmtKind::While + } else if is_node_instance(vm, &object, pyast::NodeStmtIf::static_type())? { + StmtKind::If + } else if is_node_instance(vm, &object, pyast::NodeStmtWith::static_type())? { + StmtKind::With { is_async: false } + } else if is_node_instance(vm, &object, pyast::NodeStmtAsyncWith::static_type())? { + StmtKind::With { is_async: true } + } else if is_node_instance(vm, &object, pyast::NodeStmtMatch::static_type())? { + StmtKind::Match + } else if is_node_instance(vm, &object, pyast::NodeStmtRaise::static_type())? { + StmtKind::Raise + } else if is_node_instance(vm, &object, pyast::NodeStmtTry::static_type())? { + StmtKind::Try { is_star: false } + } else if is_node_instance(vm, &object, pyast::NodeStmtTryStar::static_type())? { + StmtKind::Try { is_star: true } + } else if is_node_instance(vm, &object, pyast::NodeStmtAssert::static_type())? { + StmtKind::Assert + } else if is_node_instance(vm, &object, pyast::NodeStmtImport::static_type())? { + StmtKind::Import + } else if is_node_instance(vm, &object, pyast::NodeStmtImportFrom::static_type())? { + StmtKind::ImportFrom + } else if is_node_instance(vm, &object, pyast::NodeStmtGlobal::static_type())? { + StmtKind::Global + } else if is_node_instance(vm, &object, pyast::NodeStmtNonlocal::static_type())? { + StmtKind::Nonlocal + } else if is_node_instance(vm, &object, pyast::NodeStmtExpr::static_type())? { + StmtKind::Expr + } else if is_node_instance(vm, &object, pyast::NodeStmtPass::static_type())? { + StmtKind::Pass + } else if is_node_instance(vm, &object, pyast::NodeStmtBreak::static_type())? { + StmtKind::Break + } else if is_node_instance(vm, &object, pyast::NodeStmtContinue::static_type())? { + StmtKind::Continue + } else { + return Err(vm.new_type_error(format!( + "expected some sort of stmt, but got {}", + object.repr(vm)? + ))); + }; + let range = stmt_range_from_object(vm, source_file, object.clone())?; + Ok(match kind { + StmtKind::FunctionDef { is_async } => Self::FunctionDef( + stmt_function_def_from_object_with_range(vm, source_file, object, range, is_async)?, + ), + StmtKind::ClassDef => Self::ClassDef(stmt_class_def_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeStmtAsyncFunctionDef::static_type()) { - Self::FunctionDef(ast::StmtFunctionDef::ast_from_object( + range, + )?), + StmtKind::Return => Self::Return(stmt_return_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeStmtClassDef::static_type()) { - Self::ClassDef(ast::StmtClassDef::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtReturn::static_type()) { - Self::Return(ast::StmtReturn::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtDelete::static_type()) { - Self::Delete(ast::StmtDelete::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtAssign::static_type()) { - Self::Assign(ast::StmtAssign::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtTypeAlias::static_type()) { - Self::TypeAlias(ast::StmtTypeAlias::ast_from_object( + range, + )?), + StmtKind::Delete => Self::Delete(stmt_delete_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeStmtAugAssign::static_type()) { - Self::AugAssign(ast::StmtAugAssign::ast_from_object( + range, + )?), + StmtKind::Assign => Self::Assign(stmt_assign_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeStmtAnnAssign::static_type()) { - Self::AnnAssign(ast::StmtAnnAssign::ast_from_object( + range, + )?), + StmtKind::TypeAlias => Self::TypeAlias(stmt_type_alias_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeStmtFor::static_type()) { - Self::For(ast::StmtFor::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtAsyncFor::static_type()) { - Self::For(ast::StmtFor::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtWhile::static_type()) { - Self::While(ast::StmtWhile::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtIf::static_type()) { - Self::If(ast::StmtIf::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtWith::static_type()) { - Self::With(ast::StmtWith::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtAsyncWith::static_type()) { - Self::With(ast::StmtWith::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtMatch::static_type()) { - Self::Match(ast::StmtMatch::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtRaise::static_type()) { - Self::Raise(ast::StmtRaise::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtTry::static_type()) { - Self::Try(ast::StmtTry::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtTryStar::static_type()) { - Self::Try(ast::StmtTry::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtAssert::static_type()) { - Self::Assert(ast::StmtAssert::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtImport::static_type()) { - Self::Import(ast::StmtImport::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtImportFrom::static_type()) { - Self::ImportFrom(ast::StmtImportFrom::ast_from_object( + range, + )?), + StmtKind::AugAssign => Self::AugAssign(stmt_aug_assign_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeStmtGlobal::static_type()) { - Self::Global(ast::StmtGlobal::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtNonlocal::static_type()) { - Self::Nonlocal(ast::StmtNonlocal::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtExpr::static_type()) { - Self::Expr(ast::StmtExpr::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtPass::static_type()) { - Self::Pass(ast::StmtPass::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtBreak::static_type()) { - Self::Break(ast::StmtBreak::ast_from_object(vm, source_file, object)?) - } else if cls.is(pyast::NodeStmtContinue::static_type()) { - Self::Continue(ast::StmtContinue::ast_from_object(vm, source_file, object)?) - } else if vm.is_none(&object) { - return Err(vm.new_value_error("None disallowed in statement list")); - } else { - return Err(vm.new_type_error(format!( - "expected some sort of stmt, but got {}", - object.repr(vm)? - ))); + range, + )?), + StmtKind::AnnAssign => Self::AnnAssign(stmt_ann_assign_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::For { is_async } => Self::For(stmt_for_from_object_with_range( + vm, + source_file, + object, + range, + is_async, + )?), + StmtKind::While => Self::While(stmt_while_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::If => Self::If(elif_else_clause::ast_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::With { is_async } => Self::With(stmt_with_from_object_with_range( + vm, + source_file, + object, + range, + is_async, + )?), + StmtKind::Match => Self::Match(stmt_match_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Raise => Self::Raise(stmt_raise_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Try { is_star } => Self::Try(stmt_try_from_object_with_range( + vm, + source_file, + object, + range, + is_star, + )?), + StmtKind::Assert => Self::Assert(stmt_assert_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Import => Self::Import(stmt_import_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::ImportFrom => Self::ImportFrom(stmt_import_from_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Global => Self::Global(stmt_global_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Nonlocal => Self::Nonlocal(stmt_nonlocal_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Expr => Self::Expr(stmt_expr_from_object_with_range( + vm, + source_file, + object, + range, + )?), + StmtKind::Pass => Self::Pass(stmt_pass_from_object_with_range(range)), + StmtKind::Break => Self::Break(stmt_break_from_object_with_range(range)), + StmtKind::Continue => Self::Continue(stmt_continue_from_object_with_range(range)), }) } } // constructor +fn stmt_function_def_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, + is_async: bool, +) -> PyResult { + let typ = if is_async { + "AsyncFunctionDef" + } else { + "FunctionDef" + }; + let name = get_required_identifier_field(vm, source_file, &object, "name", typ)?; + let parameters = Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "args", typ)?, + )?; + let body: Vec> = get_node_list_field(vm, source_file, &object, "body", typ)?; + let decorator_list: Vec> = + get_node_list_field(vm, source_file, &object, "decorator_list", typ)?; + let public_decorator_list = public_decorator_expr_list(&decorator_list); + let node_index = public_node_list_overrides_node_index( + vec![(super::constant::PublicAstStmtListField::Body, &body)], + vec![( + super::constant::PublicAstExprListField::DecoratorList, + &public_decorator_list, + )], + None, + None, + ); + let body = lower_public_stmt_list(body); + let decorator_list = lower_public_decorator_list(decorator_list); + let returns = get_node_field_opt(vm, &object, "returns")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?; + let type_comment = get_ast_string_field_opt(vm, &object, "type_comment")?; + register_public_ast_stmt_type_comment(&node_index, type_comment); + let type_params = type_params_from_field(vm, source_file, &object, "type_params", typ)?; + Ok(ast::StmtFunctionDef { + node_index, + name, + parameters, + body, + decorator_list, + returns, + type_params, + range, + is_async, + }) +} + impl Node for ast::StmtFunctionDef { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, name, parameters, body, decorator_list, returns, - // type_comment, type_params, is_async, - range: _range, + range, } = self; - let source_code = source_file.to_source_code(); - let def_line = source_code.line_index(name.range.start()); - let range = TextRange::new(source_code.line_start(def_line), _range.end()); + let range = definition_range_from_name( + source_file, + name.range.start(), + range.end(), + if is_async { "async" } else { "def" }, + ); let cls = if !is_async { pyast::NodeStmtFunctionDef::static_type().to_owned() @@ -165,18 +411,36 @@ impl Node for ast::StmtFunctionDef { .unwrap(); dict.set_item("args", parameters.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("body", body.ast_to_object(vm, source_file), vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("body", body, vm).unwrap(); dict.set_item( "decorator_list", - decorator_list.ast_to_object(vm, source_file), + super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::DecoratorList, + ) + .map_or_else( + || decorator_list.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ), vm, ) .unwrap(); dict.set_item("returns", returns.ast_to_object(vm, source_file), vm) .unwrap(); - // Ruff AST doesn't track type_comment, so always set to None - dict.set_item("type_comment", vm.ctx.none(), vm).unwrap(); + dict.set_item( + "type_comment", + public_ast_stmt_type_comment_object(vm, node_index.load()), + vm, + ) + .unwrap(); dict.set_item( "type_params", type_params.map_or_else( @@ -195,66 +459,74 @@ impl Node for ast::StmtFunctionDef { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - let _cls = _object.class(); - let is_async = _cls.is(pyast::NodeStmtAsyncFunctionDef::static_type()); - let range = range_from_object(_vm, source_file, _object.clone(), "FunctionDef")?; - Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "name", "FunctionDef")?, - )?, - parameters: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "args", "FunctionDef")?, - )?, - body: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "body", "FunctionDef")?, - )?, - decorator_list: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "decorator_list", "FunctionDef")?, - )?, - returns: get_node_field_opt(_vm, &_object, "returns")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - // TODO: Ruff ignores type_comment during parsing - // type_comment: get_node_field_opt(_vm, &_object, "type_comment")? - // .map(|obj| Node::ast_from_object(_vm, obj)) - // .transpose()?, - type_params: Node::ast_from_object( - _vm, - source_file, - get_node_field_opt(_vm, &_object, "type_params")? - .unwrap_or_else(|| _vm.ctx.new_list(Vec::new()).into()), - )?, - range, - is_async, - }) + let is_async = is_node_instance( + _vm, + &_object, + pyast::NodeStmtAsyncFunctionDef::static_type(), + )?; + let typ = if is_async { + "AsyncFunctionDef" + } else { + "FunctionDef" + }; + let range = range_from_object(_vm, source_file, _object.clone(), typ)?; + stmt_function_def_from_object_with_range(_vm, source_file, _object, range, is_async) } } // constructor +fn stmt_class_def_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let name = get_required_identifier_field(vm, source_file, &object, "name", "ClassDef")?; + let bases = PositionalArguments::ast_from_field(vm, source_file, &object, "bases", "ClassDef")?; + let keywords = + KeywordArguments::ast_from_field(vm, source_file, &object, "keywords", "ClassDef")?; + let body: Vec> = + get_node_list_field(vm, source_file, &object, "body", "ClassDef")?; + let decorator_list: Vec> = + get_node_list_field(vm, source_file, &object, "decorator_list", "ClassDef")?; + let public_decorator_list = public_decorator_expr_list(&decorator_list); + let node_index = public_node_list_overrides_node_index( + vec![(super::constant::PublicAstStmtListField::Body, &body)], + vec![( + super::constant::PublicAstExprListField::DecoratorList, + &public_decorator_list, + )], + None, + None, + ); + let body = lower_public_stmt_list(body); + let decorator_list = lower_public_decorator_list(decorator_list); + let type_params = type_params_from_field(vm, source_file, &object, "type_params", "ClassDef")?; + Ok(ast::StmtClassDef { + node_index, + name, + arguments: merge_class_def_args(Some(bases), Some(keywords)), + body, + decorator_list, + type_params, + range, + }) +} + impl Node for ast::StmtClassDef { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, name, arguments, body, decorator_list, type_params, - range: _range, + range, } = self; let (bases, keywords) = split_class_def_args(arguments); - let source_code = source_file.to_source_code(); - let class_line = source_code.line_index(name.range.start()); - let range = TextRange::new(source_code.line_start(class_line), _range.end()); + let range = + definition_range_from_name(source_file, name.range.start(), range.end(), "class"); let node = NodeAst .into_ref_with_type(_vm, pyast::NodeStmtClassDef::static_type().to_owned()) .unwrap(); @@ -279,11 +551,25 @@ impl Node for ast::StmtClassDef { _vm, ) .unwrap(); - dict.set_item("body", body.ast_to_object(_vm, source_file), _vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("body", body, _vm).unwrap(); dict.set_item( "decorator_list", - decorator_list.ast_to_object(_vm, source_file), + super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::DecoratorList, + ) + .map_or_else( + || decorator_list.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ), _vm, ) .unwrap(); @@ -304,45 +590,26 @@ impl Node for ast::StmtClassDef { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - let bases = Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "bases", "ClassDef")?, - )?; - let keywords = Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "keywords", "ClassDef")?, - )?; - Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "name", "ClassDef")?, - )?, - arguments: merge_class_def_args(bases, keywords), - body: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "body", "ClassDef")?, - )?, - decorator_list: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "decorator_list", "ClassDef")?, - )?, - type_params: Node::ast_from_object( - _vm, - source_file, - get_node_field_opt(_vm, &_object, "type_params")? - .unwrap_or_else(|| _vm.ctx.new_list(Vec::new()).into()), - )?, - range: range_from_object(_vm, source_file, _object, "ClassDef")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "ClassDef")?; + stmt_class_def_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_return_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtReturn { + node_index: Default::default(), + value: get_node_field_opt(vm, &object, "value")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::StmtReturn { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -364,20 +631,32 @@ impl Node for ast::StmtReturn { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: get_node_field_opt(_vm, &_object, "value")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - range: range_from_object(_vm, source_file, _object, "Return")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Return")?; + stmt_return_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_delete_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let targets: Vec> = + get_node_list_field(vm, source_file, &object, "targets", "Delete")?; + let (node_index, targets) = + public_expr_list_from_values(super::constant::PublicAstExprListField::Targets, targets); + Ok(ast::StmtDelete { + node_index, + targets, + range, + }) +} + impl Node for ast::StmtDelete { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, targets, range: _range, } = self; @@ -385,8 +664,15 @@ impl Node for ast::StmtDelete { .into_ref_with_type(_vm, pyast::NodeStmtDelete::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("targets", targets.ast_to_object(_vm, source_file), _vm) - .unwrap(); + let targets = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Targets, + ) + .map_or_else( + || targets.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("targets", targets, _vm).unwrap(); node_add_location(&dict, _range, _vm, source_file); node.into() } @@ -395,38 +681,62 @@ impl Node for ast::StmtDelete { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - targets: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "targets", "Delete")?, - )?, - range: range_from_object(_vm, source_file, _object, "Delete")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Delete")?; + stmt_delete_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_assign_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let targets: Vec> = + get_node_list_field(vm, source_file, &object, "targets", "Assign")?; + let (node_index, targets) = + public_expr_list_from_values(super::constant::PublicAstExprListField::Targets, targets); + let value = get_required_node_field(vm, source_file, &object, "value", "Assign")?; + let type_comment = get_ast_string_field_opt(vm, &object, "type_comment")?; + register_public_ast_stmt_type_comment(&node_index, type_comment); + Ok(ast::StmtAssign { + node_index, + targets, + value, + range, + }) +} + impl Node for ast::StmtAssign { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, targets, value, - // type_comment, range, } = self; let node = NodeAst .into_ref_with_type(vm, pyast::NodeStmtAssign::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("targets", targets.ast_to_object(vm, source_file), vm) - .unwrap(); + let targets = super::constant::public_ast_expr_list_object( + node_index.load(), + super::constant::PublicAstExprListField::Targets, + ) + .map_or_else( + || targets.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ); + dict.set_item("targets", targets, vm).unwrap(); dict.set_item("value", value.ast_to_object(vm, source_file), vm) .unwrap(); - // TODO - dict.set_item("type_comment", vm.ctx.none(), vm).unwrap(); + dict.set_item( + "type_comment", + public_ast_stmt_type_comment_object(vm, node_index.load()), + vm, + ) + .unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -435,27 +745,27 @@ impl Node for ast::StmtAssign { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - targets: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "targets", "Assign")?, - )?, - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "Assign")?, - )?, - // type_comment: get_node_field_opt(_vm, &_object, "type_comment")? - // .map(|obj| Node::ast_from_object(_vm, obj)) - // .transpose()?, - range: range_from_object(vm, source_file, object, "Assign")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Assign")?; + stmt_assign_from_object_with_range(vm, source_file, object, range) } } // constructor +fn stmt_type_alias_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtTypeAlias { + node_index: Default::default(), + name: get_required_node_field(vm, source_file, &object, "name", "TypeAlias")?, + type_params: type_params_from_field(vm, source_file, &object, "type_params", "TypeAlias")?, + value: get_required_node_field(vm, source_file, &object, "value", "TypeAlias")?, + range, + }) +} + impl Node for ast::StmtTypeAlias { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -491,29 +801,31 @@ impl Node for ast::StmtTypeAlias { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "name", "TypeAlias")?, - )?, - type_params: Node::ast_from_object( - _vm, - source_file, - get_node_field_opt(_vm, &_object, "type_params")?.unwrap_or_else(|| _vm.ctx.none()), - )?, - value: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "value", "TypeAlias")?, - )?, - range: range_from_object(_vm, source_file, _object, "TypeAlias")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "TypeAlias")?; + stmt_type_alias_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_aug_assign_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtAugAssign { + node_index: Default::default(), + target: get_required_node_field(vm, source_file, &object, "target", "AugAssign")?, + op: Node::ast_from_object( + vm, + source_file, + get_node_field_required(vm, &object, "op", "AugAssign")?, + )?, + value: get_required_node_field(vm, source_file, &object, "value", "AugAssign")?, + range, + }) +} + impl Node for ast::StmtAugAssign { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -541,33 +853,41 @@ impl Node for ast::StmtAugAssign { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - target: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "target", "AugAssign")?, - )?, - op: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "op", "AugAssign")?, - )?, - value: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "value", "AugAssign")?, - )?, - range: range_from_object(_vm, source_file, _object, "AugAssign")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "AugAssign")?; + stmt_aug_assign_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_ann_assign_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let simple = node_object_to_i32(vm, get_node_field(vm, &object, "simple", "AnnAssign")?)?; + let node_index = ast::AtomicNodeIndex::NONE; + if simple != 0 && simple != 1 { + node_index.set(super::constant::register_public_ast_ann_assign_simple( + simple, + )); + } + Ok(ast::StmtAnnAssign { + node_index, + target: get_required_node_field(vm, source_file, &object, "target", "AnnAssign")?, + annotation: get_required_node_field(vm, source_file, &object, "annotation", "AnnAssign")?, + value: get_node_field_opt(vm, &object, "value")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + simple: simple != 0, + range, + }) +} + impl Node for ast::StmtAnnAssign { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, target, annotation, value, @@ -588,8 +908,12 @@ impl Node for ast::StmtAnnAssign { .unwrap(); dict.set_item("value", value.ast_to_object(_vm, source_file), _vm) .unwrap(); - dict.set_item("simple", simple.ast_to_object(_vm, source_file), _vm) - .unwrap(); + let simple = super::constant::public_ast_ann_assign_simple_object(node_index.load()) + .map_or_else( + || simple.ast_to_object(_vm, source_file), + |simple| _vm.ctx.new_int(simple).into(), + ); + dict.set_item("simple", simple, _vm).unwrap(); node_add_location(&dict, _range, _vm, source_file); node.into() } @@ -598,42 +922,53 @@ impl Node for ast::StmtAnnAssign { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - target: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "target", "AnnAssign")?, - )?, - annotation: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "annotation", "AnnAssign")?, - )?, - value: get_node_field_opt(_vm, &_object, "value")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - simple: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "simple", "AnnAssign")?, - )?, - range: range_from_object(_vm, source_file, _object, "AnnAssign")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "AnnAssign")?; + stmt_ann_assign_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_for_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, + is_async: bool, +) -> PyResult { + let typ = if is_async { "AsyncFor" } else { "For" }; + let target = get_required_node_field(vm, source_file, &object, "target", typ)?; + let iter = get_required_node_field(vm, source_file, &object, "iter", typ)?; + let body: Vec> = get_node_list_field(vm, source_file, &object, "body", typ)?; + let orelse: Vec> = + get_node_list_field(vm, source_file, &object, "orelse", typ)?; + let node_index = public_stmt_lists_node_index([ + (super::constant::PublicAstStmtListField::Body, &body), + (super::constant::PublicAstStmtListField::Orelse, &orelse), + ]); + let body = lower_public_stmt_list(body); + let orelse = lower_public_stmt_list(orelse); + let type_comment = get_ast_string_field_opt(vm, &object, "type_comment")?; + register_public_ast_stmt_type_comment(&node_index, type_comment); + Ok(ast::StmtFor { + node_index, + target, + iter, + body, + orelse, + range, + is_async, + }) +} + impl Node for ast::StmtFor { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, is_async, target, iter, body, orelse, - // type_comment, range: _range, } = self; @@ -649,12 +984,30 @@ impl Node for ast::StmtFor { .unwrap(); dict.set_item("iter", iter.ast_to_object(_vm, source_file), _vm) .unwrap(); - dict.set_item("body", body.ast_to_object(_vm, source_file), _vm) - .unwrap(); - dict.set_item("orelse", orelse.ast_to_object(_vm, source_file), _vm) - .unwrap(); - // Ruff AST doesn't track type_comment, so always set to None - dict.set_item("type_comment", _vm.ctx.none(), _vm).unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("body", body, _vm).unwrap(); + let orelse = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + ) + .map_or_else( + || orelse.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("orelse", orelse, _vm).unwrap(); + dict.set_item( + "type_comment", + public_ast_stmt_type_comment_object(_vm, node_index.load()), + _vm, + ) + .unwrap(); node_add_location(&dict, _range, _vm, source_file); node.into() } @@ -664,48 +1017,45 @@ impl Node for ast::StmtFor { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - let _cls = _object.class(); debug_assert!( - _cls.is(pyast::NodeStmtFor::static_type()) - || _cls.is(pyast::NodeStmtAsyncFor::static_type()) + is_node_instance(_vm, &_object, pyast::NodeStmtFor::static_type())? + || is_node_instance(_vm, &_object, pyast::NodeStmtAsyncFor::static_type())? ); - let is_async = _cls.is(pyast::NodeStmtAsyncFor::static_type()); - Ok(Self { - node_index: Default::default(), - target: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "target", "For")?, - )?, - iter: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "iter", "For")?, - )?, - body: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "body", "For")?, - )?, - orelse: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "orelse", "For")?, - )?, - // type_comment: get_node_field_opt(_vm, &_object, "type_comment")? - // .map(|obj| Node::ast_from_object(_vm, obj)) - // .transpose()?, - range: range_from_object(_vm, source_file, _object, "For")?, - is_async, - }) + let is_async = is_node_instance(_vm, &_object, pyast::NodeStmtAsyncFor::static_type())?; + let typ = if is_async { "AsyncFor" } else { "For" }; + let range = range_from_object(_vm, source_file, _object.clone(), typ)?; + stmt_for_from_object_with_range(_vm, source_file, _object, range, is_async) } } // constructor +fn stmt_while_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let body: Vec> = + get_node_list_field(vm, source_file, &object, "body", "While")?; + let orelse: Vec> = + get_node_list_field(vm, source_file, &object, "orelse", "While")?; + let node_index = public_stmt_lists_node_index([ + (super::constant::PublicAstStmtListField::Body, &body), + (super::constant::PublicAstStmtListField::Orelse, &orelse), + ]); + Ok(ast::StmtWhile { + node_index, + test: get_required_node_field(vm, source_file, &object, "test", "While")?, + body: lower_public_stmt_list(body), + orelse: lower_public_stmt_list(orelse), + range, + }) +} + impl Node for ast::StmtWhile { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, test, body, orelse, @@ -717,10 +1067,24 @@ impl Node for ast::StmtWhile { let dict = node.as_object().dict().unwrap(); dict.set_item("test", test.ast_to_object(_vm, source_file), _vm) .unwrap(); - dict.set_item("body", body.ast_to_object(_vm, source_file), _vm) - .unwrap(); - dict.set_item("orelse", orelse.ast_to_object(_vm, source_file), _vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("body", body, _vm).unwrap(); + let orelse = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + ) + .map_or_else( + || orelse.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("orelse", orelse, _vm).unwrap(); node_add_location(&dict, _range, _vm, source_file); node.into() } @@ -730,32 +1094,15 @@ impl Node for ast::StmtWhile { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - test: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "test", "While")?, - )?, - body: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "body", "While")?, - )?, - orelse: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "orelse", "While")?, - )?, - range: range_from_object(_vm, source_file, _object, "While")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "While")?; + stmt_while_from_object_with_range(_vm, source_file, _object, range) } } // constructor impl Node for ast::StmtIf { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, test, body, range, @@ -763,7 +1110,7 @@ impl Node for ast::StmtIf { } = self; elif_else_clause::ast_to_object( ast::ElifElseClause { - node_index: Default::default(), + node_index, range, test: Some(*test), body, @@ -778,18 +1125,41 @@ impl Node for ast::StmtIf { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - elif_else_clause::ast_from_object(vm, source_file, object) + let range = range_from_object(vm, source_file, object.clone(), "If")?; + elif_else_clause::ast_from_object_with_range(vm, source_file, object, range) } } // constructor +fn stmt_with_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, + is_async: bool, +) -> PyResult { + let typ = if is_async { "AsyncWith" } else { "With" }; + let items = get_node_list_field(vm, source_file, &object, "items", typ)?; + let body: Vec> = get_node_list_field(vm, source_file, &object, "body", typ)?; + let (node_index, body) = + public_stmt_list_from_values(super::constant::PublicAstStmtListField::Body, body); + let type_comment = get_ast_string_field_opt(vm, &object, "type_comment")?; + register_public_ast_stmt_type_comment(&node_index, type_comment); + Ok(ast::StmtWith { + node_index, + items, + body, + range, + is_async, + }) +} + impl Node for ast::StmtWith { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, is_async, items, body, - // type_comment, range: _range, } = self; @@ -803,10 +1173,21 @@ impl Node for ast::StmtWith { let dict = node.as_object().dict().unwrap(); dict.set_item("items", items.ast_to_object(_vm, source_file), _vm) .unwrap(); - dict.set_item("body", body.ast_to_object(_vm, source_file), _vm) - .unwrap(); - // Ruff AST doesn't track type_comment, so always set to None - dict.set_item("type_comment", _vm.ctx.none(), _vm).unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("body", body, _vm).unwrap(); + dict.set_item( + "type_comment", + public_ast_stmt_type_comment_object(_vm, node_index.load()), + _vm, + ) + .unwrap(); node_add_location(&dict, _range, _vm, source_file); node.into() } @@ -815,33 +1196,31 @@ impl Node for ast::StmtWith { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - let _cls = _object.class(); debug_assert!( - _cls.is(pyast::NodeStmtWith::static_type()) - || _cls.is(pyast::NodeStmtAsyncWith::static_type()) + is_node_instance(_vm, &_object, pyast::NodeStmtWith::static_type())? + || is_node_instance(_vm, &_object, pyast::NodeStmtAsyncWith::static_type())? ); - let is_async = _cls.is(pyast::NodeStmtAsyncWith::static_type()); - Ok(Self { - node_index: Default::default(), - items: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "items", "With")?, - )?, - body: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "body", "With")?, - )?, - // type_comment: get_node_field_opt(_vm, &_object, "type_comment")? - // .map(|obj| Node::ast_from_object(_vm, obj)) - // .transpose()?, - range: range_from_object(_vm, source_file, _object, "With")?, - is_async, - }) + let is_async = is_node_instance(_vm, &_object, pyast::NodeStmtAsyncWith::static_type())?; + let typ = if is_async { "AsyncWith" } else { "With" }; + let range = range_from_object(_vm, source_file, _object.clone(), typ)?; + stmt_with_from_object_with_range(_vm, source_file, _object, range, is_async) } } // constructor +fn stmt_match_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtMatch { + node_index: Default::default(), + subject: get_required_node_field(vm, source_file, &object, "subject", "Match")?, + cases: get_node_list_field(vm, source_file, &object, "cases", "Match")?, + range, + }) +} + impl Node for ast::StmtMatch { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -866,23 +1245,29 @@ impl Node for ast::StmtMatch { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - subject: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "subject", "Match")?, - )?, - cases: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "cases", "Match")?, - )?, - range: range_from_object(_vm, source_file, _object, "Match")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Match")?; + stmt_match_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_raise_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtRaise { + node_index: Default::default(), + exc: get_node_field_opt(vm, &object, "exc")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + cause: get_node_field_opt(vm, &object, "cause")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::StmtRaise { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -907,23 +1292,119 @@ impl Node for ast::StmtRaise { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - exc: get_node_field_opt(_vm, &_object, "exc")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - cause: get_node_field_opt(_vm, &_object, "cause")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - range: range_from_object(_vm, source_file, _object, "Raise")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Raise")?; + stmt_raise_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_try_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, + is_star: bool, +) -> PyResult { + let typ = if is_star { "TryStar" } else { "Try" }; + let body: Vec> = get_node_list_field(vm, source_file, &object, "body", typ)?; + let orelse: Vec> = + get_node_list_field(vm, source_file, &object, "orelse", typ)?; + let finalbody: Vec> = + get_node_list_field(vm, source_file, &object, "finalbody", typ)?; + let (public_handlers, handlers) = + except_handler_list_from_field(vm, source_file, &object, typ, is_star, range)?; + let stmt_lists: Vec<_> = [ + (super::constant::PublicAstStmtListField::Body, &body), + (super::constant::PublicAstStmtListField::Orelse, &orelse), + ( + super::constant::PublicAstStmtListField::FinalBody, + &finalbody, + ), + ] + .into_iter() + .filter(|(_, values)| values.iter().any(Option::is_none)) + .map(|(field, values)| (field, values.clone())) + .collect(); + let public_handlers = public_handlers + .iter() + .any(Option::is_none) + .then_some(public_handlers); + let node_index = ast::AtomicNodeIndex::NONE; + if !stmt_lists.is_empty() || public_handlers.is_some() { + node_index.set(super::constant::register_public_ast_try_lists( + stmt_lists, + public_handlers, + )); + } + Ok(ast::StmtTry { + node_index, + body: lower_public_stmt_list(body), + handlers, + orelse: lower_public_stmt_list(orelse), + finalbody: lower_public_stmt_list(finalbody), + range, + is_star, + }) +} + +fn except_handler_list_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + typ: &str, + is_try_star: bool, + range: TextRange, +) -> PyResult<(Vec>, Vec)> { + let value = get_node_list_field_object(vm, object, "handlers", typ)?; + let list = value.downcast_ref::().unwrap(); + let len = list.borrow_vec().len(); + let mut result = Vec::with_capacity(len); + let mut public_values = Vec::with_capacity(len); + let recursion_context = format!(" while traversing '{typ}' node"); + for i in 0..len { + let item = { + let items = list.borrow_vec(); + if items.len() != len { + return Err(vm.new_runtime_error(format!( + r#"{typ} field "handlers" changed size during iteration"# + ))); + } + items[i].clone() + }; + let public_handler = if vm.is_none(&item) { + None + } else { + Some(vm.with_recursion(&recursion_context, || { + if is_try_star { + except_handler_from_object_unvalidated_range(vm, source_file, item) + } else { + Node::ast_from_object(vm, source_file, item) + } + })?) + }; + let handler = public_handler.clone().unwrap_or_else(|| { + ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + node_index: Default::default(), + range, + type_: None, + name: None, + body: Vec::new().into(), + }) + }); + public_values.push(public_handler); + result.push(handler); + if list.borrow_vec().len() != len { + return Err(vm.new_runtime_error(format!( + r#"{typ} field "handlers" changed size during iteration"# + ))); + } + } + Ok((public_values, result)) +} + impl Node for ast::StmtTry { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { - node_index: _, + node_index, body, handlers, orelse, @@ -942,14 +1423,39 @@ impl Node for ast::StmtTry { let node = NodeAst.into_ref_with_type(_vm, cls).unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("body", body.ast_to_object(_vm, source_file), _vm) - .unwrap(); - dict.set_item("handlers", handlers.ast_to_object(_vm, source_file), _vm) - .unwrap(); - dict.set_item("orelse", orelse.ast_to_object(_vm, source_file), _vm) - .unwrap(); - dict.set_item("finalbody", finalbody.ast_to_object(_vm, source_file), _vm) - .unwrap(); + let body = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Body, + ) + .map_or_else( + || body.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("body", body, _vm).unwrap(); + let handlers = super::constant::public_ast_except_handler_list_object(node_index.load()) + .map_or_else( + || handlers.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("handlers", handlers, _vm).unwrap(); + let orelse = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + ) + .map_or_else( + || orelse.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("orelse", orelse, _vm).unwrap(); + let finalbody = super::constant::public_ast_stmt_list_object( + node_index.load(), + super::constant::PublicAstStmtListField::FinalBody, + ) + .map_or_else( + || finalbody.ast_to_object(_vm, source_file), + |values| values.values.ast_to_object(_vm, source_file), + ); + dict.set_item("finalbody", finalbody, _vm).unwrap(); node_add_location(&dict, _range, _vm, source_file); node.into() } @@ -958,42 +1464,34 @@ impl Node for ast::StmtTry { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - let _cls = _object.class(); - let is_star = _cls.is(pyast::NodeStmtTryStar::static_type()); - let _cls = _object.class(); + let is_star = is_node_instance(_vm, &_object, pyast::NodeStmtTryStar::static_type())?; debug_assert!( - _cls.is(pyast::NodeStmtTry::static_type()) - || _cls.is(pyast::NodeStmtTryStar::static_type()) + is_node_instance(_vm, &_object, pyast::NodeStmtTry::static_type())? + || is_node_instance(_vm, &_object, pyast::NodeStmtTryStar::static_type())? ); - - Ok(Self { - node_index: Default::default(), - body: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "body", "Try")?, - )?, - handlers: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "handlers", "Try")?, - )?, - orelse: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "orelse", "Try")?, - )?, - finalbody: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "finalbody", "Try")?, - )?, - range: range_from_object(_vm, source_file, _object, "Try")?, - is_star, - }) + let typ = if is_star { "TryStar" } else { "Try" }; + let range = range_from_object(_vm, source_file, _object.clone(), typ)?; + stmt_try_from_object_with_range(_vm, source_file, _object, range, is_star) } } + // constructor +fn stmt_assert_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtAssert { + node_index: Default::default(), + test: get_required_node_field(vm, source_file, &object, "test", "Assert")?, + msg: get_node_field_opt(vm, &object, "msg")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::StmtAssert { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1018,21 +1516,25 @@ impl Node for ast::StmtAssert { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - test: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "test", "Assert")?, - )?, - msg: get_node_field_opt(_vm, &_object, "msg")? - .map(|obj| Node::ast_from_object(_vm, source_file, obj)) - .transpose()?, - range: range_from_object(_vm, source_file, _object, "Assert")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Assert")?; + stmt_assert_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_import_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtImport { + node_index: Default::default(), + names: get_node_list_field(vm, source_file, &object, "names", "Import")?, + range, + is_lazy: false, + }) +} + impl Node for ast::StmtImport { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1056,19 +1558,55 @@ impl Node for ast::StmtImport { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - names: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "names", "Import")?, - )?, - range: range_from_object(_vm, source_file, _object, "Import")?, - is_lazy: false, // Placeholder - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Import")?; + stmt_import_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_import_from_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let (level, raw_level) = import_from_level_from_field(vm, &object)?; + let node_index = { + let node_index = ast::AtomicNodeIndex::NONE; + if let Some(raw_level) = raw_level.filter(|level| *level < 0) { + node_index.set(super::constant::register_public_ast_import_from_level( + raw_level, + )); + } + node_index + }; + Ok(ast::StmtImportFrom { + node_index, + module: get_node_field_opt(vm, &object, "module")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + names: get_node_list_field(vm, source_file, &object, "names", "ImportFrom")?, + level, + range, + is_lazy: false, + }) +} + +fn import_from_level_from_field( + vm: &VirtualMachine, + object: &PyObjectRef, +) -> PyResult<(u32, Option)> { + let Some(value) = get_node_field_opt(vm, object, "level")? else { + return Ok((0, None)); + }; + let level = vm.with_recursion(" while traversing 'ImportFrom' node", || { + node_object_to_i32(vm, value) + })?; + if level < 0 { + return Ok((0, Some(level))); + } + Ok((level as u32, None)) +} + impl Node for ast::StmtImportFrom { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1098,34 +1636,24 @@ impl Node for ast::StmtImportFrom { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - module: get_node_field_opt(vm, &_object, "module")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - names: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &_object, "names", "ImportFrom")?, - )?, - level: get_node_field_opt(vm, &_object, "level")? - .map(|obj| -> PyResult { - let int: PyRef = obj.try_into_value(vm)?; - let value: i64 = int.try_to_primitive(vm)?; - if value < 0 { - return Err(vm.new_value_error("Negative ImportFrom level")); - } - u32::try_from(value) - .map_err(|_| vm.new_overflow_error("ImportFrom level out of range")) - }) - .transpose()? - .unwrap_or(0), - range: range_from_object(vm, source_file, _object, "ImportFrom")?, - is_lazy: false, // Placeholder - }) + let range = range_from_object(vm, source_file, _object.clone(), "ImportFrom")?; + stmt_import_from_from_object_with_range(vm, source_file, _object, range) } } // constructor +fn stmt_global_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtGlobal { + node_index: Default::default(), + names: get_node_list_field(vm, source_file, &object, "names", "Global")?, + range, + }) +} + impl Node for ast::StmtGlobal { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1147,18 +1675,24 @@ impl Node for ast::StmtGlobal { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - names: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "names", "Global")?, - )?, - range: range_from_object(_vm, source_file, _object, "Global")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Global")?; + stmt_global_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_nonlocal_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtNonlocal { + node_index: Default::default(), + names: get_node_list_field(vm, source_file, &object, "names", "Nonlocal")?, + range, + }) +} + impl Node for ast::StmtNonlocal { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1180,18 +1714,24 @@ impl Node for ast::StmtNonlocal { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - names: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "names", "Nonlocal")?, - )?, - range: range_from_object(_vm, source_file, _object, "Nonlocal")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Nonlocal")?; + stmt_nonlocal_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_expr_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::StmtExpr { + node_index: Default::default(), + value: get_required_node_field(vm, source_file, &object, "value", "Expr")?, + range, + }) +} + impl Node for ast::StmtExpr { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1213,18 +1753,18 @@ impl Node for ast::StmtExpr { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - value: Node::ast_from_object( - _vm, - source_file, - get_node_field(_vm, &_object, "value", "Expr")?, - )?, - range: range_from_object(_vm, source_file, _object, "Expr")?, - }) + let range = range_from_object(_vm, source_file, _object.clone(), "Expr")?; + stmt_expr_from_object_with_range(_vm, source_file, _object, range) } } // constructor +fn stmt_pass_from_object_with_range(range: TextRange) -> ast::StmtPass { + ast::StmtPass { + node_index: Default::default(), + range, + } +} + impl Node for ast::StmtPass { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1265,13 +1805,18 @@ impl Node for ast::StmtPass { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - range: range_from_object(_vm, source_file, _object, "Pass")?, - }) + let range = range_from_object(_vm, source_file, _object, "Pass")?; + Ok(stmt_pass_from_object_with_range(range)) } } // constructor +fn stmt_break_from_object_with_range(range: TextRange) -> ast::StmtBreak { + ast::StmtBreak { + node_index: Default::default(), + range, + } +} + impl Node for ast::StmtBreak { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1291,14 +1836,19 @@ impl Node for ast::StmtBreak { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - range: range_from_object(_vm, source_file, _object, "Break")?, - }) + let range = range_from_object(_vm, source_file, _object, "Break")?; + Ok(stmt_break_from_object_with_range(range)) } } // constructor +fn stmt_continue_from_object_with_range(range: TextRange) -> ast::StmtContinue { + ast::StmtContinue { + node_index: Default::default(), + range, + } +} + impl Node for ast::StmtContinue { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -1317,9 +1867,7 @@ impl Node for ast::StmtContinue { source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - range: range_from_object(_vm, source_file, _object, "Continue")?, - }) + let range = range_from_object(_vm, source_file, _object, "Continue")?; + Ok(stmt_continue_from_object_with_range(range)) } } diff --git a/crates/vm/src/stdlib/_ast/string.rs b/crates/vm/src/stdlib/_ast/string.rs index 24cae476694..08686cbf752 100644 --- a/crates/vm/src/stdlib/_ast/string.rs +++ b/crates/vm/src/stdlib/_ast/string.rs @@ -37,13 +37,22 @@ fn ruff_fstring_element_to_joined_str_part( debug_text: _, // TODO: What is this? conversion, format_spec, - node_index: _, - }) => JoinedStrPart::FormattedValue(FormattedValue { - value: expression, - conversion, - format_spec: ruff_format_spec_to_joined_str(format_spec), - range, - }), + node_index, + }) => { + let override_format_spec = + super::constant::public_ast_formatted_value_object(node_index.load()) + .and_then(|formatted| formatted.format_spec) + .or_else(|| { + ruff_format_spec_to_joined_str(format_spec) + .map(|joined_str| Box::new(joined_str.into_expr())) + }); + JoinedStrPart::FormattedValue(FormattedValue { + value: expression, + conversion, + format_spec: override_format_spec, + range, + }) + } } } @@ -257,7 +266,11 @@ fn ruff_format_spec_to_joined_str( .map(ruff_fstring_element_to_joined_str_part) .collect(); let values = normalize_joined_str_parts(values).into_boxed_slice(); - Some(Box::new(JoinedStr { range, values })) + Some(Box::new(JoinedStr { + range, + values, + public_values: None, + })) } } } @@ -290,38 +303,87 @@ fn ruff_fstring_element_to_ruff_fstring_part( } } -fn joined_str_to_ruff_format_spec( - joined_str: Option>, +fn format_spec_expr_to_ruff_format_spec( + format_spec: Option>, ) -> Option> { - match joined_str { - None => None, - Some(joined_str) => { - let JoinedStr { range, values } = *joined_str; - let elements: Vec<_> = Box::into_iter(values) - .map(joined_str_part_to_ruff_fstring_element) - .collect(); - let format_spec = ast::InterpolatedStringFormatSpec { - node_index: Default::default(), + let format_spec = format_spec?; + let ast::Expr::FString(mut fstring) = *format_spec else { + return None; + }; + let ast::ExprFString { + range, + ref mut value, + node_index: _, + } = fstring; + let default_part = ast::FStringPart::FString(ast::FString { + node_index: Default::default(), + range: Default::default(), + elements: Default::default(), + flags: ast::FStringFlags::empty(), + }); + let mut elements = Vec::new(); + for i in 0..value.as_slice().len() { + let part = core::mem::replace(value.iter_mut().nth(i).unwrap(), default_part.clone()); + match part { + ast::FStringPart::Literal(ast::StringLiteral { range, - elements: elements.into(), - }; - Some(Box::new(format_spec)) + value, + node_index: _, + flags: _, + }) => elements.push(ast::InterpolatedStringElement::Literal( + ast::InterpolatedStringLiteralElement { + node_index: Default::default(), + range, + value, + }, + )), + ast::FStringPart::FString(ast::FString { + elements: fstring_elements, + .. + }) => { + elements.extend(ruff_fstring_element_into_iter(fstring_elements)); + } } } + Some(Box::new(ast::InterpolatedStringFormatSpec { + node_index: Default::default(), + range, + elements: elements.into(), + })) } #[derive(Debug)] pub(super) struct JoinedStr { pub(super) range: TextRange, pub(super) values: Box<[JoinedStrPart]>, + pub(super) public_values: Option>>, } impl JoinedStr { pub(super) fn into_expr(self) -> ast::Expr { - let Self { range, values } = self; - ast::Expr::FString(ast::ExprFString { + let Self { + range, + values, + mut public_values, + } = self; + let values = if values.iter().any(joined_str_part_requires_public_values) { + if public_values.is_none() { + public_values = Some( + values + .into_vec() + .into_iter() + .map(joined_str_part_to_expr) + .map(Some) + .collect(), + ); + } + Vec::new().into_boxed_slice() + } else { + values + }; + let expr = ast::Expr::FString(ast::ExprFString { node_index: Default::default(), - range: Default::default(), + range, value: match values.len() { // ruff represents an empty fstring like this: 0 => ast::FStringValue::single(ast::FString { @@ -333,6 +395,7 @@ impl JoinedStr { 1 => ast::FStringValue::single( Box::<[_]>::into_iter(values) .map(joined_str_part_to_ruff_fstring_element) + .map(Option::unwrap) .map(|element| ast::FString { node_index: Default::default(), range, @@ -345,53 +408,120 @@ impl JoinedStr { _ => ast::FStringValue::concatenated( Box::<[_]>::into_iter(values) .map(joined_str_part_to_ruff_fstring_element) + .map(Option::unwrap) .map(ruff_fstring_element_to_ruff_fstring_part) .collect(), ), }, - }) + }); + if let Some(values) = public_values { + let index = if values.iter().any(Option::is_none) { + super::constant::register_public_ast_node_list_overrides( + Vec::new(), + vec![(super::constant::PublicAstExprListField::Values, values)], + None, + None, + ) + } else { + super::constant::register_public_ast_joined_str( + values.into_iter().flatten().collect(), + ) + }; + ast::HasNodeIndex::node_index(&expr).set(index); + } + expr + } +} + +fn joined_str_part_requires_public_values(part: &JoinedStrPart) -> bool { + matches!( + part, + JoinedStrPart::Constant(Constant { + value, + .. + }) if !matches!(value, ConstantLiteral::Str { .. }) + ) +} + +fn joined_str_part_to_expr(part: JoinedStrPart) -> ast::Expr { + match part { + JoinedStrPart::FormattedValue(value) => formatted_value_to_expr(value), + JoinedStrPart::Constant(value) => value.into_expr(), } } -fn joined_str_part_to_ruff_fstring_element(part: JoinedStrPart) -> ast::InterpolatedStringElement { +fn joined_str_part_to_ruff_fstring_element( + part: JoinedStrPart, +) -> Option { match part { JoinedStrPart::FormattedValue(value) => { - ast::InterpolatedStringElement::Interpolation(ast::InterpolatedElement { - node_index: Default::default(), - range: value.range, - expression: value.value.clone(), - debug_text: None, // TODO: What is this? - conversion: value.conversion, - format_spec: joined_str_to_ruff_format_spec(value.format_spec), - }) + let format_spec = value.format_spec.clone(); + let node_index = + super::constant::register_public_ast_formatted_value(format_spec.clone()); + Some(ast::InterpolatedStringElement::Interpolation( + ast::InterpolatedElement { + node_index: { + let node_index_field = ast::AtomicNodeIndex::NONE; + node_index_field.set(node_index); + node_index_field + }, + range: value.range, + expression: value.value.clone(), + debug_text: None, // TODO: What is this? + conversion: value.conversion, + format_spec: format_spec_expr_to_ruff_format_spec(format_spec), + }, + )) } JoinedStrPart::Constant(value) => { - ast::InterpolatedStringElement::Literal(ast::InterpolatedStringLiteralElement { - node_index: Default::default(), - range: value.range, - value: match value.value { - ConstantLiteral::Str { value, .. } => value, - _ => todo!(), + let Constant { range, value, .. } = value; + let ConstantLiteral::Str { value, .. } = value else { + return None; + }; + Some(ast::InterpolatedStringElement::Literal( + ast::InterpolatedStringLiteralElement { + node_index: Default::default(), + range, + value, }, - }) + )) } } } // constructor +pub(super) fn joined_str_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let values: Vec> = + get_node_list_field(vm, source_file, &object, "values", "JoinedStr")?; + Ok(JoinedStr { + values: Vec::new().into_boxed_slice(), + public_values: Some(values), + range, + }) +} + impl Node for JoinedStr { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - let Self { values, range } = self; + let Self { + values, + public_values, + range, + } = self; let node = NodeAst .into_ref_with_type(vm, pyast::NodeExprJoinedStr::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item( - "values", - BoxedSlice(values).ast_to_object(vm, source_file), - vm, - ) - .unwrap(); + let values = if let Some(public_values) = public_values { + BoxedSlice(public_values.into_boxed_slice()).ast_to_object(vm, source_file) + } else { + BoxedSlice(values).ast_to_object(vm, source_file) + }; + dict.set_item("values", values, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -400,15 +530,8 @@ impl Node for JoinedStr { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let values: BoxedSlice<_> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "values", "JoinedStr")?, - )?; - Ok(Self { - values: values.0, - range: range_from_object(vm, source_file, object, "JoinedStr")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "JoinedStr")?; + joined_str_from_object_with_range(vm, source_file, object, range) } } @@ -431,8 +554,7 @@ impl Node for JoinedStrPart { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - if cls.is(pyast::NodeExprFormattedValue::static_type()) { + if is_node_instance(vm, &object, pyast::NodeExprFormattedValue::static_type())? { Ok(Self::FormattedValue(Node::ast_from_object( vm, source_file, @@ -452,11 +574,31 @@ impl Node for JoinedStrPart { pub(super) struct FormattedValue { value: Box, conversion: ast::ConversionFlag, - format_spec: Option>, + format_spec: Option>, range: TextRange, } // constructor +pub(super) fn formatted_value_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(FormattedValue { + value: get_required_node_field(vm, source_file, &object, "value", "FormattedValue")?, + conversion: Node::ast_from_object( + vm, + source_file, + get_node_field(vm, &object, "conversion", "FormattedValue")?, + )?, + format_spec: get_node_field_opt(vm, &object, "format_spec")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for FormattedValue { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -487,23 +629,19 @@ impl Node for FormattedValue { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "FormattedValue")?, - )?, - conversion: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "conversion", "FormattedValue")?, - )?, - format_spec: get_node_field_opt(vm, &object, "format_spec")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "FormattedValue")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "FormattedValue")?; + formatted_value_from_object_with_range(vm, source_file, object, range) + } +} + +pub(super) fn formatted_value_to_expr(formatted: FormattedValue) -> ast::Expr { + let range = formatted.range; + JoinedStr { + range, + values: vec![JoinedStrPart::FormattedValue(formatted)].into_boxed_slice(), + public_values: None, } + .into_expr() } pub(super) fn fstring_to_object( @@ -511,6 +649,29 @@ pub(super) fn fstring_to_object( source_file: &SourceFile, expression: ast::ExprFString, ) -> PyObjectRef { + if let Some(joined_str) = + super::constant::public_ast_joined_str_object(expression.node_index.load()) + { + return JoinedStr { + range: expression.range, + values: Vec::new().into_boxed_slice(), + public_values: Some(joined_str.values.into_iter().map(Some).collect()), + } + .ast_to_object(vm, source_file); + } + + if let Some(joined_str) = super::constant::public_ast_expr_list_object( + expression.node_index.load(), + super::constant::PublicAstExprListField::Values, + ) { + return JoinedStr { + range: expression.range, + values: Vec::new().into_boxed_slice(), + public_values: Some(joined_str.values), + } + .ast_to_object(vm, source_file); + } + let ast::ExprFString { range, mut value, @@ -555,12 +716,13 @@ pub(super) fn fstring_to_object( if let JoinedStrPart::FormattedValue(value) = part && let Some(format_spec) = &value.format_spec { - warn_invalid_escape_sequences_in_format_spec(vm, source_file, format_spec.range); + warn_invalid_escape_sequences_in_format_spec(vm, source_file, format_spec.range()); } } let c = JoinedStr { range, values: values.into_boxed_slice(), + public_values: None, }; c.ast_to_object(vm, source_file) } @@ -568,6 +730,7 @@ pub(super) fn fstring_to_object( // ===== TString (Template String) Support ===== fn ruff_tstring_element_to_template_str_part( + vm: &VirtualMachine, element: ast::InterpolatedStringElement, source_file: &SourceFile, ) -> TemplateStrPart { @@ -587,7 +750,7 @@ fn ruff_tstring_element_to_template_str_part( debug_text, conversion, format_spec, - node_index: _, + node_index, }) => { let expr_range = extend_expr_range_with_wrapping_parens(source_file, range, expression.range()) @@ -595,20 +758,29 @@ fn ruff_tstring_element_to_template_str_part( let expr_str = if let Some(debug_text) = debug_text { let expr_source = source_file.slice(expr_range); let mut expr_with_debug = String::with_capacity( - debug_text.leading.len() + expr_source.len() + debug_text.trailing.len(), + debug_text.leading().len() + expr_source.len() + debug_text.trailing().len(), ); - expr_with_debug.push_str(&debug_text.leading); + expr_with_debug.push_str(debug_text.leading()); expr_with_debug.push_str(expr_source); - expr_with_debug.push_str(&debug_text.trailing); + expr_with_debug.push_str(debug_text.trailing()); strip_interpolation_expr(&expr_with_debug) } else { tstring_interpolation_expr_str(source_file, range, expr_range) }; + let override_interpolation = + super::constant::public_ast_interpolation_object(vm, node_index.load()); TemplateStrPart::Interpolation(TStringInterpolation { value: expression, - str: expr_str, + str: override_interpolation + .as_ref() + .map_or_else(|| vm.ctx.new_str(expr_str).into(), |(str, _)| str.clone()), conversion, - format_spec: ruff_format_spec_to_joined_str(format_spec), + format_spec: override_interpolation + .and_then(|(_, format_spec)| format_spec) + .or_else(|| { + ruff_format_spec_to_joined_str(format_spec) + .map(|joined_str| Box::new(joined_str.into_expr())) + }), range, }) } @@ -695,60 +867,94 @@ fn strip_interpolation_expr(expr_source: &str) -> String { pub(super) struct TemplateStr { pub(super) range: TextRange, pub(super) values: Box<[TemplateStrPart]>, + pub(super) public_values: Option>>, } pub(super) fn template_str_to_expr( vm: &VirtualMachine, + source_file: &SourceFile, template: TemplateStr, ) -> PyResult { - let TemplateStr { range, values } = template; - let elements = template_parts_to_elements(vm, values)?; + let TemplateStr { + range, + values, + public_values, + } = template; + let elements = template_parts_to_elements(vm, source_file, values)?; let tstring = ast::TString { range, node_index: Default::default(), elements, flags: ast::TStringFlags::empty(), }; - Ok(ast::Expr::TString(ast::ExprTString { + let expr = ast::Expr::TString(ast::ExprTString { node_index: Default::default(), range, value: ast::TStringValue::single(tstring), - })) + }); + if let Some(values) = public_values { + let index = if values.iter().any(Option::is_none) { + super::constant::register_public_ast_node_list_overrides( + Vec::new(), + vec![(super::constant::PublicAstExprListField::Values, values)], + None, + None, + ) + } else { + super::constant::register_public_ast_template_str( + values.into_iter().flatten().collect(), + ) + }; + ast::HasNodeIndex::node_index(&expr).set(index); + } + Ok(expr) } pub(super) fn interpolation_to_expr( vm: &VirtualMachine, + source_file: &SourceFile, interpolation: TStringInterpolation, ) -> PyResult { + let range = interpolation.range; + let format_spec = interpolation.format_spec.clone(); + let str_constant = super::constant::constant_object_to_constant_data( + vm, + source_file, + interpolation.str.clone(), + )?; let part = TemplateStrPart::Interpolation(interpolation); - let elements = template_parts_to_elements(vm, vec![part].into_boxed_slice())?; - let range = TextRange::default(); + let elements = template_parts_to_elements(vm, source_file, vec![part].into_boxed_slice())?; let tstring = ast::TString { range, node_index: Default::default(), elements, flags: ast::TStringFlags::empty(), }; - Ok(ast::Expr::TString(ast::ExprTString { + let expr = ast::Expr::TString(ast::ExprTString { node_index: Default::default(), range, value: ast::TStringValue::single(tstring), - })) + }); + let index = super::constant::register_public_ast_interpolation(str_constant, format_spec); + ast::HasNodeIndex::node_index(&expr).set(index); + Ok(expr) } fn template_parts_to_elements( vm: &VirtualMachine, + source_file: &SourceFile, values: Box<[TemplateStrPart]>, ) -> PyResult { let mut elements = Vec::with_capacity(values.len()); for value in values.into_vec() { - elements.push(template_part_to_element(vm, value)?); + elements.push(template_part_to_element(vm, source_file, value)?); } Ok(ast::InterpolatedStringElements::from(elements)) } fn template_part_to_element( vm: &VirtualMachine, + source_file: &SourceFile, part: TemplateStrPart, ) -> PyResult { match part { @@ -767,16 +973,26 @@ fn template_part_to_element( TemplateStrPart::Interpolation(interpolation) => { let TStringInterpolation { value, + str, conversion, format_spec, range, - .. } = interpolation; - let format_spec = joined_str_to_ruff_format_spec(format_spec); + let str_constant = + super::constant::constant_object_to_constant_data(vm, source_file, str)?; + let node_index = super::constant::register_public_ast_interpolation( + str_constant, + format_spec.clone(), + ); + let format_spec = format_spec_expr_to_ruff_format_spec(format_spec); Ok(ast::InterpolatedStringElement::Interpolation( ast::InterpolatedElement { range, - node_index: Default::default(), + node_index: { + let node_index_field = ast::AtomicNodeIndex::NONE; + node_index_field.set(node_index); + node_index_field + }, expression: value, debug_text: None, conversion, @@ -788,19 +1004,38 @@ fn template_part_to_element( } // constructor +pub(super) fn template_str_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let values: Vec> = + get_node_list_field(vm, source_file, &object, "values", "TemplateStr")?; + Ok(TemplateStr { + values: Vec::new().into_boxed_slice(), + public_values: Some(values), + range, + }) +} + impl Node for TemplateStr { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - let Self { values, range } = self; + let Self { + values, + public_values, + range, + } = self; let node = NodeAst .into_ref_with_type(vm, pyast::NodeExprTemplateStr::static_type().to_owned()) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item( - "values", - BoxedSlice(values).ast_to_object(vm, source_file), - vm, - ) - .unwrap(); + let values = if let Some(public_values) = public_values { + BoxedSlice(public_values.into_boxed_slice()).ast_to_object(vm, source_file) + } else { + BoxedSlice(values).ast_to_object(vm, source_file) + }; + dict.set_item("values", values, vm).unwrap(); node_add_location(&dict, range, vm, source_file); node.into() } @@ -809,15 +1044,8 @@ impl Node for TemplateStr { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let values: BoxedSlice<_> = Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "values", "TemplateStr")?, - )?; - Ok(Self { - values: values.0, - range: range_from_object(vm, source_file, object, "TemplateStr")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "TemplateStr")?; + template_str_from_object_with_range(vm, source_file, object, range) } } @@ -840,8 +1068,7 @@ impl Node for TemplateStrPart { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - if cls.is(pyast::NodeExprInterpolation::static_type()) { + if is_node_instance(vm, &object, pyast::NodeExprInterpolation::static_type())? { Ok(Self::Interpolation(Node::ast_from_object( vm, source_file, @@ -860,13 +1087,38 @@ impl Node for TemplateStrPart { #[derive(Debug)] pub(super) struct TStringInterpolation { value: Box, - str: String, + str: PyObjectRef, conversion: ast::ConversionFlag, - format_spec: Option>, + format_spec: Option>, range: TextRange, } // constructor +pub(super) fn tstring_interpolation_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + let value = get_required_node_field(vm, source_file, &object, "value", "Interpolation")?; + let str = get_node_field(vm, &object, "str", "Interpolation")?; + let conversion = Node::ast_from_object( + vm, + source_file, + get_node_field(vm, &object, "conversion", "Interpolation")?, + )?; + let format_spec: Option> = get_node_field_opt(vm, &object, "format_spec")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?; + Ok(TStringInterpolation { + value, + str, + conversion, + format_spec, + range, + }) +} + impl Node for TStringInterpolation { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -882,8 +1134,7 @@ impl Node for TStringInterpolation { let dict = node.as_object().dict().unwrap(); dict.set_item("value", value.ast_to_object(vm, source_file), vm) .unwrap(); - dict.set_item("str", vm.ctx.new_str(str).into(), vm) - .unwrap(); + dict.set_item("str", str, vm).unwrap(); dict.set_item("conversion", conversion.ast_to_object(vm, source_file), vm) .unwrap(); dict.set_item( @@ -900,25 +1151,8 @@ impl Node for TStringInterpolation { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let str_obj = get_node_field(vm, &object, "str", "Interpolation")?; - let str_val: String = str_obj.try_into_value(vm)?; - Ok(Self { - value: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "value", "Interpolation")?, - )?, - str: str_val, - conversion: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "conversion", "Interpolation")?, - )?, - format_spec: get_node_field_opt(vm, &object, "format_spec")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - range: range_from_object(vm, source_file, object, "Interpolation")?, - }) + let range = range_from_object(vm, source_file, object.clone(), "Interpolation")?; + tstring_interpolation_from_object_with_range(vm, source_file, object, range) } } @@ -927,6 +1161,42 @@ pub(super) fn tstring_to_object( source_file: &SourceFile, expression: ast::ExprTString, ) -> PyObjectRef { + if let Some(template_str) = + super::constant::public_ast_template_str_object(expression.node_index.load()) + { + return TemplateStr { + range: expression.range, + values: Vec::new().into_boxed_slice(), + public_values: Some(template_str.values.into_iter().map(Some).collect()), + } + .ast_to_object(vm, source_file); + } + + if let Some(template_str) = super::constant::public_ast_expr_list_object( + expression.node_index.load(), + super::constant::PublicAstExprListField::Values, + ) { + return TemplateStr { + range: expression.range, + values: Vec::new().into_boxed_slice(), + public_values: Some(template_str.values), + } + .ast_to_object(vm, source_file); + } + + if let Some((str, format_spec)) = + super::constant::public_ast_interpolation_object(vm, expression.node_index.load()) + && let Some(interpolation) = standalone_tstring_interpolation_to_object( + vm, + source_file, + &expression, + str, + format_spec, + ) + { + return interpolation; + } + let ast::ExprTString { range, mut value, @@ -943,6 +1213,7 @@ pub(super) fn tstring_to_object( let tstring = core::mem::replace(value.iter_mut().nth(i).unwrap(), default_tstring.clone()); for element in ruff_fstring_element_into_iter(tstring.elements) { values.push(ruff_tstring_element_to_template_str_part( + vm, element, source_file, )); @@ -952,6 +1223,37 @@ pub(super) fn tstring_to_object( let c = TemplateStr { range, values: values.into_boxed_slice(), + public_values: None, }; c.ast_to_object(vm, source_file) } + +fn standalone_tstring_interpolation_to_object( + vm: &VirtualMachine, + source_file: &SourceFile, + expression: &ast::ExprTString, + str: PyObjectRef, + format_spec: Option>, +) -> Option { + let [tstring] = expression.value.as_slice() else { + return None; + }; + let mut elements = tstring.elements.iter(); + let ast::InterpolatedStringElement::Interpolation(interp) = elements.next()? else { + return None; + }; + if elements.next().is_some() { + return None; + } + let interpolation = TStringInterpolation { + value: interp.expression.clone(), + str, + conversion: interp.conversion, + format_spec: format_spec.or_else(|| { + ruff_format_spec_to_joined_str(interp.format_spec.clone()) + .map(|joined_str| Box::new(joined_str.into_expr())) + }), + range: interp.range, + }; + Some(interpolation.ast_to_object(vm, source_file)) +} diff --git a/crates/vm/src/stdlib/_ast/type_ignore.rs b/crates/vm/src/stdlib/_ast/type_ignore.rs index 6e90ba9b80e..6810df05057 100644 --- a/crates/vm/src/stdlib/_ast/type_ignore.rs +++ b/crates/vm/src/stdlib/_ast/type_ignore.rs @@ -2,6 +2,7 @@ use super::*; use rustpython_compiler_core::SourceFile; pub(super) enum TypeIgnore { + None, TypeIgnore(TypeIgnoreTypeIgnore), } @@ -9,6 +10,7 @@ pub(super) enum TypeIgnore { impl Node for TypeIgnore { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { match self { + Self::None => vm.ctx.none(), Self::TypeIgnore(cons) => cons.ast_to_object(vm, source_file), } } @@ -17,8 +19,9 @@ impl Node for TypeIgnore { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeTypeIgnoreTypeIgnore::static_type()) { + Ok(if vm.is_none(&object) { + Self::None + } else if is_node_instance(vm, &object, pyast::NodeTypeIgnoreTypeIgnore::static_type())? { Self::TypeIgnore(TypeIgnoreTypeIgnore::ast_from_object( vm, source_file, @@ -34,15 +37,14 @@ impl Node for TypeIgnore { } pub(super) struct TypeIgnoreTypeIgnore { - range: TextRange, - lineno: PyRefExact, - tag: PyRefExact, + lineno: i32, + tag: PyObjectRef, } // constructor impl Node for TypeIgnoreTypeIgnore { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - let Self { lineno, tag, range } = self; + let Self { lineno, tag } = self; let node = NodeAst .into_ref_with_type( vm, @@ -50,9 +52,10 @@ impl Node for TypeIgnoreTypeIgnore { ) .unwrap(); let dict = node.as_object().dict().unwrap(); - dict.set_item("lineno", lineno.to_pyobject(vm), vm).unwrap(); - dict.set_item("tag", tag.to_pyobject(vm), vm).unwrap(); - node_add_location(&dict, range, vm, source_file); + dict.set_item("lineno", vm.ctx.new_int(lineno).into(), vm) + .unwrap(); + dict.set_item("tag", tag, vm).unwrap(); + let _ = source_file; node.into() } @@ -61,14 +64,10 @@ impl Node for TypeIgnoreTypeIgnore { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { + let _ = source_file; Ok(Self { - lineno: get_node_field(vm, &object, "lineno", "TypeIgnore")? - .downcast_exact(vm) - .unwrap(), - tag: get_node_field(vm, &object, "tag", "TypeIgnore")? - .downcast_exact(vm) - .unwrap(), - range: range_from_object(vm, source_file, object, "TypeIgnore")?, + lineno: get_int_field(vm, &object, "lineno", "TypeIgnore")?, + tag: node_object_to_ast_string(vm, get_node_field(vm, &object, "tag", "TypeIgnore")?)?, }) } } diff --git a/crates/vm/src/stdlib/_ast/type_parameters.rs b/crates/vm/src/stdlib/_ast/type_parameters.rs index 0424ffbd768..f3b5025e322 100644 --- a/crates/vm/src/stdlib/_ast/type_parameters.rs +++ b/crates/vm/src/stdlib/_ast/type_parameters.rs @@ -3,7 +3,10 @@ use rustpython_compiler_core::SourceFile; impl Node for ast::TypeParams { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - self.type_params.ast_to_object(vm, source_file) + super::constant::public_ast_type_param_list_object(self.node_index.load()).map_or_else( + || self.type_params.ast_to_object(vm, source_file), + |values| values.values.ast_to_object(vm, source_file), + ) } fn ast_from_object( @@ -11,22 +14,55 @@ impl Node for ast::TypeParams { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let type_params: Vec = Node::ast_from_object(vm, source_file, object)?; - let range = Option::zip(type_params.first(), type_params.last()) - .map(|(first, last)| first.range().cover(last.range())) - .unwrap_or_default(); - Ok(Self { - node_index: Default::default(), - type_params, - range, - }) + Ok(type_params_from_values(Node::ast_from_object( + vm, + source_file, + object, + )?)) } fn is_none(&self) -> bool { - self.type_params.is_empty() + self.type_params.is_empty() && self.node_index.load() == ast::NodeIndex::NONE + } +} + +pub(super) fn type_params_from_field( + vm: &VirtualMachine, + source_file: &SourceFile, + object: &PyObject, + field: &'static str, + typ: &str, +) -> PyResult>> { + let type_params: Vec> = + get_node_list_field(vm, source_file, object, field, typ)?; + let type_params = type_params_from_values(type_params); + Ok((!type_params.is_none()).then_some(Box::new(type_params))) +} + +fn type_params_from_values(values: Vec>) -> ast::TypeParams { + let node_index = if values.iter().any(Option::is_none) { + let index = super::constant::register_public_ast_type_param_list(values.clone()); + let node_index = ast::AtomicNodeIndex::NONE; + node_index.set(index); + node_index + } else { + Default::default() + }; + let type_params = lower_nullable_type_params(&values); + let range = Option::zip(type_params.first(), type_params.last()) + .map(|(first, last)| first.range().cover(last.range())) + .unwrap_or_default(); + ast::TypeParams { + node_index, + type_params, + range, } } +fn lower_nullable_type_params(values: &[Option]) -> Vec { + values.iter().filter_map(Clone::clone).collect() +} + // sum impl Node for ast::TypeParam { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -42,35 +78,70 @@ impl Node for ast::TypeParam { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let cls = object.class(); - Ok(if cls.is(pyast::NodeTypeParamTypeVar::static_type()) { - Self::TypeVar(ast::TypeParamTypeVar::ast_from_object( - vm, - source_file, - object, - )?) - } else if cls.is(pyast::NodeTypeParamParamSpec::static_type()) { - Self::ParamSpec(ast::TypeParamParamSpec::ast_from_object( + if vm.is_none(&object) { + return Err(vm.new_type_error(format!( + "expected some sort of type_param, but got {}", + object.repr(vm)? + ))); + } + enum TypeParamKind { + TypeVar, + ParamSpec, + TypeVarTuple, + } + let kind = if is_node_instance(vm, &object, pyast::NodeTypeParamTypeVar::static_type())? { + TypeParamKind::TypeVar + } else if is_node_instance(vm, &object, pyast::NodeTypeParamParamSpec::static_type())? { + TypeParamKind::ParamSpec + } else if is_node_instance(vm, &object, pyast::NodeTypeParamTypeVarTuple::static_type())? { + TypeParamKind::TypeVarTuple + } else { + return Err(vm.new_type_error(format!( + "expected some sort of type_param, but got {}", + object.repr(vm)? + ))); + }; + let range = type_param_range_from_object(vm, source_file, object.clone())?; + Ok(match kind { + TypeParamKind::TypeVar => Self::TypeVar(type_var_from_object_with_range( vm, source_file, object, - )?) - } else if cls.is(pyast::NodeTypeParamTypeVarTuple::static_type()) { - Self::TypeVarTuple(ast::TypeParamTypeVarTuple::ast_from_object( + range, + )?), + TypeParamKind::ParamSpec => Self::ParamSpec(param_spec_from_object_with_range( vm, source_file, object, - )?) - } else { - return Err(vm.new_type_error(format!( - "expected some sort of type_param, but got {}", - object.repr(vm)? - ))); + range, + )?), + TypeParamKind::TypeVarTuple => Self::TypeVarTuple( + type_var_tuple_from_object_with_range(vm, source_file, object, range)?, + ), }) } } // constructor +fn type_var_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::TypeParamTypeVar { + node_index: Default::default(), + name: get_required_identifier_field(vm, source_file, &object, "name", "TypeVar")?, + bound: get_node_field_opt(vm, &object, "bound")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + default: get_node_field_opt(vm, &object, "default_value")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::TypeParamTypeVar { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -99,27 +170,28 @@ impl Node for ast::TypeParamTypeVar { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "name", "TypeVar")?, - )?, - bound: get_node_field_opt(vm, &object, "bound")? - .map(|obj| Node::ast_from_object(vm, source_file, obj)) - .transpose()?, - default: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "default_value", "TypeVar")?, - )?, - range: range_from_object(vm, source_file, object, "TypeVar")?, - }) + let range = type_param_range_from_object(vm, source_file, object.clone())?; + type_var_from_object_with_range(vm, source_file, object, range) } } // constructor +fn param_spec_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::TypeParamParamSpec { + node_index: Default::default(), + name: get_required_identifier_field(vm, source_file, &object, "name", "ParamSpec")?, + default: get_node_field_opt(vm, &object, "default_value")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::TypeParamParamSpec { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -145,24 +217,28 @@ impl Node for ast::TypeParamParamSpec { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "name", "ParamSpec")?, - )?, - default: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "default_value", "ParamSpec")?, - )?, - range: range_from_object(vm, source_file, object, "ParamSpec")?, - }) + let range = type_param_range_from_object(vm, source_file, object.clone())?; + param_spec_from_object_with_range(vm, source_file, object, range) } } // constructor +fn type_var_tuple_from_object_with_range( + vm: &VirtualMachine, + source_file: &SourceFile, + object: PyObjectRef, + range: TextRange, +) -> PyResult { + Ok(ast::TypeParamTypeVarTuple { + node_index: Default::default(), + name: get_required_identifier_field(vm, source_file, &object, "name", "TypeVarTuple")?, + default: get_node_field_opt(vm, &object, "default_value")? + .map(|obj| Node::ast_from_object(vm, source_file, obj)) + .transpose()?, + range, + }) +} + impl Node for ast::TypeParamTypeVarTuple { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { @@ -191,19 +267,7 @@ impl Node for ast::TypeParamTypeVarTuple { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - Ok(Self { - node_index: Default::default(), - name: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "name", "TypeVarTuple")?, - )?, - default: Node::ast_from_object( - vm, - source_file, - get_node_field(vm, &object, "default_value", "TypeVarTuple")?, - )?, - range: range_from_object(vm, source_file, object, "TypeVarTuple")?, - }) + let range = type_param_range_from_object(vm, source_file, object.clone())?; + type_var_tuple_from_object_with_range(vm, source_file, object, range) } } diff --git a/crates/vm/src/stdlib/_ast/validate.rs b/crates/vm/src/stdlib/_ast/validate.rs index cad37d38610..11baa65777d 100644 --- a/crates/vm/src/stdlib/_ast/validate.rs +++ b/crates/vm/src/stdlib/_ast/validate.rs @@ -1,8 +1,83 @@ // spell-checker: ignore assignlist ifexp use super::module::Mod; -use crate::{PyResult, VirtualMachine}; +use crate::{PyResult, VirtualMachine, compiler::CompileError}; +use core::cell::RefCell; use ruff_python_ast as ast; +use rustpython_codegen::error::{CodegenError, CodegenErrorType}; +use rustpython_codegen::{ + PublicAstExprList, PublicAstFormattedValue, PublicAstInterpolation, PublicAstNodeMap, +}; +use rustpython_compiler_core::bytecode::ConstantData; + +type AstConstantOverrides<'a> = Option<&'a PublicAstNodeMap>; +type AstInterpolationOverrides<'a> = Option<&'a PublicAstNodeMap>; +type AstFormattedValueOverrides<'a> = Option<&'a PublicAstNodeMap>; +type AstImportFromLevelOverrides<'a> = + Option<&'a super::constant::PublicAstImportFromLevelOverrideMap>; +type AstInvalidConstantOverrides<'a> = + Option<&'a super::constant::PublicAstInvalidConstantOverrideMap>; +type AstExprListOverrides<'a> = Option<&'a super::constant::PublicAstExprListOverrideMap>; +type AstPatternListOverrides<'a> = Option<&'a super::constant::PublicAstPatternListOverrideMap>; +type AstExprOptionListOverrides<'a> = + Option<&'a super::constant::PublicAstExprOptionListOverrideMap>; +type AstExprListFieldOverrides<'a> = Option<&'a super::constant::PublicAstExprListFieldOverrideMap>; +type AstStmtListOverrides<'a> = Option<&'a super::constant::PublicAstStmtListOverrideMap>; +type AstExceptHandlerListOverrides<'a> = + Option<&'a super::constant::PublicAstExceptHandlerListOverrideMap>; +type AstTypeParamListOverrides<'a> = Option<&'a super::constant::PublicAstTypeParamListOverrideMap>; +type AstMatchClassOverrides<'a> = Option<&'a super::constant::PublicAstMatchClassOverrideMap>; + +thread_local! { + // Validation borrows the same public-AST side tables created in constant.rs; + // these TLS slots add no new storage policy. + static PUBLIC_AST_INVALID_CONSTANTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_JOINED_STRS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_TEMPLATE_STRS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_PATTERN_LISTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_EXPR_OPTION_LISTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_EXPR_LISTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_STMT_LISTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_EXCEPT_HANDLER_LISTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_TYPE_PARAM_LISTS: RefCell> = const { RefCell::new(None) }; + static PUBLIC_AST_MATCH_CLASSES: RefCell> = const { RefCell::new(None) }; +} + +fn public_ast_invalid_constant_type(expr: &ast::Expr) -> Option { + let index = ast::HasNodeIndex::node_index(expr).load(); + if index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_INVALID_CONSTANTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|invalid_constants| invalid_constants.get(&index).cloned()) + }) +} + +fn public_ast_joined_str_values(expr: &ast::ExprFString) -> Option { + let index = expr.node_index.load(); + if index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_JOINED_STRS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|joined_strs| joined_strs.get(&index).cloned()) + }) +} + +fn public_ast_template_str_values(expr: &ast::ExprTString) -> Option { + let index = expr.node_index.load(); + if index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_TEMPLATE_STRS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|template_strs| template_strs.get(&index).cloned()) + }) +} fn expr_context_name(ctx: ast::ExprContext) -> &'static str { match ctx { @@ -13,6 +88,17 @@ fn expr_context_name(ctx: ast::ExprContext) -> &'static str { } } +fn invalid_syntax_error(vm: &VirtualMachine) -> crate::builtins::PyBaseExceptionRef { + vm.new_syntax_error( + &CompileError::Codegen(CodegenError { + location: None, + error: CodegenErrorType::SyntaxError("invalid syntax".to_owned()), + source_path: "".to_owned(), + }), + None, + ) +} + fn validate_name(vm: &VirtualMachine, name: &ast::name::Name) -> PyResult<()> { match name.as_str() { "None" | "True" | "False" => Err(vm.new_value_error(format!( @@ -30,6 +116,12 @@ fn validate_comprehension(vm: &VirtualMachine, gens: &[ast::Comprehension]) -> P for comp in gens { validate_expr(vm, &comp.target, ast::ExprContext::Store)?; validate_expr(vm, &comp.iter, ast::ExprContext::Load)?; + validate_public_expr_list_slots( + vm, + comp.node_index.load(), + super::constant::PublicAstExprListField::Ifs, + ast::ExprContext::Load, + )?; validate_exprs(vm, &comp.ifs, ast::ExprContext::Load, false)?; } Ok(()) @@ -42,30 +134,49 @@ fn validate_keywords(vm: &VirtualMachine, keywords: &[ast::Keyword]) -> PyResult Ok(()) } +fn validate_parameter_annotation(vm: &VirtualMachine, parameter: &ast::Parameter) -> PyResult<()> { + if let Some(annotation) = ¶meter.annotation { + validate_expr(vm, annotation, ast::ExprContext::Load)?; + } + Ok(()) +} + fn validate_parameters(vm: &VirtualMachine, params: &ast::Parameters) -> PyResult<()> { - for param in params - .posonlyargs - .iter() - .chain(¶ms.args) - .chain(¶ms.kwonlyargs) - { - if let Some(annotation) = ¶m.parameter.annotation { - validate_expr(vm, annotation, ast::ExprContext::Load)?; - } - if let Some(default) = ¶m.default { - validate_expr(vm, default, ast::ExprContext::Load)?; - } + for param in params.posonlyargs.iter().chain(¶ms.args) { + validate_parameter_annotation(vm, ¶m.parameter)?; } if let Some(vararg) = ¶ms.vararg && let Some(annotation) = &vararg.annotation { validate_expr(vm, annotation, ast::ExprContext::Load)?; } + for param in ¶ms.kwonlyargs { + validate_parameter_annotation(vm, ¶m.parameter)?; + } if let Some(kwarg) = ¶ms.kwarg && let Some(annotation) = &kwarg.annotation { validate_expr(vm, annotation, ast::ExprContext::Load)?; } + if let Some(defaults) = public_expr_option_list(params.node_index.load()) { + for default in defaults.values { + let Some(default) = default else { + return Err(vm.new_value_error("None disallowed in expression list")); + }; + validate_expr(vm, &default, ast::ExprContext::Load)?; + } + } else { + for param in params.posonlyargs.iter().chain(¶ms.args) { + if let Some(default) = ¶m.default { + validate_expr(vm, default, ast::ExprContext::Load)?; + } + } + } + for param in ¶ms.kwonlyargs { + if let Some(default) = ¶m.default { + validate_expr(vm, default, ast::ExprContext::Load)?; + } + } Ok(()) } @@ -99,9 +210,26 @@ fn validate_assignlist( validate_exprs(vm, targets, ctx, false) } -fn validate_body(vm: &VirtualMachine, body: &[ast::Stmt], owner: &'static str) -> PyResult<()> { +fn validate_body( + vm: &VirtualMachine, + body: &[ast::Stmt], + node_index: ast::NodeIndex, + owner: &'static str, + ast_constant_overrides: AstConstantOverrides<'_>, + ast_import_from_level_overrides: AstImportFromLevelOverrides<'_>, +) -> PyResult<()> { validate_nonempty_seq(vm, body.len(), "body", owner)?; - validate_stmts(vm, body) + validate_public_stmt_list_slots( + vm, + node_index, + super::constant::PublicAstStmtListField::Body, + )?; + validate_stmts( + vm, + body, + ast_constant_overrides, + ast_import_from_level_overrides, + ) } fn validate_interpolated_elements<'a>( @@ -112,33 +240,85 @@ fn validate_interpolated_elements<'a>( if let ast::InterpolatedStringElementRef::Interpolation(interpolation) = element { validate_expr(vm, &interpolation.expression, ast::ExprContext::Load)?; if let Some(format_spec) = &interpolation.format_spec { - for spec_element in &format_spec.elements { - if let ast::InterpolatedStringElement::Interpolation(spec_interp) = spec_element - { - validate_expr(vm, &spec_interp.expression, ast::ExprContext::Load)?; - } - } + validate_interpolated_elements( + vm, + format_spec + .elements + .iter() + .map(ast::InterpolatedStringElementRef::from), + )?; } } } Ok(()) } -fn validate_pattern_match_value(vm: &VirtualMachine, expr: &ast::Expr) -> PyResult<()> { +fn ensure_literal_number(expr: &ast::Expr, allow_real: bool, allow_imaginary: bool) -> bool { + let ast::Expr::NumberLiteral(number) = expr else { + return false; + }; + match number.value { + ast::Number::Int(_) | ast::Number::Float(_) => allow_real, + ast::Number::Complex { .. } => allow_imaginary, + } +} + +fn ensure_literal_negative(expr: &ast::Expr, allow_real: bool, allow_imaginary: bool) -> bool { + let ast::Expr::UnaryOp(unary) = expr else { + return false; + }; + if unary.op != ast::UnaryOp::USub { + return false; + } + ensure_literal_number(&unary.operand, allow_real, allow_imaginary) +} + +fn ensure_literal_complex(expr: &ast::Expr) -> bool { + let ast::Expr::BinOp(bin) = expr else { + return false; + }; + if !matches!(bin.op, ast::Operator::Add | ast::Operator::Sub) { + return false; + } + let real_left = ensure_literal_number(&bin.left, true, false) + || ensure_literal_negative(&bin.left, true, false); + real_left && ensure_literal_number(&bin.right, false, true) +} + +fn public_ast_constant_override<'a>( + overrides: AstConstantOverrides<'a>, + expr: &ast::Expr, +) -> Option<&'a ConstantData> { + let index = ast::HasNodeIndex::node_index(expr).load(); + if index == ast::NodeIndex::NONE { + return None; + } + overrides?.get(&index) +} + +fn validate_pattern_match_value( + vm: &VirtualMachine, + expr: &ast::Expr, + ast_constant_overrides: AstConstantOverrides<'_>, +) -> PyResult<()> { validate_expr(vm, expr, ast::ExprContext::Load)?; + if let Some(constant) = public_ast_constant_override(ast_constant_overrides, expr) { + return match constant { + ConstantData::Integer { .. } + | ConstantData::Float { .. } + | ConstantData::Bytes { .. } + | ConstantData::Complex { .. } + | ConstantData::Str { .. } => Ok(()), + _ => Err(vm.new_value_error("unexpected constant inside of a literal pattern")), + }; + } match expr { ast::Expr::NumberLiteral(_) | ast::Expr::StringLiteral(_) | ast::Expr::BytesLiteral(_) => { Ok(()) } ast::Expr::Attribute(_) => Ok(()), - ast::Expr::UnaryOp(op) => match &*op.operand { - ast::Expr::NumberLiteral(_) => Ok(()), - _ => Err(vm.new_value_error("patterns may only match literals and attribute lookups")), - }, - ast::Expr::BinOp(bin) => match (&*bin.left, &*bin.right) { - (ast::Expr::NumberLiteral(_), ast::Expr::NumberLiteral(_)) => Ok(()), - _ => Err(vm.new_value_error("patterns may only match literals and attribute lookups")), - }, + ast::Expr::UnaryOp(_) if ensure_literal_negative(expr, true, true) => Ok(()), + ast::Expr::BinOp(_) if ensure_literal_complex(expr) => Ok(()), ast::Expr::FString(_) | ast::Expr::TString(_) => Ok(()), ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) @@ -156,13 +336,23 @@ fn validate_capture(vm: &VirtualMachine, name: &ast::Identifier) -> PyResult<()> validate_name(vm, name.id()) } -fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) -> PyResult<()> { +fn validate_pattern( + vm: &VirtualMachine, + pattern: &ast::Pattern, + star_ok: bool, + ast_constant_overrides: AstConstantOverrides<'_>, +) -> PyResult<()> { match pattern { - ast::Pattern::MatchValue(value) => validate_pattern_match_value(vm, &value.value), + ast::Pattern::MatchValue(value) => { + validate_pattern_match_value(vm, &value.value, ast_constant_overrides) + } ast::Pattern::MatchSingleton(singleton) => match singleton.value { ast::Singleton::None | ast::Singleton::True | ast::Singleton::False => Ok(()), }, - ast::Pattern::MatchSequence(seq) => validate_patterns(vm, &seq.patterns, true), + ast::Pattern::MatchSequence(seq) => { + validate_public_pattern_list_slots(vm, seq.node_index.load())?; + validate_patterns(vm, &seq.patterns, true, ast_constant_overrides) + } ast::Pattern::MatchMapping(mapping) => { if mapping.keys.len() != mapping.patterns.len() { return Err(vm.new_value_error( @@ -172,15 +362,25 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) if let Some(rest) = &mapping.rest { validate_capture(vm, rest)?; } + validate_public_expr_option_list_slots(vm, mapping.node_index.load())?; for key in &mapping.keys { if let ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) = key { continue; } - validate_pattern_match_value(vm, key)?; + validate_pattern_match_value(vm, key, ast_constant_overrides)?; } - validate_patterns(vm, &mapping.patterns, false) + validate_public_pattern_list_slots(vm, mapping.node_index.load())?; + validate_patterns(vm, &mapping.patterns, false, ast_constant_overrides) } ast::Pattern::MatchClass(match_class) => { + let public_match_class = public_match_class(match_class.node_index.load()); + if let Some(values) = &public_match_class + && values.kwd_attrs.len() != values.kwd_patterns.len() + { + return Err(vm.new_value_error( + "MatchClass doesn't have the same number of keyword attributes as patterns", + )); + } validate_expr(vm, &match_class.cls, ast::ExprContext::Load)?; let mut cls = match_class.cls.as_ref(); loop { @@ -199,9 +399,20 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) for keyword in &match_class.arguments.keywords { validate_name(vm, keyword.attr.id())?; } - validate_patterns(vm, &match_class.arguments.patterns, false)?; + if let Some(values) = &public_match_class { + validate_public_nullable_patterns(vm, &values.patterns)?; + } + validate_patterns( + vm, + &match_class.arguments.patterns, + false, + ast_constant_overrides, + )?; + if let Some(values) = &public_match_class { + validate_public_nullable_patterns(vm, &values.kwd_patterns)?; + } for keyword in &match_class.arguments.keywords { - validate_pattern(vm, &keyword.pattern, false)?; + validate_pattern(vm, &keyword.pattern, false, ast_constant_overrides)?; } Ok(()) } @@ -226,7 +437,7 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) "MatchAs must specify a target name if a pattern is given", )); } - validate_pattern(vm, pattern, false) + validate_pattern(vm, pattern, false, ast_constant_overrides) } } } @@ -234,18 +445,184 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) if match_or.patterns.len() < 2 { return Err(vm.new_value_error("MatchOr requires at least 2 patterns")); } - validate_patterns(vm, &match_or.patterns, false) + validate_public_pattern_list_slots(vm, match_or.node_index.load())?; + validate_patterns(vm, &match_or.patterns, false, ast_constant_overrides) + } + } +} + +fn public_pattern_list_has_null(node_index: ast::NodeIndex) -> bool { + if node_index == ast::NodeIndex::NONE { + return false; + } + PUBLIC_AST_PATTERN_LISTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index)) + .is_some_and(|values| values.values.iter().any(Option::is_none)) + }) +} + +fn validate_public_pattern_list_slots( + vm: &VirtualMachine, + node_index: ast::NodeIndex, +) -> PyResult<()> { + if public_pattern_list_has_null(node_index) { + return Err(vm.new_value_error("unexpected pattern")); + } + Ok(()) +} + +fn public_expr_option_list_has_null(node_index: ast::NodeIndex) -> bool { + if node_index == ast::NodeIndex::NONE { + return false; + } + PUBLIC_AST_EXPR_OPTION_LISTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index)) + .is_some_and(|values| values.values.iter().any(Option::is_none)) + }) +} + +fn public_expr_option_list( + node_index: ast::NodeIndex, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_EXPR_OPTION_LISTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +fn public_expr_list( + node_index: ast::NodeIndex, + field: super::constant::PublicAstExprListField, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_EXPR_LISTS.with(|cell| { + cell.borrow().as_ref().and_then(|values| { + values + .get(&node_index) + .and_then(|values| values.get(field)) + .cloned() + }) + }) +} + +fn public_stmt_list( + node_index: ast::NodeIndex, + field: super::constant::PublicAstStmtListField, +) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_STMT_LISTS.with(|cell| { + cell.borrow().as_ref().and_then(|values| { + values + .get(&node_index) + .and_then(|values| values.get(field)) + .cloned() + }) + }) +} + +fn validate_public_expr_option_list_slots( + vm: &VirtualMachine, + node_index: ast::NodeIndex, +) -> PyResult<()> { + if public_expr_option_list_has_null(node_index) { + return Err(vm.new_value_error("None disallowed in expression list")); + } + Ok(()) +} + +fn validate_public_expr_list_slots( + vm: &VirtualMachine, + node_index: ast::NodeIndex, + field: super::constant::PublicAstExprListField, + ctx: ast::ExprContext, +) -> PyResult<()> { + if let Some(values) = public_expr_list(node_index, field) { + for value in values.values { + let Some(value) = value else { + return Err(vm.new_value_error("None disallowed in expression list")); + }; + validate_expr(vm, &value, ctx)?; } } + Ok(()) +} + +fn validate_public_stmt_list_slots( + vm: &VirtualMachine, + node_index: ast::NodeIndex, + field: super::constant::PublicAstStmtListField, +) -> PyResult<()> { + if let Some(values) = public_stmt_list(node_index, field) + && values.values.iter().any(Option::is_none) + { + return Err(vm.new_value_error("None disallowed in statement list")); + } + Ok(()) +} + +fn public_except_handler_list_has_null(node_index: ast::NodeIndex) -> bool { + if node_index == ast::NodeIndex::NONE { + return false; + } + PUBLIC_AST_EXCEPT_HANDLER_LISTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index)) + .is_some_and(|values| values.values.iter().any(Option::is_none)) + }) +} + +fn validate_public_except_handler_list_slots( + vm: &VirtualMachine, + node_index: ast::NodeIndex, +) -> PyResult<()> { + if public_except_handler_list_has_null(node_index) { + return Err(vm.new_value_error("unexpected excepthandler")); + } + Ok(()) +} + +fn public_match_class(node_index: ast::NodeIndex) -> Option { + if node_index == ast::NodeIndex::NONE { + return None; + } + PUBLIC_AST_MATCH_CLASSES.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) +} + +fn validate_public_nullable_patterns( + vm: &VirtualMachine, + patterns: &[Option], +) -> PyResult<()> { + if patterns.iter().any(Option::is_none) { + return Err(vm.new_value_error("unexpected pattern")); + } + Ok(()) } fn validate_patterns( vm: &VirtualMachine, patterns: &[ast::Pattern], star_ok: bool, + ast_constant_overrides: AstConstantOverrides<'_>, ) -> PyResult<()> { for pattern in patterns { - validate_pattern(vm, pattern, star_ok)?; + validate_pattern(vm, pattern, star_ok, ast_constant_overrides)?; } Ok(()) } @@ -282,6 +659,19 @@ fn validate_type_params( type_params: Option<&ast::TypeParams>, ) -> PyResult<()> { if let Some(type_params) = type_params { + let node_index = type_params.node_index.load(); + if node_index != ast::NodeIndex::NONE + && let Some(values) = PUBLIC_AST_TYPE_PARAM_LISTS.with(|cell| { + cell.borrow() + .as_ref() + .and_then(|values| values.get(&node_index).cloned()) + }) + { + for tp in values.values.iter().flatten() { + validate_typeparam(vm, tp)?; + } + return Ok(()); + } for tp in &type_params.type_params { validate_typeparam(vm, tp)?; } @@ -337,6 +727,12 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - if op.values.len() < 2 { return Err(vm.new_value_error("BoolOp with less than 2 values")); } + validate_public_expr_list_slots( + vm, + op.node_index.load(), + super::constant::PublicAstExprListField::Values, + ast::ExprContext::Load, + )?; validate_exprs(vm, &op.values, ast::ExprContext::Load, false) } ast::Expr::Named(named) => { @@ -362,6 +758,12 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - validate_expr(vm, &ifexp.orelse, ast::ExprContext::Load) } ast::Expr::Dict(dict) => { + validate_public_expr_list_slots( + vm, + dict.node_index.load(), + super::constant::PublicAstExprListField::Values, + ast::ExprContext::Load, + )?; for item in &dict.items { if let Some(key) = &item.key { validate_expr(vm, key, ast::ExprContext::Load)?; @@ -370,7 +772,15 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - } Ok(()) } - ast::Expr::Set(set) => validate_exprs(vm, &set.elts, ast::ExprContext::Load, false), + ast::Expr::Set(set) => { + validate_public_expr_list_slots( + vm, + set.node_index.load(), + super::constant::PublicAstExprListField::Elts, + ast::ExprContext::Load, + )?; + validate_exprs(vm, &set.elts, ast::ExprContext::Load, false) + } ast::Expr::ListComp(list) => { validate_comprehension(vm, &list.generators)?; validate_expr(vm, &list.elt, ast::ExprContext::Load) @@ -381,7 +791,10 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - } ast::Expr::DictComp(dict) => { validate_comprehension(vm, &dict.generators)?; - validate_expr(vm, &dict.key, ast::ExprContext::Load)?; + let Some(key) = &dict.key else { + return Err(vm.new_value_error("DictComp with no key")); + }; + validate_expr(vm, key, ast::ExprContext::Load)?; validate_expr(vm, &dict.value, ast::ExprContext::Load) } ast::Expr::Generator(generator) => { @@ -409,34 +822,76 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - "Compare has a different number of comparators and operands", )); } + validate_public_expr_list_slots( + vm, + compare.node_index.load(), + super::constant::PublicAstExprListField::Comparators, + ast::ExprContext::Load, + )?; validate_exprs(vm, &compare.comparators, ast::ExprContext::Load, false)?; validate_expr(vm, &compare.left, ast::ExprContext::Load) } ast::Expr::Call(call) => { validate_expr(vm, &call.func, ast::ExprContext::Load)?; + validate_public_expr_list_slots( + vm, + call.arguments.node_index.load(), + super::constant::PublicAstExprListField::Args, + ast::ExprContext::Load, + )?; validate_exprs(vm, &call.arguments.args, ast::ExprContext::Load, false)?; validate_keywords(vm, &call.arguments.keywords) } - ast::Expr::FString(fstring) => validate_interpolated_elements( - vm, - fstring - .value - .elements() - .map(ast::InterpolatedStringElementRef::from), - ), - ast::Expr::TString(tstring) => validate_interpolated_elements( - vm, - tstring - .value - .elements() - .map(ast::InterpolatedStringElementRef::from), - ), + ast::Expr::FString(fstring) => { + validate_public_expr_list_slots( + vm, + fstring.node_index.load(), + super::constant::PublicAstExprListField::Values, + ast::ExprContext::Load, + )?; + if let Some(joined_str) = public_ast_joined_str_values(fstring) { + validate_exprs(vm, &joined_str.values, ast::ExprContext::Load, false) + } else { + validate_interpolated_elements( + vm, + fstring + .value + .elements() + .map(ast::InterpolatedStringElementRef::from), + ) + } + } + ast::Expr::TString(tstring) => { + validate_public_expr_list_slots( + vm, + tstring.node_index.load(), + super::constant::PublicAstExprListField::Values, + ast::ExprContext::Load, + )?; + if let Some(template_str) = public_ast_template_str_values(tstring) { + validate_exprs(vm, &template_str.values, ast::ExprContext::Load, false) + } else { + validate_interpolated_elements( + vm, + tstring + .value + .elements() + .map(ast::InterpolatedStringElementRef::from), + ) + } + } ast::Expr::StringLiteral(_) | ast::Expr::BytesLiteral(_) | ast::Expr::NumberLiteral(_) | ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) - | ast::Expr::EllipsisLiteral(_) => Ok(()), + | ast::Expr::EllipsisLiteral(_) => { + if let Some(invalid_type) = public_ast_invalid_constant_type(expr) { + Err(vm.new_type_error(format!("got an invalid type in Constant: {invalid_type}"))) + } else { + Ok(()) + } + } ast::Expr::Attribute(attr) => validate_expr(vm, &attr.value, ast::ExprContext::Load), ast::Expr::Subscript(sub) => { validate_expr(vm, &sub.slice, ast::ExprContext::Load)?; @@ -444,8 +899,24 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - } ast::Expr::Starred(star) => validate_expr(vm, &star.value, ctx), ast::Expr::Name(_) => Ok(()), - ast::Expr::List(list) => validate_exprs(vm, &list.elts, ctx, false), - ast::Expr::Tuple(tuple) => validate_exprs(vm, &tuple.elts, ctx, false), + ast::Expr::List(list) => { + validate_public_expr_list_slots( + vm, + list.node_index.load(), + super::constant::PublicAstExprListField::Elts, + ctx, + )?; + validate_exprs(vm, &list.elts, ctx, false) + } + ast::Expr::Tuple(tuple) => { + validate_public_expr_list_slots( + vm, + tuple.node_index.load(), + super::constant::PublicAstExprListField::Elts, + ctx, + )?; + validate_exprs(vm, &tuple.elts, ctx, false) + } ast::Expr::Slice(slice) => { if let Some(lower) = &slice.lower { validate_expr(vm, lower, ast::ExprContext::Load)?; @@ -458,7 +929,7 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - } Ok(()) } - ast::Expr::IpyEscapeCommand(_) => Ok(()), + ast::Expr::IpyEscapeCommand(_) => Err(invalid_syntax_error(vm)), } } @@ -469,7 +940,12 @@ fn validate_decorators(vm: &VirtualMachine, decorators: &[ast::Decorator]) -> Py Ok(()) } -fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { +fn validate_stmt( + vm: &VirtualMachine, + stmt: &ast::Stmt, + ast_constant_overrides: AstConstantOverrides<'_>, + ast_import_from_level_overrides: AstImportFromLevelOverrides<'_>, +) -> PyResult<()> { match stmt { ast::Stmt::FunctionDef(func) => { let owner = if func.is_async { @@ -477,9 +953,22 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { } else { "FunctionDef" }; - validate_body(vm, &func.body, owner)?; + validate_body( + vm, + &func.body, + func.node_index.load(), + owner, + ast_constant_overrides, + ast_import_from_level_overrides, + )?; validate_type_params(vm, func.type_params.as_deref())?; validate_parameters(vm, &func.parameters)?; + validate_public_expr_list_slots( + vm, + func.node_index.load(), + super::constant::PublicAstExprListField::DecoratorList, + ast::ExprContext::Load, + )?; validate_decorators(vm, &func.decorator_list)?; if let Some(returns) = &func.returns { validate_expr(vm, returns, ast::ExprContext::Load)?; @@ -487,12 +976,31 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { Ok(()) } ast::Stmt::ClassDef(class_def) => { - validate_body(vm, &class_def.body, "ClassDef")?; + validate_body( + vm, + &class_def.body, + class_def.node_index.load(), + "ClassDef", + ast_constant_overrides, + ast_import_from_level_overrides, + )?; validate_type_params(vm, class_def.type_params.as_deref())?; if let Some(arguments) = &class_def.arguments { + validate_public_expr_list_slots( + vm, + arguments.node_index.load(), + super::constant::PublicAstExprListField::Bases, + ast::ExprContext::Load, + )?; validate_exprs(vm, &arguments.args, ast::ExprContext::Load, false)?; validate_keywords(vm, &arguments.keywords)?; } + validate_public_expr_list_slots( + vm, + class_def.node_index.load(), + super::constant::PublicAstExprListField::DecoratorList, + ast::ExprContext::Load, + )?; validate_decorators(vm, &class_def.decorator_list) } ast::Stmt::Return(ret) => { @@ -501,8 +1009,22 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { } Ok(()) } - ast::Stmt::Delete(del) => validate_assignlist(vm, &del.targets, ast::ExprContext::Del), + ast::Stmt::Delete(del) => { + validate_public_expr_list_slots( + vm, + del.node_index.load(), + super::constant::PublicAstExprListField::Targets, + ast::ExprContext::Del, + )?; + validate_assignlist(vm, &del.targets, ast::ExprContext::Del) + } ast::Stmt::Assign(assign) => { + validate_public_expr_list_slots( + vm, + assign.node_index.load(), + super::constant::PublicAstExprListField::Targets, + ast::ExprContext::Store, + )?; validate_assignlist(vm, &assign.targets, ast::ExprContext::Store)?; validate_expr(vm, &assign.value, ast::ExprContext::Load) } @@ -532,22 +1054,75 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { let owner = if for_stmt.is_async { "AsyncFor" } else { "For" }; validate_expr(vm, &for_stmt.target, ast::ExprContext::Store)?; validate_expr(vm, &for_stmt.iter, ast::ExprContext::Load)?; - validate_body(vm, &for_stmt.body, owner)?; - validate_stmts(vm, &for_stmt.orelse) + validate_body( + vm, + &for_stmt.body, + for_stmt.node_index.load(), + owner, + ast_constant_overrides, + ast_import_from_level_overrides, + )?; + validate_public_stmt_list_slots( + vm, + for_stmt.node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + )?; + validate_stmts( + vm, + &for_stmt.orelse, + ast_constant_overrides, + ast_import_from_level_overrides, + ) } ast::Stmt::While(while_stmt) => { validate_expr(vm, &while_stmt.test, ast::ExprContext::Load)?; - validate_body(vm, &while_stmt.body, "While")?; - validate_stmts(vm, &while_stmt.orelse) + validate_body( + vm, + &while_stmt.body, + while_stmt.node_index.load(), + "While", + ast_constant_overrides, + ast_import_from_level_overrides, + )?; + validate_public_stmt_list_slots( + vm, + while_stmt.node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + )?; + validate_stmts( + vm, + &while_stmt.orelse, + ast_constant_overrides, + ast_import_from_level_overrides, + ) } ast::Stmt::If(if_stmt) => { validate_expr(vm, &if_stmt.test, ast::ExprContext::Load)?; - validate_body(vm, &if_stmt.body, "If")?; + validate_body( + vm, + &if_stmt.body, + if_stmt.node_index.load(), + "If", + ast_constant_overrides, + ast_import_from_level_overrides, + )?; + validate_public_stmt_list_slots( + vm, + if_stmt.node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + )?; for clause in &if_stmt.elif_else_clauses { if let Some(test) = &clause.test { validate_expr(vm, test, ast::ExprContext::Load)?; } - validate_body(vm, &clause.body, "If")?; + validate_body( + vm, + &clause.body, + clause.node_index.load(), + "If", + ast_constant_overrides, + ast_import_from_level_overrides, + )?; } Ok(()) } @@ -564,17 +1139,31 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { validate_expr(vm, optional_vars, ast::ExprContext::Store)?; } } - validate_body(vm, &with_stmt.body, owner) + validate_body( + vm, + &with_stmt.body, + with_stmt.node_index.load(), + owner, + ast_constant_overrides, + ast_import_from_level_overrides, + ) } ast::Stmt::Match(match_stmt) => { validate_expr(vm, &match_stmt.subject, ast::ExprContext::Load)?; validate_nonempty_seq(vm, match_stmt.cases.len(), "cases", "Match")?; for case in &match_stmt.cases { - validate_pattern(vm, &case.pattern, false)?; + validate_pattern(vm, &case.pattern, false, ast_constant_overrides)?; if let Some(guard) = &case.guard { validate_expr(vm, guard, ast::ExprContext::Load)?; } - validate_body(vm, &case.body, "match_case")?; + validate_body( + vm, + &case.body, + case.node_index.load(), + "match_case", + ast_constant_overrides, + ast_import_from_level_overrides, + )?; } Ok(()) } @@ -591,7 +1180,14 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { } ast::Stmt::Try(try_stmt) => { let owner = if try_stmt.is_star { "TryStar" } else { "Try" }; - validate_body(vm, &try_stmt.body, owner)?; + validate_body( + vm, + &try_stmt.body, + try_stmt.node_index.load(), + owner, + ast_constant_overrides, + ast_import_from_level_overrides, + )?; if try_stmt.handlers.is_empty() && try_stmt.finalbody.is_empty() { return Err(vm.new_value_error(format!( "{owner} has neither except handlers nor finalbody" @@ -602,15 +1198,43 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { vm.new_value_error(format!("{owner} has orelse but no except handlers")) ); } + validate_public_except_handler_list_slots(vm, try_stmt.node_index.load())?; for handler in &try_stmt.handlers { let ast::ExceptHandler::ExceptHandler(handler) = handler; if let Some(type_expr) = &handler.type_ { validate_expr(vm, type_expr, ast::ExprContext::Load)?; } - validate_body(vm, &handler.body, "ExceptHandler")?; + validate_body( + vm, + &handler.body, + handler.node_index.load(), + "ExceptHandler", + ast_constant_overrides, + ast_import_from_level_overrides, + )?; } - validate_stmts(vm, &try_stmt.finalbody)?; - validate_stmts(vm, &try_stmt.orelse) + validate_public_stmt_list_slots( + vm, + try_stmt.node_index.load(), + super::constant::PublicAstStmtListField::FinalBody, + )?; + validate_stmts( + vm, + &try_stmt.finalbody, + ast_constant_overrides, + ast_import_from_level_overrides, + )?; + validate_public_stmt_list_slots( + vm, + try_stmt.node_index.load(), + super::constant::PublicAstStmtListField::Orelse, + )?; + validate_stmts( + vm, + &try_stmt.orelse, + ast_constant_overrides, + ast_import_from_level_overrides, + ) } ast::Stmt::Assert(assert_stmt) => { validate_expr(vm, &assert_stmt.test, ast::ExprContext::Load)?; @@ -624,6 +1248,12 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { Ok(()) } ast::Stmt::ImportFrom(import) => { + if let Some(level) = ast_import_from_level_overrides + .and_then(|overrides| overrides.get(&import.node_index.load())) + && *level < 0 + { + return Err(vm.new_value_error("Negative ImportFrom level")); + } validate_nonempty_seq(vm, import.names.len(), "names", "ImportFrom")?; Ok(()) } @@ -636,28 +1266,169 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { Ok(()) } ast::Stmt::Expr(expr) => validate_expr(vm, &expr.value, ast::ExprContext::Load), - ast::Stmt::Pass(_) - | ast::Stmt::Break(_) - | ast::Stmt::Continue(_) - | ast::Stmt::IpyEscapeCommand(_) => Ok(()), + ast::Stmt::Pass(_) | ast::Stmt::Break(_) | ast::Stmt::Continue(_) => Ok(()), + ast::Stmt::IpyEscapeCommand(_) => Err(invalid_syntax_error(vm)), } } -fn validate_stmts(vm: &VirtualMachine, stmts: &[ast::Stmt]) -> PyResult<()> { +fn validate_stmts( + vm: &VirtualMachine, + stmts: &[ast::Stmt], + ast_constant_overrides: AstConstantOverrides<'_>, + ast_import_from_level_overrides: AstImportFromLevelOverrides<'_>, +) -> PyResult<()> { for stmt in stmts { - validate_stmt(vm, stmt)?; + validate_stmt( + vm, + stmt, + ast_constant_overrides, + ast_import_from_level_overrides, + )?; } Ok(()) } -pub(super) fn validate_mod(vm: &VirtualMachine, module: &Mod) -> PyResult<()> { - match module { - Mod::Module(module) => validate_stmts(vm, &module.body), - Mod::Interactive(module) => validate_stmts(vm, &module.body), - Mod::Expression(expr) => validate_expr(vm, &expr.body, ast::ExprContext::Load), - Mod::FunctionType(func_type) => { - validate_exprs(vm, &func_type.argtypes, ast::ExprContext::Load, false)?; - validate_expr(vm, &func_type.returns, ast::ExprContext::Load) +#[expect( + clippy::too_many_arguments, + reason = "public AST validation installs independent override tables" +)] +pub(super) fn validate_mod( + vm: &VirtualMachine, + module: &Mod, + ast_constant_overrides: AstConstantOverrides<'_>, + ast_interpolation_overrides: AstInterpolationOverrides<'_>, + ast_formatted_value_overrides: AstFormattedValueOverrides<'_>, + ast_import_from_level_overrides: AstImportFromLevelOverrides<'_>, + ast_invalid_constant_overrides: AstInvalidConstantOverrides<'_>, + ast_joined_str_overrides: AstExprListOverrides<'_>, + ast_template_str_overrides: AstExprListOverrides<'_>, + ast_pattern_list_overrides: AstPatternListOverrides<'_>, + ast_expr_option_list_overrides: AstExprOptionListOverrides<'_>, + ast_expr_list_overrides: AstExprListFieldOverrides<'_>, + ast_stmt_list_overrides: AstStmtListOverrides<'_>, + ast_except_handler_list_overrides: AstExceptHandlerListOverrides<'_>, + ast_type_param_list_overrides: AstTypeParamListOverrides<'_>, + ast_match_class_overrides: AstMatchClassOverrides<'_>, +) -> PyResult<()> { + PUBLIC_AST_INVALID_CONSTANTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_invalid_constant_overrides.cloned(); + }); + PUBLIC_AST_JOINED_STRS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_joined_str_overrides.cloned(); + }); + PUBLIC_AST_TEMPLATE_STRS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_template_str_overrides.cloned(); + }); + PUBLIC_AST_PATTERN_LISTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_pattern_list_overrides.cloned(); + }); + PUBLIC_AST_EXPR_OPTION_LISTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_expr_option_list_overrides.cloned(); + }); + PUBLIC_AST_EXPR_LISTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_expr_list_overrides.cloned(); + }); + PUBLIC_AST_STMT_LISTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_stmt_list_overrides.cloned(); + }); + PUBLIC_AST_EXCEPT_HANDLER_LISTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_except_handler_list_overrides.cloned(); + }); + PUBLIC_AST_TYPE_PARAM_LISTS.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_type_param_list_overrides.cloned(); + }); + PUBLIC_AST_MATCH_CLASSES.with(|cell| { + debug_assert!(cell.borrow().is_none()); + *cell.borrow_mut() = ast_match_class_overrides.cloned(); + }); + let result = (|| { + if let Some(overrides) = ast_interpolation_overrides { + for interpolation in overrides.values() { + if let Some(format_spec) = &interpolation.format_spec { + validate_expr(vm, format_spec, ast::ExprContext::Load)?; + } + } } - } + if let Some(overrides) = ast_formatted_value_overrides { + for formatted_value in overrides.values() { + if let Some(format_spec) = &formatted_value.format_spec { + validate_expr(vm, format_spec, ast::ExprContext::Load)?; + } + } + } + match module { + Mod::Module(module) => { + validate_public_stmt_list_slots( + vm, + module.module.node_index.load(), + super::constant::PublicAstStmtListField::Body, + )?; + validate_stmts( + vm, + &module.module.body, + ast_constant_overrides, + ast_import_from_level_overrides, + ) + } + Mod::Interactive(module) => { + validate_public_stmt_list_slots( + vm, + module.node_index.load(), + super::constant::PublicAstStmtListField::Body, + )?; + validate_stmts( + vm, + &module.body, + ast_constant_overrides, + ast_import_from_level_overrides, + ) + } + Mod::Expression(expr) => validate_expr(vm, &expr.body, ast::ExprContext::Load), + Mod::FunctionType(func_type) => { + validate_public_expr_option_list_slots(vm, func_type.node_index.load())?; + validate_exprs(vm, &func_type.argtypes, ast::ExprContext::Load, false)?; + validate_expr(vm, &func_type.returns, ast::ExprContext::Load) + } + } + })(); + PUBLIC_AST_INVALID_CONSTANTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_JOINED_STRS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_TEMPLATE_STRS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_PATTERN_LISTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_EXPR_OPTION_LISTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_EXPR_LISTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_STMT_LISTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_EXCEPT_HANDLER_LISTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_TYPE_PARAM_LISTS.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + PUBLIC_AST_MATCH_CLASSES.with(|cell| { + let _ = cell.borrow_mut().take(); + }); + result } diff --git a/crates/vm/src/stdlib/_symtable.rs b/crates/vm/src/stdlib/_symtable.rs index 4c6ec75f5ac..adcddeacc1f 100644 --- a/crates/vm/src/stdlib/_symtable.rs +++ b/crates/vm/src/stdlib/_symtable.rs @@ -153,7 +153,9 @@ mod _symtable { CompilerScope::Class => TYPE_CLASS, CompilerScope::Module => TYPE_MODULE, CompilerScope::Annotation => TYPE_ANNOTATION, + CompilerScope::TypeAlias => TYPE_TYPE_ALIAS, CompilerScope::TypeParams => TYPE_TYPE_PARAMETERS, + CompilerScope::TypeVariable => TYPE_TYPE_VARIABLE, } } diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index d0ed32b22d6..d8ef47a7a55 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -13,11 +13,12 @@ mod builtins { PyByteArray, PyBytes, PyDictRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyUtf8StrRef, enumerate::PyReverseSequenceIterator, - function::{PyCellRef, PyFunction}, + function::{PyCell, PyCellRef, PyFunction}, int::PyIntRef, iter::PyCallableIterator, list::{PyList, SortOptions}, }, + bytecode, common::hash::PyHash, function::{ ArgBytesLike, ArgCallable, ArgIndex, ArgIntoBool, ArgIterable, ArgMapping, @@ -29,6 +30,11 @@ mod builtins { readline::{Readline, ReadlineResult}, stdlib::sys, types::PyComparisonOp, + vm::compile_mode::{ + PY_CF_ALLOW_TOP_LEVEL_AWAIT, PY_CF_ALLOWED_FLAGS, PY_CF_IGNORE_COOKIE, PY_CF_ONLY_AST, + PY_CF_OPTIMIZED_AST, PY_EVAL_INPUT, PY_FILE_INPUT, PY_FUNC_TYPE_INPUT, PY_SINGLE_INPUT, + compile_future_feature_mask, compile_future_features_from_flags, + }, }; use itertools::Itertools; use num_traits::{Signed, ToPrimitive, Zero}; @@ -103,8 +109,8 @@ mod builtins { filename: PyObjectRef, mode: PyUtf8StrRef, // CPython parity: flags / optimize accept any object with __index__, - // not just exact int. Matches the behavior of `int(x)` arg conversion - // used by Python/Python-ast.c::compile. + // not just exact int. Matches the argument conversion used by + // CPython's builtin_compile_impl. #[pyarg(any, optional)] flags: OptionalArg>, // CPython parity: dont_inherit goes through PyObject_IsTrue, so @@ -114,174 +120,67 @@ mod builtins { dont_inherit: OptionalArg, #[pyarg(any, optional)] optimize: OptionalArg>, - #[pyarg(any, optional)] + #[pyarg(named, optional)] _feature_version: OptionalArg, } - /// Detect PEP 263 encoding cookie from source bytes. - /// Checks first two lines for `# coding[:=] ` pattern. - /// Returns the encoding name if found, or None for default (UTF-8). - #[cfg(feature = "parser")] - fn detect_source_encoding(source: &[u8]) -> Option { - fn find_encoding_in_line(line: &[u8]) -> Option { - // PEP 263: '#' must be preceded only by whitespace/formfeed - let hash_pos = line.iter().position(|&b| b == b'#')?; - if !line[..hash_pos] - .iter() - .all(|&b| matches!(b, b' ' | b'\t' | b'\x0c' | b'\r')) - { - return None; - } - let after_hash = &line[hash_pos..]; - - // Find "coding" after the # - let coding_pos = after_hash.windows(6).position(|w| w == b"coding")?; - let after_coding = &after_hash[coding_pos + 6..]; - - // Next char must be ':' or '=' - let rest = if matches!(after_coding.first(), Some(b':' | b'=')) { - &after_coding[1..] - } else { - return None; - }; - - // Skip whitespace - let rest = rest - .iter() - .copied() - .skip_while(|&b| matches!(b, b' ' | b'\t')) - .collect::>(); - - // Read encoding name: [-\w.]+ - let name = rest - .iter() - .take_while(|&&b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.')) - .map(|&b| b as char) - .collect::(); - - if name.is_empty() { - None - } else { - Some(normalize_source_encoding(&name)) - } - } - - // Split into lines (first two only) - let mut lines = source.splitn(3, |&b| b == b'\n'); - - if let Some(first) = lines.next() { - // Strip BOM if present - let first = first.strip_prefix(b"\xef\xbb\xbf").unwrap_or(first); - if let Some(enc) = find_encoding_in_line(first) { - return Some(enc); - } - // Only check second line if first line is blank or a comment - let trimmed = first - .iter() - .find(|&&b| !matches!(b, b' ' | b'\t' | b'\x0c' | b'\r')) - .copied(); - - if trimmed.is_some_and(|b| b != b'#') { - return None; - } + fn merge_compile_future_features( + flags: i32, + dont_inherit: bool, + vm: &VirtualMachine, + ) -> bytecode::CodeFlags { + let mut future_features = compile_future_features_from_flags(flags); + if !dont_inherit && let Some(frame) = vm.current_frame() { + future_features |= bytecode::CodeFlags::from_bits_truncate( + frame.code.flags.bits() & compile_future_feature_mask().bits(), + ); } - - lines.next().and_then(find_encoding_in_line) + future_features } - /// Match CPython's Parser/tokenizer/helpers.c:get_normal_name(). - #[cfg(feature = "parser")] - fn normalize_source_encoding(name: &str) -> String { - let mut normalized = String::with_capacity(name.len().min(12)); - for ch in name.chars().take(12) { - if ch == '_' { - normalized.push('-'); - } else { - normalized.push(ch.to_ascii_lowercase()); - } - } + fn audit_compile_source(vm: &VirtualMachine, source: &[u8], filename: &str) -> PyResult<()> { + vm.sys_module.get_attr("audit", vm)?.call( + ( + vm.ctx.new_str("compile"), + vm.ctx.new_bytes(source.to_vec()), + vm.ctx.new_str(filename), + ), + vm, + )?; + Ok(()) + } - if normalized == "utf-8" || normalized.starts_with("utf-8-") { - "utf-8".to_owned() - } else if normalized == "latin-1" - || normalized == "iso-8859-1" - || normalized == "iso-latin-1" - || normalized.starts_with("latin-1-") - || normalized.starts_with("iso-8859-1-") - || normalized.starts_with("iso-latin-1-") + fn trim_eval_source_bytes(mut source: &[u8]) -> &[u8] { + while let Some((&first, rest)) = source.split_first() + && matches!(first, b' ' | b'\t') { - "iso-8859-1".to_owned() - } else { - name.to_owned() + source = rest; } + source } - /// Decode source bytes to a string, handling PEP 263 encoding declarations - /// and BOM. Raises SyntaxError for invalid UTF-8 without an encoding - /// declaration. - #[cfg(feature = "parser")] - fn is_utf8_encoding(name: &str) -> bool { - name == "utf-8" - } - - #[cfg(feature = "parser")] - fn decode_source_bytes(source: &[u8], filename: &str, vm: &VirtualMachine) -> PyResult { - let has_bom = source.starts_with(b"\xef\xbb\xbf"); - let encoding = detect_source_encoding(source); - - let is_utf8 = encoding.as_deref().is_none_or(is_utf8_encoding); - - // Validate BOM + encoding combination - if has_bom && !is_utf8 { - let enc = encoding.as_deref().unwrap_or("utf-8"); - return Err(vm.new_exception_msg( - vm.ctx.exceptions.syntax_error.to_owned(), - format!("encoding problem: {enc} with BOM").into(), - )); + fn decode_eval_exec_source_bytes( + vm: &VirtualMachine, + source: &[u8], + filename: &str, + ) -> PyResult { + #[cfg(feature = "parser")] + { + vm.decode_source_bytes(source, filename, false) } - - if is_utf8 { - let src = if has_bom { &source[3..] } else { source }; - match core::str::from_utf8(src) { - Ok(s) => Ok(s.to_owned()), - Err(e) => { - let bad_byte = src[e.valid_up_to()]; - let line = src[..e.valid_up_to()] - .iter() - .filter(|&&b| b == b'\n') - .count() - + 1; - Err(vm.new_exception_msg( - vm.ctx.exceptions.syntax_error.to_owned(), - format!( - "Non-UTF-8 code starting with '\\x{bad_byte:02x}' \ - on line {line}, but no encoding declared; \ - see https://peps.python.org/pep-0263/ for details \ - ({filename}, line {line})" - ) - .into(), - )) - } - } - } else { - // Use codec registry for non-UTF-8 encodings - let enc = encoding.as_deref().unwrap(); - let bytes_obj = vm.ctx.new_bytes(source.to_vec()); - let decoded = vm - .state - .codec_registry - .decode_text(bytes_obj.into(), enc, None, vm) - .map_err(|exc| { - if exc.fast_isinstance(vm.ctx.exceptions.lookup_error) { - vm.new_exception_msg( - vm.ctx.exceptions.syntax_error.to_owned(), - format!("unknown encoding for '{filename}': {enc}").into(), - ) - } else { - exc - } - })?; - Ok(decoded.to_string_lossy().into_owned()) + #[cfg(not(feature = "parser"))] + { + _ = filename; + core::str::from_utf8(source) + .map(str::to_owned) + .map_err(|err| { + let msg = format!( + "(unicode error) 'utf-8' codec can't decode byte 0x{:x?} in position {}: invalid start byte", + source[err.valid_up_to()], + err.valid_up_to() + ); + vm.new_exception_msg(vm.ctx.exceptions.syntax_error.to_owned(), msg.into()) + }) } } @@ -303,30 +202,59 @@ mod builtins { use crate::{class::PyClassImpl, stdlib::_ast}; - let feature_version = feature_version_from_arg(args._feature_version, vm)?; + let feature_version = args._feature_version.into_option().unwrap_or(-1); let mode_str = args.mode.as_str(); + let flags: i32 = args.flags.map_or(0, |v| v.value); + + if !(flags & !PY_CF_ALLOWED_FLAGS).is_zero() { + return Err(vm.new_value_error("compile(): unrecognised flags")); + } let optimize: i32 = args.optimize.map_or(-1, |v| v.value); let optimize: u8 = match optimize { - -1 => vm.state.config.settings.optimize, + -1 => vm.state.config.settings.optimize.min(2), 0..=2 => optimize as u8, _ => return Err(vm.new_value_error("compile(): invalid optimize value")), }; - - if args - .source - .fast_isinstance(&_ast::NodeAst::make_static_type()) - { - let flags: i32 = args.flags.map_or(0, |v| v.value); - let is_ast_only = !(flags & _ast::PY_CF_ONLY_AST).is_zero(); - - // func_type mode requires PyCF_ONLY_AST - if mode_str == "func_type" && !is_ast_only { + let dont_inherit = args.dont_inherit.map_or(false, ArgIntoBool::into_bool); + let is_ast_only = !(flags & PY_CF_ONLY_AST).is_zero(); + let future_features = merge_compile_future_features(flags, dont_inherit, vm); + + let start = if mode_str == "exec" { + PY_FILE_INPUT + } else if mode_str == "eval" { + PY_EVAL_INPUT + } else if mode_str == "single" { + PY_SINGLE_INPUT + } else if mode_str == "func_type" { + if !is_ast_only { return Err(vm.new_value_error( "compile() mode 'func_type' requires flag PyCF_ONLY_AST", )); } + PY_FUNC_TYPE_INPUT + } else { + let msg = if is_ast_only { + "compile() mode must be 'exec', 'eval', 'single' or 'func_type'" + } else { + "compile() mode must be 'exec', 'eval' or 'single'" + }; + return Err(vm.new_value_error(msg)); + }; + + let ast_type = _ast::NodeAst::make_static_type().as_object().to_owned(); + if args.source.is_instance(&ast_type, vm)? { + let explicit_future_annotations = + future_features.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + vm.sys_module.get_attr("audit", vm)?.call( + ( + vm.ctx.new_str("compile"), + args.source.clone(), + vm.ctx.none(), + ), + vm, + )?; // compile(ast_node, ..., PyCF_ONLY_AST) returns the AST after validation if is_ast_only { @@ -336,15 +264,29 @@ mod builtins { "compile() mode must be 'exec', 'eval', 'single' or 'func_type'", ) })?; - if !args.source.fast_isinstance(&expected_type) { + if !args.source.is_instance(expected_type.as_object(), vm)? { return Err(vm.new_type_error(format!( "expected {} node, got {}", expected_name, args.source.class().name() ))); } - _ast::validate_ast_object(vm, args.source.clone())?; - return Ok(args.source); + #[cfg(not(feature = "rustpython-codegen"))] + { + _ast::validate_ast_object(vm, args.source.clone())?; + return Ok(args.source); + } + #[cfg(feature = "rustpython-codegen")] + { + return _ast::preprocess_ast_object( + vm, + args.source, + &filename.to_string_lossy(), + optimize, + (flags & PY_CF_OPTIMIZED_AST) == PY_CF_OPTIMIZED_AST, + explicit_future_annotations, + ); + } } #[cfg(not(feature = "rustpython-codegen"))] @@ -353,133 +295,103 @@ mod builtins { } #[cfg(feature = "rustpython-codegen")] { + let (expected_type, expected_name) = _ast::mode_type_and_name(mode_str) + .ok_or_else(|| { + vm.new_value_error("compile() mode must be 'exec', 'eval' or 'single'") + })?; + if !args.source.is_instance(expected_type.as_object(), vm)? { + return Err(vm.new_type_error(format!( + "expected {} node, got {}", + expected_name, + args.source.class().name() + ))); + } let mode = mode_str .parse::() .map_err(|err| vm.new_value_error(err.to_string()))?; - return _ast::compile( - vm, - args.source, - &filename.to_string_lossy(), - mode, - Some(optimize), - ); + let mut opts = vm.compile_opts(); + opts.optimize = optimize; + opts.allow_top_level_await = !(flags & PY_CF_ALLOW_TOP_LEVEL_AWAIT).is_zero(); + opts.future_features = future_features; + return _ast::compile(vm, args.source, &filename.to_string_lossy(), mode, opts); } } #[cfg(not(feature = "parser"))] - return Err(vm.new_type_error( - "can't compile() source code when the `parser` feature of rustpython is disabled", - )); - + { + const PARSER_NOT_SUPPORTED: &str = "can't compile() source code when the `parser` feature of rustpython is disabled"; + Err(vm.new_type_error(PARSER_NOT_SUPPORTED)) + } #[cfg(feature = "parser")] { - use crate::convert::ToPyException; - - use ruff_python_parser as parser; - let source = ArgStrOrBytesLike::try_from_object(vm, args.source)?; - let source = source.borrow_bytes(); - - let source = decode_source_bytes(&source, &filename.to_string_lossy(), vm)?; - let source = source.as_str(); - - let flags: i32 = args.flags.map_or(0, |v| v.value); - - if !(flags & !_ast::PY_COMPILE_FLAGS_MASK).is_zero() { - return Err(vm.new_value_error("compile(): unrecognised flags")); - } - - let allow_incomplete = !(flags & _ast::PY_CF_ALLOW_INCOMPLETE_INPUT).is_zero(); - let type_comments = !(flags & _ast::PY_CF_TYPE_COMMENTS).is_zero(); - - let optimize_level = optimize; - if (flags & _ast::PY_CF_ONLY_AST).is_zero() { - #[cfg(not(feature = "compiler"))] - { - Err(vm.new_value_error(CODEGEN_NOT_SUPPORTED)) - } - #[cfg(feature = "compiler")] - { - if let Some(feature_version) = feature_version { - let mode = mode_str - .parse::() - .map_err(|err| vm.new_value_error(err.to_string()))?; - let _ = _ast::parse( - vm, - source, - mode, - optimize_level, - Some(feature_version), - type_comments, - ) - .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(vm))?; - } - - let mode = mode_str - .parse::() - .map_err(|err| vm.new_value_error(err.to_string()))?; - - let mut opts = vm.compile_opts(); - opts.optimize = optimize; - - let code = vm - .compile_with_opts(source, mode, &filename.to_string_lossy(), opts) - .map_err(|err| { - (err, Some(source), allow_incomplete).to_pyexception(vm) - })?; - Ok(code.into()) - } - } else { - if mode_str == "func_type" { - return _ast::parse_func_type(vm, source, optimize_level, feature_version) - .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(vm)); - } - - let mode = mode_str - .parse::() - .map_err(|err| vm.new_value_error(err.to_string()))?; - let parsed = _ast::parse( - vm, + let mut compile_flags = flags | future_features.bits() as i32; + #[cfg(feature = "rustpython-compiler")] + let compile_source = |source: &[u8], compile_flags: i32| { + vm.compile_string_object_with_flags( source, - mode, - optimize_level, + &filename.to_string_lossy(), + start, + compile_flags, feature_version, - type_comments, + optimize as i32, ) - .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(vm))?; - - if mode_str == "single" { - return _ast::wrap_interactive(vm, parsed); + }; + match &source { + ArgStrOrBytesLike::Str(source) => { + if source.as_bytes().contains(&0) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + "source code string cannot contain null bytes".into(), + )); + } + audit_compile_source( + vm, + source.as_bytes(), + filename.to_string_lossy().as_ref(), + )?; + compile_flags |= PY_CF_IGNORE_COOKIE; + #[cfg(feature = "rustpython-compiler")] + { + compile_source(source.as_bytes(), compile_flags) + } + #[cfg(not(feature = "rustpython-compiler"))] + { + Err(vm.new_value_error(CODEGEN_NOT_SUPPORTED)) + } + } + ArgStrOrBytesLike::Buf(source) => { + let source_bytes = source.borrow_buf(); + let source_bytes: &[u8] = &source_bytes; + if source_bytes.contains(&0) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + "source code string cannot contain null bytes".into(), + )); + } + audit_compile_source( + vm, + source_bytes, + filename.to_string_lossy().as_ref(), + )?; + #[cfg(feature = "rustpython-compiler")] + { + compile_source(source_bytes, compile_flags) + } + #[cfg(not(feature = "rustpython-compiler"))] + { + Err(vm.new_value_error(CODEGEN_NOT_SUPPORTED)) + } } - - Ok(parsed) } } } } - #[cfg(feature = "ast")] - fn feature_version_from_arg( - feature_version: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult> { - let Some(minor) = feature_version.into_option() else { - return Ok(None); - }; - - if minor < 0 { - return Ok(None); - } - - u8::try_from(minor) - .map(|v| Some(ruff_python_ast::PythonVersion { major: 3, minor: v })) - .map_err(|_| vm.new_value_error("compile() _feature_version out of range")) - } - #[pyfunction] fn delattr(obj: PyObjectRef, attr: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let attr = attr.try_to_ref::(vm).map_err(|_| { + let attr = attr.try_to_ref::(vm).map_err(|_e| { vm.new_type_error(format!( "attribute name must be string, not '{}'", attr.class().name() @@ -507,42 +419,39 @@ mod builtins { } impl ScopeArgs { - fn validate_globals_dict( - globals: &PyObject, - vm: &VirtualMachine, - func_name: &'static str, - ) -> PyResult<()> { - if globals.fast_isinstance(vm.ctx.types.dict_type) { - return Ok(()); - } - - let msg = match func_name { - "eval" => { - let is_mapping = globals.mapping_unchecked().check(); - if is_mapping { - "globals must be a real dict; try eval(expr, {}, mapping)".into() - } else { - "globals must be a dict".into() - } - } - "exec" => format!( - "exec() globals must be a dict, not {}", - globals.class().name() - ), - _ => "globals must be a dict".into(), - }; - - Err(vm.new_type_error(msg)) - } - fn make_scope( self, vm: &VirtualMachine, func_name: &'static str, ) -> PyResult { + fn validate_globals_dict( + globals: &PyObject, + vm: &VirtualMachine, + func_name: &'static str, + ) -> PyResult<()> { + if !globals.fast_isinstance(vm.ctx.types.dict_type) { + return Err(match func_name { + "eval" => { + let is_mapping = globals.mapping_unchecked().check(); + vm.new_type_error(if is_mapping { + "globals must be a real dict; try eval(expr, {}, mapping)" + } else { + "globals must be a dict" + }) + } + "exec" => vm.new_type_error(format!( + "exec() globals must be a dict, not {}", + globals.class().name() + )), + _ => vm.new_type_error("globals must be a dict"), + }); + } + Ok(()) + } + let (globals, locals) = match self.globals { Some(globals) => { - Self::validate_globals_dict(&globals, vm, func_name)?; + validate_globals_dict(&globals, vm, func_name)?; let globals = PyDictRef::try_from_object(vm, globals)?; if !globals.contains_key(identifier!(vm, __builtins__), vm) { @@ -570,6 +479,61 @@ mod builtins { } } + #[derive(FromArgs)] + struct ExecArgs { + #[pyarg(positional)] + source: Either>, + #[pyarg(any, default)] + globals: Option, + #[pyarg(any, default)] + locals: Option, + #[pyarg(named, optional)] + closure: OptionalOption, + } + + fn exec_closure( + code_obj: &PyRef, + closure: Option, + vm: &VirtualMachine, + ) -> PyResult>>> { + let num_free = code_obj.freevars.len(); + let Some(closure) = closure else { + if num_free == 0 { + return Ok(None); + } + return Err(vm.new_type_error(format!( + "code object requires a closure of exactly length {num_free}" + ))); + }; + + if num_free == 0 { + return Err(vm.new_type_error("cannot use a closure with this code object")); + } + + let closure_tuple = closure + .downcast_exact::(vm) + .map_err(|_| { + vm.new_type_error(format!( + "code object requires a closure of exactly length {num_free}" + )) + })? + .into_pyref(); + if closure_tuple.len() != num_free { + return Err(vm.new_type_error(format!( + "code object requires a closure of exactly length {num_free}" + ))); + } + + closure_tuple + .try_into_typed::(vm) + .map(Some) + .map_err(|_| { + vm.new_type_error(format!( + "code object requires a closure of exactly length {num_free}" + )) + }) + } + #[pyfunction] fn eval( source: Either>, @@ -581,38 +545,93 @@ mod builtins { // source as string let code = match source { Either::A(either) => { - let source: &[u8] = &either.borrow_bytes(); - if source.contains(&0) { - return Err(vm.new_exception_msg( - vm.ctx.exceptions.syntax_error.to_owned(), - "source code string cannot contain null bytes".into(), - )); - } - - let source = core::str::from_utf8(source).map_err(|err| { - let msg = format!( - "(unicode error) 'utf-8' codec can't decode byte 0x{:x?} in position {}: invalid start byte", - source[err.valid_up_to()], - err.valid_up_to() - ); - - vm.new_exception_msg(vm.ctx.exceptions.syntax_error.to_owned(), msg.into()) - })?; - Ok(Either::A(vm.ctx.new_utf8_str(source.trim_start()))) + let source = match &either { + ArgStrOrBytesLike::Str(source) => { + if source.as_bytes().contains(&0) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + "source code string cannot contain null bytes".into(), + )); + } + let source = source.expect_str().trim_start_matches([' ', '\t']); + audit_compile_source(vm, source.as_bytes(), "")?; + source.to_owned() + } + ArgStrOrBytesLike::Buf(source) => { + let source: &[u8] = &source.borrow_buf(); + if source.contains(&0) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + "source code string cannot contain null bytes".into(), + )); + } + let source = trim_eval_source_bytes(source); + audit_compile_source(vm, source, "")?; + decode_eval_exec_source_bytes(vm, source, "eval")? + } + }; + Ok(Either::A(vm.ctx.new_utf8_str(source))) } Either::B(code) => Ok(Either::B(code)), }?; - run_code(vm, code, scope, crate::compiler::Mode::Eval, "eval") + run_code(vm, code, scope, crate::compiler::Mode::Eval, "eval", None) } #[pyfunction] - fn exec( - source: Either>, - scope: ScopeArgs, - vm: &VirtualMachine, - ) -> PyResult { - let scope = scope.make_scope(vm, "exec")?; - run_code(vm, source, scope, crate::compiler::Mode::Exec, "exec") + fn exec(args: ExecArgs, vm: &VirtualMachine) -> PyResult { + let ExecArgs { + source, + globals, + locals, + closure, + } = args; + let scope = ScopeArgs { globals, locals }.make_scope(vm, "exec")?; + let closure = closure.flatten(); + let (source, closure) = match source { + Either::A(either) => { + if closure.is_some() { + return Err( + vm.new_type_error("closure can only be used when source is a code object") + ); + } + let source = match &either { + ArgStrOrBytesLike::Str(source) => { + if source.as_bytes().contains(&0) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + "source code string cannot contain null bytes".into(), + )); + } + audit_compile_source(vm, source.as_bytes(), "")?; + source.expect_str().to_owned() + } + ArgStrOrBytesLike::Buf(source) => { + let source: &[u8] = &source.borrow_buf(); + if source.contains(&0) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + "source code string cannot contain null bytes".into(), + )); + } + audit_compile_source(vm, source, "")?; + decode_eval_exec_source_bytes(vm, source, "exec")? + } + }; + (Either::A(vm.ctx.new_utf8_str(source)), None) + } + Either::B(code) => { + let closure = exec_closure(&code, closure, vm)?; + (Either::B(code), closure) + } + }; + run_code( + vm, + source, + scope, + crate::compiler::Mode::Exec, + "exec", + closure, + ) } fn run_code( @@ -621,28 +640,39 @@ mod builtins { scope: crate::scope::Scope, #[allow(unused_variables)] mode: crate::compiler::Mode, func: &str, + closure: Option>>, ) -> PyResult { // Determine code object: let code_obj = match source { #[cfg(feature = "rustpython-compiler")] Either::A(string) => { let source = string.as_str(); - vm.compile(source, mode, "") - .map_err(|err| vm.new_syntax_error(&err, Some(source)))? + let mut opts = vm.compile_opts(); + if let Some(frame) = vm.current_frame() { + opts.future_features = bytecode::CodeFlags::from_bits_truncate( + frame.code.flags.bits() & compile_future_feature_mask().bits(), + ); + } + vm.compile_with_opts(source, mode, "", opts) + .map_err(|err| err.into_pyexception(vm, Some(source)))? } #[cfg(not(feature = "rustpython-compiler"))] Either::A(_) => return Err(vm.new_type_error(CODEGEN_NOT_SUPPORTED)), Either::B(code_obj) => code_obj, }; - if !code_obj.freevars.is_empty() { + vm.sys_module + .get_attr("audit", vm)? + .call((vm.ctx.new_str("exec"), code_obj.clone()), vm)?; + + if closure.is_none() && !code_obj.freevars.is_empty() { return Err(vm.new_type_error(format!( "code object passed to {func}() may not contain free variables" ))); } // Run the code: - vm.run_code_obj(code_obj, scope) + vm.run_code_obj_with_closure(code_obj, scope, closure) } #[pyfunction] @@ -1002,8 +1032,8 @@ mod builtins { modulus, } = args; let modulus = modulus - .as_ref() - .map_or_else(|| vm.ctx.none.as_object(), |m| m); + .as_deref() + .unwrap_or_else(|| vm.ctx.none.as_object()); vm._pow(&x, &y, modulus) } diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 58324a3c071..c1c5997d08a 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -751,7 +751,7 @@ pub mod sys { .read_to_string(&mut source) .map_err(|e| vm.new_os_error(format!("Error reading from stdin: {e}")))?; vm.compile(&source, crate::compiler::Mode::Single, "") - .map_err(|e| vm.new_os_error(format!("Error running stdin: {e}")))?; + .map_err(|e| e.into_pyexception(vm, Some(&source)))?; Ok(()) } @@ -1228,7 +1228,7 @@ pub mod sys { vm.state.int_max_str_digits.store(maxdigits); Ok(()) } else { - let error = format!("maxdigits must be 0 or larger than {threshold:?}"); + let error = format!("maxdigits must be >= {threshold} or 0 for unlimited"); Err(vm.new_value_error(error)) } } @@ -1744,12 +1744,54 @@ pub mod sys { } for hook in hooks { - hook.call((event.clone(), args.clone()), vm)?; + call_audit_hook(&hook, event.clone().into(), args, vm)?; } Ok(()) } + fn audit_hook_can_trace(hook: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + match hook.get_attr("__cantrace__", vm) { + Ok(can_trace) => can_trace.try_to_bool(vm), + Err(exc) + if exc + .class() + .fast_issubclass(vm.ctx.exceptions.attribute_error) => + { + Ok(false) + } + Err(exc) => Err(exc), + } + } + + fn call_audit_hook( + hook: &PyObjectRef, + event: PyObjectRef, + args: &PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + // CPython suppresses tracing while dispatching Python audit hooks, + // except for hooks that explicitly opt in with __cantrace__. + vm.enter_tracing(); + let can_trace = audit_hook_can_trace(hook, vm); + let result = match can_trace { + Ok(can_trace) => { + if can_trace { + vm.leave_tracing(); + } + let result = hook.call((event, args.clone()), vm).map(|_| ()); + if can_trace { + vm.enter_tracing(); + } + result + } + Err(exc) => Err(exc), + }; + + vm.leave_tracing(); + result + } + #[pyfunction] fn audit(event: PyStrRef, args: PosArgs, vm: &VirtualMachine) -> PyResult<()> { if vm.audit_hooks.borrow().is_empty() { @@ -1773,10 +1815,13 @@ pub mod sys { let event: PyObjectRef = vm.ctx.new_str("sys.addaudithook").into(); for existing_hook in hooks { - let Err(exc) = existing_hook.call((event.clone(), args.clone()), vm) else { + let Err(exc) = call_audit_hook(&existing_hook, event.clone(), &args, vm) else { continue; }; - if exc.class().fast_issubclass(vm.ctx.exceptions.runtime_error) { + if exc + .class() + .fast_issubclass(vm.ctx.exceptions.exception_type) + { return Ok(()); } return Err(exc); diff --git a/crates/vm/src/stdlib/sys/monitoring.rs b/crates/vm/src/stdlib/sys/monitoring.rs index 6e61692507d..accf2001675 100644 --- a/crates/vm/src/stdlib/sys/monitoring.rs +++ b/crates/vm/src/stdlib/sys/monitoring.rs @@ -747,7 +747,7 @@ fn fire( cb_extra: &[PyObjectRef], ) -> PyResult<()> { // Prevent recursive event firing - if FIRING.with(|f| f.get()) { + if vm.tracing_is_suppressed() || FIRING.with(|f| f.get()) { return Ok(()); } @@ -795,6 +795,7 @@ fn fire( let args = FuncArgs::from(args_vec); FIRING.with(|f| f.set(true)); + vm.enter_tracing(); let result = (|| { for (tool, cb) in callbacks { let result = cb.call(args.clone(), vm)?; @@ -817,6 +818,7 @@ fn fire( } Ok(()) })(); + vm.leave_tracing(); FIRING.with(|f| f.set(false)); result } diff --git a/crates/vm/src/vm/compile.rs b/crates/vm/src/vm/compile.rs index 2dbdb17ff4a..bfe987ef007 100644 --- a/crates/vm/src/vm/compile.rs +++ b/crates/vm/src/vm/compile.rs @@ -2,19 +2,403 @@ //! //! For code execution functions, see python_run.rs +use core::fmt; + use crate::{ - PyRef, VirtualMachine, - builtins::PyCode, + AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, + builtins::{PyBaseExceptionRef, PyCode}, compiler::{self, CompileError, CompileOpts}, + vm::compile_mode::{ + PY_CF_ALLOW_INCOMPLETE_INPUT, PY_CF_ALLOW_TOP_LEVEL_AWAIT, PY_CF_DONT_IMPLY_DEDENT, + PY_CF_IGNORE_COOKIE, PY_CF_ONLY_AST, PY_CF_OPTIMIZED_AST, PY_CF_TYPE_COMMENTS, + PY_EVAL_INPUT, PY_FILE_INPUT, PY_FUNC_TYPE_INPUT, PY_SINGLE_INPUT, + compile_future_features_from_flags, + }, }; +#[derive(Debug)] +pub enum VmCompileError { + Compile(CompileError), + Warning(CompileWarningError), +} + +#[derive(Debug)] +pub struct CompileWarningError { + exception: PyBaseExceptionRef, + filename: String, + lineno: usize, + offset: usize, +} + +impl From for VmCompileError { + fn from(err: CompileError) -> Self { + Self::Compile(err) + } +} + +impl fmt::Display for VmCompileError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Compile(err) => err.fmt(f), + Self::Warning(_) => f.write_str("compiler warning raised as an exception"), + } + } +} + +impl VmCompileError { + pub fn into_pyexception(self, vm: &VirtualMachine, source: Option<&str>) -> PyBaseExceptionRef { + match self { + Self::Compile(err) => vm.new_syntax_error(&err, source), + Self::Warning(err) => err.into_pyexception(vm, source), + } + } + + pub fn into_pyexception_maybe_incomplete( + self, + vm: &VirtualMachine, + source: Option<&str>, + allow_incomplete: bool, + ) -> PyBaseExceptionRef { + match self { + Self::Compile(err) => { + vm.new_syntax_error_maybe_incomplete(&err, source, allow_incomplete) + } + Self::Warning(err) => err.into_pyexception(vm, source), + } + } +} + +impl CompileWarningError { + fn into_codegen_error( + self, + location: compiler::core::SourceLocation, + source_path: String, + vm: &VirtualMachine, + ) -> compiler::codegen::error::CodegenError { + let message = self.exception.as_object().str(vm).map_or_else( + |_| "compiler warning raised as an exception".to_owned(), + |message| message.as_wtf8().to_string(), + ); + compiler::codegen::error::CodegenError { + location: Some(location), + error: compiler::codegen::error::CodegenErrorType::SyntaxError(message), + source_path, + } + } + + fn into_pyexception(self, vm: &VirtualMachine, source: Option<&str>) -> PyBaseExceptionRef { + if !self + .exception + .fast_isinstance(vm.ctx.exceptions.syntax_warning) + { + return self.exception; + } + let Ok(message) = self.exception.as_object().str(vm) else { + return self.exception; + }; + let syntax_error = vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + message.as_wtf8().to_owned(), + ); + syntax_error + .as_object() + .set_attr("lineno", vm.ctx.new_int(self.lineno), vm) + .unwrap(); + syntax_error + .as_object() + .set_attr("offset", vm.ctx.new_int(self.offset), vm) + .unwrap(); + syntax_error + .as_object() + .set_attr("filename", vm.ctx.new_str(self.filename), vm) + .unwrap(); + let text = source + .and_then(|source| source.split('\n').nth(self.lineno.saturating_sub(1))) + .map_or_else( + || vm.ctx.none(), + |line| { + vm.ctx + .new_str(format!("{}\n", line.trim_end_matches('\r'))) + .into() + }, + ); + syntax_error.as_object().set_attr("text", text, vm).unwrap(); + syntax_error + } +} + impl VirtualMachine { + #[cfg(feature = "parser")] + fn detect_source_encoding(source: &[u8]) -> Option { + fn find_encoding_in_line(line: &[u8]) -> Option { + let hash_pos = line.iter().position(|&b| b == b'#')?; + if !line[..hash_pos] + .iter() + .all(|&b| b == b' ' || b == b'\t' || b == b'\x0c' || b == b'\r') + { + return None; + } + let after_hash = &line[hash_pos..]; + let coding_pos = after_hash.windows(6).position(|w| w == b"coding")?; + let after_coding = &after_hash[coding_pos + 6..]; + let rest = if after_coding.first() == Some(&b':') || after_coding.first() == Some(&b'=') + { + &after_coding[1..] + } else { + return None; + }; + let name: String = rest + .iter() + .copied() + .skip_while(|&b| b == b' ' || b == b'\t') + .take_while(|&b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_' || b == b'.') + .map(|b| b as char) + .collect(); + (!name.is_empty()).then(|| VirtualMachine::normalize_source_encoding(&name)) + } + + let mut lines = source.splitn(3, |&b| b == b'\n'); + if let Some(first) = lines.next() { + let first = first.strip_prefix(b"\xef\xbb\xbf").unwrap_or(first); + if let Some(enc) = find_encoding_in_line(first) { + return Some(enc); + } + let trimmed = first + .iter() + .skip_while(|&&b| b == b' ' || b == b'\t' || b == b'\x0c' || b == b'\r') + .copied() + .collect::>(); + if !trimmed.is_empty() && trimmed[0] != b'#' { + return None; + } + } + lines.next().and_then(find_encoding_in_line) + } + + #[cfg(feature = "parser")] + fn normalize_source_encoding(name: &str) -> String { + let mut normalized = String::with_capacity(name.len().min(12)); + for ch in name.chars().take(12) { + if ch == '_' { + normalized.push('-'); + } else { + normalized.push(ch.to_ascii_lowercase()); + } + } + + if normalized == "utf-8" || normalized.starts_with("utf-8-") { + "utf-8".to_owned() + } else if normalized == "latin-1" + || normalized == "iso-8859-1" + || normalized == "iso-latin-1" + || normalized.starts_with("latin-1-") + || normalized.starts_with("iso-8859-1-") + || normalized.starts_with("iso-latin-1-") + { + "iso-8859-1".to_owned() + } else { + name.to_owned() + } + } + + #[cfg(feature = "parser")] + fn is_utf8_encoding(name: &str) -> bool { + name == "utf-8" + } + + #[cfg(feature = "parser")] + pub(crate) fn decode_source_bytes( + &self, + source: &[u8], + filename: &str, + ignore_cookie: bool, + ) -> PyResult { + let has_bom = source.starts_with(b"\xef\xbb\xbf"); + let encoding = if ignore_cookie { + None + } else { + Self::detect_source_encoding(source) + }; + let is_utf8 = encoding.as_deref().is_none_or(Self::is_utf8_encoding); + if has_bom && !is_utf8 { + let enc = encoding.as_deref().unwrap_or("utf-8"); + return Err(self.new_exception_msg( + self.ctx.exceptions.syntax_error.to_owned(), + format!("encoding problem: {enc} with BOM").into(), + )); + } + + if is_utf8 { + let src = if has_bom { &source[3..] } else { source }; + match core::str::from_utf8(src) { + Ok(s) => Ok(s.to_owned()), + Err(e) => { + let bad_byte = src[e.valid_up_to()]; + let line = src[..e.valid_up_to()] + .iter() + .filter(|&&b| b == b'\n') + .count() + + 1; + Err(self.new_exception_msg( + self.ctx.exceptions.syntax_error.to_owned(), + format!( + "Non-UTF-8 code starting with '\\x{bad_byte:02x}' \ + on line {line}, but no encoding declared; \ + see https://peps.python.org/pep-0263/ for details \ + ({filename}, line {line})" + ) + .into(), + )) + } + } + } else { + let encoding = encoding.as_deref().unwrap(); + let bytes = self.ctx.new_bytes(source.to_vec()); + let decoded = self + .state + .codec_registry + .decode_text(bytes.into(), encoding, None, self) + .map_err(|exc| { + if exc.fast_isinstance(self.ctx.exceptions.lookup_error) { + self.new_exception_msg( + self.ctx.exceptions.syntax_error.to_owned(), + format!("unknown encoding for '{filename}': {encoding}").into(), + ) + } else { + exc + } + })?; + Ok(decoded.to_string_lossy().into_owned()) + } + } + + #[cfg(feature = "parser")] + pub fn compile_string_object_with_flags( + &self, + source: &[u8], + filename: &str, + start: i32, + flags: i32, + feature_version: i32, + optimize: i32, + ) -> PyResult { + use crate::convert::ToPyException; + use crate::stdlib::_ast; + + let source = + self.decode_source_bytes(source, filename, (flags & PY_CF_IGNORE_COOKIE) != 0)?; + let source = source.as_str(); + let optimize = match optimize { + -1 => self.state.config.settings.optimize.min(2), + 0..=2 => optimize as u8, + _ => return Err(self.new_value_error("compile(): invalid optimize value")), + }; + let allow_incomplete = (flags & PY_CF_ALLOW_INCOMPLETE_INPUT) != 0; + let type_comments = (flags & PY_CF_TYPE_COMMENTS) != 0; + let dont_imply_dedent = (flags & PY_CF_DONT_IMPLY_DEDENT) != 0; + let is_ast_only = (flags & PY_CF_ONLY_AST) != 0; + let optimized_ast = (flags & PY_CF_OPTIMIZED_AST) == PY_CF_OPTIMIZED_AST; + let future_features = compile_future_features_from_flags(flags); + let explicit_future_annotations = + future_features.contains(crate::bytecode::CodeFlags::FUTURE_ANNOTATIONS); + let target_version = if is_ast_only { + Some(ruff_python_ast::PythonVersion { + major: 3, + minor: u8::try_from(feature_version).unwrap_or(crate::version::MINOR as u8), + }) + } else { + None + }; + + if is_ast_only { + if start == PY_FUNC_TYPE_INPUT { + return _ast::parse_func_type(self, source, optimize, target_version) + .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(self)); + } + let (parser_mode, interactive) = match start { + PY_SINGLE_INPUT => (ruff_python_parser::Mode::Module, true), + PY_FILE_INPUT => (ruff_python_parser::Mode::Module, false), + PY_EVAL_INPUT => (ruff_python_parser::Mode::Expression, false), + _ => { + return Err( + self.new_system_error("Invalid start argument passed to Py_CompileString") + ); + } + }; + let parsed = _ast::parse( + self, + source, + parser_mode, + optimize, + target_version, + type_comments, + optimized_ast, + interactive, + explicit_future_annotations, + dont_imply_dedent, + ) + .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(self))?; + if start == PY_SINGLE_INPUT { + return _ast::wrap_interactive(self, parsed); + } + return Ok(parsed); + } + + if type_comments { + let parser_mode = match start { + PY_SINGLE_INPUT | PY_FILE_INPUT => ruff_python_parser::Mode::Module, + PY_EVAL_INPUT => ruff_python_parser::Mode::Expression, + _ => { + return Err( + self.new_system_error("Invalid start argument passed to Py_CompileString") + ); + } + }; + let _ = _ast::parse( + self, + source, + parser_mode, + optimize, + None, + type_comments, + false, + start == PY_SINGLE_INPUT, + explicit_future_annotations, + dont_imply_dedent, + ) + .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(self))?; + } + + let mode = match start { + PY_SINGLE_INPUT => compiler::Mode::Single, + PY_FILE_INPUT => compiler::Mode::Exec, + PY_EVAL_INPUT => compiler::Mode::Eval, + PY_FUNC_TYPE_INPUT => compiler::Mode::BlockExpr, + _ => { + return Err( + self.new_system_error("Invalid start argument passed to Py_CompileString") + ); + } + }; + let mut opts = self.compile_opts(); + opts.optimize = optimize; + opts.allow_top_level_await = (flags & PY_CF_ALLOW_TOP_LEVEL_AWAIT) != 0; + opts.future_features = future_features; + opts.dont_imply_dedent = dont_imply_dedent; + let code = self + .compile_with_opts(source, mode, filename, opts) + .map_err(|err| { + err.into_pyexception_maybe_incomplete(self, Some(source), allow_incomplete) + })?; + Ok(code.into()) + } + pub fn compile( &self, source: &str, mode: compiler::Mode, - source_path: &str, - ) -> Result, CompileError> { + source_path: impl Into, + ) -> Result, VmCompileError> { self.compile_with_opts(source, mode, source_path, self.compile_opts()) } @@ -22,18 +406,37 @@ impl VirtualMachine { &self, source: &str, mode: compiler::Mode, - source_path: &str, + source_path: impl Into, opts: CompileOpts, - ) -> Result, CompileError> { - let code = compiler::compile(source, mode, source_path, opts) - .map(|code| PyCode::new_ref_from_bytecode(self, code)); - + ) -> Result, VmCompileError> { + let source_path = source_path.into(); #[cfg(feature = "parser")] - if code.is_ok() { - self.emit_string_escape_warnings(source, source_path); + { + self.emit_tokenizer_syntax_warnings(source, &source_path) + .map_err(VmCompileError::Warning)?; + self.emit_string_escape_warnings(source, &source_path) + .map_err(VmCompileError::Warning)?; } - - code + #[cfg(feature = "parser")] + let code = { + let mut syntax_warning_handler = |location, message| { + escape_warnings::warn_syntax_at_location(&source_path, location, message, self) + .map_err(|err| err.into_codegen_error(location, source_path.clone(), self)) + }; + compiler::compile_with_syntax_warning_handler( + source, + mode, + &source_path, + opts, + &mut syntax_warning_handler, + ) + }; + #[cfg(not(feature = "parser"))] + let code = compiler::compile(source, mode, &source_path, opts); + let code = code + .map(|code| PyCode::new_ref_from_bytecode(self, code)) + .map_err(VmCompileError::Compile)?; + Ok(code) } } @@ -48,6 +451,8 @@ mod escape_warnings { use super::*; use crate::warn; use ruff_python_ast::{self as ast, visitor::Visitor}; + #[cfg(test)] + use ruff_text_size::Ranged; use ruff_text_size::TextRange; /// Calculate 1-indexed line number at byte offset in source. @@ -59,6 +464,30 @@ mod escape_warnings { + 1 } + fn line_offset_at(source: &str, offset: usize) -> (usize, usize) { + let offset = offset.min(source.len()); + let prefix = &source[..offset]; + let lineno = prefix.bytes().filter(|&b| b == b'\n').count() + 1; + let line_start = prefix.rfind('\n').map_or(0, |index| index + 1); + let column = source[line_start..offset].chars().count() + 1; + (lineno, column) + } + + fn compile_warning_error( + exception: PyBaseExceptionRef, + source: &str, + filename: &str, + offset: usize, + ) -> CompileWarningError { + let (lineno, offset) = line_offset_at(source, offset); + CompileWarningError { + exception, + filename: filename.to_owned(), + lineno, + offset, + } + } + /// Get content bounds (start, end byte offsets) of a quoted string literal, /// excluding prefix characters and quote delimiters. fn content_bounds(source: &str, range: TextRange) -> Option<(usize, usize)> { @@ -180,7 +609,7 @@ mod escape_warnings { offset: usize, filename: &str, vm: &VirtualMachine, - ) { + ) -> Result<(), CompileWarningError> { let lineno = line_number_at(source, offset); let message = vm.ctx.new_str(format!( "\"\\{ch}\" is an invalid escape sequence. \ @@ -188,7 +617,7 @@ mod escape_warnings { Did you mean \"\\\\{ch}\"? A raw string is also an option." )); let fname = vm.ctx.new_str(filename); - let _ = warn::warn_explicit( + warn::warn_explicit( Some(vm.ctx.exceptions.syntax_warning.to_owned()), message.into(), fname, @@ -198,23 +627,820 @@ mod escape_warnings { None, None, vm, - ); + ) + .map_err(|err| compile_warning_error(err, source, filename, offset)) + } + + #[cfg(test)] + #[derive(Copy, Clone, Eq, PartialEq)] + enum InferredType { + Tuple, + List, + Dict, + Set, + Generator, + Function, + Template, + Str, + Bytes, + Int, + Float, + Complex, + Bool, + NoneType, + Ellipsis, + Slice, + } + + #[cfg(test)] + impl InferredType { + fn name(self) -> &'static str { + match self { + Self::Tuple => "tuple", + Self::List => "list", + Self::Dict => "dict", + Self::Set => "set", + Self::Generator => "generator", + Self::Function => "function", + Self::Template => "string.templatelib.Template", + Self::Str => "str", + Self::Bytes => "bytes", + Self::Int => "int", + Self::Float => "float", + Self::Complex => "complex", + Self::Bool => "bool", + Self::NoneType => "NoneType", + Self::Ellipsis => "ellipsis", + Self::Slice => "slice", + } + } + + fn is_long_subclass(self) -> bool { + matches!(self, Self::Int | Self::Bool) + } + } + + #[cfg(test)] + fn infer_type(expr: &ast::Expr) -> Option { + match expr { + ast::Expr::Tuple(_) => Some(InferredType::Tuple), + ast::Expr::List(_) | ast::Expr::ListComp(_) => Some(InferredType::List), + ast::Expr::Dict(_) | ast::Expr::DictComp(_) => Some(InferredType::Dict), + ast::Expr::Set(_) | ast::Expr::SetComp(_) => Some(InferredType::Set), + ast::Expr::Generator(_) => Some(InferredType::Generator), + ast::Expr::Lambda(_) => Some(InferredType::Function), + ast::Expr::TString(_) => Some(InferredType::Template), + ast::Expr::FString(_) | ast::Expr::StringLiteral(_) => Some(InferredType::Str), + ast::Expr::BytesLiteral(_) => Some(InferredType::Bytes), + ast::Expr::NumberLiteral(number) => match number.value { + ast::Number::Int(_) => Some(InferredType::Int), + ast::Number::Float(_) => Some(InferredType::Float), + ast::Number::Complex { .. } => Some(InferredType::Complex), + }, + ast::Expr::BooleanLiteral(_) => Some(InferredType::Bool), + ast::Expr::NoneLiteral(_) => Some(InferredType::NoneType), + ast::Expr::EllipsisLiteral(_) => Some(InferredType::Ellipsis), + ast::Expr::Slice(_) => Some(InferredType::Slice), + _ => None, + } + } + + #[cfg(test)] + fn is_constant_expr(expr: &ast::Expr) -> bool { + matches!( + expr, + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + ) + } + + #[cfg(test)] + fn check_is_arg(expr: &ast::Expr) -> bool { + if let ast::Expr::Tuple(tuple) = expr { + return !tuple.elts.iter().all(is_constant_expr); + } + if !is_constant_expr(expr) { + return true; + } + matches!( + expr, + ast::Expr::NoneLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::EllipsisLiteral(_) + ) + } + + #[cfg(test)] + fn warn_syntax( + source: &str, + filename: &str, + range: TextRange, + message: String, + vm: &VirtualMachine, + ) -> Result<(), CompileWarningError> { + warn_syntax_at_offset(source, filename, range.start().to_usize(), message, vm) + } + + fn warn_syntax_at_offset( + source: &str, + filename: &str, + offset: usize, + message: String, + vm: &VirtualMachine, + ) -> Result<(), CompileWarningError> { + let lineno = line_number_at(source, offset); + let fname = vm.ctx.new_str(filename); + let message = vm.ctx.new_str(message); + warn::warn_explicit( + Some(vm.ctx.exceptions.syntax_warning.to_owned()), + message.into(), + fname, + lineno, + None, + vm.ctx.none(), + None, + None, + vm, + ) + .map_err(|err| compile_warning_error(err, source, filename, offset)) + } + + pub(super) fn warn_syntax_at_location( + filename: &str, + location: compiler::core::SourceLocation, + message: String, + vm: &VirtualMachine, + ) -> Result<(), CompileWarningError> { + let fname = vm.ctx.new_str(filename); + let message = vm.ctx.new_str(message); + warn::warn_explicit( + Some(vm.ctx.exceptions.syntax_warning.to_owned()), + message.into(), + fname, + location.line.get(), + None, + vm.ctx.none(), + None, + None, + vm, + ) + .map_err(|exception| CompileWarningError { + exception, + filename: filename.to_owned(), + lineno: location.line.get(), + offset: location.character_offset.get(), + }) + } + + fn is_ascii_identifier_char(byte: u8) -> bool { + byte == b'_' || byte.is_ascii_alphanumeric() + } + + fn numeric_keyword_suffix(rest: &[u8]) -> bool { + rest.starts_with(b"and") + || rest.starts_with(b"else") + || rest.starts_with(b"for") + || rest.starts_with(b"if") + || rest.starts_with(b"in") + || rest.starts_with(b"is") + || rest.starts_with(b"or") + || rest.starts_with(b"not") + } + + fn consume_decimal_digits(bytes: &[u8], mut index: usize) -> usize { + while index < bytes.len() { + match bytes[index] { + b'0'..=b'9' => index += 1, + b'_' if bytes + .get(index + 1) + .is_some_and(|byte| byte.is_ascii_digit()) => + { + index += 2; + } + _ => break, + } + } + index + } + + fn consume_radix_digits( + bytes: &[u8], + mut index: usize, + is_digit: impl Fn(u8) -> bool, + ) -> usize { + while index < bytes.len() { + if is_digit(bytes[index]) { + index += 1; + } else if bytes.get(index) == Some(&b'_') + && bytes.get(index + 1).is_some_and(|&byte| is_digit(byte)) + { + index += 2; + } else { + break; + } + } + index + } + + fn number_literal_end(bytes: &[u8], start: usize) -> Option<(&'static str, usize)> { + if bytes.get(start) == Some(&b'.') { + if !bytes + .get(start + 1) + .is_some_and(|byte| byte.is_ascii_digit()) + { + return None; + } + let mut index = consume_decimal_digits(bytes, start + 1); + index = consume_exponent(bytes, index); + if matches!(bytes.get(index), Some(b'j' | b'J')) { + return Some(("imaginary", index + 1)); + } + return Some(("decimal", index)); + } + + if !bytes.get(start).is_some_and(|byte| byte.is_ascii_digit()) { + return None; + } + + if bytes.get(start) == Some(&b'0') { + match bytes.get(start + 1) { + Some(b'x' | b'X') => { + let end = + consume_radix_digits(bytes, start + 2, |byte| byte.is_ascii_hexdigit()); + return Some(("hexadecimal", end)); + } + Some(b'o' | b'O') => { + let end = + consume_radix_digits(bytes, start + 2, |byte| matches!(byte, b'0'..=b'7')); + return Some(("octal", end)); + } + Some(b'b' | b'B') => { + let end = + consume_radix_digits(bytes, start + 2, |byte| matches!(byte, b'0' | b'1')); + return Some(("binary", end)); + } + _ => {} + } + } + + let mut index = consume_decimal_digits(bytes, start); + if bytes.get(index) == Some(&b'.') { + index = consume_decimal_digits(bytes, index + 1); + } + index = consume_exponent(bytes, index); + if matches!(bytes.get(index), Some(b'j' | b'J')) { + return Some(("imaginary", index + 1)); + } + Some(("decimal", index)) + } + + fn consume_exponent(bytes: &[u8], index: usize) -> usize { + if !matches!(bytes.get(index), Some(b'e' | b'E')) { + return index; + } + let mut cursor = index + 1; + if matches!(bytes.get(cursor), Some(b'+' | b'-')) { + cursor += 1; + } + if bytes.get(cursor).is_some_and(|byte| byte.is_ascii_digit()) { + consume_decimal_digits(bytes, cursor) + } else { + index + } + } + + fn skip_quoted_string(bytes: &[u8], mut index: usize) -> usize { + let quote = bytes[index]; + let triple = bytes.get(index + 1) == Some("e) && bytes.get(index + 2) == Some("e); + let quote_len = if triple { 3 } else { 1 }; + index += quote_len; + while index < bytes.len() { + if bytes[index] == b'\\' { + index = (index + 2).min(bytes.len()); + } else if triple + && bytes.get(index) == Some("e) + && bytes.get(index + 1) == Some("e) + && bytes.get(index + 2) == Some("e) + { + return index + 3; + } else if !triple && bytes[index] == quote { + return index + 1; + } else { + index += 1; + } + } + index + } + + fn emit_numeric_literal_warnings( + source: &str, + filename: &str, + vm: &VirtualMachine, + ) -> Result<(), CompileWarningError> { + let bytes = source.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'#' => { + while index < bytes.len() && bytes[index] != b'\n' { + index += 1; + } + } + b'\'' | b'"' => { + index = skip_quoted_string(bytes, index); + } + byte if byte >= 0x80 || byte == b'_' || byte.is_ascii_alphabetic() => { + index += 1; + while index < bytes.len() + && (bytes[index] >= 0x80 || is_ascii_identifier_char(bytes[index])) + { + index += 1; + } + } + b'.' | b'0'..=b'9' => { + let Some((kind, end)) = number_literal_end(bytes, index) else { + index += 1; + continue; + }; + if end > index && numeric_keyword_suffix(&bytes[end..]) { + warn_syntax_at_offset( + source, + filename, + index, + format!("invalid {kind} literal"), + vm, + )?; + } + index = end.max(index + 1); + } + _ => index += 1, + } + } + Ok(()) + } + + #[cfg(test)] + #[derive(Copy, Clone)] + struct ControlFlowInFinallyContext { + in_finally: bool, + in_funcdef: bool, + in_loop: bool, + } + + #[cfg(test)] + struct CompilerSyntaxWarningVisitor<'a> { + source: &'a str, + filename: &'a str, + vm: &'a VirtualMachine, + error: Option, + cf_finally: Vec, + future_annotations: bool, + skip_codegen_warnings: bool, + } + + #[cfg(test)] + impl<'a> CompilerSyntaxWarningVisitor<'a> { + fn record_warning(&mut self, result: Result<(), CompileWarningError>) { + if self.error.is_none() + && let Err(err) = result + { + self.error = Some(err); + } + } + + fn push_cf_context(&mut self, in_finally: bool, in_funcdef: bool, in_loop: bool) { + self.cf_finally.push(ControlFlowInFinallyContext { + in_finally, + in_funcdef, + in_loop, + }); + } + + fn before_return(&mut self, range: TextRange) { + if let Some(ctx) = self.cf_finally.last() + && ctx.in_finally + && !ctx.in_funcdef + { + let result = warn_syntax( + self.source, + self.filename, + range, + "'return' in a 'finally' block".to_owned(), + self.vm, + ); + self.record_warning(result); + } + } + + fn before_loop_exit(&mut self, range: TextRange, kw: &str) { + if let Some(ctx) = self.cf_finally.last() + && ctx.in_finally + && !ctx.in_loop + { + let result = warn_syntax( + self.source, + self.filename, + range, + format!("'{kw}' in a 'finally' block"), + self.vm, + ); + self.record_warning(result); + } + } + + fn visit_body_with_context( + &mut self, + body: &'a [ast::Stmt], + in_finally: bool, + in_funcdef: bool, + in_loop: bool, + ) { + self.push_cf_context(in_finally, in_funcdef, in_loop); + self.visit_body(body); + self.cf_finally.pop(); + } + + fn visit_parameter_annotation(&mut self, parameter: &'a ast::Parameter) { + if !self.future_annotations + && let Some(annotation) = ¶meter.annotation + { + self.visit_expr(annotation); + } + } + + fn visit_positional_defaults(&mut self, parameters: &'a ast::Parameters) { + for parameter in parameters.posonlyargs.iter().chain(¶meters.args) { + if let Some(default) = ¶meter.default { + self.visit_expr(default); + } + } + } + + fn visit_kwonly_defaults(&mut self, parameters: &'a ast::Parameters) { + for parameter in ¶meters.kwonlyargs { + if let Some(default) = ¶meter.default { + self.visit_expr(default); + } + } + } + + fn visit_function_annotations_cpython_order( + &mut self, + parameters: &'a ast::Parameters, + returns: Option<&'a ast::Expr>, + ) { + if self.future_annotations { + return; + } + for parameter in ¶meters.args { + self.visit_parameter_annotation(¶meter.parameter); + } + for parameter in ¶meters.posonlyargs { + self.visit_parameter_annotation(¶meter.parameter); + } + if let Some(vararg) = ¶meters.vararg { + self.visit_parameter_annotation(vararg); + } + for parameter in ¶meters.kwonlyargs { + self.visit_parameter_annotation(¶meter.parameter); + } + if let Some(kwarg) = ¶meters.kwarg { + self.visit_parameter_annotation(kwarg); + } + if let Some(returns) = returns { + self.visit_annotation(returns); + } + } + + fn check_compare(&mut self, compare: &'a ast::ExprCompare) { + let mut left = check_is_arg(&compare.left); + let mut left_expr = compare.left.as_ref(); + for (op, right_expr) in compare.ops.iter().zip(compare.comparators.iter()) { + if self.error.is_some() { + return; + } + let right = check_is_arg(right_expr); + if matches!(op, ast::CmpOp::Is | ast::CmpOp::IsNot) && (!right || !left) { + let literal = if !left { left_expr } else { right_expr }; + if let Some(inferred) = infer_type(literal) { + let is_op = matches!(op, ast::CmpOp::Is); + let op = if is_op { "\"is\"" } else { "\"is not\"" }; + let replacement = if is_op { "==" } else { "!=" }; + let result = warn_syntax( + self.source, + self.filename, + compare.range, + format!( + "{op} with '{}' literal. Did you mean \"{replacement}\"?", + inferred.name() + ), + self.vm, + ); + self.record_warning(result); + return; + } + } + left = right; + left_expr = right_expr; + } + } + + fn check_caller(&mut self, func: &'a ast::Expr) { + if matches!( + func, + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + | ast::Expr::Tuple(_) + | ast::Expr::List(_) + | ast::Expr::ListComp(_) + | ast::Expr::Dict(_) + | ast::Expr::DictComp(_) + | ast::Expr::Set(_) + | ast::Expr::SetComp(_) + | ast::Expr::Generator(_) + | ast::Expr::FString(_) + | ast::Expr::TString(_) + ) && let Some(inferred) = infer_type(func) + { + let result = warn_syntax( + self.source, + self.filename, + func.range(), + format!( + "'{}' object is not callable; perhaps you missed a comma?", + inferred.name() + ), + self.vm, + ); + self.record_warning(result); + } + } + + fn check_subscripter(&mut self, value: &'a ast::Expr) { + let warns = matches!( + value, + ast::Expr::NoneLiteral(_) + | ast::Expr::EllipsisLiteral(_) + | ast::Expr::NumberLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::Set(_) + | ast::Expr::SetComp(_) + | ast::Expr::Generator(_) + | ast::Expr::TString(_) + | ast::Expr::Lambda(_) + ); + if warns && let Some(inferred) = infer_type(value) { + let result = warn_syntax( + self.source, + self.filename, + value.range(), + format!( + "'{}' object is not subscriptable; perhaps you missed a comma?", + inferred.name() + ), + self.vm, + ); + self.record_warning(result); + } + } + + fn check_index(&mut self, value: &'a ast::Expr, slice: &'a ast::Expr) { + let Some(index_type) = infer_type(slice) else { + return; + }; + if index_type.is_long_subclass() || index_type == InferredType::Slice { + return; + } + + let warns = matches!( + value, + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::Tuple(_) + | ast::Expr::List(_) + | ast::Expr::ListComp(_) + | ast::Expr::FString(_) + ); + if warns && let Some(value_type) = infer_type(value) { + let result = warn_syntax( + self.source, + self.filename, + value.range(), + format!( + "{} indices must be integers or slices, not {}; perhaps you missed a comma?", + value_type.name(), + index_type.name() + ), + self.vm, + ); + self.record_warning(result); + } + } + + fn check_assert(&mut self, assert_stmt: &'a ast::StmtAssert) { + if matches!(&*assert_stmt.test, ast::Expr::Tuple(tuple) if !tuple.elts.is_empty()) { + let result = warn_syntax( + self.source, + self.filename, + assert_stmt.range, + "assertion is always true, perhaps remove parentheses?".to_owned(), + self.vm, + ); + self.record_warning(result); + } + } + } + + #[cfg(test)] + impl<'a> Visitor<'a> for CompilerSyntaxWarningVisitor<'a> { + fn visit_stmt(&mut self, stmt: &'a ast::Stmt) { + if self.error.is_some() { + return; + } + + match stmt { + ast::Stmt::FunctionDef(function) => { + for decorator in &function.decorator_list { + self.visit_decorator(decorator); + } + self.visit_positional_defaults(&function.parameters); + self.visit_kwonly_defaults(&function.parameters); + if let Some(type_params) = &function.type_params { + self.visit_type_params(type_params); + } + self.visit_function_annotations_cpython_order( + &function.parameters, + function.returns.as_deref(), + ); + self.visit_body_with_context(&function.body, false, true, false); + } + ast::Stmt::ClassDef(class) => { + for decorator in &class.decorator_list { + self.visit_decorator(decorator); + } + if let Some(type_params) = &class.type_params { + self.visit_type_params(type_params); + } + self.visit_body(&class.body); + if let Some(arguments) = &class.arguments { + for base in &arguments.args { + self.visit_expr(base); + } + for keyword in &arguments.keywords { + self.visit_keyword(keyword); + } + } + } + ast::Stmt::TypeAlias(type_alias) => { + if let Some(type_params) = &type_alias.type_params { + self.visit_type_params(type_params); + } + self.visit_expr(&type_alias.value); + } + ast::Stmt::Return(return_stmt) => { + self.before_return(return_stmt.range); + if let Some(value) = &return_stmt.value { + self.visit_expr(value); + } + } + ast::Stmt::AnnAssign(ann_assign) => { + self.visit_expr(&ann_assign.target); + if !self.future_annotations { + self.visit_annotation(&ann_assign.annotation); + } + if let Some(value) = &ann_assign.value { + self.visit_expr(value); + } + } + ast::Stmt::For(for_stmt) => { + self.visit_expr(&for_stmt.target); + self.visit_expr(&for_stmt.iter); + self.visit_body_with_context(&for_stmt.body, false, false, true); + self.visit_body(&for_stmt.orelse); + } + ast::Stmt::While(while_stmt) => { + self.visit_expr(&while_stmt.test); + self.visit_body_with_context(&while_stmt.body, false, false, true); + self.visit_body(&while_stmt.orelse); + } + ast::Stmt::Try(try_stmt) => { + self.visit_body(&try_stmt.body); + for handler in &try_stmt.handlers { + self.visit_except_handler(handler); + } + self.visit_body(&try_stmt.orelse); + self.visit_body_with_context(&try_stmt.finalbody, true, false, false); + } + ast::Stmt::Assert(assert_stmt) => { + if !self.skip_codegen_warnings { + self.check_assert(assert_stmt); + } + self.visit_expr(&assert_stmt.test); + if let Some(msg) = &assert_stmt.msg { + self.visit_expr(msg); + } + } + ast::Stmt::Break(break_stmt) => { + self.before_loop_exit(break_stmt.range, "break"); + } + ast::Stmt::Continue(continue_stmt) => { + self.before_loop_exit(continue_stmt.range, "continue"); + } + _ => ast::visitor::walk_stmt(self, stmt), + } + } + + fn visit_expr(&mut self, expr: &'a ast::Expr) { + if self.error.is_some() { + return; + } + match expr { + ast::Expr::Compare(compare) if !self.skip_codegen_warnings => { + self.check_compare(compare); + } + ast::Expr::Call(call) if !self.skip_codegen_warnings => { + self.check_caller(&call.func); + } + ast::Expr::Subscript(subscript) + if !self.skip_codegen_warnings + && matches!(subscript.ctx, ast::ExprContext::Load) => + { + self.check_subscripter(&subscript.value); + self.check_index(&subscript.value, &subscript.slice); + } + _ => {} + } + ast::visitor::walk_expr(self, expr); + } + } + + #[cfg(test)] + fn has_future_annotations(ast: &ast::Mod) -> bool { + let ast::Mod::Module(module) = ast else { + return false; + }; + let mut statements = module.body.iter(); + if let Some(ast::Stmt::Expr(ast::StmtExpr { value, .. })) = statements.clone().next() + && matches!(&**value, ast::Expr::StringLiteral(_)) + { + statements.next(); + } + for statement in statements { + match statement { + ast::Stmt::ImportFrom(ast::StmtImportFrom { + module, + names, + level, + .. + }) if *level == 0 + && module.as_ref().map(|id| id.as_str()) == Some("__future__") => + { + if names + .iter() + .any(|future| future.name.as_str() == "annotations") + { + return true; + } + } + _ => return false, + } + } + false } struct EscapeWarningVisitor<'a> { source: &'a str, filename: &'a str, vm: &'a VirtualMachine, + error: Option, } impl<'a> EscapeWarningVisitor<'a> { + fn record_warning(&mut self, result: Result<(), CompileWarningError>) { + if self.error.is_none() + && let Err(err) = result + { + self.error = Some(err); + } + } + /// Check a quoted string/bytes literal for invalid escapes. /// The range must include the prefix and quote delimiters. - fn check_quoted_literal(&self, range: TextRange, is_bytes: bool) { + fn check_quoted_literal(&mut self, range: TextRange, is_bytes: bool) { if let Some((start, end)) = content_bounds(self.source, range) && let Some((ch, offset)) = first_invalid_escape(self.source, start, end, is_bytes) { - warn_invalid_escape_sequence(self.source, ch, offset, self.filename, self.vm); + let result = + warn_invalid_escape_sequence(self.source, ch, offset, self.filename, self.vm); + self.record_warning(result); } } @@ -224,14 +1450,16 @@ mod escape_warnings { /// Also handles `\{` / `\}` at the literal–interpolation boundary, /// equivalent to `_PyTokenizer_warn_invalid_escape_sequence` handling /// `FSTRING_MIDDLE` / `FSTRING_END` tokens. - fn check_fstring_literal(&self, range: TextRange) { + fn check_fstring_literal(&mut self, range: TextRange) { let start = range.start().to_usize(); let end = range.end().to_usize(); if start >= end || end > self.source.len() { return; } if let Some((ch, offset)) = first_invalid_escape(self.source, start, end, false) { - warn_invalid_escape_sequence(self.source, ch, offset, self.filename, self.vm); + let result = + warn_invalid_escape_sequence(self.source, ch, offset, self.filename, self.vm); + self.record_warning(result); return; } // In CPython, _PyTokenizer_warn_invalid_escape_sequence handles @@ -249,13 +1477,14 @@ mod escape_warnings { && let Some(&after) = self.source.as_bytes().get(end) && (after == b'{' || after == b'}') { - warn_invalid_escape_sequence( + let result = warn_invalid_escape_sequence( self.source, after as char, end - 1, self.filename, self.vm, ); + self.record_warning(result); } } @@ -263,6 +1492,9 @@ mod escape_warnings { /// interpolation expressions and format specs. fn visit_fstring_elements(&mut self, elements: &'a ast::InterpolatedStringElements) { for element in elements { + if self.error.is_some() { + return; + } match element { ast::InterpolatedStringElement::Literal(lit) => { self.check_fstring_literal(lit.range); @@ -280,6 +1512,9 @@ mod escape_warnings { impl<'a> Visitor<'a> for EscapeWarningVisitor<'a> { fn visit_expr(&mut self, expr: &'a ast::Expr) { + if self.error.is_some() { + return; + } match expr { // Regular string literals — decode_unicode_with_escapes path ast::Expr::StringLiteral(string) => { @@ -334,19 +1569,99 @@ mod escape_warnings { } impl VirtualMachine { + /// Emit tokenizer-level SyntaxWarnings that CPython raises before + /// code generation. + pub(super) fn emit_tokenizer_syntax_warnings( + &self, + source: &str, + filename: &str, + ) -> Result<(), CompileWarningError> { + emit_numeric_literal_warnings(source, filename, self) + } + /// Walk all string literals in `source` and emit `SyntaxWarning` for /// each that contains an invalid escape sequence. - pub(super) fn emit_string_escape_warnings(&self, source: &str, filename: &str) { + pub(super) fn emit_string_escape_warnings( + &self, + source: &str, + filename: &str, + ) -> Result<(), CompileWarningError> { let Ok(parsed) = ruff_python_parser::parse(source, ruff_python_parser::Mode::Module.into()) else { - return; + return Ok(()); }; let ast = parsed.into_syntax(); let mut visitor = EscapeWarningVisitor { source, filename, vm: self, + error: None, + }; + match &ast { + ast::Mod::Module(module) => { + for stmt in &module.body { + visitor.visit_stmt(stmt); + } + } + ast::Mod::Expression(expr) => { + visitor.visit_expr(&expr.body); + } + } + visitor.error.map_or(Ok(()), Err) + } + + /// Emit CPython codegen SyntaxWarnings for suspicious compare, call, + /// and subscript forms. + #[cfg(test)] + pub(super) fn emit_compiler_syntax_warnings( + &self, + source: &str, + filename: &str, + ) -> Result<(), CompileWarningError> { + self.emit_compiler_syntax_warnings_with_options(source, filename, false, false) + } + + #[cfg(test)] + fn emit_compiler_syntax_warnings_with_options( + &self, + source: &str, + filename: &str, + skip_codegen_warnings: bool, + explicit_future_annotations: bool, + ) -> Result<(), CompileWarningError> { + let Ok(parsed) = + ruff_python_parser::parse(source, ruff_python_parser::Mode::Module.into()) + else { + return Ok(()); + }; + let ast = parsed.into_syntax(); + self.emit_compiler_syntax_warnings_from_ast( + &ast, + source, + filename, + skip_codegen_warnings, + explicit_future_annotations, + ) + } + + #[cfg(test)] + fn emit_compiler_syntax_warnings_from_ast( + &self, + ast: &ast::Mod, + source: &str, + filename: &str, + skip_codegen_warnings: bool, + explicit_future_annotations: bool, + ) -> Result<(), CompileWarningError> { + let mut visitor = CompilerSyntaxWarningVisitor { + source, + filename, + vm: self, + error: None, + cf_finally: Vec::new(), + future_annotations: explicit_future_annotations || has_future_annotations(ast), + skip_codegen_warnings, }; match ast { ast::Mod::Module(module) => { @@ -358,6 +1673,226 @@ mod escape_warnings { visitor.visit_expr(&expr.body); } } + visitor.error.map_or(Ok(()), Err) + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::{Interpreter, builtins::PyTuple}; + + fn install_syntax_warning_error_filter(vm: &VirtualMachine) { + let error_filter = PyTuple::new_ref( + vec![ + vm.ctx.new_str("error").into(), + vm.ctx.none(), + vm.ctx.exceptions.syntax_warning.as_object().to_owned(), + vm.ctx.none(), + vm.ctx.new_int(0).into(), + ], + &vm.ctx, + ); + vm.state + .warnings + .filters + .borrow_vec_mut() + .insert(0, error_filter.into()); + vm.state.warnings.filters_mutated(); + } + + fn first_compiler_warning(source: &str) -> String { + Interpreter::without_stdlib(Default::default()).enter(|vm| { + install_syntax_warning_error_filter(vm); + let err = vm + .emit_compiler_syntax_warnings(source, "") + .expect_err("expected compiler SyntaxWarning"); + err.exception + .as_object() + .str(vm) + .expect("warning message should stringify") + .as_wtf8() + .to_string() + }) + } + + fn compile_error_message(source: &str) -> String { + Interpreter::without_stdlib(Default::default()).enter(|vm| { + install_syntax_warning_error_filter(vm); + let err = match vm.compile(source, compiler::Mode::Exec, "") { + Ok(_) => panic!("expected compile error"), + Err(err) => err, + }; + err.into_pyexception(vm, Some(source)) + .as_object() + .str(vm) + .expect("compile error should stringify") + .as_wtf8() + .to_string() + }) + } + + #[test] + fn codegen_caller_warning_precedes_later_return_error() { + let message = compile_error_message("(1)()\nreturn\n"); + assert!( + message.contains("'int' object is not callable"), + "expected caller SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn symboltable_error_still_precedes_codegen_caller_warning() { + let message = compile_error_message("(1)()\ndef f():\n from x import *\n"); + assert!( + message.contains("import * only allowed at module level"), + "expected symboltable error first, got {message:?}" + ); + } + + #[test] + fn codegen_compare_warning_precedes_later_return_error() { + let message = compile_error_message("1 is 1\nreturn\n"); + assert!( + message.contains("\"is\" with 'int' literal"), + "expected compare SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn codegen_assert_warning_precedes_later_return_error() { + let message = compile_error_message("assert (1,)\nreturn\n"); + assert!( + message.contains("assertion is always true"), + "expected assert SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn codegen_subscript_warning_precedes_later_return_error() { + let message = compile_error_message("(1)[None]\nreturn\n"); + assert!( + message.contains("'int' object is not subscriptable"), + "expected subscript SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn codegen_index_warning_precedes_later_return_error() { + let message = compile_error_message("'x'[None]\nreturn\n"); + assert!( + message.contains("str indices must be integers or slices, not NoneType"), + "expected index SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn string_escape_warning_precedes_later_return_error() { + let message = compile_error_message("\"\\z\"\nreturn\n"); + assert!( + message.contains("\"\\z\" is an invalid escape sequence"), + "expected invalid escape SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn string_escape_warning_precedes_later_symboltable_error() { + let message = compile_error_message("\"\\z\"\ndef f():\n from x import *\n"); + assert!( + message.contains("\"\\z\" is an invalid escape sequence"), + "expected invalid escape SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn ast_preprocess_finally_warning_precedes_later_return_error() { + let message = compile_error_message("try:\n pass\nfinally:\n return\nreturn\n"); + assert!( + message.contains("'return' in a 'finally' block"), + "expected finally SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn ast_preprocess_finally_warning_precedes_symboltable_error() { + let message = compile_error_message( + "def f():\n from x import *\ntry:\n pass\nfinally:\n return\n", + ); + assert!( + message.contains("'return' in a 'finally' block"), + "expected finally SyntaxWarning first, got {message:?}" + ); + } + + #[test] + fn compiler_warning_visits_function_decorators_before_defaults_and_body() { + let message = first_compiler_warning( + r#" +@(b"decorator")() +def f(x=(1)()): + assert (1,) +"#, + ); + assert!( + message.contains("'bytes' object is not callable"), + "expected decorator warning first, got {message:?}" + ); + } + + #[test] + fn compiler_warning_visits_function_defaults_before_annotations() { + let message = first_compiler_warning( + r#" +def f(x: (1)() = ("default")()): + pass +"#, + ); + assert!( + message.contains("'str' object is not callable"), + "expected default warning before annotation warning, got {message:?}" + ); + } + + #[test] + fn compiler_warning_visits_class_decorators_before_body_and_bases() { + let message = first_compiler_warning( + r#" +@(b"decorator")() +class C((1)()): + assert (1,) +"#, + ); + assert!( + message.contains("'bytes' object is not callable"), + "expected class decorator warning first, got {message:?}" + ); + } + + #[test] + fn compiler_warning_visits_class_body_before_bases() { + let message = first_compiler_warning( + r#" +class C((1)()): + assert (1,) +"#, + ); + assert!( + message.contains("assertion is always true"), + "expected class body warning before base warning, got {message:?}" + ); + } + + #[test] + fn compiler_warning_visits_type_alias_type_params_before_value() { + let message = first_compiler_warning( + r#" +type Alias[T: (1)()] = ("value")() +"#, + ); + assert!( + message.contains("'int' object is not callable"), + "expected type parameter warning before alias value warning, got {message:?}" + ); } } } diff --git a/crates/vm/src/vm/compile_mode.rs b/crates/vm/src/vm/compile_mode.rs new file mode 100644 index 00000000000..cbff7f714d4 --- /dev/null +++ b/crates/vm/src/vm/compile_mode.rs @@ -0,0 +1,60 @@ +use crate::bytecode; + +pub(crate) const PY_SINGLE_INPUT: i32 = 256; +pub(crate) const PY_FILE_INPUT: i32 = 257; +pub(crate) const PY_EVAL_INPUT: i32 = 258; +pub(crate) const PY_FUNC_TYPE_INPUT: i32 = 345; + +// Caveat emptor: These flags are undocumented on purpose and depending +// on their effect outside the standard library is **unsupported**. +pub(crate) const PY_CF_SOURCE_IS_UTF8: i32 = 0x0100; +pub(crate) const PY_CF_DONT_IMPLY_DEDENT: i32 = 0x0200; +pub(crate) const PY_CF_ONLY_AST: i32 = 0x0400; +pub(crate) const PY_CF_IGNORE_COOKIE: i32 = 0x0800; +pub(crate) const PY_CF_TYPE_COMMENTS: i32 = 0x1000; +pub(crate) const PY_CF_ALLOW_TOP_LEVEL_AWAIT: i32 = 0x2000; +pub(crate) const PY_CF_ALLOW_INCOMPLETE_INPUT: i32 = 0x4000; +pub(crate) const PY_CF_OPTIMIZED_AST: i32 = 0x8000 | PY_CF_ONLY_AST; + +// __future__ flags - sync with Lib/__future__.py and Include/cpython/compile.h. +const CO_NESTED: i32 = 0x0010; +const CO_FUTURE_DIVISION: i32 = 0x20000; +const CO_FUTURE_ABSOLUTE_IMPORT: i32 = 0x40000; +const CO_FUTURE_WITH_STATEMENT: i32 = 0x80000; +const CO_FUTURE_PRINT_FUNCTION: i32 = 0x100000; +const CO_FUTURE_UNICODE_LITERALS: i32 = 0x200000; +const CO_FUTURE_BARRY_AS_BDFL: i32 = 0x400000; +const CO_FUTURE_GENERATOR_STOP: i32 = 0x800000; +const CO_FUTURE_ANNOTATIONS: i32 = 0x1000000; + +const PY_CF_MASK: i32 = CO_FUTURE_DIVISION + | CO_FUTURE_ABSOLUTE_IMPORT + | CO_FUTURE_WITH_STATEMENT + | CO_FUTURE_PRINT_FUNCTION + | CO_FUTURE_UNICODE_LITERALS + | CO_FUTURE_BARRY_AS_BDFL + | CO_FUTURE_GENERATOR_STOP + | CO_FUTURE_ANNOTATIONS; +const PY_CF_MASK_OBSOLETE: i32 = CO_NESTED; +pub(crate) const PY_CF_COMPILE_MASK: i32 = PY_CF_ONLY_AST + | PY_CF_ALLOW_TOP_LEVEL_AWAIT + | PY_CF_TYPE_COMMENTS + | PY_CF_DONT_IMPLY_DEDENT + | PY_CF_ALLOW_INCOMPLETE_INPUT + | PY_CF_OPTIMIZED_AST; +pub(crate) const PY_CF_ALLOWED_FLAGS: i32 = PY_CF_MASK | PY_CF_MASK_OBSOLETE | PY_CF_COMPILE_MASK; + +pub(crate) fn compile_future_feature_mask() -> bytecode::CodeFlags { + // RustPython accepts barry_as_FLUFL but leaves its parser mode disabled. + bytecode::CodeFlags::FUTURE_DIVISION + | bytecode::CodeFlags::FUTURE_ABSOLUTE_IMPORT + | bytecode::CodeFlags::FUTURE_WITH_STATEMENT + | bytecode::CodeFlags::FUTURE_PRINT_FUNCTION + | bytecode::CodeFlags::FUTURE_UNICODE_LITERALS + | bytecode::CodeFlags::FUTURE_GENERATOR_STOP + | bytecode::CodeFlags::FUTURE_ANNOTATIONS +} + +pub(crate) fn compile_future_features_from_flags(flags: i32) -> bytecode::CodeFlags { + bytecode::CodeFlags::from_bits_truncate(flags as u32 & compile_future_feature_mask().bits()) +} diff --git a/crates/vm/src/vm/interpreter.rs b/crates/vm/src/vm/interpreter.rs index 79e3e190a2f..56c9606f7de 100644 --- a/crates/vm/src/vm/interpreter.rs +++ b/crates/vm/src/vm/interpreter.rs @@ -300,8 +300,10 @@ impl Default for InterpreterBuilder { /// let scope = vm.new_scope_with_builtins(); /// let source = r#"print("Hello World!")"#; /// let code_obj = vm.compile( -/// source, Mode::Exec, "" -/// ).map_err(|err| vm.new_syntax_error(&err, Some(source))).unwrap(); +/// source, +/// Mode::Exec, +/// "", +/// ).map_err(|err| err.into_pyexception(vm, Some(source))).unwrap(); /// vm.run_code_obj(code_obj, scope).unwrap(); /// }); /// ``` diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index eb6546c02fe..9fa5754d54d 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -5,6 +5,9 @@ #[cfg(feature = "rustpython-compiler")] mod compile; +pub(crate) mod compile_mode; +#[cfg(feature = "rustpython-compiler")] +pub use compile::VmCompileError; mod context; mod interpreter; mod method; @@ -82,6 +85,7 @@ pub struct VirtualMachine { pub profile_func: RefCell, pub trace_func: RefCell, pub use_tracing: Cell, + tracing_depth: Cell, pub recursion_limit: Cell, pub(crate) signal_handlers: OnceCell, pub(crate) signal_rx: Option, @@ -739,6 +743,7 @@ impl VirtualMachine { profile_func, trace_func, use_tracing: Cell::new(false), + tracing_depth: Cell::new(0), recursion_limit: Cell::new(if cfg!(debug_assertions) { 256 } else { 1000 }), signal_handlers, signal_rx: None, @@ -1095,30 +1100,22 @@ impl VirtualMachine { } pub fn run_code_obj(&self, code: PyRef, scope: Scope) -> PyResult { - use crate::builtins::{PyFunction, PyModule}; - - // Create a function object for module code, similar to CPython's PyEval_EvalCode - let func = PyFunction::new(code.clone(), scope.globals.clone(), self)?; - let func_obj = func.into_ref(&self.ctx).into(); + self.run_code_obj_with_closure(code, scope, None) + } - // Extract builtins from globals["__builtins__"], like PyEval_EvalCode - let builtins = match scope - .globals - .get_item_opt(identifier!(self, __builtins__), self)? - { - Some(b) => { - if let Some(module) = b.downcast_ref::() { - module.dict().into() - } else { - b - } - } - None => self.builtins.dict().into(), - }; + pub(crate) fn run_code_obj_with_closure( + &self, + code: PyRef, + scope: Scope, + closure: Option>>, + ) -> PyResult { + use crate::builtins::PyFunction; - let frame = - Frame::new(code, scope, builtins, &[], Some(func_obj), false, self).into_ref(&self.ctx); - self.run_frame(frame) + // Create a function object for module code, similar to CPython's PyEval_EvalCode + let mut func = PyFunction::new(code, scope.globals.clone(), self)?; + func.closure = closure; + let func = func.into_ref(&self.ctx); + func.invoke_with_locals(FuncArgs::default(), scope.locals, self) } #[cold] @@ -1434,9 +1431,11 @@ impl VirtualMachine { } /// Stack margin bytes (like _PyOS_STACK_MARGIN_BYTES). - /// 2048 * sizeof(void*) = 16KB for 64-bit. + /// CPython doubles the margin for debug/sanitized builds because frame + /// evaluation consumes more native stack in those configurations. #[cfg_attr(any(miri, target_env = "musl"), allow(dead_code))] - const STACK_MARGIN_BYTES: usize = 2048 * core::mem::size_of::(); + const STACK_MARGIN_BYTES: usize = + (if cfg!(debug_assertions) { 4096 } else { 2048 }) * core::mem::size_of::(); /// Get the stack boundaries using platform-specific APIs. /// Returns (base, top) where base is the lowest address and top is the highest. @@ -1709,11 +1708,38 @@ impl VirtualMachine { #[cfg(feature = "rustpython-codegen")] pub fn compile_opts(&self) -> crate::compiler::CompileOpts { crate::compiler::CompileOpts { - optimize: self.state.config.settings.optimize, + optimize: self.state.config.settings.optimize.min(2), debug_ranges: self.state.config.settings.code_debug_ranges, + int_max_str_digits: self.state.int_max_str_digits.load(), + allow_top_level_await: false, + future_features: crate::bytecode::CodeFlags::empty(), + dont_imply_dedent: false, + recursion_limit: self.recursion_limit.get(), + ast_constant_overrides: None, + ast_interpolation_overrides: None, + ast_formatted_value_overrides: None, + ast_joined_str_overrides: None, + ast_template_str_overrides: None, } } + #[inline] + pub(crate) fn enter_tracing(&self) { + self.tracing_depth.set(self.tracing_depth.get() + 1); + } + + #[inline] + pub(crate) fn leave_tracing(&self) { + let depth = self.tracing_depth.get(); + debug_assert!(depth > 0); + self.tracing_depth.set(depth.saturating_sub(1)); + } + + #[inline] + pub(crate) fn tracing_is_suppressed(&self) -> bool { + self.tracing_depth.get() != 0 + } + // To be called right before raising the recursion depth. fn check_recursive_call(&self, _where: &str) -> PyResult<()> { if self.recursion_depth.get() >= self.recursion_limit.get() { @@ -2286,7 +2312,7 @@ mod tests { let source = "from dir_module.dir_module_inner import value2"; let code_obj = vm .compile(source, vm::compiler::Mode::Exec, "") - .map_err(|err| vm.new_syntax_error(&err, Some(source))) + .map_err(|err| err.into_pyexception(vm, Some(source))) .unwrap(); if let Err(e) = vm.run_code_obj(code_obj, scope) { diff --git a/crates/vm/src/vm/python_run.rs b/crates/vm/src/vm/python_run.rs index c21b437b575..91d5885e740 100644 --- a/crates/vm/src/vm/python_run.rs +++ b/crates/vm/src/vm/python_run.rs @@ -22,7 +22,7 @@ impl VirtualMachine { pub fn run_string(&self, scope: Scope, source: &str, source_path: &str) -> PyResult { let code_obj = self .compile(source, compiler::Mode::Exec, source_path) - .map_err(|err| self.new_syntax_error(&err, Some(source)))?; + .map_err(|err| err.into_pyexception(self, Some(source)))?; // linecache._register_code(code, source, filename) let _ = self.register_code_in_linecache(&code_obj, source); self.run_code_obj(code_obj, scope) @@ -47,7 +47,7 @@ impl VirtualMachine { pub fn run_block_expr(&self, scope: Scope, source: &str) -> PyResult { let code_obj = self .compile(source, compiler::Mode::BlockExpr, "") - .map_err(|err| self.new_syntax_error(&err, Some(source)))?; + .map_err(|err| err.into_pyexception(self, Some(source)))?; self.run_code_obj(code_obj, scope) } } @@ -105,11 +105,19 @@ mod file_run { if path != "" { set_main_loader(module_dict, path, "SourceFileLoader", self)?; } - match crate::host_env::fs::read_to_string(path) { - Ok(source) => { + match crate::host_env::fs::read(path) { + Ok(source_bytes) => { + if source_bytes.contains(&0) { + return Err(self.new_exception_msg( + self.ctx.exceptions.syntax_error.to_owned(), + "source code cannot contain null bytes".into(), + )); + } + let source = String::from_utf8(source_bytes) + .map_err(|err| self.new_os_error(err.to_string()))?; let code_obj = self .compile(&source, compiler::Mode::Exec, path) - .map_err(|err| self.new_syntax_error(&err, Some(&source)))?; + .map_err(|err| err.into_pyexception(self, Some(&source)))?; self.run_code_obj(code_obj, scope)?; } Err(err) => { diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index 26d8db9d764..5e73cc5f618 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -719,6 +719,7 @@ impl VirtualMachine { profile_func: RefCell::new(global_profile.unwrap_or_else(|| self.ctx.none())), trace_func: RefCell::new(global_trace.unwrap_or_else(|| self.ctx.none())), use_tracing: Cell::new(use_tracing), + tracing_depth: Cell::new(0), recursion_limit: self.recursion_limit.clone(), signal_handlers: core::cell::OnceCell::new(), signal_rx: None, diff --git a/crates/vm/src/vm/vm_new.rs b/crates/vm/src/vm/vm_new.rs index 4db965855df..d93905f203e 100644 --- a/crates/vm/src/vm/vm_new.rs +++ b/crates/vm/src/vm/vm_new.rs @@ -179,12 +179,22 @@ impl SyntaxErrorInfo { | ParseErrorType::SimpleAndCompoundStatementOnSameLine | ParseErrorType::ExpectedExpression => "invalid syntax".into(), + ParseErrorType::OtherError(s) if s.starts_with("Expected an identifier") => { + "invalid syntax".into() + } + ParseErrorType::OtherError(s) - if s.starts_with("Expected an identifier, but found a keyword") => + if s.eq_ignore_ascii_case( + "Expected a type parameter or the end of the type parameter list", + ) => { "invalid syntax".into() } + ParseErrorType::OtherError(s) if s.eq_ignore_ascii_case("Expected a statement") => { + "invalid syntax".into() + } + ParseErrorType::OtherError(s) if s.eq_ignore_ascii_case( "bytes literal cannot be mixed with non-bytes literals", @@ -430,7 +440,6 @@ impl VirtualMachine { name_error.as_object(), self, unwrap, "name" => name, ); - name_error } @@ -517,7 +526,6 @@ impl VirtualMachine { reason.clone().into(), ], ); - set_attrs!( exc.as_object(), self, unwrap, "encoding" => encoding, @@ -526,7 +534,6 @@ impl VirtualMachine { "end" => end, "reason" => reason, ); - exc } @@ -550,7 +557,6 @@ impl VirtualMachine { reason.clone().into(), ], ); - set_attrs!( exc.as_object(), self, unwrap, "encoding" => encoding, @@ -559,7 +565,6 @@ impl VirtualMachine { "end" => end, "reason" => reason, ); - exc } @@ -576,6 +581,16 @@ impl VirtualMachine { source: Option<&str>, allow_incomplete: bool, ) -> PyBaseExceptionRef { + if matches!( + error, + crate::compiler::CompileError::Codegen(crate::compiler::codegen::error::CodegenError { + error: crate::compiler::codegen::error::CodegenErrorType::RecursionError, + .. + }) + ) { + return self.new_recursion_error(error.to_string()); + } + let incomplete_or_syntax = |allow| -> &'static Py { if allow { self.ctx.exceptions.incomplete_input_error @@ -687,7 +702,9 @@ impl VirtualMachine { raw_location, .. }) => { - if s.starts_with("Expected an indented block after") { + if s.starts_with("Expected an indented block after") + || s.starts_with("expected an indented block after") + { if allow_incomplete { // Check that all chars in the error are whitespace, if so, the source is // incomplete. Otherwise, we've found code that might violates @@ -717,6 +734,12 @@ impl VirtualMachine { } else { self.ctx.exceptions.indentation_error } + } else if allow_incomplete + && source.is_some_and(|source| { + raw_location.end().to_usize() >= source.len() && !source.ends_with('\n') + }) + { + self.ctx.exceptions.incomplete_input_error } else { self.ctx.exceptions.syntax_error } @@ -738,7 +761,29 @@ impl VirtualMachine { let statement = source.and_then(|src| get_statement(src, error.location())); let mut msg = error.to_string(); - if let Some(msg) = msg.get_mut(..1) { + if !msg.starts_with("Exceeds the limit ") + && !msg.starts_with("Did you mean ") + && !msg.starts_with("Invalid star expression") + && !msg.starts_with("Function parameters cannot be parenthesized") + && !msg.starts_with("Lambda expression parameters cannot be parenthesized") + && !msg.starts_with("Cannot have two type comments on def") + && !msg.starts_with("Variable annotation syntax is") + && !msg.starts_with("The '@' operator is") + && !msg.starts_with("Async functions are") + && !msg.starts_with("Async comprehensions are") + && !msg.starts_with("Async for loops are") + && !msg.starts_with("Async with statements are") + && !msg.starts_with("Exception groups are") + && !msg.starts_with("Positional-only parameters are") + && !msg.starts_with("Pattern matching is") + && !msg.starts_with("Type statement is") + && !msg.starts_with("Type parameter lists are") + && !msg.starts_with("Type parameter defaults are") + && !msg.starts_with("Assignment expressions are") + && !msg.starts_with("Await expressions are") + && !msg.starts_with("Underscores in numeric literals are") + && let Some(msg) = msg.get_mut(..1) + { msg.make_ascii_lowercase(); } @@ -755,15 +800,24 @@ impl VirtualMachine { if syntax_error_type.is(self.ctx.exceptions.tab_error) { syntax_error_info.with_msg("inconsistent use of tabs and spaces in indentation"); } + if syntax_error_type.is(self.ctx.exceptions.incomplete_input_error) { + syntax_error_info.with_msg("incomplete input"); + } let SyntaxErrorInfo { msg, narrow_caret } = syntax_error_info; + let check_version_suite_error = msg.starts_with("Async functions are") + || msg.starts_with("Async for loops are") + || msg.starts_with("Async with statements are") + || msg.starts_with("Exception groups are") + || msg.starts_with("except expressions without parentheses are") + || msg.starts_with("Pattern matching is"); + let line_end_binary_operator_error = msg.starts_with("The '@' operator is"); let syntax_error = self.new_exception_msg(syntax_error_type, msg.into()); - let (lineno, offset) = error.python_location(); - let lineno = self.ctx.new_int(lineno); - let offset = self.ctx.new_int(offset); - + let (lineno_raw, offset_raw) = error.python_location(); + let lineno = self.ctx.new_int(lineno_raw); + let offset = self.ctx.new_int(offset_raw); set_attrs!( syntax_error.as_object(), self, unwrap, "lineno" => lineno, @@ -772,15 +826,23 @@ impl VirtualMachine { // Set end_lineno and end_offset if available if let Some((end_lineno, end_offset)) = error.python_end_location() { - let (end_lineno, end_offset) = if narrow_caret { + let (end_lineno, end_offset) = if check_version_suite_error + && statement + .as_deref() + .and_then(|line| line.chars().next()) + .is_some_and(|ch| ch.is_ascii_whitespace()) + { + (end_lineno, -1) + } else if line_end_binary_operator_error && end_offset == offset_raw { + (end_lineno, (end_offset + 1) as isize) + } else if narrow_caret { let (l, o) = error.python_location(); - (l, o + 1) + (l, (o + 1) as isize) } else { - (end_lineno, end_offset) + (end_lineno, end_offset as isize) }; let end_lineno = self.ctx.new_int(end_lineno); let end_offset = self.ctx.new_int(end_offset); - set_attrs!( syntax_error.as_object(), self, unwrap, "end_lineno" => end_lineno, @@ -833,7 +895,6 @@ impl VirtualMachine { exc.as_object(), self, unwrap, "name" => name.into(), ); - exc } diff --git a/crates/wasm/Cargo.toml b/crates/wasm/Cargo.toml index 4150beaa81c..5a5e8d77fd6 100644 --- a/crates/wasm/Cargo.toml +++ b/crates/wasm/Cargo.toml @@ -20,6 +20,7 @@ no-start-func = [] rustpython-common = { workspace = true } rustpython-pylib = { workspace = true, optional = true } rustpython-stdlib = { workspace = true, default-features = false, optional = true } +ruff_text_size = { workspace = true } # make sure no threading! otherwise wasm build will fail rustpython-vm = { workspace = true, features = ["compiler", "encodings", "serde", "wasmbind"] } diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 99668df2855..ae4e52f21b0 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -70,7 +70,11 @@ pub mod eval { if let Some(js_vars) = js_vars { vm.add_to_scope("js_vars".into(), js_vars.into())?; } - vm.run(source, mode, None) + if matches!(mode, Mode::Single) { + vm.exec_single(source, None) + } else { + vm.run(source, mode, None) + } } /// Evaluate Python code diff --git a/crates/wasm/src/vm_class.rs b/crates/wasm/src/vm_class.rs index 08cb49ecfca..cd0af10f6df 100644 --- a/crates/wasm/src/vm_class.rs +++ b/crates/wasm/src/vm_class.rs @@ -6,9 +6,14 @@ use crate::{ use alloc::rc::{Rc, Weak}; use core::cell::RefCell; use js_sys::{Object, TypeError}; +use ruff_text_size::Ranged; use rustpython_vm::{ - Interpreter, PyObjectRef, PyRef, PyResult, Settings, VirtualMachine, builtins::PyWeak, - compiler::Mode, function::ArgMapping, scope::Scope, + Interpreter, PyObjectRef, PyRef, PyResult, Settings, VirtualMachine, + builtins::PyWeak, + compiler::{self, Mode}, + function::ArgMapping, + scope::Scope, + vm::VmCompileError, }; use std::collections::HashMap; use wasm_bindgen::prelude::*; @@ -21,6 +26,25 @@ pub(crate) struct StoredVirtualMachine { held_objects: RefCell>, } +fn compile_err_to_js(vm: &VirtualMachine, err: VmCompileError) -> JsValue { + match err { + VmCompileError::Compile(err) => convert::syntax_err(err).into(), + err => convert::py_err_to_js_err(vm, &err.into_pyexception(vm, None)), + } +} + +fn statement_chunks(source: &str) -> Option> { + let module = compiler::parser::parse_module(source).ok()?.into_syntax(); + module + .body + .iter() + .map(|stmt| { + let range = stmt.range(); + source.get(range.start().to_usize()..range.end().to_usize()) + }) + .collect() +} + #[pymodule] mod _window { use super::{js_module, wasm_builtins}; @@ -263,8 +287,8 @@ impl WASMVirtualMachine { ) -> Result<(), JsValue> { self.with_vm(|vm, _| { let code = vm - .compile(source, Mode::Exec, &name) - .map_err(convert::syntax_err)?; + .compile(source, Mode::Exec, name.as_str()) + .map_err(|err| compile_err_to_js(vm, err))?; let attrs = vm.ctx.new_dict(); attrs .set_item("__name__", vm.new_pyobj(name.as_str()), vm) @@ -327,13 +351,46 @@ impl WASMVirtualMachine { ) -> Result { self.with_vm(|vm, StoredVirtualMachine { scope, .. }| { let source_path = source_path.unwrap_or_else(|| "".to_owned()); - let code = vm.compile(source, mode, &source_path); - let code = code.map_err(convert::syntax_err)?; + let code = vm.compile(source, mode, source_path.as_str()); + let code = code.map_err(|err| compile_err_to_js(vm, err))?; let result = vm.run_code_obj(code, scope.clone()); convert::pyresult_to_js_result(vm, result) })? } + pub(crate) fn run_single( + &self, + source: &str, + source_path: Option, + ) -> Result { + self.with_vm(|vm, StoredVirtualMachine { scope, .. }| { + let source_path = source_path.unwrap_or_else(|| "".to_owned()); + let Some(chunks) = statement_chunks(source) else { + let code = vm.compile(source, Mode::Single, source_path.as_str()); + let code = code.map_err(|err| compile_err_to_js(vm, err))?; + let result = vm.run_code_obj(code, scope.clone()); + return convert::pyresult_to_js_result(vm, result); + }; + + if chunks.is_empty() { + return Ok(convert::py_to_js(vm, vm.ctx.none())); + } + + let displayhook = vm + .sys_module + .get_attr("displayhook", vm) + .map_err(|_| TypeError::new("lost sys.displayhook"))?; + let mut result = vm.ctx.none(); + for chunk in chunks { + let code = vm.compile(chunk, Mode::BlockExpr, source_path.as_str()); + let code = code.map_err(|err| compile_err_to_js(vm, err))?; + result = vm.run_code_obj(code, scope.clone()).into_js(vm)?; + displayhook.call((result.clone(),), vm).into_js(vm)?; + } + Ok(convert::py_to_js(vm, result)) + })? + } + pub fn exec(&self, source: &str, source_path: Option) -> Result { self.run(source, Mode::Exec, source_path) } @@ -348,6 +405,6 @@ impl WASMVirtualMachine { source: &str, source_path: Option, ) -> Result { - self.run(source, Mode::Single, source_path) + self.run_single(source, source_path) } } diff --git a/examples/hello_embed.rs b/examples/hello_embed.rs index 9e1cdb829d6..ae56ed21bf1 100644 --- a/examples/hello_embed.rs +++ b/examples/hello_embed.rs @@ -6,7 +6,7 @@ fn main() -> vm::PyResult<()> { let source = r#"print("Hello World!")"#; let code_obj = vm .compile(source, vm::compiler::Mode::Exec, "") - .map_err(|err| vm.new_syntax_error(&err, Some(source)))?; + .map_err(|err| err.into_pyexception(vm, Some(source)))?; vm.run_code_obj(code_obj, scope)?; diff --git a/examples/mini_repl.rs b/examples/mini_repl.rs index 40d111732ae..edbd6e1495c 100644 --- a/examples/mini_repl.rs +++ b/examples/mini_repl.rs @@ -66,7 +66,7 @@ def fib(n): // (note that this is only the case when compiler::Mode::Single is passed to vm.compile) match vm .compile(&input, vm::compiler::Mode::Single, "") - .map_err(|err| vm.new_syntax_error(&err, Some(&input))) + .map_err(|err| err.into_pyexception(vm, Some(&input))) .and_then(|code_obj| vm.run_code_obj(code_obj, scope.clone())) { Ok(output) => { diff --git a/examples/parse_folder.rs b/examples/parse_folder.rs index 440bcdb9b5f..7ece0c74065 100644 --- a/examples/parse_folder.rs +++ b/examples/parse_folder.rs @@ -131,4 +131,4 @@ struct ParsedFile { result: ParseResult, } -type ParseResult = Result, String>; +type ParseResult = Result; diff --git a/extra_tests/snippets/builtin_compile.py b/extra_tests/snippets/builtin_compile.py index 15095c0eede..49295bf26d2 100644 --- a/extra_tests/snippets/builtin_compile.py +++ b/extra_tests/snippets/builtin_compile.py @@ -1,3 +1,8 @@ +import __future__ + +import ast +import sys + from testutils import assert_raises # compile() basic mode acceptance @@ -43,4 +48,100 @@ def _check_flags_error(flags): _check_flags_error(99999) +_check_flags_error(0x100) +_check_flags_error(0x800) _check_flags_error(0x10000) + + +ns = {} +exec( + "from __future__ import annotations\n" + "inherited = compile('x: __debug__\\n', '', 'exec')\n" + "not_inherited = compile('x: __debug__\\n', '', 'exec', dont_inherit=True)\n", + ns, +) +assert ns["inherited"].co_flags & 0x1000000 +assert not (ns["not_inherited"].co_flags & 0x1000000) + +barry_flag = __future__.barry_as_FLUFL.compiler_flag +barry_code = compile("x = 1", "", "exec", flags=barry_flag) +compile("from __future__ import barry_as_FLUFL\nx = 1\n", "", "exec") +if sys.implementation.name == "rustpython": + assert not (barry_code.co_flags & barry_flag) + +n = ast.parse('x = "# type: int"\n', type_comments=True) +assert n.body[0].type_comment is None +n = ast.parse("x = '# type: int'\n", type_comments=True) +assert n.body[0].type_comment is None +n = ast.parse('x = "abc" # type: str\n', type_comments=True) +assert n.body[0].type_comment == "str" +n = ast.parse("x = 1 # type: ignore[excuse]\n", type_comments=True) +assert [(ti.lineno, ti.tag) for ti in n.type_ignores] == [(1, "[excuse]")] + + +compile("() -> int", "", "func_type", flags=ast.PyCF_ONLY_AST) +func_type_tree = compile( + '("a,b", str) -> int', "", "func_type", flags=ast.PyCF_ONLY_AST +) +assert len(func_type_tree.argtypes) == 2 +assert func_type_tree.argtypes[0].value == "a,b" +func_type_tree = compile( + "(int, *str, **Any) -> float", + "", + "func_type", + flags=ast.PyCF_ONLY_AST, +) +assert [arg.id for arg in func_type_tree.argtypes] == ["int", "str", "Any"] +assert_raises( + SyntaxError, + compile, + "int -> str", + "", + "func_type", + flags=ast.PyCF_ONLY_AST, +) +assert_raises( + SyntaxError, + compile, + "(x=1) -> str", + "", + "func_type", + flags=ast.PyCF_ONLY_AST, +) +assert_raises( + SyntaxError, + compile, + "(int,) -> str", + "", + "func_type", + flags=ast.PyCF_ONLY_AST, +) +PY_CF_DONT_IMPLY_DEDENT = 0x0200 +PY_CF_ALLOW_INCOMPLETE_INPUT = 0x4000 +compile(b"# coding: latin-1\nx = '\xe9'\n", "", "exec") +compile("if 1:\n pass", "", "single") +assert_raises( + SyntaxError, + compile, + "if 1:\n pass", + "", + "single", + flags=PY_CF_DONT_IMPLY_DEDENT, +) +compile( + "if 1:\n pass\n", + "", + "single", + flags=PY_CF_DONT_IMPLY_DEDENT | PY_CF_ALLOW_INCOMPLETE_INPUT, +) +try: + compile( + "if 1:\n pass", + "", + "single", + flags=PY_CF_DONT_IMPLY_DEDENT | PY_CF_ALLOW_INCOMPLETE_INPUT, + ) +except _IncompleteInputError as exc: + assert exc.args[0] == "incomplete input", repr(exc) +else: + raise AssertionError("expected _IncompleteInputError") diff --git a/src/shell.rs b/src/shell.rs index bc7ccec5c9d..7fb9336af4b 100644 --- a/src/shell.rs +++ b/src/shell.rs @@ -10,6 +10,7 @@ use rustpython_vm::{ compiler::{self}, readline::{Readline, ReadlineResult}, scope::Scope, + vm::VmCompileError, }; enum ShellExecResult { @@ -45,25 +46,25 @@ fn shell_exec( ShellExecResult::Ok } } - Err(CompileError::Parse(ParseError { + Err(VmCompileError::Compile(CompileError::Parse(ParseError { error: ParseErrorType::Lexical(LexicalErrorType::Eof), .. - })) => ShellExecResult::ContinueLine, - Err(CompileError::Parse(ParseError { + }))) => ShellExecResult::ContinueLine, + Err(VmCompileError::Compile(CompileError::Parse(ParseError { error: ParseErrorType::Lexical(LexicalErrorType::FStringError( InterpolatedStringErrorType::UnterminatedTripleQuotedString, )), .. - })) => ShellExecResult::ContinueLine, + }))) => ShellExecResult::ContinueLine, Err(err) => { // Check if the error is from an unclosed triple quoted string (which should always // continue) - if let CompileError::Parse(ParseError { + if let VmCompileError::Compile(CompileError::Parse(ParseError { error: ParseErrorType::Lexical(LexicalErrorType::UnclosedStringError), raw_location, .. - }) = err + })) = &err { let loc = raw_location.start().to_usize(); let mut iter = source.chars(); @@ -80,8 +81,8 @@ fn shell_exec( // since indentations errors on columns other than 0 should be ignored. // if its an unrecognized token for dedent, set to false - let bad_error = match err { - CompileError::Parse(ref p) => { + let bad_error = match &err { + VmCompileError::Compile(CompileError::Parse(p)) => { match &p.error { ParseErrorType::Lexical(LexicalErrorType::IndentationError) => { continuing_block @@ -97,7 +98,7 @@ fn shell_exec( // If we are handling an error on an empty line or an error worthy of throwing if empty_line_given || bad_error { - ShellExecResult::PyErr(vm.new_syntax_error(&err, Some(source))) + ShellExecResult::PyErr(err.into_pyexception(vm, Some(source))) } else { ShellExecResult::ContinueBlock } diff --git a/tools/opcode_metadata/generate_rs_opcode_metadata.py b/tools/opcode_metadata/generate_rs_opcode_metadata.py index df2476c5e08..fd13e026613 100644 --- a/tools/opcode_metadata/generate_rs_opcode_metadata.py +++ b/tools/opcode_metadata/generate_rs_opcode_metadata.py @@ -11,6 +11,7 @@ import typing import tomllib + from cpython import Analysis, get_analysis, get_stack_effect from opcodes import OpcodeInfo from utils import DEFAULT_INPUT, ROOT, get_conf, to_pascal_case