diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index c756a93c9b9..84265bda609 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -5,6 +5,7 @@ argtypes asdl asname atopen +atext attro augassign badcert @@ -36,6 +37,7 @@ CFWS CLASSDEREF classdict cmpop +CNOTAB codedepth CODEUNIT CONIN @@ -57,7 +59,9 @@ dictoffset distpoint dynload elts +eooh eofs +EOOH evalloop excepthandler exceptiontable @@ -95,27 +99,34 @@ HASUNION heaptype hexdigit HIGHRES +ialloc IFUNC IMMUTABLETYPE INCREF inlinedepth inplace inpos +ioffset +isbytecode +ishidden ismine ISPOINTER isoctal iteminfo Itertool +iused keeped kwnames kwonlyarg kwonlyargs +kwonlydefaults lasti libffi linearise lineful lineiterator linetable +LNOTAB loadfast localsplus localspluskinds @@ -130,17 +141,23 @@ mult multibytecodec nameobj nameop +nargsf +nblocks ncells +ncellsused +ncellvars nconsts newargs newfree NEWLOCALS newsemlockobject +nextop nfrees nkwargs nkwelts nlocalsplus nointerrupt +noffsets Nondescriptor noninteger nops @@ -148,6 +165,7 @@ noraise nseen NSIGNALS numer +nvars opname opnames orelse @@ -160,6 +178,7 @@ patma peepholer phcount platstdlib +ploc posonlyarg posonlyargs prec @@ -205,12 +224,14 @@ staticbase stginfo storefast stringlib +stringized structseq subkwargs subparams subscr sval swappedbytes +swaptimize sysdict tbstderr templatelib @@ -231,6 +252,7 @@ uncollectable Unhandle unparse unparser +untargeted untracking VARKEYWORDS varkwarg diff --git a/.cspell.dict/python-more.txt b/.cspell.dict/python-more.txt index 934529a7165..7ea660a5d1f 100644 --- a/.cspell.dict/python-more.txt +++ b/.cspell.dict/python-more.txt @@ -67,6 +67,7 @@ fillchar fillvalue finallyhandler firstiter +fobj firstlineno fnctl frombytes @@ -111,12 +112,14 @@ idfunc idiv idxs impls +infd indexgroup infj inittab Inittab instancecheck instanceof +instrs interpchannels interpqueues irepeat @@ -175,6 +178,7 @@ Nonprintable onceregistry origname ospath +outfd pendingcr phello platlibdir @@ -185,6 +189,7 @@ posonlyargcount prepending profilefunc pycache +pycapsule pycodecs pycs pydatetime diff --git a/.cspell.json b/.cspell.json index f05f2adcd65..21199c0c5f5 100644 --- a/.cspell.json +++ b/.cspell.json @@ -49,13 +49,15 @@ "ignorePaths": [ "**/__pycache__/**", "target/**", - "Lib/**" + "Lib/**", + "crates/host_env/**" ], // words - list of words to be always considered correct // (compound words like pyarg, baseclass, microbenchmark are handled by allowCompoundWords) "words": [ "aiterable", "alnum", + "csock", "coro", "dedentations", "dedents", @@ -65,6 +67,7 @@ "emscripten", "excs", "fnfe", + "ifexp", "interps", "jitted", "jitting", @@ -79,6 +82,8 @@ "reraising", "significand", "summands", + "TESTFN", + "TZPATH", "unraisable", "wasi", "weaked", diff --git a/.gitattributes b/.gitattributes index d076a34f977..b3a782dbe9c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -58,13 +58,14 @@ Lib/venv/scripts/posix/* text eol=lf # [attr]generated linguist-generated=true diff=generated -Lib/_opcode_metadata.py generated -Lib/keyword.py generated -Lib/idlelib/help.html generated -Lib/test/certdata/*.pem generated -Lib/test/certdata/*.0 generated -Lib/test/levenshtein_examples.json generated -Lib/test/test_stable_abi_ctypes.py generated -Lib/token.py generated +Lib/_opcode_metadata.py generated +Lib/keyword.py generated +Lib/idlelib/help.html generated +Lib/test/certdata/*.pem generated +Lib/test/certdata/*.0 generated +Lib/test/levenshtein_examples.json generated +Lib/test/test_stable_abi_ctypes.py generated +Lib/token.py generated +crates/compiler-core/src/bytecode/opcode_metadata.rs generated -.github/workflows/*.lock.yml linguist-generated=true merge=ours \ No newline at end of file +.github/workflows/*.lock.yml linguist-generated=true merge=ours diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000000..18ba1b6951f --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,10 @@ + + +- [ ] Closes #xxxx +- [ ] This PR follows our [AI policy](https://github.com/RustPython/.github/blob/main/AI_POLICY.md) + +## Summary + + diff --git a/.github/dependabot.yml b/.github/dependabot.yml index e3f9ba3b7ab..7533ce36803 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -85,7 +85,6 @@ updates: - "quote-use*" random: patterns: - - "ahash" - "getrandom" - "mt19937" - "rand*" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6261296c74b..30c5b54099a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -19,7 +19,7 @@ concurrency: cancel-in-progress: true env: - CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,ssl-rustls,host_env + CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,ssl-rustls-aws-lc,host_env CARGO_ARGS_NO_SSL: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,host_env # Crates excluded from workspace builds: # - rustpython_wasm: requires wasm target @@ -123,9 +123,6 @@ jobs: run: cargo test if: runner.os != 'Windows' # Requires pyo3 0.29+ on Windows - - run: cargo doc --locked - if: runner.os == 'Linux' - - name: check compilation without host_env (sandbox mode) run: | cargo check -p rustpython-vm --no-default-features --features compiler @@ -143,6 +140,10 @@ jobs: run: cargo build --no-default-features --features ssl-openssl if: runner.os == 'Linux' + - name: Test vendored OpenSSL build + run: cargo build --no-default-features --features ssl-openssl-vendor + if: runner.os == 'Linux' + # - name: Install tk-dev for tkinter build # run: sudo apt-get update && sudo apt-get install -y tk-dev # if: runner.os == 'Linux' @@ -189,6 +190,7 @@ jobs: skip_ssl: true - os: ubuntu-latest target: wasm32-wasip2 + skip_ssl: true - os: ubuntu-latest target: x86_64-unknown-freebsd skip_ssl: true @@ -293,18 +295,21 @@ jobs: - os: macos-latest extra_test_args: - '-u all' - env_polluting_tests: [] + env_polluting_tests: + - test_set skips: [] timeout: 50 - os: ubuntu-latest extra_test_args: - '-u all' - env_polluting_tests: [] + env_polluting_tests: + - test_set skips: [] timeout: 60 - os: windows-2025 extra_test_args: [] # TODO: Enable '-u all' - env_polluting_tests: [] + env_polluting_tests: + - test_set skips: - test_rlcompleter - test_pathlib # panic by surrogate chars @@ -501,7 +506,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable - - uses: cargo-bins/cargo-binstall@4852a15cf01e4f33958ce547326406fe78f27c38 # v1.19.0 + - uses: cargo-bins/cargo-binstall@aaa84a43aec4955a42c5ffc65d258961e39f276e # v1.19.1 - name: cargo shear run: | @@ -519,6 +524,8 @@ jobs: security-events: write # for zizmor steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 @@ -530,7 +537,7 @@ jobs: uses: reviewdog/action-actionlint@6fb7acc99f4a1008869fa8a0f09cfca740837d9d # v1.72.0 - name: zizmor - uses: zizmorcore/zizmor-action@b1d7e1fb5de872772f31590499237e7cce841e8e # v0.5.3 + uses: zizmorcore/zizmor-action@5f14fd08f7cf1cb1609c1e344975f152c7ee938d # v0.5.6 - name: restore prek cache uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 @@ -544,12 +551,35 @@ jobs: package-manager-cache: false node-version: "24" - - name: prek + - name: install prek id: prek - uses: j178/prek-action@6ad80277337ad479fe43bd70701c3f7f8aa74db3 # v2.0.3 + uses: j178/prek-action@bdca6f102f98e2b4c7029491a53dfd366469e33d # v2.0.4 with: cache: false show-verbose-logs: false + install-only: true + + - name: prek run + run: prek run --show-diff-on-failure --color=always --all-files + + - name: Get target CPython version + id: cpython-version + run: | + version=$(cat .python-version) + echo "version=${version}" >> "$GITHUB_OUTPUT" + + - name: Clone CPython + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + repository: python/cpython + path: cpython + ref: "v${{ steps.cpython-version.outputs.version }}" + persist-credentials: false + + - name: prek run (manual stage) + run: prek run --show-diff-on-failure --color=always --all-files --hook-stage manual + env: + CPYTHON_ROOT: ${{ github.workspace }}/cpython - name: save prek cache if: ${{ github.ref == 'refs/heads/main' }} # only save on main @@ -558,6 +588,10 @@ jobs: key: prek-${{ hashFiles('.pre-commit-config.yaml') }} path: ~/.cache/prek + - name: restore git permissions + if: ${{ !cancelled() }} + run: sudo chown -R "$(id -u):$(id -g)" .git + - name: reviewdog if: ${{ !cancelled() }} uses: reviewdog/action-suggester@aa38384ceb608d00f84b4690cacc83a5aba307ff # v1.24.0 @@ -702,7 +736,7 @@ jobs: - name: Deploy demo to Github Pages if: success() && github.ref == 'refs/heads/release' - uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + uses: peaceiris/actions-gh-pages@84c30a85c19949d7eee79c4ff27748b70285e453 # v4.1.0 env: ACTIONS_DEPLOY_KEY: ${{ secrets.ACTIONS_DEMO_DEPLOY_KEY }} PUBLISH_DIR: ./wasm/demo/dist @@ -756,8 +790,44 @@ jobs: clang: true - name: build rustpython - run: cargo build --release --target wasm32-wasip1 --features freeze-stdlib,stdlib --verbose + run: cargo build --release --target wasm32-wasip1 --no-default-features --features freeze-stdlib,stdlib,stdio,importlib,host_env --verbose - name: run snippets run: wasmer run --dir "$(pwd)" target/wasm32-wasip1/release/rustpython.wasm -- "$(pwd)/extra_tests/snippets/stdlib_random.py" - name: run cpython unittest run: wasmer run --dir "$(pwd)" target/wasm32-wasip1/release/rustpython.wasm -- "$(pwd)/Lib/test/test_int.py" + + cargo_doc: + needs: + - determine_changes + if: | + ( + !contains(github.event.pull_request.labels.*.name, 'skip:ci') && + needs.determine_changes.outputs.rust_code == 'true' + ) || github.ref == 'refs/heads/main' + env: + RUST_BACKTRACE: full + name: cargo doc + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - uses: dtolnay/rust-toolchain@stable + + - name: Restore cache + uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}- + restore-keys: | + ${{ runner.os }}-stable--${{ hashFiles('**/Cargo.toml') }}- + ${{ runner.os }}-stable-- + + - name: cargo doc + run: cargo doc --locked diff --git a/.github/workflows/cron-ci.yaml b/.github/workflows/cron-ci.yaml index b9e237e300a..ac1a1cadbfb 100644 --- a/.github/workflows/cron-ci.yaml +++ b/.github/workflows/cron-ci.yaml @@ -14,7 +14,7 @@ on: - .github/workflows/cron-ci.yaml env: - CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,ssl-rustls,jit,host_env + CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,ssl-rustls-aws-lc,jit,host_env FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: 'true' # TODO: Remove on 2026/06/02 jobs: @@ -32,7 +32,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable - - uses: taiki-e/install-action@cf525cb33f51aca27cd6fa02034117ab963ff9f1 # v2.75.22 + - uses: taiki-e/install-action@b550161ef8a7bc4f2a671c0b03a18ac9ccedea1e # v2.79.1 with: tool: cargo-llvm-cov @@ -41,7 +41,7 @@ jobs: - run: sudo apt-get update && sudo apt-get -y install lcov - name: Run cargo-llvm-cov with Rust tests. - run: cargo llvm-cov --no-report --workspace --exclude rustpython_wasm --exclude rustpython-compiler-source --exclude rustpython-venvlauncher --verbose --no-default-features --features stdlib,importlib,stdio,encodings,ssl-rustls,jit,host_env + run: cargo llvm-cov --no-report --workspace --exclude rustpython_wasm --exclude rustpython-compiler-source --exclude rustpython-venvlauncher --verbose --no-default-features --features stdlib,importlib,stdio,encodings,ssl-rustls-aws-lc,jit,host_env - name: Run cargo-llvm-cov with Python snippets. run: python scripts/cargo-llvm-cov.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 703b2acb9ef..be92c5e4704 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -68,7 +68,7 @@ jobs: libtool: true - name: Build RustPython - run: cargo build --release --target=${{ matrix.target }} --verbose --no-default-features --features stdlib,stdio,importlib,encodings,sqlite,host_env,ssl-rustls,threading,jit + run: cargo build --release --target=${{ matrix.target }} --verbose --no-default-features --features stdlib,stdio,importlib,encodings,sqlite,host_env,ssl-rustls-aws-lc,threading,jit - name: Rename Binary run: cp target/${{ matrix.target }}/release/rustpython target/rustpython-release-${{ runner.os }}-${{ matrix.target }} @@ -141,7 +141,7 @@ jobs: - name: Deploy demo to Github Pages if: ${{ github.repository == 'RustPython/RustPython' }} - uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + uses: peaceiris/actions-gh-pages@84c30a85c19949d7eee79c4ff27748b70285e453 # v4.1.0 with: deploy_key: ${{ secrets.ACTIONS_DEMO_DEPLOY_KEY }} publish_dir: ./wasm/demo/dist diff --git a/.github/workflows/update-caches.yml b/.github/workflows/update-caches.yml index 585563c45e6..392717e5a2c 100644 --- a/.github/workflows/update-caches.yml +++ b/.github/workflows/update-caches.yml @@ -19,7 +19,7 @@ env: CARGO_PROFILE_TEST_DEBUG: 0 CARGO_PROFILE_DEV_DEBUG: 0 CARGO_PROFILE_RELEASE_DEBUG: 0 - CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,ssl-rustls,host_env,threading,jit + CARGO_ARGS: --workspace --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,ssl-rustls-aws-lc,host_env,threading,jit --exclude rustpython_wasm --exclude rustpython-compiler-source --exclude rustpython-venvlauncher jobs: build-caches: diff --git a/.github/workflows/upgrade-pylib.lock.yml b/.github/workflows/upgrade-pylib.lock.yml index 1be61c4bc92..ec7a12be1f4 100644 --- a/.github/workflows/upgrade-pylib.lock.yml +++ b/.github/workflows/upgrade-pylib.lock.yml @@ -58,7 +58,7 @@ jobs: comment_repo: "" steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@2f2a6f572b9038823081cb9d408f235e1a109a0b # v0.71.3 + uses: github/gh-aw/actions/setup@2c1a237d2048b0e2412e7d7528892ea1257840e2 # v0.74.4 with: destination: /opt/gh-aw/actions - name: Check workflow file timestamps @@ -99,7 +99,7 @@ jobs: secret_verification_result: ${{ steps.validate-secret.outputs.verification_result }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@2f2a6f572b9038823081cb9d408f235e1a109a0b # v0.71.3 + uses: github/gh-aw/actions/setup@2c1a237d2048b0e2412e7d7528892ea1257840e2 # v0.74.4 with: destination: /opt/gh-aw/actions - name: Checkout repository @@ -806,7 +806,7 @@ jobs: total_count: ${{ steps.missing_tool.outputs.total_count }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@2f2a6f572b9038823081cb9d408f235e1a109a0b # v0.71.3 + uses: github/gh-aw/actions/setup@2c1a237d2048b0e2412e7d7528892ea1257840e2 # v0.74.4 with: destination: /opt/gh-aw/actions - name: Download agent output artifact @@ -927,7 +927,7 @@ jobs: success: ${{ steps.parse_results.outputs.success }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@2f2a6f572b9038823081cb9d408f235e1a109a0b # v0.71.3 + uses: github/gh-aw/actions/setup@2c1a237d2048b0e2412e7d7528892ea1257840e2 # v0.74.4 with: destination: /opt/gh-aw/actions - name: Download agent artifacts @@ -1039,7 +1039,7 @@ jobs: process_safe_outputs_temporary_id_map: ${{ steps.process_safe_outputs.outputs.temporary_id_map }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@2f2a6f572b9038823081cb9d408f235e1a109a0b # v0.71.3 + uses: github/gh-aw/actions/setup@2c1a237d2048b0e2412e7d7528892ea1257840e2 # v0.74.4 with: destination: /opt/gh-aw/actions - name: Download agent output artifact diff --git a/.gitignore b/.gitignore index 338a6437ca2..b5887be53b5 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ Lib/site-packages/* Lib/test/data/* !Lib/test/data/README cpython/ - +.claude/scheduled_tasks.lock \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8778283dd5f..e66122542f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,14 +40,27 @@ repos: types: [rust] priority: 0 - - id: generate-opcode-metadata - name: generate opcode metadata - entry: python scripts/generate_opcode_metadata.py - files: '^(crates/compiler-core/src/bytecode/instruction\.rs|scripts/generate_opcode_metadata\.py)$' + - id: generate-rs-opcode-metadata + name: generate rust opcode metadata + entry: python tools/opcode_metadata/generate_rs_opcode_metadata.py + files: '^(crates/compiler-core/src/bytecode/instruction\.rs|tools/opcode_metadata/*)$' pass_filenames: false language: system require_serial: true priority: 1 # so rustfmt runs first + stages: + - manual + + - id: generate-py-opcode-metadata + name: generate python opcode metadata + entry: python tools/opcode_metadata/generate_py_opcode_metadata.py + files: '^(crates/compiler-core/src/bytecode/instruction\.rs|tools/opcode_metadata/*)$' + pass_filenames: false + language: system + require_serial: true + priority: 1 # so rustfmt runs first + stages: + - manual - repo: https://github.com/streetsidesoftware/cspell-cli rev: v10.0.0 diff --git a/AGENTS.md b/AGENTS.md index b407328cffb..c89b2a4d3a4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,6 +38,12 @@ RustPython is a Python 3 interpreter written in Rust, implementing Python 3.14.0 - Always ask the user before performing any git operations that affect the remote repository - Commits can be created locally when requested, but pushing and PR creation require explicit approval +**CRITICAL: Pre-commit Checks** +- Before creating ANY commit, you MUST run `prek run --all-files` (or `pre-commit run --all-files`) AND the full test suite. Both must pass — do not commit if either fails. +- Test commands are documented in the [Testing](#testing) section below. At minimum run `cargo test --workspace --exclude rustpython_wasm --exclude rustpython-venvlauncher`; if the change touches `extra_tests/snippets/` run `pytest -v` there too, and if it touches `Lib/` or interpreter behavior, run the relevant `cargo run --release -- -m test ` modules. +- If a hook auto-fixes files (e.g. `ruff-format`, `rustfmt`), re-stage the fixes, re-run `prek` until it reports a clean pass, then re-run the tests, then commit. +- NEVER bypass these checks with `--no-verify`, `--no-gpg-sign`, or by skipping tests "because the change is small". If a hook or test fails, fix the underlying issue and create a new commit — do not amend or force the failing commit through. + ## Important Development Notes ### Running Python Code @@ -81,6 +87,35 @@ The `Lib/` directory contains Python standard library files copied from the CPyt - `unittest.skip("TODO: RustPython ")` - `unittest.expectedFailure` with `# TODO: RUSTPYTHON ` comment +#### Choosing the right marker + +When marking a test that fails on RustPython, prefer one of the following forms: + +```python +@unittest.expectedFailure # TODO: RUSTPYTHON; +# or +@unittest.expectedFailureIf(, "TODO: RUSTPYTHON; ") +``` + +If the test would crash the interpreter (segfault, Rust panic, abort, infinite loop), use `skip` instead so the rest of the suite can still run: + +```python +@unittest.skip("TODO: RUSTPYTHON; ") +# or +@unittest.skipIf(, "TODO: RUSTPYTHON; ") +``` + +**When to use which:** + +- **Prefer `expectedFailure` / `expectedFailureIf`** by default. The test body still runs, so if RustPython is later fixed, the unexpected pass surfaces immediately and the decorator can be removed. Use the conditional `*If` form when the failure is environment-specific (e.g., a platform or build flag). +- **Use `skip` / `skipIf` only when running the test would take down the test process** — segfaults, Rust panics, aborts, or hangs that block subsequent tests. Skipping keeps the suite usable; `expectedFailure` cannot help here, because the test body still executes. + +To find WIP entries that are partly modified and may need follow-up: + +```bash +grep -d recurse 'TODO: RUSTPYTHON' Lib/test/ +``` + ### Clean Build When you modify bytecode instructions, a full clean is required: @@ -129,6 +164,7 @@ Run `./scripts/whats_left.py` to get a list of unimplemented methods, which is h - Do not delete or rewrite existing comments unless they are factually wrong or directly contradict the new code. - Do not add decorative section separators (e.g. `// -----------`, `// ===`, `/* *** */`). Use `///` doc-comments or short `//` comments only when they add value. +- Do not put `///` doc comments on items annotated with `#[pyattr]`, `#[pyclass]`, or `#[pyfunction]`. The derive macros pull authoritative docstrings from CPython via the `rustpython-doc` crate; a Rust doc comment overrides that source, and on `#[pyattr]` it is silently dropped. #### Avoid Duplicate Code in Branches @@ -258,9 +294,14 @@ See DEVELOPMENT.md "CPython Version Upgrade Checklist" section. - Document that it requires PEP 695 support - Focus on tests that can be fixed through Rust code changes only +## CI Workflows + +If you modify any file under `.github/workflows/`, the change must pass a [zizmor](https://docs.zizmor.sh/) scan in CI. + ## Documentation - Check the [architecture document](/architecture/architecture.md) for a high-level overview - Read the [development guide](/DEVELOPMENT.md) for detailed setup instructions - Generate documentation with `cargo doc --no-deps --all` - Online documentation is available at [docs.rs/rustpython](https://docs.rs/rustpython/) +- [How to update test files](https://github.com/RustPython/RustPython/wiki/How-to-update-test-files#checkout-cpython-source-code-initial-setup) — guide for syncing test cases from upstream CPython into the `Lib/` directory diff --git a/DEVELOPMENT.md b/CONTRIBUTING.md similarity index 88% rename from DEVELOPMENT.md rename to CONTRIBUTING.md index 7573f0f2640..58954486eaf 100644 --- a/DEVELOPMENT.md +++ b/CONTRIBUTING.md @@ -1,4 +1,28 @@ -# RustPython Development Guide and Tips +# Contributing to RustPython + +Contributions are more than welcome, and in many cases we are happy to guide +contributors through PRs or on [**Discord**](https://discord.gg/vru8NypEhv). + +## Finding ways to help + +We label issues that would be good for a first time contributor as [`good first issue`](https://github.com/RustPython/RustPython/issues?q=label%3A%22good+first+issue%22+is%3Aissue+is%3Aopen+). +Also checkout the [issue tracker](https://github.com/RustPython/RustPython/issues) for all open issues. + +You can enhance CPython compatibility by increasing our unittest coverage, you can see [This pinned issue](https://github.com/RustPython/RustPython/issues/6839) to see which libs and tests need be updated to our current supported python version. + +Another approach is to checkout the source code: builtin functions and object +methods are often the simplest and easiest way to contribute. + +You can also simply run `python -I scripts/whats_left.py` to assist in finding any unimplemented method. + +## Use of AI + +We **require all use of AI in contributions to follow our +[AI Policy](https://github.com/RustPython/.github/blob/main/AI_POLICY.md)**. + +If your contribution does not follow the policy, it will be closed. + +## RustPython Development Guide and Tips RustPython attracts developers with interest and experience in Rust, Python, or WebAssembly. Whether you are familiar with Rust, Python, or diff --git a/Cargo.lock b/Cargo.lock index 4e9b54bb8aa..d632ce7f741 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16,26 +16,13 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "aes" -version = "0.8.4" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +checksum = "66bd29a732b644c0431c6140f370d097879203d79b80c94a6747ba0872adaef8" dependencies = [ - "cfg-if", "cipher", - "cpufeatures", -] - -[[package]] -name = "ahash" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" -dependencies = [ - "cfg-if", - "getrandom 0.3.4", - "once_cell", - "version_check", - "zerocopy", + "cpubits", + "cpufeatures 0.3.0", ] [[package]] @@ -94,9 +81,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" @@ -129,9 +116,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "approx" @@ -175,7 +162,7 @@ dependencies = [ "nom", "num-traits", "rusticata-macros", - "thiserror 2.0.18", + "thiserror", "time", ] @@ -249,9 +236,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-fips-sys" -version = "0.13.13" +version = "0.13.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8bce4948d2520386c6d92a6ea2d472300257702242e5a1d01d6add52bd2e7c1" +checksum = "d3d619165468401dec3caa3366ebffbcb83f2f31883e5b3932f8e2dec2ddc568" dependencies = [ "bindgen 0.72.1", "cc", @@ -269,7 +256,6 @@ checksum = "0ec6fb3fe69024a75fa7e1bfb48aa6cf59706a101658ea01bfd33b2b248a038f" dependencies = [ "aws-lc-fips-sys", "aws-lc-sys", - "untrusted 0.7.1", "zeroize", ] @@ -303,7 +289,7 @@ version = "0.71.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cexpr", "clang-sys", "itertools 0.13.0", @@ -323,7 +309,7 @@ version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cexpr", "clang-sys", "itertools 0.13.0", @@ -345,9 +331,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" [[package]] name = "bitflagset" @@ -367,7 +353,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest", + "digest 0.10.7", ] [[package]] @@ -379,13 +365,22 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" +dependencies = [ + "hybrid-array", +] + [[package]] name = "block-padding" -version = "0.3.3" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +checksum = "710f1dd022ef4e93f8a438b4ba958de7f64308434fa6a87104481645cc30068b" dependencies = [ - "generic-array", + "hybrid-array", ] [[package]] @@ -410,9 +405,9 @@ dependencies = [ [[package]] name = "bytemuck" -version = "1.24.0" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" [[package]] name = "bytes" @@ -429,15 +424,6 @@ dependencies = [ "libbz2-rs-sys", ] -[[package]] -name = "caseless" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6fd507454086c8edfd769ca6ada439193cdb209c7681712ef6275cccbfe5d8" -dependencies = [ - "unicode-normalization", -] - [[package]] name = "cast" version = "0.3.0" @@ -455,18 +441,18 @@ dependencies = [ [[package]] name = "cbc" -version = "0.1.2" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +checksum = "ce2dc9ee5f88d11e0beb842c88b33c8a5cf0d1329c4b19494af42b07dbfe8896" dependencies = [ "cipher", ] [[package]] name = "cc" -version = "1.2.54" +version = "1.2.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6354c81bbfd62d9cfa9cb3c773c2b7b2a3a482d569de977fd0e961f6e7c00583" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" dependencies = [ "find-msvc-tools", "jobserver", @@ -537,11 +523,12 @@ dependencies = [ [[package]] name = "cipher" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +checksum = "e8cf2a2c93cd704877c0858356ed03480ff301ee950b43f1cbe4573b088bfa6c" dependencies = [ - "crypto-common", + "block-buffer 0.12.0", + "crypto-common 0.2.2", "inout", ] @@ -558,18 +545,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.54" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.54" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstyle", "clap_lex", @@ -577,9 +564,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.7" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clipboard-win" @@ -592,13 +579,19 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +checksum = "c0f78a02292a74a88ac736019ab962ece0bc380e3f977bf72e376c5d78ff0678" dependencies = [ "cc", ] +[[package]] +name = "cmov" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746" + [[package]] name = "collection_literals" version = "1.0.3" @@ -607,9 +600,9 @@ checksum = "2550f75b8cfac212855f6b1885455df8eaee8fe8e246b647d69146142e016084" [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "combine" @@ -700,6 +693,23 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-models" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "657f625ff361906f779745d08375ae3cc9fef87a35fba5f22874cf773010daf4" +dependencies = [ + "hax-lib", + "pastey", + "rand 0.9.4", +] + +[[package]] +name = "cpubits" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b85f9c39137c3a891689859392b1bd49812121d0d61c9caf00d46ed5ce06ae" + [[package]] name = "cpufeatures" version = "0.2.17" @@ -709,11 +719,20 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "cranelift" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe07d3554a07c7e60c112ddf4aa5eeab69a4eb632881a59c2add0fdc4596cc7" +checksum = "4df9a80b0e50668699b52cc4a4ffa10e423df6ad4fd0150e3f983efc2bd7e876" dependencies = [ "cranelift-codegen", "cranelift-frontend", @@ -722,27 +741,27 @@ dependencies = [ [[package]] name = "cranelift-assembler-x64" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8628cc4ba7f88a9205a7ee42327697abc61195a1e3d92cfae172d6a946e722e" +checksum = "008f1a8d1da5074ad858f398775a6d1989031892e46927df5ed18d3be1ed8717" dependencies = [ "cranelift-assembler-x64-meta", ] [[package]] name = "cranelift-assembler-x64-meta" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d582754487e6c9a065a91c42ccf1bdd8d5977af33468dac5ae9bec0ce88acb3e" +checksum = "9fd76237df1f4e26edb5ad7971d20280ed1e193331fd257f1b4e4dfefd88dda2" dependencies = [ "cranelift-srcgen", ] [[package]] name = "cranelift-bforest" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb59c81ace12ee7c33074db7903d4d75d1f40b28cd3e8e6f491de57b29129eb9" +checksum = "380f0bc43e535df6855bbee649efb00bde39c3f33434c47c8e10ac836d21bf47" dependencies = [ "cranelift-entity", "wasmtime-internal-core", @@ -750,18 +769,18 @@ dependencies = [ [[package]] name = "cranelift-bitset" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f25c06993a681be9cf3140798a3d4ac5bec955e7444416a2fdc87fda8567285d" +checksum = "4811e3e4502de04257e90c0a93225b56d9b85e0f9ad10b81446b415511009610" dependencies = [ "wasmtime-internal-core", ] [[package]] name = "cranelift-codegen" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b61f95c5a211918f5d336254a61a488b36a5818de47a868e8c4658dce9cccc" +checksum = "82ffadb34d497f3e76fb3b4baf764c24ba8a51512976a1b77f78bdbf8f4aa687" dependencies = [ "bumpalo", "cranelift-assembler-x64", @@ -786,9 +805,9 @@ dependencies = [ [[package]] name = "cranelift-codegen-meta" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b85aa822fce72080d041d7c2cf7c3f5c6ecdea7afae68379ba4ef85269c4fa5" +checksum = "be4f6992eb6faf086ddc7deaaa5f279abfe7f5fd5ae5709bd38253450fc7b945" dependencies = [ "cranelift-assembler-x64-meta", "cranelift-codegen-shared", @@ -798,24 +817,24 @@ dependencies = [ [[package]] name = "cranelift-codegen-shared" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "833eb9fc89326cd072cc19e96892f09b5692c0dfe17cd4da2858ba30c2cd85c0" +checksum = "70e1b2aad7d055925a4ea9cdbfa9d1d987f9dfc8ad6b708be28f901ac620a298" [[package]] name = "cranelift-control" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d005320f487e6e8a3edcc7f2fd4f43fcc9946d1013bf206ea649789ac1617fc" +checksum = "89a355348325e0a63b65c00def3871597b9fcc79d25456397010d16d872b3772" dependencies = [ "arbitrary", ] [[package]] name = "cranelift-entity" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e62ef34c6e720f347a79ece043e8584e242d168911da640bac654a33a6aaaf5" +checksum = "43f4847d93ce2c80d2bff929aa1004dfb3ce2cf5d881f6ced54b8d654d967ba3" dependencies = [ "cranelift-bitset", "wasmtime-internal-core", @@ -823,9 +842,9 @@ dependencies = [ [[package]] name = "cranelift-frontend" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfa2ad00399dd47e7e7e33cb1dc23b0e39ed9dcd01e8f026fc37af91655031b8" +checksum = "ba24e5fe5242cc445e7892ef0a51a4351cf716e3a04ac7a3a05820d056c39818" dependencies = [ "cranelift-codegen", "log", @@ -835,15 +854,15 @@ dependencies = [ [[package]] name = "cranelift-isle" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02c51975ed217b4e8e5a7fd11e9ec83a96104bdff311dddcb505d1d8a9fd7fc6" +checksum = "89bc2035de85c4f04ba7bd57eb5bd3a8b775235bf28852dbf87105115cb8919a" [[package]] name = "cranelift-jit" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4070d2acc5c976887c10c273d38dd2bebebb472fd1ba65fa17b548adf5f56350" +checksum = "9bf7a3c46c8b1ba6f4818f0cfe971d0cf875a28c7fded25b9fc0b75acbbb677a" dependencies = [ "anyhow", "cranelift-codegen", @@ -861,9 +880,9 @@ dependencies = [ [[package]] name = "cranelift-module" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052a396328dbc7dadc29d1bd27d4aa57d9e9e493ead8ef6e2ab3b75c1bf9644c" +checksum = "680bd0df1ea88dc543eaa6aadf326200640be7603c5f36f9d5c1230b784ad8bc" dependencies = [ "anyhow", "cranelift-codegen", @@ -872,9 +891,9 @@ dependencies = [ [[package]] name = "cranelift-native" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9b1889e00da9729d8f8525f3c12998ded86ea709058ff844ebe00b97548de0e" +checksum = "5ea6630c16921ab087792750f239d0c0173411e80179ca7c0ce0710ce9e7646a" dependencies = [ "cranelift-codegen", "libc", @@ -883,9 +902,9 @@ dependencies = [ [[package]] name = "cranelift-srcgen" -version = "0.131.1" +version = "0.131.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5a8f82fd5124f009f72167e60139245cd3b56cfd4b53050f22110c48c5f4da1" +checksum = "faa4bbad54fc28cc0da1f9a5d7f7f826ec8cafda3d503b401b2daaaa93c63ef0" [[package]] name = "crc32fast" @@ -972,6 +991,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "crypto-common" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6e4c961d6cd6c9a86db418387425e8bdeaf05b3c8bc1411e6dca4c252f1453" +dependencies = [ + "hybrid-array", +] + [[package]] name = "csv-core" version = "0.1.13" @@ -981,11 +1009,20 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctutils" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e" +dependencies = [ + "cmov", +] + [[package]] name = "data-encoding" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" +checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8" [[package]] name = "der" @@ -1038,18 +1075,18 @@ dependencies = [ [[package]] name = "deranged" -version = "0.5.5" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" dependencies = [ "powerfmt", ] [[package]] name = "derive-where" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f" +checksum = "d08b3a0bcc0d079199cd476b2cae8435016ec11d1c0986c6901c5ac223041534" dependencies = [ "proc-macro2", "quote", @@ -1062,30 +1099,41 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", - "crypto-common", + "block-buffer 0.10.4", + "crypto-common 0.1.7", "subtle", ] [[package]] -name = "dirs-next" -version = "2.0.0" +name = "digest" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ - "cfg-if", - "dirs-sys-next", + "block-buffer 0.12.0", + "crypto-common 0.2.2", + "ctutils", ] [[package]] -name = "dirs-sys-next" -version = "0.1.2" +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" dependencies = [ "libc", + "option-ext", "redox_users", - "winapi", + "windows-sys 0.61.2", ] [[package]] @@ -1143,9 +1191,9 @@ checksum = "869b0adbda23651a9c5c0c3d270aac9fcb52e8622a8f2b17e57802d7791962f2" [[package]] name = "env_filter" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", "regex", @@ -1194,15 +1242,15 @@ checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193" [[package]] name = "fastrand" -version = "2.3.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" [[package]] name = "find-msvc-tools" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "flagset" @@ -1262,6 +1310,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foldhash" version = "0.2.0" @@ -1274,7 +1328,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared 0.1.1", + "foreign-types-shared", ] [[package]] @@ -1283,12 +1337,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" -[[package]] -name = "foreign-types-shared" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" - [[package]] name = "fs_extra" version = "1.3.0" @@ -1297,26 +1345,25 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-task", "pin-project-lite", - "pin-utils", "slab", ] @@ -1393,11 +1440,24 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + [[package]] name = "gimli" version = "0.33.0" @@ -1416,6 +1476,16 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "graviola" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4387e0458389da24c6fe732531e65595c7c4a32b027f98f4789e512e28224465" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", +] + [[package]] name = "half" version = "2.7.1" @@ -1432,6 +1502,9 @@ name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash 0.1.5", +] [[package]] name = "hashbrown" @@ -1439,7 +1512,7 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" dependencies = [ - "foldhash", + "foldhash 0.2.0", ] [[package]] @@ -1448,6 +1521,43 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +[[package]] +name = "hax-lib" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "543f93241d32b3f00569201bfce9d7a93c92c6421b23c77864ac929dc947b9fc" +dependencies = [ + "hax-lib-macros", + "num-bigint", + "num-traits", +] + +[[package]] +name = "hax-lib-macros" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8755751e760b11021765bb04cb4a6c4e24742688d9f3aa14c2079638f537b0f" +dependencies = [ + "hax-lib-macros-types", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "hax-lib-macros-types" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f177c9ae8ea456e2f71ff3c1ea47bf4464f772a05133fcbba56cd5ba169035a2" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "heck" version = "0.5.0" @@ -1478,7 +1588,16 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "digest", + "digest 0.10.7", +] + +[[package]] +name = "hmac" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" +dependencies = [ + "digest 0.11.3", ] [[package]] @@ -1490,11 +1609,20 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "hybrid-array" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9155a582abd142abc056962c29e3ce5ff2ad5469f4246b537ed42c5deba857da" +dependencies = [ + "typenum", +] + [[package]] name = "iana-time-zone" -version = "0.1.64" +version = "0.1.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1646,6 +1774,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "indexmap" version = "2.14.0" @@ -1654,16 +1788,18 @@ checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", "hashbrown 0.17.0", + "serde", + "serde_core", ] [[package]] name = "inout" -version = "0.1.4" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +checksum = "4250ce6452e92010fdf7268ccc5d14faa80bb12fc741938534c58f16804e03c7" dependencies = [ "block-padding", - "generic-array", + "hybrid-array", ] [[package]] @@ -1722,15 +1858,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" dependencies = [ "jiff-static", "log", @@ -1741,9 +1877,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" dependencies = [ "proc-macro2", "quote", @@ -1762,7 +1898,7 @@ dependencies = [ "jni-sys", "log", "simd_cesu8", - "thiserror 2.0.18", + "thiserror", "walkdir", "windows-link", ] @@ -1811,10 +1947,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.85" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -1835,7 +1973,7 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -1850,6 +1988,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "lexical-parse-float" version = "1.0.6" @@ -1883,9 +2027,9 @@ checksum = "803ec87c9cfb29b9d2633f20cba1f488db3fd53f2158b1024cbefb47ba05d413" [[package]] name = "libbz2-rs-sys" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c4a545a15244c7d945065b5d392b2d2d7f21526fba56ce51467b06ed445e8f7" +checksum = "b3a6a8c165077efc8f3a971534c50ea6a1a18b329ef4a66e897a7e3a1494565f" [[package]] name = "libc" @@ -1893,6 +2037,70 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "libcrux-intrinsics" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1b5db005ff8001e026b73a6842ee81bbef8ec5ff0e1915a67ae65fd2a9fafa5" +dependencies = [ + "core-models", + "hax-lib", +] + +[[package]] +name = "libcrux-ml-kem" +version = "0.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a14ab3e477de9df6ee1273a114018ff62c4996ca9220070c4e5cb1743f94a67d" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-sha3", + "libcrux-traits", +] + +[[package]] +name = "libcrux-platform" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9e21d7ed31a92ac539bd69a8c970b183ee883872d2d19ce27036e24cb8ecc4" +dependencies = [ + "libc", +] + +[[package]] +name = "libcrux-secrets" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce650f3041b44ba40d4263852347d007cd2cd9d1cc856a6f6c8b2e10c3fd40b" +dependencies = [ + "hax-lib", +] + +[[package]] +name = "libcrux-sha3" +version = "0.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1ae0b7d0e1cc4793a609fd0ff2ca3b3a3fabae523770c619a3d4bc86417b0d7" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-traits", +] + +[[package]] +name = "libcrux-traits" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e4fa89f3f5e34b47f928b22b1b78395a0d4ec23b1f583db635f128159d65f" +dependencies = [ + "libcrux-secrets", + "rand 0.9.4", +] + [[package]] name = "libffi" version = "5.1.0" @@ -1960,11 +2168,10 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.12" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" dependencies = [ - "bitflags 2.11.0", "libc", ] @@ -1996,9 +2203,9 @@ checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" [[package]] name = "lock_api" @@ -2114,12 +2321,6 @@ dependencies = [ "quote", ] -[[package]] -name = "maplit" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" - [[package]] name = "md-5" version = "0.10.6" @@ -2127,7 +2328,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ "cfg-if", - "digest", + "digest 0.10.7", ] [[package]] @@ -2194,7 +2395,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "cfg_aliases", "libc", @@ -2203,11 +2404,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.31.2" +version = "0.31.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +checksum = "cf20d2fde8ff38632c426f1165ed7436270b44f199fc55284c38276f9db47c3d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "cfg_aliases", "libc", @@ -2245,9 +2446,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" [[package]] name = "num-integer" @@ -2318,9 +2519,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "once_cell_polyfill" @@ -2336,11 +2537,11 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl" -version = "0.10.79" +version = "0.10.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0b434746ee2832f4f0baf10137e1cabb18cbe6912c69e2e33263c45250f542" +checksum = "a45fa2aa886c42762255da344f0a0d313e254066c46aad76f300c3d3da62d967" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "foreign-types", "libc", @@ -2367,18 +2568,18 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-src" -version = "300.5.4+3.5.4" +version = "300.6.0+3.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507b3792995dae9b0df8a1c1e3771e8418b7c2d9f0baeba32e6fe8b06c7cb72" +checksum = "a8e8cbfd3a4a8c8f089147fd7aaa33cf8c7450c4d09f8f80698a0cf093abeff4" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.115" +version = "0.9.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "158fe5b292746440aa6e7a7e690e55aeb72d41505e2804c23c6973ad0e9c9781" +checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4" dependencies = [ "cc", "libc", @@ -2387,6 +2588,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "optional" version = "0.5.0" @@ -2395,9 +2602,9 @@ checksum = "978aa494585d3ca4ad74929863093e87cac9790d81fe7aba2b3dc2890643a0fc" [[package]] name = "ordermap" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa78c92071bbd3628c22b1a964f7e0eb201dc1456555db072beb1662ecd6715" +checksum = "7f7476a5b122ff1fce7208e7ee9dccd0a516e835f5b8b19b8f3c98a34cf757c1" dependencies = [ "indexmap", ] @@ -2440,14 +2647,30 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pastey" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee67f1008b1ba2321834326597b8e186293b049a023cdef258527550b9935b4" + [[package]] name = "pbkdf2" version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" dependencies = [ - "digest", - "hmac", + "digest 0.10.7", + "hmac 0.12.1", +] + +[[package]] +name = "pbkdf2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112d82ceb8c5bf524d9af484d4e4970c9fd5a0cc15ba14ad93dccd28873b0629" +dependencies = [ + "digest 0.11.3", + "hmac 0.13.0", ] [[package]] @@ -2505,7 +2728,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared 0.11.3", - "rand 0.8.5", + "rand 0.8.6", ] [[package]] @@ -2551,48 +2774,43 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" - -[[package]] -name = "pin-utils" -version = "0.1.0" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pkcs5" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e847e2c91a18bfa887dd028ec33f2fe6f25db77db3619024764914affe8b69a6" +checksum = "279a91971a1d8eb1260a30938eae3be9cb67b472dffecb222fbbbe2fd2dc1453" dependencies = [ "aes", "cbc", - "der 0.7.10", - "pbkdf2", + "der 0.8.0", + "pbkdf2 0.13.0", + "rand_core 0.10.1", "scrypt", - "sha2", - "spki", + "sha2 0.11.0", + "spki 0.8.0", ] [[package]] name = "pkcs8" -version = "0.10.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +checksum = "451913da69c775a56034ea8d9003d27ee8948e12443eae7c038ba100a4f21cb7" dependencies = [ - "der 0.7.10", + "der 0.8.0", "pkcs5", - "rand_core 0.6.4", - "spki", + "rand_core 0.10.1", + "spki 0.8.0", ] [[package]] name = "pkg-config" -version = "0.3.32" +version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" [[package]] name = "plotters" @@ -2635,24 +2853,24 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.4" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" dependencies = [ "portable-atomic", ] [[package]] name = "potential_utf" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ "serde_core", "writeable", @@ -2684,6 +2902,28 @@ dependencies = [ "syn", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proc-macro-utils" version = "0.10.0" @@ -2823,6 +3063,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "radium" version = "1.1.1" @@ -2844,9 +3090,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha 0.3.1", @@ -2901,11 +3147,26 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + +[[package]] +name = "rapidhash" +version = "4.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e48930979c155e2f33aa36ab3119b5ee81332beb6482199a8ecd6029b80b59" +dependencies = [ + "rustversion", +] + [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -2933,18 +3194,18 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", ] [[package]] name = "redox_users" -version = "0.4.6" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ "getrandom 0.2.17", "libredox", - "thiserror 1.0.69", + "thiserror", ] [[package]] @@ -2969,13 +3230,13 @@ dependencies = [ [[package]] name = "regalloc2" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "952ddbfc6f9f64d006c3efd8c9851a6ba2f2b944ba94730db255d55006e0ffda" +checksum = "de2c52737737f8609e94f975dee22854a2d5c125772d4b1cf292120f4d45c186" dependencies = [ "allocator-api2", "bumpalo", - "hashbrown 0.15.5", + "hashbrown 0.17.0", "log", "rustc-hash", "smallvec", @@ -2983,9 +3244,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.2" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -2995,9 +3256,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -3006,9 +3267,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "region" @@ -3053,15 +3314,15 @@ dependencies = [ "cfg-if", "getrandom 0.2.17", "libc", - "untrusted 0.9.0", + "untrusted", "windows-sys 0.52.0", ] [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustc_version" @@ -3087,7 +3348,7 @@ version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "errno", "libc", "linux-raw-sys", @@ -3096,9 +3357,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.39" +version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "aws-lc-rs", "once_cell", @@ -3108,6 +3369,17 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-graviola" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "323c712e50c59ceb2ba9ad4d79dcfd3e0046a082d61efa87fcdf8f59af04473c" +dependencies = [ + "graviola", + "libcrux-ml-kem", + "rustls", +] + [[package]] name = "rustls-native-certs" version = "0.8.3" @@ -3131,9 +3403,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" dependencies = [ "zeroize", ] @@ -3174,7 +3446,7 @@ dependencies = [ "aws-lc-rs", "ring", "rustls-pki-types", - "untrusted 0.9.0", + "untrusted", ] [[package]] @@ -3182,7 +3454,7 @@ name = "rustpython" version = "0.5.0" dependencies = [ "criterion", - "dirs-next", + "dirs", "env_logger", "flame", "flamescope", @@ -3190,6 +3462,8 @@ dependencies = [ "libc", "log", "pyo3", + "rustls", + "rustls-graviola", "rustpython-capi", "rustpython-compiler", "rustpython-pylib", @@ -3204,6 +3478,8 @@ dependencies = [ name = "rustpython-capi" version = "0.5.0" dependencies = [ + "bitflags 2.11.1", + "num-complex", "pyo3", "rustpython-stdlib", "rustpython-vm", @@ -3213,8 +3489,7 @@ dependencies = [ name = "rustpython-codegen" version = "0.5.0" dependencies = [ - "ahash", - "bitflags 2.11.0", + "bitflags 2.11.1", "indexmap", "itertools 0.14.0", "log", @@ -3222,13 +3497,14 @@ dependencies = [ "memchr", "num-complex", "num-traits", + "rapidhash", "rustpython-compiler-core", "rustpython-literal", "rustpython-ruff_python_ast", "rustpython-ruff_python_parser", "rustpython-ruff_text_size", "rustpython-wtf8", - "thiserror 2.0.18", + "thiserror", "unicode_names2 2.0.0", ] @@ -3237,7 +3513,7 @@ name = "rustpython-common" version = "0.5.0" dependencies = [ "ascii", - "bitflags 2.11.0", + "bitflags 2.11.1", "getrandom 0.3.4", "itertools 0.14.0", "libc", @@ -3265,14 +3541,14 @@ dependencies = [ "rustpython-ruff_python_parser", "rustpython-ruff_source_file", "rustpython-ruff_text_size", - "thiserror 2.0.18", + "thiserror", ] [[package]] name = "rustpython-compiler-core" version = "0.5.0" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "bitflagset", "itertools 0.14.0", "lz4_flex", @@ -3304,7 +3580,6 @@ name = "rustpython-derive-impl" version = "0.5.0" dependencies = [ "itertools 0.14.0", - "maplit", "proc-macro2", "quote", "rustpython-compiler-core", @@ -3325,10 +3600,21 @@ dependencies = [ name = "rustpython-host_env" version = "0.5.0" dependencies = [ + "bitflags 2.11.1", + "getrandom 0.3.4", + "junction", "libc", - "nix 0.31.2", + "libffi", + "libloading 0.9.0", + "memmap2", + "nix 0.31.3", "num-traits", + "num_cpus", + "parking_lot", + "paste", + "rustix", "rustpython-wtf8", + "schannel", "termios", "widestring", "windows-sys 0.61.2", @@ -3347,7 +3633,7 @@ dependencies = [ "rustpython-compiler-core", "rustpython-derive", "rustpython-wtf8", - "thiserror 2.0.18", + "thiserror", ] [[package]] @@ -3379,7 +3665,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f021ff72cabf5e2cd6d8ec8813d376a8445a228dc610ab56c27bd9054cda70d4" dependencies = [ "aho-corasick", - "bitflags 2.11.0", + "bitflags 2.11.1", "compact_str", "get-size2", "is-macro", @@ -3388,7 +3674,7 @@ dependencies = [ "rustpython-ruff_python_trivia", "rustpython-ruff_source_file", "rustpython-ruff_text_size", - "thiserror 2.0.18", + "thiserror", ] [[package]] @@ -3397,7 +3683,7 @@ version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01e6ee78bd9671fb5766664b2695fe1f2a92a961f4d9101646c570d8acdb1e0b" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "bstr", "compact_str", "get-size2", @@ -3447,7 +3733,7 @@ dependencies = [ name = "rustpython-sre_engine" version = "0.5.0" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "criterion", "icu_properties", "num_enum", @@ -3460,9 +3746,7 @@ name = "rustpython-stdlib" version = "0.5.0" dependencies = [ "adler32", - "ahash", "ascii", - "aws-lc-rs", "base64", "blake2", "bzip2", @@ -3472,15 +3756,15 @@ dependencies = [ "crossbeam-utils", "csv-core", "der 0.8.0", - "digest", + "digest 0.10.7", "dns-lookup", "dyn-clone", "flame", "flate2", - "foreign-types-shared 0.3.1", + "foreign-types-shared", "gethostname", "hex", - "hmac", + "hmac 0.12.1", "icu_normalizer", "icu_properties", "indexmap", @@ -3495,9 +3779,7 @@ dependencies = [ "malachite-bigint", "md-5", "memchr", - "memmap2", "mt19937", - "nix 0.31.2", "num-complex", "num-traits", "num_enum", @@ -3505,16 +3787,15 @@ dependencies = [ "openssl", "openssl-probe", "openssl-sys", - "page_size", "parking_lot", "paste", - "pbkdf2", + "pbkdf2 0.12.2", "pem-rfc7468 1.0.0", "phf 0.13.1", "pkcs8", "pymath", "rand_core 0.9.5", - "rustix", + "rapidhash", "rustls", "rustls-native-certs", "rustls-pemfile", @@ -3528,14 +3809,12 @@ dependencies = [ "rustpython-ruff_source_file", "rustpython-ruff_text_size", "rustpython-vm", - "schannel", "sha-1", - "sha2", + "sha2 0.10.9", "sha3", "socket2", "system-configuration", "tcl-sys", - "termios", "tk-sys", "ucd", "unic-ucd-age", @@ -3543,7 +3822,6 @@ dependencies = [ "uuid", "webpki-roots", "widestring", - "windows-sys 0.61.2", "x509-cert", "x509-parser", "xml", @@ -3557,19 +3835,15 @@ version = "0.5.0" name = "rustpython-vm" version = "0.5.0" dependencies = [ - "ahash", "ascii", - "bitflags 2.11.0", + "bitflags 2.11.1", "bstr", - "caseless", "chrono", "constant_time_eq", "crossbeam-utils", - "errno", "exitcode", "flame", "flamer", - "getrandom 0.3.4", "glob", "half", "hex", @@ -3579,25 +3853,20 @@ dependencies = [ "indexmap", "is-macro", "itertools 0.14.0", - "junction", "libc", - "libffi", - "libloading 0.9.0", "log", "malachite-bigint", "memchr", - "nix 0.31.2", "num-complex", "num-integer", "num-traits", - "num_cpus", "num_enum", "optional", "parking_lot", "paste", "psm", + "rapidhash", "result-like", - "rustix", "rustpython-codegen", "rustpython-common", "rustpython-compiler", @@ -3616,12 +3885,11 @@ dependencies = [ "static_assertions", "strum", "strum_macros", - "thiserror 2.0.18", + "thiserror", "timsort", "wasm-bindgen", "which", "widestring", - "windows-sys 0.61.2", "writeable", ] @@ -3664,14 +3932,14 @@ version = "18.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a990b25f351b25139ddc7f21ee3f6f56f86d6846b74ac8fad3a719a287cd4a0" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "clipboard-win", "home", "libc", "log", "memchr", - "nix 0.31.2", + "nix 0.31.3", "radix_trie", "unicode-segmentation", "unicode-width", @@ -3681,9 +3949,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] name = "safe_arch" @@ -3696,10 +3964,11 @@ dependencies = [ [[package]] name = "salsa20" -version = "0.10.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" +checksum = "2f874456e72520ff1375a06c588eaf074b0f01f9e9e1aada45bd9b7954a6e42c" dependencies = [ + "cfg-if", "cipher", ] @@ -3729,22 +3998,23 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "scrypt" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f" +checksum = "d87af57419b594aa23fa95f09f0e06d80d84ba01c26148c43844cad6ff4485f0" dependencies = [ - "pbkdf2", + "cfg-if", + "pbkdf2 0.13.0", "salsa20", - "sha2", + "sha2 0.11.0", ] [[package]] name = "security-framework" -version = "3.5.1" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "core-foundation 0.10.1", "core-foundation-sys", "libc", @@ -3753,9 +4023,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.15.0" +version = "2.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" dependencies = [ "core-foundation-sys", "libc", @@ -3823,9 +4093,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "876ac351060d4f882bb1032b6369eb0aef79ad9df1ea8bc404874d8cc3d0cd98" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" dependencies = [ "serde_core", ] @@ -3837,8 +4107,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" dependencies = [ "cfg-if", - "cpufeatures", - "digest", + "cpufeatures 0.2.17", + "digest 0.10.7", ] [[package]] @@ -3848,8 +4118,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", - "digest", + "cpufeatures 0.2.17", + "digest 0.10.7", ] [[package]] @@ -3859,17 +4129,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", - "digest", + "cpufeatures 0.2.17", + "digest 0.10.7", +] + +[[package]] +name = "sha2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest 0.11.3", ] [[package]] name = "sha3" -version = "0.10.8" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +checksum = "77fd7028345d415a4034cf8777cd4f8ab1851274233b45f84e3d955502d93874" dependencies = [ - "digest", + "digest 0.10.7", "keccak", ] @@ -3898,9 +4179,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] name = "simd_cesu8" @@ -3926,15 +4207,15 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "smallvec" @@ -3962,6 +4243,16 @@ dependencies = [ "der 0.7.10", ] +[[package]] +name = "spki" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9efca8738c78ee9484207732f728b1ef517bbb1833d6fc0879ca898a522f6f" +dependencies = [ + "base64ct", + "der 0.8.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -4037,7 +4328,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "core-foundation 0.9.4", "system-configuration-sys", ] @@ -4054,9 +4345,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1dd07eb858a2067e2f3c7155d54e929265c264e6f37efe3ee7a8d1b5a1dd0ba" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" [[package]] name = "tcl-sys" @@ -4069,12 +4360,12 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.24.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.2", "once_cell", "rustix", "windows-sys 0.61.2", @@ -4095,33 +4386,13 @@ version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057" -[[package]] -name = "thiserror" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" -dependencies = [ - "thiserror-impl 1.0.69", -] - [[package]] name = "thiserror" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl 2.0.18", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" -dependencies = [ - "proc-macro2", - "quote", - "syn", + "thiserror-impl", ] [[package]] @@ -4206,9 +4477,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -4251,9 +4522,9 @@ dependencies = [ [[package]] name = "toml" -version = "1.1.0+spec-1.1.0" +version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8195ca05e4eb728f4ba94f3e3291661320af739c4e43779cbdfae82ab239fcc" +checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" dependencies = [ "indexmap", "serde_core", @@ -4266,27 +4537,27 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "1.1.0+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" dependencies = [ "serde_core", ] [[package]] name = "toml_parser" -version = "1.1.0+spec-1.1.0" +version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ "winnow", ] [[package]] name = "toml_writer" -version = "1.1.0+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed" +checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" [[package]] name = "twox-hash" @@ -4296,9 +4567,9 @@ checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "ucd" @@ -4349,9 +4620,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.22" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-normalization" @@ -4364,9 +4635,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" [[package]] name = "unicode-width" @@ -4374,6 +4645,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unicode_names2" version = "1.3.0" @@ -4403,7 +4680,7 @@ dependencies = [ "getopts", "log", "phf_codegen", - "rand 0.8.5", + "rand 0.8.6", ] [[package]] @@ -4413,15 +4690,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1262662dc96937c71115228ce2e1d30f41db71a7a45d3459e98783ef94052214" dependencies = [ "phf_codegen", - "rand 0.8.5", + "rand 0.8.6", ] -[[package]] -name = "untrusted" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" - [[package]] name = "untrusted" version = "0.9.0" @@ -4453,6 +4724,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ "atomic", + "getrandom 0.4.2", "js-sys", "wasm-bindgen", ] @@ -4487,18 +4759,27 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", ] [[package]] name = "wasm-bindgen" -version = "0.2.108" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" dependencies = [ "cfg-if", "once_cell", @@ -4509,23 +4790,19 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.58" +version = "0.4.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a6e77fd0ae8029c9ea0063f87c46fde723e7d887703d74ad2616d792e51e6f" +checksum = "af934872acec734c2d80e6617bbb5ff4f12b052dd8e6332b0817bce889516084" dependencies = [ - "cfg-if", - "futures-util", "js-sys", - "once_cell", "wasm-bindgen", - "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.108" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4533,9 +4810,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.108" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" dependencies = [ "bumpalo", "proc-macro2", @@ -4546,18 +4823,52 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.108" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.11.1", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "wasmtime-internal-core" -version = "44.0.1" +version = "44.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f2c7fa6523647262bfb4095dbdf4087accefe525813e783f81a0c682f418ce4" +checksum = "4364d345719bba7fc4c435992ea1cb0c118f1e90a88c6e6f22a7a4fc507700c6" dependencies = [ "hashbrown 0.16.1", "libm", @@ -4565,9 +4876,9 @@ dependencies = [ [[package]] name = "wasmtime-internal-jit-icache-coherence" -version = "44.0.1" +version = "44.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a1859e920871515d324fb9757c3e448d6ed1512ca6ccdff14b6e016505d6ada" +checksum = "c3ba98c1492f530833e0d3cc17dbb0c3c57c9f1bb3b078ae44bb55a233e43eba" dependencies = [ "cfg-if", "libc", @@ -4577,9 +4888,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.85" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" dependencies = [ "js-sys", "wasm-bindgen", @@ -4587,9 +4898,9 @@ dependencies = [ [[package]] name = "webpki-root-certs" -version = "1.0.5" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" +checksum = "f31141ce3fc3e300ae89b78c0dd67f9708061d1d2eda54b8209346fd6be9a92c" dependencies = [ "rustls-pki-types", ] @@ -4614,9 +4925,9 @@ dependencies = [ [[package]] name = "wide" -version = "1.1.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac11b009ebeae802ed758530b6496784ebfee7a87b9abfbcaf3bbe25b814eb25" +checksum = "c9479f84a757f819cfab37295955906479181395de83add28f74975fde083141" dependencies = [ "bytemuck", "safe_arch", @@ -4876,9 +5187,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" [[package]] name = "winresource" @@ -4895,6 +5206,94 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.11.1", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] [[package]] name = "write16" @@ -4904,9 +5303,9 @@ checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" [[package]] name = "writeable" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "x509-cert" @@ -4918,7 +5317,7 @@ dependencies = [ "der 0.7.10", "sha1", "signature", - "spki", + "spki 0.7.3", "tls_codec", ] @@ -4935,7 +5334,7 @@ dependencies = [ "nom", "oid-registry", "rusticata-macros", - "thiserror 2.0.18", + "thiserror", "time", ] @@ -4970,18 +5369,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.34" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71ddd76bcebeed25db614f82bf31a9f4222d3fbba300e6fb6c00afa26cbd4d9d" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.34" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8187381b52e32220d50b255276aa16a084ec0a9017a0ca2152a1f55c539758d" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", @@ -4990,18 +5389,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", @@ -5072,6 +5471,6 @@ checksum = "3be3d40e40a133f9c916ee3f9f4fa2d9d63435b5fbe1bfc6d9dae0aa0ada1513" [[package]] name = "zmij" -version = "1.0.17" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02aae0f83f69aafc94776e879363e9771d7ecbffe2c7fbb6c14c5e00dfe88439" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml index 80976b1e1c7..de2704842c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ license.workspace = true [features] capi = ["dep:rustpython-capi", "threading"] -default = ["threading", "stdlib", "stdio", "importlib", "ssl-rustls", "host_env"] +default = ["threading", "stdlib", "stdio", "importlib", "ssl-rustls-aws-lc", "host_env"] host_env = ["rustpython-vm/host_env", "rustpython-stdlib?/host_env"] importlib = ["rustpython-vm/importlib"] encodings = ["rustpython-vm/encodings"] @@ -22,10 +22,12 @@ freeze-stdlib = ["stdlib", "rustpython-vm/freeze-stdlib", "rustpython-pylib?/fre jit = ["rustpython-vm/jit"] threading = ["rustpython-vm/threading", "rustpython-stdlib/threading"] sqlite = ["rustpython-stdlib/sqlite"] -ssl = [] +ssl = ["host_env"] ssl-rustls = ["ssl", "rustpython-stdlib/ssl-rustls"] +ssl-rustls-aws-lc = ["ssl-rustls", "dep:rustls", "rustls/aws_lc_rs"] +ssl-rustls-aws-lc-fips = ["ssl-rustls-aws-lc", "rustls/fips"] ssl-openssl = ["ssl", "rustpython-stdlib/ssl-openssl"] -ssl-vendor = ["ssl-openssl", "rustpython-stdlib/ssl-vendor"] +ssl-openssl-vendor = ["ssl-openssl", "rustpython-stdlib/ssl-openssl-vendor"] tkinter = ["rustpython-stdlib/tkinter"] [build-dependencies] @@ -42,10 +44,13 @@ log = { workspace = true } flame = { workspace = true, optional = true } lexopt = "0.3" -dirs = { package = "dirs-next", version = "2.0" } +dirs = "6" env_logger = "0.11" flamescope = { version = "0.1.2", optional = true } +rustls = { workspace = true, optional = true } +rustls-graviola = { workspace = true, optional = true } + [target.'cfg(windows)'.dependencies] libc = { workspace = true } @@ -70,6 +75,17 @@ harness = false name = "rustpython" path = "src/main.rs" +[[example]] +name = "custom_tls_providers" +path = "examples/custom_tls_providers.rs" +required-features = [ + "rustls-graviola", + "rustls/ring", + "rustpython-pylib/freeze-stdlib", + "rustpython-stdlib/ssl-rustls", + "rustpython-vm/freeze-stdlib", +] + [profile.dev.package."*"] opt-level = 3 @@ -176,22 +192,19 @@ ruff_source_file = { package = "rustpython-ruff_source_file", version = "0.15.8" der = { version = "0.8", features = ["alloc", "oid", "pem", "zeroize"] } phf = { version = "0.13.1", default-features = false, features = ["macros"]} adler32 = "1.2.0" -ahash = "0.8.12" approx = "0.5.1" ascii = "1.1" -aws-lc-rs = "1.16.3" base64 = "0.22" blake2 = "0.10.4" bitflags = "2.11.0" bitflagset = "0.0.3" bstr = "1" bzip2 = "0.6" -caseless = "0.2.2" chrono = { version = "0.4.44", default-features = false, features = ["clock", "std"] } console_error_panic_hook = "0.1" constant_time_eq = "0.4" -cranelift = "0.131.1" -cranelift-jit = "0.131.1" +cranelift = "0.131.2" +cranelift-jit = "0.131.2" cranelift-module = "0.131.0" crc32fast = "1.3.2" criterion = { version = "0.8", features = ["html_reports"] } @@ -201,11 +214,11 @@ digest = "0.10.7" dns-lookup = "3.0" dyn-clone = "1.0.10" exitcode = "1.1.2" -errno = "0.3" flame = "0.2.2" flamer = "0.5" flate2 = { version = "1.1.9", default-features = false } -foreign-types-shared = "0.3.1" +# Bump only when the openssl crate bumps it +foreign-types-shared = "0.1" gethostname = "1.0.2" getrandom = { version = "0.3", features = ["std"] } glob = "0.3" @@ -235,7 +248,6 @@ mac_address = "1.1.3" malachite-bigint = "0.9.1" malachite-q = "0.9.1" malachite-base = "0.9.1" -maplit = "1.0.2" md-5 = "0.10.1" memchr = "2.8.0" memmap2 = "0.9.10" @@ -246,16 +258,15 @@ num-traits = "0.2" num_cpus = "1.17.0" num_enum = { version = "0.7", default-features = false } oid-registry = "0.8" -openssl = "0.10.79" +openssl = "0.10.80" openssl-sys = "0.9.110" openssl-probe = "0.2.1" optional = "0.5" -page_size = "0.6" parking_lot = "0.12.3" paste = "1.0.15" pbkdf2 = "0.12" pem-rfc7468 = "1.0" -pkcs8 = "0.10" +pkcs8 = "0.11" proc-macro2 = "1.0.105" psm = "0.1" pymath = { version = "0.2.0", features = ["mul_add", "malachite-bigint", "complex"] } @@ -264,9 +275,11 @@ quote = "1.0.45" radium = "1.1.1" rand = "0.9" rand_core = { version = "0.9", features = ["os_rng"] } +rapidhash = "4.4.1" result-like = "0.5.0" -rustix = { version = "1.1", features = ["event", "system"] } +rustix = { version = "1.1", features = ["event", "param", "system"] } rustls = { version = "0.23.39", default-features = false } +rustls-graviola = "0.3" rustls-native-certs = "0.8" rustls-pemfile = "2.2" rustls-platform-verifier = "0.7" @@ -337,8 +350,11 @@ similar_names = "allow" # restriction lints alloc_instead_of_core = "warn" +cfg_not_test = "warn" +redundant_test_prefix = "warn" std_instead_of_alloc = "warn" std_instead_of_core = "warn" +tests_outside_test_module = "warn" # nursery lints to enforce gradually debug_assert_with_mut_call = "warn" diff --git a/Lib/_opcode_metadata.py b/Lib/_opcode_metadata.py index 4da6e507736..001a016621c 100644 --- a/Lib/_opcode_metadata.py +++ b/Lib/_opcode_metadata.py @@ -1,4 +1,4 @@ -# This file is generated by scripts/generate_opcode_metadata.py +# This file is generated by tools/opcode_metadata/generate_py_opcode_metadata.py # for RustPython bytecode format (CPython 3.14 compatible opcode numbers). # Do not edit! diff --git a/Lib/ctypes/__init__.py b/Lib/ctypes/__init__.py index 04ec0270148..7223fc49977 100644 --- a/Lib/ctypes/__init__.py +++ b/Lib/ctypes/__init__.py @@ -470,6 +470,8 @@ def _load_library(self, name, mode, handle, winmode): if name and name.endswith(")") and ".a(" in name: mode |= _os.RTLD_MEMBER | _os.RTLD_NOW self._name = name + if handle is not None: + return handle return _dlopen(name, mode) def __repr__(self): diff --git a/Lib/ctypes/util.py b/Lib/ctypes/util.py index 378f12167c6..3b21658433b 100644 --- a/Lib/ctypes/util.py +++ b/Lib/ctypes/util.py @@ -85,15 +85,10 @@ def find_library(name): wintypes.DWORD, ) - _psapi = ctypes.WinDLL('psapi', use_last_error=True) - _enum_process_modules = _psapi["EnumProcessModules"] - _enum_process_modules.restype = wintypes.BOOL - _enum_process_modules.argtypes = ( - wintypes.HANDLE, - ctypes.POINTER(wintypes.HMODULE), - wintypes.DWORD, - wintypes.LPDWORD, - ) + # gh-145307: We defer loading psapi.dll until _get_module_handles is called. + # Loading additional DLLs at startup for functionality that may never be + # used is wasteful. + _enum_process_modules = None def _get_module_filename(module: wintypes.HMODULE): name = (wintypes.WCHAR * 32767)() # UNICODE_STRING_MAX_CHARS @@ -101,8 +96,19 @@ def _get_module_filename(module: wintypes.HMODULE): return name.value return None - def _get_module_handles(): + global _enum_process_modules + if _enum_process_modules is None: + _psapi = ctypes.WinDLL('psapi', use_last_error=True) + _enum_process_modules = _psapi["EnumProcessModules"] + _enum_process_modules.restype = wintypes.BOOL + _enum_process_modules.argtypes = ( + wintypes.HANDLE, + ctypes.POINTER(wintypes.HMODULE), + wintypes.DWORD, + wintypes.LPDWORD, + ) + process = _get_current_process() space_needed = wintypes.DWORD() n = 1024 diff --git a/Lib/email/__init__.py b/Lib/email/__init__.py index 9fa47783004..6d597006e5e 100644 --- a/Lib/email/__init__.py +++ b/Lib/email/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2007 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/_encoded_words.py b/Lib/email/_encoded_words.py index 6795a606de0..05a34a4c105 100644 --- a/Lib/email/_encoded_words.py +++ b/Lib/email/_encoded_words.py @@ -219,7 +219,7 @@ def encode(string, charset='utf-8', encoding=None, lang=''): """ if charset == 'unknown-8bit': - bstring = string.encode('ascii', 'surrogateescape') + bstring = string.encode('utf-8', 'surrogateescape') else: bstring = string.encode(charset) if encoding is None: diff --git a/Lib/email/_header_value_parser.py b/Lib/email/_header_value_parser.py index 91243378dc0..1367e34195b 100644 --- a/Lib/email/_header_value_parser.py +++ b/Lib/email/_header_value_parser.py @@ -80,7 +80,8 @@ # Useful constants and functions # -WSP = set(' \t') +_WSP = ' \t' +WSP = set(_WSP) CFWS_LEADER = WSP | set('(') SPECIALS = set(r'()<>@,:;.\"[]') ATOM_ENDS = SPECIALS | WSP @@ -101,6 +102,12 @@ def make_quoted_pairs(value): return str(value).replace('\\', '\\\\').replace('"', '\\"') +def make_parenthesis_pairs(value): + """Escape parenthesis and backslash for use within a comment.""" + return str(value).replace('\\', '\\\\') \ + .replace('(', '\\(').replace(')', '\\)') + + def quote_string(value): escaped = make_quoted_pairs(value) return f'"{escaped}"' @@ -632,11 +639,11 @@ def local_part(self): for tok in self[0] + [DOT]: if tok.token_type == 'cfws': continue - if (last_is_tl and tok.token_type == 'dot' and + if (last_is_tl and tok.token_type == 'dot' and last and last[-1].token_type == 'cfws'): res[-1] = TokenList(last[:-1]) is_tl = isinstance(tok, TokenList) - if (is_tl and last.token_type == 'dot' and + if (is_tl and last.token_type == 'dot' and tok and tok[0].token_type == 'cfws'): res.append(TokenList(tok[1:])) else: @@ -874,6 +881,12 @@ class MessageID(MsgID): class InvalidMessageID(MessageID): token_type = 'invalid-message-id' +class MessageIDList(TokenList): + token_type = 'message-id-list' + + @property + def message_ids(self): + return [x for x in self if x.token_type=='msg-id'] class Header(TokenList): token_type = 'header' @@ -933,7 +946,7 @@ def value(self): return ' ' def startswith_fws(self): - return True + return self and self[0] in WSP class ValueTerminal(Terminal): @@ -1232,8 +1245,7 @@ def get_bare_quoted_string(value): bare_quoted_string = BareQuotedString() value = value[1:] if value and value[0] == '"': - token, value = get_qcontent(value) - bare_quoted_string.append(token) + return bare_quoted_string, value[1:] while value and value[0] != '"': if value[0] in WSP: token, value = get_fws(value) @@ -2046,12 +2058,10 @@ def get_address_list(value): address_list.defects.append(errors.InvalidHeaderDefect( "invalid address in address-list")) if value and value[0] != ',': - # Crap after address; treat it as an invalid mailbox. - # The mailbox info will still be available. - mailbox = address_list[-1][0] - mailbox.token_type = 'invalid-mailbox' + # Crap after address: add it to the address list + # as an invalid mailbox token, value = get_invalid_mailbox(value, ',') - mailbox.extend(token) + address_list.append(Address([token])) address_list.defects.append(errors.InvalidHeaderDefect( "invalid address in address-list")) if value: # Must be a , at this point. @@ -2171,6 +2181,32 @@ def parse_message_id(value): return message_id +def parse_message_ids(value): + """in-reply-to = "In-Reply-To:" 1*msg-id CRLF + references = "References:" 1*msg-id CRLF + """ + message_id_list = MessageIDList() + while value: + if value[0] == ',': + # message id list separated with commas - this is invalid, + # but happens rather frequently in the wild + message_id_list.defects.append( + errors.InvalidHeaderDefect("comma in msg-id list")) + message_id_list.append( + WhiteSpaceTerminal(' ', 'invalid-comma-replacement')) + value = value[1:] + continue + try: + token, value = get_msg_id(value) + message_id_list.append(token) + except errors.HeaderParseError as ex: + token = get_unstructured(value) + message_id_list.append(InvalidMessageID(token)) + message_id_list.defects.append( + errors.InvalidHeaderDefect("Invalid msg-id: {!r}".format(ex))) + break + return message_id_list + # # XXX: As I begin to add additional header parsers, I'm realizing we probably # have two level of parser routines: the get_XXX methods that get a token in @@ -2788,8 +2824,12 @@ def _steal_trailing_WSP_if_exists(lines): if lines and lines[-1] and lines[-1][-1] in WSP: wsp = lines[-1][-1] lines[-1] = lines[-1][:-1] + # gh-142006: if the line is now empty, remove it entirely. + if not lines[-1]: + lines.pop() return wsp + def _refold_parse_tree(parse_tree, *, policy): """Return string of contents of parse_tree folded according to RFC rules. @@ -2798,11 +2838,9 @@ def _refold_parse_tree(parse_tree, *, policy): maxlen = policy.max_line_length or sys.maxsize encoding = 'utf-8' if policy.utf8 else 'us-ascii' lines = [''] # Folded lines to be output - leading_whitespace = '' # When we have whitespace between two encoded - # words, we may need to encode the whitespace - # at the beginning of the second word. - last_ew = None # Points to the last encoded character if there's an ew on - # the line + last_word_is_ew = False + last_ew = None # if there is an encoded word in the last line of lines, + # points to the encoded word's first character last_charset = None wrap_as_ew_blocked = 0 want_encoding = False # This is set to True if we need to encode this part @@ -2837,6 +2875,7 @@ def _refold_parse_tree(parse_tree, *, policy): if part.token_type == 'mime-parameters': # Mime parameter folding (using RFC2231) is extra special. _fold_mime_parameters(part, lines, maxlen, encoding) + last_word_is_ew = False continue if want_encoding and not wrap_as_ew_blocked: @@ -2853,6 +2892,7 @@ def _refold_parse_tree(parse_tree, *, policy): # XXX what if encoded_part has no leading FWS? lines.append(newline) lines[-1] += encoded_part + last_word_is_ew = False continue # Either this is not a major syntactic break, so we don't # want it on a line by itself even if it fits, or it @@ -2871,11 +2911,16 @@ def _refold_parse_tree(parse_tree, *, policy): (last_charset == 'unknown-8bit' or last_charset == 'utf-8' and charset != 'us-ascii')): last_ew = None - last_ew = _fold_as_ew(tstr, lines, maxlen, last_ew, - part.ew_combine_allowed, charset, leading_whitespace) - # This whitespace has been added to the lines in _fold_as_ew() - # so clear it now. - leading_whitespace = '' + last_ew = _fold_as_ew( + tstr, + lines, + maxlen, + last_ew, + part.ew_combine_allowed, + charset, + last_word_is_ew, + ) + last_word_is_ew = True last_charset = charset want_encoding = False continue @@ -2888,28 +2933,19 @@ def _refold_parse_tree(parse_tree, *, policy): if len(tstr) <= maxlen - len(lines[-1]): lines[-1] += tstr + last_word_is_ew = last_word_is_ew and not bool(tstr.strip(_WSP)) continue # This part is too long to fit. The RFC wants us to break at # "major syntactic breaks", so unless we don't consider this # to be one, check if it will fit on the next line by itself. - leading_whitespace = '' if (part.syntactic_break and len(tstr) + 1 <= maxlen): newline = _steal_trailing_WSP_if_exists(lines) if newline or part.startswith_fws(): - # We're going to fold the data onto a new line here. Due to - # the way encoded strings handle continuation lines, we need to - # be prepared to encode any whitespace if the next line turns - # out to start with an encoded word. lines.append(newline + tstr) - - whitespace_accumulator = [] - for char in lines[-1]: - if char not in WSP: - break - whitespace_accumulator.append(char) - leading_whitespace = ''.join(whitespace_accumulator) + last_word_is_ew = (last_word_is_ew + and not bool(lines[-1].strip(_WSP))) last_ew = None continue if not hasattr(part, 'encode'): @@ -2924,6 +2960,13 @@ def _refold_parse_tree(parse_tree, *, policy): [ValueTerminal(make_quoted_pairs(p), 'ptext') for p in newparts] + [ValueTerminal('"', 'ptext')]) + if part.token_type == 'comment': + newparts = ( + [ValueTerminal('(', 'ptext')] + + [ValueTerminal(make_parenthesis_pairs(p), 'ptext') + if p.token_type == 'ptext' else p + for p in newparts] + + [ValueTerminal(')', 'ptext')]) if not part.as_ew_allowed: wrap_as_ew_blocked += 1 newparts.append(end_ew_not_allowed) @@ -2942,10 +2985,11 @@ def _refold_parse_tree(parse_tree, *, policy): else: # We can't fold it onto the next line either... lines[-1] += tstr + last_word_is_ew = last_word_is_ew and not bool(tstr.strip(_WSP)) return policy.linesep.join(lines) + policy.linesep -def _fold_as_ew(to_encode, lines, maxlen, last_ew, ew_combine_allowed, charset, leading_whitespace): +def _fold_as_ew(to_encode, lines, maxlen, last_ew, ew_combine_allowed, charset, last_word_is_ew): """Fold string to_encode into lines as encoded word, combining if allowed. Return the new value for last_ew, or None if ew_combine_allowed is False. @@ -2960,6 +3004,16 @@ def _fold_as_ew(to_encode, lines, maxlen, last_ew, ew_combine_allowed, charset, to_encode = str( get_unstructured(lines[-1][last_ew:] + to_encode)) lines[-1] = lines[-1][:last_ew] + elif last_word_is_ew: + # If we are following up an encoded word with another encoded word, + # any white space between the two will be ignored when decoded. + # Therefore, we encode all to-be-displayed whitespace in the second + # encoded word. + len_without_wsp = len(lines[-1].rstrip(_WSP)) + leading_whitespace = lines[-1][len_without_wsp:] + lines[-1] = (lines[-1][:len_without_wsp] + + (' ' if leading_whitespace else '')) + to_encode = leading_whitespace + to_encode elif to_encode[0] in WSP: # We're joining this to non-encoded text, so don't encode # the leading blank. @@ -2988,20 +3042,13 @@ def _fold_as_ew(to_encode, lines, maxlen, last_ew, ew_combine_allowed, charset, while to_encode: remaining_space = maxlen - len(lines[-1]) - text_space = remaining_space - chrome_len - len(leading_whitespace) + text_space = remaining_space - chrome_len if text_space <= 0: - lines.append(' ') + newline = _steal_trailing_WSP_if_exists(lines) + lines.append(newline or ' ') + new_last_ew = len(lines[-1]) continue - # If we are at the start of a continuation line, prepend whitespace - # (we only want to do this when the line starts with an encoded word - # but if we're folding in this helper function, then we know that we - # are going to be writing out an encoded word.) - if len(lines) > 1 and len(lines[-1]) == 1 and leading_whitespace: - encoded_word = _ew.encode(leading_whitespace, charset=encode_as) - lines[-1] += encoded_word - leading_whitespace = '' - to_encode_word = to_encode[:text_space] encoded_word = _ew.encode(to_encode_word, charset=encode_as) excess = len(encoded_word) - remaining_space @@ -3013,7 +3060,6 @@ def _fold_as_ew(to_encode, lines, maxlen, last_ew, ew_combine_allowed, charset, excess = len(encoded_word) - remaining_space lines[-1] += encoded_word to_encode = to_encode[len(to_encode_word):] - leading_whitespace = '' if to_encode: lines.append(' ') diff --git a/Lib/email/_parseaddr.py b/Lib/email/_parseaddr.py index 565af0cf361..6a7c5fa06d2 100644 --- a/Lib/email/_parseaddr.py +++ b/Lib/email/_parseaddr.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2007 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # Contact: email-sig@python.org """Email address parsing code. @@ -225,7 +225,7 @@ class AddrlistClass: def __init__(self, field): """Initialize a new instance. - `field' is an unparsed address header field, containing + 'field' is an unparsed address header field, containing one or more addresses. """ self.specials = '()<>@,:;.\"[]' @@ -426,14 +426,14 @@ def getdomain(self): def getdelimited(self, beginchar, endchars, allowcomments=True): """Parse a header fragment delimited by special characters. - `beginchar' is the start character for the fragment. - If self is not looking at an instance of `beginchar' then + 'beginchar' is the start character for the fragment. + If self is not looking at an instance of 'beginchar' then getdelimited returns the empty string. - `endchars' is a sequence of allowable end-delimiting characters. + 'endchars' is a sequence of allowable end-delimiting characters. Parsing stops when one of these is encountered. - If `allowcomments' is non-zero, embedded RFC 2822 comments are allowed + If 'allowcomments' is non-zero, embedded RFC 2822 comments are allowed within the parsed fragment. """ if self.field[self.pos] != beginchar: @@ -477,7 +477,7 @@ def getatom(self, atomends=None): Optional atomends specifies a different set of end token delimiters (the default is to use self.atomends). This is used e.g. in - getphraselist() since phrase endings must not include the `.' (which + getphraselist() since phrase endings must not include the '.' (which is legal in phrases).""" atomlist = [''] if atomends is None: diff --git a/Lib/email/_policybase.py b/Lib/email/_policybase.py index 0d486c90a9c..e23843df448 100644 --- a/Lib/email/_policybase.py +++ b/Lib/email/_policybase.py @@ -4,6 +4,7 @@ """ import abc +import re from email import header from email import charset as _charset from email.utils import _has_surrogates @@ -14,6 +15,14 @@ 'compat32', ] +# validation regex from RFC 5322, equivalent to pattern re.compile("[!-9;-~]+$") +valid_header_name_re = re.compile("[\041-\071\073-\176]+$") + +def validate_header_name(name): + # Validate header name according to RFC 5322 + if not valid_header_name_re.match(name): + raise ValueError( + f"Header field name contains invalid characters: {name!r}") class _PolicyBase: @@ -150,7 +159,7 @@ class Policy(_PolicyBase, metaclass=abc.ABCMeta): wrapping is done. Default is 78. mangle_from_ -- a flag that, when True escapes From_ lines in the - body of the message by putting a `>' in front of + body of the message by putting a '>' in front of them. This is used when the message is being serialized by a generator. Default: False. @@ -314,6 +323,7 @@ def header_store_parse(self, name, value): """+ The name and value are returned unmodified. """ + validate_header_name(name) return (name, value) def header_fetch_parse(self, name, value): diff --git a/Lib/email/base64mime.py b/Lib/email/base64mime.py index 4cdf22666e3..a5a3f737a97 100644 --- a/Lib/email/base64mime.py +++ b/Lib/email/base64mime.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2007 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # Author: Ben Gertzfield # Contact: email-sig@python.org @@ -15,7 +15,7 @@ with Base64 encoding. RFC 2045 defines a method for including character set information in an -`encoded-word' in a header. This method is commonly used for 8-bit real names +'encoded-word' in a header. This method is commonly used for 8-bit real names in To:, From:, Cc:, etc. fields, as well as Subject: lines. This module does not do the line wrapping or end-of-line character conversion diff --git a/Lib/email/charset.py b/Lib/email/charset.py index 043801107b6..5036c3f58a5 100644 --- a/Lib/email/charset.py +++ b/Lib/email/charset.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2007 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Ben Gertzfield, Barry Warsaw # Contact: email-sig@python.org @@ -175,7 +175,7 @@ class Charset: module expose the following information about a character set: input_charset: The initial character set specified. Common aliases - are converted to their `official' email names (e.g. latin_1 + are converted to their 'official' email names (e.g. latin_1 is converted to iso-8859-1). Defaults to 7-bit us-ascii. header_encoding: If the character set must be encoded before it can be @@ -245,7 +245,7 @@ def __eq__(self, other): def get_body_encoding(self): """Return the content-transfer-encoding used for body encoding. - This is either the string `quoted-printable' or `base64' depending on + This is either the string 'quoted-printable' or 'base64' depending on the encoding used, or it is a function in which case you should call the function with a single argument, the Message object being encoded. The function should then set the Content-Transfer-Encoding diff --git a/Lib/email/encoders.py b/Lib/email/encoders.py index 17bd1ab7b19..55741a22a07 100644 --- a/Lib/email/encoders.py +++ b/Lib/email/encoders.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/errors.py b/Lib/email/errors.py index 02aa5eced6a..6bc744bd59c 100644 --- a/Lib/email/errors.py +++ b/Lib/email/errors.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/feedparser.py b/Lib/email/feedparser.py index bc773f38030..ae8ef32792b 100644 --- a/Lib/email/feedparser.py +++ b/Lib/email/feedparser.py @@ -1,4 +1,4 @@ -# Copyright (C) 2004-2006 Python Software Foundation +# Copyright (C) 2004 Python Software Foundation # Authors: Baxter, Wouters and Warsaw # Contact: email-sig@python.org @@ -30,7 +30,7 @@ NLCRE = re.compile(r'\r\n|\r|\n') NLCRE_bol = re.compile(r'(\r\n|\r|\n)') -NLCRE_eol = re.compile(r'(\r\n|\r|\n)\Z') +NLCRE_eol = re.compile(r'(\r\n|\r|\n)\z') NLCRE_crack = re.compile(r'(\r\n|\r|\n)') # RFC 5322 section 3.6.8 Optional fields. ftext is %d33-57 / %d59-126, Any character # except controls, SP, and ":". @@ -504,10 +504,9 @@ def _parse_headers(self, lines): self._input.unreadline(line) return else: - # Weirdly placed unix-from line. Note this as a defect - # and ignore it. + # Weirdly placed unix-from line. defect = errors.MisplacedEnvelopeHeaderDefect(line) - self._cur.defects.append(defect) + self.policy.handle_defect(self._cur, defect) continue # Split the line on the colon separating field name from value. # There will always be a colon, because if there wasn't the part of @@ -519,7 +518,7 @@ def _parse_headers(self, lines): # message. Track the error but keep going. if i == 0: defect = errors.InvalidHeaderDefect("Missing header name.") - self._cur.defects.append(defect) + self.policy.handle_defect(self._cur, defect) continue assert i>0, "_parse_headers fed line with no : and no leading WS" diff --git a/Lib/email/generator.py b/Lib/email/generator.py index ce94f5c56fe..ba11d63fba6 100644 --- a/Lib/email/generator.py +++ b/Lib/email/generator.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2010 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org @@ -22,6 +22,7 @@ NLCRE = re.compile(r'\r\n|\r|\n') fcre = re.compile(r'^From ', re.MULTILINE) NEWLINE_WITHOUT_FWSP = re.compile(r'\r\n[^ \t]|\r[^ \n\t]|\n[^ \t]') +NEWLINE_WITHOUT_FWSP_BYTES = re.compile(br'\r\n[^ \t]|\r[^ \n\t]|\n[^ \t]') class Generator: @@ -43,7 +44,7 @@ def __init__(self, outfp, mangle_from_=None, maxheaderlen=None, *, Optional mangle_from_ is a flag that, when True (the default if policy is not set), escapes From_ lines in the body of the message by putting - a `>' in front of them. + a '>' in front of them. Optional maxheaderlen specifies the longest length for a non-continued header. When a header line is longer (in characters, with tabs @@ -76,7 +77,7 @@ def flatten(self, msg, unixfrom=False, linesep=None): unixfrom is a flag that forces the printing of a Unix From_ delimiter before the first object in the message tree. If the original message - has no From_ delimiter, a `standard' one is crafted. By default, this + has no From_ delimiter, a 'standard' one is crafted. By default, this is False to inhibit the printing of any From_ delimiter. Note that for subobjects, no From_ line is printed. @@ -227,7 +228,7 @@ def _write_headers(self, msg): folded = self.policy.fold(h, v) if self.policy.verify_generated_headers: linesep = self.policy.linesep - if not folded.endswith(self.policy.linesep): + if not folded.endswith(linesep): raise HeaderWriteError( f'folded header does not end with {linesep!r}: {folded!r}') if NEWLINE_WITHOUT_FWSP.search(folded.removesuffix(linesep)): @@ -391,7 +392,7 @@ def _make_boundary(cls, text=None): b = boundary counter = 0 while True: - cre = cls._compile_re('^--' + re.escape(b) + '(--)?$', re.MULTILINE) + cre = cls._compile_re('^--' + re.escape(b) + '(--)?\r?$', re.MULTILINE) if not cre.search(text): break b = boundary + '.' + str(counter) @@ -429,7 +430,16 @@ def _write_headers(self, msg): # This is almost the same as the string version, except for handling # strings with 8bit bytes. for h, v in msg.raw_items(): - self._fp.write(self.policy.fold_binary(h, v)) + folded = self.policy.fold_binary(h, v) + if self.policy.verify_generated_headers: + linesep = self.policy.linesep.encode() + if not folded.endswith(linesep): + raise HeaderWriteError( + f'folded header does not end with {linesep!r}: {folded!r}') + if NEWLINE_WITHOUT_FWSP_BYTES.search(folded.removesuffix(linesep)): + raise HeaderWriteError( + f'folded header contains newline: {folded!r}') + self._fp.write(folded) # A blank line always separates headers from body self.write(self._NL) @@ -467,7 +477,7 @@ def __init__(self, outfp, mangle_from_=None, maxheaderlen=None, fmt=None, *, argument is allowed. Walks through all subparts of a message. If the subpart is of main - type `text', then it prints the decoded payload of the subpart. + type 'text', then it prints the decoded payload of the subpart. Otherwise, fmt is a format string that is used instead of the message payload. fmt is expanded with the following keywords (in diff --git a/Lib/email/header.py b/Lib/email/header.py index a0aadb97ca6..220a84a7454 100644 --- a/Lib/email/header.py +++ b/Lib/email/header.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2007 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # Author: Ben Gertzfield, Barry Warsaw # Contact: email-sig@python.org @@ -201,7 +201,7 @@ def __init__(self, s=None, charset=None, The maximum line length can be specified explicitly via maxlinelen. For splitting the first line to a shorter value (to account for the field - header which isn't included in s, e.g. `Subject') pass in the name of + header which isn't included in s, e.g. 'Subject') pass in the name of the field in header_name. The default maxlinelen is 78 as recommended by RFC 2822. @@ -285,7 +285,7 @@ def append(self, s, charset=None, errors='strict'): output codec of the charset. If the string cannot be encoded to the output codec, a UnicodeError will be raised. - Optional `errors' is passed as the errors argument to the decode + Optional 'errors' is passed as the errors argument to the decode call if s is a byte string. """ if charset is None: @@ -335,7 +335,7 @@ def encode(self, splitchars=';, \t', maxlinelen=None, linesep='\n'): Optional splitchars is a string containing characters which should be given extra weight by the splitting algorithm during normal header - wrapping. This is in very rough support of RFC 2822's `higher level + wrapping. This is in very rough support of RFC 2822's 'higher level syntactic breaks': split points preceded by a splitchar are preferred during line splitting, with the characters preferred in the order in which they appear in the string. Space and tab may be included in the diff --git a/Lib/email/headerregistry.py b/Lib/email/headerregistry.py index 543141dc427..0e8698efc0b 100644 --- a/Lib/email/headerregistry.py +++ b/Lib/email/headerregistry.py @@ -534,6 +534,18 @@ def parse(cls, value, kwds): kwds['defects'].extend(parse_tree.all_defects) +class ReferencesHeader: + + max_count = 1 + value_parser = staticmethod(parser.parse_message_ids) + + @classmethod + def parse(cls, value, kwds): + kwds['parse_tree'] = parse_tree = cls.value_parser(value) + kwds['decoded'] = str(parse_tree) + kwds['defects'].extend(parse_tree.all_defects) + + # The header factory # _default_header_map = { @@ -557,6 +569,8 @@ def parse(cls, value, kwds): 'content-disposition': ContentDispositionHeader, 'content-transfer-encoding': ContentTransferEncodingHeader, 'message-id': MessageIDHeader, + 'in-reply-to': ReferencesHeader, + 'references': ReferencesHeader, } class HeaderRegistry: diff --git a/Lib/email/iterators.py b/Lib/email/iterators.py index 3410935e38f..08ede3ec679 100644 --- a/Lib/email/iterators.py +++ b/Lib/email/iterators.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org @@ -43,8 +43,8 @@ def body_line_iterator(msg, decode=False): def typed_subpart_iterator(msg, maintype='text', subtype=None): """Iterate over the subparts with a given MIME type. - Use `maintype' as the main MIME type to match against; this defaults to - "text". Optional `subtype' is the MIME subtype to match against; if + Use 'maintype' as the main MIME type to match against; this defaults to + "text". Optional 'subtype' is the MIME subtype to match against; if omitted, only the main type is matched. """ for subpart in msg.walk(): diff --git a/Lib/email/message.py b/Lib/email/message.py index 80f01d66a33..641fb2e944d 100644 --- a/Lib/email/message.py +++ b/Lib/email/message.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2007 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org @@ -21,7 +21,7 @@ SEMISPACE = '; ' -# Regular expression that matches `special' characters in parameters, the +# Regular expression that matches 'special' characters in parameters, the # existence of which force quoting of the parameter value. tspecials = re.compile(r'[ \(\)<>@,;:\\"/\[\]\?=]') @@ -147,7 +147,7 @@ class Message: multipart or a message/rfc822), then the payload is a list of Message objects, otherwise it is a string. - Message objects implement part of the `mapping' interface, which assumes + Message objects implement part of the 'mapping' interface, which assumes there is exactly one occurrence of the header per message. Some headers do in fact appear multiple times (e.g. Received) and for those headers, you must use the explicit API to set or get all the headers. Not all of @@ -609,7 +609,7 @@ def get_content_type(self): """Return the message's content type. The returned string is coerced to lower case of the form - `maintype/subtype'. If there was no Content-Type header in the + 'maintype/subtype'. If there was no Content-Type header in the message, the default type as given by get_default_type() will be returned. Since according to RFC 2045, messages always have a default type this will always return a value. @@ -632,7 +632,7 @@ def get_content_type(self): def get_content_maintype(self): """Return the message's main content type. - This is the `maintype' part of the string returned by + This is the 'maintype' part of the string returned by get_content_type(). """ ctype = self.get_content_type() @@ -641,14 +641,14 @@ def get_content_maintype(self): def get_content_subtype(self): """Returns the message's sub-content type. - This is the `subtype' part of the string returned by + This is the 'subtype' part of the string returned by get_content_type(). """ ctype = self.get_content_type() return ctype.split('/')[1] def get_default_type(self): - """Return the `default' content type. + """Return the 'default' content type. Most messages have a default content type of text/plain, except for messages that are subparts of multipart/digest containers. Such @@ -657,7 +657,7 @@ def get_default_type(self): return self._default_type def set_default_type(self, ctype): - """Set the `default' content type. + """Set the 'default' content type. ctype should be either "text/plain" or "message/rfc822", although this is not enforced. The default content type is not stored in the @@ -690,8 +690,8 @@ def get_params(self, failobj=None, header='content-type', unquote=True): """Return the message's Content-Type parameters, as a list. The elements of the returned list are 2-tuples of key/value pairs, as - split on the `=' sign. The left hand side of the `=' is the key, - while the right hand side is the value. If there is no `=' sign in + split on the '=' sign. The left hand side of the '=' is the key, + while the right hand side is the value. If there is no '=' sign in the parameter the value is the empty string. The value is as described in the get_param() method. @@ -851,9 +851,9 @@ def get_filename(self, failobj=None): """Return the filename associated with the payload if present. The filename is extracted from the Content-Disposition header's - `filename' parameter, and it is unquoted. If that header is missing - the `filename' parameter, this method falls back to looking for the - `name' parameter. + 'filename' parameter, and it is unquoted. If that header is missing + the 'filename' parameter, this method falls back to looking for the + 'name' parameter. """ missing = object() filename = self.get_param('filename', missing, 'content-disposition') @@ -866,7 +866,7 @@ def get_filename(self, failobj=None): def get_boundary(self, failobj=None): """Return the boundary associated with the payload if present. - The boundary is extracted from the Content-Type header's `boundary' + The boundary is extracted from the Content-Type header's 'boundary' parameter, and it is unquoted. """ missing = object() diff --git a/Lib/email/mime/application.py b/Lib/email/mime/application.py index f67cbad3f03..9a9d213d2a9 100644 --- a/Lib/email/mime/application.py +++ b/Lib/email/mime/application.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Keith Dart # Contact: email-sig@python.org diff --git a/Lib/email/mime/audio.py b/Lib/email/mime/audio.py index aa0c4905cbb..85f4a955238 100644 --- a/Lib/email/mime/audio.py +++ b/Lib/email/mime/audio.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2007 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Anthony Baxter # Contact: email-sig@python.org diff --git a/Lib/email/mime/base.py b/Lib/email/mime/base.py index f601f621cec..da4c6e591a5 100644 --- a/Lib/email/mime/base.py +++ b/Lib/email/mime/base.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/mime/image.py b/Lib/email/mime/image.py index 4b7f2f9cbad..dab96858481 100644 --- a/Lib/email/mime/image.py +++ b/Lib/email/mime/image.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/mime/message.py b/Lib/email/mime/message.py index 61836b5a786..13d9ff599f8 100644 --- a/Lib/email/mime/message.py +++ b/Lib/email/mime/message.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/mime/multipart.py b/Lib/email/mime/multipart.py index 94d81c771a4..1abb84d5fed 100644 --- a/Lib/email/mime/multipart.py +++ b/Lib/email/mime/multipart.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2006 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org @@ -21,7 +21,7 @@ def __init__(self, _subtype='mixed', boundary=None, _subparts=None, Content-Type and MIME-Version headers. _subtype is the subtype of the multipart content type, defaulting to - `mixed'. + 'mixed'. boundary is the multipart boundary string. By default it is calculated as needed. diff --git a/Lib/email/mime/nonmultipart.py b/Lib/email/mime/nonmultipart.py index a41386eb148..5beab3a441e 100644 --- a/Lib/email/mime/nonmultipart.py +++ b/Lib/email/mime/nonmultipart.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2006 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/mime/text.py b/Lib/email/mime/text.py index 7672b789138..aa4da7f8217 100644 --- a/Lib/email/mime/text.py +++ b/Lib/email/mime/text.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org diff --git a/Lib/email/parser.py b/Lib/email/parser.py index e3003118ce1..c6a51dd8e37 100644 --- a/Lib/email/parser.py +++ b/Lib/email/parser.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2007 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw, Thomas Wouters, Anthony Baxter # Contact: email-sig@python.org diff --git a/Lib/email/policy.py b/Lib/email/policy.py index 6e109b65011..4169150101a 100644 --- a/Lib/email/policy.py +++ b/Lib/email/policy.py @@ -4,7 +4,13 @@ import re import sys -from email._policybase import Policy, Compat32, compat32, _extend_docstrings +from email._policybase import ( + Compat32, + Policy, + _extend_docstrings, + compat32, + validate_header_name +) from email.utils import _has_surrogates from email.headerregistry import HeaderRegistry as HeaderRegistry from email.contentmanager import raw_data_manager @@ -138,6 +144,7 @@ def header_store_parse(self, name, value): CR or LF characters. """ + validate_header_name(name) if hasattr(value, 'name') and value.name.lower() == name.lower(): return (name, value) if isinstance(value, str) and len(value.splitlines())>1: diff --git a/Lib/email/quoprimime.py b/Lib/email/quoprimime.py index 27fcbb5a26e..bc53b376821 100644 --- a/Lib/email/quoprimime.py +++ b/Lib/email/quoprimime.py @@ -1,11 +1,11 @@ -# Copyright (C) 2001-2006 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Ben Gertzfield # Contact: email-sig@python.org """Quoted-printable content transfer encoding per RFCs 2045-2047. This module handles the content transfer encoding method defined in RFC 2045 -to encode US ASCII-like 8-bit data called `quoted-printable'. It is used to +to encode US ASCII-like 8-bit data called 'quoted-printable'. It is used to safely encode text that is in a character set similar to the 7-bit US ASCII character set, but that includes some 8-bit characters that are normally not allowed in email bodies or headers. @@ -17,7 +17,7 @@ with quoted-printable encoding. RFC 2045 defines a method for including character set information in an -`encoded-word' in a header. This method is commonly used for 8-bit real names +'encoded-word' in a header. This method is commonly used for 8-bit real names in To:/From:/Cc: etc. fields, as well as Subject: lines. This module does not do the line wrapping or end-of-line character @@ -127,7 +127,7 @@ def quote(c): def header_encode(header_bytes, charset='iso-8859-1'): """Encode a single header line with quoted-printable (like) encoding. - Defined in RFC 2045, this `Q' encoding is similar to quoted-printable, but + Defined in RFC 2045, this 'Q' encoding is similar to quoted-printable, but used specifically for email header fields to allow charsets with mostly 7 bit characters (and some 8 bit) to remain more or less readable in non-RFC 2045 aware mail clients. @@ -272,7 +272,7 @@ def decode(encoded, eol=NL): decoded += eol # Special case if original string did not end with eol if encoded[-1] not in '\r\n' and decoded.endswith(eol): - decoded = decoded[:-1] + decoded = decoded[:-len(eol)] return decoded @@ -290,7 +290,7 @@ def _unquote_match(match): # Header decoding is done a bit differently def header_decode(s): - """Decode a string encoded with RFC 2045 MIME header `Q' encoding. + """Decode a string encoded with RFC 2045 MIME header 'Q' encoding. This function does not parse a full MIME header value encoded with quoted-printable (like =?iso-8859-1?q?Hello_World?=) -- please use diff --git a/Lib/email/utils.py b/Lib/email/utils.py index e4d35f06abc..3de1f0d24a1 100644 --- a/Lib/email/utils.py +++ b/Lib/email/utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2010 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Author: Barry Warsaw # Contact: email-sig@python.org @@ -472,23 +472,15 @@ def collapse_rfc2231_value(value, errors='replace', # better than not having it. # -def localtime(dt=None, isdst=None): +def localtime(dt=None): """Return local time as an aware datetime object. If called without arguments, return current time. Otherwise *dt* argument should be a datetime instance, and it is converted to the local time zone according to the system time zone database. If *dt* is naive (that is, dt.tzinfo is None), it is assumed to be in local time. - The isdst parameter is ignored. """ - if isdst is not None: - import warnings - warnings._deprecated( - "The 'isdst' parameter to 'localtime'", - message='{name} is deprecated and slated for removal in Python {remove}', - remove=(3, 14), - ) if dt is None: dt = datetime.datetime.now() return dt.astimezone() diff --git a/Lib/ensurepip/__init__.py b/Lib/ensurepip/__init__.py index 715389ea6c5..a8040457abf 100644 --- a/Lib/ensurepip/__init__.py +++ b/Lib/ensurepip/__init__.py @@ -10,7 +10,7 @@ __all__ = ["version", "bootstrap"] -_PIP_VERSION = "26.0.1" +_PIP_VERSION = "26.1.1" # Directory of system wheel packages. Some Linux distribution packaging # policies recommend against bundling dependencies. For example, Fedora diff --git a/Lib/ensurepip/_bundled/pip-26.0.1-py3-none-any.whl b/Lib/ensurepip/_bundled/pip-26.1.1-py3-none-any.whl similarity index 73% rename from Lib/ensurepip/_bundled/pip-26.0.1-py3-none-any.whl rename to Lib/ensurepip/_bundled/pip-26.1.1-py3-none-any.whl index 580d09a9204..ab0307c7716 100644 Binary files a/Lib/ensurepip/_bundled/pip-26.0.1-py3-none-any.whl and b/Lib/ensurepip/_bundled/pip-26.1.1-py3-none-any.whl differ diff --git a/Lib/http/client.py b/Lib/http/client.py index 77f8d26291d..6fb7d254ea9 100644 --- a/Lib/http/client.py +++ b/Lib/http/client.py @@ -972,13 +972,22 @@ def _wrap_ipv6(self, ip): return ip def _tunnel(self): + if _contains_disallowed_url_pchar_re.search(self._tunnel_host): + raise ValueError('Tunnel host can\'t contain control characters %r' + % (self._tunnel_host,)) connect = b"CONNECT %s:%d %s\r\n" % ( self._wrap_ipv6(self._tunnel_host.encode("idna")), self._tunnel_port, self._http_vsn_str.encode("ascii")) headers = [connect] for header, value in self._tunnel_headers.items(): - headers.append(f"{header}: {value}\r\n".encode("latin-1")) + header_bytes = header.encode("latin-1") + value_bytes = value.encode("latin-1") + if not _is_legal_header_name(header_bytes): + raise ValueError('Invalid header name %r' % (header_bytes,)) + if _is_illegal_header_value(value_bytes): + raise ValueError('Invalid header value %r' % (value_bytes,)) + headers.append(b"%s: %s\r\n" % (header_bytes, value_bytes)) headers.append(b"\r\n") # Making a single send() call instead of one per line encourages # the host OS to use a more optimal packet size instead of diff --git a/Lib/http/cookies.py b/Lib/http/cookies.py index d5b8ba939be..5c5b14788dc 100644 --- a/Lib/http/cookies.py +++ b/Lib/http/cookies.py @@ -391,17 +391,21 @@ def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.OutputString()) def js_output(self, attrs=None): + import base64 # Print javascript output_string = self.OutputString(attrs) if _has_control_character(output_string): raise CookieError("Control characters are not allowed in cookies") + # Base64-encode value to avoid template + # injection in cookie values. + output_encoded = base64.b64encode(output_string.encode('utf-8')).decode("ascii") return """ - """ % (output_string.replace('"', r'\"')) + """ % (output_encoded,) def OutputString(self, attrs=None): # Build up our result diff --git a/Lib/pickle.py b/Lib/pickle.py index beaefae0479..7b951858604 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -904,17 +904,11 @@ def save_picklebuffer(self, obj): # Write data in-band # XXX The C implementation avoids a copy here buf = m.tobytes() - in_memo = id(buf) in self.memo if m.readonly: - if in_memo: - self._save_bytes_no_memo(buf) - else: - self.save_bytes(buf) + self._save_bytes_no_memo(buf) else: - if in_memo: - self._save_bytearray_no_memo(buf) - else: - self.save_bytearray(buf) + self._save_bytearray_no_memo(buf) + self.memoize(obj) else: # Write data out-of-band self.write(NEXT_BUFFER) diff --git a/Lib/random.py b/Lib/random.py index 86d562f0b8a..726a71e7828 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -836,7 +836,11 @@ def binomialvariate(self, n=1, p=0.5): if not c: return x while True: - y += _floor(_log2(random()) / c) + 1 + try: + y += _floor(_log2(random()) / c) + 1 + except ValueError: + # Reject case where random() returned 0.0 + continue if y > n: return x x += 1 @@ -844,8 +848,8 @@ def binomialvariate(self, n=1, p=0.5): # BTRS: Transformed rejection with squeeze method by Wolfgang Hörmann # https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.47.8407&rep=rep1&type=pdf assert n*p >= 10.0 and p <= 0.5 - setup_complete = False + setup_complete = False spq = _sqrt(n * p * (1.0 - p)) # Standard deviation of the distribution b = 1.15 + 2.53 * spq a = -0.0873 + 0.0248 * b + 0.01 * p @@ -860,22 +864,23 @@ def binomialvariate(self, n=1, p=0.5): k = _floor((2.0 * a / us + b) * u + c) if k < 0 or k > n: continue + v = random() # The early-out "squeeze" test substantially reduces # the number of acceptance condition evaluations. - v = random() if us >= 0.07 and v <= vr: return k - # Acceptance-rejection test. - # Note, the original paper erroneously omits the call to log(v) - # when comparing to the log of the rescaled binomial distribution. if not setup_complete: alpha = (2.83 + 5.1 / b) * spq lpq = _log(p / (1.0 - p)) m = _floor((n + 1) * p) # Mode of the distribution h = _lgamma(m + 1) + _lgamma(n - m + 1) setup_complete = True # Only needs to be done once + + # Acceptance-rejection test. + # Note, the original paper erroneously omits the call to log(v) + # when comparing to the log of the rescaled binomial distribution. v *= alpha / (a / (us * us) + b) if _log(v) <= h - _lgamma(k + 1) - _lgamma(n - k + 1) + (k - m) * lpq: return k diff --git a/Lib/shutil.py b/Lib/shutil.py index 8d8fe145567..b7608f7edfc 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -1314,27 +1314,9 @@ def _unpack_zipfile(filename, extract_dir): if not zipfile.is_zipfile(filename): raise ReadError("%s is not a zip file" % filename) - zip = zipfile.ZipFile(filename) - try: - for info in zip.infolist(): - name = info.filename - - # don't extract absolute paths or ones with .. in them - if name.startswith('/') or '..' in name: - continue - - targetpath = os.path.join(extract_dir, *name.split('/')) - if not targetpath: - continue - - _ensure_directory(targetpath) - if not name.endswith('/'): - # file - with zip.open(name, 'r') as source, \ - open(targetpath, 'wb') as target: - copyfileobj(source, target) - finally: - zip.close() + with zipfile.ZipFile(filename) as zip: + zip._ignore_invalid_names = True + zip.extractall(extract_dir) def _unpack_tarfile(filename, extract_dir, *, filter=None): """Unpack tar/tar.gz/tar.bz2/tar.xz/tar.zst `filename` to `extract_dir` diff --git a/Lib/tarfile.py b/Lib/tarfile.py index c7e9f7d681a..414aefe9744 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -1278,6 +1278,20 @@ def _create_pax_generic_header(cls, pax_headers, type, encoding): @classmethod def frombuf(cls, buf, encoding, errors): """Construct a TarInfo object from a 512 byte bytes object. + + To support the old v7 tar format AREGTYPE headers are + transformed to DIRTYPE headers if their name ends in '/'. + """ + return cls._frombuf(buf, encoding, errors) + + @classmethod + def _frombuf(cls, buf, encoding, errors, *, dircheck=True): + """Construct a TarInfo object from a 512 byte bytes object. + + If ``dircheck`` is set to ``True`` then ``AREGTYPE`` headers will + be normalized to ``DIRTYPE`` if the name ends in a trailing slash. + ``dircheck`` must be set to ``False`` if this function is called + on a follow-up header such as ``GNUTYPE_LONGNAME``. """ if len(buf) == 0: raise EmptyHeaderError("empty header") @@ -1308,7 +1322,7 @@ def frombuf(cls, buf, encoding, errors): # Old V7 tar format represents a directory as a regular # file with a trailing slash. - if obj.type == AREGTYPE and obj.name.endswith("/"): + if dircheck and obj.type == AREGTYPE and obj.name.endswith("/"): obj.type = DIRTYPE # The old GNU sparse format occupies some of the unused @@ -1343,8 +1357,15 @@ def fromtarfile(cls, tarfile): """Return the next TarInfo object from TarFile object tarfile. """ + return cls._fromtarfile(tarfile) + + @classmethod + def _fromtarfile(cls, tarfile, *, dircheck=True): + """ + See dircheck documentation in _frombuf(). + """ buf = tarfile.fileobj.read(BLOCKSIZE) - obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors) + obj = cls._frombuf(buf, tarfile.encoding, tarfile.errors, dircheck=dircheck) obj.offset = tarfile.fileobj.tell() - BLOCKSIZE return obj._proc_member(tarfile) @@ -1402,7 +1423,7 @@ def _proc_gnulong(self, tarfile): # Fetch the next header and process it. try: - next = self.fromtarfile(tarfile) + next = self._fromtarfile(tarfile, dircheck=False) except HeaderError as e: raise SubsequentHeaderError(str(e)) from None @@ -1537,7 +1558,7 @@ def _proc_pax(self, tarfile): # Fetch the next header. try: - next = self.fromtarfile(tarfile) + next = self._fromtarfile(tarfile, dircheck=False) except HeaderError as e: raise SubsequentHeaderError(str(e)) from None diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py new file mode 100644 index 00000000000..6884ac0dbe6 --- /dev/null +++ b/Lib/test/audit-tests.py @@ -0,0 +1,681 @@ +"""This script contains the actual auditing tests. + +It should not be imported directly, but should be run by the test_audit +module with arguments identifying each test. + +""" + +import contextlib +import os +import sys + + +class TestHook: + """Used in standard hook tests to collect any logged events. + + Should be used in a with block to ensure that it has no impact + after the test completes. + """ + + def __init__(self, raise_on_events=None, exc_type=RuntimeError): + self.raise_on_events = raise_on_events or () + self.exc_type = exc_type + self.seen = [] + self.closed = False + + def __enter__(self, *a): + sys.addaudithook(self) + return self + + def __exit__(self, *a): + self.close() + + def close(self): + self.closed = True + + @property + def seen_events(self): + return [i[0] for i in self.seen] + + def __call__(self, event, args): + if self.closed: + return + self.seen.append((event, args)) + if event in self.raise_on_events: + raise self.exc_type("saw event " + event) + + +# Simple helpers, since we are not in unittest here +def assertEqual(x, y): + if x != y: + raise AssertionError(f"{x!r} should equal {y!r}") + + +def assertIn(el, series): + if el not in series: + raise AssertionError(f"{el!r} should be in {series!r}") + + +def assertNotIn(el, series): + if el in series: + raise AssertionError(f"{el!r} should not be in {series!r}") + + +def assertSequenceEqual(x, y): + if len(x) != len(y): + raise AssertionError(f"{x!r} should equal {y!r}") + if any(ix != iy for ix, iy in zip(x, y)): + raise AssertionError(f"{x!r} should equal {y!r}") + + +@contextlib.contextmanager +def assertRaises(ex_type): + try: + yield + assert False, f"expected {ex_type}" + except BaseException as ex: + if isinstance(ex, AssertionError): + raise + assert type(ex) is ex_type, f"{ex} should be {ex_type}" + + +def test_basic(): + with TestHook() as hook: + sys.audit("test_event", 1, 2, 3) + assertEqual(hook.seen[0][0], "test_event") + assertEqual(hook.seen[0][1], (1, 2, 3)) + + +def test_block_add_hook(): + # Raising an exception should prevent a new hook from being added, + # but will not propagate out. + with TestHook(raise_on_events="sys.addaudithook") as hook1: + with TestHook() as hook2: + sys.audit("test_event") + assertIn("test_event", hook1.seen_events) + assertNotIn("test_event", hook2.seen_events) + + +def test_block_add_hook_baseexception(): + # Raising BaseException will propagate out when adding a hook + with assertRaises(BaseException): + with TestHook( + raise_on_events="sys.addaudithook", exc_type=BaseException + ) as hook1: + # Adding this next hook should raise BaseException + with TestHook() as hook2: + pass + + +def test_marshal(): + import marshal + o = ("a", "b", "c", 1, 2, 3) + payload = marshal.dumps(o) + + with TestHook() as hook: + assertEqual(o, marshal.loads(marshal.dumps(o))) + + try: + with open("test-marshal.bin", "wb") as f: + marshal.dump(o, f) + with open("test-marshal.bin", "rb") as f: + assertEqual(o, marshal.load(f)) + finally: + os.unlink("test-marshal.bin") + + actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"] + assertSequenceEqual(actual, [(o, marshal.version)] * 2) + + actual = [a[0] for e, a in hook.seen if e == "marshal.loads"] + assertSequenceEqual(actual, [payload]) + + actual = [e for e, a in hook.seen if e == "marshal.load"] + assertSequenceEqual(actual, ["marshal.load"]) + + +def test_pickle(): + import pickle + + class PicklePrint: + def __reduce_ex__(self, p): + return str, ("Pwned!",) + + payload_1 = pickle.dumps(PicklePrint()) + payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3)) + + # Before we add the hook, ensure our malicious pickle loads + assertEqual("Pwned!", pickle.loads(payload_1)) + + with TestHook(raise_on_events="pickle.find_class") as hook: + with assertRaises(RuntimeError): + # With the hook enabled, loading globals is not allowed + pickle.loads(payload_1) + # pickles with no globals are okay + pickle.loads(payload_2) + + +def test_monkeypatch(): + class A: + pass + + class B: + pass + + class C(A): + pass + + a = A() + + with TestHook() as hook: + # Catch name changes + C.__name__ = "X" + # Catch type changes + C.__bases__ = (B,) + # Ensure bypassing __setattr__ is still caught + type.__dict__["__bases__"].__set__(C, (B,)) + # Catch attribute replacement + C.__init__ = B.__init__ + # Catch attribute addition + C.new_attr = 123 + # Catch class changes + a.__class__ = B + + actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] + assertSequenceEqual( + [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual + ) + + +def test_open(testfn): + # SSLContext.load_dh_params uses Py_fopen() rather than normal open() + try: + import ssl + + load_dh_params = ssl.create_default_context().load_dh_params + except ImportError: + load_dh_params = None + + try: + import readline + except ImportError: + readline = None + + def rl(name): + if readline: + return getattr(readline, name, None) + else: + return None + + # Try a range of "open" functions. + # All of them should fail + with TestHook(raise_on_events={"open"}) as hook: + for fn, *args in [ + (open, testfn, "r"), + (open, sys.executable, "rb"), + (open, 3, "wb"), + (open, testfn, "w", -1, None, None, None, False, lambda *a: 1), + (load_dh_params, testfn), + (rl("read_history_file"), testfn), + (rl("read_history_file"), None), + (rl("write_history_file"), testfn), + (rl("write_history_file"), None), + (rl("append_history_file"), 0, testfn), + (rl("append_history_file"), 0, None), + (rl("read_init_file"), testfn), + (rl("read_init_file"), None), + ]: + if not fn: + continue + with assertRaises(RuntimeError): + try: + fn(*args) + except NotImplementedError: + if fn == load_dh_params: + # Not callable in some builds + load_dh_params = None + raise RuntimeError + else: + raise + + actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] + actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] + assertSequenceEqual( + [ + i + for i in [ + (testfn, "r"), + (sys.executable, "r"), + (3, "w"), + (testfn, "w"), + (testfn, "rb") if load_dh_params else None, + (testfn, "r") if readline else None, + ("~/.history", "r") if readline else None, + (testfn, "w") if readline else None, + ("~/.history", "w") if readline else None, + (testfn, "a") if rl("append_history_file") else None, + ("~/.history", "a") if rl("append_history_file") else None, + (testfn, "r") if readline else None, + ("", "r") if readline else None, + ] + if i is not None + ], + actual_mode, + ) + assertSequenceEqual([], actual_flag) + + +def test_cantrace(): + traced = [] + + def trace(frame, event, *args): + if frame.f_code == TestHook.__call__.__code__: + traced.append(event) + + old = sys.settrace(trace) + try: + with TestHook() as hook: + # No traced call + eval("1") + + # No traced call + hook.__cantrace__ = False + eval("2") + + # One traced call + hook.__cantrace__ = True + eval("3") + + # Two traced calls (writing to private member, eval) + hook.__cantrace__ = 1 + eval("4") + + # One traced call (writing to private member) + hook.__cantrace__ = 0 + finally: + sys.settrace(old) + + assertSequenceEqual(["call"] * 4, traced) + + +def test_mmap(): + import mmap + + with TestHook() as hook: + mmap.mmap(-1, 8) + assertEqual(hook.seen[0][1][:2], (-1, 8)) + + +def test_ctypes_call_function(): + import ctypes + import _ctypes + + with TestHook() as hook: + _ctypes.call_function(ctypes._memmove_addr, (0, 0, 0)) + assert ("ctypes.call_function", (ctypes._memmove_addr, (0, 0, 0))) in hook.seen, f"{ctypes._memmove_addr=} {hook.seen=}" + + ctypes.CFUNCTYPE(ctypes.c_voidp)(ctypes._memset_addr)(1, 0, 0) + assert ("ctypes.call_function", (ctypes._memset_addr, (1, 0, 0))) in hook.seen, f"{ctypes._memset_addr=} {hook.seen=}" + + with TestHook() as hook: + ctypes.cast(ctypes.c_voidp(0), ctypes.POINTER(ctypes.c_char)) + assert "ctypes.call_function" in hook.seen_events + + with TestHook() as hook: + ctypes.string_at(id("ctypes.string_at") + 40) + assert "ctypes.call_function" in hook.seen_events + assert "ctypes.string_at" in hook.seen_events + + +def test_posixsubprocess(): + import multiprocessing.util + + exe = b"xxx" + args = [b"yyy", b"zzz"] + with TestHook() as hook: + multiprocessing.util.spawnv_passfds(exe, args, ()) + assert ("_posixsubprocess.fork_exec", ([exe], args, None)) in hook.seen + + +def test_excepthook(): + def excepthook(exc_type, exc_value, exc_tb): + if exc_type is not RuntimeError: + sys.__excepthook__(exc_type, exc_value, exc_tb) + + def hook(event, args): + if event == "sys.excepthook": + if not isinstance(args[2], args[1]): + raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})") + if args[0] != excepthook: + raise ValueError(f"Expected {args[0]} == {excepthook}") + print(event, repr(args[2])) + + sys.addaudithook(hook) + sys.excepthook = excepthook + raise RuntimeError("fatal-error") + + +def test_unraisablehook(): + from _testcapi import err_formatunraisable + + def unraisablehook(hookargs): + pass + + def hook(event, args): + if event == "sys.unraisablehook": + if args[0] != unraisablehook: + raise ValueError(f"Expected {args[0]} == {unraisablehook}") + print(event, repr(args[1].exc_value), args[1].err_msg) + + sys.addaudithook(hook) + sys.unraisablehook = unraisablehook + err_formatunraisable(RuntimeError("nonfatal-error"), + "Exception ignored for audit hook test") + + +def test_winreg(): + from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE + + def hook(event, args): + if not event.startswith("winreg."): + return + print(event, *args) + + sys.addaudithook(hook) + + k = OpenKey(HKEY_LOCAL_MACHINE, "Software") + EnumKey(k, 0) + try: + EnumKey(k, 10000) + except OSError: + pass + else: + raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail") + + kv = k.Detach() + CloseKey(kv) + + +def test_socket(): + import socket + + def hook(event, args): + if event.startswith("socket."): + print(event, *args) + + sys.addaudithook(hook) + + socket.gethostname() + + # Don't care if this fails, we just want the audit message + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + # Don't care if this fails, we just want the audit message + sock.bind(('127.0.0.1', 8080)) + except Exception: + pass + finally: + sock.close() + + +def test_gc(): + import gc + + def hook(event, args): + if event.startswith("gc."): + print(event, *args) + + sys.addaudithook(hook) + + gc.get_objects(generation=1) + + x = object() + y = [x] + + gc.get_referrers(x) + gc.get_referents(y) + + +def test_http_client(): + import http.client + + def hook(event, args): + if event.startswith("http.client."): + print(event, *args[1:]) + + sys.addaudithook(hook) + + conn = http.client.HTTPConnection('www.python.org') + try: + conn.request('GET', '/') + except OSError: + print('http.client.send', '[cannot send]') + finally: + conn.close() + + +def test_sqlite3(): + import sqlite3 + + def hook(event, *args): + if event.startswith("sqlite3."): + print(event, *args) + + sys.addaudithook(hook) + cx1 = sqlite3.connect(":memory:") + cx2 = sqlite3.Connection(":memory:") + + # Configured without --enable-loadable-sqlite-extensions + try: + if hasattr(sqlite3.Connection, "enable_load_extension"): + cx1.enable_load_extension(False) + try: + cx1.load_extension("test") + except sqlite3.OperationalError: + pass + else: + raise RuntimeError("Expected sqlite3.load_extension to fail") + finally: + cx1.close() + cx2.close() + +def test_sys_getframe(): + import sys + + def hook(event, args): + if event.startswith("sys."): + print(event, args[0].f_code.co_name) + + sys.addaudithook(hook) + sys._getframe() + + +def test_sys_getframemodulename(): + import sys + + def hook(event, args): + if event.startswith("sys."): + print(event, *args) + + sys.addaudithook(hook) + sys._getframemodulename() + + +def test_threading(): + import _thread + + def hook(event, args): + if event.startswith(("_thread.", "cpython.PyThreadState", "test.")): + print(event, args) + + sys.addaudithook(hook) + + lock = _thread.allocate_lock() + lock.acquire() + + class test_func: + def __repr__(self): return "" + def __call__(self): + sys.audit("test.test_func") + lock.release() + + i = _thread.start_new_thread(test_func(), ()) + lock.acquire() + + handle = _thread.start_joinable_thread(test_func()) + handle.join() + + +def test_threading_abort(): + # Ensures that aborting PyThreadState_New raises the correct exception + import _thread + + class ThreadNewAbortError(Exception): + pass + + def hook(event, args): + if event == "cpython.PyThreadState_New": + raise ThreadNewAbortError() + + sys.addaudithook(hook) + + try: + _thread.start_new_thread(lambda: None, ()) + except ThreadNewAbortError: + # Other exceptions are raised and the test will fail + pass + + +def test_wmi_exec_query(): + import _wmi + + def hook(event, args): + if event.startswith("_wmi."): + print(event, args[0]) + + sys.addaudithook(hook) + try: + _wmi.exec_query("SELECT * FROM Win32_OperatingSystem") + except WindowsError as e: + # gh-112278: WMI may be slow response when first called, but we still + # get the audit event, so just ignore the timeout + if e.winerror != 258: + raise + +def test_syslog(): + import syslog + + def hook(event, args): + if event.startswith("syslog."): + print(event, *args) + + sys.addaudithook(hook) + syslog.openlog('python') + syslog.syslog('test') + syslog.setlogmask(syslog.LOG_DEBUG) + syslog.closelog() + # implicit open + syslog.syslog('test2') + # open with default ident + syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0) + sys.argv = None + syslog.openlog() + syslog.closelog() + + +def test_not_in_gc(): + import gc + + hook = lambda *a: None + sys.addaudithook(hook) + + for o in gc.get_objects(): + if isinstance(o, list): + assert hook not in o + + +def test_time(mode): + import time + + def hook(event, args): + if event.startswith("time."): + if mode == 'print': + print(event, *args) + elif mode == 'fail': + raise AssertionError('hook failed') + sys.addaudithook(hook) + + time.sleep(0) + time.sleep(0.0625) # 1/16, a small exact float + try: + time.sleep(-1) + except ValueError: + pass + +def test_sys_monitoring_register_callback(): + import sys + + def hook(event, args): + if event.startswith("sys.monitoring"): + print(event, args) + + sys.addaudithook(hook) + sys.monitoring.register_callback(1, 1, None) + + +def test_winapi_createnamedpipe(pipe_name): + import _winapi + + def hook(event, args): + if event == "_winapi.CreateNamedPipe": + print(event, args) + + sys.addaudithook(hook) + _winapi.CreateNamedPipe(pipe_name, _winapi.PIPE_ACCESS_DUPLEX, 8, 2, 0, 0, 0, 0) + + +def test_assert_unicode(): + import sys + sys.addaudithook(lambda *args: None) + try: + sys.audit(9) + except TypeError: + pass + else: + raise RuntimeError("Expected sys.audit(9) to fail.") + +def test_sys_remote_exec(): + import tempfile + + pid = os.getpid() + event_pid = -1 + event_script_path = "" + remote_event_script_path = "" + def hook(event, args): + if event not in ["sys.remote_exec", "cpython.remote_debugger_script"]: + return + print(event, args) + match event: + case "sys.remote_exec": + nonlocal event_pid, event_script_path + event_pid = args[0] + event_script_path = args[1] + case "cpython.remote_debugger_script": + nonlocal remote_event_script_path + remote_event_script_path = args[0] + + sys.addaudithook(hook) + with tempfile.NamedTemporaryFile(mode='w+', delete=True) as tmp_file: + tmp_file.write("a = 1+1\n") + tmp_file.flush() + sys.remote_exec(pid, tmp_file.name) + assertEqual(event_pid, pid) + assertEqual(event_script_path, tmp_file.name) + assertEqual(remote_event_script_path, tmp_file.name) + +if __name__ == "__main__": + from test.support import suppress_msvcrt_asserts + + suppress_msvcrt_asserts() + + test = sys.argv[1] + globals()[test](*sys.argv[2:]) diff --git a/Lib/test/exception_hierarchy.txt b/Lib/test/exception_hierarchy.txt index 1eca123be0f..f2649aa2d41 100644 --- a/Lib/test/exception_hierarchy.txt +++ b/Lib/test/exception_hierarchy.txt @@ -40,6 +40,7 @@ BaseException ├── ReferenceError ├── RuntimeError │ ├── NotImplementedError + │ ├── PythonFinalizationError │ └── RecursionError ├── StopAsyncIteration ├── StopIteration diff --git a/Lib/test/picklecommon.py b/Lib/test/picklecommon.py new file mode 100644 index 00000000000..4c19b6c421f --- /dev/null +++ b/Lib/test/picklecommon.py @@ -0,0 +1,390 @@ +# Classes used for pickle testing. +# They are moved to separate file, so they can be loaded +# in other Python version for test_xpickle. + +import sys + +class C: + def __eq__(self, other): + return self.__dict__ == other.__dict__ + +# For test_load_classic_instance +class D(C): + def __init__(self, arg): + pass + +class E(C): + def __getinitargs__(self): + return () + +import __main__ +__main__.C = C +C.__module__ = "__main__" +__main__.D = D +D.__module__ = "__main__" +__main__.E = E +E.__module__ = "__main__" + +# Simple mutable object. +class Object(object): + pass + +# Hashable immutable key object containing unheshable mutable data. +class K: + def __init__(self, value): + self.value = value + + def __reduce__(self): + # Shouldn't support the recursion itself + return K, (self.value,) + +class WithSlots(object): + __slots__ = ('a', 'b') + +class WithSlotsSubclass(WithSlots): + __slots__ = ('c',) + +class WithSlotsAndDict(object): + __slots__ = ('a', '__dict__') + +class WithPrivateAttrs(object): + def __init__(self, a): + self.__private = a + def get(self): + return self.__private + +class WithPrivateAttrsSubclass(WithPrivateAttrs): + def __init__(self, a, b): + super().__init__(a) + self.__private = b + def get2(self): + return self.__private + +class WithPrivateSlots(object): + __slots__ = ('__private',) + def __init__(self, a): + self.__private = a + def get(self): + return self.__private + +class WithPrivateSlotsSubclass(WithPrivateSlots): + __slots__ = ('__private',) + def __init__(self, a, b): + super().__init__(a) + self.__private = b + def get2(self): + return self.__private + +# For test_misc +class myint(int): + def __init__(self, x): + self.str = str(x) + +# For test_misc and test_getinitargs +class initarg(C): + + def __init__(self, a, b): + self.a = a + self.b = b + + def __getinitargs__(self): + return self.a, self.b + +# For test_metaclass +class metaclass(type): + pass + +if sys.version_info >= (3,): + # Syntax not compatible with Python 2 + exec(''' +class use_metaclass(object, metaclass=metaclass): + pass +''') +else: + class use_metaclass(object): + __metaclass__ = metaclass + + +# Test classes for reduce_ex + +class R: + def __init__(self, reduce=None): + self.reduce = reduce + def __reduce__(self, proto): + return self.reduce + +class REX: + def __init__(self, reduce_ex=None): + self.reduce_ex = reduce_ex + def __reduce_ex__(self, proto): + return self.reduce_ex + +class REX_one(object): + """No __reduce_ex__ here, but inheriting it from object""" + _reduce_called = 0 + def __reduce__(self): + self._reduce_called = 1 + return REX_one, () + +class REX_two(object): + """No __reduce__ here, but inheriting it from object""" + _proto = None + def __reduce_ex__(self, proto): + self._proto = proto + return REX_two, () + +class REX_three(object): + _proto = None + def __reduce_ex__(self, proto): + self._proto = proto + return REX_two, () + def __reduce__(self): + raise AssertionError("This __reduce__ shouldn't be called") + +class REX_four(object): + """Calling base class method should succeed""" + _proto = None + def __reduce_ex__(self, proto): + self._proto = proto + return object.__reduce_ex__(self, proto) + +class REX_five(object): + """This one used to fail with infinite recursion""" + _reduce_called = 0 + def __reduce__(self): + self._reduce_called = 1 + return object.__reduce__(self) + +class REX_six(object): + """This class is used to check the 4th argument (list iterator) of + the reduce protocol. + """ + def __init__(self, items=None): + self.items = items if items is not None else [] + def __eq__(self, other): + return type(self) is type(other) and self.items == other.items + def append(self, item): + self.items.append(item) + def __reduce__(self): + return type(self), (), None, iter(self.items), None + +class REX_seven(object): + """This class is used to check the 5th argument (dict iterator) of + the reduce protocol. + """ + def __init__(self, table=None): + self.table = table if table is not None else {} + def __eq__(self, other): + return type(self) is type(other) and self.table == other.table + def __setitem__(self, key, value): + self.table[key] = value + def __reduce__(self): + return type(self), (), None, None, iter(self.table.items()) + +class REX_state(object): + """This class is used to check the 3th argument (state) of + the reduce protocol. + """ + def __init__(self, state=None): + self.state = state + def __eq__(self, other): + return type(self) is type(other) and self.state == other.state + def __setstate__(self, state): + self.state = state + def __reduce__(self): + return type(self), (), self.state + +# For test_reduce_ex_None +class REX_None: + """ Setting __reduce_ex__ to None should fail """ + __reduce_ex__ = None + +# For test_reduce_None +class R_None: + """ Setting __reduce__ to None should fail """ + __reduce__ = None + +# For test_pickle_setstate_None +class C_None_setstate: + """ Setting __setstate__ to None should fail """ + def __getstate__(self): + return 1 + + __setstate__ = None + + +# Test classes for newobj + +# For test_newobj_generic and test_newobj_proxies + +class MyInt(int): + sample = 1 + +if sys.version_info >= (3,): + class MyLong(int): + sample = 1 +else: + class MyLong(long): + sample = long(1) + +class MyFloat(float): + sample = 1.0 + +class MyComplex(complex): + sample = 1.0 + 0.0j + +class MyStr(str): + sample = "hello" + +if sys.version_info >= (3,): + class MyUnicode(str): + sample = "hello \u1234" +else: + class MyUnicode(unicode): + sample = unicode(r"hello \u1234", "raw-unicode-escape") + +class MyTuple(tuple): + sample = (1, 2, 3) + +class MyList(list): + sample = [1, 2, 3] + +class MyDict(dict): + sample = {"a": 1, "b": 2} + +class MySet(set): + sample = {"a", "b"} + +class MyFrozenSet(frozenset): + sample = frozenset({"a", "b"}) + +myclasses = [MyInt, MyLong, MyFloat, + MyComplex, + MyStr, MyUnicode, + MyTuple, MyList, MyDict, MySet, MyFrozenSet] + +# For test_newobj_overridden_new +class MyIntWithNew(int): + def __new__(cls, value): + raise AssertionError + +class MyIntWithNew2(MyIntWithNew): + __new__ = int.__new__ + + +# For test_newobj_list_slots +class SlotList(MyList): + __slots__ = ["foo"] + +# Ruff "redefined while unused" false positive here due to `global` variables +# being assigned (and then restored) from within test methods earlier in the file +class SimpleNewObj(int): # noqa: F811 + def __init__(self, *args, **kwargs): + # raise an error, to make sure this isn't called + raise TypeError("SimpleNewObj.__init__() didn't expect to get called") + def __eq__(self, other): + return int(self) == int(other) and self.__dict__ == other.__dict__ + +class ComplexNewObj(SimpleNewObj): + def __getnewargs__(self): + return ('%X' % self, 16) + +class ComplexNewObjEx(SimpleNewObj): + def __getnewargs_ex__(self): + return ('%X' % self,), {'base': 16} + + +class ZeroCopyBytes(bytes): + readonly = True + c_contiguous = True + f_contiguous = True + zero_copy_reconstruct = True + + def __reduce_ex__(self, protocol): + if protocol >= 5: + import pickle + return type(self)._reconstruct, (pickle.PickleBuffer(self),), None + else: + return type(self)._reconstruct, (bytes(self),) + + def __repr__(self): + return "{}({!r})".format(self.__class__.__name__, bytes(self)) + + __str__ = __repr__ + + @classmethod + def _reconstruct(cls, obj): + with memoryview(obj) as m: + obj = m.obj + if type(obj) is cls: + # Zero-copy + return obj + else: + return cls(obj) + + +class ZeroCopyBytearray(bytearray): + readonly = False + c_contiguous = True + f_contiguous = True + zero_copy_reconstruct = True + + def __reduce_ex__(self, protocol): + if protocol >= 5: + import pickle + return type(self)._reconstruct, (pickle.PickleBuffer(self),), None + else: + return type(self)._reconstruct, (bytes(self),) + + def __repr__(self): + return "{}({!r})".format(self.__class__.__name__, bytes(self)) + + __str__ = __repr__ + + @classmethod + def _reconstruct(cls, obj): + with memoryview(obj) as m: + obj = m.obj + if type(obj) is cls: + # Zero-copy + return obj + else: + return cls(obj) + + +# For test_nested_names +class Nested: + class A: + class B: + class C: + pass + +# For test_py_methods +class PyMethodsTest: + @staticmethod + def cheese(): + return "cheese" + @classmethod + def wine(cls): + assert cls is PyMethodsTest + return "wine" + def biscuits(self): + assert isinstance(self, PyMethodsTest) + return "biscuits" + class Nested: + "Nested class" + @staticmethod + def ketchup(): + return "ketchup" + @classmethod + def maple(cls): + assert cls is PyMethodsTest.Nested + return "maple" + def pie(self): + assert isinstance(self, PyMethodsTest.Nested) + return "pie" + +# For test_c_methods +class Subclass(tuple): + class Nested(str): + pass diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 9a3a26a8400..7c81ba8db3e 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -15,6 +15,7 @@ import types import unittest import weakref +import __main__ from textwrap import dedent from http.cookies import SimpleCookie @@ -26,13 +27,15 @@ from test import support from test.support import os_helper from test.support import ( - TestFailed, run_with_locales, no_tracing, + run_with_locales, no_tracing, _2G, _4G, bigmemtest ) from test.support.import_helper import forget from test.support.os_helper import TESTFN from test.support import threading_helper from test.support.warnings_helper import save_restore_warnings_filters +from test import picklecommon +from test.picklecommon import * from pickle import bytes_types @@ -54,6 +57,8 @@ # kind of outer loop. protocols = range(pickle.HIGHEST_PROTOCOL + 1) +FAST_NESTING_LIMIT = 50 + # Return True if opcode code appears in the pickle, else False. def opcode_in_pickle(code, pickle): @@ -132,58 +137,6 @@ def restore(self): if pair is not None: copyreg.add_extension(pair[0], pair[1], code) -class C: - def __eq__(self, other): - return self.__dict__ == other.__dict__ - -class D(C): - def __init__(self, arg): - pass - -class E(C): - def __getinitargs__(self): - return () - -import __main__ -__main__.C = C -C.__module__ = "__main__" -__main__.D = D -D.__module__ = "__main__" -__main__.E = E -E.__module__ = "__main__" - -# Simple mutable object. -class Object: - pass - -# Hashable immutable key object containing unheshable mutable data. -class K: - def __init__(self, value): - self.value = value - - def __reduce__(self): - # Shouldn't support the recursion itself - return K, (self.value,) - -class myint(int): - def __init__(self, x): - self.str = str(x) - -class initarg(C): - - def __init__(self, a, b): - self.a = a - self.b = b - - def __getinitargs__(self): - return self.a, self.b - -class metaclass(type): - pass - -class use_metaclass(object, metaclass=metaclass): - pass - class pickling_metaclass(type): def __eq__(self, other): return (type(self) == type(other) and @@ -198,62 +151,6 @@ def create_dynamic_class(name, bases): return result -class ZeroCopyBytes(bytes): - readonly = True - c_contiguous = True - f_contiguous = True - zero_copy_reconstruct = True - - def __reduce_ex__(self, protocol): - if protocol >= 5: - return type(self)._reconstruct, (pickle.PickleBuffer(self),), None - else: - return type(self)._reconstruct, (bytes(self),) - - def __repr__(self): - return "{}({!r})".format(self.__class__.__name__, bytes(self)) - - __str__ = __repr__ - - @classmethod - def _reconstruct(cls, obj): - with memoryview(obj) as m: - obj = m.obj - if type(obj) is cls: - # Zero-copy - return obj - else: - return cls(obj) - - -class ZeroCopyBytearray(bytearray): - readonly = False - c_contiguous = True - f_contiguous = True - zero_copy_reconstruct = True - - def __reduce_ex__(self, protocol): - if protocol >= 5: - return type(self)._reconstruct, (pickle.PickleBuffer(self),), None - else: - return type(self)._reconstruct, (bytes(self),) - - def __repr__(self): - return "{}({!r})".format(self.__class__.__name__, bytes(self)) - - __str__ = __repr__ - - @classmethod - def _reconstruct(cls, obj): - with memoryview(obj) as m: - obj = m.obj - if type(obj) is cls: - # Zero-copy - return obj - else: - return cls(obj) - - if _testbuffer is not None: class PicklableNDArray: @@ -298,9 +195,10 @@ def __ne__(self, other): return not (self == other) def __repr__(self): - return (f"{type(self)}(shape={self.array.shape}," - f"strides={self.array.strides}, " - f"bytes={self.array.tobytes()})") + return ("{name}(shape={array.shape}," + "strides={array.strides}, " + "bytes={array.tobytes()})").format( + name=type(self).__name__, array=self.array.shape) def __reduce_ex__(self, protocol): if not self.array.contiguous: @@ -1359,9 +1257,9 @@ def check(key, exc): self.loads(b'\x82\x01.') check(None, ValueError) check((), ValueError) - check((__name__,), (TypeError, ValueError)) - check((__name__, "MyList", "x"), (TypeError, ValueError)) - check((__name__, None), (TypeError, ValueError)) + check((MyList.__module__,), (TypeError, ValueError)) + check((MyList.__module__, "MyList", "x"), (TypeError, ValueError)) + check((MyList.__module__, None), (TypeError, ValueError)) check((None, "MyList"), (TypeError, ValueError)) def test_bad_reduce(self): @@ -1675,7 +1573,7 @@ def test_bad_reduce_result(self): self.assertEqual(str(cm.exception), '__reduce__ must return a string or tuple, not list') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) obj = REX((print,)) for proto in protocols: @@ -1685,7 +1583,7 @@ def test_bad_reduce_result(self): self.assertEqual(str(cm.exception), 'tuple returned by __reduce__ must contain 2 through 6 elements') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) obj = REX((print, (), None, None, None, None, None)) for proto in protocols: @@ -1695,7 +1593,7 @@ def test_bad_reduce_result(self): self.assertEqual(str(cm.exception), 'tuple returned by __reduce__ must contain 2 through 6 elements') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_bad_reconstructor(self): obj = REX((42, ())) @@ -1707,7 +1605,7 @@ def test_bad_reconstructor(self): 'first item of the tuple returned by __reduce__ ' 'must be callable, not int') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_reconstructor(self): obj = REX((UnpickleableCallable(), ())) @@ -1716,8 +1614,8 @@ def test_unpickleable_reconstructor(self): with self.assertRaises(CustomError) as cm: self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX reconstructor', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor', + f'when serializing {REX.__module__}.REX object']) def test_bad_reconstructor_args(self): obj = REX((print, [])) @@ -1729,7 +1627,7 @@ def test_bad_reconstructor_args(self): 'second item of the tuple returned by __reduce__ ' 'must be a tuple, not list') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_reconstructor_args(self): obj = REX((print, (1, 2, UNPICKLEABLE))) @@ -1739,8 +1637,8 @@ def test_unpickleable_reconstructor_args(self): self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ 'when serializing tuple item 2', - 'when serializing test.pickletester.REX reconstructor arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor arguments', + f'when serializing {REX.__module__}.REX object']) def test_bad_newobj_args(self): obj = REX((copyreg.__newobj__, ())) @@ -1752,7 +1650,7 @@ def test_bad_newobj_args(self): 'tuple index out of range', '__newobj__ expected at least 1 argument, got 0'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) obj = REX((copyreg.__newobj__, [REX])) for proto in protocols[2:]: @@ -1763,7 +1661,7 @@ def test_bad_newobj_args(self): 'second item of the tuple returned by __reduce__ ' 'must be a tuple, not list') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_bad_newobj_class(self): obj = REX((copyreg.__newobj__, (NoNew(),))) @@ -1773,9 +1671,9 @@ def test_bad_newobj_class(self): self.dumps(obj, proto) self.assertIn(str(cm.exception), { 'first argument to __newobj__() has no __new__', - f'first argument to __newobj__() must be a class, not {__name__}.NoNew'}) + f'first argument to __newobj__() must be a class, not {NoNew.__module__}.NoNew'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_wrong_newobj_class(self): obj = REX((copyreg.__newobj__, (str,))) @@ -1786,7 +1684,7 @@ def test_wrong_newobj_class(self): self.assertEqual(str(cm.exception), f'first argument to __newobj__() must be {REX!r}, not {str!r}') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_newobj_class(self): class LocalREX(REX): pass @@ -1814,13 +1712,13 @@ def test_unpickleable_newobj_args(self): if proto >= 2: self.assertEqual(cm.exception.__notes__, [ 'when serializing tuple item 2', - 'when serializing test.pickletester.REX __new__ arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX __new__ arguments', + f'when serializing {REX.__module__}.REX object']) else: self.assertEqual(cm.exception.__notes__, [ 'when serializing tuple item 3', - 'when serializing test.pickletester.REX reconstructor arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor arguments', + f'when serializing {REX.__module__}.REX object']) def test_bad_newobj_ex_args(self): obj = REX((copyreg.__newobj_ex__, ())) @@ -1832,7 +1730,7 @@ def test_bad_newobj_ex_args(self): 'not enough values to unpack (expected 3, got 0)', '__newobj_ex__ expected 3 arguments, got 0'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) obj = REX((copyreg.__newobj_ex__, 42)) for proto in protocols[2:]: @@ -1843,7 +1741,7 @@ def test_bad_newobj_ex_args(self): 'second item of the tuple returned by __reduce__ ' 'must be a tuple, not int') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) obj = REX((copyreg.__newobj_ex__, (REX, 42, {}))) if self.pickler is pickle._Pickler: @@ -1854,7 +1752,7 @@ def test_bad_newobj_ex_args(self): self.assertEqual(str(cm.exception), 'Value after * must be an iterable, not int') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) else: for proto in protocols[2:]: with self.subTest(proto=proto): @@ -1863,7 +1761,7 @@ def test_bad_newobj_ex_args(self): self.assertEqual(str(cm.exception), 'second argument to __newobj_ex__() must be a tuple, not int') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) obj = REX((copyreg.__newobj_ex__, (REX, (), []))) if self.pickler is pickle._Pickler: @@ -1874,7 +1772,7 @@ def test_bad_newobj_ex_args(self): self.assertEqual(str(cm.exception), 'functools.partial() argument after ** must be a mapping, not list') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) else: for proto in protocols[2:]: with self.subTest(proto=proto): @@ -1883,7 +1781,7 @@ def test_bad_newobj_ex_args(self): self.assertEqual(str(cm.exception), 'third argument to __newobj_ex__() must be a dict, not list') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_bad_newobj_ex__class(self): obj = REX((copyreg.__newobj_ex__, (NoNew(), (), {}))) @@ -1893,9 +1791,9 @@ def test_bad_newobj_ex__class(self): self.dumps(obj, proto) self.assertIn(str(cm.exception), { 'first argument to __newobj_ex__() has no __new__', - f'first argument to __newobj_ex__() must be a class, not {__name__}.NoNew'}) + f'first argument to __newobj_ex__() must be a class, not {NoNew.__module__}.NoNew'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_wrong_newobj_ex_class(self): if self.pickler is not pickle._Pickler: @@ -1908,7 +1806,7 @@ def test_wrong_newobj_ex_class(self): self.assertEqual(str(cm.exception), f'first argument to __newobj_ex__() must be {REX}, not {str}') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_newobj_ex_class(self): class LocalREX(REX): pass @@ -1944,22 +1842,22 @@ def test_unpickleable_newobj_ex_args(self): if proto >= 4: self.assertEqual(cm.exception.__notes__, [ 'when serializing tuple item 2', - 'when serializing test.pickletester.REX __new__ arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX __new__ arguments', + f'when serializing {REX.__module__}.REX object']) elif proto >= 2: self.assertEqual(cm.exception.__notes__, [ 'when serializing tuple item 3', 'when serializing tuple item 1', 'when serializing functools.partial state', 'when serializing functools.partial object', - 'when serializing test.pickletester.REX reconstructor', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor', + f'when serializing {REX.__module__}.REX object']) else: self.assertEqual(cm.exception.__notes__, [ 'when serializing tuple item 2', 'when serializing tuple item 1', - 'when serializing test.pickletester.REX reconstructor arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor arguments', + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_newobj_ex_kwargs(self): obj = REX((copyreg.__newobj_ex__, (REX, (), {'a': UNPICKLEABLE}))) @@ -1970,22 +1868,22 @@ def test_unpickleable_newobj_ex_kwargs(self): if proto >= 4: self.assertEqual(cm.exception.__notes__, [ "when serializing dict item 'a'", - 'when serializing test.pickletester.REX __new__ arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX __new__ arguments', + f'when serializing {REX.__module__}.REX object']) elif proto >= 2: self.assertEqual(cm.exception.__notes__, [ "when serializing dict item 'a'", 'when serializing tuple item 2', 'when serializing functools.partial state', 'when serializing functools.partial object', - 'when serializing test.pickletester.REX reconstructor', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor', + f'when serializing {REX.__module__}.REX object']) else: self.assertEqual(cm.exception.__notes__, [ "when serializing dict item 'a'", 'when serializing tuple item 2', - 'when serializing test.pickletester.REX reconstructor arguments', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX reconstructor arguments', + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_state(self): obj = REX_state(UNPICKLEABLE) @@ -1994,8 +1892,8 @@ def test_unpickleable_state(self): with self.assertRaises(CustomError) as cm: self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX_state state', - 'when serializing test.pickletester.REX_state object']) + f'when serializing {REX_state.__module__}.REX_state state', + f'when serializing {REX_state.__module__}.REX_state object']) def test_bad_state_setter(self): if self.pickler is pickle._Pickler: @@ -2009,7 +1907,7 @@ def test_bad_state_setter(self): 'sixth item of the tuple returned by __reduce__ ' 'must be callable, not int') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_state_setter(self): obj = REX((print, (), 'state', None, None, UnpickleableCallable())) @@ -2018,8 +1916,8 @@ def test_unpickleable_state_setter(self): with self.assertRaises(CustomError) as cm: self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX state setter', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX state setter', + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_state_with_state_setter(self): obj = REX((print, (), UNPICKLEABLE, None, None, print)) @@ -2028,8 +1926,8 @@ def test_unpickleable_state_with_state_setter(self): with self.assertRaises(CustomError) as cm: self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX state', - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX state', + f'when serializing {REX.__module__}.REX object']) def test_bad_object_list_items(self): # Issue4176: crash when 4th and 5th items of __reduce__() @@ -2044,7 +1942,7 @@ def test_bad_object_list_items(self): 'fourth item of the tuple returned by __reduce__ ' 'must be an iterator, not int'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) if self.pickler is not pickle._Pickler: # Python implementation is less strict and also accepts iterables. @@ -2057,7 +1955,7 @@ def test_bad_object_list_items(self): 'fourth item of the tuple returned by __reduce__ ' 'must be an iterator, not int') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_object_list_items(self): obj = REX_six([1, 2, UNPICKLEABLE]) @@ -2066,8 +1964,8 @@ def test_unpickleable_object_list_items(self): with self.assertRaises(CustomError) as cm: self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX_six item 2', - 'when serializing test.pickletester.REX_six object']) + f'when serializing {REX_six.__module__}.REX_six item 2', + f'when serializing {REX_six.__module__}.REX_six object']) def test_bad_object_dict_items(self): # Issue4176: crash when 4th and 5th items of __reduce__() @@ -2082,7 +1980,7 @@ def test_bad_object_dict_items(self): 'fifth item of the tuple returned by __reduce__ ' 'must be an iterator, not int'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) for proto in protocols: obj = REX((dict, (), None, None, iter([('a',)]))) @@ -2093,7 +1991,7 @@ def test_bad_object_dict_items(self): 'not enough values to unpack (expected 2, got 1)', 'dict items iterator must return 2-tuples'}) self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) if self.pickler is not pickle._Pickler: # Python implementation is less strict and also accepts iterables. @@ -2105,7 +2003,7 @@ def test_bad_object_dict_items(self): self.assertEqual(str(cm.exception), 'dict items iterator must return 2-tuples') self.assertEqual(cm.exception.__notes__, [ - 'when serializing test.pickletester.REX object']) + f'when serializing {REX.__module__}.REX object']) def test_unpickleable_object_dict_items(self): obj = REX_seven({'a': UNPICKLEABLE}) @@ -2114,8 +2012,8 @@ def test_unpickleable_object_dict_items(self): with self.assertRaises(CustomError) as cm: self.dumps(obj, proto) self.assertEqual(cm.exception.__notes__, [ - "when serializing test.pickletester.REX_seven item 'a'", - 'when serializing test.pickletester.REX_seven object']) + f"when serializing {REX_seven.__module__}.REX_seven item 'a'", + f'when serializing {REX_seven.__module__}.REX_seven object']) def test_unpickleable_list_items(self): obj = [1, [2, 3, UNPICKLEABLE]] @@ -2208,15 +2106,15 @@ def test_unpickleable_frozenset_items(self): def test_global_lookup_error(self): # Global name does not exist obj = REX('spam') - obj.__module__ = __name__ + obj.__module__ = 'test.picklecommon' for proto in protocols: with self.subTest(proto=proto): with self.assertRaises(pickle.PicklingError) as cm: self.dumps(obj, proto) self.assertEqual(str(cm.exception), - f"Can't pickle {obj!r}: it's not found as {__name__}.spam") + f"Can't pickle {obj!r}: it's not found as test.picklecommon.spam") self.assertEqual(str(cm.exception.__context__), - f"module '{__name__}' has no attribute 'spam'") + "module 'test.picklecommon' has no attribute 'spam'") obj.__module__ = 'nonexisting' for proto in protocols: @@ -2371,6 +2269,7 @@ def test_reduce_None(self): with self.assertRaises(TypeError): self.dumps(c) + @support.skip_if_unlimited_stack_size @no_tracing def test_bad_getattr(self): # Issue #3514: crash when there is an infinite loop in __getattr__ @@ -2381,6 +2280,7 @@ def test_bad_getattr(self): for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): s = self.dumps(x, proto) + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'pickle' has no attribute 'PickleBuffer' def test_picklebuffer_error(self): # PickleBuffer forbidden with protocol < 5 pb = pickle.PickleBuffer(b"foobar") @@ -2437,7 +2337,7 @@ def persistent_id(self, obj): def test_bad_ext_code(self): # This should never happen in normal circumstances, because the type # and the value of the extension code is checked in copyreg.add_extension(). - key = (__name__, 'MyList') + key = (MyList.__module__, 'MyList') def check(code, exc): assert key not in copyreg._extension_registry assert code not in copyreg._inverted_registry @@ -2459,6 +2359,7 @@ def check(code, exc): class AbstractPickleTests: # Subclass must define self.dumps, self.loads. + py_version = sys.version_info # for test_xpickle optimized = False _testdata = AbstractUnpickleTests._testdata @@ -2471,24 +2372,33 @@ def setUp(self): def test_misc(self): # test various datatypes not tested by testdata for proto in protocols: - x = myint(4) - s = self.dumps(x, proto) - y = self.loads(s) - self.assert_is_copy(x, y) + with self.subTest('myint', proto=proto): + if self.py_version < (3, 0) and proto < 2: + self.skipTest('int subclasses are not interoperable with Python 2') + x = myint(4) + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) - x = (1, ()) - s = self.dumps(x, proto) - y = self.loads(s) - self.assert_is_copy(x, y) + with self.subTest('tuple', proto=proto): + x = (1, ()) + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) - x = initarg(1, x) - s = self.dumps(x, proto) - y = self.loads(s) - self.assert_is_copy(x, y) + with self.subTest('initarg', proto=proto): + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') + x = initarg(1, x) + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) # XXX test __reduce__ protocol? def test_roundtrip_equality(self): + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') expected = self._testdata for proto in protocols: s = self.dumps(expected, proto) @@ -2685,6 +2595,8 @@ def test_recursive_tuple_and_dict_like_key(self): self._test_recursive_tuple_and_dict_key(REX_seven, asdict=lambda x: x.table) def test_recursive_set(self): + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') # Set containing an immutable object containing the original set. y = set() y.add(K(y)) @@ -2708,6 +2620,8 @@ def test_recursive_set(self): def test_recursive_inst(self): # Mutable object containing itself. + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') i = Object() i.attr = i for proto in protocols: @@ -2718,6 +2632,8 @@ def test_recursive_inst(self): self.assertIs(x.attr, x) def test_recursive_multi(self): + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') l = [] d = {1:l} i = Object() @@ -2732,39 +2648,49 @@ def test_recursive_multi(self): self.assertEqual(list(x[0].attr.keys()), [1]) self.assertIs(x[0].attr[1], x) - def _test_recursive_collection_and_inst(self, factory): + def _test_recursive_collection_and_inst(self, factory, oldminproto=None): + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') # Mutable object containing a collection containing the original # object. o = Object() o.attr = factory([o]) t = type(o.attr) - for proto in protocols: - s = self.dumps(o, proto) - x = self.loads(s) - self.assertIsInstance(x.attr, t) - self.assertEqual(len(x.attr), 1) - self.assertIsInstance(list(x.attr)[0], Object) - self.assertIs(list(x.attr)[0], x) + with self.subTest('obj -> {t.__name__} -> obj'): + for proto in protocols: + with self.subTest(proto=proto): + s = self.dumps(o, proto) + x = self.loads(s) + self.assertIsInstance(x.attr, t) + self.assertEqual(len(x.attr), 1) + self.assertIsInstance(list(x.attr)[0], Object) + self.assertIs(list(x.attr)[0], x) # Collection containing a mutable object containing the original # collection. o = o.attr - for proto in protocols: - s = self.dumps(o, proto) - x = self.loads(s) - self.assertIsInstance(x, t) - self.assertEqual(len(x), 1) - self.assertIsInstance(list(x)[0], Object) - self.assertIs(list(x)[0].attr, x) + with self.subTest(f'{t.__name__} -> obj -> {t.__name__}'): + if self.py_version < (3, 4) and oldminproto is None: + self.skipTest('not supported in Python < 3.4') + for proto in protocols: + with self.subTest(proto=proto): + if self.py_version < (3, 4) and proto < oldminproto: + self.skipTest(f'requires protocol {oldminproto} in Python < 3.4') + s = self.dumps(o, proto) + x = self.loads(s) + self.assertIsInstance(x, t) + self.assertEqual(len(x), 1) + self.assertIsInstance(list(x)[0], Object) + self.assertIs(list(x)[0].attr, x) def test_recursive_list_and_inst(self): - self._test_recursive_collection_and_inst(list) + self._test_recursive_collection_and_inst(list, oldminproto=0) def test_recursive_tuple_and_inst(self): - self._test_recursive_collection_and_inst(tuple) + self._test_recursive_collection_and_inst(tuple, oldminproto=0) def test_recursive_dict_and_inst(self): - self._test_recursive_collection_and_inst(dict.fromkeys) + self._test_recursive_collection_and_inst(dict.fromkeys, oldminproto=0) def test_recursive_set_and_inst(self): self._test_recursive_collection_and_inst(set) @@ -2773,13 +2699,13 @@ def test_recursive_frozenset_and_inst(self): self._test_recursive_collection_and_inst(frozenset) def test_recursive_list_subclass_and_inst(self): - self._test_recursive_collection_and_inst(MyList) + self._test_recursive_collection_and_inst(MyList, oldminproto=2) def test_recursive_tuple_subclass_and_inst(self): self._test_recursive_collection_and_inst(MyTuple) def test_recursive_dict_subclass_and_inst(self): - self._test_recursive_collection_and_inst(MyDict.fromkeys) + self._test_recursive_collection_and_inst(MyDict.fromkeys, oldminproto=2) def test_recursive_set_subclass_and_inst(self): self._test_recursive_collection_and_inst(MySet) @@ -2839,6 +2765,8 @@ def test_unicode_high_plane(self): def test_unicode_memoization(self): # Repeated str is re-used (even when escapes added). + if self.py_version < (3, 0): + self.skipTest('not supported in Python < 3.0') for proto in protocols: for s in '', 'xyz', 'xyz\n', 'x\\yz', 'x\xa1yz\r': p = self.dumps((s, s), proto) @@ -2858,23 +2786,27 @@ def test_bytes(self): self.assert_is_copy(s, self.loads(p)) def test_bytes_memoization(self): + array_types = [bytes] + if self.py_version >= (3, 4): + array_types += [ZeroCopyBytes] for proto in protocols: - for array_type in [bytes, ZeroCopyBytes]: + for array_type in array_types: for s in b'', b'xyz', b'xyz'*100: + b = array_type(s) + expected = (b, b) if self.py_version >= (3, 0) else (b.decode(),)*2 with self.subTest(proto=proto, array_type=array_type, s=s, independent=False): - b = array_type(s) p = self.dumps((b, b), proto) x, y = self.loads(p) self.assertIs(x, y) - self.assert_is_copy((b, b), (x, y)) + self.assert_is_copy(expected, (x, y)) + b2 = array_type(s) with self.subTest(proto=proto, array_type=array_type, s=s, independent=True): - b1, b2 = array_type(s), array_type(s) - p = self.dumps((b1, b2), proto) - # Note that (b1, b2) = self.loads(p) might have identical - # components, i.e., b1 is b2, but this is not always the + p = self.dumps((b, b2), proto) + # Note that (b, b2) = self.loads(p) might have identical + # components, i.e., b is b2, but this is not always the # case if the content is large (equality still holds). - self.assert_is_copy((b1, b2), self.loads(p)) + self.assert_is_copy(expected, self.loads(p)) def test_bytearray(self): for proto in protocols: @@ -2896,8 +2828,11 @@ def test_bytearray(self): self.assertTrue(opcode_in_pickle(pickle.BYTEARRAY8, p)) def test_bytearray_memoization(self): + array_types = [bytearray] + if self.py_version >= (3, 4): + array_types += [ZeroCopyBytearray] for proto in protocols: - for array_type in [bytearray, ZeroCopyBytearray]: + for array_type in array_types: for s in b'', b'xyz', b'xyz'*100: with self.subTest(proto=proto, array_type=array_type, s=s, independent=False): b = array_type(s) @@ -2921,6 +2856,53 @@ def test_bytearray_memoization(self): self.assertIsNot(b2a, b2b) self.assert_is_copy(b2a, b2b) + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'pickle' has no attribute 'PickleBuffer' + def test_picklebuffer_memoization(self): + if self.py_version < (3, 8): + self.skipTest('not supported in Python < 3.8') + array_types = [bytes, bytearray] + for proto in range(5, pickle.HIGHEST_PROTOCOL + 1): + for array_type in array_types: + for s in b'', b'xyz', b'xyz'*100: + with self.subTest(proto=proto, array_type=array_type, s=s, independent=False): + b = pickle.PickleBuffer(array_type(s)) + p = self.dumps((b, b), proto) + b1, b2 = self.loads(p) + self.assertIs(b1, b2) + + with self.subTest(proto=proto, array_type=array_type, s=s, independent=True): + b = array_type(s) + b1a = pickle.PickleBuffer(b) + b2a = pickle.PickleBuffer(b) + p = self.dumps((b1a, b2a), proto) + b1b, b2b = self.loads(p) + if array_type is not bytes: + self.assertIsNot(b1b, b2b) + self.assert_is_copy(b1b, b) + self.assert_is_copy(b2b, b) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'pickle' has no attribute 'PickleBuffer' + def test_empty_picklebuffer_memoization(self): + # gh-148914: Empty writable PickleBuffer memoized an empty bytearray + # with the id of b'' (a singleton in CPython). + if self.py_version < (3, 8): + self.skipTest('not supported in Python < 3.8') + for proto in range(5, pickle.HIGHEST_PROTOCOL + 1): + for readonly in False, True: + with self.subTest(proto=proto, readonly=readonly): + b = b'' + ba = bytearray() + buf = pickle.PickleBuffer(b if readonly else ba) + p = self.dumps((buf, b, ba), proto) + buf, b, ba = self.loads(p) + array_type = bytes if readonly else bytearray + self.assertIsInstance(buf, array_type) + self.assertIsInstance(b, bytes) + self.assertIsInstance(ba, bytearray) + self.assertEqual(buf, b'') + self.assertEqual(b, b'') + self.assertEqual(ba, b'') + def test_ints(self): for proto in protocols: n = sys.maxsize @@ -2971,12 +2953,17 @@ def test_float_format(self): def test_reduce(self): for proto in protocols: - inst = AAA() - dumped = self.dumps(inst, proto) - loaded = self.loads(dumped) - self.assertEqual(loaded, REDUCE_A) + with self.subTest(proto=proto): + if self.py_version < (3, 4) and proto < 3: + self.skipTest('str is not interoperable with Python < 3.4') + inst = AAA() + dumped = self.dumps(inst, proto) + loaded = self.loads(dumped) + self.assertEqual(loaded, REDUCE_A) def test_getinitargs(self): + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') for proto in protocols: inst = initarg(1, 2) dumped = self.dumps(inst, proto) @@ -2984,6 +2971,7 @@ def test_getinitargs(self): self.assert_is_copy(inst, loaded) def test_metaclass(self): + self.assertEqual(type(use_metaclass), metaclass) a = use_metaclass() for proto in protocols: s = self.dumps(a, proto) @@ -3008,6 +2996,10 @@ def test_structseq(self): s = self.dumps(t, proto) u = self.loads(s) self.assert_is_copy(t, u) + if self.py_version < (3, 4): + # module 'os' has no attributes '_make_stat_result' and + # '_make_statvfs_result' + continue t = os.stat(os.curdir) s = self.dumps(t, proto) u = self.loads(s) @@ -3019,52 +3011,111 @@ def test_structseq(self): self.assert_is_copy(t, u) def test_ellipsis(self): + if self.py_version < (3, 3): + self.skipTest('not supported in Python < 3.3') for proto in protocols: - s = self.dumps(..., proto) - u = self.loads(s) - self.assertIs(..., u) + with self.subTest(proto=proto): + s = self.dumps(..., proto) + u = self.loads(s) + self.assertIs(..., u) def test_notimplemented(self): + if self.py_version < (3, 3): + self.skipTest('not supported in Python < 3.3') for proto in protocols: - s = self.dumps(NotImplemented, proto) - u = self.loads(s) - self.assertIs(NotImplemented, u) + with self.subTest(proto=proto): + s = self.dumps(NotImplemented, proto) + u = self.loads(s) + self.assertIs(NotImplemented, u) def test_singleton_types(self): # Issue #6477: Test that types of built-in singletons can be pickled. + if self.py_version < (3, 3): + self.skipTest('not supported in Python < 3.3') singletons = [None, ..., NotImplemented] for singleton in singletons: + t = type(singleton) for proto in protocols: - s = self.dumps(type(singleton), proto) - u = self.loads(s) - self.assertIs(type(singleton), u) + with self.subTest(name=t.__name__, proto=proto): + s = self.dumps(t, proto) + u = self.loads(s) + self.assertIs(t, u) def test_builtin_types(self): + new_names = { + 'bytes': (3, 0), + 'BuiltinImporter': (3, 3), + 'str': (3, 4), # not interoperable with Python < 3.4 + } for t in builtins.__dict__.values(): if isinstance(t, type) and not issubclass(t, BaseException): + if t.__name__ in new_names and self.py_version < new_names[t.__name__]: + continue for proto in protocols: - s = self.dumps(t, proto) - self.assertIs(self.loads(s), t) + with self.subTest(name=t.__name__, proto=proto): + s = self.dumps(t, proto) + self.assertIs(self.loads(s), t) def test_builtin_exceptions(self): + new_names = { + 'BlockingIOError': (3, 3), + 'BrokenPipeError': (3, 3), + 'ChildProcessError': (3, 3), + 'ConnectionError': (3, 3), + 'ConnectionAbortedError': (3, 3), + 'ConnectionRefusedError': (3, 3), + 'ConnectionResetError': (3, 3), + 'FileExistsError': (3, 3), + 'FileNotFoundError': (3, 3), + 'InterruptedError': (3, 3), + 'IsADirectoryError': (3, 3), + 'NotADirectoryError': (3, 3), + 'PermissionError': (3, 3), + 'ProcessLookupError': (3, 3), + 'TimeoutError': (3, 3), + 'RecursionError': (3, 5), + 'StopAsyncIteration': (3, 5), + 'ModuleNotFoundError': (3, 6), + 'EncodingWarning': (3, 10), + 'BaseExceptionGroup': (3, 11), + 'ExceptionGroup': (3, 11), + '_IncompleteInputError': (3, 13), + 'PythonFinalizationError': (3, 13), + } for t in builtins.__dict__.values(): if isinstance(t, type) and issubclass(t, BaseException): + if t.__name__ in new_names and self.py_version < new_names[t.__name__]: + continue for proto in protocols: - s = self.dumps(t, proto) - u = self.loads(s) - if proto <= 2 and issubclass(t, OSError) and t is not BlockingIOError: - self.assertIs(u, OSError) - elif proto <= 2 and issubclass(t, ImportError): - self.assertIs(u, ImportError) - else: - self.assertIs(u, t) + with self.subTest(name=t.__name__, proto=proto): + if self.py_version < (3, 3) and proto < 3: + self.skipTest('exception classes are not interoperable with Python < 3.3') + s = self.dumps(t, proto) + u = self.loads(s) + if proto <= 2 and issubclass(t, OSError) and t is not BlockingIOError: + self.assertIs(u, OSError) + elif proto <= 2 and issubclass(t, ImportError): + self.assertIs(u, ImportError) + else: + self.assertIs(u, t) def test_builtin_functions(self): + new_names = { + '__build_class__': (3, 0), + 'ascii': (3, 0), + 'exec': (3, 0), + 'breakpoint': (3, 7), + 'aiter': (3, 10), + 'anext': (3, 10), + } for t in builtins.__dict__.values(): if isinstance(t, types.BuiltinFunctionType): + if t.__name__ in new_names and self.py_version < new_names[t.__name__]: + continue for proto in protocols: - s = self.dumps(t, proto) - self.assertIs(self.loads(s), t) + with self.subTest(name=t.__name__, proto=proto): + s = self.dumps(t, proto) + self.assertIs(self.loads(s), t) # Tests for protocol 2 @@ -3077,6 +3128,9 @@ def test_proto(self): else: self.assertEqual(count_opcode(pickle.PROTO, pickled), 0) + def test_bad_proto(self): + if self.py_version < (3, 8): + self.skipTest('no protocol validation in Python < 3.8') oob = protocols[-1] + 1 # a future protocol build_none = pickle.NONE + pickle.STOP badpickle = pickle.PROTO + bytes([oob]) + build_none @@ -3188,57 +3242,65 @@ def test_newobj_list(self): def test_newobj_generic(self): for proto in protocols: for C in myclasses: - B = C.__base__ - x = C(C.sample) - x.foo = 42 - s = self.dumps(x, proto) - y = self.loads(s) - detail = (proto, C, B, x, y, type(y)) - self.assert_is_copy(x, y) # XXX revisit - self.assertEqual(B(x), B(y), detail) - self.assertEqual(x.__dict__, y.__dict__, detail) + with self.subTest(proto=proto, C=C): + if self.py_version < (3, 0) and proto < 2 and C in (MyInt, MyStr): + self.skipTest('int and str subclasses are not interoperable with Python 2') + if (3, 0) <= self.py_version < (3, 4) and proto < 2 and C in (MyStr, MyUnicode): + self.skipTest('str subclasses are not interoperable with Python < 3.4') + B = C.__base__ + x = C(C.sample) + x.foo = 42 + s = self.dumps(x, proto) + y = self.loads(s) + detail = (proto, C, B, x, y, type(y)) + self.assert_is_copy(x, y) # XXX revisit + self.assertEqual(B(x), B(y), detail) + self.assertEqual(x.__dict__, y.__dict__, detail) def test_newobj_proxies(self): # NEWOBJ should use the __class__ rather than the raw type classes = myclasses[:] # Cannot create weakproxies to these classes - for c in (MyInt, MyTuple): + for c in (MyInt, MyLong, MyTuple): classes.remove(c) for proto in protocols: for C in classes: - B = C.__base__ - x = C(C.sample) - x.foo = 42 - p = weakref.proxy(x) - s = self.dumps(p, proto) - y = self.loads(s) - self.assertEqual(type(y), type(x)) # rather than type(p) - detail = (proto, C, B, x, y, type(y)) - self.assertEqual(B(x), B(y), detail) - self.assertEqual(x.__dict__, y.__dict__, detail) + with self.subTest(proto=proto, C=C): + if self.py_version < (3, 4) and proto < 3 and C in (MyStr, MyUnicode): + self.skipTest('str subclasses are not interoperable with Python < 3.4') + B = C.__base__ + x = C(C.sample) + x.foo = 42 + p = weakref.proxy(x) + s = self.dumps(p, proto) + y = self.loads(s) + self.assertEqual(type(y), type(x)) # rather than type(p) + detail = (proto, C, B, x, y, type(y)) + self.assertEqual(B(x), B(y), detail) + self.assertEqual(x.__dict__, y.__dict__, detail) def test_newobj_overridden_new(self): # Test that Python class with C implemented __new__ is pickleable for proto in protocols: - x = MyIntWithNew2(1) - x.foo = 42 - s = self.dumps(x, proto) - y = self.loads(s) - self.assertIs(type(y), MyIntWithNew2) - self.assertEqual(int(y), 1) - self.assertEqual(y.foo, 42) + with self.subTest(proto=proto): + if self.py_version < (3, 0) and proto < 2: + self.skipTest('int subclasses are not interoperable with Python 2') + x = MyIntWithNew2(1) + x.foo = 42 + s = self.dumps(x, proto) + y = self.loads(s) + self.assertIs(type(y), MyIntWithNew2) + self.assertEqual(int(y), 1) + self.assertEqual(y.foo, 42) def test_newobj_not_class(self): # Issue 24552 - global SimpleNewObj - save = SimpleNewObj + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') o = SimpleNewObj.__new__(SimpleNewObj) b = self.dumps(o, 4) - try: - SimpleNewObj = 42 + with support.swap_attr(picklecommon, 'SimpleNewObj', 42): self.assertRaises((TypeError, pickle.UnpicklingError), self.loads, b) - finally: - SimpleNewObj = save # Register a type with copyreg, with extension code extcode. Pickle # an object of that type. Check that the resulting pickle uses opcode @@ -3247,14 +3309,14 @@ def test_newobj_not_class(self): def produce_global_ext(self, extcode, opcode): e = ExtensionSaver(extcode) try: - copyreg.add_extension(__name__, "MyList", extcode) + copyreg.add_extension(MyList.__module__, "MyList", extcode) x = MyList([1, 2, 3]) x.foo = 42 x.bar = "hello" # Dump using protocol 1 for comparison. s1 = self.dumps(x, 1) - self.assertIn(__name__.encode("utf-8"), s1) + self.assertIn(MyList.__module__.encode(), s1) self.assertIn(b"MyList", s1) self.assertFalse(opcode_in_pickle(opcode, s1)) @@ -3263,7 +3325,7 @@ def produce_global_ext(self, extcode, opcode): # Dump using protocol 2 for test. s2 = self.dumps(x, 2) - self.assertNotIn(__name__.encode("utf-8"), s2) + self.assertNotIn(MyList.__module__.encode(), s2) self.assertNotIn(b"MyList", s2) self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2)) @@ -3361,14 +3423,20 @@ def test_simple_newobj(self): x.abc = 666 for proto in protocols: with self.subTest(proto=proto): + if self.py_version < (3, 0) and proto < 2: + self.skipTest('int subclasses are not interoperable with Python 2') s = self.dumps(x, proto) if proto < 1: - self.assertIn(b'\nI64206', s) # INT + if self.py_version >= (3, 7): + self.assertIn(b'\nI64206', s) # INT + else: # for test_xpickle + self.assertIn(b'64206', s) # INT or LONG else: self.assertIn(b'M\xce\xfa', s) # BININT2 - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), - 2 <= proto) - self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + if not (self.py_version < (3, 5) and proto == 4): + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) y = self.loads(s) # will raise TypeError if __init__ called self.assert_is_copy(x, y) @@ -3377,29 +3445,45 @@ def test_complex_newobj(self): x.abc = 666 for proto in protocols: with self.subTest(proto=proto): + if self.py_version < (3, 0) and proto < 2: + self.skipTest('int subclasses are not interoperable with Python 2') s = self.dumps(x, proto) if proto < 1: - self.assertIn(b'\nI64206', s) # INT + if self.py_version >= (3, 7): + self.assertIn(b'\nI64206', s) # INT + else: # for test_xpickle + self.assertIn(b'64206', s) # INT or LONG elif proto < 2: self.assertIn(b'M\xce\xfa', s) # BININT2 elif proto < 4: - self.assertIn(b'X\x04\x00\x00\x00FACE', s) # BINUNICODE + if self.py_version >= (3, 0): + self.assertIn(b'X\x04\x00\x00\x00FACE', s) # BINUNICODE + else: # for test_xpickle + self.assertIn(b'U\x04FACE', s) # SHORT_BINSTRING else: self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), - 2 <= proto) - self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + if not (self.py_version < (3, 5) and proto == 4): + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) y = self.loads(s) # will raise TypeError if __init__ called self.assert_is_copy(x, y) def test_complex_newobj_ex(self): + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') x = ComplexNewObjEx.__new__(ComplexNewObjEx, 0xface) # avoid __init__ x.abc = 666 for proto in protocols: with self.subTest(proto=proto): + if self.py_version < (3, 6) and proto < 4: + self.skipTest('requires protocol 4 in Python < 3.6') s = self.dumps(x, proto) if proto < 1: - self.assertIn(b'\nI64206', s) # INT + if self.py_version >= (3, 7): + self.assertIn(b'\nI64206', s) # INT + else: # for test_xpickle + self.assertIn(b'64206', s) # INT or LONG elif proto < 2: self.assertIn(b'M\xce\xfa', s) # BININT2 elif proto < 4: @@ -3487,6 +3571,8 @@ def test_many_puts_and_gets(self): def test_attribute_name_interning(self): # Test that attribute names of pickled objects are interned when # unpickling. + if self.py_version < (3, 0): + self.skipTest('"classic" classes are not interoperable with Python 2') for proto in protocols: x = C() x.foo = 42 @@ -3516,10 +3602,14 @@ def test_large_pickles(self): dumped = self.dumps(data, proto) loaded = self.loads(dumped) self.assertEqual(len(loaded), len(data)) + if self.py_version < (3, 0): + data = (1, min, 'xy' * (30 * 1024), len) self.assertEqual(loaded, data) def test_int_pickling_efficiency(self): # Test compacity of int representation (see issue #12744) + if self.py_version < (3, 3): + self.skipTest('not supported in Python < 3.3') for proto in protocols: with self.subTest(proto=proto): pickles = [self.dumps(2**n, proto) for n in range(70)] @@ -3540,10 +3630,13 @@ def test_appends_on_non_lists(self): # Issue #17720 obj = REX_six([1, 2, 3]) for proto in protocols: - if proto == 0: - self._check_pickling_with_opcode(obj, pickle.APPEND, proto) - else: - self._check_pickling_with_opcode(obj, pickle.APPENDS, proto) + with self.subTest(proto=proto): + if proto == 0: + self._check_pickling_with_opcode(obj, pickle.APPEND, proto) + else: + if self.py_version < (3, 0): + self.skipTest('not supported in Python 2') + self._check_pickling_with_opcode(obj, pickle.APPENDS, proto) def test_setitems_on_non_dicts(self): obj = REX_seven({1: -1, 2: -2, 3: -3}) @@ -3608,6 +3701,8 @@ def check_frame_opcodes(self, pickled): @support.skip_if_pgo_task @support.requires_resource('cpu') def test_framing_many_objects(self): + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') obj = list(range(10**5)) for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): with self.subTest(proto=proto): @@ -3623,6 +3718,8 @@ def test_framing_many_objects(self): self.check_frame_opcodes(pickled) def test_framing_large_objects(self): + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') N = 1024 * 1024 small_items = [[i] for i in range(10)] obj = [b'x' * N, *small_items, b'y' * N, 'z' * N] @@ -3648,15 +3745,16 @@ def test_framing_large_objects(self): [len(x) for x in unpickled]) # Perform full equality check if the lengths match. self.assertEqual(obj, unpickled) - n_frames = count_opcode(pickle.FRAME, pickled) - # A single frame for small objects between - # first two large objects. - self.assertEqual(n_frames, 1) - self.check_frame_opcodes(pickled) + if self.py_version >= (3, 7): + n_frames = count_opcode(pickle.FRAME, pickled) + # A single frame for small objects between + # first two large objects. + self.assertEqual(n_frames, 1) + self.check_frame_opcodes(pickled) def test_optional_frames(self): - if pickle.HIGHEST_PROTOCOL < 4: - return + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') def remove_frames(pickled, keep_frame=None): """Remove frame opcodes from the given pickle.""" @@ -3698,6 +3796,9 @@ def remove_frames(pickled, keep_frame=None): @support.skip_if_pgo_task def test_framed_write_sizes_with_delayed_writer(self): + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') + class ChunkAccumulator: """Accumulate pickler output in a list of raw chunks.""" def __init__(self): @@ -3763,13 +3864,12 @@ def concatenate_chunks(self): chunk_sizes) def test_nested_names(self): - global Nested - class Nested: - class A: - class B: - class C: - pass + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') + # required protocol 4 in Python 3.4 for proto in range(pickle.HIGHEST_PROTOCOL + 1): + if self.py_version < (3, 5) and proto < 4: + continue for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: with self.subTest(proto=proto, obj=obj): unpickled = self.loads(self.dumps(obj, proto)) @@ -3800,35 +3900,21 @@ class Recursive: del Recursive.ref # break reference loop def test_py_methods(self): - global PyMethodsTest - class PyMethodsTest: - @staticmethod - def cheese(): - return "cheese" - @classmethod - def wine(cls): - assert cls is PyMethodsTest - return "wine" - def biscuits(self): - assert isinstance(self, PyMethodsTest) - return "biscuits" - class Nested: - "Nested class" - @staticmethod - def ketchup(): - return "ketchup" - @classmethod - def maple(cls): - assert cls is PyMethodsTest.Nested - return "maple" - def pie(self): - assert isinstance(self, PyMethodsTest.Nested) - return "pie" - + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') py_methods = ( - PyMethodsTest.cheese, PyMethodsTest.wine, PyMethodsTest().biscuits, + ) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + for method in py_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(), unpickled()) + + # required protocol 4 in Python 3.4 + py_methods = ( + PyMethodsTest.cheese, PyMethodsTest.Nested.ketchup, PyMethodsTest.Nested.maple, PyMethodsTest.Nested().pie @@ -3838,6 +3924,8 @@ def pie(self): (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) ) for proto in range(pickle.HIGHEST_PROTOCOL + 1): + if self.py_version < (3, 5) and proto < 4: + continue for method in py_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -3858,11 +3946,8 @@ def pie(self): self.assertRaises(TypeError, self.dumps, descr, proto) def test_c_methods(self): - global Subclass - class Subclass(tuple): - class Nested(str): - pass - + if self.py_version < (3, 4): + self.skipTest('not supported in Python < 3.4') c_methods = ( # bound built-in method ("abcd".index, ("c",)), @@ -3883,7 +3968,6 @@ class Nested(str): # subclass methods (Subclass([1,2,2]).count, (2,)), (Subclass.count, (Subclass([1,2,2]), 2)), - (Subclass.Nested("sweet").count, ("e",)), (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), ) for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -3892,6 +3976,18 @@ class Nested(str): unpickled = self.loads(self.dumps(method, proto)) self.assertEqual(method(*args), unpickled(*args)) + # required protocol 4 in Python 3.4 + c_methods = ( + (Subclass.Nested("sweet").count, ("e",)), + ) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + if self.py_version < (3, 5) and proto < 4: + continue + for method, args in c_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(*args), unpickled(*args)) + descriptors = ( bytearray.__dict__['maketrans'], # built-in static method descriptor dict.__dict__['fromkeys'], # built-in class method descriptor @@ -3901,7 +3997,85 @@ class Nested(str): with self.subTest(proto=proto, descr=descr): self.assertRaises(TypeError, self.dumps, descr, proto) + def test_object_with_attrs(self): + obj = Object() + obj.a = 1 + for proto in protocols: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.a, obj.a) + + def test_object_with_slots(self): + obj = WithSlots() + obj.a = 1 + self.assertRaises(TypeError, self.dumps, obj, 0) + self.assertRaises(TypeError, self.dumps, obj, 1) + for proto in protocols[2:]: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.a, obj.a) + self.assertNotHasAttr(unpickled, 'b') + + obj = WithSlotsSubclass() + obj.a = 1 + obj.c = 2 + self.assertRaises(TypeError, self.dumps, obj, 0) + self.assertRaises(TypeError, self.dumps, obj, 1) + for proto in protocols[2:]: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.a, obj.a) + self.assertEqual(unpickled.c, obj.c) + self.assertNotHasAttr(unpickled, 'b') + + obj = WithSlotsAndDict() + obj.a = 1 + obj.c = 2 + self.assertRaises(TypeError, self.dumps, obj, 0) + self.assertRaises(TypeError, self.dumps, obj, 1) + for proto in protocols[2:]: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.a, obj.a) + self.assertEqual(unpickled.c, obj.c) + self.assertEqual(unpickled.__dict__, obj.__dict__) + self.assertNotHasAttr(unpickled, 'b') + + def test_object_with_private_attrs(self): + obj = WithPrivateAttrs(1) + for proto in protocols: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.get(), obj.get()) + + obj = WithPrivateAttrsSubclass(1, 2) + for proto in protocols: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.get(), obj.get()) + self.assertEqual(unpickled.get2(), obj.get2()) + + def test_object_with_private_slots(self): + obj = WithPrivateSlots(1) + self.assertRaises(TypeError, self.dumps, obj, 0) + self.assertRaises(TypeError, self.dumps, obj, 1) + for proto in protocols[2:]: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.get(), obj.get()) + + obj = WithPrivateSlotsSubclass(1, 2) + self.assertRaises(TypeError, self.dumps, obj, 0) + self.assertRaises(TypeError, self.dumps, obj, 1) + for proto in protocols[2:]: + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertEqual(unpickled.get(), obj.get()) + self.assertEqual(unpickled.get2(), obj.get2()) + def test_compat_pickle(self): + if self.py_version < (3, 4): + self.skipTest("doesn't work in Python < 3.4'") tests = [ (range(1, 7), '__builtin__', 'xrange'), (map(int, '123'), 'itertools', 'imap'), @@ -4090,6 +4264,37 @@ def check_array(arr): # 2-D, non-contiguous check_array(arr[::2]) + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'pickle' has no attribute 'PickleBuffer' + def test_concurrent_mutation_in_buffer_with_bytearray(self): + def factory(): + s = b"a" * 16 + return bytearray(s), s + self.do_test_concurrent_mutation_in_buffer_callback(factory) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'pickle' has no attribute 'PickleBuffer' + def test_concurrent_mutation_in_buffer_with_memoryview(self): + def factory(): + obj = memoryview(b"a" * 32)[10:26] + sub = b"a" * len(obj) + return obj, sub + self.do_test_concurrent_mutation_in_buffer_callback(factory) + + def do_test_concurrent_mutation_in_buffer_callback(self, factory): + # See: https://github.com/python/cpython/issues/143308. + class R: + def __bool__(self): + buf.release() + return True + + for proto in range(5, pickle.HIGHEST_PROTOCOL + 1): + obj, sub = factory() + buf = pickle.PickleBuffer(obj) + buffer_callback = lambda _: R() + + with self.subTest(proto=proto, obj=obj, sub=sub): + res = self.dumps(buf, proto, buffer_callback=buffer_callback) + self.assertIn(sub, res) + def test_evil_class_mutating_dict(self): # https://github.com/python/cpython/issues/92930 from random import getrandbits @@ -4120,6 +4325,94 @@ def __reduce__(self): expected = "changed size during iteration" self.assertIn(expected, str(e)) + def fast_save_enter(self, create_data, minprotocol=0): + # gh-146059: Check that fast_save_leave() is called when + # fast_save_enter() is called. + if not hasattr(self, "pickler"): + self.skipTest("need Pickler class") + + data = [create_data(i) for i in range(FAST_NESTING_LIMIT * 2)] + protocols = range(minprotocol, pickle.HIGHEST_PROTOCOL + 1) + for proto in protocols: + with self.subTest(proto=proto): + buf = io.BytesIO() + pickler = self.pickler(buf, protocol=proto) + # Enable fast mode (disables memo, enables cycle detection) + pickler.fast = 1 + pickler.dump(data) + + buf.seek(0) + data2 = self.unpickler(buf).load() + self.assertEqual(data2, data) + + def test_fast_save_enter_tuple(self): + self.fast_save_enter(lambda i: (i,)) + + def test_fast_save_enter_list(self): + self.fast_save_enter(lambda i: [i]) + + def test_fast_save_enter_frozenset(self): + self.fast_save_enter(lambda i: frozenset([i])) + + def test_fast_save_enter_set(self): + self.fast_save_enter(lambda i: set([i])) + + def test_fast_save_enter_frozendict(self): + if self.py_version < (3, 15): + self.skipTest('need frozendict') + self.fast_save_enter(lambda i: frozendict(key=i), minprotocol=2) + + def test_fast_save_enter_dict(self): + self.fast_save_enter(lambda i: {"key": i}) + + def deep_nested_struct(self, create_nested, + minprotocol=0, compare_equal=True, + depth=FAST_NESTING_LIMIT * 2): + # gh-146059: Check that fast_save_leave() is called when + # fast_save_enter() is called. + if not hasattr(self, "pickler"): + self.skipTest("need Pickler class") + + data = None + for i in range(depth): + data = create_nested(data) + protocols = range(minprotocol, pickle.HIGHEST_PROTOCOL + 1) + for proto in protocols: + with self.subTest(proto=proto): + buf = io.BytesIO() + pickler = self.pickler(buf, protocol=proto) + # Enable fast mode (disables memo, enables cycle detection) + pickler.fast = 1 + pickler.dump(data) + + buf.seek(0) + data2 = self.unpickler(buf).load() + if compare_equal: + self.assertEqual(data2, data) + + def test_deep_nested_struct_tuple(self): + self.deep_nested_struct(lambda data: (data,)) + + def test_deep_nested_struct_list(self): + self.deep_nested_struct(lambda data: [data]) + + def test_deep_nested_struct_frozenset(self): + self.deep_nested_struct(lambda data: frozenset((1, data))) + + def test_deep_nested_struct_set(self): + self.deep_nested_struct(lambda data: {K(data)}, + depth=FAST_NESTING_LIMIT+1, + compare_equal=False) + + def test_deep_nested_struct_frozendict(self): + if self.py_version < (3, 15): + self.skipTest('need frozendict') + self.deep_nested_struct(lambda data: frozendict(x=data), + minprotocol=2) + + def test_deep_nested_struct_dict(self): + self.deep_nested_struct(lambda data: {'x': data}) + class BigmemPickleTests: @@ -4248,110 +4541,6 @@ def test_huge_str_64b(self, size): data = None -# Test classes for reduce_ex - -class R: - def __init__(self, reduce=None): - self.reduce = reduce - def __reduce__(self, proto): - return self.reduce - -class REX: - def __init__(self, reduce_ex=None): - self.reduce_ex = reduce_ex - def __reduce_ex__(self, proto): - return self.reduce_ex - -class REX_one(object): - """No __reduce_ex__ here, but inheriting it from object""" - _reduce_called = 0 - def __reduce__(self): - self._reduce_called = 1 - return REX_one, () - -class REX_two(object): - """No __reduce__ here, but inheriting it from object""" - _proto = None - def __reduce_ex__(self, proto): - self._proto = proto - return REX_two, () - -class REX_three(object): - _proto = None - def __reduce_ex__(self, proto): - self._proto = proto - return REX_two, () - def __reduce__(self): - raise TestFailed("This __reduce__ shouldn't be called") - -class REX_four(object): - """Calling base class method should succeed""" - _proto = None - def __reduce_ex__(self, proto): - self._proto = proto - return object.__reduce_ex__(self, proto) - -class REX_five(object): - """This one used to fail with infinite recursion""" - _reduce_called = 0 - def __reduce__(self): - self._reduce_called = 1 - return object.__reduce__(self) - -class REX_six(object): - """This class is used to check the 4th argument (list iterator) of - the reduce protocol. - """ - def __init__(self, items=None): - self.items = items if items is not None else [] - def __eq__(self, other): - return type(self) is type(other) and self.items == other.items - def append(self, item): - self.items.append(item) - def __reduce__(self): - return type(self), (), None, iter(self.items), None - -class REX_seven(object): - """This class is used to check the 5th argument (dict iterator) of - the reduce protocol. - """ - def __init__(self, table=None): - self.table = table if table is not None else {} - def __eq__(self, other): - return type(self) is type(other) and self.table == other.table - def __setitem__(self, key, value): - self.table[key] = value - def __reduce__(self): - return type(self), (), None, None, iter(self.table.items()) - -class REX_state(object): - """This class is used to check the 3th argument (state) of - the reduce protocol. - """ - def __init__(self, state=None): - self.state = state - def __eq__(self, other): - return type(self) is type(other) and self.state == other.state - def __setstate__(self, state): - self.state = state - def __reduce__(self): - return type(self), (), self.state - -class REX_None: - """ Setting __reduce_ex__ to None should fail """ - __reduce_ex__ = None - -class R_None: - """ Setting __reduce__ to None should fail """ - __reduce__ = None - -class C_None_setstate: - """ Setting __setstate__ to None should fail """ - def __getstate__(self): - return 1 - - __setstate__ = None - class CustomError(Exception): pass @@ -4361,80 +4550,17 @@ def __reduce__(self): UNPICKLEABLE = Unpickleable() +# For test_unpickleable_reconstructor and test_unpickleable_state_setter class UnpickleableCallable(Unpickleable): def __call__(self, *args, **kwargs): pass - -# Test classes for newobj - -class MyInt(int): - sample = 1 - -class MyFloat(float): - sample = 1.0 - -class MyComplex(complex): - sample = 1.0 + 0.0j - -class MyStr(str): - sample = "hello" - -class MyUnicode(str): - sample = "hello \u1234" - -class MyTuple(tuple): - sample = (1, 2, 3) - -class MyList(list): - sample = [1, 2, 3] - -class MyDict(dict): - sample = {"a": 1, "b": 2} - -class MySet(set): - sample = {"a", "b"} - -class MyFrozenSet(frozenset): - sample = frozenset({"a", "b"}) - -myclasses = [MyInt, MyFloat, - MyComplex, - MyStr, MyUnicode, - MyTuple, MyList, MyDict, MySet, MyFrozenSet] - -class MyIntWithNew(int): - def __new__(cls, value): - raise AssertionError - -class MyIntWithNew2(MyIntWithNew): - __new__ = int.__new__ - - -class SlotList(MyList): - __slots__ = ["foo"] - -# Ruff "redefined while unused" false positive here due to `global` variables -# being assigned (and then restored) from within test methods earlier in the file -class SimpleNewObj(int): # noqa: F811 - def __init__(self, *args, **kwargs): - # raise an error, to make sure this isn't called - raise TypeError("SimpleNewObj.__init__() didn't expect to get called") - def __eq__(self, other): - return int(self) == int(other) and self.__dict__ == other.__dict__ - -class ComplexNewObj(SimpleNewObj): - def __getnewargs__(self): - return ('%X' % self, 16) - -class ComplexNewObjEx(SimpleNewObj): - def __getnewargs_ex__(self): - return ('%X' % self,), {'base': 16} - +# For test_bad_getattr class BadGetattr: def __getattr__(self, key): self.foo +# For test_bad_newobj_class and test_bad_newobj_ex__class class NoNew: def __getattribute__(self, name): if name == '__new__': diff --git a/Lib/test/seq_tests.py b/Lib/test/seq_tests.py index 7b2d6521b1e..e8834c2bafc 100644 --- a/Lib/test/seq_tests.py +++ b/Lib/test/seq_tests.py @@ -439,7 +439,7 @@ def test_pickle(self): self.assertEqual(lst2, lst) self.assertNotEqual(id(lst2), id(lst)) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON; hangs") def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, self.type2test) support.check_free_after_iterating(self, reversed, self.type2test) diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 6b3f2c447e8..6635ec3474e 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -548,7 +548,6 @@ def requires_lzma(reason='requires lzma'): import lzma except ImportError: lzma = None - lzma = None # XXX: RUSTPYTHON; xz is not supported yet return unittest.skipUnless(lzma, reason) def requires_zstd(reason='requires zstd'): @@ -856,8 +855,6 @@ def gc_collect(): longer than expected. This function tries its best to force all garbage objects to disappear. """ - return # TODO: RUSTPYTHON - import gc gc.collect() gc.collect() @@ -865,13 +862,6 @@ def gc_collect(): @contextlib.contextmanager def disable_gc(): - # TODO: RUSTPYTHON; GC is not supported yet - try: - yield - finally: - pass - return - import gc have_gc = gc.isenabled() gc.disable() @@ -2004,10 +1994,6 @@ def _check_tracemalloc(): def check_free_after_iterating(test, iter, cls, args=()): - # TODO: RUSTPYTHON; GC is not supported yet - test.assertTrue(False) - return - done = False def wrapper(): class A(cls): @@ -3048,6 +3034,10 @@ def get_signal_name(exitcode): except KeyError: pass + # Format Windows exit status as hexadecimal + if 0xC0000000 <= exitcode: + return f"0x{exitcode:X}" + return None class BrokenIter: diff --git a/Lib/test/support/_hypothesis_stubs/__init__.py b/Lib/test/support/_hypothesis_stubs/__init__.py index 9a57c309616..6ba5bb814b9 100644 --- a/Lib/test/support/_hypothesis_stubs/__init__.py +++ b/Lib/test/support/_hypothesis_stubs/__init__.py @@ -1,6 +1,6 @@ +from enum import Enum import functools import unittest -from enum import Enum __all__ = [ "given", diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py index 98eaf0b2721..173d299afee 100644 --- a/Lib/test/support/ast_helper.py +++ b/Lib/test/support/ast_helper.py @@ -1,6 +1,5 @@ import ast - class ASTTestMixin: """Test mixing to have basic assertions for AST nodes.""" diff --git a/Lib/test/support/asyncore.py b/Lib/test/support/asyncore.py index 658c22fdcee..870e4283764 100644 --- a/Lib/test/support/asyncore.py +++ b/Lib/test/support/asyncore.py @@ -51,27 +51,17 @@ sophisticated high-performance network servers and clients a snap. """ -import os import select import socket import sys import time import warnings -from errno import ( - EAGAIN, - EALREADY, - EBADF, - ECONNABORTED, - ECONNRESET, - EINPROGRESS, - EINVAL, - EISCONN, - ENOTCONN, - EPIPE, - ESHUTDOWN, - EWOULDBLOCK, - errorcode, -) + +import os +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ + ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \ + errorcode + _DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) diff --git a/Lib/test/support/bytecode_helper.py b/Lib/test/support/bytecode_helper.py index 4a3c8c2c4f1..f6426c3e285 100644 --- a/Lib/test/support/bytecode_helper.py +++ b/Lib/test/support/bytecode_helper.py @@ -1,10 +1,9 @@ """bytecode_helper - support tools for testing correct bytecode generation""" +import unittest import dis import io import opcode -import unittest - try: import _testinternalcapi except ImportError: diff --git a/Lib/test/support/channels.py b/Lib/test/support/channels.py index 3f7b46030fd..5352f7d4da3 100644 --- a/Lib/test/support/channels.py +++ b/Lib/test/support/channels.py @@ -1,23 +1,19 @@ """Cross-interpreter Channels High Level Module.""" import time -from concurrent.interpreters import _crossinterp -from concurrent.interpreters._crossinterp import ( - UNBOUND_ERROR, - UNBOUND_REMOVE, -) - import _interpchannels as _channels +from concurrent.interpreters import _crossinterp # aliases: from _interpchannels import ( - ChannelClosedError, - ChannelEmptyError, - ChannelError, - ChannelNotEmptyError, - ChannelNotFoundError, + ChannelError, ChannelNotFoundError, ChannelClosedError, + ChannelEmptyError, ChannelNotEmptyError, +) +from concurrent.interpreters._crossinterp import ( + UNBOUND_ERROR, UNBOUND_REMOVE, ) + __all__ = [ 'UNBOUND', 'UNBOUND_ERROR', 'UNBOUND_REMOVE', 'create', 'list_all', diff --git a/Lib/test/support/hashlib_helper.py b/Lib/test/support/hashlib_helper.py index 75dc2ba7506..7032257b068 100644 --- a/Lib/test/support/hashlib_helper.py +++ b/Lib/test/support/hashlib_helper.py @@ -2,7 +2,6 @@ import hashlib import importlib import unittest - from test.support.import_helper import import_module try: diff --git a/Lib/test/support/hypothesis_helper.py b/Lib/test/support/hypothesis_helper.py index 6e9e168f63a..a99a4963ffe 100644 --- a/Lib/test/support/hypothesis_helper.py +++ b/Lib/test/support/hypothesis_helper.py @@ -7,10 +7,9 @@ else: # Regrtest changes to use a tempdir as the working directory, so we have # to tell Hypothesis to use the original in order to persist the database. - from hypothesis.configuration import set_hypothesis_home_dir - from test.support import has_socket_support from test.support.os_helper import SAVEDCWD + from hypothesis.configuration import set_hypothesis_home_dir set_hypothesis_home_dir(os.path.join(SAVEDCWD, ".hypothesis")) diff --git a/Lib/test/support/i18n_helper.py b/Lib/test/support/i18n_helper.py index af97cdc9cb5..2e304f29e8b 100644 --- a/Lib/test/support/i18n_helper.py +++ b/Lib/test/support/i18n_helper.py @@ -3,10 +3,10 @@ import sys import unittest from pathlib import Path - from test.support import REPO_ROOT, TEST_HOME_DIR, requires_subprocess from test.test_tools import skip_if_missing + pygettext = Path(REPO_ROOT) / 'Tools' / 'i18n' / 'pygettext.py' msgid_pattern = re.compile(r'msgid(.*?)(?:msgid_plural|msgctxt|msgstr)', diff --git a/Lib/test/support/logging_helper.py b/Lib/test/support/logging_helper.py index db556c7f5ad..12fcca4f0f0 100644 --- a/Lib/test/support/logging_helper.py +++ b/Lib/test/support/logging_helper.py @@ -1,6 +1,5 @@ import logging.handlers - class TestHandler(logging.handlers.BufferingHandler): def __init__(self, matcher): # BufferingHandler takes a "capacity" argument diff --git a/Lib/test/support/os_helper.py b/Lib/test/support/os_helper.py index d3d6fa632f9..2c45fe2369e 100644 --- a/Lib/test/support/os_helper.py +++ b/Lib/test/support/os_helper.py @@ -13,6 +13,7 @@ from test import support + # Filename used for testing TESTFN_ASCII = '@test' diff --git a/Lib/test/support/pty_helper.py b/Lib/test/support/pty_helper.py index 7e1ae9e59b8..dbe7fa42909 100644 --- a/Lib/test/support/pty_helper.py +++ b/Lib/test/support/pty_helper.py @@ -10,12 +10,19 @@ from test.support.import_helper import import_module - def run_pty(script, input=b"dummy input\r", env=None): pty = import_module('pty') output = bytearray() [master, slave] = pty.openpty() args = (sys.executable, '-c', script) + + # Isolate readline from personal init files by setting INPUTRC + # to an empty file. See also GH-142353. + if env is None: + env = {**os.environ.copy(), "INPUTRC": os.devnull} + else: + env.setdefault("INPUTRC", os.devnull) + proc = subprocess.Popen(args, stdin=slave, stdout=slave, stderr=slave, env=env) os.close(slave) with ExitStack() as cleanup: diff --git a/Lib/test/support/script_helper.py b/Lib/test/support/script_helper.py index a338f484449..46ce950433d 100644 --- a/Lib/test/support/script_helper.py +++ b/Lib/test/support/script_helper.py @@ -3,16 +3,17 @@ import collections import importlib +import sys import os import os.path -import py_compile import subprocess -import sys -from importlib.util import source_from_cache +import py_compile +from importlib.util import source_from_cache from test import support from test.support.import_helper import make_legacy_pyc + # Cached result of the expensive test performed in the function below. __cached_interp_requires_environment = None diff --git a/Lib/test/support/smtpd.py b/Lib/test/support/smtpd.py index cf333aaf6b0..6537679db9a 100755 --- a/Lib/test/support/smtpd.py +++ b/Lib/test/support/smtpd.py @@ -70,17 +70,16 @@ # - Handle more ESMTP extensions # - handle error codes from the backend smtpd -import collections +import sys +import os import errno import getopt -import os -import socket -import sys import time -from email._header_value_parser import get_addr_spec, get_angle_addr +import socket +import collections +from test.support import asyncore, asynchat from warnings import warn - -from test.support import asynchat, asyncore +from email._header_value_parser import get_addr_spec, get_angle_addr __all__ = [ "SMTPChannel", "SMTPServer", "DebuggingServer", "PureProxy", diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 655ffbea0db..a41e487f3e4 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -2,8 +2,8 @@ import errno import os.path import socket -import subprocess import sys +import subprocess import tempfile import unittest diff --git a/Lib/test/support/strace_helper.py b/Lib/test/support/strace_helper.py index abc93dee2ce..cf95f7bdc7d 100644 --- a/Lib/test/support/strace_helper.py +++ b/Lib/test/support/strace_helper.py @@ -1,11 +1,10 @@ -import os import re import sys import textwrap +import os import unittest from dataclasses import dataclass from functools import cache - from test import support from test.support.script_helper import run_python_until_end diff --git a/Lib/test/support/testcase.py b/Lib/test/support/testcase.py index e617b19b6ac..fad1e4cb349 100644 --- a/Lib/test/support/testcase.py +++ b/Lib/test/support/testcase.py @@ -1,64 +1,6 @@ from math import copysign, isnan -# XXX: RUSTPYTHON: removed in 3.14 -class ExtraAssertions: - - def assertIsSubclass(self, cls, superclass, msg=None): - if issubclass(cls, superclass): - return - standardMsg = f'{cls!r} is not a subclass of {superclass!r}' - self.fail(self._formatMessage(msg, standardMsg)) - - def assertNotIsSubclass(self, cls, superclass, msg=None): - if not issubclass(cls, superclass): - return - standardMsg = f'{cls!r} is a subclass of {superclass!r}' - self.fail(self._formatMessage(msg, standardMsg)) - - def assertHasAttr(self, obj, name, msg=None): - if not hasattr(obj, name): - if isinstance(obj, types.ModuleType): - standardMsg = f'module {obj.__name__!r} has no attribute {name!r}' - elif isinstance(obj, type): - standardMsg = f'type object {obj.__name__!r} has no attribute {name!r}' - else: - standardMsg = f'{type(obj).__name__!r} object has no attribute {name!r}' - self.fail(self._formatMessage(msg, standardMsg)) - - def assertNotHasAttr(self, obj, name, msg=None): - if hasattr(obj, name): - if isinstance(obj, types.ModuleType): - standardMsg = f'module {obj.__name__!r} has unexpected attribute {name!r}' - elif isinstance(obj, type): - standardMsg = f'type object {obj.__name__!r} has unexpected attribute {name!r}' - else: - standardMsg = f'{type(obj).__name__!r} object has unexpected attribute {name!r}' - self.fail(self._formatMessage(msg, standardMsg)) - - def assertStartsWith(self, s, prefix, msg=None): - if s.startswith(prefix): - return - standardMsg = f"{s!r} doesn't start with {prefix!r}" - self.fail(self._formatMessage(msg, standardMsg)) - - def assertNotStartsWith(self, s, prefix, msg=None): - if not s.startswith(prefix): - return - self.fail(self._formatMessage(msg, f"{s!r} starts with {prefix!r}")) - - def assertEndsWith(self, s, suffix, msg=None): - if s.endswith(suffix): - return - standardMsg = f"{s!r} doesn't end with {suffix!r}" - self.fail(self._formatMessage(msg, standardMsg)) - - def assertNotEndsWith(self, s, suffix, msg=None): - if not s.endswith(suffix): - return - self.fail(self._formatMessage(msg, f"{s!r} ends with {suffix!r}")) - - class ExceptionIsLikeMixin: def assertExceptionIsLike(self, exc, template): """ diff --git a/Lib/test/support/threading_helper.py b/Lib/test/support/threading_helper.py index 9b2b8f2dff0..cf87233f0e2 100644 --- a/Lib/test/support/threading_helper.py +++ b/Lib/test/support/threading_helper.py @@ -8,6 +8,7 @@ from test import support + #======================================================================= # Threading support to prevent reporting refleaks when running regrtest.py -R @@ -249,21 +250,32 @@ def requires_working_threading(*, module=False): return unittest.skipUnless(can_start_thread, msg) -def run_concurrently(worker_func, nthreads, args=(), kwargs={}): +def run_concurrently(worker_func, nthreads=None, args=(), kwargs={}): """ - Run the worker function concurrently in multiple threads. + Run the worker function(s) concurrently in multiple threads. + + If `worker_func` is a single callable, it is used for all threads. + If it is a list of callables, each callable is used for one thread. """ + from collections.abc import Iterable + + if nthreads is None: + nthreads = len(worker_func) + if not isinstance(worker_func, Iterable): + worker_func = [worker_func] * nthreads + assert len(worker_func) == nthreads + barrier = threading.Barrier(nthreads) - def wrapper_func(*args, **kwargs): + def wrapper_func(func, *args, **kwargs): # Wait for all threads to reach this point before proceeding. barrier.wait() - worker_func(*args, **kwargs) + func(*args, **kwargs) with catch_threading_exception() as cm: workers = [ - threading.Thread(target=wrapper_func, args=args, kwargs=kwargs) - for _ in range(nthreads) + threading.Thread(target=wrapper_func, args=(func, *args), kwargs=kwargs) + for func in worker_func ] with start_threads(workers): pass diff --git a/Lib/test/support/venv.py b/Lib/test/support/venv.py index b60f6097e65..757392b51c8 100644 --- a/Lib/test/support/venv.py +++ b/Lib/test/support/venv.py @@ -1,8 +1,8 @@ import contextlib import logging import os -import shlex import subprocess +import shlex import sys import sysconfig import tempfile diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index db09e50e8f4..d300337b915 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -1200,7 +1200,7 @@ def test_obsolete_write_lock(self): a = array.array('B', b"") self.assertRaises(BufferError, _testcapi.getbuffer_with_null_view, a) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON; hangs") def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, array.array, (self.typecode,)) diff --git a/Lib/test/test_asdl_parser.py b/Lib/test/test_asdl_parser.py new file mode 100644 index 00000000000..b9df6568123 --- /dev/null +++ b/Lib/test/test_asdl_parser.py @@ -0,0 +1,131 @@ +"""Tests for the asdl parser in Parser/asdl.py""" + +import importlib.machinery +import importlib.util +import os +from os.path import dirname +import sys +import sysconfig +import unittest + + +# This test is only relevant for from-source builds of Python. +if not sysconfig.is_python_build(): + raise unittest.SkipTest('test irrelevant for an installed Python') + +src_base = dirname(dirname(dirname(__file__))) +parser_dir = os.path.join(src_base, 'Parser') + + +class TestAsdlParser(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Loads the asdl module dynamically, since it's not in a real importable + # package. + # Parses Python.asdl into an ast.Module and run the check on it. + # There's no need to do this for each test method, hence setUpClass. + sys.path.insert(0, parser_dir) + loader = importlib.machinery.SourceFileLoader( + 'asdl', os.path.join(parser_dir, 'asdl.py')) + spec = importlib.util.spec_from_loader('asdl', loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + cls.asdl = module + cls.mod = cls.asdl.parse(os.path.join(parser_dir, 'Python.asdl')) + cls.assertTrue(cls.asdl.check(cls.mod), 'Module validation failed') + + @classmethod + def tearDownClass(cls): + del sys.path[0] + + def setUp(self): + # alias stuff from the class, for convenience + self.asdl = TestAsdlParser.asdl + self.mod = TestAsdlParser.mod + self.types = self.mod.types + + def test_module(self): + self.assertEqual(self.mod.name, 'Python') + self.assertIn('stmt', self.types) + self.assertIn('expr', self.types) + self.assertIn('mod', self.types) + + def test_definitions(self): + defs = self.mod.dfns + self.assertIsInstance(defs[0], self.asdl.Type) + self.assertIsInstance(defs[0].value, self.asdl.Sum) + + self.assertIsInstance(self.types['withitem'], self.asdl.Product) + self.assertIsInstance(self.types['alias'], self.asdl.Product) + + def test_product(self): + alias = self.types['alias'] + self.assertEqual( + str(alias), + 'Product([Field(identifier, name), Field(identifier, asname, quantifiers=[OPTIONAL])], ' + '[Field(int, lineno), Field(int, col_offset), ' + 'Field(int, end_lineno, quantifiers=[OPTIONAL]), Field(int, end_col_offset, quantifiers=[OPTIONAL])])') + + def test_attributes(self): + stmt = self.types['stmt'] + self.assertEqual(len(stmt.attributes), 4) + self.assertEqual(repr(stmt.attributes[0]), 'Field(int, lineno)') + self.assertEqual(repr(stmt.attributes[1]), 'Field(int, col_offset)') + self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, quantifiers=[OPTIONAL])') + self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, quantifiers=[OPTIONAL])') + + def test_constructor_fields(self): + ehandler = self.types['excepthandler'] + self.assertEqual(len(ehandler.types), 1) + self.assertEqual(len(ehandler.attributes), 4) + + cons = ehandler.types[0] + self.assertIsInstance(cons, self.asdl.Constructor) + self.assertEqual(len(cons.fields), 3) + + f0 = cons.fields[0] + self.assertEqual(f0.type, 'expr') + self.assertEqual(f0.name, 'type') + self.assertTrue(f0.opt) + + f1 = cons.fields[1] + self.assertEqual(f1.type, 'identifier') + self.assertEqual(f1.name, 'name') + self.assertTrue(f1.opt) + + f2 = cons.fields[2] + self.assertEqual(f2.type, 'stmt') + self.assertEqual(f2.name, 'body') + self.assertFalse(f2.opt) + self.assertTrue(f2.seq) + + def test_visitor(self): + class CustomVisitor(self.asdl.VisitorBase): + def __init__(self): + super().__init__() + self.names_with_seq = [] + + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type): + self.visit(type.value) + + def visitSum(self, sum): + for t in sum.types: + self.visit(t) + + def visitConstructor(self, cons): + for f in cons.fields: + if f.seq: + self.names_with_seq.append(cons.name) + + v = CustomVisitor() + v.visit(self.types['mod']) + self.assertEqual(v.names_with_seq, + ['Module', 'Module', 'Interactive', 'FunctionType']) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index 0578b755d65..00283ca05a0 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -114,7 +114,6 @@ def cleanup(): with self.assertRaisesRegex(AttributeError, msg): ast.AST() - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: .X object at 0x7e85c3a80> is not None def test_AST_garbage_collection(self): class X: pass @@ -1984,7 +1983,6 @@ def test_level_as_none(self): exec(code, ns) self.assertIn('sleep', ns) - @unittest.skip("TODO: RUSTPYTHON; crash") @skip_if_unlimited_stack_size @skip_emscripten_stack_overflow() def test_recursion_direct(self): @@ -1994,7 +1992,6 @@ def test_recursion_direct(self): with support.infinite_recursion(): compile(ast.Expression(e), "", "eval") - @unittest.skip("TODO: RUSTPYTHON; crash") @skip_if_unlimited_stack_size @skip_emscripten_stack_overflow() def test_recursion_indirect(self): @@ -2357,7 +2354,6 @@ def test_yield(self): self.expr(ast.Yield(ast.Name("x", ast.Store())), "must have Load") self.expr(ast.YieldFrom(ast.Name("x", ast.Store())), "must have Load") - @unittest.skip("TODO: RUSTPYTHON; thread 'main' panicked") def test_compare(self): left = ast.Name("x", ast.Load()) comp = ast.Compare(left, [ast.In()], []) diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py index dcd963b3355..e266d57742a 100644 --- a/Lib/test/test_asyncio/test_sendfile.py +++ b/Lib/test/test_asyncio/test_sendfile.py @@ -566,6 +566,10 @@ class EPollEventLoopTests(SendfileTestsBase, def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.EpollSelector()) + @unittest.skipIf(sys.platform != "win32", "TODO: RUSTPYTHON; Flaky on CI") + def test_sendfile_ssl_pre_and_post_data(self): + return super().test_sendfile_ssl_pre_and_post_data() + if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(SendfileTestsBase, test_utils.TestCase): @@ -573,6 +577,10 @@ class PollEventLoopTests(SendfileTestsBase, def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.PollSelector()) + @unittest.skipIf(sys.platform != "win32", "TODO: RUSTPYTHON; Flaky on CI") + def test_sendfile_ssl_pre_and_post_data(self): + return super().test_sendfile_ssl_pre_and_post_data() + # Should always exist. class SelectEventLoopTests(SendfileTestsBase, test_utils.TestCase): @@ -580,6 +588,10 @@ class SelectEventLoopTests(SendfileTestsBase, def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.SelectSelector()) + @unittest.skipIf(sys.platform != "win32", "TODO: RUSTPYTHON; Flaky on CI") + def test_sendfile_ssl_pre_and_post_data(self): + return super().test_sendfile_ssl_pre_and_post_data() + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 7ab6e1511d7..3e304c16642 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -282,7 +282,6 @@ def buffer_updated(self, nsize): with self.assertRaisesRegex(RuntimeError, 'empty buffer'): protocols._feed_data_to_buffered_proto(proto, b'12345') - @unittest.expectedFailure # TODO: RUSTPYTHON; - gc.collect() doesn't release SSLContext properly def test_start_tls_client_reg_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE @@ -349,7 +348,6 @@ async def client(addr): support.gc_collect() self.assertIsNone(client_context()) - @unittest.expectedFailure # TODO: RUSTPYTHON; - gc.collect() doesn't release SSLContext properly def test_create_connection_memory_leak(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE @@ -668,7 +666,6 @@ async def main(): self.loop.run_until_complete(main()) - @unittest.expectedFailure # TODO: RUSTPYTHON; - gc.collect() doesn't release SSLContext properly def test_handshake_timeout(self): # bpo-29970: Check that a connection is aborted if handshake is not # completed in timeout period, instead of remaining open indefinitely diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 8a291f1cb7e..9fbad3d0ea8 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -2989,10 +2989,6 @@ class PyTask_CFutureSubclass_Tests(BaseTaskTests, test_utils.TestCase): all_tasks = staticmethod(tasks._py_all_tasks) current_task = staticmethod(tasks._py_current_task) - @unittest.expectedFailure # TODO: RUSTPYTHON; Actual: not called. - def test_log_destroyed_pending_task(self): - return super().test_log_destroyed_pending_task() - @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') @@ -3008,7 +3004,6 @@ def test_log_destroyed_pending_task(self): return super().test_log_destroyed_pending_task() - @unittest.skipUnless(hasattr(futures, '_CFuture'), 'requires the C _asyncio module') class PyTask_CFuture_Tests(BaseTaskTests, test_utils.TestCase): @@ -3018,10 +3013,6 @@ class PyTask_CFuture_Tests(BaseTaskTests, test_utils.TestCase): all_tasks = staticmethod(tasks._py_all_tasks) current_task = staticmethod(tasks._py_current_task) - @unittest.expectedFailure # TODO: RUSTPYTHON; Actual: not called. - def test_log_destroyed_pending_task(self): - return super().test_log_destroyed_pending_task() - class PyTask_PyFuture_Tests(BaseTaskTests, SetMethodsTest, test_utils.TestCase): @@ -3031,10 +3022,6 @@ class PyTask_PyFuture_Tests(BaseTaskTests, SetMethodsTest, all_tasks = staticmethod(tasks._py_all_tasks) current_task = staticmethod(tasks._py_current_task) - @unittest.expectedFailure # TODO: RUSTPYTHON; Actual: not called. - def test_log_destroyed_pending_task(self): - return super().test_log_destroyed_pending_task() - @add_subclass_tests class PyTask_PyFuture_SubclassTests(BaseTaskTests, test_utils.TestCase): @@ -3043,9 +3030,6 @@ class PyTask_PyFuture_SubclassTests(BaseTaskTests, test_utils.TestCase): all_tasks = staticmethod(tasks._py_all_tasks) current_task = staticmethod(tasks._py_current_task) - @unittest.expectedFailure # TODO: RUSTPYTHON; Actual: not called. - def test_log_destroyed_pending_task(self): - return super().test_log_destroyed_pending_task() @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py index ddd9f951143..d01d36ad3db 100644 --- a/Lib/test/test_audit.py +++ b/Lib/test/test_audit.py @@ -23,6 +23,7 @@ def run_test_in_subprocess(self, *args): with subprocess.Popen( [sys.executable, "-X utf8", AUDIT_TESTS_PY, *args], encoding="utf-8", + errors="backslashreplace", stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as p: @@ -57,6 +58,7 @@ def test_block_add_hook(self): def test_block_add_hook_baseexception(self): self.do_test("test_block_add_hook_baseexception") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_marshal(self): import_helper.import_module("marshal") @@ -67,18 +69,33 @@ def test_pickle(self): self.do_test("test_pickle") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_monkeypatch(self): self.do_test("test_monkeypatch") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_open(self): self.do_test("test_open", os_helper.TESTFN) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cantrace(self): self.do_test("test_cantrace") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_mmap(self): self.do_test("test_mmap") + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_ctypes_call_function(self): + import_helper.import_module("ctypes") + self.do_test("test_ctypes_call_function") + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_posixsubprocess(self): + import_helper.import_module("_posixsubprocess") + self.do_test("test_posixsubprocess") + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_excepthook(self): returncode, events, stderr = self.run_python("test_excepthook") if not returncode: @@ -100,6 +117,7 @@ def test_unraisablehook(self): "RuntimeError('nonfatal-error') Exception ignored for audit hook test", ) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_winreg(self): import_helper.import_module("winreg") returncode, events, stderr = self.run_python("test_winreg") @@ -125,8 +143,9 @@ def test_socket(self): self.assertEqual(events[0][0], "socket.gethostname") self.assertEqual(events[1][0], "socket.__new__") self.assertEqual(events[2][0], "socket.bind") - self.assertTrue(events[2][2].endswith("('127.0.0.1', 8080)")) + self.assertEndsWith(events[2][2], "('127.0.0.1', 8080)") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_gc(self): returncode, events, stderr = self.run_python("test_gc") if returncode: @@ -156,6 +175,7 @@ def test_http(self): self.assertIn('HTTP', events[1][2]) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_sqlite3(self): sqlite3 = import_helper.import_module("sqlite3") returncode, events, stderr = self.run_python("test_sqlite3") @@ -200,6 +220,7 @@ def test_sys_getframemodulename(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_threading(self): returncode, events, stderr = self.run_python("test_threading") if returncode: @@ -218,6 +239,7 @@ def test_threading(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_wmi_exec_query(self): import_helper.import_module("_wmi") returncode, events, stderr = self.run_python("test_wmi_exec_query") @@ -231,6 +253,7 @@ def test_wmi_exec_query(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_syslog(self): syslog = import_helper.import_module("syslog") @@ -292,6 +315,7 @@ def test_sys_monitoring_register_callback(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_winapi_createnamedpipe(self): winapi = import_helper.import_module("_winapi") @@ -313,6 +337,14 @@ def test_assert_unicode(self): if returncode: self.fail(stderr) + @support.support_remote_exec_only + @support.cpython_only + def test_sys_remote_exec(self): + returncode, events, stderr = self.run_python("test_sys_remote_exec") + self.assertTrue(any(["sys.remote_exec" in event for event in events])) + self.assertTrue(any(["cpython.remote_debugger_script" in event for event in events])) + if returncode: + self.fail(stderr) if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_bdb.py b/Lib/test/test_bdb.py index f15dae13eb3..590d8166b68 100644 --- a/Lib/test/test_bdb.py +++ b/Lib/test/test_bdb.py @@ -728,6 +728,7 @@ def test_until_in_caller_frame(self): with TracerRun(self) as tracer: tracer.runcall(tfunc_main) + @unittest.skipIf(hasattr(__import__("sys"), "addaudithook"), "TODO: RUSTPYTHON; Currently no conditional tracing toggle") @patch_list(sys.meta_path) def test_skip(self): # Check that tracing is skipped over the import statement in diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index cf0268c2ce5..13b0d4a2a22 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -2696,6 +2696,7 @@ def detach_readline(self): else: yield + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty(self): # Test input() functionality when wired to a tty self.check_input_tty("prompt", b"quux") @@ -2710,17 +2711,20 @@ def test_input_tty_non_ascii_unicode_errors(self): # Check stdin/stdout error handler is used when invoking PyOS_Readline() self.check_input_tty("prompté", b"quux\xe9", "ascii") + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty_null_in_prompt(self): self.check_input_tty("prompt\0", b"", expected='ValueError: input: prompt string cannot contain ' 'null characters') + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty_nonencodable_prompt(self): self.check_input_tty("prompté", b"quux", "ascii", stdout_errors='strict', expected="UnicodeEncodeError: 'ascii' codec can't encode " "character '\\xe9' in position 6: ordinal not in " "range(128)") + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty_nondecodable_input(self): self.check_input_tty("prompt", b"quux\xe9", "ascii", stdin_errors='strict', expected="UnicodeDecodeError: 'ascii' codec can't decode " diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index df8a4d68892..a72bc03c329 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -1042,7 +1042,7 @@ def test_find_etc_raise_correct_error_messages(self): self.assertRaisesRegex(TypeError, r'\bendswith\b', b.endswith, x, None, None, None) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON; hangs") def test_free_after_iterating(self): test.support.check_free_after_iterating(self, iter, self.type2test) test.support.check_free_after_iterating(self, reversed, self.type2test) diff --git a/Lib/test/test_class.py b/Lib/test/test_class.py index 1aee6fb73d2..f2f99d366d9 100644 --- a/Lib/test/test_class.py +++ b/Lib/test/test_class.py @@ -887,14 +887,14 @@ class VarSizedSubclass(tuple): class TestInlineValues(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON; NameError: name 'has_inline_values' is not defined. + @unittest.expectedFailure # TODO: RUSTPYTHON; NameError: name 'has_inline_values' is not defined. def test_no_flags_for_slots_class(self): flags = NoManagedDict.__flags__ self.assertEqual(flags & Py_TPFLAGS_MANAGED_DICT, 0) self.assertEqual(flags & Py_TPFLAGS_INLINE_VALUES, 0) self.assertFalse(has_inline_values(NoManagedDict())) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 4 + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 4 def test_both_flags_for_regular_class(self): for cls in (Plain, WithAttrs): with self.subTest(cls=cls.__name__): @@ -903,7 +903,7 @@ def test_both_flags_for_regular_class(self): self.assertEqual(flags & Py_TPFLAGS_INLINE_VALUES, Py_TPFLAGS_INLINE_VALUES) self.assertTrue(has_inline_values(cls())) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 4 + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 4 def test_managed_dict_only_for_varsized_subclass(self): flags = VarSizedSubclass.__flags__ self.assertEqual(flags & Py_TPFLAGS_MANAGED_DICT, Py_TPFLAGS_MANAGED_DICT) @@ -1056,6 +1056,5 @@ def __init__(self): self.assertFalse(out, msg=out.decode('utf-8')) self.assertFalse(err, msg=err.decode('utf-8')) - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py new file mode 100644 index 00000000000..73bb942af7c --- /dev/null +++ b/Lib/test/test_clinic.py @@ -0,0 +1,4594 @@ +# Argument Clinic +# Copyright 2012-2013 by Larry Hastings. +# Licensed to the PSF under a contributor agreement. + +from functools import partial +from test import support, test_tools +from test.support import os_helper +from test.support.os_helper import TESTFN, unlink, rmtree +from textwrap import dedent +from unittest import TestCase +import inspect +import os.path +import re +import sys +import unittest + +test_tools.skip_if_missing('clinic') +with test_tools.imports_under_tool('clinic'): + import libclinic + from libclinic import ClinicError, unspecified, NULL, fail + from libclinic.converters import int_converter, str_converter, self_converter + from libclinic.function import ( + Module, Class, Function, FunctionKind, Parameter, + permute_optional_groups, permute_right_option_groups, + permute_left_option_groups) + import clinic + from libclinic.clanguage import CLanguage + from libclinic.converter import converters, legacy_converters + from libclinic.return_converters import return_converters, int_return_converter + from libclinic.block_parser import Block, BlockParser + from libclinic.codegen import BlockPrinter, Destination + from libclinic.dsl_parser import DSLParser + from libclinic.cli import parse_file, Clinic + + +def repeat_fn(*functions): + def wrapper(test): + def wrapped(self): + for fn in functions: + with self.subTest(fn=fn): + test(self, fn) + return wrapped + return wrapper + +def _make_clinic(*, filename='clinic_tests', limited_capi=False): + clang = CLanguage(filename) + c = Clinic(clang, filename=filename, limited_capi=limited_capi) + c.block_parser = BlockParser('', clang) + return c + + +def _expect_failure(tc, parser, code, errmsg, *, filename=None, lineno=None, + strip=True): + """Helper for the parser tests. + + tc: unittest.TestCase; passed self in the wrapper + parser: the clinic parser used for this test case + code: a str with input text (clinic code) + errmsg: the expected error message + filename: str, optional filename + lineno: int, optional line number + """ + code = dedent(code) + if strip: + code = code.strip() + errmsg = re.escape(errmsg) + with tc.assertRaisesRegex(ClinicError, errmsg) as cm: + parser(code) + if filename is not None: + tc.assertEqual(cm.exception.filename, filename) + if lineno is not None: + tc.assertEqual(cm.exception.lineno, lineno) + return cm.exception + + +def restore_dict(converters, old_converters): + converters.clear() + converters.update(old_converters) + + +def save_restore_converters(testcase): + testcase.addCleanup(restore_dict, converters, + converters.copy()) + testcase.addCleanup(restore_dict, legacy_converters, + legacy_converters.copy()) + testcase.addCleanup(restore_dict, return_converters, + return_converters.copy()) + + +class ClinicWholeFileTest(TestCase): + maxDiff = None + + def expect_failure(self, raw, errmsg, *, filename=None, lineno=None): + _expect_failure(self, self.clinic.parse, raw, errmsg, + filename=filename, lineno=lineno) + + def setUp(self): + save_restore_converters(self) + self.clinic = _make_clinic(filename="test.c") + + def test_eol(self): + # regression test: + # clinic's block parser didn't recognize + # the "end line" for the block if it + # didn't end in "\n" (as in, the last) + # byte of the file was '/'. + # so it would spit out an end line for you. + # and since you really already had one, + # the last line of the block got corrupted. + raw = "/*[clinic]\nfoo\n[clinic]*/" + cooked = self.clinic.parse(raw).splitlines() + end_line = cooked[2].rstrip() + # this test is redundant, it's just here explicitly to catch + # the regression test so we don't forget what it looked like + self.assertNotEqual(end_line, "[clinic]*/[clinic]*/") + self.assertEqual(end_line, "[clinic]*/") + + def test_mangled_marker_line(self): + raw = """ + /*[clinic input] + [clinic start generated code]*/ + /*[clinic end generated code: foo]*/ + """ + err = ( + "Mangled Argument Clinic marker line: " + "'/*[clinic end generated code: foo]*/'" + ) + self.expect_failure(raw, err, filename="test.c", lineno=3) + + def test_checksum_mismatch(self): + raw = """ + /*[clinic input] + [clinic start generated code]*/ + /*[clinic end generated code: output=0123456789abcdef input=fedcba9876543210]*/ + """ + err = ("Checksum mismatch! " + "Expected '0123456789abcdef', computed 'da39a3ee5e6b4b0d'") + self.expect_failure(raw, err, filename="test.c", lineno=3) + + def test_garbage_after_stop_line(self): + raw = """ + /*[clinic input] + [clinic start generated code]*/foobarfoobar! + """ + err = "Garbage after stop line: 'foobarfoobar!'" + self.expect_failure(raw, err, filename="test.c", lineno=2) + + def test_whitespace_before_stop_line(self): + raw = """ + /*[clinic input] + [clinic start generated code]*/ + """ + err = ( + "Whitespace is not allowed before the stop line: " + "' [clinic start generated code]*/'" + ) + self.expect_failure(raw, err, filename="test.c", lineno=2) + + def test_parse_with_body_prefix(self): + clang = CLanguage(None) + clang.body_prefix = "//" + clang.start_line = "//[{dsl_name} start]" + clang.stop_line = "//[{dsl_name} stop]" + cl = Clinic(clang, filename="test.c", limited_capi=False) + raw = dedent(""" + //[clinic start] + //module test + //[clinic stop] + """).strip() + out = cl.parse(raw) + expected = dedent(""" + //[clinic start] + //module test + // + //[clinic stop] + /*[clinic end generated code: output=da39a3ee5e6b4b0d input=65fab8adff58cf08]*/ + """).lstrip() # Note, lstrip() because of the newline + self.assertEqual(out, expected) + + def test_cpp_monitor_fail_nested_block_comment(self): + raw = """ + /* start + /* nested + */ + */ + """ + err = 'Nested block comment!' + self.expect_failure(raw, err, filename="test.c", lineno=2) + + def test_cpp_monitor_fail_invalid_format_noarg(self): + raw = """ + #if + a() + #endif + """ + err = 'Invalid format for #if line: no argument!' + self.expect_failure(raw, err, filename="test.c", lineno=1) + + def test_cpp_monitor_fail_invalid_format_toomanyargs(self): + raw = """ + #ifdef A B + a() + #endif + """ + err = 'Invalid format for #ifdef line: should be exactly one argument!' + self.expect_failure(raw, err, filename="test.c", lineno=1) + + def test_cpp_monitor_fail_no_matching_if(self): + raw = '#else' + err = '#else without matching #if / #ifdef / #ifndef!' + self.expect_failure(raw, err, filename="test.c", lineno=1) + + def test_directive_output_unknown_preset(self): + raw = """ + /*[clinic input] + output preset nosuchpreset + [clinic start generated code]*/ + """ + err = "Unknown preset 'nosuchpreset'" + self.expect_failure(raw, err) + + def test_directive_output_cant_pop(self): + raw = """ + /*[clinic input] + output pop + [clinic start generated code]*/ + """ + err = "Can't 'output pop', stack is empty" + self.expect_failure(raw, err) + + def test_directive_output_print(self): + raw = dedent(""" + /*[clinic input] + output print 'I told you once.' + [clinic start generated code]*/ + """) + out = self.clinic.parse(raw) + # The generated output will differ for every run, but we can check that + # it starts with the clinic block, we check that it contains all the + # expected fields, and we check that it contains the checksum line. + self.assertStartsWith(out, dedent(""" + /*[clinic input] + output print 'I told you once.' + [clinic start generated code]*/ + """)) + fields = { + "cpp_endif", + "cpp_if", + "docstring_definition", + "docstring_prototype", + "impl_definition", + "impl_prototype", + "methoddef_define", + "methoddef_ifndef", + "parser_definition", + "parser_prototype", + } + for field in fields: + with self.subTest(field=field): + self.assertIn(field, out) + last_line = out.rstrip().split("\n")[-1] + self.assertStartsWith(last_line, "/*[clinic end generated code: output=") + + def test_directive_wrong_arg_number(self): + raw = dedent(""" + /*[clinic input] + preserve foo bar baz eggs spam ham mushrooms + [clinic start generated code]*/ + """) + err = "takes 1 positional argument but 8 were given" + self.expect_failure(raw, err) + + def test_unknown_destination_command(self): + raw = """ + /*[clinic input] + destination buffer nosuchcommand + [clinic start generated code]*/ + """ + err = "unknown destination command 'nosuchcommand'" + self.expect_failure(raw, err) + + def test_no_access_to_members_in_converter_init(self): + raw = """ + /*[python input] + class Custom_converter(CConverter): + converter = "some_c_function" + def converter_init(self): + self.function.noaccess + [python start generated code]*/ + /*[clinic input] + module test + test.fn + a: Custom + [clinic start generated code]*/ + """ + err = ( + "accessing self.function inside converter_init is disallowed!" + ) + self.expect_failure(raw, err) + + def test_clone_mismatch(self): + err = "'kind' of function and cloned function don't match!" + block = """ + /*[clinic input] + module m + @classmethod + m.f1 + a: object + [clinic start generated code]*/ + /*[clinic input] + @staticmethod + m.f2 = m.f1 + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=9) + + def test_badly_formed_return_annotation(self): + err = "Badly formed annotation for 'm.f': 'Custom'" + block = """ + /*[python input] + class Custom_return_converter(CReturnConverter): + def __init__(self): + raise ValueError("abc") + [python start generated code]*/ + /*[clinic input] + module m + m.f -> Custom + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=8) + + def test_star_after_vararg(self): + err = "'my_test_func' uses '*' more than once." + block = """ + /*[clinic input] + my_test_func + + pos_arg: object + *args: tuple + * + kw_arg: object + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=6) + + def test_vararg_after_star(self): + err = "'my_test_func' uses '*' more than once." + block = """ + /*[clinic input] + my_test_func + + pos_arg: object + * + *args: tuple + kw_arg: object + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=6) + + def test_module_already_got_one(self): + err = "Already defined module 'm'!" + block = """ + /*[clinic input] + module m + module m + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=3) + + def test_destination_already_got_one(self): + err = "Destination already exists: 'test'" + block = """ + /*[clinic input] + destination test new buffer + destination test new buffer + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=3) + + def test_destination_does_not_exist(self): + err = "Destination does not exist: '/dev/null'" + block = """ + /*[clinic input] + output everything /dev/null + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=2) + + def test_class_already_got_one(self): + err = "Already defined class 'C'!" + block = """ + /*[clinic input] + class C "" "" + class C "" "" + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=3) + + def test_cant_nest_module_inside_class(self): + err = "Can't nest a module inside a class!" + block = """ + /*[clinic input] + class C "" "" + module C.m + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=3) + + def test_dest_buffer_not_empty_at_eof(self): + expected_warning = ("Destination buffer 'buffer' not empty at " + "end of file, emptying.") + expected_generated = dedent(""" + /*[clinic input] + output everything buffer + fn + a: object + / + [clinic start generated code]*/ + /*[clinic end generated code: output=da39a3ee5e6b4b0d input=1c4668687f5fd002]*/ + + /*[clinic input] + dump buffer + [clinic start generated code]*/ + + PyDoc_VAR(fn__doc__); + + PyDoc_STRVAR(fn__doc__, + "fn($module, a, /)\\n" + "--\\n" + "\\n"); + + #define FN_METHODDEF \\ + {"fn", (PyCFunction)fn, METH_O, fn__doc__}, + + static PyObject * + fn(PyObject *module, PyObject *a) + /*[clinic end generated code: output=be6798b148ab4e53 input=524ce2e021e4eba6]*/ + """) + block = dedent(""" + /*[clinic input] + output everything buffer + fn + a: object + / + [clinic start generated code]*/ + """) + with support.captured_stdout() as stdout: + generated = self.clinic.parse(block) + self.assertIn(expected_warning, stdout.getvalue()) + self.assertEqual(generated, expected_generated) + + def test_dest_clear(self): + err = "Can't clear destination 'file': it's not of type 'buffer'" + block = """ + /*[clinic input] + destination file clear + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=2) + + def test_directive_set_misuse(self): + err = "unknown variable 'ets'" + block = """ + /*[clinic input] + set ets tse + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=2) + + def test_directive_set_prefix(self): + block = dedent(""" + /*[clinic input] + set line_prefix '// ' + output everything suppress + output docstring_prototype buffer + fn + a: object + / + [clinic start generated code]*/ + /* We need to dump the buffer. + * If not, Argument Clinic will emit a warning */ + /*[clinic input] + dump buffer + [clinic start generated code]*/ + """) + generated = self.clinic.parse(block) + expected_docstring_prototype = "// PyDoc_VAR(fn__doc__);" + self.assertIn(expected_docstring_prototype, generated) + + def test_directive_set_suffix(self): + block = dedent(""" + /*[clinic input] + set line_suffix ' // test' + output everything suppress + output docstring_prototype buffer + fn + a: object + / + [clinic start generated code]*/ + /* We need to dump the buffer. + * If not, Argument Clinic will emit a warning */ + /*[clinic input] + dump buffer + [clinic start generated code]*/ + """) + generated = self.clinic.parse(block) + expected_docstring_prototype = "PyDoc_VAR(fn__doc__); // test" + self.assertIn(expected_docstring_prototype, generated) + + def test_directive_set_prefix_and_suffix(self): + block = dedent(""" + /*[clinic input] + set line_prefix '{block comment start} ' + set line_suffix ' {block comment end}' + output everything suppress + output docstring_prototype buffer + fn + a: object + / + [clinic start generated code]*/ + /* We need to dump the buffer. + * If not, Argument Clinic will emit a warning */ + /*[clinic input] + dump buffer + [clinic start generated code]*/ + """) + generated = self.clinic.parse(block) + expected_docstring_prototype = "/* PyDoc_VAR(fn__doc__); */" + self.assertIn(expected_docstring_prototype, generated) + + def test_directive_printout(self): + block = dedent(""" + /*[clinic input] + output everything buffer + printout test + [clinic start generated code]*/ + """) + expected = dedent(""" + /*[clinic input] + output everything buffer + printout test + [clinic start generated code]*/ + test + /*[clinic end generated code: output=4e1243bd22c66e76 input=898f1a32965d44ca]*/ + """) + generated = self.clinic.parse(block) + self.assertEqual(generated, expected) + + def test_directive_preserve_twice(self): + err = "Can't have 'preserve' twice in one block!" + block = """ + /*[clinic input] + preserve + preserve + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=3) + + def test_directive_preserve_input(self): + err = "'preserve' only works for blocks that don't produce any output!" + block = """ + /*[clinic input] + preserve + fn + a: object + / + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=6) + + def test_directive_preserve_output(self): + block = dedent(""" + /*[clinic input] + output everything buffer + preserve + [clinic start generated code]*/ + // Preserve this + /*[clinic end generated code: output=eaa49677ae4c1f7d input=559b5db18fddae6a]*/ + /*[clinic input] + dump buffer + [clinic start generated code]*/ + /*[clinic end generated code: output=da39a3ee5e6b4b0d input=524ce2e021e4eba6]*/ + """) + generated = self.clinic.parse(block) + self.assertEqual(generated, block) + + def test_directive_output_invalid_command(self): + err = dedent(""" + Invalid command or destination name 'cmd'. Must be one of: + - 'preset' + - 'push' + - 'pop' + - 'print' + - 'everything' + - 'cpp_if' + - 'docstring_prototype' + - 'docstring_definition' + - 'methoddef_define' + - 'impl_prototype' + - 'parser_prototype' + - 'parser_definition' + - 'cpp_endif' + - 'methoddef_ifndef' + - 'impl_definition' + """).strip() + block = """ + /*[clinic input] + output cmd buffer + [clinic start generated code]*/ + """ + self.expect_failure(block, err, lineno=2) + + def test_validate_cloned_init(self): + block = """ + /*[clinic input] + class C "void *" "" + C.meth + a: int + [clinic start generated code]*/ + /*[clinic input] + @classmethod + C.__init__ = C.meth + [clinic start generated code]*/ + """ + err = "'__init__' must be a normal method; got 'FunctionKind.CLASS_METHOD'!" + self.expect_failure(block, err, lineno=8) + + def test_validate_cloned_new(self): + block = """ + /*[clinic input] + class C "void *" "" + C.meth + a: int + [clinic start generated code]*/ + /*[clinic input] + C.__new__ = C.meth + [clinic start generated code]*/ + """ + err = "'__new__' must be a class method" + self.expect_failure(block, err, lineno=7) + + def test_no_c_basename_cloned(self): + block = """ + /*[clinic input] + foo2 + [clinic start generated code]*/ + /*[clinic input] + foo as = foo2 + [clinic start generated code]*/ + """ + err = "No C basename provided after 'as' keyword" + self.expect_failure(block, err, lineno=5) + + def test_cloned_with_custom_c_basename(self): + raw = dedent(""" + /*[clinic input] + # Make sure we don't create spurious clinic/ directories. + output everything suppress + foo2 + [clinic start generated code]*/ + + /*[clinic input] + foo as foo1 = foo2 + [clinic start generated code]*/ + """) + self.clinic.parse(raw) + funcs = self.clinic.functions + self.assertEqual(len(funcs), 2) + self.assertEqual(funcs[1].name, "foo") + self.assertEqual(funcs[1].c_basename, "foo1") + + def test_cloned_with_illegal_c_basename(self): + block = """ + /*[clinic input] + class C "void *" "" + foo1 + [clinic start generated code]*/ + + /*[clinic input] + foo2 as .illegal. = foo1 + [clinic start generated code]*/ + """ + err = "Illegal C basename: '.illegal.'" + self.expect_failure(block, err, lineno=7) + + def test_cloned_forced_text_signature(self): + block = dedent(""" + /*[clinic input] + @text_signature "($module, a[, b])" + src + a: object + param a + b: object = NULL + / + + docstring + [clinic start generated code]*/ + + /*[clinic input] + dst = src + [clinic start generated code]*/ + """) + self.clinic.parse(block) + self.addCleanup(rmtree, "clinic") + funcs = self.clinic.functions + self.assertEqual(len(funcs), 2) + + src_docstring_lines = funcs[0].docstring.split("\n") + dst_docstring_lines = funcs[1].docstring.split("\n") + + # Signatures are copied. + self.assertEqual(src_docstring_lines[0], "src($module, a[, b])") + self.assertEqual(dst_docstring_lines[0], "dst($module, a[, b])") + + # Param docstrings are copied. + self.assertIn(" param a", src_docstring_lines) + self.assertIn(" param a", dst_docstring_lines) + + # Docstrings are not copied. + self.assertIn("docstring", src_docstring_lines) + self.assertNotIn("docstring", dst_docstring_lines) + + def test_cloned_forced_text_signature_illegal(self): + block = """ + /*[clinic input] + @text_signature "($module, a[, b])" + src + a: object + b: object = NULL + / + [clinic start generated code]*/ + + /*[clinic input] + @text_signature "($module, a_override[, b])" + dst = src + [clinic start generated code]*/ + """ + err = "Cannot use @text_signature when cloning a function" + self.expect_failure(block, err, lineno=11) + + def test_ignore_preprocessor_in_comments(self): + for dsl in "clinic", "python": + raw = dedent(f"""\ + /*[{dsl} input] + # CPP directives, valid or not, should be ignored in C comments. + # + [{dsl} start generated code]*/ + """) + self.clinic.parse(raw) + + +class ParseFileUnitTest(TestCase): + def expect_parsing_failure( + self, *, filename, expected_error, verify=True, output=None + ): + errmsg = re.escape(dedent(expected_error).strip()) + with self.assertRaisesRegex(ClinicError, errmsg): + parse_file(filename, limited_capi=False) + + def test_parse_file_no_extension(self) -> None: + self.expect_parsing_failure( + filename="foo", + expected_error="Can't extract file type for file 'foo'" + ) + + def test_parse_file_strange_extension(self) -> None: + filenames_to_errors = { + "foo.rs": "Can't identify file type for file 'foo.rs'", + "foo.hs": "Can't identify file type for file 'foo.hs'", + "foo.js": "Can't identify file type for file 'foo.js'", + } + for filename, errmsg in filenames_to_errors.items(): + with self.subTest(filename=filename): + self.expect_parsing_failure(filename=filename, expected_error=errmsg) + + +class ClinicGroupPermuterTest(TestCase): + def _test(self, l, m, r, output): + computed = permute_optional_groups(l, m, r) + self.assertEqual(output, computed) + + def test_range(self): + self._test([['start']], ['stop'], [['step']], + ( + ('stop',), + ('start', 'stop',), + ('start', 'stop', 'step',), + )) + + def test_add_window(self): + self._test([['x', 'y']], ['ch'], [['attr']], + ( + ('ch',), + ('ch', 'attr'), + ('x', 'y', 'ch',), + ('x', 'y', 'ch', 'attr'), + )) + + def test_ludicrous(self): + self._test([['a1', 'a2', 'a3'], ['b1', 'b2']], ['c1'], [['d1', 'd2'], ['e1', 'e2', 'e3']], + ( + ('c1',), + ('b1', 'b2', 'c1'), + ('b1', 'b2', 'c1', 'd1', 'd2'), + ('a1', 'a2', 'a3', 'b1', 'b2', 'c1'), + ('a1', 'a2', 'a3', 'b1', 'b2', 'c1', 'd1', 'd2'), + ('a1', 'a2', 'a3', 'b1', 'b2', 'c1', 'd1', 'd2', 'e1', 'e2', 'e3'), + )) + + def test_right_only(self): + self._test([], [], [['a'],['b'],['c']], + ( + (), + ('a',), + ('a', 'b'), + ('a', 'b', 'c') + )) + + def test_have_left_options_but_required_is_empty(self): + def fn(): + permute_optional_groups(['a'], [], []) + self.assertRaises(ValueError, fn) + + +class ClinicLinearFormatTest(TestCase): + def _test(self, input, output, **kwargs): + computed = libclinic.linear_format(input, **kwargs) + self.assertEqual(output, computed) + + def test_empty_strings(self): + self._test('', '') + + def test_solo_newline(self): + self._test('\n', '\n') + + def test_no_substitution(self): + self._test(""" + abc + """, """ + abc + """) + + def test_empty_substitution(self): + self._test(""" + abc + {name} + def + """, """ + abc + def + """, name='') + + def test_single_line_substitution(self): + self._test(""" + abc + {name} + def + """, """ + abc + GARGLE + def + """, name='GARGLE') + + def test_multiline_substitution(self): + self._test(""" + abc + {name} + def + """, """ + abc + bingle + bungle + + def + """, name='bingle\nbungle\n') + + def test_text_before_block_marker(self): + regex = re.escape("found before '{marker}'") + with self.assertRaisesRegex(ClinicError, regex): + libclinic.linear_format("no text before marker for you! {marker}", + marker="not allowed!") + + def test_text_after_block_marker(self): + regex = re.escape("found after '{marker}'") + with self.assertRaisesRegex(ClinicError, regex): + libclinic.linear_format("{marker} no text after marker for you!", + marker="not allowed!") + + +class InertParser: + def __init__(self, clinic): + pass + + def parse(self, block): + pass + +class CopyParser: + def __init__(self, clinic): + pass + + def parse(self, block): + block.output = block.input + + +class ClinicBlockParserTest(TestCase): + def _test(self, input, output): + language = CLanguage(None) + + blocks = list(BlockParser(input, language)) + writer = BlockPrinter(language) + for block in blocks: + writer.print_block(block) + output = writer.f.getvalue() + assert output == input, "output != input!\n\noutput " + repr(output) + "\n\n input " + repr(input) + + def round_trip(self, input): + return self._test(input, input) + + def test_round_trip_1(self): + self.round_trip(""" + verbatim text here + lah dee dah + """) + def test_round_trip_2(self): + self.round_trip(""" + verbatim text here + lah dee dah +/*[inert] +abc +[inert]*/ +def +/*[inert checksum: 7b18d017f89f61cf17d47f92749ea6930a3f1deb]*/ +xyz +""") + + def _test_clinic(self, input, output): + language = CLanguage(None) + c = Clinic(language, filename="file", limited_capi=False) + c.parsers['inert'] = InertParser(c) + c.parsers['copy'] = CopyParser(c) + computed = c.parse(input) + self.assertEqual(output, computed) + + def test_clinic_1(self): + self._test_clinic(""" + verbatim text here + lah dee dah +/*[copy input] +def +[copy start generated code]*/ +abc +/*[copy end generated code: output=03cfd743661f0797 input=7b18d017f89f61cf]*/ +xyz +""", """ + verbatim text here + lah dee dah +/*[copy input] +def +[copy start generated code]*/ +def +/*[copy end generated code: output=7b18d017f89f61cf input=7b18d017f89f61cf]*/ +xyz +""") + + +class ClinicParserTest(TestCase): + + def parse(self, text): + c = _make_clinic() + parser = DSLParser(c) + block = Block(text) + parser.parse(block) + return block + + def parse_function(self, text, signatures_in_block=2, function_index=1): + block = self.parse(text) + s = block.signatures + self.assertEqual(len(s), signatures_in_block) + assert isinstance(s[0], Module) + assert isinstance(s[function_index], Function) + return s[function_index] + + def expect_failure(self, block, err, *, + filename=None, lineno=None, strip=True): + return _expect_failure(self, self.parse_function, block, err, + filename=filename, lineno=lineno, strip=strip) + + def checkDocstring(self, fn, expected): + self.assertTrue(hasattr(fn, "docstring")) + self.assertEqual(dedent(expected).strip(), + fn.docstring.strip()) + + def test_trivial(self): + parser = DSLParser(_make_clinic()) + block = Block(""" + module os + os.access + """) + parser.parse(block) + module, function = block.signatures + self.assertEqual("access", function.name) + self.assertEqual("os", module.name) + + def test_ignore_line(self): + block = self.parse(dedent(""" + # + module os + os.access + """)) + module, function = block.signatures + self.assertEqual("access", function.name) + self.assertEqual("os", module.name) + + def test_param(self): + function = self.parse_function(""" + module os + os.access + path: int + """) + self.assertEqual("access", function.name) + self.assertEqual(2, len(function.parameters)) + p = function.parameters['path'] + self.assertEqual('path', p.name) + self.assertIsInstance(p.converter, int_converter) + + def test_param_default(self): + function = self.parse_function(""" + module os + os.access + follow_symlinks: bool = True + """) + p = function.parameters['follow_symlinks'] + self.assertEqual(True, p.default) + + def test_param_with_continuations(self): + function = self.parse_function(r""" + module os + os.access + follow_symlinks: \ + bool \ + = \ + True + """) + p = function.parameters['follow_symlinks'] + self.assertEqual(True, p.default) + + def test_param_default_none(self): + function = self.parse_function(r""" + module test + test.func + obj: object = None + str: str(accept={str, NoneType}) = None + buf: Py_buffer(accept={str, buffer, NoneType}) = None + """) + p = function.parameters['obj'] + self.assertIs(p.default, None) + self.assertEqual(p.converter.py_default, 'None') + self.assertEqual(p.converter.c_default, 'Py_None') + + p = function.parameters['str'] + self.assertIs(p.default, None) + self.assertEqual(p.converter.py_default, 'None') + self.assertEqual(p.converter.c_default, 'NULL') + + p = function.parameters['buf'] + self.assertIs(p.default, None) + self.assertEqual(p.converter.py_default, 'None') + self.assertEqual(p.converter.c_default, '{NULL, NULL}') + + def test_param_default_null(self): + function = self.parse_function(r""" + module test + test.func + obj: object = NULL + str: str = NULL + buf: Py_buffer = NULL + fsencoded: unicode_fs_encoded = NULL + fsdecoded: unicode_fs_decoded = NULL + """) + p = function.parameters['obj'] + self.assertIs(p.default, NULL) + self.assertEqual(p.converter.py_default, '') + self.assertEqual(p.converter.c_default, 'NULL') + + p = function.parameters['str'] + self.assertIs(p.default, NULL) + self.assertEqual(p.converter.py_default, '') + self.assertEqual(p.converter.c_default, 'NULL') + + p = function.parameters['buf'] + self.assertIs(p.default, NULL) + self.assertEqual(p.converter.py_default, '') + self.assertEqual(p.converter.c_default, '{NULL, NULL}') + + p = function.parameters['fsencoded'] + self.assertIs(p.default, NULL) + self.assertEqual(p.converter.py_default, '') + self.assertEqual(p.converter.c_default, 'NULL') + + p = function.parameters['fsdecoded'] + self.assertIs(p.default, NULL) + self.assertEqual(p.converter.py_default, '') + self.assertEqual(p.converter.c_default, 'NULL') + + def test_param_default_str_literal(self): + function = self.parse_function(r""" + module test + test.func + str: str = ' \t\n\r\v\f\xa0' + buf: Py_buffer(accept={str, buffer}) = ' \t\n\r\v\f\xa0' + """) + p = function.parameters['str'] + self.assertEqual(p.default, ' \t\n\r\v\f\xa0') + self.assertEqual(p.converter.py_default, r"' \t\n\r\x0b\x0c\xa0'") + self.assertEqual(p.converter.c_default, r'" \t\n\r\v\f\u00a0"') + + p = function.parameters['buf'] + self.assertEqual(p.default, ' \t\n\r\v\f\xa0') + self.assertEqual(p.converter.py_default, r"' \t\n\r\x0b\x0c\xa0'") + self.assertEqual(p.converter.c_default, + r'{.buf = " \t\n\r\v\f\302\240", .obj = NULL, .len = 8}') + + def test_param_default_bytes_literal(self): + function = self.parse_function(r""" + module test + test.func + str: str(accept={robuffer}) = b' \t\n\r\v\f\xa0' + buf: Py_buffer = b' \t\n\r\v\f\xa0' + """) + p = function.parameters['str'] + self.assertEqual(p.default, b' \t\n\r\v\f\xa0') + self.assertEqual(p.converter.py_default, r"b' \t\n\r\x0b\x0c\xa0'") + self.assertEqual(p.converter.c_default, r'" \t\n\r\v\f\240"') + + p = function.parameters['buf'] + self.assertEqual(p.default, b' \t\n\r\v\f\xa0') + self.assertEqual(p.converter.py_default, r"b' \t\n\r\x0b\x0c\xa0'") + self.assertEqual(p.converter.c_default, + r'{.buf = " \t\n\r\v\f\240", .obj = NULL, .len = 7}') + + def test_param_default_byte_literal(self): + function = self.parse_function(r""" + module test + test.func + zero: char = b'\0' + one: char = b'\1' + lf: char = b'\n' + nbsp: char = b'\xa0' + """) + p = function.parameters['zero'] + self.assertEqual(p.default, b'\0') + self.assertEqual(p.converter.py_default, r"b'\x00'") + self.assertEqual(p.converter.c_default, r"'\0'") + + p = function.parameters['one'] + self.assertEqual(p.default, b'\1') + self.assertEqual(p.converter.py_default, r"b'\x01'") + self.assertEqual(p.converter.c_default, r"'\001'") + + p = function.parameters['lf'] + self.assertEqual(p.default, b'\n') + self.assertEqual(p.converter.py_default, r"b'\n'") + self.assertEqual(p.converter.c_default, r"'\n'") + + p = function.parameters['nbsp'] + self.assertEqual(p.default, b'\xa0') + self.assertEqual(p.converter.py_default, r"b'\xa0'") + self.assertEqual(p.converter.c_default, r"'\240'") + + def test_param_default_unicode_char(self): + function = self.parse_function(r""" + module test + test.func + zero: int(accept={str}) = '\0' + one: int(accept={str}) = '\1' + lf: int(accept={str}) = '\n' + nbsp: int(accept={str}) = '\xa0' + snake: int(accept={str}) = '\U0001f40d' + """) + p = function.parameters['zero'] + self.assertEqual(p.default, '\0') + self.assertEqual(p.converter.py_default, r"'\x00'") + self.assertEqual(p.converter.c_default, '0') + + p = function.parameters['one'] + self.assertEqual(p.default, '\1') + self.assertEqual(p.converter.py_default, r"'\x01'") + self.assertEqual(p.converter.c_default, '0x01') + + p = function.parameters['lf'] + self.assertEqual(p.default, '\n') + self.assertEqual(p.converter.py_default, r"'\n'") + self.assertEqual(p.converter.c_default, r"'\n'") + + p = function.parameters['nbsp'] + self.assertEqual(p.default, '\xa0') + self.assertEqual(p.converter.py_default, r"'\xa0'") + self.assertEqual(p.converter.c_default, '0xa0') + + p = function.parameters['snake'] + self.assertEqual(p.default, '\U0001f40d') + self.assertEqual(p.converter.py_default, "'\U0001f40d'") + self.assertEqual(p.converter.c_default, '0x1f40d') + + def test_param_default_bool(self): + function = self.parse_function(r""" + module test + test.func + bool: bool = True + intbool: bool(accept={int}) = True + intbool2: bool(accept={int}) = 2 + """) + p = function.parameters['bool'] + self.assertIs(p.default, True) + self.assertEqual(p.converter.py_default, 'True') + self.assertEqual(p.converter.c_default, '1') + + p = function.parameters['intbool'] + self.assertIs(p.default, True) + self.assertEqual(p.converter.py_default, 'True') + self.assertEqual(p.converter.c_default, '1') + + p = function.parameters['intbool2'] + self.assertEqual(p.default, 2) + self.assertEqual(p.converter.py_default, '2') + self.assertEqual(p.converter.c_default, '2') + + def test_param_default_expr_named_constant(self): + function = self.parse_function(""" + module os + os.access + follow_symlinks: int(c_default='MAXSIZE') = sys.maxsize + """) + p = function.parameters['follow_symlinks'] + self.assertEqual(sys.maxsize, p.default) + self.assertEqual("MAXSIZE", p.converter.c_default) + + err = ( + "When you specify a named constant ('sys.maxsize') as your default value, " + "you MUST specify a valid c_default." + ) + block = """ + module os + os.access + follow_symlinks: int = sys.maxsize + """ + self.expect_failure(block, err, lineno=2) + + def test_param_with_bizarre_default_fails_correctly(self): + template = """ + module os + os.access + follow_symlinks: int = {default} + """ + err = "Unsupported expression as default value" + for bad_default_value in ( + "{1, 2, 3}", + "3 if bool() else 4", + "[x for x in range(42)]" + ): + with self.subTest(bad_default=bad_default_value): + block = template.format(default=bad_default_value) + self.expect_failure(block, err, lineno=2) + + def test_unspecified_not_allowed_as_default_value(self): + block = """ + module os + os.access + follow_symlinks: int(c_default='MAXSIZE') = unspecified + """ + err = "'unspecified' is not a legal default value!" + exc = self.expect_failure(block, err, lineno=2) + self.assertNotIn('Malformed expression given as default value', str(exc)) + + def test_malformed_expression_as_default_value(self): + block = """ + module os + os.access + follow_symlinks: int(c_default='MAXSIZE') = 1/0 + """ + err = "Malformed expression given as default value" + self.expect_failure(block, err, lineno=2) + + def test_param_default_expr_binop(self): + err = ( + "When you specify an expression ('a + b') as your default value, " + "you MUST specify a valid c_default." + ) + block = """ + fn + follow_symlinks: int = a + b + """ + self.expect_failure(block, err, lineno=1) + + def test_param_no_docstring(self): + function = self.parse_function(""" + module os + os.access + follow_symlinks: bool = True + something_else: str = '' + """) + self.assertEqual(3, len(function.parameters)) + conv = function.parameters['something_else'].converter + self.assertIsInstance(conv, str_converter) + + def test_param_default_parameters_out_of_order(self): + err = ( + "Can't have a parameter without a default ('something_else') " + "after a parameter with a default!" + ) + block = """ + module os + os.access + follow_symlinks: bool = True + something_else: str + """ + self.expect_failure(block, err, lineno=3) + + def disabled_test_converter_arguments(self): + function = self.parse_function(""" + module os + os.access + path: path_t(allow_fd=1) + """) + p = function.parameters['path'] + self.assertEqual(1, p.converter.args['allow_fd']) + + def test_function_docstring(self): + function = self.parse_function(""" + module os + os.stat as os_stat_fn + + path: str + Path to be examined + Ensure that multiple lines are indented correctly. + + Perform a stat system call on the given path. + + Ensure that multiple lines are indented correctly. + Ensure that multiple lines are indented correctly. + """) + self.checkDocstring(function, """ + stat($module, /, path) + -- + + Perform a stat system call on the given path. + + path + Path to be examined + Ensure that multiple lines are indented correctly. + + Ensure that multiple lines are indented correctly. + Ensure that multiple lines are indented correctly. + """) + + def test_docstring_trailing_whitespace(self): + function = self.parse_function( + "module t\n" + "t.s\n" + " a: object\n" + " Param docstring with trailing whitespace \n" + "Func docstring summary with trailing whitespace \n" + " \n" + "Func docstring body with trailing whitespace \n" + ) + self.checkDocstring(function, """ + s($module, /, a) + -- + + Func docstring summary with trailing whitespace + + a + Param docstring with trailing whitespace + + Func docstring body with trailing whitespace + """) + + def test_explicit_parameters_in_docstring(self): + function = self.parse_function(dedent(""" + module foo + foo.bar + x: int + Documentation for x. + y: int + + This is the documentation for foo. + + Okay, we're done here. + """)) + self.checkDocstring(function, """ + bar($module, /, x, y) + -- + + This is the documentation for foo. + + x + Documentation for x. + + Okay, we're done here. + """) + + def test_docstring_with_comments(self): + function = self.parse_function(dedent(""" + module foo + foo.bar + x: int + # We're about to have + # the documentation for x. + Documentation for x. + # We've just had + # the documentation for x. + y: int + + # We're about to have + # the documentation for foo. + This is the documentation for foo. + # We've just had + # the documentation for foo. + + Okay, we're done here. + """)) + self.checkDocstring(function, """ + bar($module, /, x, y) + -- + + This is the documentation for foo. + + x + Documentation for x. + + Okay, we're done here. + """) + + def test_parser_regression_special_character_in_parameter_column_of_docstring_first_line(self): + function = self.parse_function(dedent(""" + module os + os.stat + path: str + This/used to break Clinic! + """)) + self.checkDocstring(function, """ + stat($module, /, path) + -- + + This/used to break Clinic! + """) + + def test_c_name(self): + function = self.parse_function(""" + module os + os.stat as os_stat_fn + """) + self.assertEqual("os_stat_fn", function.c_basename) + + def test_base_invalid_syntax(self): + block = """ + module os + os.stat + invalid syntax: int = 42 + """ + err = "Function 'stat' has an invalid parameter declaration: 'invalid syntax: int = 42'" + self.expect_failure(block, err, lineno=2) + + def test_param_default_invalid_syntax(self): + block = """ + module os + os.stat + x: int = invalid syntax + """ + err = "Function 'stat' has an invalid parameter declaration:" + self.expect_failure(block, err, lineno=2) + + def test_cloning_nonexistent_function_correctly_fails(self): + block = """ + cloned = fooooooooooooooooo + This is trying to clone a nonexistent function!! + """ + err = "Couldn't find existing function 'fooooooooooooooooo'!" + with support.captured_stderr() as stderr: + self.expect_failure(block, err, lineno=0) + expected_debug_print = dedent("""\ + cls=None, module=, existing='fooooooooooooooooo' + (cls or module).functions=[] + """) + stderr = stderr.getvalue() + self.assertIn(expected_debug_print, stderr) + + def test_return_converter(self): + function = self.parse_function(""" + module os + os.stat -> int + """) + self.assertIsInstance(function.return_converter, int_return_converter) + + def test_return_converter_invalid_syntax(self): + block = """ + module os + os.stat -> invalid syntax + """ + err = "Badly formed annotation for 'os.stat': 'invalid syntax'" + self.expect_failure(block, err) + + def test_legacy_converter_disallowed_in_return_annotation(self): + block = """ + module os + os.stat -> "s" + """ + err = "Legacy converter 's' not allowed as a return converter" + self.expect_failure(block, err) + + def test_unknown_return_converter(self): + block = """ + module os + os.stat -> fooooooooooooooooooooooo + """ + err = "No available return converter called 'fooooooooooooooooooooooo'" + self.expect_failure(block, err) + + def test_star(self): + function = self.parse_function(""" + module os + os.access + * + follow_symlinks: bool = True + """) + p = function.parameters['follow_symlinks'] + self.assertEqual(inspect.Parameter.KEYWORD_ONLY, p.kind) + self.assertEqual(0, p.group) + + def test_group(self): + function = self.parse_function(""" + module window + window.border + [ + ls: int + ] + / + """) + p = function.parameters['ls'] + self.assertEqual(1, p.group) + + def test_left_group(self): + function = self.parse_function(""" + module curses + curses.addch + [ + y: int + Y-coordinate. + x: int + X-coordinate. + ] + ch: char + Character to add. + [ + attr: long + Attributes for the character. + ] + / + """) + dataset = ( + ('y', -1), ('x', -1), + ('ch', 0), + ('attr', 1), + ) + for name, group in dataset: + with self.subTest(name=name, group=group): + p = function.parameters[name] + self.assertEqual(p.group, group) + self.assertEqual(p.kind, inspect.Parameter.POSITIONAL_ONLY) + self.checkDocstring(function, """ + addch([y, x,] ch, [attr]) + + + y + Y-coordinate. + x + X-coordinate. + ch + Character to add. + attr + Attributes for the character. + """) + + def test_nested_groups(self): + function = self.parse_function(""" + module curses + curses.imaginary + [ + [ + y1: int + Y-coordinate. + y2: int + Y-coordinate. + ] + x1: int + X-coordinate. + x2: int + X-coordinate. + ] + ch: char + Character to add. + [ + attr1: long + Attributes for the character. + attr2: long + Attributes for the character. + attr3: long + Attributes for the character. + [ + attr4: long + Attributes for the character. + attr5: long + Attributes for the character. + attr6: long + Attributes for the character. + ] + ] + / + """) + dataset = ( + ('y1', -2), ('y2', -2), + ('x1', -1), ('x2', -1), + ('ch', 0), + ('attr1', 1), ('attr2', 1), ('attr3', 1), + ('attr4', 2), ('attr5', 2), ('attr6', 2), + ) + for name, group in dataset: + with self.subTest(name=name, group=group): + p = function.parameters[name] + self.assertEqual(p.group, group) + self.assertEqual(p.kind, inspect.Parameter.POSITIONAL_ONLY) + + self.checkDocstring(function, """ + imaginary([[y1, y2,] x1, x2,] ch, [attr1, attr2, attr3, [attr4, attr5, + attr6]]) + + + y1 + Y-coordinate. + y2 + Y-coordinate. + x1 + X-coordinate. + x2 + X-coordinate. + ch + Character to add. + attr1 + Attributes for the character. + attr2 + Attributes for the character. + attr3 + Attributes for the character. + attr4 + Attributes for the character. + attr5 + Attributes for the character. + attr6 + Attributes for the character. + """) + + def test_disallowed_grouping__two_top_groups_on_left(self): + err = ( + "Function 'two_top_groups_on_left' has an unsupported group " + "configuration. (Unexpected state 2.b)" + ) + block = """ + module foo + foo.two_top_groups_on_left + [ + group1 : int + ] + [ + group2 : int + ] + param: int + """ + self.expect_failure(block, err, lineno=5) + + def test_disallowed_grouping__two_top_groups_on_right(self): + block = """ + module foo + foo.two_top_groups_on_right + param: int + [ + group1 : int + ] + [ + group2 : int + ] + """ + err = ( + "Function 'two_top_groups_on_right' has an unsupported group " + "configuration. (Unexpected state 6.b)" + ) + self.expect_failure(block, err) + + def test_disallowed_grouping__parameter_after_group_on_right(self): + block = """ + module foo + foo.parameter_after_group_on_right + param: int + [ + [ + group1 : int + ] + group2 : int + ] + """ + err = ( + "Function parameter_after_group_on_right has an unsupported group " + "configuration. (Unexpected state 6.a)" + ) + self.expect_failure(block, err) + + def test_disallowed_grouping__group_after_parameter_on_left(self): + block = """ + module foo + foo.group_after_parameter_on_left + [ + group2 : int + [ + group1 : int + ] + ] + param: int + """ + err = ( + "Function 'group_after_parameter_on_left' has an unsupported group " + "configuration. (Unexpected state 2.b)" + ) + self.expect_failure(block, err) + + def test_disallowed_grouping__empty_group_on_left(self): + block = """ + module foo + foo.empty_group + [ + [ + ] + group2 : int + ] + param: int + """ + err = ( + "Function 'empty_group' has an empty group. " + "All groups must contain at least one parameter." + ) + self.expect_failure(block, err) + + def test_disallowed_grouping__empty_group_on_right(self): + block = """ + module foo + foo.empty_group + param: int + [ + [ + ] + group2 : int + ] + """ + err = ( + "Function 'empty_group' has an empty group. " + "All groups must contain at least one parameter." + ) + self.expect_failure(block, err) + + def test_disallowed_grouping__no_matching_bracket(self): + block = """ + module foo + foo.empty_group + param: int + ] + group2: int + ] + """ + err = "Function 'empty_group' has a ']' without a matching '['" + self.expect_failure(block, err) + + def test_disallowed_grouping__must_be_position_only(self): + dataset = (""" + with_kwds + [ + * + a: object + ] + """, """ + with_kwds + [ + a: object + ] + """) + err = ( + "You cannot use optional groups ('[' and ']') unless all " + "parameters are positional-only ('/')" + ) + for block in dataset: + with self.subTest(block=block): + self.expect_failure(block, err) + + def test_no_parameters(self): + function = self.parse_function(""" + module foo + foo.bar + + Docstring + + """) + self.assertEqual("bar($module, /)\n--\n\nDocstring", function.docstring) + self.assertEqual(1, len(function.parameters)) # self! + + def test_init_with_no_parameters(self): + function = self.parse_function(""" + module foo + class foo.Bar "unused" "notneeded" + foo.Bar.__init__ + + Docstring + + """, signatures_in_block=3, function_index=2) + + # self is not in the signature + self.assertEqual("Bar()\n--\n\nDocstring", function.docstring) + # but it *is* a parameter + self.assertEqual(1, len(function.parameters)) + + def test_illegal_module_line(self): + block = """ + module foo + foo.bar => int + / + """ + err = "Illegal function name: 'foo.bar => int'" + self.expect_failure(block, err) + + def test_illegal_c_basename(self): + block = """ + module foo + foo.bar as 935 + / + """ + err = "Illegal C basename: '935'" + self.expect_failure(block, err) + + def test_no_c_basename(self): + block = "foo as " + err = "No C basename provided after 'as' keyword" + self.expect_failure(block, err, strip=False) + + def test_single_star(self): + block = """ + module foo + foo.bar + * + * + """ + err = "Function 'bar' uses '*' more than once." + self.expect_failure(block, err) + + def test_parameters_required_after_star(self): + dataset = ( + "module foo\nfoo.bar\n *", + "module foo\nfoo.bar\n *\nDocstring here.", + "module foo\nfoo.bar\n this: int\n *", + "module foo\nfoo.bar\n this: int\n *\nDocstring.", + ) + err = "Function 'bar' specifies '*' without following parameters." + for block in dataset: + with self.subTest(block=block): + self.expect_failure(block, err) + + def test_fulldisplayname_class(self): + dataset = ( + ("T", """ + class T "void *" "" + T.__init__ + """), + ("m.T", """ + module m + class m.T "void *" "" + @classmethod + m.T.__new__ + """), + ("m.T.C", """ + module m + class m.T "void *" "" + class m.T.C "void *" "" + m.T.C.__init__ + """), + ) + for name, code in dataset: + with self.subTest(name=name, code=code): + block = self.parse(code) + func = block.signatures[-1] + self.assertEqual(func.fulldisplayname, name) + + def test_fulldisplayname_meth(self): + dataset = ( + ("func", "func"), + ("m.func", """ + module m + m.func + """), + ("T.meth", """ + class T "void *" "" + T.meth + """), + ("m.T.meth", """ + module m + class m.T "void *" "" + m.T.meth + """), + ("m.T.C.meth", """ + module m + class m.T "void *" "" + class m.T.C "void *" "" + m.T.C.meth + """), + ) + for name, code in dataset: + with self.subTest(name=name, code=code): + block = self.parse(code) + func = block.signatures[-1] + self.assertEqual(func.fulldisplayname, name) + + def test_depr_star_invalid_format_1(self): + block = """ + module foo + foo.bar + this: int + * [from 3] + Docstring. + """ + err = ( + "Function 'bar': expected format '[from major.minor]' " + "where 'major' and 'minor' are integers; got '3'" + ) + self.expect_failure(block, err, lineno=3) + + def test_depr_star_invalid_format_2(self): + block = """ + module foo + foo.bar + this: int + * [from a.b] + Docstring. + """ + err = ( + "Function 'bar': expected format '[from major.minor]' " + "where 'major' and 'minor' are integers; got 'a.b'" + ) + self.expect_failure(block, err, lineno=3) + + def test_depr_star_invalid_format_3(self): + block = """ + module foo + foo.bar + this: int + * [from 1.2.3] + Docstring. + """ + err = ( + "Function 'bar': expected format '[from major.minor]' " + "where 'major' and 'minor' are integers; got '1.2.3'" + ) + self.expect_failure(block, err, lineno=3) + + def test_parameters_required_after_depr_star(self): + block = """ + module foo + foo.bar + this: int + * [from 3.14] + Docstring. + """ + err = ( + "Function 'bar' specifies '* [from ...]' without " + "following parameters." + ) + self.expect_failure(block, err, lineno=4) + + def test_parameters_required_after_depr_star2(self): + block = """ + module foo + foo.bar + a: int + * [from 3.14] + * + b: int + Docstring. + """ + err = ( + "Function 'bar' specifies '* [from ...]' without " + "following parameters." + ) + self.expect_failure(block, err, lineno=4) + + def test_parameters_required_after_depr_star3(self): + block = """ + module foo + foo.bar + a: int + * [from 3.14] + *args: tuple + b: int + Docstring. + """ + err = ( + "Function 'bar' specifies '* [from ...]' without " + "following parameters." + ) + self.expect_failure(block, err, lineno=4) + + def test_depr_star_must_come_before_star(self): + block = """ + module foo + foo.bar + a: int + * + * [from 3.14] + b: int + Docstring. + """ + err = "Function 'bar': '* [from ...]' must precede '*'" + self.expect_failure(block, err, lineno=4) + + def test_depr_star_must_come_before_vararg(self): + block = """ + module foo + foo.bar + a: int + *args: tuple + * [from 3.14] + b: int + Docstring. + """ + err = "Function 'bar': '* [from ...]' must precede '*'" + self.expect_failure(block, err, lineno=4) + + def test_depr_star_duplicate(self): + block = """ + module foo + foo.bar + a: int + * [from 3.14] + b: int + * [from 3.14] + c: int + Docstring. + """ + err = "Function 'bar' uses '* [from 3.14]' more than once." + self.expect_failure(block, err, lineno=5) + + def test_depr_star_duplicate2(self): + block = """ + module foo + foo.bar + a: int + * [from 3.14] + b: int + * [from 3.15] + c: int + Docstring. + """ + err = "Function 'bar': '* [from 3.15]' must precede '* [from 3.14]'" + self.expect_failure(block, err, lineno=5) + + def test_depr_slash_duplicate(self): + block = """ + module foo + foo.bar + a: int + / [from 3.14] + b: int + / [from 3.14] + c: int + Docstring. + """ + err = "Function 'bar' uses '/ [from 3.14]' more than once." + self.expect_failure(block, err, lineno=5) + + def test_depr_slash_duplicate2(self): + block = """ + module foo + foo.bar + a: int + / [from 3.15] + b: int + / [from 3.14] + c: int + Docstring. + """ + err = "Function 'bar': '/ [from 3.14]' must precede '/ [from 3.15]'" + self.expect_failure(block, err, lineno=5) + + def test_single_slash(self): + block = """ + module foo + foo.bar + / + / + """ + err = ( + "Function 'bar' has an unsupported group configuration. " + "(Unexpected state 0.d)" + ) + self.expect_failure(block, err) + + def test_parameters_required_before_depr_slash(self): + block = """ + module foo + foo.bar + / [from 3.14] + Docstring. + """ + err = ( + "Function 'bar' specifies '/ [from ...]' without " + "preceding parameters." + ) + self.expect_failure(block, err, lineno=2) + + def test_parameters_required_before_depr_slash2(self): + block = """ + module foo + foo.bar + a: int + / + / [from 3.14] + Docstring. + """ + err = ( + "Function 'bar' specifies '/ [from ...]' without " + "preceding parameters." + ) + self.expect_failure(block, err, lineno=4) + + def test_double_slash(self): + block = """ + module foo + foo.bar + a: int + / + b: int + / + """ + err = "Function 'bar' uses '/' more than once." + self.expect_failure(block, err) + + def test_slash_after_star(self): + block = """ + module foo + foo.bar + x: int + y: int + * + z: int + / + """ + err = "Function 'bar': '/' must precede '*'" + self.expect_failure(block, err) + + def test_slash_after_vararg(self): + block = """ + module foo + foo.bar + x: int + y: int + *args: tuple + z: int + / + """ + err = "Function 'bar': '/' must precede '*'" + self.expect_failure(block, err) + + def test_depr_star_must_come_after_slash(self): + block = """ + module foo + foo.bar + a: int + * [from 3.14] + / + b: int + Docstring. + """ + err = "Function 'bar': '/' must precede '* [from ...]'" + self.expect_failure(block, err, lineno=4) + + def test_depr_star_must_come_after_depr_slash(self): + block = """ + module foo + foo.bar + a: int + * [from 3.14] + / [from 3.14] + b: int + Docstring. + """ + err = "Function 'bar': '/ [from ...]' must precede '* [from ...]'" + self.expect_failure(block, err, lineno=4) + + def test_star_must_come_after_depr_slash(self): + block = """ + module foo + foo.bar + a: int + * + / [from 3.14] + b: int + Docstring. + """ + err = "Function 'bar': '/ [from ...]' must precede '*'" + self.expect_failure(block, err, lineno=4) + + def test_vararg_must_come_after_depr_slash(self): + block = """ + module foo + foo.bar + a: int + *args: tuple + / [from 3.14] + b: int + Docstring. + """ + err = "Function 'bar': '/ [from ...]' must precede '*'" + self.expect_failure(block, err, lineno=4) + + def test_depr_slash_must_come_after_slash(self): + block = """ + module foo + foo.bar + a: int + / [from 3.14] + / + b: int + Docstring. + """ + err = "Function 'bar': '/' must precede '/ [from ...]'" + self.expect_failure(block, err, lineno=4) + + def test_parameters_not_permitted_after_slash_for_now(self): + block = """ + module foo + foo.bar + / + x: int + """ + err = ( + "Function 'bar' has an unsupported group configuration. " + "(Unexpected state 0.d)" + ) + self.expect_failure(block, err) + + def test_parameters_no_more_than_one_vararg(self): + err = "Function 'bar' uses '*' more than once." + block = """ + module foo + foo.bar + *vararg1: tuple + *vararg2: tuple + """ + self.expect_failure(block, err, lineno=3) + + def test_function_not_at_column_0(self): + function = self.parse_function(""" + module foo + foo.bar + x: int + Nested docstring here, goeth. + * + y: str + Not at column 0! + """) + self.checkDocstring(function, """ + bar($module, /, x, *, y) + -- + + Not at column 0! + + x + Nested docstring here, goeth. + """) + + def test_docstring_only_summary(self): + function = self.parse_function(""" + module m + m.f + summary + """) + self.checkDocstring(function, """ + f($module, /) + -- + + summary + """) + + def test_docstring_empty_lines(self): + function = self.parse_function(""" + module m + m.f + + + """) + self.checkDocstring(function, """ + f($module, /) + -- + """) + + def test_docstring_explicit_params_placement(self): + function = self.parse_function(""" + module m + m.f + a: int + Param docstring for 'a' will be included + b: int + c: int + Param docstring for 'c' will be included + This is the summary line. + + We'll now place the params section here: + {parameters} + And now for something completely different! + (Note the added newline) + """) + self.checkDocstring(function, """ + f($module, /, a, b, c) + -- + + This is the summary line. + + We'll now place the params section here: + a + Param docstring for 'a' will be included + c + Param docstring for 'c' will be included + + And now for something completely different! + (Note the added newline) + """) + + def test_indent_stack_no_tabs(self): + block = """ + module foo + foo.bar + *vararg1: tuple + \t*vararg2: tuple + """ + err = ("Tab characters are illegal in the Clinic DSL: " + r"'\t*vararg2: tuple'") + self.expect_failure(block, err) + + def test_indent_stack_illegal_outdent(self): + block = """ + module foo + foo.bar + a: object + b: object + """ + err = "Illegal outdent" + self.expect_failure(block, err) + + def test_directive(self): + parser = DSLParser(_make_clinic()) + parser.flag = False + parser.directives['setflag'] = lambda : setattr(parser, 'flag', True) + block = Block("setflag") + parser.parse(block) + self.assertTrue(parser.flag) + + def test_legacy_converters(self): + block = self.parse('module os\nos.access\n path: "s"') + module, function = block.signatures + conv = (function.parameters['path']).converter + self.assertIsInstance(conv, str_converter) + + def test_legacy_converters_non_string_constant_annotation(self): + err = "Annotations must be either a name, a function call, or a string" + dataset = ( + 'module os\nos.access\n path: 42', + 'module os\nos.access\n path: 42.42', + 'module os\nos.access\n path: 42j', + 'module os\nos.access\n path: b"42"', + ) + for block in dataset: + with self.subTest(block=block): + self.expect_failure(block, err, lineno=2) + + def test_other_bizarre_things_in_annotations_fail(self): + err = "Annotations must be either a name, a function call, or a string" + dataset = ( + 'module os\nos.access\n path: {"some": "dictionary"}', + 'module os\nos.access\n path: ["list", "of", "strings"]', + 'module os\nos.access\n path: (x for x in range(42))', + ) + for block in dataset: + with self.subTest(block=block): + self.expect_failure(block, err, lineno=2) + + def test_kwarg_splats_disallowed_in_function_call_annotations(self): + err = "Cannot use a kwarg splat in a function-call annotation" + dataset = ( + 'module fo\nfo.barbaz\n o: bool(**{None: "bang!"})', + 'module fo\nfo.barbaz -> bool(**{None: "bang!"})', + 'module fo\nfo.barbaz -> bool(**{"bang": 42})', + 'module fo\nfo.barbaz\n o: bool(**{"bang": None})', + ) + for block in dataset: + with self.subTest(block=block): + self.expect_failure(block, err) + + def test_self_param_placement(self): + err = ( + "A 'self' parameter, if specified, must be the very first thing " + "in the parameter block." + ) + block = """ + module foo + foo.func + a: int + self: self(type="PyObject *") + """ + self.expect_failure(block, err, lineno=3) + + def test_self_param_cannot_be_optional(self): + err = "A 'self' parameter cannot be marked optional." + block = """ + module foo + foo.func + self: self(type="PyObject *") = None + """ + self.expect_failure(block, err, lineno=2) + + def test_defining_class_param_placement(self): + err = ( + "A 'defining_class' parameter, if specified, must either be the " + "first thing in the parameter block, or come just after 'self'." + ) + block = """ + module foo + foo.func + self: self(type="PyObject *") + a: int + cls: defining_class + """ + self.expect_failure(block, err, lineno=4) + + def test_defining_class_param_cannot_be_optional(self): + err = "A 'defining_class' parameter cannot be marked optional." + block = """ + module foo + foo.func + cls: defining_class(type="PyObject *") = None + """ + self.expect_failure(block, err, lineno=2) + + def test_slot_methods_cannot_access_defining_class(self): + block = """ + module foo + class Foo "" "" + Foo.__init__ + cls: defining_class + a: object + """ + err = "Slot methods cannot access their defining class." + with self.assertRaisesRegex(ValueError, err): + self.parse_function(block) + + def test_new_must_be_a_class_method(self): + err = "'__new__' must be a class method!" + block = """ + module foo + class Foo "" "" + Foo.__new__ + """ + self.expect_failure(block, err, lineno=2) + + def test_init_must_be_a_normal_method(self): + err_template = "'__init__' must be a normal method; got 'FunctionKind.{}'!" + annotations = { + "@classmethod": "CLASS_METHOD", + "@staticmethod": "STATIC_METHOD", + "@getter": "GETTER", + } + for annotation, invalid_kind in annotations.items(): + with self.subTest(annotation=annotation, invalid_kind=invalid_kind): + block = f""" + module foo + class Foo "" "" + {annotation} + Foo.__init__ + """ + expected_error = err_template.format(invalid_kind) + self.expect_failure(block, expected_error, lineno=3) + + def test_init_cannot_define_a_return_type(self): + block = """ + class Foo "" "" + Foo.__init__ -> long + """ + expected_error = "__init__ methods cannot define a return type" + self.expect_failure(block, expected_error, lineno=1) + + def test_invalid_getset(self): + annotations = ["@getter", "@setter"] + for annotation in annotations: + with self.subTest(annotation=annotation): + block = f""" + module foo + class Foo "" "" + {annotation} + Foo.property -> int + """ + expected_error = f"{annotation} method cannot define a return type" + self.expect_failure(block, expected_error, lineno=3) + + block = f""" + module foo + class Foo "" "" + {annotation} + Foo.property + obj: int + / + """ + expected_error = f"{annotation} methods cannot define parameters" + self.expect_failure(block, expected_error) + + def test_setter_docstring(self): + block = """ + module foo + class Foo "" "" + @setter + Foo.property + + foo + + bar + [clinic start generated code]*/ + """ + expected_error = "docstrings are only supported for @getter, not @setter" + self.expect_failure(block, expected_error) + + def test_duplicate_getset(self): + annotations = ["@getter", "@setter"] + for annotation in annotations: + with self.subTest(annotation=annotation): + block = f""" + module foo + class Foo "" "" + {annotation} + {annotation} + Foo.property -> int + """ + expected_error = f"Cannot apply {annotation} twice to the same function!" + self.expect_failure(block, expected_error, lineno=3) + + def test_getter_and_setter_disallowed_on_same_function(self): + dup_annotations = [("@getter", "@setter"), ("@setter", "@getter")] + for dup in dup_annotations: + with self.subTest(dup=dup): + block = f""" + module foo + class Foo "" "" + {dup[0]} + {dup[1]} + Foo.property -> int + """ + expected_error = "Cannot apply both @getter and @setter to the same function!" + self.expect_failure(block, expected_error, lineno=3) + + def test_getset_no_class(self): + for annotation in "@getter", "@setter": + with self.subTest(annotation=annotation): + block = f""" + module m + {annotation} + m.func + """ + expected_error = "@getter and @setter must be methods" + self.expect_failure(block, expected_error, lineno=2) + + def test_duplicate_coexist(self): + err = "Called @coexist twice" + block = """ + module m + @coexist + @coexist + m.fn + """ + self.expect_failure(block, err, lineno=2) + + def test_unused_param(self): + block = self.parse(""" + module foo + foo.func + fn: object + k: float + i: float(unused=True) + / + * + flag: bool(unused=True) = False + """) + sig = block.signatures[1] # Function index == 1 + params = sig.parameters + conv = lambda fn: params[fn].converter + dataset = ( + {"name": "fn", "unused": False}, + {"name": "k", "unused": False}, + {"name": "i", "unused": True}, + {"name": "flag", "unused": True}, + ) + for param in dataset: + name, unused = param.values() + with self.subTest(name=name, unused=unused): + p = conv(name) + # Verify that the unused flag is parsed correctly. + self.assertEqual(unused, p.unused) + + # Now, check that we'll produce correct code. + decl = p.simple_declaration(in_parser=False) + if unused: + self.assertIn("Py_UNUSED", decl) + else: + self.assertNotIn("Py_UNUSED", decl) + + # Make sure the Py_UNUSED macro is not used in the parser body. + parser_decl = p.simple_declaration(in_parser=True) + self.assertNotIn("Py_UNUSED", parser_decl) + + def test_scaffolding(self): + # test repr on special values + self.assertEqual(repr(unspecified), '') + self.assertEqual(repr(NULL), '') + + # test that fail fails + with support.captured_stdout() as stdout: + errmsg = 'The igloos are melting' + with self.assertRaisesRegex(ClinicError, errmsg) as cm: + fail(errmsg, filename='clown.txt', line_number=69) + exc = cm.exception + self.assertEqual(exc.filename, 'clown.txt') + self.assertEqual(exc.lineno, 69) + self.assertEqual(stdout.getvalue(), "") + + def test_non_ascii_character_in_docstring(self): + block = """ + module test + test.fn + a: int + á param docstring + docstring fü bár baß + """ + with support.captured_stdout() as stdout: + self.parse(block) + # The line numbers are off; this is a known limitation. + expected = dedent("""\ + Warning: + Non-ascii characters are not allowed in docstrings: 'á' + + Warning: + Non-ascii characters are not allowed in docstrings: 'ü', 'á', 'ß' + + """) + self.assertEqual(stdout.getvalue(), expected) + + def test_illegal_c_identifier(self): + err = "Illegal C identifier: 17a" + block = """ + module test + test.fn + a as 17a: int + """ + self.expect_failure(block, err, lineno=2) + + def test_cannot_convert_special_method(self): + err = "'__len__' is a special method and cannot be converted" + block = """ + class T "" "" + T.__len__ + """ + self.expect_failure(block, err, lineno=1) + + def test_cannot_specify_pydefault_without_default(self): + err = "You can't specify py_default without specifying a default value!" + block = """ + fn + a: object(py_default='NULL') + """ + self.expect_failure(block, err, lineno=1) + + def test_vararg_cannot_take_default_value(self): + err = "Function 'fn' has an invalid parameter declaration:" + block = """ + fn + *args: tuple = None + """ + self.expect_failure(block, err, lineno=1) + + def test_default_is_not_of_correct_type(self): + err = ("int_converter: default value 2.5 for field 'a' " + "is not of type 'int'") + block = """ + fn + a: int = 2.5 + """ + self.expect_failure(block, err, lineno=1) + + def test_invalid_legacy_converter(self): + err = "'fhi' is not a valid legacy converter" + block = """ + fn + a: 'fhi' + """ + self.expect_failure(block, err, lineno=1) + + def test_parent_class_or_module_does_not_exist(self): + err = "Parent class or module 'baz' does not exist" + block = """ + module m + baz.func + """ + self.expect_failure(block, err, lineno=1) + + def test_duplicate_param_name(self): + err = "You can't have two parameters named 'a'" + block = """ + module m + m.func + a: int + a: float + """ + self.expect_failure(block, err, lineno=3) + + def test_param_requires_custom_c_name(self): + err = "Parameter 'module' requires a custom C name" + block = """ + module m + m.func + module: int + """ + self.expect_failure(block, err, lineno=2) + + def test_state_func_docstring_assert_no_group(self): + err = "Function 'func' has a ']' without a matching '['" + block = """ + module m + m.func + ] + docstring + """ + self.expect_failure(block, err, lineno=2) + + def test_state_func_docstring_no_summary(self): + err = "Docstring for 'm.func' does not have a summary line!" + block = """ + module m + m.func + docstring1 + docstring2 + """ + self.expect_failure(block, err, lineno=3) + + def test_state_func_docstring_only_one_param_template(self): + err = "You may not specify {parameters} more than once in a docstring!" + block = """ + module m + m.func + docstring summary + + these are the params: + {parameters} + these are the params again: + {parameters} + """ + self.expect_failure(block, err, lineno=7) + + def test_kind_defining_class(self): + function = self.parse_function(""" + module m + class m.C "PyObject *" "" + m.C.meth + cls: defining_class + """, signatures_in_block=3, function_index=2) + p = function.parameters['cls'] + self.assertEqual(p.kind, inspect.Parameter.POSITIONAL_ONLY) + + def test_disallow_defining_class_at_module_level(self): + err = "A 'defining_class' parameter cannot be defined at module level." + block = """ + module m + m.func + cls: defining_class + """ + self.expect_failure(block, err, lineno=2) + + +class ClinicExternalTest(TestCase): + maxDiff = None + + def setUp(self): + save_restore_converters(self) + + def run_clinic(self, *args): + with ( + support.captured_stdout() as out, + support.captured_stderr() as err, + self.assertRaises(SystemExit) as cm + ): + clinic.main(args) + return out.getvalue(), err.getvalue(), cm.exception.code + + def expect_success(self, *args): + out, err, code = self.run_clinic(*args) + if code != 0: + self.fail("\n".join([f"Unexpected failure: {args=}", out, err])) + self.assertEqual(err, "") + return out + + def expect_failure(self, *args): + out, err, code = self.run_clinic(*args) + self.assertNotEqual(code, 0, f"Unexpected success: {args=}") + return out, err + + def test_external(self): + CLINIC_TEST = 'clinic.test.c' + source = support.findfile(CLINIC_TEST) + with open(source, encoding='utf-8') as f: + orig_contents = f.read() + + # Run clinic CLI and verify that it does not complain. + self.addCleanup(unlink, TESTFN) + out = self.expect_success("-f", "-o", TESTFN, source) + self.assertEqual(out, "") + + with open(TESTFN, encoding='utf-8') as f: + new_contents = f.read() + + self.assertEqual(new_contents, orig_contents) + + def test_no_change(self): + # bpo-42398: Test that the destination file is left unchanged if the + # content does not change. Moreover, check also that the file + # modification time does not change in this case. + code = dedent(""" + /*[clinic input] + [clinic start generated code]*/ + /*[clinic end generated code: output=da39a3ee5e6b4b0d input=da39a3ee5e6b4b0d]*/ + """) + with os_helper.temp_dir() as tmp_dir: + fn = os.path.join(tmp_dir, "test.c") + with open(fn, "w", encoding="utf-8") as f: + f.write(code) + pre_mtime = os.stat(fn).st_mtime_ns + self.expect_success(fn) + post_mtime = os.stat(fn).st_mtime_ns + # Don't change the file modification time + # if the content does not change + self.assertEqual(pre_mtime, post_mtime) + + def test_cli_force(self): + invalid_input = dedent(""" + /*[clinic input] + output preset block + module test + test.fn + a: int + [clinic start generated code]*/ + + const char *hand_edited = "output block is overwritten"; + /*[clinic end generated code: output=bogus input=bogus]*/ + """) + fail_msg = ( + "Checksum mismatch! Expected 'bogus', computed '2ed19'. " + "Suggested fix: remove all generated code including the end marker, " + "or use the '-f' option.\n" + ) + with os_helper.temp_dir() as tmp_dir: + fn = os.path.join(tmp_dir, "test.c") + with open(fn, "w", encoding="utf-8") as f: + f.write(invalid_input) + # First, run the CLI without -f and expect failure. + # Note, we cannot check the entire fail msg, because the path to + # the tmp file will change for every run. + _, err = self.expect_failure(fn) + self.assertEndsWith(err, fail_msg) + # Then, force regeneration; success expected. + out = self.expect_success("-f", fn) + self.assertEqual(out, "") + # Verify by checking the checksum. + checksum = ( + "/*[clinic end generated code: " + "output=a2957bc4d43a3c2f input=9543a8d2da235301]*/\n" + ) + with open(fn, encoding='utf-8') as f: + generated = f.read() + self.assertEndsWith(generated, checksum) + + def test_cli_make(self): + c_code = dedent(""" + /*[clinic input] + [clinic start generated code]*/ + """) + py_code = "pass" + c_files = "file1.c", "file2.c" + py_files = "file1.py", "file2.py" + + def create_files(files, srcdir, code): + for fn in files: + path = os.path.join(srcdir, fn) + with open(path, "w", encoding="utf-8") as f: + f.write(code) + + with os_helper.temp_dir() as tmp_dir: + # add some folders, some C files and a Python file + create_files(c_files, tmp_dir, c_code) + create_files(py_files, tmp_dir, py_code) + + # create C files in externals/ dir + ext_path = os.path.join(tmp_dir, "externals") + with os_helper.temp_dir(path=ext_path) as externals: + create_files(c_files, externals, c_code) + + # run clinic in verbose mode with --make on tmpdir + out = self.expect_success("-v", "--make", "--srcdir", tmp_dir) + + # expect verbose mode to only mention the C files in tmp_dir + for filename in c_files: + with self.subTest(filename=filename): + path = os.path.join(tmp_dir, filename) + self.assertIn(path, out) + for filename in py_files: + with self.subTest(filename=filename): + path = os.path.join(tmp_dir, filename) + self.assertNotIn(path, out) + # don't expect C files from the externals dir + for filename in c_files: + with self.subTest(filename=filename): + path = os.path.join(ext_path, filename) + self.assertNotIn(path, out) + + def test_cli_make_exclude(self): + code = dedent(""" + /*[clinic input] + [clinic start generated code]*/ + """) + with os_helper.temp_dir(quiet=False) as tmp_dir: + # add some folders, some C files and a Python file + for fn in "file1.c", "file2.c", "file3.c", "file4.c": + path = os.path.join(tmp_dir, fn) + with open(path, "w", encoding="utf-8") as f: + f.write(code) + + # Run clinic in verbose mode with --make on tmpdir. + # Exclude file2.c and file3.c. + out = self.expect_success( + "-v", "--make", "--srcdir", tmp_dir, + "--exclude", os.path.join(tmp_dir, "file2.c"), + # The added ./ should be normalised away. + "--exclude", os.path.join(tmp_dir, "./file3.c"), + # Relative paths should also work. + "--exclude", "file4.c" + ) + + # expect verbose mode to only mention the C files in tmp_dir + self.assertIn("file1.c", out) + self.assertNotIn("file2.c", out) + self.assertNotIn("file3.c", out) + self.assertNotIn("file4.c", out) + + def test_cli_verbose(self): + with os_helper.temp_dir() as tmp_dir: + fn = os.path.join(tmp_dir, "test.c") + with open(fn, "w", encoding="utf-8") as f: + f.write("") + out = self.expect_success("-v", fn) + self.assertEqual(out.strip(), fn) + + @support.force_not_colorized + def test_cli_help(self): + out = self.expect_success("-h") + self.assertIn("usage: clinic.py", out) + + def test_cli_converters(self): + prelude = dedent(""" + Legacy converters: + B C D L O S U Y Z Z# + b c d f h i l p s s# s* u u# w* y y# y* z z# z* + + Converters: + """) + expected_converters = ( + "bool", + "byte", + "char", + "defining_class", + "double", + "fildes", + "float", + "int", + "long", + "long_long", + "object", + "Py_buffer", + "Py_complex", + "Py_ssize_t", + "Py_UNICODE", + "PyByteArrayObject", + "PyBytesObject", + "self", + "short", + "size_t", + "slice_index", + "str", + "uint16", + "uint32", + "uint64", + "uint8", + "unicode", + "unicode_fs_decoded", + "unicode_fs_encoded", + "unsigned_char", + "unsigned_int", + "unsigned_long", + "unsigned_long_long", + "unsigned_short", + ) + finale = dedent(""" + Return converters: + bool() + double() + float() + int() + long() + object() + Py_ssize_t() + size_t() + unsigned_int() + unsigned_long() + + All converters also accept (c_default=None, py_default=None, annotation=None). + All return converters also accept (py_default=None). + """) + out = self.expect_success("--converters") + # We cannot simply compare the output, because the repr of the *accept* + # param may change (it's a set, thus unordered). So, let's compare the + # start and end of the expected output, and then assert that the + # converters appear lined up in alphabetical order. + self.assertStartsWith(out, prelude) + self.assertEndsWith(out, finale) + + out = out.removeprefix(prelude) + out = out.removesuffix(finale) + lines = out.split("\n") + for converter, line in zip(expected_converters, lines): + line = line.lstrip() + with self.subTest(converter=converter): + self.assertStartsWith(line, converter) + + def test_cli_fail_converters_and_filename(self): + _, err = self.expect_failure("--converters", "test.c") + msg = "can't specify --converters and a filename at the same time" + self.assertIn(msg, err) + + def test_cli_fail_no_filename(self): + _, err = self.expect_failure() + self.assertIn("no input files", err) + + def test_cli_fail_output_and_multiple_files(self): + _, err = self.expect_failure("-o", "out.c", "input.c", "moreinput.c") + msg = "error: can't use -o with multiple filenames" + self.assertIn(msg, err) + + def test_cli_fail_filename_or_output_and_make(self): + msg = "can't use -o or filenames with --make" + for opts in ("-o", "out.c"), ("filename.c",): + with self.subTest(opts=opts): + _, err = self.expect_failure("--make", *opts) + self.assertIn(msg, err) + + def test_cli_fail_make_without_srcdir(self): + _, err = self.expect_failure("--make", "--srcdir", "") + msg = "error: --srcdir must not be empty with --make" + self.assertIn(msg, err) + + def test_file_dest(self): + block = dedent(""" + /*[clinic input] + destination test new file {path}.h + output everything test + func + a: object + / + [clinic start generated code]*/ + """) + expected_checksum_line = ( + "/*[clinic end generated code: " + "output=da39a3ee5e6b4b0d input=b602ab8e173ac3bd]*/\n" + ) + expected_output = dedent("""\ + /*[clinic input] + preserve + [clinic start generated code]*/ + + PyDoc_VAR(func__doc__); + + PyDoc_STRVAR(func__doc__, + "func($module, a, /)\\n" + "--\\n" + "\\n"); + + #define FUNC_METHODDEF \\ + {"func", (PyCFunction)func, METH_O, func__doc__}, + + static PyObject * + func(PyObject *module, PyObject *a) + /*[clinic end generated code: output=3dde2d13002165b9 input=a9049054013a1b77]*/ + """) + with os_helper.temp_dir() as tmp_dir: + in_fn = os.path.join(tmp_dir, "test.c") + out_fn = os.path.join(tmp_dir, "test.c.h") + with open(in_fn, "w", encoding="utf-8") as f: + f.write(block) + with open(out_fn, "w", encoding="utf-8") as f: + f.write("") # Write an empty output file! + # Clinic should complain about the empty output file. + _, err = self.expect_failure(in_fn) + expected_err = (f"Modified destination file {out_fn!r}; " + "not overwriting!") + self.assertIn(expected_err, err) + # Run clinic again, this time with the -f option. + _ = self.expect_success("-f", in_fn) + # Read back the generated output. + with open(in_fn, encoding="utf-8") as f: + data = f.read() + expected_block = f"{block}{expected_checksum_line}" + self.assertEqual(data, expected_block) + with open(out_fn, encoding="utf-8") as f: + data = f.read() + self.assertEqual(data, expected_output) + +try: + import _testclinic as ac_tester +except ImportError: + ac_tester = None + +@unittest.skipIf(ac_tester is None, "_testclinic is missing") +class ClinicFunctionalTest(unittest.TestCase): + locals().update((name, getattr(ac_tester, name)) + for name in dir(ac_tester) if name.startswith('test_')) + + def check_depr(self, regex, fn, /, *args, **kwds): + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + # Record the line number, so we're sure we've got the correct stack + # level on the deprecation warning. + _, lineno = fn(*args, **kwds), sys._getframe().f_lineno + self.assertEqual(cm.filename, __file__) + self.assertEqual(cm.lineno, lineno) + + def check_depr_star(self, pnames, fn, /, *args, name=None, **kwds): + if name is None: + name = fn.__qualname__ + if isinstance(fn, type): + name = f'{fn.__module__}.{name}' + regex = ( + fr"Passing( more than)?( [0-9]+)? positional argument(s)? to " + fr"{re.escape(name)}\(\) is deprecated. Parameters? {pnames} will " + fr"become( a)? keyword-only parameters? in Python 3\.14" + ) + self.check_depr(regex, fn, *args, **kwds) + + def check_depr_kwd(self, pnames, fn, *args, name=None, **kwds): + if name is None: + name = fn.__qualname__ + if isinstance(fn, type): + name = f'{fn.__module__}.{name}' + pl = 's' if ' ' in pnames else '' + regex = ( + fr"Passing keyword argument{pl} {pnames} to " + fr"{re.escape(name)}\(\) is deprecated. Parameter{pl} {pnames} " + fr"will become positional-only in Python 3\.14." + ) + self.check_depr(regex, fn, *args, **kwds) + + def test_objects_converter(self): + with self.assertRaises(TypeError): + ac_tester.objects_converter() + self.assertEqual(ac_tester.objects_converter(1, 2), (1, 2)) + self.assertEqual(ac_tester.objects_converter([], 'whatever class'), ([], 'whatever class')) + self.assertEqual(ac_tester.objects_converter(1), (1, None)) + + def test_bytes_object_converter(self): + with self.assertRaises(TypeError): + ac_tester.bytes_object_converter(1) + self.assertEqual(ac_tester.bytes_object_converter(b'BytesObject'), (b'BytesObject',)) + + def test_byte_array_object_converter(self): + with self.assertRaises(TypeError): + ac_tester.byte_array_object_converter(1) + byte_arr = bytearray(b'ByteArrayObject') + self.assertEqual(ac_tester.byte_array_object_converter(byte_arr), (byte_arr,)) + + def test_unicode_converter(self): + with self.assertRaises(TypeError): + ac_tester.unicode_converter(1) + self.assertEqual(ac_tester.unicode_converter('unicode'), ('unicode',)) + + def test_bool_converter(self): + with self.assertRaises(TypeError): + ac_tester.bool_converter(False, False, 'not a int') + self.assertEqual(ac_tester.bool_converter(), (True, True, True)) + self.assertEqual(ac_tester.bool_converter('', [], 5), (False, False, True)) + self.assertEqual(ac_tester.bool_converter(('not empty',), {1: 2}, 0), (True, True, False)) + + def test_bool_converter_c_default(self): + self.assertEqual(ac_tester.bool_converter_c_default(), (1, 0, -2, -3)) + self.assertEqual(ac_tester.bool_converter_c_default(False, True, False, True), + (0, 1, 0, 1)) + + def test_char_converter(self): + with self.assertRaises(TypeError): + ac_tester.char_converter(1) + with self.assertRaises(TypeError): + ac_tester.char_converter(b'ab') + chars = [b'A', b'\a', b'\b', b'\t', b'\n', b'\v', b'\f', b'\r', b'"', b"'", b'?', b'\\', b'\000', b'\377'] + expected = tuple(ord(c) for c in chars) + self.assertEqual(ac_tester.char_converter(), expected) + chars = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'0', b'a', b'b', b'c', b'd'] + expected = tuple(ord(c) for c in chars) + self.assertEqual(ac_tester.char_converter(*chars), expected) + + def test_unsigned_char_converter(self): + from _testcapi import UCHAR_MAX + with self.assertRaises(OverflowError): + ac_tester.unsigned_char_converter(-1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_char_converter(UCHAR_MAX + 1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_char_converter(0, UCHAR_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.unsigned_char_converter([]) + self.assertEqual(ac_tester.unsigned_char_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.unsigned_char_converter(0, 0, UCHAR_MAX + 1), (0, 0, 0)) + self.assertEqual(ac_tester.unsigned_char_converter(0, 0, (UCHAR_MAX + 1) * 3 + 123), (0, 0, 123)) + + def test_short_converter(self): + from _testcapi import SHRT_MIN, SHRT_MAX + with self.assertRaises(OverflowError): + ac_tester.short_converter(SHRT_MIN - 1) + with self.assertRaises(OverflowError): + ac_tester.short_converter(SHRT_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.short_converter([]) + self.assertEqual(ac_tester.short_converter(-1234), (-1234,)) + self.assertEqual(ac_tester.short_converter(4321), (4321,)) + + def test_unsigned_short_converter(self): + from _testcapi import USHRT_MAX + with self.assertRaises(ValueError): + ac_tester.unsigned_short_converter(-1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_short_converter(USHRT_MAX + 1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_short_converter(0, USHRT_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.unsigned_short_converter([]) + self.assertEqual(ac_tester.unsigned_short_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.unsigned_short_converter(0, 0, USHRT_MAX + 1), (0, 0, 0)) + self.assertEqual(ac_tester.unsigned_short_converter(0, 0, (USHRT_MAX + 1) * 3 + 123), (0, 0, 123)) + + def test_int_converter(self): + from _testcapi import INT_MIN, INT_MAX + with self.assertRaises(OverflowError): + ac_tester.int_converter(INT_MIN - 1) + with self.assertRaises(OverflowError): + ac_tester.int_converter(INT_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.int_converter(1, 2, 3) + with self.assertRaises(TypeError): + ac_tester.int_converter([]) + self.assertEqual(ac_tester.int_converter(), (12, 34, 45)) + self.assertEqual(ac_tester.int_converter(1, 2, '3'), (1, 2, ord('3'))) + + def test_unsigned_int_converter(self): + from _testcapi import UINT_MAX + with self.assertRaises(ValueError): + ac_tester.unsigned_int_converter(-1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_int_converter(UINT_MAX + 1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_int_converter(0, UINT_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.unsigned_int_converter([]) + self.assertEqual(ac_tester.unsigned_int_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.unsigned_int_converter(0, 0, UINT_MAX + 1), (0, 0, 0)) + self.assertEqual(ac_tester.unsigned_int_converter(0, 0, (UINT_MAX + 1) * 3 + 123), (0, 0, 123)) + + def test_long_converter(self): + from _testcapi import LONG_MIN, LONG_MAX + with self.assertRaises(OverflowError): + ac_tester.long_converter(LONG_MIN - 1) + with self.assertRaises(OverflowError): + ac_tester.long_converter(LONG_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.long_converter([]) + self.assertEqual(ac_tester.long_converter(), (12,)) + self.assertEqual(ac_tester.long_converter(-1234), (-1234,)) + + def test_unsigned_long_converter(self): + from _testcapi import ULONG_MAX + with self.assertRaises(ValueError): + ac_tester.unsigned_long_converter(-1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_long_converter(ULONG_MAX + 1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_long_converter(0, ULONG_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.unsigned_long_converter([]) + self.assertEqual(ac_tester.unsigned_long_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.unsigned_long_converter(0, 0, ULONG_MAX + 1), (0, 0, 0)) + self.assertEqual(ac_tester.unsigned_long_converter(0, 0, (ULONG_MAX + 1) * 3 + 123), (0, 0, 123)) + + def test_long_long_converter(self): + from _testcapi import LLONG_MIN, LLONG_MAX + with self.assertRaises(OverflowError): + ac_tester.long_long_converter(LLONG_MIN - 1) + with self.assertRaises(OverflowError): + ac_tester.long_long_converter(LLONG_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.long_long_converter([]) + self.assertEqual(ac_tester.long_long_converter(), (12,)) + self.assertEqual(ac_tester.long_long_converter(-1234), (-1234,)) + + def test_unsigned_long_long_converter(self): + from _testcapi import ULLONG_MAX + with self.assertRaises(ValueError): + ac_tester.unsigned_long_long_converter(-1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_long_long_converter(ULLONG_MAX + 1) + with self.assertRaises(OverflowError): + ac_tester.unsigned_long_long_converter(0, ULLONG_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.unsigned_long_long_converter([]) + self.assertEqual(ac_tester.unsigned_long_long_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.unsigned_long_long_converter(0, 0, ULLONG_MAX + 1), (0, 0, 0)) + self.assertEqual(ac_tester.unsigned_long_long_converter(0, 0, (ULLONG_MAX + 1) * 3 + 123), (0, 0, 123)) + + def test_py_ssize_t_converter(self): + from _testcapi import PY_SSIZE_T_MIN, PY_SSIZE_T_MAX + with self.assertRaises(OverflowError): + ac_tester.py_ssize_t_converter(PY_SSIZE_T_MIN - 1) + with self.assertRaises(OverflowError): + ac_tester.py_ssize_t_converter(PY_SSIZE_T_MAX + 1) + with self.assertRaises(TypeError): + ac_tester.py_ssize_t_converter([]) + self.assertEqual(ac_tester.py_ssize_t_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.py_ssize_t_converter(1, 2, None), (1, 2, 56)) + + def test_slice_index_converter(self): + from _testcapi import PY_SSIZE_T_MIN, PY_SSIZE_T_MAX + with self.assertRaises(TypeError): + ac_tester.slice_index_converter([]) + self.assertEqual(ac_tester.slice_index_converter(), (12, 34, 56)) + self.assertEqual(ac_tester.slice_index_converter(1, 2, None), (1, 2, 56)) + self.assertEqual(ac_tester.slice_index_converter(PY_SSIZE_T_MAX, PY_SSIZE_T_MAX + 1, PY_SSIZE_T_MAX + 1234), + (PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX)) + self.assertEqual(ac_tester.slice_index_converter(PY_SSIZE_T_MIN, PY_SSIZE_T_MIN - 1, PY_SSIZE_T_MIN - 1234), + (PY_SSIZE_T_MIN, PY_SSIZE_T_MIN, PY_SSIZE_T_MIN)) + + def test_size_t_converter(self): + with self.assertRaises(ValueError): + ac_tester.size_t_converter(-1) + with self.assertRaises(TypeError): + ac_tester.size_t_converter([]) + self.assertEqual(ac_tester.size_t_converter(), (12,)) + + def test_float_converter(self): + with self.assertRaises(TypeError): + ac_tester.float_converter([]) + self.assertEqual(ac_tester.float_converter(), (12.5,)) + self.assertEqual(ac_tester.float_converter(-0.5), (-0.5,)) + + def test_double_converter(self): + with self.assertRaises(TypeError): + ac_tester.double_converter([]) + self.assertEqual(ac_tester.double_converter(), (12.5,)) + self.assertEqual(ac_tester.double_converter(-0.5), (-0.5,)) + + def test_py_complex_converter(self): + with self.assertRaises(TypeError): + ac_tester.py_complex_converter([]) + self.assertEqual(ac_tester.py_complex_converter(complex(1, 2)), (complex(1, 2),)) + self.assertEqual(ac_tester.py_complex_converter(complex('-1-2j')), (complex('-1-2j'),)) + self.assertEqual(ac_tester.py_complex_converter(-0.5), (-0.5,)) + self.assertEqual(ac_tester.py_complex_converter(10), (10,)) + + def test_str_converter(self): + with self.assertRaises(TypeError): + ac_tester.str_converter(1) + with self.assertRaises(TypeError): + ac_tester.str_converter('a', 'b', 'c') + with self.assertRaises(ValueError): + ac_tester.str_converter('a', b'b\0b', 'c') + self.assertEqual(ac_tester.str_converter('a', b'b', 'c'), ('a', 'b', 'c')) + self.assertEqual(ac_tester.str_converter('a', b'b', b'c'), ('a', 'b', 'c')) + self.assertEqual(ac_tester.str_converter('a', b'b', 'c\0c'), ('a', 'b', 'c\0c')) + + def test_str_converter_encoding(self): + with self.assertRaises(TypeError): + ac_tester.str_converter_encoding(1) + self.assertEqual(ac_tester.str_converter_encoding('a', 'b', 'c'), ('a', 'b', 'c')) + with self.assertRaises(TypeError): + ac_tester.str_converter_encoding('a', b'b\0b', 'c') + self.assertEqual(ac_tester.str_converter_encoding('a', b'b', bytearray([ord('c')])), ('a', 'b', 'c')) + self.assertEqual(ac_tester.str_converter_encoding('a', b'b', bytearray([ord('c'), 0, ord('c')])), + ('a', 'b', 'c\x00c')) + self.assertEqual(ac_tester.str_converter_encoding('a', b'b', b'c\x00c'), ('a', 'b', 'c\x00c')) + + def test_py_buffer_converter(self): + with self.assertRaises(TypeError): + ac_tester.py_buffer_converter('a', 'b') + self.assertEqual(ac_tester.py_buffer_converter('abc', bytearray([1, 2, 3])), (b'abc', b'\x01\x02\x03')) + + def test_keywords(self): + self.assertEqual(ac_tester.keywords(1, 2), (1, 2)) + self.assertEqual(ac_tester.keywords(1, b=2), (1, 2)) + self.assertEqual(ac_tester.keywords(a=1, b=2), (1, 2)) + + def test_keywords_kwonly(self): + with self.assertRaises(TypeError): + ac_tester.keywords_kwonly(1, 2) + self.assertEqual(ac_tester.keywords_kwonly(1, b=2), (1, 2)) + self.assertEqual(ac_tester.keywords_kwonly(a=1, b=2), (1, 2)) + + def test_keywords_opt(self): + self.assertEqual(ac_tester.keywords_opt(1), (1, None, None)) + self.assertEqual(ac_tester.keywords_opt(1, 2), (1, 2, None)) + self.assertEqual(ac_tester.keywords_opt(1, 2, 3), (1, 2, 3)) + self.assertEqual(ac_tester.keywords_opt(1, b=2), (1, 2, None)) + self.assertEqual(ac_tester.keywords_opt(1, 2, c=3), (1, 2, 3)) + self.assertEqual(ac_tester.keywords_opt(a=1, c=3), (1, None, 3)) + self.assertEqual(ac_tester.keywords_opt(a=1, b=2, c=3), (1, 2, 3)) + + def test_keywords_opt_kwonly(self): + self.assertEqual(ac_tester.keywords_opt_kwonly(1), (1, None, None, None)) + self.assertEqual(ac_tester.keywords_opt_kwonly(1, 2), (1, 2, None, None)) + with self.assertRaises(TypeError): + ac_tester.keywords_opt_kwonly(1, 2, 3) + self.assertEqual(ac_tester.keywords_opt_kwonly(1, b=2), (1, 2, None, None)) + self.assertEqual(ac_tester.keywords_opt_kwonly(1, 2, c=3), (1, 2, 3, None)) + self.assertEqual(ac_tester.keywords_opt_kwonly(a=1, c=3), (1, None, 3, None)) + self.assertEqual(ac_tester.keywords_opt_kwonly(a=1, b=2, c=3, d=4), (1, 2, 3, 4)) + + def test_keywords_kwonly_opt(self): + self.assertEqual(ac_tester.keywords_kwonly_opt(1), (1, None, None)) + with self.assertRaises(TypeError): + ac_tester.keywords_kwonly_opt(1, 2) + self.assertEqual(ac_tester.keywords_kwonly_opt(1, b=2), (1, 2, None)) + self.assertEqual(ac_tester.keywords_kwonly_opt(a=1, c=3), (1, None, 3)) + self.assertEqual(ac_tester.keywords_kwonly_opt(a=1, b=2, c=3), (1, 2, 3)) + + def test_posonly_keywords(self): + with self.assertRaises(TypeError): + ac_tester.posonly_keywords(1) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords(a=1, b=2) + self.assertEqual(ac_tester.posonly_keywords(1, 2), (1, 2)) + self.assertEqual(ac_tester.posonly_keywords(1, b=2), (1, 2)) + + def test_posonly_kwonly(self): + with self.assertRaises(TypeError): + ac_tester.posonly_kwonly(1) + with self.assertRaises(TypeError): + ac_tester.posonly_kwonly(1, 2) + with self.assertRaises(TypeError): + ac_tester.posonly_kwonly(a=1, b=2) + self.assertEqual(ac_tester.posonly_kwonly(1, b=2), (1, 2)) + + def test_posonly_keywords_kwonly(self): + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly(1) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly(1, 2, 3) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly(a=1, b=2, c=3) + self.assertEqual(ac_tester.posonly_keywords_kwonly(1, 2, c=3), (1, 2, 3)) + self.assertEqual(ac_tester.posonly_keywords_kwonly(1, b=2, c=3), (1, 2, 3)) + + def test_posonly_keywords_opt(self): + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_opt(1) + self.assertEqual(ac_tester.posonly_keywords_opt(1, 2), (1, 2, None, None)) + self.assertEqual(ac_tester.posonly_keywords_opt(1, 2, 3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_keywords_opt(1, 2, 3, 4), (1, 2, 3, 4)) + self.assertEqual(ac_tester.posonly_keywords_opt(1, b=2), (1, 2, None, None)) + self.assertEqual(ac_tester.posonly_keywords_opt(1, 2, c=3), (1, 2, 3, None)) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_opt(a=1, b=2, c=3, d=4) + self.assertEqual(ac_tester.posonly_keywords_opt(1, b=2, c=3, d=4), (1, 2, 3, 4)) + + def test_posonly_opt_keywords_opt(self): + self.assertEqual(ac_tester.posonly_opt_keywords_opt(1), (1, None, None, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt(1, 2), (1, 2, None, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt(1, 2, 3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt(1, 2, 3, 4), (1, 2, 3, 4)) + with self.assertRaises(TypeError): + ac_tester.posonly_opt_keywords_opt(1, b=2) + self.assertEqual(ac_tester.posonly_opt_keywords_opt(1, 2, c=3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt(1, 2, c=3, d=4), (1, 2, 3, 4)) + with self.assertRaises(TypeError): + ac_tester.posonly_opt_keywords_opt(a=1, b=2, c=3, d=4) + + def test_posonly_kwonly_opt(self): + with self.assertRaises(TypeError): + ac_tester.posonly_kwonly_opt(1) + with self.assertRaises(TypeError): + ac_tester.posonly_kwonly_opt(1, 2) + self.assertEqual(ac_tester.posonly_kwonly_opt(1, b=2), (1, 2, None, None)) + self.assertEqual(ac_tester.posonly_kwonly_opt(1, b=2, c=3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_kwonly_opt(1, b=2, c=3, d=4), (1, 2, 3, 4)) + with self.assertRaises(TypeError): + ac_tester.posonly_kwonly_opt(a=1, b=2, c=3, d=4) + + def test_posonly_opt_kwonly_opt(self): + self.assertEqual(ac_tester.posonly_opt_kwonly_opt(1), (1, None, None, None)) + self.assertEqual(ac_tester.posonly_opt_kwonly_opt(1, 2), (1, 2, None, None)) + with self.assertRaises(TypeError): + ac_tester.posonly_opt_kwonly_opt(1, 2, 3) + with self.assertRaises(TypeError): + ac_tester.posonly_opt_kwonly_opt(1, b=2) + self.assertEqual(ac_tester.posonly_opt_kwonly_opt(1, 2, c=3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_opt_kwonly_opt(1, 2, c=3, d=4), (1, 2, 3, 4)) + + def test_posonly_keywords_kwonly_opt(self): + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly_opt(1) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly_opt(1, 2) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly_opt(1, b=2) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly_opt(1, 2, 3) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_kwonly_opt(a=1, b=2, c=3) + self.assertEqual(ac_tester.posonly_keywords_kwonly_opt(1, 2, c=3), (1, 2, 3, None, None)) + self.assertEqual(ac_tester.posonly_keywords_kwonly_opt(1, b=2, c=3), (1, 2, 3, None, None)) + self.assertEqual(ac_tester.posonly_keywords_kwonly_opt(1, 2, c=3, d=4), (1, 2, 3, 4, None)) + self.assertEqual(ac_tester.posonly_keywords_kwonly_opt(1, 2, c=3, d=4, e=5), (1, 2, 3, 4, 5)) + + def test_posonly_keywords_opt_kwonly_opt(self): + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_opt_kwonly_opt(1) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, 2), (1, 2, None, None, None)) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, b=2), (1, 2, None, None, None)) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_opt_kwonly_opt(1, 2, 3, 4) + with self.assertRaises(TypeError): + ac_tester.posonly_keywords_opt_kwonly_opt(a=1, b=2) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, 2, c=3), (1, 2, 3, None, None)) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, b=2, c=3), (1, 2, 3, None, None)) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, 2, 3, d=4), (1, 2, 3, 4, None)) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, 2, c=3, d=4), (1, 2, 3, 4, None)) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, 2, 3, d=4, e=5), (1, 2, 3, 4, 5)) + self.assertEqual(ac_tester.posonly_keywords_opt_kwonly_opt(1, 2, c=3, d=4, e=5), (1, 2, 3, 4, 5)) + + def test_posonly_opt_keywords_opt_kwonly_opt(self): + self.assertEqual(ac_tester.posonly_opt_keywords_opt_kwonly_opt(1), (1, None, None, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, 2), (1, 2, None, None)) + with self.assertRaises(TypeError): + ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, b=2) + self.assertEqual(ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, 2, 3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, 2, c=3), (1, 2, 3, None)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, 2, 3, d=4), (1, 2, 3, 4)) + self.assertEqual(ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, 2, c=3, d=4), (1, 2, 3, 4)) + with self.assertRaises(TypeError): + ac_tester.posonly_opt_keywords_opt_kwonly_opt(1, 2, 3, 4) + + def test_keyword_only_parameter(self): + with self.assertRaises(TypeError): + ac_tester.keyword_only_parameter() + with self.assertRaises(TypeError): + ac_tester.keyword_only_parameter(1) + self.assertEqual(ac_tester.keyword_only_parameter(a=1), (1,)) + + if ac_tester is not None: + @repeat_fn(ac_tester.varpos, + ac_tester.varpos_array, + ac_tester.TestClass.varpos_no_fastcall, + ac_tester.TestClass.varpos_array_no_fastcall) + def test_varpos(self, fn): + # fn(*args) + self.assertEqual(fn(), ()) + self.assertEqual(fn(1, 2), (1, 2)) + + @repeat_fn(ac_tester.posonly_varpos, + ac_tester.posonly_varpos_array, + ac_tester.TestClass.posonly_varpos_no_fastcall, + ac_tester.TestClass.posonly_varpos_array_no_fastcall) + def test_posonly_varpos(self, fn): + # fn(a, b, /, *args) + self.assertRaises(TypeError, fn) + self.assertRaises(TypeError, fn, 1) + self.assertRaises(TypeError, fn, 1, b=2) + self.assertEqual(fn(1, 2), (1, 2, ())) + self.assertEqual(fn(1, 2, 3, 4), (1, 2, (3, 4))) + + @repeat_fn(ac_tester.posonly_req_opt_varpos, + ac_tester.posonly_req_opt_varpos_array, + ac_tester.TestClass.posonly_req_opt_varpos_no_fastcall, + ac_tester.TestClass.posonly_req_opt_varpos_array_no_fastcall) + def test_posonly_req_opt_varpos(self, fn): + # fn(a, b=False, /, *args) + self.assertRaises(TypeError, fn) + self.assertRaises(TypeError, fn, a=1) + self.assertEqual(fn(1), (1, False, ())) + self.assertEqual(fn(1, 2), (1, 2, ())) + self.assertEqual(fn(1, 2, 3, 4), (1, 2, (3, 4))) + + @repeat_fn(ac_tester.posonly_poskw_varpos, + ac_tester.posonly_poskw_varpos_array, + ac_tester.TestClass.posonly_poskw_varpos_no_fastcall, + ac_tester.TestClass.posonly_poskw_varpos_array_no_fastcall) + def test_posonly_poskw_varpos(self, fn): + # fn(a, /, b, *args) + self.assertRaises(TypeError, fn) + self.assertEqual(fn(1, 2), (1, 2, ())) + self.assertEqual(fn(1, b=2), (1, 2, ())) + self.assertEqual(fn(1, 2, 3, 4), (1, 2, (3, 4))) + self.assertRaises(TypeError, fn, b=4) + errmsg = re.escape("given by name ('b') and position (2)") + self.assertRaisesRegex(TypeError, errmsg, fn, 1, 2, 3, b=4) + + def test_poskw_varpos(self): + # fn(a, *args) + fn = ac_tester.poskw_varpos + self.assertRaises(TypeError, fn) + self.assertRaises(TypeError, fn, 1, b=2) + self.assertEqual(fn(a=1), (1, ())) + errmsg = re.escape("given by name ('a') and position (1)") + self.assertRaisesRegex(TypeError, errmsg, fn, 1, a=2) + self.assertEqual(fn(1), (1, ())) + self.assertEqual(fn(1, 2, 3, 4), (1, (2, 3, 4))) + + def test_poskw_varpos_kwonly_opt(self): + # fn(a, *args, b=False) + fn = ac_tester.poskw_varpos_kwonly_opt + self.assertRaises(TypeError, fn) + errmsg = re.escape("given by name ('a') and position (1)") + self.assertRaisesRegex(TypeError, errmsg, fn, 1, a=2) + self.assertEqual(fn(1, b=2), (1, (), True)) + self.assertEqual(fn(1, 2, 3, 4), (1, (2, 3, 4), False)) + self.assertEqual(fn(1, 2, 3, 4, b=5), (1, (2, 3, 4), True)) + self.assertEqual(fn(a=1), (1, (), False)) + self.assertEqual(fn(a=1, b=2), (1, (), True)) + + def test_poskw_varpos_kwonly_opt2(self): + # fn(a, *args, b=False, c=False) + fn = ac_tester.poskw_varpos_kwonly_opt2 + self.assertRaises(TypeError, fn) + errmsg = re.escape("given by name ('a') and position (1)") + self.assertRaisesRegex(TypeError, errmsg, fn, 1, a=2) + self.assertEqual(fn(1, b=2), (1, (), 2, False)) + self.assertEqual(fn(1, b=2, c=3), (1, (), 2, 3)) + self.assertEqual(fn(1, 2, 3), (1, (2, 3), False, False)) + self.assertEqual(fn(1, 2, 3, b=4), (1, (2, 3), 4, False)) + self.assertEqual(fn(1, 2, 3, b=4, c=5), (1, (2, 3), 4, 5)) + self.assertEqual(fn(a=1), (1, (), False, False)) + self.assertEqual(fn(a=1, b=2), (1, (), 2, False)) + self.assertEqual(fn(a=1, b=2, c=3), (1, (), 2, 3)) + + def test_varpos_kwonly_opt(self): + # fn(*args, b=False) + fn = ac_tester.varpos_kwonly_opt + self.assertEqual(fn(), ((), False)) + self.assertEqual(fn(b=2), ((), 2)) + self.assertEqual(fn(1, b=2), ((1, ), 2)) + self.assertEqual(fn(1, 2, 3, 4), ((1, 2, 3, 4), False)) + self.assertEqual(fn(1, 2, 3, 4, b=5), ((1, 2, 3, 4), 5)) + + def test_varpos_kwonly_req_opt(self): + fn = ac_tester.varpos_kwonly_req_opt + self.assertRaises(TypeError, fn) + self.assertEqual(fn(a=1), ((), 1, False, False)) + self.assertEqual(fn(a=1, b=2), ((), 1, 2, False)) + self.assertEqual(fn(a=1, b=2, c=3), ((), 1, 2, 3)) + self.assertRaises(TypeError, fn, 1) + self.assertEqual(fn(1, a=2), ((1,), 2, False, False)) + self.assertEqual(fn(1, a=2, b=3), ((1,), 2, 3, False)) + self.assertEqual(fn(1, a=2, b=3, c=4), ((1,), 2, 3, 4)) + + def test_gh_32092_oob(self): + ac_tester.gh_32092_oob(1, 2, 3, 4, kw1=5, kw2=6) + + def test_gh_32092_kw_pass(self): + ac_tester.gh_32092_kw_pass(1, 2, 3) + + def test_gh_99233_refcount(self): + arg = '*A unique string is not referenced by anywhere else.*' + arg_refcount_origin = sys.getrefcount(arg) + ac_tester.gh_99233_refcount(arg) + arg_refcount_after = sys.getrefcount(arg) + self.assertEqual(arg_refcount_origin, arg_refcount_after) + + def test_gh_99240_double_free(self): + err = re.escape( + "gh_99240_double_free() argument 2 must be encoded string " + "without null bytes, not str" + ) + with self.assertRaisesRegex(TypeError, err): + ac_tester.gh_99240_double_free('a', '\0b') + + def test_null_or_tuple_for_varargs(self): + # fn(name, *constraints, covariant=False) + fn = ac_tester.null_or_tuple_for_varargs + # All of these should not crash: + self.assertEqual(fn('a'), ('a', (), False)) + self.assertEqual(fn('a', 1, 2, 3, covariant=True), ('a', (1, 2, 3), True)) + self.assertEqual(fn(name='a'), ('a', (), False)) + self.assertEqual(fn(name='a', covariant=True), ('a', (), True)) + self.assertEqual(fn(covariant=True, name='a'), ('a', (), True)) + + self.assertRaises(TypeError, fn, covariant=True) + errmsg = re.escape("given by name ('name') and position (1)") + self.assertRaisesRegex(TypeError, errmsg, fn, 1, name='a') + self.assertRaisesRegex(TypeError, errmsg, fn, 1, 2, 3, name='a', covariant=True) + self.assertRaisesRegex(TypeError, errmsg, fn, 1, 2, 3, covariant=True, name='a') + + def test_cloned_func_exception_message(self): + incorrect_arg = -1 # f1() and f2() accept a single str + with self.assertRaisesRegex(TypeError, "clone_f1"): + ac_tester.clone_f1(incorrect_arg) + with self.assertRaisesRegex(TypeError, "clone_f2"): + ac_tester.clone_f2(incorrect_arg) + + def test_cloned_func_with_converter_exception_message(self): + for name in "clone_with_conv_f1", "clone_with_conv_f2": + with self.subTest(name=name): + func = getattr(ac_tester, name) + self.assertEqual(func(), name) + + def test_get_defining_class(self): + obj = ac_tester.TestClass() + meth = obj.get_defining_class + self.assertIs(obj.get_defining_class(), ac_tester.TestClass) + + # 'defining_class' argument is a positional only argument + with self.assertRaises(TypeError): + obj.get_defining_class_arg(cls=ac_tester.TestClass) + + check = partial(self.assertRaisesRegex, TypeError, "no arguments") + check(meth, 1) + check(meth, a=1) + + def test_get_defining_class_capi(self): + from _testcapi import pyobject_vectorcall + obj = ac_tester.TestClass() + meth = obj.get_defining_class + pyobject_vectorcall(meth, None, None) + pyobject_vectorcall(meth, (), None) + pyobject_vectorcall(meth, (), ()) + pyobject_vectorcall(meth, None, ()) + self.assertIs(pyobject_vectorcall(meth, (), ()), ac_tester.TestClass) + + check = partial(self.assertRaisesRegex, TypeError, "no arguments") + check(pyobject_vectorcall, meth, (1,), None) + check(pyobject_vectorcall, meth, (1,), ("a",)) + + def test_get_defining_class_arg(self): + obj = ac_tester.TestClass() + self.assertEqual(obj.get_defining_class_arg("arg"), + (ac_tester.TestClass, "arg")) + self.assertEqual(obj.get_defining_class_arg(arg=123), + (ac_tester.TestClass, 123)) + + # 'defining_class' argument is a positional only argument + with self.assertRaises(TypeError): + obj.get_defining_class_arg(cls=ac_tester.TestClass, arg="arg") + + # wrong number of arguments + with self.assertRaises(TypeError): + obj.get_defining_class_arg() + with self.assertRaises(TypeError): + obj.get_defining_class_arg("arg1", "arg2") + + def test_defclass_varpos(self): + # fn(*args) + cls = ac_tester.TestClass + obj = cls() + fn = obj.defclass_varpos + self.assertEqual(fn(), (cls, ())) + self.assertEqual(fn(1, 2), (cls, (1, 2))) + fn = cls.defclass_varpos + self.assertRaises(TypeError, fn) + self.assertEqual(fn(obj), (cls, ())) + self.assertEqual(fn(obj, 1, 2), (cls, (1, 2))) + + def test_defclass_posonly_varpos(self): + # fn(a, b, /, *args) + cls = ac_tester.TestClass + obj = cls() + fn = obj.defclass_posonly_varpos + errmsg = 'takes at least 2 positional arguments' + self.assertRaisesRegex(TypeError, errmsg, fn) + self.assertRaisesRegex(TypeError, errmsg, fn, 1) + self.assertEqual(fn(1, 2), (cls, 1, 2, ())) + self.assertEqual(fn(1, 2, 3, 4), (cls, 1, 2, (3, 4))) + fn = cls.defclass_posonly_varpos + self.assertRaises(TypeError, fn) + self.assertRaisesRegex(TypeError, errmsg, fn, obj) + self.assertRaisesRegex(TypeError, errmsg, fn, obj, 1) + self.assertEqual(fn(obj, 1, 2), (cls, 1, 2, ())) + self.assertEqual(fn(obj, 1, 2, 3, 4), (cls, 1, 2, (3, 4))) + + def test_depr_star_new(self): + cls = ac_tester.DeprStarNew + cls() + cls(a=None) + self.check_depr_star("'a'", cls, None) + + def test_depr_star_new_cloned(self): + fn = ac_tester.DeprStarNew().cloned + fn() + fn(a=None) + self.check_depr_star("'a'", fn, None, name='_testclinic.DeprStarNew.cloned') + + def test_depr_star_init(self): + cls = ac_tester.DeprStarInit + cls() + cls(a=None) + self.check_depr_star("'a'", cls, None) + + def test_depr_star_init_cloned(self): + fn = ac_tester.DeprStarInit().cloned + fn() + fn(a=None) + self.check_depr_star("'a'", fn, None, name='_testclinic.DeprStarInit.cloned') + + def test_depr_star_init_noinline(self): + cls = ac_tester.DeprStarInitNoInline + self.assertRaises(TypeError, cls, "a") + cls(a="a", b="b") + cls(a="a", b="b", c="c") + cls("a", b="b") + cls("a", b="b", c="c") + check = partial(self.check_depr_star, "'b' and 'c'", cls) + check("a", "b") + check("a", "b", "c") + check("a", "b", c="c") + self.assertRaises(TypeError, cls, "a", "b", "c", "d") + + def test_depr_kwd_new(self): + cls = ac_tester.DeprKwdNew + cls() + cls(None) + self.check_depr_kwd("'a'", cls, a=None) + + def test_depr_kwd_init(self): + cls = ac_tester.DeprKwdInit + cls() + cls(None) + self.check_depr_kwd("'a'", cls, a=None) + + def test_depr_kwd_init_noinline(self): + cls = ac_tester.DeprKwdInitNoInline + cls = ac_tester.depr_star_noinline + self.assertRaises(TypeError, cls, "a") + cls(a="a", b="b") + cls(a="a", b="b", c="c") + cls("a", b="b") + cls("a", b="b", c="c") + check = partial(self.check_depr_star, "'b' and 'c'", cls) + check("a", "b") + check("a", "b", "c") + check("a", "b", c="c") + self.assertRaises(TypeError, cls, "a", "b", "c", "d") + + def test_depr_star_pos0_len1(self): + fn = ac_tester.depr_star_pos0_len1 + fn(a=None) + self.check_depr_star("'a'", fn, "a") + + def test_depr_star_pos0_len2(self): + fn = ac_tester.depr_star_pos0_len2 + fn(a=0, b=0) + check = partial(self.check_depr_star, "'a' and 'b'", fn) + check("a", b=0) + check("a", "b") + + def test_depr_star_pos0_len3_with_kwd(self): + fn = ac_tester.depr_star_pos0_len3_with_kwd + fn(a=0, b=0, c=0, d=0) + check = partial(self.check_depr_star, "'a', 'b' and 'c'", fn) + check("a", b=0, c=0, d=0) + check("a", "b", c=0, d=0) + check("a", "b", "c", d=0) + + def test_depr_star_pos1_len1_opt(self): + fn = ac_tester.depr_star_pos1_len1_opt + fn(a=0, b=0) + fn("a", b=0) + fn(a=0) # b is optional + check = partial(self.check_depr_star, "'b'", fn) + check("a", "b") + + def test_depr_star_pos1_len1(self): + fn = ac_tester.depr_star_pos1_len1 + fn(a=0, b=0) + fn("a", b=0) + check = partial(self.check_depr_star, "'b'", fn) + check("a", "b") + + def test_depr_star_pos1_len2_with_kwd(self): + fn = ac_tester.depr_star_pos1_len2_with_kwd + fn(a=0, b=0, c=0, d=0), + fn("a", b=0, c=0, d=0), + check = partial(self.check_depr_star, "'b' and 'c'", fn) + check("a", "b", c=0, d=0), + check("a", "b", "c", d=0), + + def test_depr_star_pos2_len1(self): + fn = ac_tester.depr_star_pos2_len1 + fn(a=0, b=0, c=0) + fn("a", b=0, c=0) + fn("a", "b", c=0) + check = partial(self.check_depr_star, "'c'", fn) + check("a", "b", "c") + + def test_depr_star_pos2_len2(self): + fn = ac_tester.depr_star_pos2_len2 + fn(a=0, b=0, c=0, d=0) + fn("a", b=0, c=0, d=0) + fn("a", "b", c=0, d=0) + check = partial(self.check_depr_star, "'c' and 'd'", fn) + check("a", "b", "c", d=0) + check("a", "b", "c", "d") + + def test_depr_star_pos2_len2_with_kwd(self): + fn = ac_tester.depr_star_pos2_len2_with_kwd + fn(a=0, b=0, c=0, d=0, e=0) + fn("a", b=0, c=0, d=0, e=0) + fn("a", "b", c=0, d=0, e=0) + check = partial(self.check_depr_star, "'c' and 'd'", fn) + check("a", "b", "c", d=0, e=0) + check("a", "b", "c", "d", e=0) + + def test_depr_star_noinline(self): + fn = ac_tester.depr_star_noinline + self.assertRaises(TypeError, fn, "a") + fn(a="a", b="b") + fn(a="a", b="b", c="c") + fn("a", b="b") + fn("a", b="b", c="c") + check = partial(self.check_depr_star, "'b' and 'c'", fn) + check("a", "b") + check("a", "b", "c") + check("a", "b", c="c") + self.assertRaises(TypeError, fn, "a", "b", "c", "d") + + def test_depr_star_multi(self): + fn = ac_tester.depr_star_multi + self.assertRaises(TypeError, fn, "a") + fn("a", b="b", c="c", d="d", e="e", f="f", g="g", h="h") + errmsg = ( + "Passing more than 1 positional argument to depr_star_multi() is deprecated. " + "Parameter 'b' will become a keyword-only parameter in Python 3.16. " + "Parameters 'c' and 'd' will become keyword-only parameters in Python 3.15. " + "Parameters 'e', 'f' and 'g' will become keyword-only parameters in Python 3.14.") + check = partial(self.check_depr, re.escape(errmsg), fn) + check("a", "b", c="c", d="d", e="e", f="f", g="g", h="h") + check("a", "b", "c", d="d", e="e", f="f", g="g", h="h") + check("a", "b", "c", "d", e="e", f="f", g="g", h="h") + check("a", "b", "c", "d", "e", f="f", g="g", h="h") + check("a", "b", "c", "d", "e", "f", g="g", h="h") + check("a", "b", "c", "d", "e", "f", "g", h="h") + self.assertRaises(TypeError, fn, "a", "b", "c", "d", "e", "f", "g", "h") + + def test_depr_kwd_required_1(self): + fn = ac_tester.depr_kwd_required_1 + fn("a", "b") + self.assertRaises(TypeError, fn, "a") + self.assertRaises(TypeError, fn, "a", "b", "c") + check = partial(self.check_depr_kwd, "'b'", fn) + check("a", b="b") + self.assertRaises(TypeError, fn, a="a", b="b") + + def test_depr_kwd_required_2(self): + fn = ac_tester.depr_kwd_required_2 + fn("a", "b", "c") + self.assertRaises(TypeError, fn, "a", "b") + self.assertRaises(TypeError, fn, "a", "b", "c", "d") + check = partial(self.check_depr_kwd, "'b' and 'c'", fn) + check("a", "b", c="c") + check("a", b="b", c="c") + self.assertRaises(TypeError, fn, a="a", b="b", c="c") + + def test_depr_kwd_optional_1(self): + fn = ac_tester.depr_kwd_optional_1 + fn("a") + fn("a", "b") + self.assertRaises(TypeError, fn) + self.assertRaises(TypeError, fn, "a", "b", "c") + check = partial(self.check_depr_kwd, "'b'", fn) + check("a", b="b") + self.assertRaises(TypeError, fn, a="a", b="b") + + def test_depr_kwd_optional_2(self): + fn = ac_tester.depr_kwd_optional_2 + fn("a") + fn("a", "b") + fn("a", "b", "c") + self.assertRaises(TypeError, fn) + self.assertRaises(TypeError, fn, "a", "b", "c", "d") + check = partial(self.check_depr_kwd, "'b' and 'c'", fn) + check("a", b="b") + check("a", c="c") + check("a", b="b", c="c") + check("a", c="c", b="b") + check("a", "b", c="c") + self.assertRaises(TypeError, fn, a="a", b="b", c="c") + + def test_depr_kwd_optional_3(self): + fn = ac_tester.depr_kwd_optional_3 + fn() + fn("a") + fn("a", "b") + fn("a", "b", "c") + self.assertRaises(TypeError, fn, "a", "b", "c", "d") + check = partial(self.check_depr_kwd, "'a', 'b' and 'c'", fn) + check("a", "b", c="c") + check("a", b="b") + check(a="a") + + def test_depr_kwd_required_optional(self): + fn = ac_tester.depr_kwd_required_optional + fn("a", "b") + fn("a", "b", "c") + self.assertRaises(TypeError, fn) + self.assertRaises(TypeError, fn, "a") + self.assertRaises(TypeError, fn, "a", "b", "c", "d") + check = partial(self.check_depr_kwd, "'b' and 'c'", fn) + check("a", b="b") + check("a", b="b", c="c") + check("a", c="c", b="b") + check("a", "b", c="c") + self.assertRaises(TypeError, fn, "a", c="c") + self.assertRaises(TypeError, fn, a="a", b="b", c="c") + + def test_depr_kwd_noinline(self): + fn = ac_tester.depr_kwd_noinline + fn("a", "b") + fn("a", "b", "c") + self.assertRaises(TypeError, fn, "a") + check = partial(self.check_depr_kwd, "'b' and 'c'", fn) + check("a", b="b") + check("a", b="b", c="c") + check("a", c="c", b="b") + check("a", "b", c="c") + self.assertRaises(TypeError, fn, "a", c="c") + self.assertRaises(TypeError, fn, a="a", b="b", c="c") + + def test_depr_kwd_multi(self): + fn = ac_tester.depr_kwd_multi + fn("a", "b", "c", "d", "e", "f", "g", h="h") + errmsg = ( + "Passing keyword arguments 'b', 'c', 'd', 'e', 'f' and 'g' to depr_kwd_multi() is deprecated. " + "Parameter 'b' will become positional-only in Python 3.14. " + "Parameters 'c' and 'd' will become positional-only in Python 3.15. " + "Parameters 'e', 'f' and 'g' will become positional-only in Python 3.16.") + check = partial(self.check_depr, re.escape(errmsg), fn) + check("a", "b", "c", "d", "e", "f", g="g", h="h") + check("a", "b", "c", "d", "e", f="f", g="g", h="h") + check("a", "b", "c", "d", e="e", f="f", g="g", h="h") + check("a", "b", "c", d="d", e="e", f="f", g="g", h="h") + check("a", "b", c="c", d="d", e="e", f="f", g="g", h="h") + check("a", b="b", c="c", d="d", e="e", f="f", g="g", h="h") + self.assertRaises(TypeError, fn, a="a", b="b", c="c", d="d", e="e", f="f", g="g", h="h") + + def test_depr_multi(self): + fn = ac_tester.depr_multi + self.assertRaises(TypeError, fn, "a", "b", "c", "d", "e", "f", "g") + errmsg = ( + "Passing more than 4 positional arguments to depr_multi() is deprecated. " + "Parameter 'e' will become a keyword-only parameter in Python 3.15. " + "Parameter 'f' will become a keyword-only parameter in Python 3.14.") + check = partial(self.check_depr, re.escape(errmsg), fn) + check("a", "b", "c", "d", "e", "f", g="g") + check("a", "b", "c", "d", "e", f="f", g="g") + fn("a", "b", "c", "d", e="e", f="f", g="g") + fn("a", "b", "c", d="d", e="e", f="f", g="g") + errmsg = ( + "Passing keyword arguments 'b' and 'c' to depr_multi() is deprecated. " + "Parameter 'b' will become positional-only in Python 3.14. " + "Parameter 'c' will become positional-only in Python 3.15.") + check = partial(self.check_depr, re.escape(errmsg), fn) + check("a", "b", c="c", d="d", e="e", f="f", g="g") + check("a", b="b", c="c", d="d", e="e", f="f", g="g") + self.assertRaises(TypeError, fn, a="a", b="b", c="c", d="d", e="e", f="f", g="g") + + +class LimitedCAPIOutputTests(unittest.TestCase): + + def setUp(self): + self.clinic = _make_clinic(limited_capi=True) + + @staticmethod + def wrap_clinic_input(block): + return dedent(f""" + /*[clinic input] + output everything buffer + {block} + [clinic start generated code]*/ + /*[clinic input] + dump buffer + [clinic start generated code]*/ + """) + + def test_limited_capi_float(self): + block = self.wrap_clinic_input(""" + func + f: float + / + """) + generated = self.clinic.parse(block) + self.assertNotIn("PyFloat_AS_DOUBLE", generated) + self.assertIn("float f;", generated) + self.assertIn("f = (float) PyFloat_AsDouble", generated) + + def test_limited_capi_double(self): + block = self.wrap_clinic_input(""" + func + f: double + / + """) + generated = self.clinic.parse(block) + self.assertNotIn("PyFloat_AS_DOUBLE", generated) + self.assertIn("double f;", generated) + self.assertIn("f = PyFloat_AsDouble", generated) + + +try: + import _testclinic_limited +except ImportError: + _testclinic_limited = None + +@unittest.skipIf(_testclinic_limited is None, "_testclinic_limited is missing") +class LimitedCAPIFunctionalTest(unittest.TestCase): + locals().update((name, getattr(_testclinic_limited, name)) + for name in dir(_testclinic_limited) if name.startswith('test_')) + + def test_my_int_func(self): + with self.assertRaises(TypeError): + _testclinic_limited.my_int_func() + self.assertEqual(_testclinic_limited.my_int_func(3), 3) + with self.assertRaises(TypeError): + _testclinic_limited.my_int_func(1.0) + with self.assertRaises(TypeError): + _testclinic_limited.my_int_func("xyz") + + def test_my_int_sum(self): + with self.assertRaises(TypeError): + _testclinic_limited.my_int_sum() + with self.assertRaises(TypeError): + _testclinic_limited.my_int_sum(1) + self.assertEqual(_testclinic_limited.my_int_sum(1, 2), 3) + with self.assertRaises(TypeError): + _testclinic_limited.my_int_sum(1.0, 2) + with self.assertRaises(TypeError): + _testclinic_limited.my_int_sum(1, "str") + + def test_my_double_sum(self): + for func in ( + _testclinic_limited.my_float_sum, + _testclinic_limited.my_double_sum, + ): + with self.subTest(func=func.__name__): + self.assertEqual(func(1.0, 2.5), 3.5) + with self.assertRaises(TypeError): + func() + with self.assertRaises(TypeError): + func(1) + with self.assertRaises(TypeError): + func(1., "2") + + def test_get_file_descriptor(self): + # test 'file descriptor' converter: call PyObject_AsFileDescriptor() + get_fd = _testclinic_limited.get_file_descriptor + + class MyInt(int): + pass + + class MyFile: + def __init__(self, fd): + self._fd = fd + def fileno(self): + return self._fd + + for fd in (0, 1, 2, 5, 123_456): + self.assertEqual(get_fd(fd), fd) + + myint = MyInt(fd) + self.assertEqual(get_fd(myint), fd) + + myfile = MyFile(fd) + self.assertEqual(get_fd(myfile), fd) + + with self.assertRaises(OverflowError): + get_fd(2**256) + with self.assertWarnsRegex(RuntimeWarning, + "bool is used as a file descriptor"): + get_fd(True) + with self.assertRaises(TypeError): + get_fd(1.0) + with self.assertRaises(TypeError): + get_fd("abc") + with self.assertRaises(TypeError): + get_fd(None) + + +class PermutationTests(unittest.TestCase): + """Test permutation support functions.""" + + def test_permute_left_option_groups(self): + expected = ( + (), + (3,), + (2, 3), + (1, 2, 3), + ) + data = list(zip([1, 2, 3])) # Generate a list of 1-tuples. + actual = tuple(permute_left_option_groups(data)) + self.assertEqual(actual, expected) + + def test_permute_right_option_groups(self): + expected = ( + (), + (1,), + (1, 2), + (1, 2, 3), + ) + data = list(zip([1, 2, 3])) # Generate a list of 1-tuples. + actual = tuple(permute_right_option_groups(data)) + self.assertEqual(actual, expected) + + def test_permute_optional_groups(self): + empty = { + "left": (), "required": (), "right": (), + "expected": ((),), + } + noleft1 = { + "left": (), "required": ("b",), "right": ("c",), + "expected": ( + ("b",), + ("b", "c"), + ), + } + noleft2 = { + "left": (), "required": ("b", "c",), "right": ("d",), + "expected": ( + ("b", "c"), + ("b", "c", "d"), + ), + } + noleft3 = { + "left": (), "required": ("b", "c",), "right": ("d", "e"), + "expected": ( + ("b", "c"), + ("b", "c", "d"), + ("b", "c", "d", "e"), + ), + } + noright1 = { + "left": ("a",), "required": ("b",), "right": (), + "expected": ( + ("b",), + ("a", "b"), + ), + } + noright2 = { + "left": ("a",), "required": ("b", "c"), "right": (), + "expected": ( + ("b", "c"), + ("a", "b", "c"), + ), + } + noright3 = { + "left": ("a", "b"), "required": ("c",), "right": (), + "expected": ( + ("c",), + ("b", "c"), + ("a", "b", "c"), + ), + } + leftandright1 = { + "left": ("a",), "required": ("b",), "right": ("c",), + "expected": ( + ("b",), + ("a", "b"), # Prefer left. + ("a", "b", "c"), + ), + } + leftandright2 = { + "left": ("a", "b"), "required": ("c", "d"), "right": ("e", "f"), + "expected": ( + ("c", "d"), + ("b", "c", "d"), # Prefer left. + ("a", "b", "c", "d"), # Prefer left. + ("a", "b", "c", "d", "e"), + ("a", "b", "c", "d", "e", "f"), + ), + } + dataset = ( + empty, + noleft1, noleft2, noleft3, + noright1, noright2, noright3, + leftandright1, leftandright2, + ) + for params in dataset: + with self.subTest(**params): + left, required, right, expected = params.values() + permutations = permute_optional_groups(left, required, right) + actual = tuple(permutations) + self.assertEqual(actual, expected) + + +class FormatHelperTests(unittest.TestCase): + + def test_strip_leading_and_trailing_blank_lines(self): + dataset = ( + # Input lines, expected output. + ("a\nb", "a\nb"), + ("a\nb\n", "a\nb"), + ("a\nb ", "a\nb"), + ("\na\nb\n\n", "a\nb"), + ("\n\na\nb\n\n", "a\nb"), + ("\n\na\n\nb\n\n", "a\n\nb"), + # Note, leading whitespace is preserved: + (" a\nb", " a\nb"), + (" a\nb ", " a\nb"), + (" \n \n a\nb \n \n ", " a\nb"), + ) + for lines, expected in dataset: + with self.subTest(lines=lines, expected=expected): + out = libclinic.normalize_snippet(lines) + self.assertEqual(out, expected) + + def test_normalize_snippet(self): + snippet = """ + one + two + three + """ + + # Expected outputs: + zero_indent = ( + "one\n" + "two\n" + "three" + ) + four_indent = ( + " one\n" + " two\n" + " three" + ) + eight_indent = ( + " one\n" + " two\n" + " three" + ) + expected_outputs = {0: zero_indent, 4: four_indent, 8: eight_indent} + for indent, expected in expected_outputs.items(): + with self.subTest(indent=indent): + actual = libclinic.normalize_snippet(snippet, indent=indent) + self.assertEqual(actual, expected) + + def test_escaped_docstring(self): + dataset = ( + # input, expected + (r"abc", r'"abc"'), + (r"\abc", r'"\\abc"'), + (r"\a\bc", r'"\\a\\bc"'), + (r"\a\\bc", r'"\\a\\\\bc"'), + (r'"abc"', r'"\"abc\""'), + (r"'a'", r'"\'a\'"'), + ) + for line, expected in dataset: + with self.subTest(line=line, expected=expected): + out = libclinic.docstring_for_c_string(line) + self.assertEqual(out, expected) + + def test_format_escape(self): + line = "{}, {a}" + expected = "{{}}, {{a}}" + out = libclinic.format_escape(line) + self.assertEqual(out, expected) + + def test_c_bytes_repr(self): + c_bytes_repr = libclinic.c_bytes_repr + self.assertEqual(c_bytes_repr(b''), '""') + self.assertEqual(c_bytes_repr(b'abc'), '"abc"') + self.assertEqual(c_bytes_repr(b'\a\b\f\n\r\t\v'), r'"\a\b\f\n\r\t\v"') + self.assertEqual(c_bytes_repr(b' \0\x7f'), r'" \000\177"') + self.assertEqual(c_bytes_repr(b'"'), r'"\""') + self.assertEqual(c_bytes_repr(b"'"), r'''"'"''') + self.assertEqual(c_bytes_repr(b'\\'), r'"\\"') + self.assertEqual(c_bytes_repr(b'??/'), r'"?\?/"') + self.assertEqual(c_bytes_repr(b'???/'), r'"?\?\?/"') + self.assertEqual(c_bytes_repr(b'/*****/ /*/ */*'), r'"/\*****\/ /\*\/ *\/\*"') + self.assertEqual(c_bytes_repr(b'\xa0'), r'"\240"') + self.assertEqual(c_bytes_repr(b'\xff'), r'"\377"') + + def test_c_str_repr(self): + c_str_repr = libclinic.c_str_repr + self.assertEqual(c_str_repr(''), '""') + self.assertEqual(c_str_repr('abc'), '"abc"') + self.assertEqual(c_str_repr('\a\b\f\n\r\t\v'), r'"\a\b\f\n\r\t\v"') + self.assertEqual(c_str_repr(' \0\x7f'), r'" \000\177"') + self.assertEqual(c_str_repr('"'), r'"\""') + self.assertEqual(c_str_repr("'"), r'''"'"''') + self.assertEqual(c_str_repr('\\'), r'"\\"') + self.assertEqual(c_str_repr('??/'), r'"?\?/"') + self.assertEqual(c_str_repr('???/'), r'"?\?\?/"') + self.assertEqual(c_str_repr('/*****/ /*/ */*'), r'"/\*****\/ /\*\/ *\/\*"') + self.assertEqual(c_str_repr('\xa0'), r'"\u00a0"') + self.assertEqual(c_str_repr('\xff'), r'"\u00ff"') + self.assertEqual(c_str_repr('\u20ac'), r'"\u20ac"') + self.assertEqual(c_str_repr('\U0001f40d'), r'"\U0001f40d"') + + def test_c_unichar_repr(self): + c_unichar_repr = libclinic.c_unichar_repr + self.assertEqual(c_unichar_repr('a'), "'a'") + self.assertEqual(c_unichar_repr('\n'), r"'\n'") + self.assertEqual(c_unichar_repr('\b'), r"'\b'") + self.assertEqual(c_unichar_repr('\0'), '0') + self.assertEqual(c_unichar_repr('\1'), '0x01') + self.assertEqual(c_unichar_repr('\x7f'), '0x7f') + self.assertEqual(c_unichar_repr(' '), "' '") + self.assertEqual(c_unichar_repr('"'), """'"'""") + self.assertEqual(c_unichar_repr("'"), r"'\''") + self.assertEqual(c_unichar_repr('\\'), r"'\\'") + self.assertEqual(c_unichar_repr('?'), "'?'") + self.assertEqual(c_unichar_repr('\xa0'), '0xa0') + self.assertEqual(c_unichar_repr('\xff'), '0xff') + self.assertEqual(c_unichar_repr('\u20ac'), '0x20ac') + self.assertEqual(c_unichar_repr('\U0001f40d'), '0x1f40d') + + def test_indent_all_lines(self): + # Blank lines are expected to be unchanged. + self.assertEqual(libclinic.indent_all_lines("", prefix="bar"), "") + + lines = ( + "one\n" + "two" # The missing newline is deliberate. + ) + expected = ( + "barone\n" + "bartwo" + ) + out = libclinic.indent_all_lines(lines, prefix="bar") + self.assertEqual(out, expected) + + # If last line is empty, expect it to be unchanged. + lines = ( + "\n" + "one\n" + "two\n" + "" + ) + expected = ( + "bar\n" + "barone\n" + "bartwo\n" + "" + ) + out = libclinic.indent_all_lines(lines, prefix="bar") + self.assertEqual(out, expected) + + def test_suffix_all_lines(self): + # Blank lines are expected to be unchanged. + self.assertEqual(libclinic.suffix_all_lines("", suffix="foo"), "") + + lines = ( + "one\n" + "two" # The missing newline is deliberate. + ) + expected = ( + "onefoo\n" + "twofoo" + ) + out = libclinic.suffix_all_lines(lines, suffix="foo") + self.assertEqual(out, expected) + + # If last line is empty, expect it to be unchanged. + lines = ( + "\n" + "one\n" + "two\n" + "" + ) + expected = ( + "foo\n" + "onefoo\n" + "twofoo\n" + "" + ) + out = libclinic.suffix_all_lines(lines, suffix="foo") + self.assertEqual(out, expected) + + +class ClinicReprTests(unittest.TestCase): + def test_Block_repr(self): + block = Block("foo") + expected_repr = "" + self.assertEqual(repr(block), expected_repr) + + block2 = Block("bar", "baz", [], "eggs", "spam") + expected_repr_2 = "" + self.assertEqual(repr(block2), expected_repr_2) + + block3 = Block( + input="longboi_" * 100, + dsl_name="wow_so_long", + signatures=[], + output="very_long_" * 100, + indent="" + ) + expected_repr_3 = ( + "" + ) + self.assertEqual(repr(block3), expected_repr_3) + + def test_Destination_repr(self): + c = _make_clinic() + + destination = Destination( + "foo", type="file", clinic=c, args=("eggs",) + ) + self.assertEqual( + repr(destination), "" + ) + + destination2 = Destination("bar", type="buffer", clinic=c) + self.assertEqual(repr(destination2), "") + + def test_Module_repr(self): + module = Module("foo", _make_clinic()) + self.assertRegex(repr(module), r"") + + def test_Class_repr(self): + cls = Class("foo", _make_clinic(), None, 'some_typedef', 'some_type_object') + self.assertRegex(repr(cls), r"") + + def test_FunctionKind_repr(self): + self.assertEqual( + repr(FunctionKind.CLASS_METHOD), "" + ) + + def test_Function_and_Parameter_reprs(self): + function = Function( + name='foo', + module=_make_clinic(), + cls=None, + c_basename=None, + full_name='foofoo', + return_converter=int_return_converter(), + kind=FunctionKind.METHOD_INIT, + coexist=False + ) + self.assertEqual(repr(function), "") + + converter = self_converter('bar', 'bar', function) + parameter = Parameter( + "bar", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + function=function, + converter=converter + ) + self.assertEqual(repr(parameter), "") + + def test_Monitor_repr(self): + monitor = libclinic.cpp.Monitor("test.c") + self.assertRegex(repr(monitor), r"") + + monitor.line_number = 42 + monitor.stack.append(("token1", "condition1")) + self.assertRegex( + repr(monitor), r"" + ) + + monitor.stack.append(("token2", "condition2")) + self.assertRegex( + repr(monitor), + r"" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py index fd1743e6701..4d117be1b88 100644 --- a/Lib/test/test_compile.py +++ b/Lib/test/test_compile.py @@ -1249,7 +1249,6 @@ def get_code_lines(self, code): last_line = line return res - @unittest.expectedFailure # TODO: RUSTPYTHON def test_lineno_attribute(self): def load_attr(): return ( @@ -1294,7 +1293,6 @@ def aug_store_attr(): code_lines = self.get_code_lines(func.__code__) self.assertEqual(lines, code_lines) - @unittest.expectedFailure # TODO: RUSTPYTHON; + [0] def test_line_number_genexp(self): def return_genexp(): @@ -2582,7 +2580,6 @@ def test_set(self): def test_dict(self): self.check_stack_size("{" + "x:x, " * self.N + "x:x}") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 102 not less than or equal to 6 def test_func_args(self): self.check_stack_size("f(" + "x, " * self.N + ")") @@ -2590,7 +2587,6 @@ def test_func_kwargs(self): kwargs = (f'a{i}=x' for i in range(self.N)) self.check_stack_size("f(" + ", ".join(kwargs) + ")") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 102 not less than or equal to 6 def test_meth_args(self): self.check_stack_size("o.m(" + "x, " * self.N + ")") diff --git a/Lib/test/test_ctypes/test_byteswap.py b/Lib/test/test_ctypes/test_byteswap.py index ea5951603f9..6a1bae14773 100644 --- a/Lib/test/test_ctypes/test_byteswap.py +++ b/Lib/test/test_ctypes/test_byteswap.py @@ -166,6 +166,48 @@ def test_endian_double(self): self.assertEqual(s.value, math.pi) self.assertEqual(bin(struct.pack(">d", math.pi)), bin(s)) + @unittest.skipUnless(hasattr(ctypes, 'c_float_complex'), "No complex types") + def test_endian_float_complex(self): + c_float_complex = ctypes.c_float_complex + if sys.byteorder == "little": + self.assertIs(c_float_complex.__ctype_le__, c_float_complex) + self.assertIs(c_float_complex.__ctype_be__.__ctype_le__, + c_float_complex) + else: + self.assertIs(c_float_complex.__ctype_be__, c_float_complex) + self.assertIs(c_float_complex.__ctype_le__.__ctype_be__, + c_float_complex) + s = c_float_complex(math.pi+1j) + self.assertEqual(bin(struct.pack("F", math.pi+1j)), bin(s)) + self.assertAlmostEqual(s.value, math.pi+1j, places=6) + s = c_float_complex.__ctype_le__(math.pi+1j) + self.assertAlmostEqual(s.value, math.pi+1j, places=6) + self.assertEqual(bin(struct.pack("F", math.pi+1j)), bin(s)) + + @unittest.skipUnless(hasattr(ctypes, 'c_double_complex'), "No complex types") + def test_endian_double_complex(self): + c_double_complex = ctypes.c_double_complex + if sys.byteorder == "little": + self.assertIs(c_double_complex.__ctype_le__, c_double_complex) + self.assertIs(c_double_complex.__ctype_be__.__ctype_le__, + c_double_complex) + else: + self.assertIs(c_double_complex.__ctype_be__, c_double_complex) + self.assertIs(c_double_complex.__ctype_le__.__ctype_be__, + c_double_complex) + s = c_double_complex(math.pi+1j) + self.assertEqual(bin(struct.pack("D", math.pi+1j)), bin(s)) + self.assertAlmostEqual(s.value, math.pi+1j, places=6) + s = c_double_complex.__ctype_le__(math.pi+1j) + self.assertAlmostEqual(s.value, math.pi+1j, places=6) + self.assertEqual(bin(struct.pack("D", math.pi+1j)), bin(s)) + def test_endian_other(self): self.assertIs(c_byte.__ctype_le__, c_byte) self.assertIs(c_byte.__ctype_be__, c_byte) diff --git a/Lib/test/test_ctypes/test_dlerror.py b/Lib/test/test_ctypes/test_dlerror.py index ea2d97d9000..5658234f9ec 100644 --- a/Lib/test/test_ctypes/test_dlerror.py +++ b/Lib/test/test_ctypes/test_dlerror.py @@ -55,8 +55,6 @@ class TestNullDlsym(unittest.TestCase): this 'dlsym returned NULL -> throw Error' rule. """ - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_null_dlsym(self): import subprocess import tempfile diff --git a/Lib/test/test_ctypes/test_dllist.py b/Lib/test/test_ctypes/test_dllist.py index 0e7c65127f6..53f077be026 100644 --- a/Lib/test/test_ctypes/test_dllist.py +++ b/Lib/test/test_ctypes/test_dllist.py @@ -26,7 +26,6 @@ ) class ListSharedLibraries(unittest.TestCase): - # TODO: RUSTPYTHON @unittest.skipIf(not APPLE, "TODO: RUSTPYTHON") def test_lists_system(self): dlls = ctypes.util.dllist() @@ -36,8 +35,7 @@ def test_lists_system(self): any(lib in dll for dll in dlls for lib in KNOWN_LIBRARIES), f"loaded={dlls}" ) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_lists_updates(self): dlls = ctypes.util.dllist() diff --git a/Lib/test/test_ctypes/test_loading.py b/Lib/test/test_ctypes/test_loading.py index 3b8332fbb30..343f6a07c0a 100644 --- a/Lib/test/test_ctypes/test_loading.py +++ b/Lib/test/test_ctypes/test_loading.py @@ -106,6 +106,14 @@ def test_load_without_name_and_with_handle(self): lib = ctypes.WinDLL(name=None, handle=handle) self.assertIs(handle, lib._handle) + @unittest.skipIf(os.name == "nt", 'POSIX-specific test') + @unittest.skipIf(libc_name is None, 'could not find libc') + def test_load_without_name_and_with_handle_posix(self): + lib1 = CDLL(libc_name) + handle = lib1._handle + lib2 = CDLL(name=None, handle=handle) + self.assertIs(lib2._handle, handle) + @unittest.skipUnless(os.name == "nt", 'Windows-specific test') def test_1703286_A(self): # On winXP 64-bit, advapi32 loads at an address that does diff --git a/Lib/test/test_ctypes/test_python_api.py b/Lib/test/test_ctypes/test_python_api.py index a1ee8a0de1e..28abf2ac031 100644 --- a/Lib/test/test_ctypes/test_python_api.py +++ b/Lib/test/test_ctypes/test_python_api.py @@ -7,8 +7,7 @@ class PythonAPITestCase(unittest.TestCase): - # TODO: RUSTPYTHON - requires pythonapi (Python C API) - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; - requires pythonapi (Python C API) def test_PyBytes_FromStringAndSize(self): PyBytes_FromStringAndSize = pythonapi.PyBytes_FromStringAndSize @@ -59,8 +58,7 @@ def test_PyObj_FromPtr(self): del pyobj self.assertEqual(sys.getrefcount(s), ref) - # TODO: RUSTPYTHON - requires pythonapi (Python C API) - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; - requires pythonapi (Python C API) def test_PyOS_snprintf(self): PyOS_snprintf = pythonapi.PyOS_snprintf PyOS_snprintf.argtypes = POINTER(c_char), c_size_t, c_char_p diff --git a/Lib/test/test_ctypes/test_values.py b/Lib/test/test_ctypes/test_values.py index 18554e193be..27db481894f 100644 --- a/Lib/test/test_ctypes/test_values.py +++ b/Lib/test/test_ctypes/test_values.py @@ -39,8 +39,7 @@ def test_undefined(self): class PythonValuesTestCase(unittest.TestCase): """This test only works when python itself is a dll/shared library""" - # TODO: RUSTPYTHON - requires pythonapi (Python C API) - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; - requires pythonapi (Python C API) def test_optimizeflag(self): # This test accesses the Py_OptimizeFlag integer, which is # exported by the Python dll and should match the sys.flags value @@ -48,8 +47,7 @@ def test_optimizeflag(self): opt = c_int.in_dll(pythonapi, "Py_OptimizeFlag").value self.assertEqual(opt, sys.flags.optimize) - # TODO: RUSTPYTHON - requires pythonapi (Python C API) - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; - requires pythonapi (Python C API) @thread_unsafe('overrides frozen modules') def test_frozentable(self): # Python exports a PyImport_FrozenModules symbol. This is a diff --git a/Lib/test/test_ctypes/test_win32.py b/Lib/test/test_ctypes/test_win32.py index bb2fc0ca222..0e028e3a37c 100644 --- a/Lib/test/test_ctypes/test_win32.py +++ b/Lib/test/test_ctypes/test_win32.py @@ -13,8 +13,7 @@ @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') class FunctionCallTestCase(unittest.TestCase): - # TODO: RUSTPYTHON: SEH not implemented, crashes with STATUS_ACCESS_VIOLATION - @unittest.skip("TODO: RUSTPYTHON") + @unittest.skip("TODO: RUSTPYTHON; SEH not implemented, crashes with STATUS_ACCESS_VIOLATION") @unittest.skipUnless('MSC' in sys.version, "SEH only supported by MSC") @unittest.skipIf(sys.executable.lower().endswith('_d.exe'), "SEH not enabled in debug builds") diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index ec60e99f8fb..92bf7998d75 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -4339,7 +4339,6 @@ class C: C.__name__ = Nasty("abc") C.__name__ = "normal" - @unittest.expectedFailureIf(support.is_android, "TODO: RUSTPYTHON; AssertionError: 'C.__rfloordiv__' != 'C.__floordiv__'") def test_subclass_right_op(self): # Testing correct dispatch of subclass overloading __r__... @@ -5017,7 +5016,6 @@ def test_qualname_dict(self): ns = {'__qualname__': 1} self.assertRaises(TypeError, type, 'Foo', (), ns) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_cycle_through_dict(self): # See bug #1469629 class X(dict): diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 80d9e87d38f..79c975946f7 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -1259,7 +1259,7 @@ def __eq__(self, o): d = {X(): 0, 1: 1} self.assertRaises(RuntimeError, d.update, other) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON; hangs") def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, dict) support.check_free_after_iterating(self, lambda d: iter(d.keys()), dict) diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py index fcd6a6b8be7..7360ead1144 100644 --- a/Lib/test/test_dis.py +++ b/Lib/test/test_dis.py @@ -1211,15 +1211,12 @@ def test_disassemble_coroutine(self): def test_disassemble_fstring(self): self.do_disassembly_test(_fstring, dis_fstring) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_disassemble_with(self): self.do_disassembly_test(_with, dis_with) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_disassemble_asyncwith(self): self.do_disassembly_test(_asyncwith, dis_asyncwith) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_disassemble_try_finally(self): self.do_disassembly_test(_tryfinally, dis_tryfinally) self.do_disassembly_test(_tryfinallyconst, dis_tryfinallyconst) @@ -1991,26 +1988,22 @@ def test_first_line_set_to_None(self): actual = dis.get_instructions(simple, first_line=None) self.assertInstructionsEqual(list(actual), expected_opinfo_simple) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_outer(self): actual = dis.get_instructions(outer, first_line=expected_outer_line) self.assertInstructionsEqual(list(actual), expected_opinfo_outer) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_nested(self): with captured_stdout(): f = outer() actual = dis.get_instructions(f, first_line=expected_f_line) self.assertInstructionsEqual(list(actual), expected_opinfo_f) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_doubly_nested(self): with captured_stdout(): inner = outer()() actual = dis.get_instructions(inner, first_line=expected_inner_line) self.assertInstructionsEqual(list(actual), expected_opinfo_inner) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_jumpy(self): actual = dis.get_instructions(jumpy, first_line=expected_jumpy_line) self.assertInstructionsEqual(list(actual), expected_opinfo_jumpy) @@ -2314,7 +2307,6 @@ def test_iteration(self): via_generator = list(dis.get_instructions(obj)) self.assertInstructionsEqual(via_object, via_generator) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_explicit_first_line(self): actual = dis.Bytecode(outer, first_line=expected_outer_line) self.assertInstructionsEqual(list(actual), expected_opinfo_outer) diff --git a/Lib/test/test_email/__init__.py b/Lib/test/test_email/__init__.py index 455dc48facf..5d708e6e97e 100644 --- a/Lib/test/test_email/__init__.py +++ b/Lib/test/test_email/__init__.py @@ -5,7 +5,6 @@ from email.message import Message from email._policybase import compat32 from test.support import load_package_tests -from test.support.testcase import ExtraAssertions from test.test_email import __file__ as landmark # Load all tests in package @@ -21,7 +20,7 @@ def openfile(filename, *args, **kws): # Base test class -class TestEmailBase(unittest.TestCase, ExtraAssertions): +class TestEmailBase(unittest.TestCase): maxDiff = None # Currently the default policy is compat32. By setting that as the default diff --git a/Lib/test/test_email/test__header_value_parser.py b/Lib/test/test_email/test__header_value_parser.py index 64bc3677e87..f3c03062572 100644 --- a/Lib/test/test_email/test__header_value_parser.py +++ b/Lib/test/test_email/test__header_value_parser.py @@ -2617,7 +2617,7 @@ def test_get_address_list_mailboxes_invalid_addresses(self): '') self.assertEqual(address_list.token_type, 'address-list') self.assertEqual(len(address_list.mailboxes), 1) - self.assertEqual(len(address_list.all_mailboxes), 3) + self.assertEqual(len(address_list.all_mailboxes), 4) self.assertEqual([str(x) for x in address_list.all_mailboxes], [str(x) for x in address_list.addresses]) self.assertEqual(address_list.mailboxes[0].domain, 'example.com') @@ -2626,11 +2626,13 @@ def test_get_address_list_mailboxes_invalid_addresses(self): self.assertEqual(address_list.addresses[1].token_type, 'address') self.assertEqual(len(address_list.addresses[0].mailboxes), 1) self.assertEqual(len(address_list.addresses[1].mailboxes), 0) - self.assertEqual(len(address_list.addresses[1].mailboxes), 0) + self.assertEqual(len(address_list.addresses[2].mailboxes), 0) + self.assertEqual(len(address_list.addresses[3].mailboxes), 0) self.assertEqual( address_list.addresses[1].all_mailboxes[0].local_part, 'Foo x') + self.assertEqual(address_list.addresses[2].all_mailboxes[0].value, '[]') self.assertEqual( - address_list.addresses[2].all_mailboxes[0].display_name, + address_list.addresses[3].all_mailboxes[0].display_name, "Nobody Is. Special") def test_get_address_list_group_empty(self): @@ -2695,6 +2697,14 @@ def test_get_address_list_group_and_mailboxes(self): self.assertEqual(str(address_list.addresses[1]), str(address_list.mailboxes[2])) + def test_get_address_list_trailing_garbage(self): + address_list = self._test_get_x(parser.get_address_list, + 'unlisted-recipients:; (no To-header on input)', + 'unlisted-recipients:; (no To-header on input)', + 'unlisted-recipients:; ', + [errors.InvalidHeaderDefect]*2 + [errors.ObsoleteHeaderDefect], + '') + def test_invalid_content_disposition(self): content_disp = self._test_parse_x( parser.parse_content_disposition_header, @@ -2851,7 +2861,7 @@ def test_get_msg_id_no_id_right(self): parser.get_msg_id("") @@ -2867,6 +2877,81 @@ def test_get_msg_id_ws_only_local(self): ) self.assertEqual(msg_id.token_type, 'msg-id') + def test_parse_message_ids_valid(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + " ", + " ", + " ", + [], + ) + self.assertEqual(message_ids.token_type, 'message-id-list') + + def test_parse_message_ids_empty(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + " ", + " ", + " ", + [errors.InvalidHeaderDefect], + ) + self.assertEqual(message_ids.token_type, 'message-id-list') + + def test_parse_message_ids_comment(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + " (foo's message from \"bar\")", + " (foo's message from \"bar\")", + " ", + [], + ) + self.assertEqual(message_ids.message_ids[0].value, ' ') + self.assertEqual(message_ids.token_type, 'message-id-list') + + def test_parse_message_ids_no_sep(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + "", + "", + "", + [], + ) + self.assertEqual(message_ids.message_ids[0].value, '') + self.assertEqual(message_ids.message_ids[1].value, '') + self.assertEqual(message_ids.token_type, 'message-id-list') + + def test_parse_message_ids_comma_sep(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + ",", + " ", + " ", + [errors.InvalidHeaderDefect], + ) + self.assertEqual(message_ids.message_ids[0].value, '') + self.assertEqual(message_ids.message_ids[1].value, '') + self.assertEqual(message_ids.token_type, 'message-id-list') + + def test_parse_message_ids_invalid_id(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + "", + "", + "", + [errors.InvalidHeaderDefect]*2, + ) + self.assertEqual(message_ids.token_type, 'message-id-list') + + def test_parse_message_ids_broken_ang(self): + message_ids = self._test_parse_x( + parser.parse_message_ids, + " >bar@foo", + " >bar@foo", + " >bar@foo", + [errors.InvalidHeaderDefect]*1, + ) + self.assertEqual(message_ids.token_type, 'message-id-list') + @parameterize @@ -3219,6 +3304,29 @@ def test_address_list_with_specials_in_long_quoted_string(self): with self.subTest(to=to): self._test(parser.get_address_list(to)[0], folded, policy=policy) + def test_address_list_with_long_unwrapable_comment(self): + policy = self.policy.clone(max_line_length=40) + cases = [ + # (to, folded) + ('(loremipsumdolorsitametconsecteturadipi)', + '(loremipsumdolorsitametconsecteturadipi)\n'), + ('(loremipsumdolorsitametconsecteturadipi)', + '(loremipsumdolorsitametconsecteturadipi)\n'), + ('(loremipsum dolorsitametconsecteturadipi)', + '(loremipsum dolorsitametconsecteturadipi)\n'), + ('(loremipsum dolorsitametconsecteturadipi)', + '(loremipsum\n dolorsitametconsecteturadipi)\n'), + ('(Escaped \\( \\) chars \\\\ in comments stay escaped)', + '(Escaped \\( \\) chars \\\\ in comments stay\n escaped)\n'), + ('((loremipsum)(loremipsum)(loremipsum)(loremipsum))', + '((loremipsum)(loremipsum)(loremipsum)(loremipsum))\n'), + ('((loremipsum)(loremipsum)(loremipsum) (loremipsum))', + '((loremipsum)(loremipsum)(loremipsum)\n (loremipsum))\n'), + ] + for (to, folded) in cases: + with self.subTest(to=to): + self._test(parser.get_address_list(to)[0], folded, policy=policy) + # XXX Need tests with comments on various sides of a unicode token, # and with unicode tokens in the comments. Spaces inside the quotes # currently don't do the right thing. @@ -3255,5 +3363,23 @@ def test_long_filename_attachment(self): " filename*1*=_TEST_TES.txt\n", ) + def test_fold_unfoldable_element_stealing_whitespace(self): + # gh-142006: When an element is too long to fit on the current line + # the previous line's trailing whitespace should not trigger a double newline. + policy = self.policy.clone(max_line_length=10) + # The non-whitespace text needs to exactly fill the max_line_length (10). + text = ("a" * 9) + ", " + ("b" * 20) + expected = ("a" * 9) + ",\n " + ("b" * 20) + "\n" + token = parser.get_address_list(text)[0] + self._test(token, expected, policy=policy) + + def test_encoded_word_with_undecodable_bytes(self): + self._test(parser.get_address_list( + ' =?utf-8?Q?=E5=AE=A2=E6=88=B6=E6=AD=A3=E8=A6=8F=E4=BA=A4=E7?=' + )[0], + ' =?unknown-8bit?b?5a6i5oi25q2j6KaP5Lqk5w==?=\n', + ) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_email/test_asian_codecs.py b/Lib/test/test_email/test_asian_codecs.py index 1e0caeeaed0..ca44f54c69b 100644 --- a/Lib/test/test_email/test_asian_codecs.py +++ b/Lib/test/test_email/test_asian_codecs.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2006 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # Contact: email-sig@python.org # email package unit tests for (optional) Asian codecs diff --git a/Lib/test/test_email/test_defect_handling.py b/Lib/test/test_email/test_defect_handling.py index 44e76c8ce5e..acc4accccac 100644 --- a/Lib/test/test_email/test_defect_handling.py +++ b/Lib/test/test_email/test_defect_handling.py @@ -126,12 +126,10 @@ def test_multipart_invalid_cte(self): errors.InvalidMultipartContentTransferEncodingDefect) def test_multipart_no_cte_no_defect(self): - if self.raise_expected: return msg = self._str_msg(self.multipart_msg.format('')) self.assertEqual(len(self.get_defects(msg)), 0) def test_multipart_valid_cte_no_defect(self): - if self.raise_expected: return for cte in ('7bit', '8bit', 'BINary'): msg = self._str_msg( self.multipart_msg.format("\nContent-Transfer-Encoding: "+cte)) @@ -300,6 +298,47 @@ def test_missing_ending_boundary(self): self.assertDefectsEqual(self.get_defects(msg), [errors.CloseBoundaryNotFoundDefect]) + def test_line_beginning_colon(self): + string = ( + "Subject: Dummy subject\r\n: faulty header line\r\n\r\nbody\r\n" + ) + + with self._raise_point(errors.InvalidHeaderDefect): + msg = self._str_msg(string) + self.assertEqual(len(self.get_defects(msg)), 1) + self.assertDefectsEqual( + self.get_defects(msg), [errors.InvalidHeaderDefect] + ) + + if msg: + self.assertEqual(msg.items(), [("Subject", "Dummy subject")]) + self.assertEqual(msg.get_payload(), "body\r\n") + + def test_misplaced_envelope(self): + string = ( + "Subject: Dummy subject\r\nFrom wtf\r\nTo: abc\r\n\r\nbody\r\n" + ) + with self._raise_point(errors.MisplacedEnvelopeHeaderDefect): + msg = self._str_msg(string) + self.assertEqual(len(self.get_defects(msg)), 1) + self.assertDefectsEqual( + self.get_defects(msg), [errors.MisplacedEnvelopeHeaderDefect] + ) + + if msg: + headers = [("Subject", "Dummy subject"), ("To", "abc")] + self.assertEqual(msg.items(), headers) + self.assertEqual(msg.get_payload(), "body\r\n") + + + +class TestCompat32(TestDefectsBase, TestEmailBase): + + policy = policy.compat32 + + def get_defects(self, obj): + return obj.defects + class TestDefectDetection(TestDefectsBase, TestEmailBase): @@ -332,6 +371,9 @@ def _raise_point(self, defect): with self.assertRaises(defect): yield + def get_defects(self, obj): + return obj.defects + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py index 2d843c7d723..774a9265c8a 100644 --- a/Lib/test/test_email/test_email.py +++ b/Lib/test/test_email/test_email.py @@ -1,4 +1,4 @@ -# Copyright (C) 2001-2010 Python Software Foundation +# Copyright (C) 2001 Python Software Foundation # Contact: email-sig@python.org # email package unit tests @@ -481,7 +481,7 @@ def test_get_param_with_quotes(self): "Content-Type: foo; bar*0=\"baz\\\"foobar\"; bar*1=\"\\\"baz\"") self.assertEqual(msg.get_param('bar'), 'baz"foobar"baz') - @unittest.skip('TODO: RUSTPYTHON; Takes a long time to the point of timeouting') + @unittest.skip("TODO: RUSTPYTHON; Timeout") def test_get_param_linear_complexity(self): # Ensure that email.message._parseparam() is fast. # See https://github.com/python/cpython/issues/136063. @@ -768,6 +768,31 @@ def test_nonascii_add_header_with_tspecial(self): "attachment; filename*=utf-8''Fu%C3%9Fballer%20%5Bfilename%5D.ppt", msg['Content-Disposition']) + def test_invalid_header_names(self): + invalid_headers = [ + ('Invalid Header', 'contains space'), + ('Tab\tHeader', 'contains tab'), + ('Colon:Header', 'contains colon'), + ('', 'Empty name'), + (' LeadingSpace', 'starts with space'), + ('TrailingSpace ', 'ends with space'), + ('Header\x7F', 'Non-ASCII character'), + ('Header\x80', 'Extended ASCII'), + ] + for policy in (email.policy.default, email.policy.compat32): + for setter in (Message.__setitem__, Message.add_header): + for name, value in invalid_headers: + self.do_test_invalid_header_names( + policy, setter,name, value) + + def do_test_invalid_header_names(self, policy, setter, name, value): + with self.subTest(policy=policy, setter=setter, name=name, value=value): + message = Message(policy=policy) + pattern = r'(?i)(?=.*invalid)(?=.*header)(?=.*name)' + with self.assertRaisesRegex(ValueError, pattern) as cm: + setter(message, name, value) + self.assertIn(f"{name!r}", str(cm.exception)) + def test_binary_quopri_payload(self): for charset in ('latin-1', 'ascii'): msg = Message() @@ -2241,70 +2266,6 @@ def test_parse_missing_minor_type(self): eq(msg.get_content_maintype(), 'text') eq(msg.get_content_subtype(), 'plain') - # test_defect_handling - def test_same_boundary_inner_outer(self): - msg = self._msgobj('msg_15.txt') - # XXX We can probably eventually do better - inner = msg.get_payload(0) - self.assertHasAttr(inner, 'defects') - self.assertEqual(len(inner.defects), 1) - self.assertIsInstance(inner.defects[0], - errors.StartBoundaryNotFoundDefect) - - # test_defect_handling - def test_multipart_no_boundary(self): - msg = self._msgobj('msg_25.txt') - self.assertIsInstance(msg.get_payload(), str) - self.assertEqual(len(msg.defects), 2) - self.assertIsInstance(msg.defects[0], - errors.NoBoundaryInMultipartDefect) - self.assertIsInstance(msg.defects[1], - errors.MultipartInvariantViolationDefect) - - multipart_msg = textwrap.dedent("""\ - Date: Wed, 14 Nov 2007 12:56:23 GMT - From: foo@bar.invalid - To: foo@bar.invalid - Subject: Content-Transfer-Encoding: base64 and multipart - MIME-Version: 1.0 - Content-Type: multipart/mixed; - boundary="===============3344438784458119861=="{} - - --===============3344438784458119861== - Content-Type: text/plain - - Test message - - --===============3344438784458119861== - Content-Type: application/octet-stream - Content-Transfer-Encoding: base64 - - YWJj - - --===============3344438784458119861==-- - """) - - # test_defect_handling - def test_multipart_invalid_cte(self): - msg = self._str_msg( - self.multipart_msg.format("\nContent-Transfer-Encoding: base64")) - self.assertEqual(len(msg.defects), 1) - self.assertIsInstance(msg.defects[0], - errors.InvalidMultipartContentTransferEncodingDefect) - - # test_defect_handling - def test_multipart_no_cte_no_defect(self): - msg = self._str_msg(self.multipart_msg.format('')) - self.assertEqual(len(msg.defects), 0) - - # test_defect_handling - def test_multipart_valid_cte_no_defect(self): - for cte in ('7bit', '8bit', 'BINary'): - msg = self._str_msg( - self.multipart_msg.format( - "\nContent-Transfer-Encoding: {}".format(cte))) - self.assertEqual(len(msg.defects), 0) - # test_headerregistry.TestContentTypeHeader invalid_1 and invalid_2. def test_invalid_content_type(self): eq = self.assertEqual @@ -2381,30 +2342,6 @@ def test_missing_start_boundary(self): self.assertIsInstance(bad.defects[0], errors.StartBoundaryNotFoundDefect) - # test_defect_handling - def test_first_line_is_continuation_header(self): - eq = self.assertEqual - m = ' Line 1\nSubject: test\n\nbody' - msg = email.message_from_string(m) - eq(msg.keys(), ['Subject']) - eq(msg.get_payload(), 'body') - eq(len(msg.defects), 1) - self.assertDefectsEqual(msg.defects, - [errors.FirstHeaderLineIsContinuationDefect]) - eq(msg.defects[0].line, ' Line 1\n') - - # test_defect_handling - def test_missing_header_body_separator(self): - # Our heuristic if we see a line that doesn't look like a header (no - # leading whitespace but no ':') is to assume that the blank line that - # separates the header from the body is missing, and to stop parsing - # headers and start parsing the body. - msg = self._str_msg('Subject: test\nnot a header\nTo: abc\n\nb\n') - self.assertEqual(msg.keys(), ['Subject']) - self.assertEqual(msg.get_payload(), 'not a header\nTo: abc\n\nb\n') - self.assertDefectsEqual(msg.defects, - [errors.MissingHeaderBodySeparatorDefect]) - def test_string_payload_with_extra_space_after_cte(self): # https://github.com/python/cpython/issues/98188 cte = "base64 " @@ -4898,6 +4835,15 @@ def test_decode_soft_line_break(self): def test_decode_false_quoting(self): self._test_decode('A=1,B=A ==> A+B==2', 'A=1,B=A ==> A+B==2') + def test_decode_crlf_eol_no_trailing_newline(self): + self._test_decode('abc', 'abc', eol='\r\n') + + def test_decode_crlf_eol_multiline_no_trailing_newline(self): + self._test_decode('a\r\nb', 'a\r\nb', eol='\r\n') + + def test_decode_crlf_eol_with_trailing_newline(self): + self._test_decode('abc\r\n', 'abc\r\n', eol='\r\n') + def _test_encode(self, body, expected_encoded_body, maxlinelen=None, eol=None): kwargs = {} if maxlinelen is None: diff --git a/Lib/test/test_email/test_generator.py b/Lib/test/test_email/test_generator.py index c75a842c335..3c9a86f3e8c 100644 --- a/Lib/test/test_email/test_generator.py +++ b/Lib/test/test_email/test_generator.py @@ -1,13 +1,20 @@ import io import textwrap import unittest +import random +import sys from email import message_from_string, message_from_bytes from email.message import EmailMessage +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText from email.generator import Generator, BytesGenerator +import email.generator from email.headerregistry import Address from email import policy import email.errors from test.test_email import TestEmailBase, parameterize +import test.support + @parameterize @@ -288,6 +295,36 @@ def test_keep_long_encoded_newlines(self): g.flatten(msg) self.assertEqual(s.getvalue(), self.typ(expected)) + def _test_boundary_detection(self, linesep): + # Generate a boundary token in the same way as _make_boundary + token = random.randrange(sys.maxsize) + + def _patch_random_randrange(*args, **kwargs): + return token + + with test.support.swap_attr( + random, "randrange", _patch_random_randrange + ): + boundary = self.genclass._make_boundary(text=None) + boundary_in_part = ( + "this goes before the boundary\n--" + + boundary + + "\nthis goes after\n" + ) + msg = MIMEMultipart() + msg.attach(MIMEText(boundary_in_part)) + self.genclass(self.ioclass()).flatten(msg, linesep=linesep) + # Generator checks the message content for the string it is about + # to use as a boundary ('token' in this test) and when it finds it + # in our attachment appends .0 to make the boundary it uses unique. + self.assertEqual(msg.get_boundary(), boundary + ".0") + + def test_lf_boundary_detection(self): + self._test_boundary_detection("\n") + + def test_crlf_boundary_detection(self): + self._test_boundary_detection("\r\n") + class TestGenerator(TestGeneratorBase, TestEmailBase): @@ -313,7 +350,7 @@ def test_flatten_unicode_linesep(self): self.assertEqual(s.getvalue(), self.typ(expected)) def test_verify_generated_headers(self): - """gh-121650: by default the generator prevents header injection""" + # gh-121650: by default the generator prevents header injection class LiteralHeader(str): name = 'Header' def fold(self, **kwargs): @@ -334,6 +371,8 @@ def fold(self, **kwargs): with self.assertRaises(email.errors.HeaderWriteError): message.as_string() + with self.assertRaises(email.errors.HeaderWriteError): + message.as_bytes() class TestBytesGenerator(TestGeneratorBase, TestEmailBase): @@ -391,6 +430,50 @@ def test_defaults_handle_spaces_at_start_of_continuation_line(self): g.flatten(msg) self.assertEqual(s.getvalue(), expected) + # gh-144156: fold between non-encoded and encoded words don't need to encoded + # the separating space + def test_defaults_handle_spaces_at_start_of_continuation_line_2(self): + source = ("Re: [SOS-1495488] Commande et livraison - Demande de retour - " + "bibijolie - 251210-AABBCC - Abo actualités digitales 20 semaines " + "d’abonnement à 24 heures, Bilan, Tribune de Genève et tous les titres Tamedia") + expected = ( + b"Subject: " + b"Re: [SOS-1495488] Commande et livraison - Demande de retour -\n" + b" bibijolie - 251210-AABBCC - Abo =?utf-8?q?actualit=C3=A9s?= digitales 20\n" + b" semaines =?utf-8?q?d=E2=80=99abonnement_=C3=A0?= 24 heures, Bilan, Tribune de\n" + b" =?utf-8?q?Gen=C3=A8ve?= et tous les titres Tamedia\n\n" + ) + msg = EmailMessage() + msg['Subject'] = source + s = io.BytesIO() + g = BytesGenerator(s) + g.flatten(msg) + self.assertEqual(s.getvalue(), expected) + + def test_ew_folding_round_trip_1(self): + print() + source = "aaaaaaaaa фффффффф " + msg = EmailMessage() + msg['Subject'] = source + s = io.BytesIO() + g = BytesGenerator(s, maxheaderlen=30) + g.flatten(msg) + flat = s.getvalue() + reparsed = message_from_bytes(flat, policy=policy.default)['Subject'] + self.assertMultiLineEqual(reparsed, source) + + def test_ew_folding_round_trip_2(self): + print() + source = "aaa aaaaaaa aaa ффф фффф " + msg = EmailMessage() + msg['Subject'] = source + s = io.BytesIO() + g = BytesGenerator(s, maxheaderlen=30) + g.flatten(msg) + flat = s.getvalue() + reparsed = message_from_bytes(flat, policy=policy.default)['Subject'] + self.assertMultiLineEqual(reparsed, source) + def test_cte_type_7bit_handles_unknown_8bit(self): source = ("Subject: Maintenant je vous présente mon " "collègue\n\n").encode('utf-8') diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py index d2c571299bc..854b76966ec 100644 --- a/Lib/test/test_email/test_headerregistry.py +++ b/Lib/test/test_email/test_headerregistry.py @@ -7,7 +7,6 @@ from test.test_email import TestEmailBase, parameterize from email import headerregistry from email.headerregistry import Address, Group -from email.header import decode_header from test.support import ALWAYS_EQ @@ -133,11 +132,6 @@ def string_as_value(self, source, decoded, *args): - # TODO: RUSTPYTHON; RustPython currently does not support non-utf8 encoding - if source == '=?gb2312?b?1eLKx9bQzsSy4srUo6E=?=': - raise unittest.SkipTest("TODO: RUSTPYTHON; RustPython currently does not support non-utf8 encoding") - # RUSTPYTHON: End - # ------------------------------------------------------------------ l = len(args) defects = args[0] if l>0 else [] header = 'Subject:' + (' ' if source else '') @@ -171,6 +165,9 @@ def string_as_value(self, } +TestUnstructuredHeader.test_value_rfc2047_gb2312_base64 = unittest.expectedFailure( # TODO: RUSTPYTHON + TestUnstructuredHeader.test_value_rfc2047_gb2312_base64 +) @parameterize class TestDateHeader(TestHeaderBase): @@ -1268,12 +1265,12 @@ class TestAddressHeader(TestHeaderBase): 'example.com', None), - } - # XXX: Need many more examples, and in particular some with names in # trailing comments, which aren't currently handled. comments in # general are not handled yet. + } + def example_as_address(self, source, defects, decoded, display_name, addr_spec, username, domain, comment): h = self.make_header('sender', source) @@ -1291,6 +1288,43 @@ def example_as_address(self, source, defects, decoded, display_name, # XXX: we have no comment support yet. #self.assertEqual(a.comment, comment) + example_broken_header_params = { + + 'just_dquote': + ('"', + [errors.InvalidHeaderDefect]*2, + '<>', + '', + '<>', + '', + '', + ), + + } + + def example_broken_header_as_address( + self, + source, + defects, + decoded, + display_name, + addr_spec, + username, + domain, + ): + h = self.make_header('sender', source) + self.assertEqual(h, decoded) + self.assertDefectsEqual(h.defects, defects) + a = h.address + self.assertEqual(str(a), decoded) + self.assertEqual(len(h.groups), 1) + self.assertEqual([a], list(h.groups[0].addresses)) + self.assertEqual([a], list(h.addresses)) + self.assertEqual(a.display_name, display_name) + self.assertEqual(a.addr_spec, addr_spec) + self.assertEqual(a.username, username) + self.assertEqual(a.domain, domain) + def example_as_group(self, source, defects, decoded, display_name, addr_spec, username, domain, comment): source = 'foo: {};'.format(source) @@ -1708,7 +1742,7 @@ def test_fold_unstructured_with_overlong_word(self): 'singlewordthatwontfit') self.assertEqual( h.fold(policy=policy.default.clone(max_line_length=20)), - 'Subject: \n' + 'Subject:\n' ' =?utf-8?q?thisisa?=\n' ' =?utf-8?q?verylon?=\n' ' =?utf-8?q?glineco?=\n' @@ -1724,7 +1758,7 @@ def test_fold_unstructured_with_two_overlong_words(self): 'singlewordthatwontfit plusanotherverylongwordthatwontfit') self.assertEqual( h.fold(policy=policy.default.clone(max_line_length=20)), - 'Subject: \n' + 'Subject:\n' ' =?utf-8?q?thisisa?=\n' ' =?utf-8?q?verylon?=\n' ' =?utf-8?q?glineco?=\n' @@ -1818,5 +1852,18 @@ def test_message_id_header_is_not_folded(self): h.fold(policy=policy.default.clone(max_line_length=20)), 'Message-ID:\n <ईमेलfromMessage@wők.com>\n') + def test_fold_references(self): + h = self.make_header( + 'References', + ' ' + '' + ) + self.assertEqual( + h.fold(policy=policy.default.clone(max_line_length=20)), + 'References: ' + '\n' + ' \n') + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_email/test_message.py b/Lib/test/test_email/test_message.py index 966615dcc1d..56ad446694d 100644 --- a/Lib/test/test_email/test_message.py +++ b/Lib/test/test_email/test_message.py @@ -1030,6 +1030,30 @@ def do_test_no_wrapping_max_line_length(self, falsey): parsed_body = parsed.get_body().get_content().rstrip('\n') self.assertEqual(parsed_body, body) + def test_invalid_header_names(self): + invalid_headers = [ + ('Invalid Header', 'contains space'), + ('Tab\tHeader', 'contains tab'), + ('Colon:Header', 'contains colon'), + ('', 'Empty name'), + (' LeadingSpace', 'starts with space'), + ('TrailingSpace ', 'ends with space'), + ('Header\x7F', 'Non-ASCII character'), + ('Header\x80', 'Extended ASCII'), + ] + for email_policy in (policy.default, policy.compat32): + for setter in (EmailMessage.__setitem__, EmailMessage.add_header): + for name, value in invalid_headers: + self.do_test_invalid_header_names(email_policy, setter, name, value) + + def do_test_invalid_header_names(self, policy, setter, name, value): + with self.subTest(policy=policy, setter=setter, name=name, value=value): + message = EmailMessage(policy=policy) + pattern = r'(?i)(?=.*invalid)(?=.*header)(?=.*name)' + with self.assertRaisesRegex(ValueError, pattern) as cm: + setter(message, name, value) + self.assertIn(f"{name!r}", str(cm.exception)) + def test_get_body_malformed(self): """test for bpo-42892""" msg = textwrap.dedent("""\ diff --git a/Lib/test/test_email/test_policy.py b/Lib/test/test_email/test_policy.py index baa35fd68e4..90e8e558029 100644 --- a/Lib/test/test_email/test_policy.py +++ b/Lib/test/test_email/test_policy.py @@ -273,7 +273,7 @@ def test_non_ascii_chars_do_not_cause_inf_loop(self): actual = policy.fold('Subject', 'ą' * 12) self.assertEqual( actual, - 'Subject: \n' + + 'Subject:\n' + 12 * ' =?utf-8?q?=C4=85?=\n') def test_short_maxlen_error(self): @@ -296,7 +296,7 @@ def test_short_maxlen_error(self): policy.fold("Subject", subject) def test_verify_generated_headers(self): - """Turning protection off allows header injection""" + # Turning protection off allows header injection policy = email.policy.default.clone(verify_generated_headers=False) for text in ( 'Header: Value\r\nBad: Injection\r\n', @@ -319,6 +319,10 @@ def fold(self, **kwargs): message.as_string(), f"{text}\nBody", ) + self.assertEqual( + message.as_bytes(), + f"{text}\nBody".encode(), + ) # XXX: Need subclassing tests. # For adding subclassed objects, make sure the usual rules apply (subclass diff --git a/Lib/test/test_email/test_utils.py b/Lib/test/test_email/test_utils.py index d04b3909efa..c9d09098b50 100644 --- a/Lib/test/test_email/test_utils.py +++ b/Lib/test/test_email/test_utils.py @@ -3,9 +3,17 @@ import test.support import time import unittest -import sys -import os.path -import zoneinfo + +from test.support import cpython_only +from test.support.import_helper import ensure_lazy_imports + + +class TestImportTime(unittest.TestCase): + + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("email.utils", {"random", "socket"}) + class DateTimeTests(unittest.TestCase): @@ -154,10 +162,6 @@ def test_variable_tzname(self): t1 = utils.localtime(t0) self.assertEqual(t1.tzname(), 'EET') - def test_isdst_deprecation(self): - with self.assertWarns(DeprecationWarning): - t0 = datetime.datetime(1990, 1, 1) - t1 = utils.localtime(t0, isdst=True) # Issue #24836: The timezone files are out of date (pre 2011k) # on Mac OS X Snow Leopard. diff --git a/Lib/test/test_email/torture_test.py b/Lib/test/test_email/torture_test.py index 9cf9362c9b7..d15948a38b2 100644 --- a/Lib/test/test_email/torture_test.py +++ b/Lib/test/test_email/torture_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2004 Python Software Foundation +# Copyright (C) 2002 Python Software Foundation # # A torture test of the email package. This should not be run as part of the # standard Python test suite since it requires several meg of email messages diff --git a/Lib/test/test_eof.py b/Lib/test/test_eof.py index 2c74e2d87d3..f5a0bc56958 100644 --- a/Lib/test/test_eof.py +++ b/Lib/test/test_eof.py @@ -9,7 +9,6 @@ import unittest class EOFTestCase(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON def test_EOF_single_quote(self): expect = "unterminated string literal (detected at line 1) (, line 1)" for quote in ("'", "\""): @@ -19,7 +18,7 @@ def test_EOF_single_quote(self): self.assertEqual(str(cm.exception), expect) self.assertEqual(cm.exception.offset, 1) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_EOFS(self): expect = ("unterminated triple-quoted string literal (detected at line 3) (, line 1)") with self.assertRaises(SyntaxError) as cm: @@ -46,7 +45,7 @@ def test_EOFS(self): self.assertEqual(cm.exception.text, "ä = '''thîs is ") self.assertEqual(cm.exception.offset, 5) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON @force_not_colorized def test_EOFS_with_file(self): expect = ("(, line 1)") @@ -87,7 +86,7 @@ def test_EOFS_with_file(self): ' ^', 'SyntaxError: unterminated triple-quoted string literal (detected at line 4)']) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON @warnings_helper.ignore_warnings(category=SyntaxWarning) def test_eof_with_line_continuation(self): expect = "unexpected EOF while parsing (, line 1)" @@ -95,7 +94,7 @@ def test_eof_with_line_continuation(self): compile('"\\Xhh" \\', '', 'exec') self.assertEqual(str(cm.exception), expect) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_line_continuation_EOF(self): """A continuation at the end of input must be an error; bpo2180.""" expect = 'unexpected EOF while parsing (, line 1)' @@ -128,7 +127,7 @@ def test_line_continuation_EOF(self): exec('\\') self.assertEqual(str(cm.exception), expect) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipIf(not sys.executable, "sys.executable required") @force_not_colorized def test_line_continuation_EOF_from_file_bpo2180(self): diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py index 10010ffa9b8..7e79732a3b9 100644 --- a/Lib/test/test_exceptions.py +++ b/Lib/test/test_exceptions.py @@ -2245,7 +2245,6 @@ def test_assertion_error_location(self): result = run_script(source) self.assertEqual(result[-3:], expected) - @unittest.expectedFailure # TODO: RUSTPYTHON @force_not_colorized def test_multiline_not_highlighted(self): cases = [ diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index d53a3f46134..6868c87171d 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -423,7 +423,6 @@ def test_non_ascii(self): self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007") self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; AssertionError: ',' not found in '123456789'") def test_locale(self): try: oldloc = locale.setlocale(locale.LC_ALL) diff --git a/Lib/test/test_generated_cases.py b/Lib/test/test_generated_cases.py new file mode 100644 index 00000000000..fc34ac2fdc9 --- /dev/null +++ b/Lib/test/test_generated_cases.py @@ -0,0 +1,2074 @@ +import contextlib +import os +import re +import sys +import tempfile +import unittest + +from io import StringIO +from test import support +from test import test_tools + + +def skip_if_different_mount_drives(): + if sys.platform != "win32": + return + ROOT = os.path.dirname(os.path.dirname(__file__)) + root_drive = os.path.splitroot(ROOT)[0] + cwd_drive = os.path.splitroot(os.getcwd())[0] + if root_drive != cwd_drive: + # May raise ValueError if ROOT and the current working + # different have different mount drives (on Windows). + raise unittest.SkipTest( + f"the current working directory and the Python source code " + f"directory have different mount drives " + f"({cwd_drive} and {root_drive})" + ) + + +skip_if_different_mount_drives() + + +test_tools.skip_if_missing("cases_generator") +with test_tools.imports_under_tool("cases_generator"): + from analyzer import analyze_forest, StackItem + from cwriter import CWriter + import parser + from stack import Local, Stack + import tier1_generator + import opcode_metadata_generator + import optimizer_generator + + +def handle_stderr(): + if support.verbose > 1: + return contextlib.nullcontext() + else: + return support.captured_stderr() + + +def parse_src(src): + p = parser.Parser(src, "test.c") + nodes = [] + while node := p.definition(): + nodes.append(node) + return nodes + + +class TestEffects(unittest.TestCase): + def test_effect_sizes(self): + stack = Stack() + inputs = [ + x := StackItem("x", None, "1"), + y := StackItem("y", None, "oparg"), + z := StackItem("z", None, "oparg*2"), + ] + outputs = [ + StackItem("x", None, "1"), + StackItem("b", None, "oparg*4"), + StackItem("c", None, "1"), + ] + null = CWriter.null() + stack.pop(z, null) + stack.pop(y, null) + stack.pop(x, null) + for out in outputs: + stack.push(Local.undefined(out)) + self.assertEqual(stack.base_offset.to_c(), "-1 - oparg - oparg*2") + self.assertEqual(stack.physical_sp.to_c(), "0") + self.assertEqual(stack.logical_sp.to_c(), "1 - oparg - oparg*2 + oparg*4") + + +class TestGeneratedCases(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.maxDiff = None + + self.temp_dir = tempfile.gettempdir() + self.temp_input_filename = os.path.join(self.temp_dir, "input.txt") + self.temp_output_filename = os.path.join(self.temp_dir, "output.txt") + self.temp_metadata_filename = os.path.join(self.temp_dir, "metadata.txt") + self.temp_pymetadata_filename = os.path.join(self.temp_dir, "pymetadata.txt") + self.temp_executor_filename = os.path.join(self.temp_dir, "executor.txt") + + def tearDown(self) -> None: + for filename in [ + self.temp_input_filename, + self.temp_output_filename, + self.temp_metadata_filename, + self.temp_pymetadata_filename, + self.temp_executor_filename, + ]: + try: + os.remove(filename) + except: + pass + super().tearDown() + + def run_cases_test(self, input: str, expected: str): + with open(self.temp_input_filename, "w+") as temp_input: + temp_input.write(parser.BEGIN_MARKER) + temp_input.write(input) + temp_input.write(parser.END_MARKER) + temp_input.flush() + + with handle_stderr(): + tier1_generator.generate_tier1_from_files( + [self.temp_input_filename], self.temp_output_filename, False + ) + + with open(self.temp_output_filename) as temp_output: + lines = temp_output.read() + _, rest = lines.split(tier1_generator.INSTRUCTION_START_MARKER) + instructions, labels_with_prelude_and_postlude = rest.split(tier1_generator.INSTRUCTION_END_MARKER) + _, labels_with_postlude = labels_with_prelude_and_postlude.split(tier1_generator.LABEL_START_MARKER) + labels, _ = labels_with_postlude.split(tier1_generator.LABEL_END_MARKER) + actual = instructions.strip() + "\n\n " + labels.strip() + + self.assertEqual(actual.strip(), expected.strip()) + + def test_inst_no_args(self): + input = """ + inst(OP, (--)) { + SPAM(); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + SPAM(); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_inst_one_pop(self): + input = """ + inst(OP, (value --)) { + SPAM(value); + DEAD(value); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef value; + value = stack_pointer[-1]; + SPAM(value); + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_inst_one_push(self): + input = """ + inst(OP, (-- res)) { + res = SPAM(); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef res; + res = SPAM(); + stack_pointer[0] = res; + stack_pointer += 1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_inst_one_push_one_pop(self): + input = """ + inst(OP, (value -- res)) { + res = SPAM(value); + DEAD(value); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef value; + _PyStackRef res; + value = stack_pointer[-1]; + res = SPAM(value); + stack_pointer[-1] = res; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_binary_op(self): + input = """ + inst(OP, (left, right -- res)) { + res = SPAM(left, right); + INPUTS_DEAD(); + + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef left; + _PyStackRef right; + _PyStackRef res; + right = stack_pointer[-1]; + left = stack_pointer[-2]; + res = SPAM(left, right); + stack_pointer[-2] = res; + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_overlap(self): + input = """ + inst(OP, (left, right -- left, result)) { + result = SPAM(left, right); + INPUTS_DEAD(); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef left; + _PyStackRef right; + _PyStackRef result; + right = stack_pointer[-1]; + left = stack_pointer[-2]; + result = SPAM(left, right); + stack_pointer[-1] = result; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_predictions(self): + input = """ + inst(OP1, (arg -- res)) { + DEAD(arg); + res = Py_None; + } + inst(OP3, (arg -- res)) { + DEAD(arg); + DEOPT_IF(xxx); + res = Py_None; + } + family(OP1, INLINE_CACHE_ENTRIES_OP1) = { OP3 }; + """ + output = """ + TARGET(OP1) { + #if Py_TAIL_CALL_INTERP + int opcode = OP1; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP1); + PREDICTED_OP1:; + _PyStackRef arg; + _PyStackRef res; + arg = stack_pointer[-1]; + res = Py_None; + stack_pointer[-1] = res; + DISPATCH(); + } + + TARGET(OP3) { + #if Py_TAIL_CALL_INTERP + int opcode = OP3; + (void)(opcode); + #endif + _Py_CODEUNIT* const this_instr = next_instr; + (void)this_instr; + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP3); + static_assert(INLINE_CACHE_ENTRIES_OP1 == 0, "incorrect cache size"); + _PyStackRef arg; + _PyStackRef res; + arg = stack_pointer[-1]; + if (xxx) { + UPDATE_MISS_STATS(OP1); + assert(_PyOpcode_Deopt[opcode] == (OP1)); + JUMP_TO_PREDICTED(OP1); + } + res = Py_None; + stack_pointer[-1] = res; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_sync_sp(self): + input = """ + inst(A, (arg -- res)) { + DEAD(arg); + SYNC_SP(); + escaping_call(); + res = Py_None; + } + inst(B, (arg -- res)) { + DEAD(arg); + res = Py_None; + SYNC_SP(); + escaping_call(); + } + """ + output = """ + TARGET(A) { + #if Py_TAIL_CALL_INTERP + int opcode = A; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(A); + _PyStackRef arg; + _PyStackRef res; + arg = stack_pointer[-1]; + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + _PyFrame_SetStackPointer(frame, stack_pointer); + escaping_call(); + stack_pointer = _PyFrame_GetStackPointer(frame); + res = Py_None; + stack_pointer[0] = res; + stack_pointer += 1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + + TARGET(B) { + #if Py_TAIL_CALL_INTERP + int opcode = B; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(B); + _PyStackRef arg; + _PyStackRef res; + arg = stack_pointer[-1]; + res = Py_None; + stack_pointer[-1] = res; + _PyFrame_SetStackPointer(frame, stack_pointer); + escaping_call(); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + + def test_pep7_condition(self): + input = """ + inst(OP, (arg1 -- out)) { + if (arg1) + out = 0; + else { + out = 1; + } + } + """ + output = "" + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_error_if_plain(self): + input = """ + inst(OP, (--)) { + ERROR_IF(cond); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + if (cond) { + JUMP_TO_LABEL(error); + } + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_error_if_plain_with_comment(self): + input = """ + inst(OP, (--)) { + ERROR_IF(cond); // Comment is ok + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + if (cond) { + JUMP_TO_LABEL(error); + } + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_error_if_pop(self): + input = """ + inst(OP, (left, right -- res)) { + SPAM(left, right); + INPUTS_DEAD(); + ERROR_IF(cond); + res = 0; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef left; + _PyStackRef right; + _PyStackRef res; + right = stack_pointer[-1]; + left = stack_pointer[-2]; + SPAM(left, right); + if (cond) { + JUMP_TO_LABEL(pop_2_error); + } + res = 0; + stack_pointer[-2] = res; + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_error_if_pop_with_result(self): + input = """ + inst(OP, (left, right -- res)) { + res = SPAM(left, right); + INPUTS_DEAD(); + ERROR_IF(cond); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef left; + _PyStackRef right; + _PyStackRef res; + right = stack_pointer[-1]; + left = stack_pointer[-2]; + res = SPAM(left, right); + if (cond) { + JUMP_TO_LABEL(pop_2_error); + } + stack_pointer[-2] = res; + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_cache_effect(self): + input = """ + inst(OP, (counter/1, extra/2, value --)) { + DEAD(value); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + _Py_CODEUNIT* const this_instr = next_instr; + (void)this_instr; + frame->instr_ptr = next_instr; + next_instr += 4; + INSTRUCTION_STATS(OP); + _PyStackRef value; + value = stack_pointer[-1]; + uint16_t counter = read_u16(&this_instr[1].cache); + (void)counter; + uint32_t extra = read_u32(&this_instr[2].cache); + (void)extra; + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_suppress_dispatch(self): + input = """ + label(somewhere) { + } + + inst(OP, (--)) { + goto somewhere; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + JUMP_TO_LABEL(somewhere); + } + + LABEL(somewhere) + { + } + """ + self.run_cases_test(input, output) + + def test_macro_instruction(self): + input = """ + inst(OP1, (counter/1, left, right -- left, right)) { + op1(left, right); + } + op(OP2, (extra/2, arg2, left, right -- res)) { + res = op2(arg2, left, right); + INPUTS_DEAD(); + } + macro(OP) = OP1 + cache/2 + OP2; + inst(OP3, (unused/5, arg2, left, right -- res)) { + res = op3(arg2, left, right); + INPUTS_DEAD(); + } + family(OP, INLINE_CACHE_ENTRIES_OP) = { OP3 }; + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 6; + INSTRUCTION_STATS(OP); + PREDICTED_OP:; + _Py_CODEUNIT* const this_instr = next_instr - 6; + (void)this_instr; + _PyStackRef left; + _PyStackRef right; + _PyStackRef arg2; + _PyStackRef res; + // _OP1 + { + right = stack_pointer[-1]; + left = stack_pointer[-2]; + uint16_t counter = read_u16(&this_instr[1].cache); + (void)counter; + _PyFrame_SetStackPointer(frame, stack_pointer); + op1(left, right); + stack_pointer = _PyFrame_GetStackPointer(frame); + } + /* Skip 2 cache entries */ + // OP2 + { + arg2 = stack_pointer[-3]; + uint32_t extra = read_u32(&this_instr[4].cache); + (void)extra; + _PyFrame_SetStackPointer(frame, stack_pointer); + res = op2(arg2, left, right); + stack_pointer = _PyFrame_GetStackPointer(frame); + } + stack_pointer[-3] = res; + stack_pointer += -2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + + TARGET(OP1) { + #if Py_TAIL_CALL_INTERP + int opcode = OP1; + (void)(opcode); + #endif + _Py_CODEUNIT* const this_instr = next_instr; + (void)this_instr; + frame->instr_ptr = next_instr; + next_instr += 2; + INSTRUCTION_STATS(OP1); + _PyStackRef left; + _PyStackRef right; + right = stack_pointer[-1]; + left = stack_pointer[-2]; + uint16_t counter = read_u16(&this_instr[1].cache); + (void)counter; + _PyFrame_SetStackPointer(frame, stack_pointer); + op1(left, right); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + + TARGET(OP3) { + #if Py_TAIL_CALL_INTERP + int opcode = OP3; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 6; + INSTRUCTION_STATS(OP3); + static_assert(INLINE_CACHE_ENTRIES_OP == 5, "incorrect cache size"); + _PyStackRef arg2; + _PyStackRef left; + _PyStackRef right; + _PyStackRef res; + /* Skip 5 cache entries */ + right = stack_pointer[-1]; + left = stack_pointer[-2]; + arg2 = stack_pointer[-3]; + _PyFrame_SetStackPointer(frame, stack_pointer); + res = op3(arg2, left, right); + stack_pointer = _PyFrame_GetStackPointer(frame); + stack_pointer[-3] = res; + stack_pointer += -2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_unused_caches(self): + input = """ + inst(OP, (unused/1, unused/2 --)) { + body; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 4; + INSTRUCTION_STATS(OP); + /* Skip 1 cache entry */ + /* Skip 2 cache entries */ + body; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_pseudo_instruction_no_flags(self): + input = """ + pseudo(OP, (in -- out1, out2)) = { + OP1, + }; + + inst(OP1, (--)) { + } + """ + output = """ + TARGET(OP1) { + #if Py_TAIL_CALL_INTERP + int opcode = OP1; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP1); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_pseudo_instruction_with_flags(self): + input = """ + pseudo(OP, (in1, in2 --), (HAS_ARG, HAS_JUMP)) = { + OP1, + }; + + inst(OP1, (--)) { + } + """ + output = """ + TARGET(OP1) { + #if Py_TAIL_CALL_INTERP + int opcode = OP1; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP1); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_pseudo_instruction_as_sequence(self): + input = """ + pseudo(OP, (in -- out1, out2)) = [ + OP1, OP2 + ]; + + inst(OP1, (--)) { + } + + inst(OP2, (--)) { + } + """ + output = """ + TARGET(OP1) { + #if Py_TAIL_CALL_INTERP + int opcode = OP1; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP1); + DISPATCH(); + } + + TARGET(OP2) { + #if Py_TAIL_CALL_INTERP + int opcode = OP2; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP2); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + + def test_array_input(self): + input = """ + inst(OP, (below, values[oparg*2], above --)) { + SPAM(values, oparg); + DEAD(below); + DEAD(values); + DEAD(above); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef below; + _PyStackRef *values; + _PyStackRef above; + above = stack_pointer[-1]; + values = &stack_pointer[-1 - oparg*2]; + below = stack_pointer[-2 - oparg*2]; + SPAM(values, oparg); + stack_pointer += -2 - oparg*2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_array_output(self): + input = """ + inst(OP, (unused, unused -- below, values[oparg*3], above)) { + SPAM(values, oparg); + below = 0; + above = 0; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef below; + _PyStackRef *values; + _PyStackRef above; + values = &stack_pointer[-1]; + SPAM(values, oparg); + below = 0; + above = 0; + stack_pointer[-2] = below; + stack_pointer[-1 + oparg*3] = above; + stack_pointer += oparg*3; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_array_input_output(self): + input = """ + inst(OP, (values[oparg] -- values[oparg], above)) { + SPAM(values, oparg); + above = 0; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef *values; + _PyStackRef above; + values = &stack_pointer[-oparg]; + SPAM(values, oparg); + above = 0; + stack_pointer[0] = above; + stack_pointer += 1; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_array_error_if(self): + input = """ + inst(OP, (extra, values[oparg] --)) { + DEAD(extra); + DEAD(values); + ERROR_IF(oparg == 0); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef extra; + _PyStackRef *values; + values = &stack_pointer[-oparg]; + extra = stack_pointer[-1 - oparg]; + if (oparg == 0) { + stack_pointer += -1 - oparg; + assert(WITHIN_STACK_BOUNDS()); + JUMP_TO_LABEL(error); + } + stack_pointer += -1 - oparg; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_macro_push_push(self): + input = """ + op(A, (-- val1)) { + val1 = SPAM(); + } + op(B, (-- val2)) { + val2 = SPAM(); + } + macro(M) = A + B; + """ + output = """ + TARGET(M) { + #if Py_TAIL_CALL_INTERP + int opcode = M; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(M); + _PyStackRef val1; + _PyStackRef val2; + // A + { + val1 = SPAM(); + } + // B + { + val2 = SPAM(); + } + stack_pointer[0] = val1; + stack_pointer[1] = val2; + stack_pointer += 2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_override_inst(self): + input = """ + inst(OP, (--)) { + spam; + } + override inst(OP, (--)) { + ham; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + ham; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_override_op(self): + input = """ + op(OP, (--)) { + spam; + } + macro(M) = OP; + override op(OP, (--)) { + ham; + } + """ + output = """ + TARGET(M) { + #if Py_TAIL_CALL_INTERP + int opcode = M; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(M); + ham; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_annotated_inst(self): + input = """ + pure inst(OP, (--)) { + ham; + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + ham; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_annotated_op(self): + input = """ + pure op(OP, (--)) { + SPAM(); + } + macro(M) = OP; + """ + output = """ + TARGET(M) { + #if Py_TAIL_CALL_INTERP + int opcode = M; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(M); + SPAM(); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + input = """ + pure register specializing op(OP, (--)) { + SPAM(); + } + macro(M) = OP; + """ + self.run_cases_test(input, output) + + def test_deopt_and_exit(self): + input = """ + pure op(OP, (arg1 -- out)) { + DEOPT_IF(1); + EXIT_IF(1); + } + """ + output = "" + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_array_of_one(self): + input = """ + inst(OP, (arg[1] -- out[1])) { + out[0] = arg[0]; + DEAD(arg); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef *arg; + _PyStackRef *out; + arg = &stack_pointer[-1]; + out = &stack_pointer[-1]; + out[0] = arg[0]; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_pointer_to_stackref(self): + input = """ + inst(OP, (arg: _PyStackRef * -- out)) { + out = *arg; + DEAD(arg); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef *arg; + _PyStackRef out; + arg = (_PyStackRef *)stack_pointer[-1].bits; + out = *arg; + stack_pointer[-1] = out; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_unused_cached_value(self): + input = """ + op(FIRST, (arg1 -- out)) { + out = arg1; + } + + op(SECOND, (unused -- unused)) { + } + + macro(BOTH) = FIRST + SECOND; + """ + output = """ + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_unused_named_values(self): + input = """ + op(OP, (named -- named)) { + } + + macro(INST) = OP; + """ + output = """ + TARGET(INST) { + #if Py_TAIL_CALL_INTERP + int opcode = INST; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(INST); + DISPATCH(); + } + + """ + self.run_cases_test(input, output) + + def test_used_unused_used(self): + input = """ + op(FIRST, (w -- w)) { + USE(w); + } + + op(SECOND, (x -- x)) { + } + + op(THIRD, (y -- y)) { + USE(y); + } + + macro(TEST) = FIRST + SECOND + THIRD; + """ + output = """ + TARGET(TEST) { + #if Py_TAIL_CALL_INTERP + int opcode = TEST; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(TEST); + _PyStackRef w; + _PyStackRef y; + // FIRST + { + w = stack_pointer[-1]; + USE(w); + } + // SECOND + { + } + // THIRD + { + y = w; + USE(y); + } + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_unused_used_used(self): + input = """ + op(FIRST, (w -- w)) { + } + + op(SECOND, (x -- x)) { + USE(x); + } + + op(THIRD, (y -- y)) { + USE(y); + } + + macro(TEST) = FIRST + SECOND + THIRD; + """ + output = """ + TARGET(TEST) { + #if Py_TAIL_CALL_INTERP + int opcode = TEST; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(TEST); + _PyStackRef x; + _PyStackRef y; + // FIRST + { + } + // SECOND + { + x = stack_pointer[-1]; + USE(x); + } + // THIRD + { + y = x; + USE(y); + } + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_flush(self): + input = """ + op(FIRST, ( -- a, b)) { + a = 0; + b = 1; + } + + op(SECOND, (a, b -- )) { + USE(a, b); + INPUTS_DEAD(); + } + + macro(TEST) = FIRST + flush + SECOND; + """ + output = """ + TARGET(TEST) { + #if Py_TAIL_CALL_INTERP + int opcode = TEST; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(TEST); + _PyStackRef a; + _PyStackRef b; + // FIRST + { + a = 0; + b = 1; + } + // flush + stack_pointer[0] = a; + stack_pointer[1] = b; + stack_pointer += 2; + assert(WITHIN_STACK_BOUNDS()); + // SECOND + { + USE(a, b); + } + stack_pointer += -2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_pop_on_error_peeks(self): + + input = """ + op(FIRST, (x, y -- a, b)) { + a = x; + DEAD(x); + b = y; + DEAD(y); + } + + op(SECOND, (a, b -- a, b)) { + } + + op(THIRD, (j, k --)) { + INPUTS_DEAD(); // Mark j and k as used + ERROR_IF(cond); + } + + macro(TEST) = FIRST + SECOND + THIRD; + """ + output = """ + TARGET(TEST) { + #if Py_TAIL_CALL_INTERP + int opcode = TEST; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(TEST); + _PyStackRef x; + _PyStackRef y; + _PyStackRef a; + _PyStackRef b; + // FIRST + { + y = stack_pointer[-1]; + x = stack_pointer[-2]; + a = x; + b = y; + } + // SECOND + { + } + // THIRD + { + if (cond) { + JUMP_TO_LABEL(pop_2_error); + } + } + stack_pointer += -2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_push_then_error(self): + + input = """ + op(FIRST, ( -- a)) { + a = 1; + } + + op(SECOND, (a -- a, b)) { + b = 1; + ERROR_IF(cond); + } + + macro(TEST) = FIRST + SECOND; + """ + + output = """ + TARGET(TEST) { + #if Py_TAIL_CALL_INTERP + int opcode = TEST; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(TEST); + _PyStackRef a; + _PyStackRef b; + // FIRST + { + a = 1; + } + // SECOND + { + b = 1; + if (cond) { + stack_pointer[0] = a; + stack_pointer[1] = b; + stack_pointer += 2; + assert(WITHIN_STACK_BOUNDS()); + JUMP_TO_LABEL(error); + } + } + stack_pointer[0] = a; + stack_pointer[1] = b; + stack_pointer += 2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_error_if_true(self): + + input = """ + inst(OP1, ( --)) { + ERROR_IF(true); + } + inst(OP2, ( --)) { + ERROR_IF(1); + } + """ + output = """ + TARGET(OP1) { + #if Py_TAIL_CALL_INTERP + int opcode = OP1; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP1); + JUMP_TO_LABEL(error); + } + + TARGET(OP2) { + #if Py_TAIL_CALL_INTERP + int opcode = OP2; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP2); + JUMP_TO_LABEL(error); + } + """ + self.run_cases_test(input, output) + + def test_scalar_array_inconsistency(self): + + input = """ + op(FIRST, ( -- a)) { + a = 1; + } + + op(SECOND, (a[1] -- b)) { + b = 1; + } + + macro(TEST) = FIRST + SECOND; + """ + + output = """ + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_array_size_inconsistency(self): + + input = """ + op(FIRST, ( -- a[2])) { + a[0] = 1; + } + + op(SECOND, (a[1] -- b)) { + b = 1; + } + + macro(TEST) = FIRST + SECOND; + """ + + output = """ + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_stack_save_reload(self): + + input = """ + inst(BALANCED, ( -- )) { + SAVE_STACK(); + code(); + RELOAD_STACK(); + } + """ + + output = """ + TARGET(BALANCED) { + #if Py_TAIL_CALL_INTERP + int opcode = BALANCED; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(BALANCED); + _PyFrame_SetStackPointer(frame, stack_pointer); + code(); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_stack_save_reload_paired(self): + + input = """ + inst(BALANCED, ( -- )) { + SAVE_STACK(); + RELOAD_STACK(); + } + """ + + output = """ + TARGET(BALANCED) { + #if Py_TAIL_CALL_INTERP + int opcode = BALANCED; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(BALANCED); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_stack_reload_only(self): + + input = """ + inst(BALANCED, ( -- )) { + RELOAD_STACK(); + } + """ + + output = """ + TARGET(BALANCED) { + #if Py_TAIL_CALL_INTERP + int opcode = BALANCED; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(BALANCED); + _PyFrame_SetStackPointer(frame, stack_pointer); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_stack_save_only(self): + + input = """ + inst(BALANCED, ( -- )) { + SAVE_STACK(); + } + """ + + output = """ + TARGET(BALANCED) { + #if Py_TAIL_CALL_INTERP + int opcode = BALANCED; + (void)(opcode); + #endif + _Py_CODEUNIT* const this_instr = next_instr; + (void)this_instr; + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(BALANCED); + _PyFrame_SetStackPointer(frame, stack_pointer); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, output) + + def test_instruction_size_macro(self): + input = """ + inst(OP, (--)) { + frame->return_offset = INSTRUCTION_SIZE; + } + """ + + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + frame->return_offset = 1u ; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + # Two instructions of different sizes referencing the same + # uop containing the `INSTRUCTION_SIZE` macro is not allowed. + input = """ + inst(OP, (--)) { + frame->return_offset = INSTRUCTION_SIZE; + } + macro(OP2) = unused/1 + OP; + """ + + output = "" # No output needed as this should raise an error. + with self.assertRaisesRegex(SyntaxError, "All instructions containing a uop"): + self.run_cases_test(input, output) + + def test_escaping_call_next_to_cmacro(self): + input = """ + inst(OP, (--)) { + #ifdef Py_GIL_DISABLED + escaping_call(); + #else + another_escaping_call(); + #endif + yet_another_escaping_call(); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + #ifdef Py_GIL_DISABLED + _PyFrame_SetStackPointer(frame, stack_pointer); + escaping_call(); + stack_pointer = _PyFrame_GetStackPointer(frame); + #else + _PyFrame_SetStackPointer(frame, stack_pointer); + another_escaping_call(); + stack_pointer = _PyFrame_GetStackPointer(frame); + #endif + _PyFrame_SetStackPointer(frame, stack_pointer); + yet_another_escaping_call(); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_pystackref_frompyobject_new_next_to_cmacro(self): + input = """ + inst(OP, (-- out1, out2)) { + PyObject *obj = SPAM(); + #ifdef Py_GIL_DISABLED + out1 = PyStackRef_FromPyObjectNew(obj); + #else + out1 = PyStackRef_FromPyObjectNew(obj); + #endif + out2 = PyStackRef_FromPyObjectNew(obj); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef out1; + _PyStackRef out2; + PyObject *obj = SPAM(); + #ifdef Py_GIL_DISABLED + out1 = PyStackRef_FromPyObjectNew(obj); + #else + out1 = PyStackRef_FromPyObjectNew(obj); + #endif + out2 = PyStackRef_FromPyObjectNew(obj); + stack_pointer[0] = out1; + stack_pointer[1] = out2; + stack_pointer += 2; + assert(WITHIN_STACK_BOUNDS()); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_no_escaping_calls_in_branching_macros(self): + + input = """ + inst(OP, ( -- )) { + DEOPT_IF(escaping_call()); + } + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, "") + + input = """ + inst(OP, ( -- )) { + EXIT_IF(escaping_call()); + } + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, "") + + input = """ + inst(OP, ( -- )) { + ERROR_IF(escaping_call()); + } + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, "") + + def test_kill_in_wrong_order(self): + input = """ + inst(OP, (a, b -- c)) { + c = b; + PyStackRef_CLOSE(a); + PyStackRef_CLOSE(b); + } + """ + with self.assertRaises(SyntaxError): + self.run_cases_test(input, "") + + def test_complex_label(self): + input = """ + label(other_label) { + } + + label(other_label2) { + } + + label(my_label) { + // Comment + do_thing(); + if (complex) { + goto other_label; + } + goto other_label2; + } + """ + + output = """ + LABEL(other_label) + { + } + + LABEL(other_label2) + { + } + + LABEL(my_label) + { + _PyFrame_SetStackPointer(frame, stack_pointer); + do_thing(); + stack_pointer = _PyFrame_GetStackPointer(frame); + if (complex) { + JUMP_TO_LABEL(other_label); + } + JUMP_TO_LABEL(other_label2); + } + """ + self.run_cases_test(input, output) + + def test_spilled_label(self): + input = """ + spilled label(one) { + RELOAD_STACK(); + goto two; + } + + label(two) { + SAVE_STACK(); + goto one; + } + """ + + output = """ + LABEL(one) + { + stack_pointer = _PyFrame_GetStackPointer(frame); + JUMP_TO_LABEL(two); + } + + LABEL(two) + { + _PyFrame_SetStackPointer(frame, stack_pointer); + JUMP_TO_LABEL(one); + } + """ + self.run_cases_test(input, output) + + + def test_incorrect_spills(self): + input1 = """ + spilled label(one) { + goto two; + } + + label(two) { + } + """ + + input2 = """ + spilled label(one) { + } + + label(two) { + goto one; + } + """ + with self.assertRaisesRegex(SyntaxError, ".*reload.*"): + self.run_cases_test(input1, "") + with self.assertRaisesRegex(SyntaxError, ".*spill.*"): + self.run_cases_test(input2, "") + + + def test_multiple_labels(self): + input = """ + label(my_label_1) { + // Comment + do_thing1(); + goto my_label_2; + } + + label(my_label_2) { + // Comment + do_thing2(); + goto my_label_1; + } + """ + + output = """ + LABEL(my_label_1) + { + _PyFrame_SetStackPointer(frame, stack_pointer); + do_thing1(); + stack_pointer = _PyFrame_GetStackPointer(frame); + JUMP_TO_LABEL(my_label_2); + } + + LABEL(my_label_2) + { + _PyFrame_SetStackPointer(frame, stack_pointer); + do_thing2(); + stack_pointer = _PyFrame_GetStackPointer(frame); + JUMP_TO_LABEL(my_label_1); + } + """ + self.run_cases_test(input, output) + + def test_reassigning_live_inputs(self): + input = """ + inst(OP, (in -- in)) { + in = 0; + } + """ + + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef in; + in = stack_pointer[-1]; + in = 0; + stack_pointer[-1] = in; + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + def test_reassigning_dead_inputs(self): + input = """ + inst(OP, (in -- )) { + temp = use(in); + DEAD(in); + in = temp; + PyStackRef_CLOSE(in); + } + """ + output = """ + TARGET(OP) { + #if Py_TAIL_CALL_INTERP + int opcode = OP; + (void)(opcode); + #endif + frame->instr_ptr = next_instr; + next_instr += 1; + INSTRUCTION_STATS(OP); + _PyStackRef in; + in = stack_pointer[-1]; + _PyFrame_SetStackPointer(frame, stack_pointer); + temp = use(in); + stack_pointer = _PyFrame_GetStackPointer(frame); + in = temp; + stack_pointer += -1; + assert(WITHIN_STACK_BOUNDS()); + _PyFrame_SetStackPointer(frame, stack_pointer); + PyStackRef_CLOSE(in); + stack_pointer = _PyFrame_GetStackPointer(frame); + DISPATCH(); + } + """ + self.run_cases_test(input, output) + + +class TestGeneratedAbstractCases(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.maxDiff = None + + self.temp_dir = tempfile.gettempdir() + self.temp_input_filename = os.path.join(self.temp_dir, "input.txt") + self.temp_input2_filename = os.path.join(self.temp_dir, "input2.txt") + self.temp_output_filename = os.path.join(self.temp_dir, "output.txt") + + def tearDown(self) -> None: + for filename in [ + self.temp_input_filename, + self.temp_input2_filename, + self.temp_output_filename, + ]: + try: + os.remove(filename) + except: + pass + super().tearDown() + + def run_cases_test(self, input: str, input2: str, expected: str): + with open(self.temp_input_filename, "w+") as temp_input: + temp_input.write(parser.BEGIN_MARKER) + temp_input.write(input) + temp_input.write(parser.END_MARKER) + temp_input.flush() + + with open(self.temp_input2_filename, "w+") as temp_input: + temp_input.write(parser.BEGIN_MARKER) + temp_input.write(input2) + temp_input.write(parser.END_MARKER) + temp_input.flush() + + with handle_stderr(): + optimizer_generator.generate_tier2_abstract_from_files( + [self.temp_input_filename, self.temp_input2_filename], + self.temp_output_filename + ) + + with open(self.temp_output_filename) as temp_output: + lines = temp_output.readlines() + while lines and lines[0].startswith(("// ", "#", " #", "\n")): + lines.pop(0) + while lines and lines[-1].startswith(("#", "\n")): + lines.pop(-1) + actual = "".join(lines) + self.assertEqual(actual.strip(), expected.strip()) + + def test_overridden_abstract(self): + input = """ + pure op(OP, (--)) { + SPAM(); + } + """ + input2 = """ + pure op(OP, (--)) { + eggs(); + } + """ + output = """ + case OP: { + eggs(); + break; + } + """ + self.run_cases_test(input, input2, output) + + def test_overridden_abstract_args(self): + input = """ + pure op(OP, (arg1 -- out)) { + out = SPAM(arg1); + } + op(OP2, (arg1 -- out)) { + out = EGGS(arg1); + } + """ + input2 = """ + op(OP, (arg1 -- out)) { + out = EGGS(arg1); + } + """ + output = """ + case OP: { + JitOptSymbol *arg1; + JitOptSymbol *out; + arg1 = stack_pointer[-1]; + out = EGGS(arg1); + stack_pointer[-1] = out; + break; + } + + case OP2: { + JitOptSymbol *out; + out = sym_new_not_null(ctx); + stack_pointer[-1] = out; + break; + } + """ + self.run_cases_test(input, input2, output) + + def test_no_overridden_case(self): + input = """ + pure op(OP, (arg1 -- out)) { + out = SPAM(arg1); + } + + pure op(OP2, (arg1 -- out)) { + } + + """ + input2 = """ + pure op(OP2, (arg1 -- out)) { + out = NULL; + } + """ + output = """ + case OP: { + JitOptSymbol *out; + out = sym_new_not_null(ctx); + stack_pointer[-1] = out; + break; + } + + case OP2: { + JitOptSymbol *out; + out = NULL; + stack_pointer[-1] = out; + break; + } + """ + self.run_cases_test(input, input2, output) + + def test_missing_override_failure(self): + input = """ + pure op(OP, (arg1 -- out)) { + SPAM(); + } + """ + input2 = """ + pure op(OTHER, (arg1 -- out)) { + } + """ + output = """ + """ + with self.assertRaisesRegex(AssertionError, "All abstract uops"): + self.run_cases_test(input, input2, output) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py index 8da74ff530d..8ede6e22fab 100644 --- a/Lib/test/test_generators.py +++ b/Lib/test/test_generators.py @@ -68,7 +68,6 @@ def gen(): del frame support.gc_collect() - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: False is not true def test_refcycle(self): # A generator caught in a refcycle gets finalized anyway. old_garbage = gc.garbage[:] @@ -114,7 +113,6 @@ def g3(): return (yield from f()) gen.send(2) self.assertEqual(cm.exception.value, 2) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 1 def test_generator_resurrect(self): # Test that a resurrected generator still has a valid gi_code resurrected = [] @@ -2527,7 +2525,7 @@ def printsolution(self, x): ... SyntaxError: 'yield from' outside function ->>> def f(): x = yield = y # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE +>>> def f(): x = yield = y Traceback (most recent call last): ... SyntaxError: assignment to yield expression not possible diff --git a/Lib/test/test_genexps.py b/Lib/test/test_genexps.py new file mode 100644 index 00000000000..fde12f13cdc --- /dev/null +++ b/Lib/test/test_genexps.py @@ -0,0 +1,294 @@ +import sys +import doctest +import unittest + + +doctests = """ + +Test simple loop with conditional + + >>> sum(i*i for i in range(100) if i&1 == 1) + 166650 + +Test simple nesting + + >>> list((i,j) for i in range(3) for j in range(4) ) + [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)] + +Test nesting with the inner expression dependent on the outer + + >>> list((i,j) for i in range(4) for j in range(i) ) + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)] + +Test the idiom for temporary variable assignment in comprehensions. + + >>> list((j*j for i in range(4) for j in [i+1])) + [1, 4, 9, 16] + >>> list((j*k for i in range(4) for j in [i+1] for k in [j+1])) + [2, 6, 12, 20] + >>> list((j*k for i in range(4) for j, k in [(i+1, i+2)])) + [2, 6, 12, 20] + +Not assignment + + >>> list((i*i for i in [*range(4)])) + [0, 1, 4, 9] + >>> list((i*i for i in (*range(4),))) + [0, 1, 4, 9] + +Make sure the induction variable is not exposed + + >>> i = 20 + >>> sum(i*i for i in range(100)) + 328350 + >>> i + 20 + +Test first class + + >>> g = (i*i for i in range(4)) + >>> type(g) + + >>> list(g) + [0, 1, 4, 9] + +Test direct calls to next() + + >>> g = (i*i for i in range(3)) + >>> next(g) + 0 + >>> next(g) + 1 + >>> next(g) + 4 + >>> next(g) + Traceback (most recent call last): + File "", line 1, in -toplevel- + next(g) + StopIteration + +Does it stay stopped? + + >>> next(g) + Traceback (most recent call last): + File "", line 1, in -toplevel- + next(g) + StopIteration + >>> list(g) + [] + +Test running gen when defining function is out of scope + + >>> def f(n): + ... return (i*i for i in range(n)) + >>> list(f(10)) + [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + + >>> def f(n): + ... return ((i,j) for i in range(3) for j in range(n)) + >>> list(f(4)) + [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)] + >>> def f(n): + ... return ((i,j) for i in range(3) for j in range(4) if j in range(n)) + >>> list(f(4)) + [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)] + >>> list(f(2)) + [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] + +Verify that parenthesis are required in a statement + + >>> def f(n): + ... return i*i for i in range(n) + Traceback (most recent call last): + ... + SyntaxError: invalid syntax + +Verify that parenthesis are required when used as a keyword argument value + + >>> dict(a = i for i in range(10)) # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + ... + SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? + +Verify that parenthesis are required when used as a keyword argument value + + >>> dict(a = (i for i in range(10))) #doctest: +ELLIPSIS + {'a': at ...>} + +Verify early binding for the outermost for-expression + + >>> x=10 + >>> g = (i*i for i in range(x)) + >>> x = 5 + >>> list(g) + [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + +Verify late binding for the outermost if-expression + + >>> include = (2,4,6,8) + >>> g = (i*i for i in range(10) if i in include) + >>> include = (1,3,5,7,9) + >>> list(g) + [1, 9, 25, 49, 81] + +Verify that the outermost for-expression makes an immediate check +for iterability + >>> (i for i in 6) + Traceback (most recent call last): + File "", line 1, in -toplevel- + (i for i in 6) + TypeError: 'int' object is not iterable + +Verify late binding for the innermost for-expression + + >>> g = ((i,j) for i in range(3) for j in range(x)) + >>> x = 4 + >>> list(g) + [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)] + +Verify re-use of tuples (a side benefit of using genexps over listcomps) + + >>> tupleids = list(map(id, ((i,i) for i in range(10)))) + >>> int(max(tupleids) - min(tupleids)) + 0 + +Verify that syntax error's are raised for genexps used as lvalues + + >>> (y for y in (1,2)) = 10 + Traceback (most recent call last): + ... + SyntaxError: cannot assign to generator expression + + >>> (y for y in (1,2)) += 10 # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + ... + SyntaxError: 'generator expression' is an illegal expression for augmented assignment + + +########### Tests borrowed from or inspired by test_generators.py ############ + +Make a generator that acts like range() + + >>> yrange = lambda n: (i for i in range(n)) + >>> list(yrange(10)) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + +Generators always return to the most recent caller: + + >>> def creator(): + ... r = yrange(5) + ... print("creator", next(r)) + ... return r + >>> def caller(): + ... r = creator() + ... for i in r: + ... print("caller", i) + >>> caller() + creator 0 + caller 1 + caller 2 + caller 3 + caller 4 + +Generators can call other generators: + + >>> def zrange(n): + ... for i in yrange(n): + ... yield i + >>> list(zrange(5)) + [0, 1, 2, 3, 4] + + +Verify that a gen exp cannot be resumed while it is actively running: + + >>> g = (next(me) for i in range(10)) + >>> me = g + >>> next(me) + Traceback (most recent call last): + File "", line 1, in -toplevel- + next(me) + File "", line 1, in + g = (next(me) for i in range(10)) + ValueError: generator already executing + +Verify exception propagation + + >>> g = (10 // i for i in (5, 0, 2)) + >>> next(g) + 2 + >>> next(g) + Traceback (most recent call last): + File "", line 1, in -toplevel- + next(g) + File "", line 1, in + g = (10 // i for i in (5, 0, 2)) + ZeroDivisionError: division by zero + >>> next(g) + Traceback (most recent call last): + File "", line 1, in -toplevel- + next(g) + StopIteration + +Make sure that None is a valid return value + + >>> list(None for i in range(10)) + [None, None, None, None, None, None, None, None, None, None] + +Check that generator attributes are present + + >>> g = (i*i for i in range(3)) + >>> expected = set(['gi_frame', 'gi_running']) + >>> set(attr for attr in dir(g) if not attr.startswith('__')) >= expected + True + + >>> from test.support import HAVE_DOCSTRINGS + >>> print(g.__next__.__doc__ if HAVE_DOCSTRINGS else 'Implement next(self).') # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + Implement next(self). + >>> import types + >>> isinstance(g, types.GeneratorType) + True + +Check the __iter__ slot is defined to return self + + >>> iter(g) is g + True + +Verify that the running flag is set properly + + >>> g = (me.gi_running for i in (0,1)) + >>> me = g + >>> me.gi_running + 0 + >>> next(me) + 1 + >>> me.gi_running + 0 + +Verify that genexps are weakly referencable + + >>> import weakref + >>> g = (i*i for i in range(4)) + >>> wr = weakref.ref(g) + >>> wr() is g + True + >>> p = weakref.proxy(g) + >>> list(p) + [0, 1, 4, 9] + + +""" + +# Trace function can throw off the tuple reuse test. +if hasattr(sys, 'gettrace') and sys.gettrace(): + __test__ = {} +else: + __test__ = {'doctests' : doctests} + +def load_tests(loader, tests, pattern): + from test.support.rustpython import DocTestChecker # TODO: RUSTPYTHON + tests.addTest(doctest.DocTestSuite(checker=DocTestChecker())) # TODO: RUSTPYTHON + return tests + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py index 91eb6cc58f3..cf90de7b115 100644 --- a/Lib/test/test_grammar.py +++ b/Lib/test/test_grammar.py @@ -1,11 +1,13 @@ # Python test set -- part 1, grammar. # This just tests whether the parser accepts them all. -from test.support import check_syntax_error +from test.support import check_syntax_error, skip_wasi_stack_overflow from test.support import import_helper +import annotationlib import inspect import unittest import sys +import textwrap import warnings # testing import * from sys import * @@ -112,7 +114,7 @@ def test_underscore_literals(self): # Sanity check: no literal begins with an underscore self.assertRaises(NameError, eval, "_0") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_numerical_literals(self): check = self.check_syntax_error check("0b12", "invalid digit '2' in binary literal") @@ -135,7 +137,7 @@ def test_bad_numerical_literals(self): check("1e2_", "invalid decimal literal") check("1e+", "invalid decimal literal") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_end_of_numerical_literals(self): def check(test, error=False): with self.subTest(expr=test): @@ -216,6 +218,27 @@ def test_string_literals(self): ' self.assertEqual(x, y) + def test_string_prefixes(self): + def check(s): + parsed = eval(s) + self.assertIs(type(parsed), str) + self.assertGreater(len(parsed), 0) + + check("u'abc'") + check("r'abc\t'") + check("rf'abc\a {1 + 1}'") + check("fr'abc\a {1 + 1}'") + + def test_bytes_prefixes(self): + def check(s): + parsed = eval(s) + self.assertIs(type(parsed), bytes) + self.assertGreater(len(parsed), 0) + + check("b'abc'") + check("br'abc\t'") + check("rb'abc\a'") + def test_ellipsis(self): x = ... self.assertTrue(x is Ellipsis) @@ -228,17 +251,20 @@ def test_eof_error(self): compile(s, "", "exec") self.assertIn("was never closed", str(cm.exception)) -var_annot_global: int # a global annotated is necessary for test_var_annot + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised + @skip_wasi_stack_overflow() + def test_max_level(self): + # Macro defined in Parser/lexer/state.h + MAXLEVEL = 200 + + result = eval("(" * MAXLEVEL + ")" * MAXLEVEL) + self.assertEqual(result, ()) -# custom namespace for testing __annotations__ + with self.assertRaises(SyntaxError) as cm: + eval("(" * (MAXLEVEL + 1) + ")" * (MAXLEVEL + 1)) + self.assertStartsWith(str(cm.exception), 'too many nested parentheses') -class CNS: - def __init__(self): - self._dct = {} - def __setitem__(self, item, value): - self._dct[item.lower()] = value - def __getitem__(self, item): - return self._dct[item] +var_annot_global: int # a global annotated is necessary for test_var_annot class GrammarTests(unittest.TestCase): @@ -272,7 +298,7 @@ def one(): my_lst[one()-1]: int = 5 self.assertEqual(my_lst, [5]) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_var_annot_syntax_errors(self): # parser pass check_syntax_error(self, "def f: int") @@ -371,24 +397,12 @@ class F(C, A): self.assertEqual(E.__annotations__, {}) self.assertEqual(F.__annotations__, {}) - - @unittest.expectedFailure # TODO: RUSTPYTHON - def test_var_annot_metaclass_semantics(self): - class CMeta(type): - @classmethod - def __prepare__(metacls, name, bases, **kwds): - return {'__annotations__': CNS()} - class CC(metaclass=CMeta): - XX: 'ANNOT' - self.assertEqual(CC.__annotations__['xx'], 'ANNOT') - - @unittest.expectedFailure # TODO: RUSTPYTHON def test_var_annot_module_semantics(self): self.assertEqual(test.__annotations__, {}) self.assertEqual(ann_module.__annotations__, - {1: 2, 'x': int, 'y': str, 'f': typing.Tuple[int, int], 'u': int | float}) + {'x': int, 'y': str, 'f': typing.Tuple[int, int], 'u': int | float}) self.assertEqual(ann_module.M.__annotations__, - {'123': 123, 'o': type}) + {'o': type}) self.assertEqual(ann_module2.__annotations__, {}) def test_var_annot_in_module(self): @@ -402,55 +416,14 @@ def test_var_annot_in_module(self): with self.assertRaises(NameError): ann_module3.D_bad_ann(5) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_var_annot_simple_exec(self): - gns = {}; lns= {} + gns = {}; lns = {} exec("'docstring'\n" - "__annotations__[1] = 2\n" "x: int = 5\n", gns, lns) - self.assertEqual(lns["__annotations__"], {1: 2, 'x': int}) - with self.assertRaises(KeyError): - gns['__annotations__'] - - @unittest.expectedFailure # TODO: RUSTPYTHON - def test_var_annot_custom_maps(self): - # tests with custom locals() and __annotations__ - ns = {'__annotations__': CNS()} - exec('X: int; Z: str = "Z"; (w): complex = 1j', ns) - self.assertEqual(ns['__annotations__']['x'], int) - self.assertEqual(ns['__annotations__']['z'], str) - with self.assertRaises(KeyError): - ns['__annotations__']['w'] - nonloc_ns = {} - class CNS2: - def __init__(self): - self._dct = {} - def __setitem__(self, item, value): - nonlocal nonloc_ns - self._dct[item] = value - nonloc_ns[item] = value - def __getitem__(self, item): - return self._dct[item] - exec('x: int = 1', {}, CNS2()) - self.assertEqual(nonloc_ns['__annotations__']['x'], int) - - @unittest.expectedFailure # TODO: RUSTPYTHON - def test_var_annot_refleak(self): - # complex case: custom locals plus custom __annotations__ - # this was causing refleak - cns = CNS() - nonloc_ns = {'__annotations__': cns} - class CNS2: - def __init__(self): - self._dct = {'__annotations__': cns} - def __setitem__(self, item, value): - nonlocal nonloc_ns - self._dct[item] = value - nonloc_ns[item] = value - def __getitem__(self, item): - return self._dct[item] - exec('X: str', {}, CNS2()) - self.assertEqual(nonloc_ns['__annotations__']['x'], str) + self.assertNotIn('__annotate__', gns) + + gns.update(lns) # __annotate__ looks at globals + self.assertEqual(lns["__annotate__"](annotationlib.Format.VALUE), {'x': int}) def test_var_annot_rhs(self): ns = {} @@ -778,7 +751,7 @@ def test_expr_stmt(self): # Check the heuristic for print & exec covers significant cases # As well as placing some limits on false positives - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_former_statements_refer_to_builtins(self): keywords = "print", "exec" # Cases where we want the custom error @@ -905,188 +878,294 @@ def g3(): self.assertEqual(y, (1, 2, 3), "unparenthesized star expr return") check_syntax_error(self, "class foo:return 1") - def test_break_in_finally(self): - count = 0 - while count < 2: - count += 1 - try: - pass - finally: - break - self.assertEqual(count, 1) + def test_control_flow_in_finally(self): - count = 0 - while count < 2: - count += 1 - try: - continue - finally: - break - self.assertEqual(count, 1) + def run_case(self, src, expected): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', SyntaxWarning) + g, l = {}, { 'self': self } + exec(textwrap.dedent(src), g, l) + self.assertEqual(expected, l['result']) - count = 0 - while count < 2: - count += 1 - try: - 1/0 - finally: - break - self.assertEqual(count, 1) - for count in [0, 1]: - self.assertEqual(count, 0) - try: - pass - finally: - break - self.assertEqual(count, 0) + # *********** Break in finally *********** - for count in [0, 1]: - self.assertEqual(count, 0) - try: - continue - finally: - break - self.assertEqual(count, 0) + run_case( + self, + """ + result = 0 + while result < 2: + result += 1 + try: + pass + finally: + break + """, + 1) + + run_case( + self, + """ + result = 0 + while result < 2: + result += 1 + try: + continue + finally: + break + """, + 1) + + run_case( + self, + """ + result = 0 + while result < 2: + result += 1 + try: + 1/0 + finally: + break + """, + 1) + + run_case( + self, + """ + for result in [0, 1]: + self.assertEqual(result, 0) + try: + pass + finally: + break + """, + 0) + + run_case( + self, + """ + for result in [0, 1]: + self.assertEqual(result, 0) + try: + continue + finally: + break + """, + 0) + + run_case( + self, + """ + for result in [0, 1]: + self.assertEqual(result, 0) + try: + 1/0 + finally: + break + """, + 0) - for count in [0, 1]: - self.assertEqual(count, 0) - try: - 1/0 - finally: + + # *********** Continue in finally *********** + + run_case( + self, + """ + result = 0 + while result < 2: + result += 1 + try: + pass + finally: + continue break - self.assertEqual(count, 0) + """, + 2) - def test_continue_in_finally(self): - count = 0 - while count < 2: - count += 1 - try: - pass - finally: - continue - break - self.assertEqual(count, 2) - count = 0 - while count < 2: - count += 1 - try: + run_case( + self, + """ + result = 0 + while result < 2: + result += 1 + try: + break + finally: + continue + """, + 2) + + run_case( + self, + """ + result = 0 + while result < 2: + result += 1 + try: + 1/0 + finally: + continue break - finally: - continue - self.assertEqual(count, 2) + """, + 2) - count = 0 - while count < 2: - count += 1 - try: - 1/0 - finally: - continue - break - self.assertEqual(count, 2) + run_case( + self, + """ + for result in [0, 1]: + try: + pass + finally: + continue + break + """, + 1) - for count in [0, 1]: - try: - pass - finally: - continue - break - self.assertEqual(count, 1) + run_case( + self, + """ + for result in [0, 1]: + try: + break + finally: + continue + """, + 1) - for count in [0, 1]: - try: + run_case( + self, + """ + for result in [0, 1]: + try: + 1/0 + finally: + continue break - finally: - continue - self.assertEqual(count, 1) + """, + 1) - for count in [0, 1]: - try: - 1/0 - finally: - continue - break - self.assertEqual(count, 1) - def test_return_in_finally(self): - def g1(): - try: - pass - finally: - return 1 - self.assertEqual(g1(), 1) + # *********** Return in finally *********** - def g2(): - try: - return 2 - finally: - return 3 - self.assertEqual(g2(), 3) + run_case( + self, + """ + def f(): + try: + pass + finally: + return 1 + result = f() + """, + 1) + + run_case( + self, + """ + def f(): + try: + return 2 + finally: + return 3 + result = f() + """, + 3) + + run_case( + self, + """ + def f(): + try: + 1/0 + finally: + return 4 + result = f() + """, + 4) - def g3(): - try: - 1/0 - finally: - return 4 - self.assertEqual(g3(), 4) + # See issue #37830 + run_case( + self, + """ + def break_in_finally_after_return1(x): + for count in [0, 1]: + count2 = 0 + while count2 < 20: + count2 += 10 + try: + return count + count2 + finally: + if x: + break + return 'end', count, count2 + + self.assertEqual(break_in_finally_after_return1(False), 10) + self.assertEqual(break_in_finally_after_return1(True), ('end', 1, 10)) + result = True + """, + True) + + + run_case( + self, + """ + def break_in_finally_after_return2(x): + for count in [0, 1]: + for count2 in [10, 20]: + try: + return count + count2 + finally: + if x: + break + return 'end', count, count2 + + self.assertEqual(break_in_finally_after_return2(False), 10) + self.assertEqual(break_in_finally_after_return2(True), ('end', 1, 10)) + result = True + """, + True) - def test_break_in_finally_after_return(self): # See issue #37830 - def g1(x): - for count in [0, 1]: - count2 = 0 - while count2 < 20: - count2 += 10 + run_case( + self, + """ + def continue_in_finally_after_return1(x): + count = 0 + while count < 100: + count += 1 try: - return count + count2 + return count finally: if x: - break - return 'end', count, count2 - self.assertEqual(g1(False), 10) - self.assertEqual(g1(True), ('end', 1, 10)) - - def g2(x): - for count in [0, 1]: - for count2 in [10, 20]: + continue + return 'end', count + + self.assertEqual(continue_in_finally_after_return1(False), 1) + self.assertEqual(continue_in_finally_after_return1(True), ('end', 100)) + result = True + """, + True) + + run_case( + self, + """ + def continue_in_finally_after_return2(x): + for count in [0, 1]: try: - return count + count2 + return count finally: if x: - break - return 'end', count, count2 - self.assertEqual(g2(False), 10) - self.assertEqual(g2(True), ('end', 1, 10)) - - def test_continue_in_finally_after_return(self): - # See issue #37830 - def g1(x): - count = 0 - while count < 100: - count += 1 - try: - return count - finally: - if x: - continue - return 'end', count - self.assertEqual(g1(False), 1) - self.assertEqual(g1(True), ('end', 100)) + continue + return 'end', count - def g2(x): - for count in [0, 1]: - try: - return count - finally: - if x: - continue - return 'end', count - self.assertEqual(g2(False), 0) - self.assertEqual(g2(True), ('end', 1)) + self.assertEqual(continue_in_finally_after_return2(False), 0) + self.assertEqual(continue_in_finally_after_return2(True), ('end', 1)) + result = True + """, + True) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_yield(self): # Allowed as standalone statement def g(): yield 1 @@ -1126,7 +1205,7 @@ def g(): rest = 4, 5, 6; yield 1, 2, 3, *rest # Check annotation refleak on SyntaxError check_syntax_error(self, "def g(a:(yield)): pass") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_yield_in_comprehensions(self): # Check yield in comprehensions def g(): [x for x in [(yield 1)]] @@ -1223,7 +1302,7 @@ def test_assert_failures(self): else: self.fail("AssertionError not raised by 'assert False'") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_assert_syntax_warnings(self): # Ensure that we warn users if they provide a non-zero length tuple as # the assertion test. @@ -1238,7 +1317,7 @@ def test_assert_syntax_warnings(self): compile('assert x, "msg"', '', 'exec') compile('assert False, "msg"', '', 'exec') - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_assert_warning_promotes_to_syntax_error(self): # If SyntaxWarning is configured to be an error, it actually raises a # SyntaxError. @@ -1339,6 +1418,8 @@ def test_try(self): try: 1/0 except (EOFError, TypeError, ZeroDivisionError): pass try: 1/0 + except EOFError, TypeError, ZeroDivisionError: pass + try: 1/0 except (EOFError, TypeError, ZeroDivisionError) as msg: pass try: pass finally: pass @@ -1346,8 +1427,6 @@ def test_try(self): compile("try:\n pass\nexcept Exception as a.b:\n pass", "?", "exec") compile("try:\n pass\nexcept Exception as a[b]:\n pass", "?", "exec") - # TODO: RUSTPYTHON - ''' def test_try_star(self): ### try_stmt: 'try': suite (except_star_clause : suite) + ['else' ':' suite] ### except_star_clause: 'except*' expr ['as' NAME] @@ -1364,6 +1443,8 @@ def test_try_star(self): try: 1/0 except* (EOFError, TypeError, ZeroDivisionError): pass try: 1/0 + except* EOFError, TypeError, ZeroDivisionError: pass + try: 1/0 except* (EOFError, TypeError, ZeroDivisionError) as msg: pass try: pass finally: pass @@ -1371,7 +1452,6 @@ def test_try_star(self): compile("try:\n pass\nexcept* Exception as a.b:\n pass", "?", "exec") compile("try:\n pass\nexcept* Exception as a[b]:\n pass", "?", "exec") compile("try:\n pass\nexcept*:\n pass", "?", "exec") - ''' def test_suite(self): # simple_stmt | NEWLINE INDENT NEWLINE* (stmt NEWLINE*)+ DEDENT @@ -1416,7 +1496,7 @@ def test_comparison(self): if 1 not in (): pass if 1 < 1 > 1 == 1 >= 1 <= 1 != 1 in 1 not in x is x is not x: pass - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_comparison_is_literal(self): def check(test, msg): self.check_syntax_warning(test, msg) @@ -1446,7 +1526,7 @@ def check(test, msg): compile('True is x', '', 'exec') compile('... is x', '', 'exec') - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_warn_missed_comma(self): def check(test): self.check_syntax_warning(test, msg) @@ -1471,6 +1551,8 @@ def check(test): check('[None (3, 4)]') check('[True (3, 4)]') check('[... (3, 4)]') + check('[t"{x}" (3, 4)]') + check('[t"x={x}" (3, 4)]') msg=r'is not subscriptable; perhaps you missed a comma\?' check('[{1, 2} [i, j]]') @@ -1483,6 +1565,8 @@ def check(test): check('[None [i, j]]') check('[True [i, j]]') check('[... [i, j]]') + check('[t"{x}" [i, j]]') + check('[t"x={x}" [i, j]]') msg=r'indices must be integers or slices, not tuple; perhaps you missed a comma\?' check('[(1, 2) [i, j]]') @@ -1513,6 +1597,9 @@ def check(test): check('[[1, 2] [f"{x}"]]') check('[[1, 2] [f"x={x}"]]') check('[[1, 2] ["abc"]]') + msg=r'indices must be integers or slices, not string.templatelib.Template;' + check('[[1, 2] [t"{x}"]]') + check('[[1, 2] [t"x={x}"]]') msg=r'indices must be integers or slices, not' check('[[1, 2] [b"abc"]]') check('[[1, 2] [12.3]]') @@ -1623,8 +1710,6 @@ def test_atoms(self): ### testlist: test (',' test)* [','] # These have been exercised enough above - # TODO: RUSTPYTHON - ''' def test_classdef(self): # 'class' NAME ['(' [testlist] ')'] ':' suite class B: pass @@ -1649,7 +1734,8 @@ class G: pass class H: pass @d := class_decorator class I: pass - @lambda c: class_decorator(c) + # TODO: RUSTPYTHON; SyntaxError: the symbol 'class_decorator' must be present in the symbol table + # @lambda c: class_decorator(c) class J: pass @[..., class_decorator, ...][1] class K: pass @@ -1657,7 +1743,6 @@ class K: pass class L: pass @[class_decorator][0].__call__.__call__ class M: pass - ''' def test_dictcomps(self): # dictorsetmaker: ( (test ':' test (comp_for | diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py index 9fae17db7b1..11c43ccee19 100644 --- a/Lib/test/test_hash.py +++ b/Lib/test/test_hash.py @@ -263,8 +263,7 @@ def get_expected_hash(self, position, length): platform = 3 if IS_64BIT else 2 return self.known_hashes[algorithm][position][platform] - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_null_hash(self): # PYTHONHASHSEED=0 disables the randomized hash known_hash_of_obj = self.get_expected_hash(0, 3) @@ -275,8 +274,7 @@ def test_null_hash(self): # It can also be disabled by setting the seed to 0: self.assertEqual(self.get_hash(self.repr_, seed=0), known_hash_of_obj) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @skip_unless_internalhash def test_fixed_hash(self): # test a fixed seed for the randomized hash @@ -284,8 +282,7 @@ def test_fixed_hash(self): h = self.get_expected_hash(1, 3) self.assertEqual(self.get_hash(self.repr_, seed=42), h) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @skip_unless_internalhash def test_long_fixed_hash(self): if self.repr_long is None: @@ -304,8 +301,7 @@ class StrHashRandomizationTests(StringlikeHashRandomizationTests, def test_empty_string(self): self.assertEqual(hash(""), 0) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @skip_unless_internalhash def test_ucs2_string(self): h = self.get_expected_hash(3, 6) diff --git a/Lib/test/test_http_cookies.py b/Lib/test/test_http_cookies.py index b4685f6e75e..ab01d68d1d3 100644 --- a/Lib/test/test_http_cookies.py +++ b/Lib/test/test_http_cookies.py @@ -1,5 +1,5 @@ # Simple test suite for http/cookies.py - +import base64 import copy import unittest import doctest @@ -153,17 +153,19 @@ def test_load(self): self.assertEqual(C.output(['path']), 'Set-Cookie: Customer="WILE_E_COYOTE"; Path=/acme') - self.assertEqual(C.js_output(), r""" + cookie_encoded = base64.b64encode(b'Customer="WILE_E_COYOTE"; Path=/acme; Version=1').decode('ascii') + self.assertEqual(C.js_output(), fr""" """) - self.assertEqual(C.js_output(['path']), r""" + cookie_encoded = base64.b64encode(b'Customer="WILE_E_COYOTE"; Path=/acme').decode('ascii') + self.assertEqual(C.js_output(['path']), fr""" """) @@ -268,17 +270,19 @@ def test_quoted_meta(self): self.assertEqual(C.output(['path']), 'Set-Cookie: Customer="WILE_E_COYOTE"; Path=/acme') - self.assertEqual(C.js_output(), r""" + expected_encoded_cookie = base64.b64encode(b'Customer=\"WILE_E_COYOTE\"; Path=/acme; Version=1').decode('ascii') + self.assertEqual(C.js_output(), fr""" """) - self.assertEqual(C.js_output(['path']), r""" + expected_encoded_cookie = base64.b64encode(b'Customer=\"WILE_E_COYOTE\"; Path=/acme').decode('ascii') + self.assertEqual(C.js_output(['path']), fr""" """) @@ -369,13 +373,16 @@ def test_setter(self): self.assertEqual( M.output(), "Set-Cookie: %s=%s; Path=/foo" % (i, "%s_coded_val" % i)) + expected_encoded_cookie = base64.b64encode( + ("%s=%s; Path=/foo" % (i, "%s_coded_val" % i)).encode("ascii") + ).decode('ascii') expected_js_output = """ - """ % (i, "%s_coded_val" % i) + """ % (expected_encoded_cookie,) self.assertEqual(M.js_output(), expected_js_output) for i in ["foo bar", "foo@bar"]: # Try some illegal characters diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 8ce5f853b7c..f7ab3e576c0 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -369,6 +369,51 @@ def test_invalid_headers(self): with self.assertRaisesRegex(ValueError, 'Invalid header'): conn.putheader(name, value) + def test_invalid_tunnel_headers(self): + cases = ( + ('Invalid\r\nName', 'ValidValue'), + ('Invalid\rName', 'ValidValue'), + ('Invalid\nName', 'ValidValue'), + ('\r\nInvalidName', 'ValidValue'), + ('\rInvalidName', 'ValidValue'), + ('\nInvalidName', 'ValidValue'), + (' InvalidName', 'ValidValue'), + ('\tInvalidName', 'ValidValue'), + ('Invalid:Name', 'ValidValue'), + (':InvalidName', 'ValidValue'), + ('ValidName', 'Invalid\r\nValue'), + ('ValidName', 'Invalid\rValue'), + ('ValidName', 'Invalid\nValue'), + ('ValidName', 'InvalidValue\r\n'), + ('ValidName', 'InvalidValue\r'), + ('ValidName', 'InvalidValue\n'), + ) + for name, value in cases: + with self.subTest((name, value)): + conn = client.HTTPConnection('example.com') + conn.set_tunnel('tunnel', headers={ + name: value + }) + conn.sock = FakeSocket('') + with self.assertRaisesRegex(ValueError, 'Invalid header'): + conn._tunnel() # Called in .connect() + + def test_invalid_tunnel_host(self): + cases = ( + 'invalid\r.host', + '\ninvalid.host', + 'invalid.host\r\n', + 'invalid.host\x00', + 'invalid host', + ) + for tunnel_host in cases: + with self.subTest(tunnel_host): + conn = client.HTTPConnection('example.com') + conn.set_tunnel(tunnel_host) + conn.sock = FakeSocket('') + with self.assertRaisesRegex(ValueError, 'Tunnel host can\'t contain control characters'): + conn._tunnel() # Called in .connect() + def test_headers_debuglevel(self): body = ( b'HTTP/1.1 200 OK\r\n' diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py index 6920cf45533..60413aa7629 100644 --- a/Lib/test/test_import/__init__.py +++ b/Lib/test/test_import/__init__.py @@ -43,6 +43,7 @@ Py_GIL_DISABLED, no_rerun, force_not_colorized_test_class, + catch_unraisable_exception ) from test.support.import_helper import ( forget, make_legacy_pyc, unlink, unload, ready_to_import, @@ -1238,8 +1239,7 @@ def test_script_shadowing_stdlib_sys_path_modification(self): stdout, stderr = popen.communicate() self.assertRegex(stdout, expected_error) - # TODO: RUSTPYTHON: _imp.create_dynamic is for C extensions, not applicable - @unittest.skip("TODO: RustPython _imp.create_dynamic not implemented") + @unittest.expectedFailure # TODO: RUSTPYTHON; _imp.create_dynamic not implemented def test_create_dynamic_null(self): with self.assertRaisesRegex(ValueError, 'embedded null character'): class Spec: @@ -2544,6 +2544,32 @@ def test_disallowed_reimport(self): excsnap = _interpreters.run_string(interpid, script) self.assertIsNot(excsnap, None) + @requires_subinterpreters + def test_pyinit_function_raises_exception(self): + # gh-144601: PyInit functions that raised exceptions would cause a + # crash when imported from a subinterpreter. + import _testsinglephase + filename = _testsinglephase.__file__ + script = f"""if True: + from test.test_import import import_extension_from_file + + import_extension_from_file('_testsinglephase_raise_exception', {filename!r})""" + + interp = _interpreters.create() + try: + with catch_unraisable_exception() as cm: + exception = _interpreters.run_string(interp, script) + unraisable = cm.unraisable + finally: + _interpreters.destroy(interp) + + self.assertIsNotNone(exception) + self.assertIsNotNone(exception.type.__name__, "ImportError") + self.assertIsNotNone(exception.msg, "failed to import from subinterpreter due to exception") + self.assertIsNotNone(unraisable) + self.assertIs(unraisable.exc_type, RuntimeError) + self.assertEqual(str(unraisable.exc_value), "evil") + class TestSinglePhaseSnapshot(ModuleSnapshot): """A representation of a single-phase init module for testing. diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index d479f21f558..f7a7c0cc825 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -237,7 +237,6 @@ class FakePackage: self.assertFalse(inspect.ispackage(FakePackage())) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: False is not true def test_iscoroutine(self): async_gen_coro = async_generator_function_example(1) gen_coro = gen_coroutine_function_example(1) @@ -888,7 +887,6 @@ def test_getsource_on_generated_class(self): self.assertRaises(OSError, inspect.getsourcelines, A) self.assertIsNone(inspect.getcomments(A)) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: OSError not raised by getsource def test_getsource_on_class_without_firstlineno(self): __firstlineno__ = 1 class C: diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index 4496ecd6554..c51b547f31a 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -804,7 +804,6 @@ def test_closefd_attr(self): file = self.open(f.fileno(), "r", encoding="utf-8", closefd=False) self.assertEqual(file.buffer.raw.closefd, False) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: filter ('', ResourceWarning) did not catch any warning @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON; cyclic GC not supported, causes file locking") def test_garbage_collection(self): # FileIO objects are collected, and collecting them flushes @@ -1114,7 +1113,6 @@ def reader(file, barrier): class CIOTest(IOTest): - @unittest.expectedFailure # TODO: RUSTPYTHON; cyclic gc def test_IOBase_finalize(self): # Issue #12149: segmentation fault on _PyIOBase_finalize when both a # class which inherits IOBase and an object of this class are caught @@ -1823,7 +1821,6 @@ def test_misbehaved_io_read(self): # checking this is not so easy. self.assertRaises(OSError, bufio.read, 10) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: filter ('', ResourceWarning) did not catch any warning @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON; cyclic GC not supported, causes file locking") def test_garbage_collection(self): # C BufferedReader objects are collected. @@ -2173,7 +2170,6 @@ def test_initialization(self): self.assertRaises(ValueError, bufio.__init__, rawio, buffer_size=-1) self.assertRaises(ValueError, bufio.write, b"def") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: filter ('', ResourceWarning) did not catch any warning @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON; cyclic GC not supported, causes file locking") def test_garbage_collection(self): # C BufferedWriter objects are collected, and collecting them flushes @@ -2674,7 +2670,6 @@ def test_interleaved_readline_write(self): class CBufferedRandomTest(BufferedRandomTest, SizeofTest): tp = io.BufferedRandom - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: filter ('', ResourceWarning) did not catch any warning @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON; cyclic GC not supported, causes file locking") def test_garbage_collection(self): CBufferedReaderTest.test_garbage_collection(self) @@ -4119,7 +4114,6 @@ def test_initialization(self): t = self.TextIOWrapper.__new__(self.TextIOWrapper) self.assertRaises(Exception, repr, t) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: filter ('', ResourceWarning) did not catch any warning @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON; cyclic GC not supported, causes file locking") def test_garbage_collection(self): # C TextIOWrapper objects are collected, and collecting them flushes diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index 9c26eb08583..7ac48a50233 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -1137,7 +1137,7 @@ def test_iter_neg_setstate(self): self.assertEqual(next(it), 0) self.assertEqual(next(it), 1) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: False is not true + @unittest.skip("TODO: RUSTPYTHON; hangs") def test_free_after_iterating(self): check_free_after_iterating(self, iter, SequenceClass, (0,)) diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py index 7b3b30ce449..6d6ca11f673 100644 --- a/Lib/test/test_json/test_decode.py +++ b/Lib/test/test_json/test_decode.py @@ -18,7 +18,6 @@ def test_float(self): self.assertIsInstance(rval, float) self.assertEqual(rval, 1.0) - @unittest.skip("TODO: RUSTPYTHON; called `Result::unwrap()` on an `Err` value: ParseFloatError { kind: Invalid }") def test_nonascii_digits_rejected(self): # JSON specifies only ascii digits, see gh-125687 for num in ["1\uff10", "0.\uff10", "0e\uff10"]: diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py index 4adfcb17c4d..eec64b528cf 100644 --- a/Lib/test/test_json/test_fail.py +++ b/Lib/test/test_json/test_fail.py @@ -239,7 +239,4 @@ def test_linecol(self): (line, col, idx)) class TestPyFail(TestFail, PyTest): pass -class TestCFail(TestFail, CTest): - @unittest.expectedFailure # TODO: RUSTPYTHON - def test_failures(self): - return super().test_failures() +class TestCFail(TestFail, CTest): pass diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index d8a5eb27501..5678536ff18 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -324,6 +324,8 @@ def testAtanh(self): self.assertRaises(ValueError, math.atanh, NINF) self.assertTrue(math.isnan(math.atanh(NAN))) + @unittest.skipIf(sys.platform.startswith("sunos"), + "skipping, see gh-138573") def testAtan2(self): self.assertRaises(TypeError, math.atan2) self.ftest('atan2(-1, 0)', math.atan2(-1, 0), -math.pi/2) diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py index 891c4d76745..7889fa88d00 100644 --- a/Lib/test/test_memoryview.py +++ b/Lib/test/test_memoryview.py @@ -581,6 +581,28 @@ def test_array_assign(self): m[:] = new_a self.assertEqual(a, new_a) + def test_boolean_format(self): + # Test '?' format (keep all the checks below for UBSan) + # See github.com/python/cpython/issues/148390. + + # m1a and m1b are equivalent to [False, True, False] + m1a = memoryview(b'\0\2\0').cast('?') + self.assertEqual(m1a.tolist(), [False, True, False]) + m1b = memoryview(b'\0\4\0').cast('?') + self.assertEqual(m1b.tolist(), [False, True, False]) + self.assertEqual(m1a, m1b) + + # m2a and m2b are equivalent to [True, True, True] + m2a = memoryview(b'\1\3\5').cast('?') + self.assertEqual(m2a.tolist(), [True, True, True]) + m2b = memoryview(b'\2\4\6').cast('?') + self.assertEqual(m2b.tolist(), [True, True, True]) + self.assertEqual(m2a, m2b) + + allbytes = bytes(range(256)) + allbytes = memoryview(allbytes).cast('?') + self.assertEqual(allbytes.tolist(), [False] + [True] * 255) + class BytesMemorySliceTest(unittest.TestCase, BaseMemorySliceTests, BaseBytesMemoryTests): diff --git a/Lib/test/test_metaclass.py b/Lib/test/test_metaclass.py index 1707df9075a..dfa6c633a63 100644 --- a/Lib/test/test_metaclass.py +++ b/Lib/test/test_metaclass.py @@ -128,8 +128,7 @@ Check for duplicate keywords. - # TODO: RUSTPYTHON - >>> class C(metaclass=type, metaclass=type): pass # doctest: +SKIP + >>> class C(metaclass=type, metaclass=type): pass ... Traceback (most recent call last): [...] @@ -139,9 +138,7 @@ Another way. >>> kwds = {'metaclass': type} - - # TODO: RUSTPYTHON - >>> class C(metaclass=type, **kwds): pass # doctest: +SKIP + >>> class C(metaclass=type, **kwds): pass ... Traceback (most recent call last): [...] @@ -160,9 +157,7 @@ ... def __prepare__(name, bases): ... return LoggingDict() ... - - # TODO: RUSTPYTHON - >>> class C(metaclass=Meta): # doctest: +SKIP + >>> class C(metaclass=Meta): ... foo = 2+2 ... foo = 42 ... bar = 123 @@ -184,22 +179,16 @@ ... print("kw:", sorted(kwds.items())) ... return namespace ... - - # TODO: RUSTPYTHON - >>> class C(metaclass=meta): # doctest: +SKIP + >>> class C(metaclass=meta): ... a = 42 ... b = 24 ... meta: C () ns: [('__firstlineno__', 1), ('__module__', 'test.test_metaclass'), ('__qualname__', 'C'), ('__static_attributes__', ()), ('a', 42), ('b', 24)] kw: [] - - # TODO: RUSTPYTHON - >>> type(C) is dict # doctest: +SKIP + >>> type(C) is dict True - - # TODO: RUSTPYTHON - >>> print(sorted(C.items())) # doctest: +SKIP + >>> print(sorted(C.items())) [('__firstlineno__', 1), ('__module__', 'test.test_metaclass'), ('__qualname__', 'C'), ('__static_attributes__', ()), ('a', 42), ('b', 24)] >>> @@ -210,9 +199,7 @@ ... return LoggingDict() ... >>> meta.__prepare__ = prepare - - # TODO: RUSTPYTHON - >>> class C(metaclass=meta, other="booh"): # doctest: +SKIP + >>> class C(metaclass=meta, other="booh"): ... a = 1 ... a = 2 ... b = 3 @@ -281,23 +268,17 @@ ... >>> Base.value 1 - - # TODO: RUSTPYTHON; AttributeError: type object 'WeirdClass' has no attribute 'value' - >>> WeirdClass.value # doctest: +SKIP + >>> WeirdClass.value # TODO: RUSTPYTHON; AttributeError: type object 'WeirdClass' has no attribute 'value' # doctest: +SKIP 1 >>> Base.value = 2 >>> Base.value 2 - - # TODO: RUSTPYTHON; AttributeError: type object 'WeirdClass' has no attribute 'value' - >>> WeirdClass.value # doctest: +SKIP + >>> WeirdClass.value # TODO: RUSTPYTHON; AttributeError: type object 'WeirdClass' has no attribute 'value' # doctest: +SKIP 2 >>> Base.value = 3 >>> Base.value 3 - - # TODO: RUSTPYTHON; AttributeError: type object 'WeirdClass' has no attribute 'value' - >>> WeirdClass.value # doctest: +SKIP + >>> WeirdClass.value # TODO: RUSTPYTHON; AttributeError: type object 'WeirdClass' has no attribute 'value' # doctest: +SKIP 3 """ @@ -311,7 +292,8 @@ __test__ = {'doctests' : doctests} def load_tests(loader, tests, pattern): - tests.addTest(doctest.DocTestSuite()) + from test.support.rustpython import DocTestChecker # TODO: RUSTPYTHON + tests.addTest(doctest.DocTestSuite(checker=DocTestChecker())) return tests diff --git a/Lib/test/test_module/__init__.py b/Lib/test/test_module/__init__.py index 4704ad29974..d4ed61648dd 100644 --- a/Lib/test/test_module/__init__.py +++ b/Lib/test/test_module/__init__.py @@ -103,7 +103,6 @@ def f(): gc_collect() self.assertEqual(f().__dict__["bar"], 4) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_clear_dict_in_ref_cycle(self): destroyed = [] m = ModuleType("foo") diff --git a/Lib/test/test_monitoring.py b/Lib/test/test_monitoring.py index 5125701202b..1f72b552c6c 100644 --- a/Lib/test/test_monitoring.py +++ b/Lib/test/test_monitoring.py @@ -1624,7 +1624,6 @@ def whilefunc(n=0): ('branch left', 'func', 44, 50), ('branch right', 'func', 28, 70)]) - @unittest.expectedFailure # TODO: RUSTPYTHON; - bytecode layout differs from CPython def test_except_star(self): class Foo: diff --git a/Lib/test/test_named_expressions.py b/Lib/test/test_named_expressions.py index adf774f102f..4f92176b301 100644 --- a/Lib/test/test_named_expressions.py +++ b/Lib/test/test_named_expressions.py @@ -63,7 +63,6 @@ def test_named_expression_invalid_10(self): with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_11(self): code = """spam(a=1, b := 2)""" @@ -71,7 +70,6 @@ def test_named_expression_invalid_11(self): "positional argument follows keyword argument"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_12(self): code = """spam(a=1, (b := 2))""" @@ -79,7 +77,6 @@ def test_named_expression_invalid_12(self): "positional argument follows keyword argument"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_13(self): code = """spam(a=1, (b := 2))""" @@ -93,7 +90,6 @@ def test_named_expression_invalid_14(self): with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_named_expression_invalid_15(self): code = """(lambda: x := 1)""" diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py index 378f6c5ab59..ae7935ac07e 100644 --- a/Lib/test/test_ordered_dict.py +++ b/Lib/test/test_ordered_dict.py @@ -680,7 +680,6 @@ class A: gc.collect() self.assertIsNone(r()) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: False is not true def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, self.OrderedDict) support.check_free_after_iterating(self, lambda d: iter(d.keys()), self.OrderedDict) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 40466ec67ba..8d359a646d9 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3433,7 +3433,6 @@ def trace(frame, event, arg): sys.settrace(old_trace) return actual_linenos - @unittest.expectedFailure # TODO: RUSTPYTHON def test_default_wildcard(self): def f(command): # 0 match command.split(): # 1 @@ -3494,7 +3493,6 @@ def f(command): # 0 self.assertListEqual(self._trace(f, "go x"), [1, 2, 3]) self.assertListEqual(self._trace(f, "spam"), [1, 2, 3]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_unreachable_code(self): def f(command): # 0 match command: # 1 diff --git a/Lib/test/test_peg_generator/__init__.py b/Lib/test/test_peg_generator/__init__.py new file mode 100644 index 00000000000..b32db4426f2 --- /dev/null +++ b/Lib/test/test_peg_generator/__init__.py @@ -0,0 +1,12 @@ +import os.path +from test import support +from test.support import load_package_tests + + +# Creating a virtual environment and building C extensions is slow +support.requires('cpu') + + +# Load all tests in package +def load_tests(*args): + return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_peg_generator/__main__.py b/Lib/test/test_peg_generator/__main__.py new file mode 100644 index 00000000000..1fab1fddb57 --- /dev/null +++ b/Lib/test/test_peg_generator/__main__.py @@ -0,0 +1,4 @@ +import unittest +from . import load_tests + +unittest.main() diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py new file mode 100644 index 00000000000..aa01a9b8f7e --- /dev/null +++ b/Lib/test/test_peg_generator/test_c_parser.py @@ -0,0 +1,523 @@ +import contextlib +import subprocess +import sysconfig +import textwrap +import unittest +import os +import shutil +import tempfile +from pathlib import Path + +from test import test_tools +from test import support +from test.support import os_helper, import_helper +from test.support.script_helper import assert_python_ok + +if support.check_cflags_pgo(): + raise unittest.SkipTest("peg_generator test disabled under PGO build") + +test_tools.skip_if_missing("peg_generator") +with test_tools.imports_under_tool("peg_generator"): + from pegen.grammar_parser import GeneratedParser as GrammarParser + from pegen.testutil import ( + parse_string, + generate_parser_c_extension, + generate_c_parser_source, + ) + + +TEST_TEMPLATE = """ +tmp_dir = {extension_path!r} + +import ast +import traceback +import sys +import unittest + +from test import test_tools +with test_tools.imports_under_tool("peg_generator"): + from pegen.ast_dump import ast_dump + +sys.path.insert(0, tmp_dir) +import parse + +class Tests(unittest.TestCase): + + def check_input_strings_for_grammar( + self, + valid_cases = (), + invalid_cases = (), + ): + if valid_cases: + for case in valid_cases: + parse.parse_string(case, mode=0) + + if invalid_cases: + for case in invalid_cases: + with self.assertRaises(SyntaxError): + parse.parse_string(case, mode=0) + + def verify_ast_generation(self, stmt): + expected_ast = ast.parse(stmt) + actual_ast = parse.parse_string(stmt, mode=1) + self.assertEqual(ast_dump(expected_ast), ast_dump(actual_ast)) + + def test_parse(self): + {test_source} + +unittest.main() +""" + + +@support.requires_subprocess() +class TestCParser(unittest.TestCase): + + _has_run = False + + @classmethod + def setUpClass(cls): + if cls._has_run: + # Since gh-104798 (Use setuptools in peg-generator and reenable + # tests), this test case has been producing ref leaks. Initial + # debugging points to bug(s) in setuptools and/or importlib. + # See gh-105063 for more info. + raise unittest.SkipTest("gh-105063: can not rerun because of ref. leaks") + cls._has_run = True + + # When running under regtest, a separate tempdir is used + # as the current directory and watched for left-overs. + # Reusing that as the base for temporary directories + # ensures everything is cleaned up properly and + # cleans up afterwards if not (with warnings). + cls.tmp_base = os.getcwd() + if os.path.samefile(cls.tmp_base, os_helper.SAVEDCWD): + cls.tmp_base = None + # Create a directory for the reuseable static library part of + # the pegen extension build process. This greatly reduces the + # runtime overhead of spawning compiler processes. + cls.library_dir = tempfile.mkdtemp(dir=cls.tmp_base) + cls.addClassCleanup(shutil.rmtree, cls.library_dir) + + with contextlib.ExitStack() as stack: + python_exe = stack.enter_context(support.setup_venv_with_pip_setuptools("venv")) + sitepackages = subprocess.check_output( + [python_exe, "-c", "import sysconfig; print(sysconfig.get_path('platlib'))"], + text=True, + ).strip() + stack.enter_context(import_helper.DirsOnSysPath(sitepackages)) + cls.addClassCleanup(stack.pop_all().close) + + @support.requires_venv_with_pip() + def setUp(self): + self._backup_config_vars = dict(sysconfig._CONFIG_VARS) + cmd = support.missing_compiler_executable() + if cmd is not None: + self.skipTest("The %r command is not found" % cmd) + self.old_cwd = os.getcwd() + self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base) + self.enterContext(os_helper.change_cwd(self.tmp_path)) + + def tearDown(self): + os.chdir(self.old_cwd) + shutil.rmtree(self.tmp_path) + sysconfig._CONFIG_VARS.clear() + sysconfig._CONFIG_VARS.update(self._backup_config_vars) + + def build_extension(self, grammar_source): + grammar = parse_string(grammar_source, GrammarParser) + # Because setUp() already changes the current directory to the + # temporary path, use a relative path here to prevent excessive + # path lengths when compiling. + generate_parser_c_extension(grammar, Path('.'), library_dir=self.library_dir) + + def run_test(self, grammar_source, test_source): + self.build_extension(grammar_source) + test_source = textwrap.indent(textwrap.dedent(test_source), 8 * " ") + assert_python_ok( + "-c", + TEST_TEMPLATE.format(extension_path=self.tmp_path, test_source=test_source), + ) + + def test_c_parser(self) -> None: + grammar_source = """ + start[mod_ty]: a[asdl_stmt_seq*]=stmt* $ { _PyAST_Module(a, NULL, p->arena) } + stmt[stmt_ty]: a=expr_stmt { a } + expr_stmt[stmt_ty]: a=expression NEWLINE { _PyAST_Expr(a, EXTRA) } + expression[expr_ty]: ( l=expression '+' r=term { _PyAST_BinOp(l, Add, r, EXTRA) } + | l=expression '-' r=term { _PyAST_BinOp(l, Sub, r, EXTRA) } + | t=term { t } + ) + term[expr_ty]: ( l=term '*' r=factor { _PyAST_BinOp(l, Mult, r, EXTRA) } + | l=term '/' r=factor { _PyAST_BinOp(l, Div, r, EXTRA) } + | f=factor { f } + ) + factor[expr_ty]: ('(' e=expression ')' { e } + | a=atom { a } + ) + atom[expr_ty]: ( n=NAME { n } + | n=NUMBER { n } + | s=STRING { s } + ) + """ + test_source = """ + expressions = [ + "4+5", + "4-5", + "4*5", + "1+4*5", + "1+4/5", + "(1+1) + (1+1)", + "(1+1) - (1+1)", + "(1+1) * (1+1)", + "(1+1) / (1+1)", + ] + + for expr in expressions: + the_ast = parse.parse_string(expr, mode=1) + expected_ast = ast.parse(expr) + self.assertEqual(ast_dump(the_ast), ast_dump(expected_ast)) + """ + self.run_test(grammar_source, test_source) + + def test_lookahead(self) -> None: + grammar_source = """ + start: NAME &NAME expr NEWLINE? ENDMARKER + expr: NAME | NUMBER + """ + test_source = """ + valid_cases = ["foo bar"] + invalid_cases = ["foo 34"] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_negative_lookahead(self) -> None: + grammar_source = """ + start: NAME !NAME expr NEWLINE? ENDMARKER + expr: NAME | NUMBER + """ + test_source = """ + valid_cases = ["foo 34"] + invalid_cases = ["foo bar"] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_cut(self) -> None: + grammar_source = """ + start: X ~ Y Z | X Q S + X: 'x' + Y: 'y' + Z: 'z' + Q: 'q' + S: 's' + """ + test_source = """ + valid_cases = ["x y z"] + invalid_cases = ["x q s"] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_gather(self) -> None: + grammar_source = """ + start: ';'.pass_stmt+ NEWLINE + pass_stmt: 'pass' + """ + test_source = """ + valid_cases = ["pass", "pass; pass"] + invalid_cases = ["pass;", "pass; pass;"] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_left_recursion(self) -> None: + grammar_source = """ + start: expr NEWLINE + expr: ('-' term | expr '+' term | term) + term: NUMBER + """ + test_source = """ + valid_cases = ["-34", "34", "34 + 12", "1 + 1 + 2 + 3"] + self.check_input_strings_for_grammar(valid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_advanced_left_recursive(self) -> None: + grammar_source = """ + start: NUMBER | sign start + sign: ['-'] + """ + test_source = """ + valid_cases = ["23", "-34"] + self.check_input_strings_for_grammar(valid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_mutually_left_recursive(self) -> None: + grammar_source = """ + start: foo 'E' + foo: bar 'A' | 'B' + bar: foo 'C' | 'D' + """ + test_source = """ + valid_cases = ["B E", "D A C A E"] + self.check_input_strings_for_grammar(valid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_nasty_mutually_left_recursive(self) -> None: + grammar_source = """ + start: target '=' + target: maybe '+' | NAME + maybe: maybe '-' | target + """ + test_source = """ + valid_cases = ["x ="] + invalid_cases = ["x - + ="] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_return_stmt_noexpr_action(self) -> None: + grammar_source = """ + start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) } + statements[asdl_stmt_seq*]: a[asdl_stmt_seq*]=statement+ { a } + statement[stmt_ty]: simple_stmt + simple_stmt[stmt_ty]: small_stmt + small_stmt[stmt_ty]: return_stmt + return_stmt[stmt_ty]: a='return' NEWLINE { _PyAST_Return(NULL, EXTRA) } + """ + test_source = """ + stmt = "return" + self.verify_ast_generation(stmt) + """ + self.run_test(grammar_source, test_source) + + def test_gather_action_ast(self) -> None: + grammar_source = """ + start[mod_ty]: a[asdl_stmt_seq*]=';'.pass_stmt+ NEWLINE ENDMARKER { _PyAST_Module(a, NULL, p->arena) } + pass_stmt[stmt_ty]: a='pass' { _PyAST_Pass(EXTRA)} + """ + test_source = """ + stmt = "pass; pass" + self.verify_ast_generation(stmt) + """ + self.run_test(grammar_source, test_source) + + def test_pass_stmt_action(self) -> None: + grammar_source = """ + start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) } + statements[asdl_stmt_seq*]: a[asdl_stmt_seq*]=statement+ { a } + statement[stmt_ty]: simple_stmt + simple_stmt[stmt_ty]: small_stmt + small_stmt[stmt_ty]: pass_stmt + pass_stmt[stmt_ty]: a='pass' NEWLINE { _PyAST_Pass(EXTRA) } + """ + test_source = """ + stmt = "pass" + self.verify_ast_generation(stmt) + """ + self.run_test(grammar_source, test_source) + + def test_if_stmt_action(self) -> None: + grammar_source = """ + start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) } + statements[asdl_stmt_seq*]: a=statement+ { (asdl_stmt_seq*)_PyPegen_seq_flatten(p, a) } + statement[asdl_stmt_seq*]: a=compound_stmt { (asdl_stmt_seq*)_PyPegen_singleton_seq(p, a) } | simple_stmt + + simple_stmt[asdl_stmt_seq*]: a=small_stmt b=further_small_stmt* [';'] NEWLINE { + (asdl_stmt_seq*)_PyPegen_seq_insert_in_front(p, a, b) } + further_small_stmt[stmt_ty]: ';' a=small_stmt { a } + + block: simple_stmt | NEWLINE INDENT a=statements DEDENT { a } + + compound_stmt: if_stmt + + if_stmt: 'if' a=full_expression ':' b=block { _PyAST_If(a, b, NULL, EXTRA) } + + small_stmt[stmt_ty]: pass_stmt + + pass_stmt[stmt_ty]: a='pass' { _PyAST_Pass(EXTRA) } + + full_expression: NAME + """ + test_source = """ + stmt = "pass" + self.verify_ast_generation(stmt) + """ + self.run_test(grammar_source, test_source) + + def test_same_name_different_types(self) -> None: + grammar_source = """ + start[mod_ty]: a[asdl_stmt_seq*]=import_from+ NEWLINE ENDMARKER { _PyAST_Module(a, NULL, p->arena)} + import_from[stmt_ty]: ( a='from' !'import' c=simple_name 'import' d=import_as_names_from { + _PyAST_ImportFrom(c->v.Name.id, d, 0, EXTRA) } + | a='from' '.' 'import' c=import_as_names_from { + _PyAST_ImportFrom(NULL, c, 1, EXTRA) } + ) + simple_name[expr_ty]: NAME + import_as_names_from[asdl_alias_seq*]: a[asdl_alias_seq*]=','.import_as_name_from+ { a } + import_as_name_from[alias_ty]: a=NAME 'as' b=NAME { _PyAST_alias(((expr_ty) a)->v.Name.id, ((expr_ty) b)->v.Name.id, EXTRA) } + """ + test_source = """ + for stmt in ("from a import b as c", "from . import a as b"): + expected_ast = ast.parse(stmt) + actual_ast = parse.parse_string(stmt, mode=1) + self.assertEqual(ast_dump(expected_ast), ast_dump(actual_ast)) + """ + self.run_test(grammar_source, test_source) + + def test_with_stmt_with_paren(self) -> None: + grammar_source = """ + start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) } + statements[asdl_stmt_seq*]: a=statement+ { (asdl_stmt_seq*)_PyPegen_seq_flatten(p, a) } + statement[asdl_stmt_seq*]: a=compound_stmt { (asdl_stmt_seq*)_PyPegen_singleton_seq(p, a) } + compound_stmt[stmt_ty]: with_stmt + with_stmt[stmt_ty]: ( + a='with' '(' b[asdl_withitem_seq*]=','.with_item+ ')' ':' c=block { + _PyAST_With(b, (asdl_stmt_seq*) _PyPegen_singleton_seq(p, c), NULL, EXTRA) } + ) + with_item[withitem_ty]: ( + e=NAME o=['as' t=NAME { t }] { _PyAST_withitem(e, _PyPegen_set_expr_context(p, o, Store), p->arena) } + ) + block[stmt_ty]: a=pass_stmt NEWLINE { a } | NEWLINE INDENT a=pass_stmt DEDENT { a } + pass_stmt[stmt_ty]: a='pass' { _PyAST_Pass(EXTRA) } + """ + test_source = """ + stmt = "with (\\n a as b,\\n c as d\\n): pass" + the_ast = parse.parse_string(stmt, mode=1) + self.assertStartsWith(ast_dump(the_ast), + "Module(body=[With(items=[withitem(context_expr=Name(id='a', ctx=Load()), optional_vars=Name(id='b', ctx=Store())), " + "withitem(context_expr=Name(id='c', ctx=Load()), optional_vars=Name(id='d', ctx=Store()))]" + ) + """ + self.run_test(grammar_source, test_source) + + def test_ternary_operator(self) -> None: + grammar_source = """ + start[mod_ty]: a=expr ENDMARKER { _PyAST_Module(a, NULL, p->arena) } + expr[asdl_stmt_seq*]: a=listcomp NEWLINE { (asdl_stmt_seq*)_PyPegen_singleton_seq(p, _PyAST_Expr(a, EXTRA)) } + listcomp[expr_ty]: ( + a='[' b=NAME c=for_if_clauses d=']' { _PyAST_ListComp(b, c, EXTRA) } + ) + for_if_clauses[asdl_comprehension_seq*]: ( + a[asdl_comprehension_seq*]=(y=['async'] 'for' a=NAME 'in' b=NAME c[asdl_expr_seq*]=('if' z=NAME { z })* + { _PyAST_comprehension(_PyAST_Name(((expr_ty) a)->v.Name.id, Store, EXTRA), b, c, (y == NULL) ? 0 : 1, p->arena) })+ { a } + ) + """ + test_source = """ + stmt = "[i for i in a if b]" + self.verify_ast_generation(stmt) + """ + self.run_test(grammar_source, test_source) + + def test_syntax_error_for_string(self) -> None: + grammar_source = """ + start: expr+ NEWLINE? ENDMARKER + expr: NAME + """ + test_source = r""" + for text in ("a b 42 b a", "\u540d \u540d 42 \u540d \u540d"): + try: + parse.parse_string(text, mode=0) + except SyntaxError as e: + tb = traceback.format_exc() + self.assertTrue('File "", line 1' in tb) + self.assertTrue(f"SyntaxError: invalid syntax" in tb) + """ + self.run_test(grammar_source, test_source) + + def test_headers_and_trailer(self) -> None: + grammar_source = """ + @header 'SOME HEADER' + @subheader 'SOME SUBHEADER' + @trailer 'SOME TRAILER' + start: expr+ NEWLINE? ENDMARKER + expr: x=NAME + """ + grammar = parse_string(grammar_source, GrammarParser) + parser_source = generate_c_parser_source(grammar) + + self.assertTrue("SOME HEADER" in parser_source) + self.assertTrue("SOME SUBHEADER" in parser_source) + self.assertTrue("SOME TRAILER" in parser_source) + + def test_error_in_rules(self) -> None: + grammar_source = """ + start: expr+ NEWLINE? ENDMARKER + expr: NAME {PyTuple_New(-1)} + """ + # PyTuple_New raises SystemError if an invalid argument was passed. + test_source = """ + with self.assertRaises(SystemError): + parse.parse_string("a", mode=0) + """ + self.run_test(grammar_source, test_source) + + def test_no_soft_keywords(self) -> None: + grammar_source = """ + start: expr+ NEWLINE? ENDMARKER + expr: 'foo' + """ + grammar = parse_string(grammar_source, GrammarParser) + parser_source = generate_c_parser_source(grammar) + assert "expect_soft_keyword" not in parser_source + + def test_soft_keywords(self) -> None: + grammar_source = """ + start: expr+ NEWLINE? ENDMARKER + expr: "foo" + """ + grammar = parse_string(grammar_source, GrammarParser) + parser_source = generate_c_parser_source(grammar) + assert "expect_soft_keyword" in parser_source + + def test_soft_keywords_parse(self) -> None: + grammar_source = """ + start: "if" expr '+' expr NEWLINE + expr: NAME + """ + test_source = """ + valid_cases = ["if if + if"] + invalid_cases = ["if if"] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_soft_keywords_lookahead(self) -> None: + grammar_source = """ + start: &"if" "if" expr '+' expr NEWLINE + expr: NAME + """ + test_source = """ + valid_cases = ["if if + if"] + invalid_cases = ["if if"] + self.check_input_strings_for_grammar(valid_cases, invalid_cases) + """ + self.run_test(grammar_source, test_source) + + def test_forced(self) -> None: + grammar_source = """ + start: NAME &&':' | NAME + """ + test_source = """ + self.assertEqual(parse.parse_string("number :", mode=0), None) + with self.assertRaises(SyntaxError) as e: + parse.parse_string("a", mode=0) + self.assertIn("expected ':'", str(e.exception)) + """ + self.run_test(grammar_source, test_source) + + def test_forced_with_group(self) -> None: + grammar_source = """ + start: NAME &&(':' | ';') | NAME + """ + test_source = """ + self.assertEqual(parse.parse_string("number :", mode=0), None) + self.assertEqual(parse.parse_string("number ;", mode=0), None) + with self.assertRaises(SyntaxError) as e: + parse.parse_string("a", mode=0) + self.assertIn("expected (':' | ';')", e.exception.args[0]) + """ + self.run_test(grammar_source, test_source) diff --git a/Lib/test/test_peg_generator/test_first_sets.py b/Lib/test/test_peg_generator/test_first_sets.py new file mode 100644 index 00000000000..d6f8322f034 --- /dev/null +++ b/Lib/test/test_peg_generator/test_first_sets.py @@ -0,0 +1,286 @@ +import unittest + +from test import test_tools +from typing import Dict, Set + +test_tools.skip_if_missing("peg_generator") +with test_tools.imports_under_tool("peg_generator"): + from pegen.grammar_parser import GeneratedParser as GrammarParser + from pegen.testutil import parse_string + from pegen.first_sets import FirstSetCalculator + from pegen.grammar import Grammar + + +class TestFirstSets(unittest.TestCase): + def calculate_first_sets(self, grammar_source: str) -> Dict[str, Set[str]]: + grammar: Grammar = parse_string(grammar_source, GrammarParser) + return FirstSetCalculator(grammar.rules).calculate() + + def test_alternatives(self) -> None: + grammar = """ + start: expr NEWLINE? ENDMARKER + expr: A | B + A: 'a' | '-' + B: 'b' | '+' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "A": {"'a'", "'-'"}, + "B": {"'+'", "'b'"}, + "expr": {"'+'", "'a'", "'b'", "'-'"}, + "start": {"'+'", "'a'", "'b'", "'-'"}, + }, + ) + + def test_optionals(self) -> None: + grammar = """ + start: expr NEWLINE + expr: ['a'] ['b'] 'c' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "expr": {"'c'", "'a'", "'b'"}, + "start": {"'c'", "'a'", "'b'"}, + }, + ) + + def test_repeat_with_separator(self) -> None: + grammar = """ + start: ','.thing+ NEWLINE + thing: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"thing": {"NUMBER"}, "start": {"NUMBER"}}, + ) + + def test_optional_operator(self) -> None: + grammar = """ + start: sum NEWLINE + sum: (term)? 'b' + term: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "term": {"NUMBER"}, + "sum": {"NUMBER", "'b'"}, + "start": {"'b'", "NUMBER"}, + }, + ) + + def test_optional_literal(self) -> None: + grammar = """ + start: sum NEWLINE + sum: '+' ? term + term: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "term": {"NUMBER"}, + "sum": {"'+'", "NUMBER"}, + "start": {"'+'", "NUMBER"}, + }, + ) + + def test_optional_after(self) -> None: + grammar = """ + start: term NEWLINE + term: NUMBER ['+'] + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"term": {"NUMBER"}, "start": {"NUMBER"}}, + ) + + def test_optional_before(self) -> None: + grammar = """ + start: term NEWLINE + term: ['+'] NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"term": {"NUMBER", "'+'"}, "start": {"NUMBER", "'+'"}}, + ) + + def test_repeat_0(self) -> None: + grammar = """ + start: thing* "+" NEWLINE + thing: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"thing": {"NUMBER"}, "start": {'"+"', "NUMBER"}}, + ) + + def test_repeat_0_with_group(self) -> None: + grammar = """ + start: ('+' '-')* term NEWLINE + term: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"term": {"NUMBER"}, "start": {"'+'", "NUMBER"}}, + ) + + def test_repeat_1(self) -> None: + grammar = """ + start: thing+ '-' NEWLINE + thing: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"thing": {"NUMBER"}, "start": {"NUMBER"}}, + ) + + def test_repeat_1_with_group(self) -> None: + grammar = """ + start: ('+' term)+ term NEWLINE + term: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), {"term": {"NUMBER"}, "start": {"'+'"}} + ) + + def test_gather(self) -> None: + grammar = """ + start: ','.thing+ NEWLINE + thing: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"thing": {"NUMBER"}, "start": {"NUMBER"}}, + ) + + def test_positive_lookahead(self) -> None: + grammar = """ + start: expr NEWLINE + expr: &'a' opt + opt: 'a' | 'b' | 'c' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "expr": {"'a'"}, + "start": {"'a'"}, + "opt": {"'b'", "'c'", "'a'"}, + }, + ) + + def test_negative_lookahead(self) -> None: + grammar = """ + start: expr NEWLINE + expr: !'a' opt + opt: 'a' | 'b' | 'c' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "opt": {"'b'", "'a'", "'c'"}, + "expr": {"'b'", "'c'"}, + "start": {"'b'", "'c'"}, + }, + ) + + def test_left_recursion(self) -> None: + grammar = """ + start: expr NEWLINE + expr: ('-' term | expr '+' term | term) + term: NUMBER + foo: 'foo' + bar: 'bar' + baz: 'baz' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "expr": {"NUMBER", "'-'"}, + "term": {"NUMBER"}, + "start": {"NUMBER", "'-'"}, + "foo": {"'foo'"}, + "bar": {"'bar'"}, + "baz": {"'baz'"}, + }, + ) + + def test_advance_left_recursion(self) -> None: + grammar = """ + start: NUMBER | sign start + sign: ['-'] + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"sign": {"'-'", ""}, "start": {"'-'", "NUMBER"}}, + ) + + def test_mutual_left_recursion(self) -> None: + grammar = """ + start: foo 'E' + foo: bar 'A' | 'B' + bar: foo 'C' | 'D' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "foo": {"'D'", "'B'"}, + "bar": {"'D'"}, + "start": {"'D'", "'B'"}, + }, + ) + + def test_nasty_left_recursion(self) -> None: + # TODO: Validate this + grammar = """ + start: target '=' + target: maybe '+' | NAME + maybe: maybe '-' | target + """ + self.assertEqual( + self.calculate_first_sets(grammar), + {"maybe": set(), "target": {"NAME"}, "start": {"NAME"}}, + ) + + def test_nullable_rule(self) -> None: + grammar = """ + start: sign thing $ + sign: ['-'] + thing: NUMBER + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "sign": {"", "'-'"}, + "thing": {"NUMBER"}, + "start": {"NUMBER", "'-'"}, + }, + ) + + def test_epsilon_production_in_start_rule(self) -> None: + grammar = """ + start: ['-'] $ + """ + self.assertEqual( + self.calculate_first_sets(grammar), {"start": {"ENDMARKER", "'-'"}} + ) + + def test_multiple_nullable_rules(self) -> None: + grammar = """ + start: sign thing other another $ + sign: ['-'] + thing: ['+'] + other: '*' + another: '/' + """ + self.assertEqual( + self.calculate_first_sets(grammar), + { + "sign": {"", "'-'"}, + "thing": {"'+'", ""}, + "start": {"'+'", "'-'", "'*'"}, + "other": {"'*'"}, + "another": {"'/'"}, + }, + ) diff --git a/Lib/test/test_peg_generator/test_grammar_validator.py b/Lib/test/test_peg_generator/test_grammar_validator.py new file mode 100644 index 00000000000..857aced8ae5 --- /dev/null +++ b/Lib/test/test_peg_generator/test_grammar_validator.py @@ -0,0 +1,77 @@ +import unittest +from test import test_tools + +test_tools.skip_if_missing("peg_generator") +with test_tools.imports_under_tool("peg_generator"): + from pegen.grammar_parser import GeneratedParser as GrammarParser + from pegen.validator import SubRuleValidator, ValidationError + from pegen.validator import RaiseRuleValidator, CutValidator + from pegen.testutil import parse_string + from pegen.grammar import Grammar + + +class TestPegen(unittest.TestCase): + def test_rule_with_no_collision(self) -> None: + grammar_source = """ + start: bad_rule + sum: + | NAME '-' NAME + | NAME '+' NAME + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + validator = SubRuleValidator(grammar) + for rule_name, rule in grammar.rules.items(): + validator.validate_rule(rule_name, rule) + + def test_rule_with_simple_collision(self) -> None: + grammar_source = """ + start: bad_rule + sum: + | NAME '+' NAME + | NAME '+' NAME ';' + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + validator = SubRuleValidator(grammar) + with self.assertRaises(ValidationError): + for rule_name, rule in grammar.rules.items(): + validator.validate_rule(rule_name, rule) + + def test_rule_with_collision_after_some_other_rules(self) -> None: + grammar_source = """ + start: bad_rule + sum: + | NAME '+' NAME + | NAME '*' NAME ';' + | NAME '-' NAME + | NAME '+' NAME ';' + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + validator = SubRuleValidator(grammar) + with self.assertRaises(ValidationError): + for rule_name, rule in grammar.rules.items(): + validator.validate_rule(rule_name, rule) + + def test_raising_valid_rule(self) -> None: + grammar_source = """ + start: NAME { RAISE_SYNTAX_ERROR("this is not allowed") } + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + validator = RaiseRuleValidator(grammar) + with self.assertRaises(ValidationError): + for rule_name, rule in grammar.rules.items(): + validator.validate_rule(rule_name, rule) + + def test_cut_validator(self) -> None: + grammar_source = """ + star: (OP ~ OP)* + plus: (OP ~ OP)+ + bracket: [OP ~ OP] + gather: OP.(OP ~ OP)+ + nested: [OP | NAME ~ OP] + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + validator = CutValidator(grammar) + for rule_name, rule in grammar.rules.items(): + with self.subTest(rule_name): + with self.assertRaises(ValidationError): + validator.validate_rule(rule_name, rule) diff --git a/Lib/test/test_peg_generator/test_pegen.py b/Lib/test/test_peg_generator/test_pegen.py new file mode 100644 index 00000000000..58ce558c548 --- /dev/null +++ b/Lib/test/test_peg_generator/test_pegen.py @@ -0,0 +1,1132 @@ +import ast +import difflib +import io +import textwrap +import unittest + +from test import test_tools +from typing import Dict, Any +from tokenize import TokenInfo, NAME, NEWLINE, NUMBER, OP + +test_tools.skip_if_missing("peg_generator") +with test_tools.imports_under_tool("peg_generator"): + from pegen.grammar_parser import GeneratedParser as GrammarParser + from pegen.testutil import parse_string, generate_parser, make_parser + from pegen.grammar import GrammarVisitor, GrammarError, Grammar + from pegen.grammar_visualizer import ASTGrammarPrinter + from pegen.parser import Parser + from pegen.parser_generator import compute_nullables, compute_left_recursives + from pegen.python_generator import PythonParserGenerator + + +class TestPegen(unittest.TestCase): + def test_parse_grammar(self) -> None: + grammar_source = """ + start: sum NEWLINE + sum: t1=term '+' t2=term { action } | term + term: NUMBER + """ + expected = """ + start: sum NEWLINE + sum: term '+' term | term + term: NUMBER + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + rules = grammar.rules + self.assertEqual(str(grammar), textwrap.dedent(expected).strip()) + # Check the str() and repr() of a few rules; AST nodes don't support ==. + self.assertEqual(str(rules["start"]), "start: sum NEWLINE") + self.assertEqual(str(rules["sum"]), "sum: term '+' term | term") + expected_repr = ( + "Rule('term', None, Rhs([Alt([NamedItem(None, NameLeaf('NUMBER'))])]))" + ) + self.assertEqual(repr(rules["term"]), expected_repr) + + def test_repeated_rules(self) -> None: + grammar_source = """ + start: the_rule NEWLINE + the_rule: 'b' NEWLINE + the_rule: 'a' NEWLINE + """ + with self.assertRaisesRegex(GrammarError, "Repeated rule 'the_rule'"): + parse_string(grammar_source, GrammarParser) + + def test_long_rule_str(self) -> None: + grammar_source = """ + start: zero | one | one zero | one one | one zero zero | one zero one | one one zero | one one one + """ + expected = """ + start: + | zero + | one + | one zero + | one one + | one zero zero + | one zero one + | one one zero + | one one one + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + self.assertEqual(str(grammar.rules["start"]), textwrap.dedent(expected).strip()) + + def test_typed_rules(self) -> None: + grammar = """ + start[int]: sum NEWLINE + sum[int]: t1=term '+' t2=term { action } | term + term[int]: NUMBER + """ + rules = parse_string(grammar, GrammarParser).rules + # Check the str() and repr() of a few rules; AST nodes don't support ==. + self.assertEqual(str(rules["start"]), "start: sum NEWLINE") + self.assertEqual(str(rules["sum"]), "sum: term '+' term | term") + self.assertEqual( + repr(rules["term"]), + "Rule('term', 'int', Rhs([Alt([NamedItem(None, NameLeaf('NUMBER'))])]))", + ) + + def test_gather(self) -> None: + grammar = """ + start: ','.thing+ NEWLINE + thing: NUMBER + """ + rules = parse_string(grammar, GrammarParser).rules + self.assertEqual(str(rules["start"]), "start: ','.thing+ NEWLINE") + self.assertStartsWith(repr(rules["start"]), + "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'" + ) + self.assertEqual(str(rules["thing"]), "thing: NUMBER") + parser_class = make_parser(grammar) + node = parse_string("42\n", parser_class) + node = parse_string("1, 2\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1, 2\n" + ), + TokenInfo( + NUMBER, string="2", start=(1, 3), end=(1, 4), line="1, 2\n" + ), + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 4), end=(1, 5), line="1, 2\n" + ), + ], + ) + + def test_expr_grammar(self) -> None: + grammar = """ + start: sum NEWLINE + sum: term '+' term | term + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("42\n", parser_class) + self.assertEqual( + node, + [ + TokenInfo(NUMBER, string="42", start=(1, 0), end=(1, 2), line="42\n"), + TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="42\n"), + ], + ) + + def test_optional_operator(self) -> None: + grammar = """ + start: sum NEWLINE + sum: term ('+' term)? + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1 + 2\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2\n" + ), + [ + TokenInfo( + OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2\n" + ), + TokenInfo( + NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2\n" + ), + ], + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 + 2\n" + ), + ], + ) + node = parse_string("1\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"), + None, + ], + TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"), + ], + ) + + def test_optional_literal(self) -> None: + grammar = """ + start: sum NEWLINE + sum: term '+' ? + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1+\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1+\n" + ), + TokenInfo(OP, string="+", start=(1, 1), end=(1, 2), line="1+\n"), + ], + TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="1+\n"), + ], + ) + node = parse_string("1\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"), + None, + ], + TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"), + ], + ) + + def test_alt_optional_operator(self) -> None: + grammar = """ + start: sum NEWLINE + sum: term ['+' term] + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1 + 2\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2\n" + ), + [ + TokenInfo( + OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2\n" + ), + TokenInfo( + NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2\n" + ), + ], + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 + 2\n" + ), + ], + ) + node = parse_string("1\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"), + None, + ], + TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"), + ], + ) + + def test_repeat_0_simple(self) -> None: + grammar = """ + start: thing thing* NEWLINE + thing: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1 2 3\n", parser_class) + self.assertEqual( + node, + [ + TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 2 3\n"), + [ + TokenInfo( + NUMBER, string="2", start=(1, 2), end=(1, 3), line="1 2 3\n" + ), + TokenInfo( + NUMBER, string="3", start=(1, 4), end=(1, 5), line="1 2 3\n" + ), + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 2 3\n" + ), + ], + ) + node = parse_string("1\n", parser_class) + self.assertEqual( + node, + [ + TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"), + [], + TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"), + ], + ) + + def test_repeat_0_complex(self) -> None: + grammar = """ + start: term ('+' term)* NEWLINE + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1 + 2 + 3\n", parser_class) + self.assertEqual( + node, + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n" + ), + [ + [ + TokenInfo( + OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n" + ), + TokenInfo( + NUMBER, + string="2", + start=(1, 4), + end=(1, 5), + line="1 + 2 + 3\n", + ), + ], + [ + TokenInfo( + OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n" + ), + TokenInfo( + NUMBER, + string="3", + start=(1, 8), + end=(1, 9), + line="1 + 2 + 3\n", + ), + ], + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n" + ), + ], + ) + + def test_repeat_1_simple(self) -> None: + grammar = """ + start: thing thing+ NEWLINE + thing: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1 2 3\n", parser_class) + self.assertEqual( + node, + [ + TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 2 3\n"), + [ + TokenInfo( + NUMBER, string="2", start=(1, 2), end=(1, 3), line="1 2 3\n" + ), + TokenInfo( + NUMBER, string="3", start=(1, 4), end=(1, 5), line="1 2 3\n" + ), + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 2 3\n" + ), + ], + ) + with self.assertRaises(SyntaxError): + parse_string("1\n", parser_class) + + def test_repeat_1_complex(self) -> None: + grammar = """ + start: term ('+' term)+ NEWLINE + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1 + 2 + 3\n", parser_class) + self.assertEqual( + node, + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n" + ), + [ + [ + TokenInfo( + OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n" + ), + TokenInfo( + NUMBER, + string="2", + start=(1, 4), + end=(1, 5), + line="1 + 2 + 3\n", + ), + ], + [ + TokenInfo( + OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n" + ), + TokenInfo( + NUMBER, + string="3", + start=(1, 8), + end=(1, 9), + line="1 + 2 + 3\n", + ), + ], + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n" + ), + ], + ) + with self.assertRaises(SyntaxError): + parse_string("1\n", parser_class) + + def test_repeat_with_sep_simple(self) -> None: + grammar = """ + start: ','.thing+ NEWLINE + thing: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("1, 2, 3\n", parser_class) + self.assertEqual( + node, + [ + [ + TokenInfo( + NUMBER, string="1", start=(1, 0), end=(1, 1), line="1, 2, 3\n" + ), + TokenInfo( + NUMBER, string="2", start=(1, 3), end=(1, 4), line="1, 2, 3\n" + ), + TokenInfo( + NUMBER, string="3", start=(1, 6), end=(1, 7), line="1, 2, 3\n" + ), + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 7), end=(1, 8), line="1, 2, 3\n" + ), + ], + ) + + def test_left_recursive(self) -> None: + grammar_source = """ + start: expr NEWLINE + expr: ('-' term | expr '+' term | term) + term: NUMBER + foo: NAME+ + bar: NAME* + baz: NAME? + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + parser_class = generate_parser(grammar) + rules = grammar.rules + self.assertFalse(rules["start"].left_recursive) + self.assertTrue(rules["expr"].left_recursive) + self.assertFalse(rules["term"].left_recursive) + self.assertFalse(rules["foo"].left_recursive) + self.assertFalse(rules["bar"].left_recursive) + self.assertFalse(rules["baz"].left_recursive) + node = parse_string("1 + 2 + 3\n", parser_class) + self.assertEqual( + node, + [ + [ + [ + TokenInfo( + NUMBER, + string="1", + start=(1, 0), + end=(1, 1), + line="1 + 2 + 3\n", + ), + TokenInfo( + OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n" + ), + TokenInfo( + NUMBER, + string="2", + start=(1, 4), + end=(1, 5), + line="1 + 2 + 3\n", + ), + ], + TokenInfo( + OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n" + ), + TokenInfo( + NUMBER, string="3", start=(1, 8), end=(1, 9), line="1 + 2 + 3\n" + ), + ], + TokenInfo( + NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n" + ), + ], + ) + + def test_python_expr(self) -> None: + grammar = """ + start: expr NEWLINE? $ { ast.Expression(expr) } + expr: ( expr '+' term { ast.BinOp(expr, ast.Add(), term, lineno=expr.lineno, col_offset=expr.col_offset, end_lineno=term.end_lineno, end_col_offset=term.end_col_offset) } + | expr '-' term { ast.BinOp(expr, ast.Sub(), term, lineno=expr.lineno, col_offset=expr.col_offset, end_lineno=term.end_lineno, end_col_offset=term.end_col_offset) } + | term { term } + ) + term: ( l=term '*' r=factor { ast.BinOp(l, ast.Mult(), r, lineno=l.lineno, col_offset=l.col_offset, end_lineno=r.end_lineno, end_col_offset=r.end_col_offset) } + | l=term '/' r=factor { ast.BinOp(l, ast.Div(), r, lineno=l.lineno, col_offset=l.col_offset, end_lineno=r.end_lineno, end_col_offset=r.end_col_offset) } + | factor { factor } + ) + factor: ( '(' expr ')' { expr } + | atom { atom } + ) + atom: ( n=NAME { ast.Name(id=n.string, ctx=ast.Load(), lineno=n.start[0], col_offset=n.start[1], end_lineno=n.end[0], end_col_offset=n.end[1]) } + | n=NUMBER { ast.Constant(value=ast.literal_eval(n.string), lineno=n.start[0], col_offset=n.start[1], end_lineno=n.end[0], end_col_offset=n.end[1]) } + ) + """ + parser_class = make_parser(grammar) + node = parse_string("(1 + 2*3 + 5)/(6 - 2)\n", parser_class) + code = compile(node, "", "eval") + val = eval(code) + self.assertEqual(val, 3.0) + + def test_f_string_in_action(self) -> None: + grammar = """ + start: n=NAME NEWLINE? $ { f"name -> {n.string}" } + """ + parser_class = make_parser(grammar) + node = parse_string("a", parser_class) + self.assertEqual(node.strip(), "name -> a") + + def test_nullable(self) -> None: + grammar_source = """ + start: sign NUMBER + sign: ['-' | '+'] + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + rules = grammar.rules + nullables = compute_nullables(rules) + self.assertNotIn(rules["start"], nullables) # Not None! + self.assertIn(rules["sign"], nullables) + + def test_advanced_left_recursive(self) -> None: + grammar_source = """ + start: NUMBER | sign start + sign: ['-'] + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + rules = grammar.rules + nullables = compute_nullables(rules) + compute_left_recursives(rules) + self.assertNotIn(rules["start"], nullables) # Not None! + self.assertIn(rules["sign"], nullables) + self.assertTrue(rules["start"].left_recursive) + self.assertFalse(rules["sign"].left_recursive) + + def test_mutually_left_recursive(self) -> None: + grammar_source = """ + start: foo 'E' + foo: bar 'A' | 'B' + bar: foo 'C' | 'D' + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + out = io.StringIO() + genr = PythonParserGenerator(grammar, out) + rules = grammar.rules + self.assertFalse(rules["start"].left_recursive) + self.assertTrue(rules["foo"].left_recursive) + self.assertTrue(rules["bar"].left_recursive) + genr.generate("") + ns: Dict[str, Any] = {} + exec(out.getvalue(), ns) + parser_class: Type[Parser] = ns["GeneratedParser"] + node = parse_string("D A C A E", parser_class) + + self.assertEqual( + node, + [ + [ + [ + [ + TokenInfo( + type=NAME, + string="D", + start=(1, 0), + end=(1, 1), + line="D A C A E", + ), + TokenInfo( + type=NAME, + string="A", + start=(1, 2), + end=(1, 3), + line="D A C A E", + ), + ], + TokenInfo( + type=NAME, + string="C", + start=(1, 4), + end=(1, 5), + line="D A C A E", + ), + ], + TokenInfo( + type=NAME, + string="A", + start=(1, 6), + end=(1, 7), + line="D A C A E", + ), + ], + TokenInfo( + type=NAME, string="E", start=(1, 8), end=(1, 9), line="D A C A E" + ), + ], + ) + node = parse_string("B C A E", parser_class) + self.assertEqual( + node, + [ + [ + [ + TokenInfo( + type=NAME, + string="B", + start=(1, 0), + end=(1, 1), + line="B C A E", + ), + TokenInfo( + type=NAME, + string="C", + start=(1, 2), + end=(1, 3), + line="B C A E", + ), + ], + TokenInfo( + type=NAME, string="A", start=(1, 4), end=(1, 5), line="B C A E" + ), + ], + TokenInfo( + type=NAME, string="E", start=(1, 6), end=(1, 7), line="B C A E" + ), + ], + ) + + def test_nasty_mutually_left_recursive(self) -> None: + # This grammar does not recognize 'x - + =', much to my chagrin. + # But that's the way PEG works. + # [Breathlessly] + # The problem is that the toplevel target call + # recurses into maybe, which recognizes 'x - +', + # and then the toplevel target looks for another '+', + # which fails, so it retreats to NAME, + # which succeeds, so we end up just recognizing 'x', + # and then start fails because there's no '=' after that. + grammar_source = """ + start: target '=' + target: maybe '+' | NAME + maybe: maybe '-' | target + """ + grammar: Grammar = parse_string(grammar_source, GrammarParser) + out = io.StringIO() + genr = PythonParserGenerator(grammar, out) + genr.generate("") + ns: Dict[str, Any] = {} + exec(out.getvalue(), ns) + parser_class = ns["GeneratedParser"] + with self.assertRaises(SyntaxError): + parse_string("x - + =", parser_class) + + def test_lookahead(self) -> None: + grammar = """ + start: (expr_stmt | assign_stmt) &'.' + expr_stmt: !(target '=') expr + assign_stmt: target '=' expr + expr: term ('+' term)* + target: NAME + term: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("foo = 12 + 12 .", parser_class) + self.maxDiff = None + self.assertEqual( + node, + [ + TokenInfo( + NAME, string="foo", start=(1, 0), end=(1, 3), line="foo = 12 + 12 ." + ), + TokenInfo( + OP, string="=", start=(1, 4), end=(1, 5), line="foo = 12 + 12 ." + ), + [ + TokenInfo( + NUMBER, + string="12", + start=(1, 6), + end=(1, 8), + line="foo = 12 + 12 .", + ), + [ + [ + TokenInfo( + OP, + string="+", + start=(1, 9), + end=(1, 10), + line="foo = 12 + 12 .", + ), + TokenInfo( + NUMBER, + string="12", + start=(1, 11), + end=(1, 13), + line="foo = 12 + 12 .", + ), + ] + ], + ], + ], + ) + + def test_named_lookahead_error(self) -> None: + grammar = """ + start: foo=!'x' NAME + """ + with self.assertRaises(SyntaxError): + make_parser(grammar) + + def test_start_leader(self) -> None: + grammar = """ + start: attr | NAME + attr: start '.' NAME + """ + # Would assert False without a special case in compute_left_recursives(). + make_parser(grammar) + + def test_opt_sequence(self) -> None: + grammar = """ + start: [NAME*] + """ + # This case was failing because of a double trailing comma at the end + # of a line in the generated source. See bpo-41044 + make_parser(grammar) + + def test_left_recursion_too_complex(self) -> None: + grammar = """ + start: foo + foo: bar '+' | baz '+' | '+' + bar: baz '-' | foo '-' | '-' + baz: foo '*' | bar '*' | '*' + """ + with self.assertRaises(ValueError) as errinfo: + make_parser(grammar) + self.assertTrue("no leader" in str(errinfo.exception.value)) + + def test_cut(self) -> None: + grammar = """ + start: '(' ~ expr ')' + expr: NUMBER + """ + parser_class = make_parser(grammar) + node = parse_string("(1)", parser_class) + self.assertEqual( + node, + [ + TokenInfo(OP, string="(", start=(1, 0), end=(1, 1), line="(1)"), + TokenInfo(NUMBER, string="1", start=(1, 1), end=(1, 2), line="(1)"), + TokenInfo(OP, string=")", start=(1, 2), end=(1, 3), line="(1)"), + ], + ) + + def test_cut_is_local_in_rule(self) -> None: + grammar = """ + start: + | inner + | 'x' { "ok" } + inner: + | 'x' ~ 'y' + | 'x' + """ + parser_class = make_parser(grammar) + node = parse_string("x", parser_class) + self.assertEqual(node, 'ok') + + def test_cut_is_local_in_parens(self) -> None: + # we currently don't guarantee this behavior, see gh-143054 + grammar = """ + start: + | ('x' ~ 'y' | 'x') + | 'x' { "ok" } + """ + parser_class = make_parser(grammar) + node = parse_string("x", parser_class) + self.assertEqual(node, 'ok') + + def test_dangling_reference(self) -> None: + grammar = """ + start: foo ENDMARKER + foo: bar NAME + """ + with self.assertRaises(GrammarError): + parser_class = make_parser(grammar) + + def test_bad_token_reference(self) -> None: + grammar = """ + start: foo + foo: NAMEE + """ + with self.assertRaises(GrammarError): + parser_class = make_parser(grammar) + + def test_missing_start(self) -> None: + grammar = """ + foo: NAME + """ + with self.assertRaises(GrammarError): + parser_class = make_parser(grammar) + + def test_invalid_rule_name(self) -> None: + grammar = """ + start: _a b + _a: 'a' + b: 'b' + """ + with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_a'"): + parser_class = make_parser(grammar) + + def test_invalid_variable_name(self) -> None: + grammar = """ + start: a b + a: _x='a' + b: 'b' + """ + with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_x'"): + parser_class = make_parser(grammar) + + def test_invalid_variable_name_in_temporal_rule(self) -> None: + grammar = """ + start: a b + a: (_x='a' | 'b') | 'c' + b: 'b' + """ + with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_x'"): + parser_class = make_parser(grammar) + + def test_soft_keyword(self) -> None: + grammar = """ + start: + | "number" n=NUMBER { eval(n.string) } + | "string" n=STRING { n.string } + | SOFT_KEYWORD l=NAME n=(NUMBER | NAME | STRING) { l.string + " = " + n.string } + """ + parser_class = make_parser(grammar) + self.assertEqual(parse_string("number 1", parser_class), 1) + self.assertEqual(parse_string("string 'b'", parser_class), "'b'") + self.assertEqual( + parse_string("number test 1", parser_class), "test = 1" + ) + assert ( + parse_string("string test 'b'", parser_class) == "test = 'b'" + ) + with self.assertRaises(SyntaxError): + parse_string("test 1", parser_class) + + def test_forced(self) -> None: + grammar = """ + start: NAME &&':' | NAME + """ + parser_class = make_parser(grammar) + self.assertTrue(parse_string("number :", parser_class)) + with self.assertRaises(SyntaxError) as e: + parse_string("a", parser_class) + + self.assertIn("expected ':'", str(e.exception)) + + def test_forced_with_group(self) -> None: + grammar = """ + start: NAME &&(':' | ';') | NAME + """ + parser_class = make_parser(grammar) + self.assertTrue(parse_string("number :", parser_class)) + self.assertTrue(parse_string("number ;", parser_class)) + with self.assertRaises(SyntaxError) as e: + parse_string("a", parser_class) + self.assertIn("expected (':' | ';')", e.exception.args[0]) + + def test_unreachable_explicit(self) -> None: + source = """ + start: NAME { UNREACHABLE } + """ + grammar = parse_string(source, GrammarParser) + out = io.StringIO() + genr = PythonParserGenerator( + grammar, out, unreachable_formatting="This is a test" + ) + genr.generate("") + self.assertIn("This is a test", out.getvalue()) + + def test_unreachable_implicit1(self) -> None: + source = """ + start: NAME | invalid_input + invalid_input: NUMBER { None } + """ + grammar = parse_string(source, GrammarParser) + out = io.StringIO() + genr = PythonParserGenerator( + grammar, out, unreachable_formatting="This is a test" + ) + genr.generate("") + self.assertIn("This is a test", out.getvalue()) + + def test_unreachable_implicit2(self) -> None: + source = """ + start: NAME | '(' invalid_input ')' + invalid_input: NUMBER { None } + """ + grammar = parse_string(source, GrammarParser) + out = io.StringIO() + genr = PythonParserGenerator( + grammar, out, unreachable_formatting="This is a test" + ) + genr.generate("") + self.assertIn("This is a test", out.getvalue()) + + def test_unreachable_implicit3(self) -> None: + source = """ + start: NAME | invalid_input { None } + invalid_input: NUMBER + """ + grammar = parse_string(source, GrammarParser) + out = io.StringIO() + genr = PythonParserGenerator( + grammar, out, unreachable_formatting="This is a test" + ) + genr.generate("") + self.assertNotIn("This is a test", out.getvalue()) + + def test_locations_in_alt_action_and_group(self) -> None: + grammar = """ + start: t=term NEWLINE? $ { ast.Expression(t) } + term: + | l=term '*' r=factor { ast.BinOp(l, ast.Mult(), r, LOCATIONS) } + | l=term '/' r=factor { ast.BinOp(l, ast.Div(), r, LOCATIONS) } + | factor + factor: + | ( + n=NAME { ast.Name(id=n.string, ctx=ast.Load(), LOCATIONS) } | + n=NUMBER { ast.Constant(value=ast.literal_eval(n.string), LOCATIONS) } + ) + """ + parser_class = make_parser(grammar) + source = "2*3\n" + o = ast.dump(parse_string(source, parser_class).body, include_attributes=True) + p = ast.dump(ast.parse(source).body[0].value, include_attributes=True).replace( + " kind=None,", "" + ) + diff = "\n".join( + difflib.unified_diff( + o.split("\n"), p.split("\n"), "cpython", "python-pegen" + ) + ) + self.assertFalse(diff) + + +class TestGrammarVisitor: + class Visitor(GrammarVisitor): + def __init__(self) -> None: + self.n_nodes = 0 + + def visit(self, node: Any, *args: Any, **kwargs: Any) -> None: + self.n_nodes += 1 + super().visit(node, *args, **kwargs) + + def test_parse_trivial_grammar(self) -> None: + grammar = """ + start: 'a' + """ + rules = parse_string(grammar, GrammarParser) + visitor = self.Visitor() + + visitor.visit(rules) + + self.assertEqual(visitor.n_nodes, 6) + + def test_parse_or_grammar(self) -> None: + grammar = """ + start: rule + rule: 'a' | 'b' + """ + rules = parse_string(grammar, GrammarParser) + visitor = self.Visitor() + + visitor.visit(rules) + + # Grammar/Rule/Rhs/Alt/NamedItem/NameLeaf -> 6 + # Rule/Rhs/ -> 2 + # Alt/NamedItem/StringLeaf -> 3 + # Alt/NamedItem/StringLeaf -> 3 + + self.assertEqual(visitor.n_nodes, 14) + + def test_parse_repeat1_grammar(self) -> None: + grammar = """ + start: 'a'+ + """ + rules = parse_string(grammar, GrammarParser) + visitor = self.Visitor() + + visitor.visit(rules) + + # Grammar/Rule/Rhs/Alt/NamedItem/Repeat1/StringLeaf -> 6 + self.assertEqual(visitor.n_nodes, 7) + + def test_parse_repeat0_grammar(self) -> None: + grammar = """ + start: 'a'* + """ + rules = parse_string(grammar, GrammarParser) + visitor = self.Visitor() + + visitor.visit(rules) + + # Grammar/Rule/Rhs/Alt/NamedItem/Repeat0/StringLeaf -> 6 + + self.assertEqual(visitor.n_nodes, 7) + + def test_parse_optional_grammar(self) -> None: + grammar = """ + start: 'a' ['b'] + """ + rules = parse_string(grammar, GrammarParser) + visitor = self.Visitor() + + visitor.visit(rules) + + # Grammar/Rule/Rhs/Alt/NamedItem/StringLeaf -> 6 + # NamedItem/Opt/Rhs/Alt/NamedItem/Stringleaf -> 6 + + self.assertEqual(visitor.n_nodes, 12) + + +class TestGrammarVisualizer(unittest.TestCase): + def test_simple_rule(self) -> None: + grammar = """ + start: 'a' 'b' + """ + rules = parse_string(grammar, GrammarParser) + + printer = ASTGrammarPrinter() + lines: List[str] = [] + printer.print_grammar_ast(rules, printer=lines.append) + + output = "\n".join(lines) + expected_output = textwrap.dedent( + """\ + └──Rule + └──Rhs + └──Alt + ├──NamedItem + │ └──StringLeaf("'a'") + └──NamedItem + └──StringLeaf("'b'") + """ + ) + + self.assertEqual(output, expected_output) + + def test_multiple_rules(self) -> None: + grammar = """ + start: a b + a: 'a' + b: 'b' + """ + rules = parse_string(grammar, GrammarParser) + + printer = ASTGrammarPrinter() + lines: List[str] = [] + printer.print_grammar_ast(rules, printer=lines.append) + + output = "\n".join(lines) + expected_output = textwrap.dedent( + """\ + └──Rule + └──Rhs + └──Alt + ├──NamedItem + │ └──NameLeaf('a') + └──NamedItem + └──NameLeaf('b') + + └──Rule + └──Rhs + └──Alt + └──NamedItem + └──StringLeaf("'a'") + + └──Rule + └──Rhs + └──Alt + └──NamedItem + └──StringLeaf("'b'") + """ + ) + + self.assertEqual(output, expected_output) + + def test_deep_nested_rule(self) -> None: + grammar = """ + start: 'a' ['b'['c'['d']]] + """ + rules = parse_string(grammar, GrammarParser) + + printer = ASTGrammarPrinter() + lines: List[str] = [] + printer.print_grammar_ast(rules, printer=lines.append) + + output = "\n".join(lines) + expected_output = textwrap.dedent( + """\ + └──Rule + └──Rhs + └──Alt + ├──NamedItem + │ └──StringLeaf("'a'") + └──NamedItem + └──Opt + └──Rhs + └──Alt + ├──NamedItem + │ └──StringLeaf("'b'") + └──NamedItem + └──Opt + └──Rhs + └──Alt + ├──NamedItem + │ └──StringLeaf("'c'") + └──NamedItem + └──Opt + └──Rhs + └──Alt + └──NamedItem + └──StringLeaf("'d'") + """ + ) + + self.assertEqual(output, expected_output) diff --git a/Lib/test/test_pep646_syntax.py b/Lib/test/test_pep646_syntax.py index ca8e7d62057..8034bb9e935 100644 --- a/Lib/test/test_pep646_syntax.py +++ b/Lib/test/test_pep646_syntax.py @@ -312,7 +312,7 @@ >>> f4.__annotations__ {'args': StarredB, 'arg1': } - >>> def f5(*args: *b = (1,)): pass # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE + >>> def f5(*args: *b = (1,)): pass # TODO: RUSTPYTHON # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: invalid syntax @@ -320,17 +320,9 @@ __test__ = {'doctests' : doctests} -EXPECTED_FAILURE = doctest.register_optionflag('EXPECTED_FAILURE') # TODO: RUSTPYTHON -class CustomOutputChecker(doctest.OutputChecker): # TODO: RUSTPYTHON - def check_output(self, want, got, optionflags): # TODO: RUSTPYTHON - if optionflags & EXPECTED_FAILURE: # TODO: RUSTPYTHON - if want == got: # TODO: RUSTPYTHON - return False # TODO: RUSTPYTHON - return True # TODO: RUSTPYTHON - return super().check_output(want, got, optionflags) # TODO: RUSTPYTHON - def load_tests(loader, tests, pattern): - tests.addTest(doctest.DocTestSuite(checker=CustomOutputChecker())) # TODO: RUSTPYTHON + from test.support.rustpython import DocTestChecker # TODO: RUSTPYTHON + tests.addTest(doctest.DocTestSuite(checker=DocTestChecker())) # TODO: RUSTPYTHON return tests diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index c9d4a348448..bfb0c173374 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -775,7 +775,6 @@ def test_reverse_name_mapping(self): module, name = mapping(module, name) self.assertEqual((module, name), (module3, name3)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_exceptions(self): self.assertEqual(mapping('exceptions', 'StandardError'), ('builtins', 'Exception')) diff --git a/Lib/test/test_pkg.py b/Lib/test/test_pkg.py index eed0fd1c6b7..d2b724db40d 100644 --- a/Lib/test/test_pkg.py +++ b/Lib/test/test_pkg.py @@ -94,7 +94,7 @@ def mkhier(self, descr): def test_1(self): hier = [("t1", None), ("t1 __init__.py", "")] self.mkhier(hier) - import t1 + import t1 # noqa: F401 def test_2(self): hier = [ @@ -124,7 +124,7 @@ def test_2(self): from t2 import sub from t2.sub import subsub - from t2.sub.subsub import spam + from t2.sub.subsub import spam # noqa: F401 self.assertEqual(sub.__name__, "t2.sub") self.assertEqual(subsub.__name__, "t2.sub.subsub") self.assertEqual(sub.subsub.__name__, "t2.sub.subsub") @@ -190,7 +190,6 @@ def test_5(self): ] self.mkhier(hier) - import t5 s = """ from t5 import * self.assertEqual(dir(), ['foo', 'self', 'string', 't5']) diff --git a/Lib/test/test_positional_only_arg.py b/Lib/test/test_positional_only_arg.py index 1817592ca25..eea0625012d 100644 --- a/Lib/test/test_positional_only_arg.py +++ b/Lib/test/test_positional_only_arg.py @@ -23,7 +23,6 @@ def assertRaisesSyntaxError(self, codestr, regex="invalid syntax"): with self.assertRaisesRegex(SyntaxError, regex): compile(codestr + "\n", "", "single") - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_invalid_syntax_errors(self): check_syntax_error(self, "def f(a, b = 5, /, c): pass", "parameter without a default follows parameter with a default") check_syntax_error(self, "def f(a = 5, b, /, c): pass", "parameter without a default follows parameter with a default") @@ -46,7 +45,6 @@ def test_invalid_syntax_errors(self): check_syntax_error(self, "def f(a, /, c, /, d, *, e): pass") check_syntax_error(self, "def f(a, *, c, /, d, e): pass") - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_invalid_syntax_errors_async(self): check_syntax_error(self, "async def f(a, b = 5, /, c): pass", "parameter without a default follows parameter with a default") check_syntax_error(self, "async def f(a = 5, b, /, c): pass", "parameter without a default follows parameter with a default") @@ -235,7 +233,6 @@ def test_lambdas(self): x = lambda a, b, /, : a + b self.assertEqual(x(1, 2), 3) - @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message def test_invalid_syntax_lambda(self): check_syntax_error(self, "lambda a, b = 5, /, c: None", "parameter without a default follows parameter with a default") check_syntax_error(self, "lambda a = 5, b, /, c: None", "parameter without a default follows parameter with a default") diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py index f00f24204b4..c4788f47a06 100644 --- a/Lib/test/test_py_compile.py +++ b/Lib/test/test_py_compile.py @@ -132,7 +132,6 @@ def test_exceptions_propagate(self): finally: os.chmod(self.directory, mode.st_mode) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_coding(self): bad_coding = os.path.join(os.path.dirname(__file__), 'tokenizedata', @@ -198,7 +197,6 @@ def test_invalidation_mode(self): fp.read(), 'test', {}) self.assertEqual(flags, 0b1) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_quiet(self): bad_coding = os.path.join(os.path.dirname(__file__), 'tokenizedata', diff --git a/Lib/test/test_pyclbr.py b/Lib/test/test_pyclbr.py index 9e7a67ebee5..1c0d8619227 100644 --- a/Lib/test/test_pyclbr.py +++ b/Lib/test/test_pyclbr.py @@ -177,7 +177,6 @@ def test_easy(self): "DocTestCase", '_DocTestSuite')) self.checkModule('difflib', ignore=("Match",)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_cases(self): # see test.pyclbr_input for the rationale behind the ignored symbols self.checkModule('test.pyclbr_input', ignore=['om', 'f']) diff --git a/Lib/test/test_rlcompleter.py b/Lib/test/test_rlcompleter.py index ffadfee2763..6db31df891b 100644 --- a/Lib/test/test_rlcompleter.py +++ b/Lib/test/test_rlcompleter.py @@ -48,8 +48,7 @@ def test_global_matches(self): self.assertEqual(self.completer.global_matches('CompleteM'), ['CompleteMe(' if MISSING_C_DOCSTRINGS else 'CompleteMe()']) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_attr_matches(self): # test with builtins namespace self.assertEqual(self.stdcompleter.attr_matches('str.s'), @@ -90,7 +89,7 @@ def create_expected_for_none(): ['CompleteMe._ham']) matches = self.completer.attr_matches('CompleteMe.__') for x in matches: - self.assertTrue(x.startswith('CompleteMe.__'), x) + self.assertStartsWith(x, 'CompleteMe.__') self.assertIn('CompleteMe.__name__', matches) self.assertIn('CompleteMe.__new__(', matches) diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index de36b4525c5..d88f2c598f1 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -363,7 +363,7 @@ class C(object): gc.collect() self.assertTrue(ref() is None, "Cycle was not collected") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.skipIf("RUSTPYTHON_SKIP_ENV_POLLUTERS" in __import__("os").environ, "TODO: RUSTPYTHON") def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, self.thetype) diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index 62c80aab4b3..fb1a7d876a6 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -2110,8 +2110,6 @@ def test_make_zipfile_rootdir_nodir(self): def check_unpack_archive(self, format, **kwargs): self.check_unpack_archive_with_converter( format, lambda path: path, **kwargs) - self.check_unpack_archive_with_converter( - format, FakePath, **kwargs) self.check_unpack_archive_with_converter(format, FakePath, **kwargs) def check_unpack_archive_with_converter(self, format, converter, **kwargs): @@ -2168,6 +2166,71 @@ def test_unpack_archive_zip(self): with self.assertRaises(TypeError): self.check_unpack_archive('zip', filter='data') + def test_unpack_archive_zip_badpaths(self): + srcdir = self.mkdtemp() + zipname = os.path.join(srcdir, 'test.zip') + abspath = os.path.join(srcdir, 'abspath') + with zipfile.ZipFile(zipname, 'w') as zf: + zf.writestr(abspath, 'badfile') + zf.writestr(os.sep + abspath, 'badfile') + zf.writestr('/abspath', 'badfile') + zf.writestr('C:/abspath', 'badfile') + zf.writestr('D:\\abspath', 'badfile') + zf.writestr('E:abspath', 'badfile') + zf.writestr('F:/G:/abspath', 'badfile') + zf.writestr('//server/share/abspath', 'badfile') + zf.writestr('\\\\server2\\share\\abspath', 'badfile') + zf.writestr('../relpath', 'badfile') + zf.writestr(os.pardir + os.sep + 'relpath2', 'badfile') + zf.writestr('good/file', 'goodfile') + zf.writestr('good..file', 'goodfile') + + dstdir = os.path.join(self.mkdtemp(), 'dst') + unpack_archive(zipname, dstdir) + self.assertTrue(os.path.isfile(os.path.join(dstdir, 'good', 'file'))) + self.assertTrue(os.path.isfile(os.path.join(dstdir, 'good..file'))) + self.assertFalse(os.path.exists(abspath)) + self.assertFalse(os.path.exists(os.path.join(dstdir, 'abspath'))) + self.assertFalse(os.path.exists(os.path.join(dstdir, 'G_'))) + self.assertFalse(os.path.exists(os.path.join(dstdir, 'server'))) + if os.name != 'nt': + self.assertTrue(os.path.isfile(os.path.join(dstdir, 'C:', 'abspath'))) + self.assertTrue(os.path.isfile(os.path.join(dstdir, 'D:\\abspath'))) + self.assertTrue(os.path.isfile(os.path.join(dstdir, 'E:abspath'))) + self.assertTrue(os.path.isfile(os.path.join(dstdir, 'F:', 'G:', 'abspath'))) + self.assertTrue(os.path.isfile(os.path.join(dstdir, '\\\\server2\\share\\abspath'))) + if os.pardir == '..': + self.assertFalse(os.path.exists(os.path.join(dstdir, '..', 'relpath'))) + self.assertFalse(os.path.exists(os.path.join(dstdir, 'relpath'))) + else: + self.assertTrue(os.path.isfile(os.path.join(dstdir, '..', 'relpath'))) + self.assertFalse(os.path.exists(os.path.join(dstdir, os.pardir, 'relpath2'))) + self.assertFalse(os.path.exists(os.path.join(dstdir, 'relpath2'))) + + dstdir2 = os.path.join(self.mkdtemp(), 'dst') + os.mkdir(dstdir2) + with os_helper.change_cwd(dstdir2): + unpack_archive(zipname, '') + self.assertTrue(os.path.isfile(os.path.join('good', 'file'))) + self.assertTrue(os.path.isfile('good..file')) + self.assertFalse(os.path.exists(abspath)) + self.assertFalse(os.path.exists('abspath')) + self.assertFalse(os.path.exists('C_')) + self.assertFalse(os.path.exists('server')) + if os.name != 'nt': + self.assertTrue(os.path.isfile(os.path.join('C:', 'abspath'))) + self.assertTrue(os.path.isfile('D:\\abspath')) + self.assertTrue(os.path.isfile('E:abspath')) + self.assertTrue(os.path.isfile(os.path.join('F:', 'G:', 'abspath'))) + self.assertTrue(os.path.isfile('\\\\server2\\share\\abspath')) + if os.pardir == '..': + self.assertFalse(os.path.exists(os.path.join('..', 'relpath'))) + self.assertFalse(os.path.exists('relpath')) + else: + self.assertTrue(os.path.isfile(os.path.join('..', 'relpath'))) + self.assertFalse(os.path.exists(os.path.join(os.pardir, 'relpath2'))) + self.assertFalse(os.path.exists('relpath2')) + def test_unpack_registry(self): formats = get_unpack_formats() diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py index 6de7e73c399..c35a2293f79 100644 --- a/Lib/test/test_slice.py +++ b/Lib/test/test_slice.py @@ -286,7 +286,6 @@ def test_deepcopy(self): self.assertIsNot(s.stop, c.stop) self.assertIsNot(s.step, c.step) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_cycle(self): class myobj(): pass o = myobj() diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 6235c8e74cf..2ca356606b2 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -43,7 +43,7 @@ def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): raise RuntimeError("timed out on %r" % (sock,)) -@test.support.requires_fork() # TODO: RUSTPYTHON, os.fork is currently only supported on Unix-based systems +@test.support.requires_fork() @contextlib.contextmanager def simple_subprocess(testcase): """Tests that a custom child process is not waited on (Issue 1540386)""" @@ -218,12 +218,16 @@ def test_ForkingUDPServer(self): self.dgram_examine) @requires_unix_sockets + @unittest.skipIf(test.support.is_apple_mobile and test.support.on_github_actions, + "gh-140702: Test fails regularly on iOS simulator on GitHub Actions") def test_UnixDatagramServer(self): self.run_server(socketserver.UnixDatagramServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_unix_sockets + @unittest.skipIf(test.support.is_apple_mobile and test.support.on_github_actions, + "gh-140702: Test fails regularly on iOS simulator on GitHub Actions") def test_ThreadingUnixDatagramServer(self): self.run_server(socketserver.ThreadingUnixDatagramServer, socketserver.DatagramRequestHandler, diff --git a/Lib/test/test_sort.py b/Lib/test/test_sort.py index dfb050a34ec..3a6f4f1c05b 100644 --- a/Lib/test/test_sort.py +++ b/Lib/test/test_sort.py @@ -326,7 +326,7 @@ def test_safe_object_compare(self): for L in float_int_lists: check_against_PyObject_RichCompareBool(self, L) - @support.cpython_only # XXX RUSTPYTHON: added by us but it seems like an implementation detail + @unittest.skip("TODO: RUSTPYTHON; not really a todo, it seems like an implementation detail") def test_unsafe_object_compare(self): # This test is by ppperry. It ensures that unsafe_object_compare is diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 8962fe00ed8..2174e14e7cb 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -1444,7 +1444,6 @@ def test_blob_get_item_error_bigint(self): with self.assertRaisesRegex(IndexError, "cannot fit 'int'"): self.blob[_testcapi.ULLONG_MAX] - @unittest.expectedFailure # TODO: RUSTPYTHON def test_blob_set_item_error(self): with self.assertRaisesRegex(TypeError, "cannot be interpreted"): self.blob[0] = b"multiple" diff --git a/Lib/test/test_str.py b/Lib/test/test_str.py index a4859643578..baef294285f 100644 --- a/Lib/test/test_str.py +++ b/Lib/test/test_str.py @@ -2607,7 +2607,7 @@ def test_compare(self): self.assertTrue(astral >= bmp2) self.assertFalse(astral >= astral2) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: False is not true + @unittest.skip("TODO: RUSTPYTHON; hangs") def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, str) if not support.Py_GIL_DISABLED: diff --git a/Lib/test/test_strtod.py b/Lib/test/test_strtod.py index 03c8afa51ef..f263b7ab4f1 100644 --- a/Lib/test/test_strtod.py +++ b/Lib/test/test_strtod.py @@ -173,7 +173,6 @@ def test_halfway_cases(self): s = '{}e{}'.format(digits, exponent) self.check_strtod(s) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_boundaries(self): # boundaries expressed as triples (n, e, u), where # n*10**e is an approximation to the boundary value and @@ -194,7 +193,6 @@ def test_boundaries(self): u *= 10 e -= 1 - @unittest.expectedFailure # TODO: RUSTPYTHON def test_underflow_boundary(self): # test values close to 2**-1075, the underflow boundary; similar # to boundary_tests, except that the random error doesn't scale @@ -206,7 +204,6 @@ def test_underflow_boundary(self): s = '{}e{}'.format(digits, exponent) self.check_strtod(s) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bigcomp(self): for ndigs in 5, 10, 14, 15, 16, 17, 18, 19, 20, 40, 41, 50: dig10 = 10**ndigs @@ -284,7 +281,6 @@ def negative_exp(n): self.assertEqual(float(negative_exp(20000)), 1.0) self.assertEqual(float(negative_exp(30000)), 1.0) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_particular(self): # inputs that produced crashes or incorrectly rounded results with # previous versions of dtoa.c, for various reasons diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index a9fe193028e..8ef6dd2fee8 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -1,6 +1,12 @@ +import copy +import gc import os +import pickle +import re +import textwrap import time import unittest +from test.support import script_helper class StructSeqTest(unittest.TestCase): @@ -37,7 +43,7 @@ def test_repr(self): # os.stat() gives a complicated struct sequence. st = os.stat(__file__) rep = repr(st) - self.assertTrue(rep.startswith("os.stat_result")) + self.assertStartsWith(rep, "os.stat_result") self.assertIn("st_mode=", rep) self.assertIn("st_ino=", rep) self.assertIn("st_dev=", rep) @@ -81,6 +87,7 @@ def test_fields(self): self.assertEqual(t.n_unnamed_fields, 0) self.assertEqual(t.n_fields, time._STRUCT_TM_ITEMS) + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument dict def test_constructor(self): t = time.struct_time @@ -89,10 +96,72 @@ def test_constructor(self): self.assertRaises(TypeError, t, "123") self.assertRaises(TypeError, t, "123", dict={}) self.assertRaises(TypeError, t, "123456789", dict=None) + self.assertRaises(TypeError, t, seq="123456789", dict={}) + + self.assertEqual(t("123456789"), tuple("123456789")) + self.assertEqual(t("123456789", {}), tuple("123456789")) + self.assertEqual(t("123456789", dict={}), tuple("123456789")) + self.assertEqual(t(sequence="123456789", dict={}), tuple("123456789")) + + self.assertEqual(t("1234567890"), tuple("123456789")) + self.assertEqual(t("1234567890").tm_zone, "0") + self.assertEqual(t("123456789", {"tm_zone": "some zone"}), tuple("123456789")) + self.assertEqual(t("123456789", {"tm_zone": "some zone"}).tm_zone, "some zone") s = "123456789" self.assertEqual("".join(t(s)), s) + @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message + def test_constructor_with_duplicate_fields(self): + t = time.struct_time + + error_message = re.escape("got duplicate or unexpected field name(s)") + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"tm_zone": "some zone"}) + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"tm_zone": "some zone", "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"error": 0, "tm_zone": "some zone"}) + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"error": 0, "tm_zone": "some zone", "tm_mon": 1}) + + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 1 arguments, got 2 + def test_constructor_with_duplicate_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + n_visible_fields = os.stat_result.n_sequence_fields + + r = os.stat_result(range(n_visible_fields), {'st_atime': -1.0}) + self.assertEqual(r.st_atime, -1.0) + self.assertEqual(r, tuple(range(n_visible_fields))) + + r = os.stat_result((*range(n_visible_fields), -1.0)) + self.assertEqual(r.st_atime, -1.0) + self.assertEqual(r, tuple(range(n_visible_fields))) + + with self.assertRaisesRegex(TypeError, + re.escape("got duplicate or unexpected field name(s)")): + os.stat_result((*range(n_visible_fields), -1.0), {'st_atime': -1.0}) + + @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message + def test_constructor_with_unknown_fields(self): + t = time.struct_time + + error_message = re.escape("got duplicate or unexpected field name(s)") + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_year": 0}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_year": 0, "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_zone": "some zone", "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_zone": "some zone", "error": 0}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"error": 0, "tm_zone": "some zone", "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"error": 0}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_zone": "some zone", "error": 0}) + def test_eviltuple(self): class Exc(Exception): pass @@ -106,9 +175,80 @@ def __len__(self): self.assertRaises(Exc, time.struct_time, C()) - def test_reduce(self): + def test_pickling(self): t = time.gmtime() - x = t.__reduce__() + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(t, proto) + t2 = pickle.loads(p) + self.assertEqual(t2.__class__, t.__class__) + self.assertEqual(t2, t) + self.assertEqual(t2.tm_year, t.tm_year) + self.assertEqual(t2.tm_zone, t.tm_zone) + + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 1 arguments, got 2 + def test_pickling_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + r = os.stat_result(range(os.stat_result.n_sequence_fields), + {'st_atime': 1.0, 'st_atime_ns': 2.0}) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(r, proto) + r2 = pickle.loads(p) + self.assertEqual(r2.__class__, r.__class__) + self.assertEqual(r2, r) + self.assertEqual(r2.st_mode, r.st_mode) + self.assertEqual(r2.st_atime, r.st_atime) + self.assertEqual(r2.st_atime_ns, r.st_atime_ns) + + def test_copying(self): + n_fields = time.struct_time.n_fields + t = time.struct_time([[i] for i in range(n_fields)]) + + t2 = copy.copy(t) + self.assertEqual(t2.__class__, t.__class__) + self.assertEqual(t2, t) + self.assertEqual(t2.tm_year, t.tm_year) + self.assertEqual(t2.tm_zone, t.tm_zone) + self.assertIs(t2[0], t[0]) + self.assertIs(t2.tm_year, t.tm_year) + + t3 = copy.deepcopy(t) + self.assertEqual(t3.__class__, t.__class__) + self.assertEqual(t3, t) + self.assertEqual(t3.tm_year, t.tm_year) + self.assertEqual(t3.tm_zone, t.tm_zone) + self.assertIsNot(t3[0], t[0]) + self.assertIsNot(t3.tm_year, t.tm_year) + + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 1 arguments, got 2 + def test_copying_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + n_sequence_fields = os.stat_result.n_sequence_fields + r = os.stat_result([[i] for i in range(n_sequence_fields)], + {'st_atime': [1.0], 'st_atime_ns': [2.0]}) + + r2 = copy.copy(r) + self.assertEqual(r2.__class__, r.__class__) + self.assertEqual(r2, r) + self.assertEqual(r2.st_mode, r.st_mode) + self.assertEqual(r2.st_atime, r.st_atime) + self.assertEqual(r2.st_atime_ns, r.st_atime_ns) + self.assertIs(r2[0], r[0]) + self.assertIs(r2.st_mode, r.st_mode) + self.assertIs(r2.st_atime, r.st_atime) + self.assertIs(r2.st_atime_ns, r.st_atime_ns) + + r3 = copy.deepcopy(r) + self.assertEqual(r3.__class__, r.__class__) + self.assertEqual(r3, r) + self.assertEqual(r3.st_mode, r.st_mode) + self.assertEqual(r3.st_atime, r.st_atime) + self.assertEqual(r3.st_atime_ns, r.st_atime_ns) + self.assertIsNot(r3[0], r[0]) + self.assertIsNot(r3.st_mode, r.st_mode) + self.assertIsNot(r3.st_atime, r.st_atime) + self.assertIsNot(r3.st_atime_ns, r.st_atime_ns) def test_extended_getslice(self): # Test extended slicing by comparing with list slicing. @@ -133,6 +273,104 @@ def test_match_args_with_unnamed_fields(self): self.assertEqual(os.stat_result.n_unnamed_fields, 3) self.assertEqual(os.stat_result.__match_args__, expected_args) + def test_copy_replace_all_fields_visible(self): + assert os.times_result.n_unnamed_fields == 0 + assert os.times_result.n_sequence_fields == os.times_result.n_fields + + t = os.times() + + # visible fields + self.assertEqual(copy.replace(t), t) + self.assertIsInstance(copy.replace(t), os.times_result) + self.assertEqual(copy.replace(t, user=1.5), (1.5, *t[1:])) + self.assertEqual(copy.replace(t, system=2.5), (t[0], 2.5, *t[2:])) + self.assertEqual(copy.replace(t, user=1.5, system=2.5), (1.5, 2.5, *t[2:])) + + # unknown fields + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, error=-1) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, user=1, error=-1) + + @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message + def test_copy_replace_with_invisible_fields(self): + assert time.struct_time.n_unnamed_fields == 0 + assert time.struct_time.n_sequence_fields < time.struct_time.n_fields + + t = time.gmtime(0) + + # visible fields + t2 = copy.replace(t) + self.assertEqual(t2, (1970, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertIsInstance(t2, time.struct_time) + t3 = copy.replace(t, tm_year=2000) + self.assertEqual(t3, (2000, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t3.tm_year, 2000) + t4 = copy.replace(t, tm_mon=2) + self.assertEqual(t4, (1970, 2, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t4.tm_mon, 2) + t5 = copy.replace(t, tm_year=2000, tm_mon=2) + self.assertEqual(t5, (2000, 2, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t5.tm_year, 2000) + self.assertEqual(t5.tm_mon, 2) + + # named invisible fields + self.assertHasAttr(t, 'tm_zone') + with self.assertRaisesRegex(AttributeError, 'readonly attribute'): + t.tm_zone = 'some other zone' + self.assertEqual(t2.tm_zone, t.tm_zone) + self.assertEqual(t3.tm_zone, t.tm_zone) + self.assertEqual(t4.tm_zone, t.tm_zone) + t6 = copy.replace(t, tm_zone='some other zone') + self.assertEqual(t, t6) + self.assertEqual(t6.tm_zone, 'some other zone') + t7 = copy.replace(t, tm_year=2000, tm_zone='some other zone') + self.assertEqual(t7, (2000, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t7.tm_year, 2000) + self.assertEqual(t7.tm_zone, 'some other zone') + + # unknown fields + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, error=2) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, tm_year=2000, error=2) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, tm_zone='some other zone', error=2) + + def test_copy_replace_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + r = os.stat_result(range(os.stat_result.n_sequence_fields)) + + error_message = re.escape('__replace__() is not supported') + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, st_mode=1) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, error=2) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, st_mode=1, error=2) + + def test_reference_cycle(self): + # gh-122527: Check that a structseq that's part of a reference cycle + # with its own type doesn't crash. Previously, if the type's dictionary + # was cleared first, the structseq instance would crash in the + # destructor. + script_helper.assert_python_ok("-c", textwrap.dedent(r""" + import time + t = time.gmtime() + type(t).refcyle = t + """)) + + def test_replace_gc_tracked(self): + # Verify that __replace__ results are properly GC-tracked + time_struct = time.gmtime(0) + lst = [] + replaced_struct = time_struct.__replace__(tm_year=lst) + lst.append(replaced_struct) + + self.assertTrue(gc.is_tracked(replaced_struct)) if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 37b5543badf..6b3aa466d06 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -560,16 +560,28 @@ def test_args_from_interpreter_flags(self): # -X options ['-X', 'dev'], ['-Wignore', '-X', 'dev'], + ['-X', 'cpu_count=4'], + ['-X', 'disable-remote-debug'], ['-X', 'faulthandler'], ['-X', 'importtime'], ['-X', 'importtime=2'], + ['-X', 'int_max_str_digits=1000'], + ['-X', 'lazy_imports=all'], + ['-X', 'no_debug_ranges'], ['-X', 'showrefcount'], ['-X', 'tracemalloc'], ['-X', 'tracemalloc=3'], + ['-X', 'warn_default_encoding'], ): with self.subTest(opts=opts): self.check_options(opts, 'args_from_interpreter_flags') + with os_helper.temp_dir() as temp_path: + prefix = os.path.join(temp_path, 'pycache') + opts = ['-X', f'pycache_prefix={prefix}'] + with self.subTest(opts=opts): + self.check_options(opts, 'args_from_interpreter_flags') + self.check_options(['-I', '-E', '-s', '-P'], 'args_from_interpreter_flags', ['-I']) @@ -781,6 +793,7 @@ def test_get_signal_name(self): (128 + int(signal.SIGABRT), 'SIGABRT'), (3221225477, "STATUS_ACCESS_VIOLATION"), (0xC00000FD, "STATUS_STACK_OVERFLOW"), + (0xC0000906, "0xC0000906"), ): self.assertEqual(support.get_signal_name(exitcode), expected, exitcode) diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py index 98246ac2214..0934f22d470 100644 --- a/Lib/test/test_syntax.py +++ b/Lib/test/test_syntax.py @@ -59,11 +59,23 @@ Traceback (most recent call last): SyntaxError: cannot assign to __debug__ +>>> def __debug__(): pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> async def __debug__(): pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> class __debug__: pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + >>> del __debug__ Traceback (most recent call last): SyntaxError: cannot delete __debug__ ->>> f() = 1 +>>> f() = 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? @@ -71,11 +83,11 @@ Traceback (most recent call last): SyntaxError: assignment to yield expression not possible ->>> del f() +>>> del f() # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot delete function call ->>> a + 1 = 2 +>>> a + 1 = 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to expression here. Maybe you meant '==' instead of '='? @@ -108,7 +120,7 @@ This test just checks a couple of cases rather than enumerating all of them. ->>> (a, "b", c) = (1, 2, 3) +>>> (a, "b", c) = (1, 2, 3) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to literal @@ -156,6 +168,18 @@ Traceback (most recent call last): SyntaxError: expected 'else' after 'if' expression +>>> x = 1 if 1 else pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: expected expression after 'else', but statement is given + +>>> x = pass if 1 else 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: expected expression before 'if', but statement is given + +>>> x = pass if 1 else pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: expected expression before 'if', but statement is given + >>> if True: ... print("Hello" ... @@ -176,15 +200,15 @@ Traceback (most recent call last): SyntaxError: assignment to yield expression not possible ->>> a, b += 1, 2 +>>> a, b += 1, 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'tuple' is an illegal expression for augmented assignment ->>> (a, b) += 1, 2 +>>> (a, b) += 1, 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'tuple' is an illegal expression for augmented assignment ->>> [a, b] += 1, 2 +>>> [a, b] += 1, 2 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'list' is an illegal expression for augmented assignment @@ -219,7 +243,7 @@ Traceback (most recent call last): SyntaxError: cannot assign to expression ->>> for i < (): pass +>>> for i < (): pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax @@ -261,11 +285,11 @@ Comprehensions without 'in' keyword: ->>> [x for x if range(1)] +>>> [x for x if range(1)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'in' expected after for-loop variables ->>> tuple(x for x if range(1)) +>>> tuple(x for x if range(1)) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'in' expected after for-loop variables @@ -277,7 +301,7 @@ Traceback (most recent call last): SyntaxError: cannot assign to expression ->>> [x for a, b, (c + 1, d()) if y] +>>> [x for a, b, (c + 1, d()) if y] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'in' expected after for-loop variables @@ -292,14 +316,20 @@ Comprehensions creating tuples without parentheses should produce a specialized error message: ->>> [x,y for x,y in range(100)] +>>> [x,y for x,y in range(100)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: did you forget parentheses around the comprehension target? ->>> {x,y for x,y in range(100)} +>>> {x,y for x,y in range(100)} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: did you forget parentheses around the comprehension target? +# Incorrectly closed strings + +>>> "The interesting object "The important object" is very important" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: invalid syntax. Is this intended to be part of the string? + # Missing commas in literals collections should not # produce special error messages regarding missing # parentheses, but about missing commas instead @@ -323,7 +353,7 @@ # Make sure soft keywords constructs don't raise specialized # errors regarding missing commas or other spezialiced errors ->>> match x: +>>> match x: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... y = 3 Traceback (most recent call last): SyntaxError: invalid syntax @@ -340,7 +370,7 @@ Traceback (most recent call last): SyntaxError: invalid syntax ->>> match ...: +>>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... case {**rest, "key": value}: ... ... Traceback (most recent call last): @@ -355,7 +385,7 @@ # But prefixes of soft keywords should # still raise specialized errors ->>> (mat x) +>>> (mat x) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax. Perhaps you forgot a comma? @@ -383,7 +413,7 @@ Traceback (most recent call last): SyntaxError: invalid syntax ->>> def f(*None): +>>> def f(*None): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: invalid syntax @@ -438,12 +468,12 @@ Traceback (most recent call last): SyntaxError: var-positional argument cannot have default value ->>> def foo(a,**b=3): +>>> def foo(a,**b=3): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: var-keyword argument cannot have default value ->>> def foo(a,**b: int=3): +>>> def foo(a,**b: int=3): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: var-keyword argument cannot have default value @@ -493,22 +523,22 @@ Traceback (most recent call last): SyntaxError: * argument may appear only once ->>> def foo(a=1,/*,b,c): +>>> def foo(a=1,/*,b,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected comma between / and * ->>> def foo(a=1,d=,c): +>>> def foo(a=1,d=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected default value expression ->>> def foo(a,d=,c): +>>> def foo(a,d=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected default value expression ->>> def foo(a,d: int=,c): +>>> def foo(a,d: int=,c): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: expected default value expression @@ -541,7 +571,7 @@ Traceback (most recent call last): SyntaxError: / must be ahead of * ->>> lambda a=1,/*,b,c: None +>>> lambda a=1,/*,b,c: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expected comma between / and * @@ -549,7 +579,7 @@ Traceback (most recent call last): SyntaxError: var-positional argument cannot have default value ->>> lambda a,**b=3: None +>>> lambda a,**b=3: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: var-keyword argument cannot have default value @@ -589,11 +619,11 @@ Traceback (most recent call last): SyntaxError: * argument may appear only once ->>> lambda a=1,d=,c: None +>>> lambda a=1,d=,c: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expected default value expression ->>> lambda a,d=,c: None +>>> lambda a,d=,c: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expected default value expression @@ -605,14 +635,13 @@ Traceback (most recent call last): SyntaxError: parameter without a default follows parameter with a default ->>> # TODO: RUSTPYTHON ->>> import ast; ast.parse(''' # doctest: +SKIP +>>> import ast; ast.parse(''' ... def f( ... *, # type: int ... a, # type: int ... ): ... pass -... ''', type_comments=True) +... ''', type_comments=True) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: bare * has associated type comment @@ -655,8 +684,7 @@ SyntaxError: Generator expression must be parenthesized >>> f((x for x in L), 1) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ->>> # TODO: RUSTPYTHON ->>> class C(x for x in L): # doctest: +SKIP +>>> class C(x for x in L): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: invalid syntax @@ -756,7 +784,7 @@ ... 290, 291, 292, 293, 294, 295, 296, 297, 298, 299) # doctest: +ELLIPSIS (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ..., 297, 298, 299) ->>> f(lambda x: x[0] = 3) +>>> f(lambda x: x[0] = 3) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? @@ -768,75 +796,76 @@ The grammar accepts any test (basically, any expression) in the keyword slot of a call site. Test a few different options. ->>> f(x()=2) +>>> f(x()=2) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f(a or b=1) +>>> f(a or b=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f(x.y=1) +>>> f(x.y=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> # TODO: RUSTPYTHON ->>> f((x)=2) # doctest: +SKIP +>>> f((x)=2) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression cannot contain assignment, perhaps you meant "=="? ->>> f(True=1) +>>> f(True=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to True ->>> f(False=1) +>>> f(False=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to False ->>> f(None=1) +>>> f(None=1) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to None >>> f(__debug__=1) Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> # TODO: RUSTPYTHON ->>> __debug__: int # doctest: +SKIP +>>> __debug__: int Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> f(a=) +>>> x.__debug__: int +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ +>>> f(a=) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expected argument value expression ->>> f(a, b, c=) +>>> f(a, b, c=) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expected argument value expression ->>> f(a, b, c=, d) +>>> f(a, b, c=, d) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expected argument value expression ->>> f(*args=[0]) +>>> f(*args=[0]) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to iterable argument unpacking ->>> f(a, b, *args=[0]) +>>> f(a, b, *args=[0]) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to iterable argument unpacking ->>> f(**kwargs={'a': 1}) +>>> f(**kwargs={'a': 1}) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to keyword argument unpacking ->>> f(a, b, *args, **kwargs={'a': 1}) +>>> f(a, b, *args, **kwargs={'a': 1}) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to keyword argument unpacking More set_context(): ->>> (x for x in x) += 1 +>>> (x for x in x) += 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'generator expression' is an illegal expression for augmented assignment ->>> None += 1 +>>> None += 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: 'None' is an illegal expression for augmented assignment >>> __debug__ += 1 Traceback (most recent call last): SyntaxError: cannot assign to __debug__ ->>> f() += 1 -Traceback (most recent call last): +>>> f() += 1 # TODO: RUSTPYTHON; Raises an exception # doctest: +SKIP +Traceback (most recent call last): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE SyntaxError: 'function call' is an illegal expression for augmented assignment -Test continue in finally in weird combinations. +Test control flow in finally continue in for loop under finally should be ok. @@ -850,51 +879,63 @@ >>> test() 9 -continue in a finally should be ok. +break in for loop under finally should be ok. >>> def test(): - ... for abc in range(10): - ... try: - ... pass - ... finally: - ... continue - ... print(abc) + ... try: + ... pass + ... finally: + ... for abc in range(10): + ... break + ... print(abc) >>> test() - 9 + 0 + +return in function under finally should be ok. >>> def test(): - ... for abc in range(10): - ... try: - ... pass - ... finally: - ... try: - ... continue - ... except: - ... pass - ... print(abc) + ... try: + ... pass + ... finally: + ... def f(): + ... return 42 + ... print(f()) >>> test() - 9 + 42 + +combine for loop and function def + +return in function under finally should be ok. >>> def test(): - ... for abc in range(10): - ... try: - ... pass - ... finally: - ... try: - ... pass - ... except: - ... continue - ... print(abc) + ... try: + ... pass + ... finally: + ... for i in range(10): + ... def f(): + ... return 42 + ... print(f()) >>> test() - 9 + 42 + + >>> def test(): + ... try: + ... pass + ... finally: + ... def f(): + ... for i in range(10): + ... return 42 + ... print(f()) + >>> test() + 42 A continue outside loop should not be allowed. >>> def foo(): ... try: - ... pass - ... finally: ... continue + ... finally: + ... pass Traceback (most recent call last): ... SyntaxError: 'continue' not properly in loop @@ -914,6 +955,18 @@ ... SyntaxError: 'break' outside loop +elif can't come after an else. + + >>> if a % 2 == 0: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + ... else: + ... pass + ... elif a % 2 == 1: + ... pass + Traceback (most recent call last): + ... + SyntaxError: 'elif' block follows an 'else' block + Misuse of the nonlocal and global statement can lead to a few unique syntax errors. >>> def f(): @@ -960,7 +1013,7 @@ ... SyntaxError: name 'x' is parameter and nonlocal - >>> def f(): + >>> def f(): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... global x ... nonlocal x Traceback (most recent call last): @@ -992,7 +1045,7 @@ a complex 'if' (one with 'elif') would fail to notice an invalid suite, leading to spurious errors. - >>> if 1: + >>> if 1: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... x() = 1 ... elif 1: ... pass @@ -1000,7 +1053,7 @@ ... SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? - >>> if 1: + >>> if 1: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... elif 1: ... x() = 1 @@ -1008,7 +1061,7 @@ ... SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? - >>> if 1: + >>> if 1: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... x() = 1 ... elif 1: ... pass @@ -1018,7 +1071,7 @@ ... SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? - >>> if 1: + >>> if 1: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... elif 1: ... x() = 1 @@ -1028,7 +1081,7 @@ ... SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? - >>> if 1: + >>> if 1: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... elif 1: ... pass @@ -1155,7 +1208,7 @@ >>> with block ad something: ... pass Traceback (most recent call last): - SyntaxError: invalid syntax + SyntaxError: invalid syntax. Did you mean 'and'? >>> try ... pass @@ -1193,22 +1246,40 @@ Traceback (most recent call last): SyntaxError: expected ':' - >>> if x = 3: + >>> match x: + ... case a, __debug__, b: + ... pass + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> match x: + ... case a, b, *__debug__: + ... pass + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> match x: + ... case Foo(a, __debug__=1, b=2): + ... pass + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> if x = 3: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? - >>> while x = 3: + >>> while x = 3: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? - >>> if x.a = 3: + >>> if x.a = 3: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: cannot assign to attribute here. Maybe you meant '==' instead of '='? - >>> while x.a = 3: + >>> while x.a = 3: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: cannot assign to attribute here. Maybe you meant '==' instead of '='? @@ -1242,39 +1313,39 @@ Parenthesized arguments in function definitions - >>> def f(x, (y, z), w): + >>> def f(x, (y, z), w): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> def f((x, y, z, w)): + >>> def f((x, y, z, w)): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> def f(x, (y, z, w)): + >>> def f(x, (y, z, w)): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> def f((x, y, z), w): + >>> def f((x, y, z), w): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): SyntaxError: Function parameters cannot be parenthesized - >>> lambda x, (y, z), w: None + >>> lambda x, (y, z), w: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized - >>> lambda (x, y, z, w): None + >>> lambda (x, y, z, w): None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized - >>> lambda x, (y, z, w): None + >>> lambda x, (y, z, w): None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized - >>> lambda (x, y, z), w: None + >>> lambda (x, y, z), w: None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Lambda expression parameters cannot be parenthesized @@ -1286,6 +1357,15 @@ Traceback (most recent call last): SyntaxError: expected 'except' or 'finally' block +Custom error message for __debug__ as exception variable + + >>> try: + ... pass + ... except TypeError as __debug__: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + Custom error message for try block mixing except and except* >>> try: @@ -1328,6 +1408,53 @@ Traceback (most recent call last): SyntaxError: cannot have both 'except' and 'except*' on the same 'try' +Better error message for using `except as` with not a name: + + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + ... except TypeError as obj.attr: + ... pass + Traceback (most recent call last): + SyntaxError: cannot use except statement with attribute + + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + ... except TypeError as obj[1]: + ... pass + Traceback (most recent call last): + SyntaxError: cannot use except statement with subscript + + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + ... except* TypeError as (obj, name): + ... pass + Traceback (most recent call last): + SyntaxError: cannot use except* statement with tuple + + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + ... except* TypeError as 1: + ... pass + Traceback (most recent call last): + SyntaxError: cannot use except* statement with literal + +Regression tests for gh-133999: + + >>> try: pass + ... except TypeError as name: raise from None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + SyntaxError: invalid syntax + + >>> try: pass + ... except* TypeError as name: raise from None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + SyntaxError: invalid syntax + + >>> match 1: + ... case 1 | 2 as abc: raise from None # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + SyntaxError: invalid syntax + Ensure that early = are not matched by the parser as invalid comparisons >>> f(2, 4, x=34); 1 $ 2 Traceback (most recent call last): @@ -1341,33 +1468,33 @@ Traceback (most recent call last): SyntaxError: invalid syntax - >>> dict(x=34, x=1, y=2); x $ y + >>> dict(x=34, x=1, y=2); x $ y # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax Incomplete dictionary literals - >>> {1:2, 3:4, 5} + >>> {1:2, 3:4, 5} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: ':' expected after dictionary key - >>> {1:2, 3:4, 5:} + >>> {1:2, 3:4, 5:} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression expected after dictionary key and ':' - >>> {1: *12+1, 23: 1} + >>> {1: *12+1, 23: 1} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot use a starred expression in a dictionary value - >>> {1: *12+1} + >>> {1: *12+1} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot use a starred expression in a dictionary value - >>> {1: 23, 1: *12+1} + >>> {1: 23, 1: *12+1} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot use a starred expression in a dictionary value - >>> {1:} + >>> {1:} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: expression expected after dictionary key and ':' @@ -1379,7 +1506,7 @@ # Ensure that the error is not raised for invalid expressions - >>> {1: 2, 3: foo(,), 4: 5} + >>> {1: 2, 3: foo(,), 4: 5} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax @@ -1389,48 +1516,48 @@ Specialized indentation errors: - >>> while condition: + >>> while condition: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'while' statement on line 1 - >>> for x in range(10): + >>> for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'for' statement on line 1 - >>> for x in range(10): + >>> for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... else: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'else' statement on line 3 - >>> async for x in range(10): + >>> async for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'for' statement on line 1 - >>> async for x in range(10): + >>> async for x in range(10): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... else: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'else' statement on line 3 - >>> if something: + >>> if something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'if' statement on line 1 - >>> if something: + >>> if something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... elif something_else: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'elif' statement on line 3 - >>> if something: + >>> if something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass ... elif something_else: ... pass @@ -1439,33 +1566,33 @@ Traceback (most recent call last): IndentationError: expected an indented block after 'else' statement on line 5 - >>> try: + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'try' statement on line 1 - >>> try: + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... something() ... except: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'except' statement on line 3 - >>> try: + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... something() ... except A: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'except' statement on line 3 - >>> try: + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... something() ... except* A: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'except*' statement on line 3 - >>> try: + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... something() ... except A: ... pass @@ -1474,7 +1601,7 @@ Traceback (most recent call last): IndentationError: expected an indented block after 'finally' statement on line 5 - >>> try: + >>> try: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... something() ... except* A: ... pass @@ -1483,68 +1610,81 @@ Traceback (most recent call last): IndentationError: expected an indented block after 'finally' statement on line 5 - >>> with A: + >>> with A: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> with A as a, B as b: + >>> with A as a, B as b: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> with (A as a, B as b): + >>> with (A as a, B as b): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> async with A: + >>> async with A: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> async with A as a, B as b: + >>> async with A as a, B as b: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> async with (A as a, B as b): + >>> async with (A as a, B as b): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'with' statement on line 1 - >>> def foo(x, /, y, *, z=2): + >>> def foo(x, /, y, *, z=2): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after function definition on line 1 - >>> def foo[T](x, /, y, *, z=2): + >>> def foo[T](x, /, y, *, z=2): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after function definition on line 1 - >>> class Blech(A): + >>> class Blech(A): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after class definition on line 1 - >>> class Blech[T](A): + >>> class Blech[T](A): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after class definition on line 1 - >>> match something: + >>> class C(__debug__=42): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> class Meta(type): + ... def __new__(*args, **kwargs): + ... pass + + >>> class C(metaclass=Meta, __debug__=42): # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... pass + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> match something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'match' statement on line 1 - >>> match something: + >>> match something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... case []: ... pass Traceback (most recent call last): IndentationError: expected an indented block after 'case' statement on line 2 - >>> match something: + >>> match something: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... case []: ... ... ... case {}: @@ -1553,40 +1693,24 @@ IndentationError: expected an indented block after 'case' statement on line 4 Make sure that the old "raise X, Y[, Z]" form is gone: - >>> raise X, Y + >>> raise X, Y # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: invalid syntax - >>> raise X, Y, Z + >>> raise X, Y, Z # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: invalid syntax Check that an multiple exception types with missing parentheses -raise a custom exception - - >>> # TODO: RUSTPYTHON - >>> try: # doctest: +SKIP - ... pass - ... except A, B: - ... pass - Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized - - >>> # TODO: RUSTPYTHON - >>> try: # doctest: +SKIP - ... pass - ... except A, B, C: - ... pass - Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized +raise a custom exception only when using 'as' >>> try: ... pass ... except A, B, C as blech: ... pass Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized + SyntaxError: multiple exception types must be parenthesized when using 'as' >>> try: ... pass @@ -1595,29 +1719,15 @@ ... finally: ... pass Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized + SyntaxError: multiple exception types must be parenthesized when using 'as' - >>> try: - ... pass - ... except* A, B: - ... pass - Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized - - >>> try: - ... pass - ... except* A, B, C: - ... pass - Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized - >>> try: ... pass ... except* A, B, C as blech: ... pass Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized + SyntaxError: multiple exception types must be parenthesized when using 'as' >>> try: ... pass @@ -1626,7 +1736,7 @@ ... finally: ... pass Traceback (most recent call last): - SyntaxError: multiple exception types must be parenthesized + SyntaxError: multiple exception types must be parenthesized when using 'as' Custom exception for 'except*' without an exception type @@ -1639,45 +1749,255 @@ Traceback (most recent call last): SyntaxError: expected one or more exception types +Check custom exceptions for keywords with typos + +>>> fur a in b: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'for'? + +>>> for a in b: +... pass +... elso: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'else'? + +>>> whille True: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'while'? + +>>> while True: +... pass +... elso: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'else'? + +>>> iff x > 5: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'if'? + +>>> if x: +... pass +... elseif y: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'elif'? + +>>> if x: +... pass +... elif y: +... pass +... elso: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'else'? + +>>> tyo: +... pass +... except y: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'try'? + +>>> classe MyClass: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'class'? + +>>> impor math +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'import'? + +>>> form x import y +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'from'? + +>>> defn calculate_sum(a, b): +... return a + b +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'def'? + +>>> def foo(): +... returm result +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'return'? + +>>> lamda x: x ** 2 +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'lambda'? + +>>> def foo(): +... yeld i +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'yield'? + +>>> def foo(): +... globel counter +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'global'? + +>>> frum math import sqrt +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'from'? + +>>> asynch def fetch_data(): +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'async'? + +>>> async def foo(): +... awaid fetch_data() +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'await'? + +>>> raisee ValueError("Error") +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'raise'? + +>>> [ +... x for x +... in range(3) +... of x +... ] +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'if'? + +>>> [ +... 123 fur x +... in range(3) +... if x +... ] +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'for'? + + +>>> for x im n: +... pass +Traceback (most recent call last): +SyntaxError: invalid syntax. Did you mean 'in'? >>> f(a=23, a=234) Traceback (most recent call last): ... SyntaxError: keyword argument repeated: a ->>> {1, 2, 3} = 42 +>>> {1, 2, 3} = 42 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to set display here. Maybe you meant '==' instead of '='? ->>> {1: 2, 3: 4} = 42 +>>> {1: 2, 3: 4} = 42 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to dict literal here. Maybe you meant '==' instead of '='? ->>> f'{x}' = 42 +>>> f'{x}' = 42 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to f-string expression here. Maybe you meant '==' instead of '='? ->>> f'{x}-{y}' = 42 +>>> f'{x}-{y}' = 42 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: cannot assign to f-string expression here. Maybe you meant '==' instead of '='? ->>> (x, y, z=3, d, e) +>>> ub'' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'b' prefixes are incompatible + +>>> bu"привет" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'b' prefixes are incompatible + +>>> ur'' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'r' prefixes are incompatible + +>>> ru"\t" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'r' prefixes are incompatible + +>>> uf'{1 + 1}' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'f' prefixes are incompatible + +>>> fu"" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'f' prefixes are incompatible + +>>> ut'{1}' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 't' prefixes are incompatible + +>>> tu"234" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 't' prefixes are incompatible + +>>> bf'{x!r}' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'b' and 'f' prefixes are incompatible + +>>> fb"text" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'b' and 'f' prefixes are incompatible + +>>> bt"text" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'b' and 't' prefixes are incompatible + +>>> tb'' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'b' and 't' prefixes are incompatible + +>>> tf"{0.3:.02f}" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'f' and 't' prefixes are incompatible + +>>> ft'{x=}' # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'f' and 't' prefixes are incompatible + +>>> tfu"{x=}" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'f' prefixes are incompatible + +>>> turf"{x=}" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'r' prefixes are incompatible + +>>> burft"{x=}" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'u' and 'b' prefixes are incompatible + +>>> brft"{x=}" # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: 'b' and 'f' prefixes are incompatible + +>>> t'{x}' = 42 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot assign to t-string expression here. Maybe you meant '==' instead of '='? + +>>> t'{x}-{y}' = 42 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot assign to t-string expression here. Maybe you meant '==' instead of '='? + +>>> (x, y, z=3, d, e) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> [x, y, z=3, d, e] +>>> [x, y, z=3, d, e] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> [z=3] +>>> [z=3] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> {x, y, z=3, d, e} +>>> {x, y, z=3, d, e} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? ->>> {z=3} +>>> {z=3} # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? @@ -1689,38 +2009,127 @@ Traceback (most recent call last): SyntaxError: trailing comma not allowed without surrounding parentheses ->>> import a from b +>>> import a from b # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z from b.y.z +>>> import a.y.z from b.y.z # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a from b as bar +>>> import a from b as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z from b.y.z as bar +>>> import a.y.z from b.y.z as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a, b,c from b +>>> import a, b,c from b # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z, b.y.z, c.y.z from b.y.z +>>> import a.y.z, b.y.z, c.y.z from b.y.z # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a,b,c from b as bar +>>> import a,b,c from b as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? ->>> import a.y.z, b.y.z, c.y.z from b.y.z as bar +>>> import a.y.z, b.y.z, c.y.z from b.y.z as bar # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Did you mean to use 'from ... import ...' instead? +>>> import __debug__ +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> import a as __debug__ +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> import a.b.c as __debug__ +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> from a import __debug__ +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> from a import b as __debug__ +Traceback (most recent call last): +SyntaxError: cannot assign to __debug__ + +>>> import a as b.c # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use attribute as import target + +>>> import a.b as (a, b) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use tuple as import target + +>>> import a, a.b as 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use literal as import target + +>>> import a.b as 'a', a # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use literal as import target + +>>> from a import (b as c.d) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use attribute as import target + +>>> from a import b as 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use literal as import target + +>>> from a import ( # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +... b as f()) +Traceback (most recent call last): +SyntaxError: cannot use function call as import target + +>>> from a import ( # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +... b as [], +... ) +Traceback (most recent call last): +SyntaxError: cannot use list as import target + +>>> from a import ( # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +... b, +... c as () +... ) +Traceback (most recent call last): +SyntaxError: cannot use tuple as import target + +>>> from a import b, с as d[e] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use subscript as import target + +>>> from a import с as d[e], b # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE +Traceback (most recent call last): +SyntaxError: cannot use subscript as import target + +# Check that we don't raise a "cannot use name as import target" error +# if there is an error in an unrelated statement after ';' + +>>> import a as b; None = 1 +Traceback (most recent call last): +SyntaxError: cannot assign to None + +>>> import a, b as c; d = 1; None = 1 +Traceback (most recent call last): +SyntaxError: cannot assign to None + +>>> from a import b as c; None = 1 +Traceback (most recent call last): +SyntaxError: cannot assign to None + +>>> from a import b, c as d; e = 1; None = 1 +Traceback (most recent call last): +SyntaxError: cannot assign to None + # Check that we dont raise the "trailing comma" error if there is more # input to the left of the valid part that we parsed. @@ -1824,24 +2233,47 @@ Traceback (most recent call last): SyntaxError: cannot assign to __debug__ - >>> import ä £ + >>> import ä £ # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: invalid character '£' (U+00A3) Invalid pattern matching constructs: - >>> # TODO: RUSTPYTHON - >>> match ...: # doctest: +SKIP + >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... case 42 as _: ... ... Traceback (most recent call last): SyntaxError: cannot use '_' as a target - >>> match ...: + >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE ... case 42 as 1+2+4: ... ... Traceback (most recent call last): - SyntaxError: invalid pattern target + SyntaxError: cannot use expression as pattern target + + >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... case 42 as a.b: + ... ... + Traceback (most recent call last): + SyntaxError: cannot use attribute as pattern target + + >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... case 42 as (a, b): + ... ... + Traceback (most recent call last): + SyntaxError: cannot use tuple as pattern target + + >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... case 42 as (a + 1): + ... ... + Traceback (most recent call last): + SyntaxError: cannot use expression as pattern target + + >>> match ...: # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + ... case (32 as x) | (42 as a()): + ... ... + Traceback (most recent call last): + SyntaxError: cannot use function call as pattern target >>> match ...: ... case Foo(z=1, y=2, x): @@ -1875,7 +2307,7 @@ Traceback (most recent call last): ... SyntaxError: invalid syntax - >>> A[:(*b)] + >>> A[:(*b)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: cannot use starred expression here @@ -1894,7 +2326,7 @@ Traceback (most recent call last): ... SyntaxError: invalid syntax - >>> A[(*b):] + >>> A[(*b):] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: cannot use starred expression here @@ -1928,22 +2360,22 @@ A[*(1:2)] - >>> A[*(1:2)] + >>> A[*(1:2)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: Invalid star expression - >>> A[*(1:2)] = 1 + >>> A[*(1:2)] = 1 # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: Invalid star expression - >>> del A[*(1:2)] + >>> del A[*(1:2)] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: Invalid star expression A[*:] and A[:*] - >>> A[*:] + >>> A[*:] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: Invalid star expression @@ -1954,7 +2386,7 @@ A[*] - >>> A[*] + >>> A[*] # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: Invalid star expression @@ -2180,8 +2612,7 @@ def f(x: *b) ... SyntaxError: yield expression cannot be used within a ParamSpec default - >>> # TODO: RUSTPYTHON - >>> type A = (x := 3) # doctest: +SKIP + >>> type A = (x := 3) Traceback (most recent call last): ... SyntaxError: named expression cannot be used within a type alias @@ -2201,22 +2632,30 @@ def f(x: *b) ... SyntaxError: yield expression cannot be used within a type alias - >>> class A[T]((x := 3)): ... + >>> type __debug__ = int + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> class A[__debug__]: pass # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE + Traceback (most recent call last): + SyntaxError: cannot assign to __debug__ + + >>> class A[T]((x := 3)): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: named expression cannot be used within the definition of a generic - >>> class A[T]((yield 3)): ... + >>> class A[T]((yield 3)): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: yield expression cannot be used within the definition of a generic - >>> class A[T]((await 3)): ... + >>> class A[T]((await 3)): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: await expression cannot be used within the definition of a generic - >>> class A[T]((yield from [])): ... + >>> class A[T]((yield from [])): ... # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): ... SyntaxError: yield expression cannot be used within the definition of a generic @@ -2225,23 +2664,23 @@ def f(x: *b) Traceback (most recent call last): SyntaxError: iterable argument unpacking follows keyword argument unpacking - >>> f(**x, *) + >>> f(**x, *) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x, *:) + >>> f(x, *:) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x, *) + >>> f(x, *) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x = 5, *) + >>> f(x = 5, *) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Invalid star expression - >>> f(x = 5, *:) + >>> f(x = 5, *:) # TODO: RUSTPYTHON; Wrong error message # doctest: +EXPECTED_FAILURE Traceback (most recent call last): SyntaxError: Invalid star expression """ @@ -2253,7 +2692,90 @@ def f(x: *b) from test import support -class SyntaxTestCase(unittest.TestCase): +class SyntaxWarningTest(unittest.TestCase): + def check_warning(self, code, errtext, filename="", mode="exec"): + """Check that compiling code raises SyntaxWarning with errtext. + + errtest is a regular expression that must be present in the + text of the warning raised. + """ + with self.assertWarnsRegex(SyntaxWarning, errtext): + compile(code, filename, mode) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxWarning not triggered + def test_return_in_finally(self): + source = textwrap.dedent(""" + def f(): + try: + pass + finally: + return 42 + """) + self.check_warning(source, "'return' in a 'finally' block") + + source = textwrap.dedent(""" + def f(): + try: + pass + finally: + try: + return 42 + except: + pass + """) + self.check_warning(source, "'return' in a 'finally' block") + + source = textwrap.dedent(""" + def f(): + try: + pass + finally: + try: + pass + except: + return 42 + """) + self.check_warning(source, "'return' in a 'finally' block") + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxWarning not triggered + def test_break_and_continue_in_finally(self): + for kw in ('break', 'continue'): + + source = textwrap.dedent(f""" + for abc in range(10): + try: + pass + finally: + {kw} + """) + self.check_warning(source, f"'{kw}' in a 'finally' block") + + source = textwrap.dedent(f""" + for abc in range(10): + try: + pass + finally: + try: + {kw} + except: + pass + """) + self.check_warning(source, f"'{kw}' in a 'finally' block") + + source = textwrap.dedent(f""" + for abc in range(10): + try: + pass + finally: + try: + pass + except: + {kw} + """) + self.check_warning(source, f"'{kw}' in a 'finally' block") + + +class SyntaxErrorTestCase(unittest.TestCase): def _check_error(self, code, errtext, filename="", mode="exec", subclass=None, @@ -2261,7 +2783,7 @@ def _check_error(self, code, errtext, """Check that compiling code raises SyntaxError with errtext. errtest is a regular expression that must be present in the - test of the exception raised. If subclass is specified it + text of the exception raised. If subclass is specified it is the expected subclass of SyntaxError (e.g. IndentationError). """ try: @@ -2285,7 +2807,7 @@ def _check_error(self, code, errtext, else: self.fail("compile() did not raise SyntaxError") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_expression_with_assignment(self): self._check_error( "print(end1 + end2 = ' ')", @@ -2299,7 +2821,7 @@ def test_curly_brace_after_primary_raises_immediately(self): def test_assign_call(self): self._check_error("f() = 1", "assign") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_assign_del(self): self._check_error("del (,)", "invalid syntax") self._check_error("del 1", "cannot delete literal") @@ -2390,7 +2912,6 @@ def test_break_outside_loop(self): self._check_error("with object() as obj:\n break", msg, lineno=2) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_continue_outside_loop(self): msg = "not properly in loop" self._check_error("if 0: continue", msg, lineno=1) @@ -2415,36 +2936,32 @@ def test_bad_outdent(self): "unindent does not match .* level", subclass=IndentationError) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_kwargs_last(self): self._check_error("int(base=10, '2')", "positional argument follows keyword argument") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_kwargs_last2(self): self._check_error("int(**{'base': 10}, '2')", "positional argument follows " "keyword argument unpacking") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_kwargs_last3(self): self._check_error("int(**{'base': 10}, *['2'])", "iterable argument unpacking follows " "keyword argument unpacking") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_generator_in_function_call(self): self._check_error("foo(x, y for y in range(3) for z in range(2) if z , p)", "Generator expression must be parenthesized", lineno=1, end_lineno=1, offset=11, end_offset=53) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_except_then_except_star(self): self._check_error("try: pass\nexcept ValueError: pass\nexcept* TypeError: pass", r"cannot have both 'except' and 'except\*' on the same 'try'", lineno=3, end_lineno=3, offset=1, end_offset=8) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_except_star_then_except(self): self._check_error("try: pass\nexcept* ValueError: pass\nexcept TypeError: pass", r"cannot have both 'except' and 'except\*' on the same 'try'", @@ -2575,7 +3092,6 @@ async def bug(): with self.subTest(f"out of range: {n=}"): self._check_error(get_code(n), "too many statically nested blocks") - @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message def test_barry_as_flufl_with_syntax_errors(self): # The "barry_as_flufl" rule can produce some "bugs-at-a-distance" if # is reading the wrong token in the presence of syntax errors later @@ -2593,7 +3109,7 @@ def func2(): """ self._check_error(code, "expected ':'") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_line_continuation_error_position(self): self._check_error(r"a = 3 \ 4", "unexpected character after line continuation character", @@ -2605,7 +3121,6 @@ def test_invalid_line_continuation_error_position(self): "unexpected character after line continuation character", lineno=3, offset=4) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_line_continuation_left_recursive(self): # Check bpo-42218: SyntaxErrors following left-recursive rules # (t_primary_raw in this case) need to be tested explicitly @@ -2614,7 +3129,7 @@ def test_invalid_line_continuation_left_recursive(self): self._check_error("A.\u03bc\\\n", "unexpected EOF while parsing") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_parenthesis(self): for paren in "([{": self._check_error(paren + "1 + 2", f"\\{paren}' was never closed") @@ -2640,7 +3155,7 @@ def test_error_parenthesis(self): s = b'# coding=latin\n(aaaaaaaaaaaaaaaaa\naaaaaaaaaaa\xb5' self._check_error(s, r"'\(' was never closed") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_string_literal(self): self._check_error("'blech", r"unterminated string literal \(.*\)$") @@ -2654,7 +3169,7 @@ def test_error_string_literal(self): self._check_error("'''blech", "unterminated triple-quoted string literal") self._check_error('"""blech', "unterminated triple-quoted string literal") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_invisible_characters(self): self._check_error('print\x17("Hello")', "invalid non-printable character") self._check_error(b"with(0,,):\n\x01", "invalid non-printable character") @@ -2677,7 +3192,6 @@ def case(x): """ compile(code, "", "exec") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_multiline_compiler_error_points_to_the_end(self): self._check_error( "call(\na=1,\na=1\n)", @@ -2730,6 +3244,7 @@ def test_error_on_parser_stack_overflow(self): compile(source, "", mode) @support.cpython_only + @support.skip_wasi_stack_overflow() def test_deep_invalid_rule(self): # Check that a very deep invalid rule in the PEG # parser doesn't have exponential backtracking. @@ -2737,10 +3252,86 @@ def test_deep_invalid_rule(self): with self.assertRaises(SyntaxError): compile(source, "", "exec") + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_except_stmt_invalid_as_expr(self): + self._check_error( + textwrap.dedent( + """ + try: + pass + except ValueError as obj.attr: + pass + """ + ), + errtext="cannot use except statement with attribute", + lineno=4, + end_lineno=4, + offset=22, + end_offset=22 + len("obj.attr"), + ) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_match_stmt_invalid_as_expr(self): + self._check_error( + textwrap.dedent( + """ + match 1: + case x as obj.attr: + ... + """ + ), + errtext="cannot use attribute as pattern target", + lineno=3, + end_lineno=3, + offset=15, + end_offset=15 + len("obj.attr"), + ) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_ifexp_else_stmt(self): + msg = "expected expression after 'else', but statement is given" + + for stmt in [ + "pass", + "return", + "return 2", + "raise Exception('a')", + "del a", + "yield 2", + "assert False", + "break", + "continue", + "import", + "import ast", + "from", + "from ast import *" + ]: + self._check_error(f"x = 1 if 1 else {stmt}", msg) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_ifexp_body_stmt_else_expression(self): + msg = "expected expression before 'if', but statement is given" + + for stmt in [ + "pass", + "break", + "continue" + ]: + self._check_error(f"x = {stmt} if 1 else 1", msg) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_ifexp_body_stmt_else_stmt(self): + msg = "expected expression before 'if', but statement is given" + for lhs_stmt, rhs_stmt in [ + ("pass", "pass"), + ("break", "pass"), + ("continue", "import ast") + ]: + self._check_error(f"x = {lhs_stmt} if 1 else {rhs_stmt}", msg) def load_tests(loader, tests, pattern): - # TODO: RUSTPYTHON Eventually remove the optionflags for ingoring exception details. - tests.addTest(doctest.DocTestSuite(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)) + from test.support.rustpython import DocTestChecker # TODO: RUSTPYTHON + tests.addTest(doctest.DocTestSuite(checker=DocTestChecker())) # TODO: RUSTPYTHON return tests diff --git a/Lib/test/test_sys_setprofile.py b/Lib/test/test_sys_setprofile.py index 21a09b51926..813adff2a32 100644 --- a/Lib/test/test_sys_setprofile.py +++ b/Lib/test/test_sys_setprofile.py @@ -30,9 +30,9 @@ def callback(self, frame, event, arg): if (event == "call" or event == "return" or event == "exception"): - self.add_event(event, frame) + self.add_event(event, frame, arg) - def add_event(self, event, frame=None): + def add_event(self, event, frame=None, arg=None): """Add an event to the log.""" if frame is None: frame = sys._getframe(1) @@ -43,7 +43,7 @@ def add_event(self, event, frame=None): frameno = len(self.frames) self.frames.append(frame) - self.events.append((frameno, event, ident(frame))) + self.events.append((frameno, event, ident(frame), arg)) def get_events(self): """Remove calls to add_event().""" @@ -89,11 +89,16 @@ def trace_pass(self, frame): class TestCaseBase(unittest.TestCase): - def check_events(self, callable, expected): + def check_events(self, callable, expected, check_args=False): events = capture_events(callable, self.new_watcher()) - if events != expected: - self.fail("Expected events:\n%s\nReceived events:\n%s" - % (pprint.pformat(expected), pprint.pformat(events))) + if check_args: + if events != expected: + self.fail("Expected events:\n%s\nReceived events:\n%s" + % (pprint.pformat(expected), pprint.pformat(events))) + else: + if [(frameno, event, ident) for frameno, event, ident, arg in events] != expected: + self.fail("Expected events:\n%s\nReceived events:\n%s" + % (pprint.pformat(expected), pprint.pformat(events))) class ProfileHookTestCase(TestCaseBase): @@ -119,7 +124,7 @@ def f(p): def test_caught_exception(self): def f(p): try: 1/0 - except: pass + except ZeroDivisionError: pass f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -128,7 +133,7 @@ def f(p): def test_caught_nested_exception(self): def f(p): try: 1/0 - except: pass + except ZeroDivisionError: pass f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -151,9 +156,9 @@ def f(p): def g(p): try: f(p) - except: + except ZeroDivisionError: try: f(p) - except: pass + except ZeroDivisionError: pass f_ident = ident(f) g_ident = ident(g) self.check_events(g, [(1, 'call', g_ident), @@ -164,6 +169,7 @@ def g(p): (1, 'return', g_ident), ]) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_exception_propagation(self): def f(p): 1/0 @@ -182,7 +188,7 @@ def g(p): def test_raise_twice(self): def f(p): try: 1/0 - except: 1/0 + except ZeroDivisionError: 1/0 f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -191,7 +197,7 @@ def f(p): def test_raise_reraise(self): def f(p): try: 1/0 - except: raise + except ZeroDivisionError: raise f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -255,6 +261,24 @@ def g(p): (1, 'return', g_ident), ]) + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_unfinished_generator(self): + def f(): + for i in range(2): + yield i + def g(p): + next(f()) + + f_ident = ident(f) + g_ident = ident(g) + self.check_events(g, [(1, 'call', g_ident, None), + (2, 'call', f_ident, None), + (2, 'return', f_ident, 0), + (2, 'call', f_ident, None), + (2, 'return', f_ident, None), + (1, 'return', g_ident, None), + ], check_args=True) + def test_stop_iteration(self): def f(): for i in range(2): @@ -300,7 +324,7 @@ def f(p): def test_caught_exception(self): def f(p): try: 1/0 - except: pass + except ZeroDivisionError: pass f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -415,5 +439,104 @@ def show_events(callable): pprint.pprint(capture_events(callable)) +class TestEdgeCases(unittest.TestCase): + + def setUp(self): + self.addCleanup(sys.setprofile, sys.getprofile()) + sys.setprofile(None) + + def test_reentrancy(self): + def foo(*args): + ... + + def bar(*args): + ... + + class A: + def __call__(self, *args): + pass + + def __del__(self): + sys.setprofile(bar) + + sys.setprofile(A()) + sys.setprofile(foo) + self.assertEqual(sys.getprofile(), bar) + + def test_same_object(self): + def foo(*args): + ... + + sys.setprofile(foo) + del foo + sys.setprofile(sys.getprofile()) + + def test_profile_after_trace_opcodes(self): + def f(): + ... + + sys._getframe().f_trace_opcodes = True + prev_trace = sys.gettrace() + sys.settrace(lambda *args: None) + f() + sys.settrace(prev_trace) + sys.setprofile(lambda *args: None) + f() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_method_with_c_function(self): + # gh-122029 + # When we have a PyMethodObject whose im_func is a C function, we + # should record both the call and the return. f = classmethod(repr) + # is just a way to create a PyMethodObject with a C function. + class A: + f = classmethod(repr) + events = [] + sys.setprofile(lambda frame, event, args: events.append(event)) + A().f() + sys.setprofile(None) + # The last c_call is the call to sys.setprofile + self.assertEqual(events, ['c_call', 'c_return', 'c_call']) + + class B: + f = classmethod(max) + events = [] + sys.setprofile(lambda frame, event, args: events.append(event)) + # Not important, we only want to trigger INSTRUMENTED_CALL_KW + B().f(1, key=lambda x: 0) + sys.setprofile(None) + # The last c_call is the call to sys.setprofile + self.assertEqual( + events, + ['c_call', + 'call', 'return', + 'call', 'return', + 'c_return', + 'c_call' + ] + ) + + # Test CALL_FUNCTION_EX + events = [] + sys.setprofile(lambda frame, event, args: events.append(event)) + # Not important, we only want to trigger INSTRUMENTED_CALL_KW + args = (1,) + m = B().f + m(*args, key=lambda x: 0) + sys.setprofile(None) + # The last c_call is the call to sys.setprofile + # INSTRUMENTED_CALL_FUNCTION_EX has different behavior than the other + # instrumented call bytecodes, it does not unpack the callable before + # calling it. This is probably not ideal because it's not consistent, + # but at least we get a consistent call stack (no unmatched c_call). + self.assertEqual( + events, + ['call', 'return', + 'call', 'return', + 'c_call' + ] + ) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sys_settrace.py b/Lib/test/test_sys_settrace.py index aa2d54ee16e..7eef1290dc2 100644 --- a/Lib/test/test_sys_settrace.py +++ b/Lib/test/test_sys_settrace.py @@ -1488,8 +1488,6 @@ def test_jump_in_nested_finally_3(output): output.append(11) output.append(12) - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(5, 11, [2, 4], (ValueError, 'after')) def test_no_jump_over_return_try_finally_in_finally_block(output): try: diff --git a/Lib/test/test_syslog.py b/Lib/test/test_syslog.py index b378d62e5cf..9ad2cc51819 100644 --- a/Lib/test/test_syslog.py +++ b/Lib/test/test_syslog.py @@ -43,8 +43,6 @@ def test_setlogmask(self): self.assertEqual(syslog.setlogmask(0), mask) self.assertEqual(syslog.setlogmask(oldmask), mask) - # TODO: RUSTPYTHON; AssertionError: 12 is not false - @unittest.expectedFailure def test_log_mask(self): mask = syslog.LOG_UPTO(syslog.LOG_WARNING) self.assertTrue(mask & syslog.LOG_MASK(syslog.LOG_WARNING)) diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index d6cbb350428..8d9f8824f7c 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -38,8 +38,6 @@ import lzma except ImportError: lzma = None -# XXX: RUSTPYTHON; xz is not supported yet -lzma = None try: from compression import zstd except ImportError: @@ -1236,6 +1234,25 @@ def test_longname_directory(self): self.assertIsNotNone(tar.getmember(longdir)) self.assertIsNotNone(tar.getmember(longdir.removesuffix('/'))) + def test_longname_file_not_directory(self): + # Test reading a longname file and ensure it is not handled as a directory + # Issue #141707 + buf = io.BytesIO() + with tarfile.open(mode='w', fileobj=buf, format=self.format) as tar: + ti = tarfile.TarInfo() + ti.type = tarfile.AREGTYPE + ti.name = ('a' * 99) + '/' + ('b' * 3) + tar.addfile(ti) + + expected = {t.name: t.type for t in tar.getmembers()} + + buf.seek(0) + with tarfile.open(mode='r', fileobj=buf) as tar: + actual = {t.name: t.type for t in tar.getmembers()} + + self.assertEqual(expected, actual) + + class GNUReadTest(LongnameTest, ReadTest, unittest.TestCase): subdir = "gnu" diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py index 99052de4c7f..8d752dbb7aa 100644 --- a/Lib/test/test_threading_local.py +++ b/Lib/test/test_threading_local.py @@ -185,7 +185,6 @@ class LocalSubclass(self._local): """To test that subclasses behave properly.""" self._test_dict_attribute(LocalSubclass) - @unittest.expectedFailure # TODO: RUSTPYTHON; cycle detection/collection def test_cycle_collection(self): class X: pass @@ -233,6 +232,10 @@ class ThreadLocalTest(unittest.TestCase, BaseLocalTest): def test_arguments(self): return super().test_arguments() + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_cycle_collection(self): + return super().test_cycle_collection() + class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest): _local = _threading_local.local diff --git a/Lib/test/test_timeout.py b/Lib/test/test_timeout.py index 70a0175d771..967d4ff7e1c 100644 --- a/Lib/test/test_timeout.py +++ b/Lib/test/test_timeout.py @@ -5,9 +5,6 @@ from test import support from test.support import socket_helper -# This requires the 'network' resource as given on the regrtest command line. -skip_expected = not support.is_resource_enabled('network') - import time import errno import socket @@ -29,10 +26,8 @@ class CreationTestCase(unittest.TestCase): """Test case for socket.gettimeout() and socket.settimeout()""" def setUp(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - def tearDown(self): - self.sock.close() + self.sock = self.enterContext( + socket.socket(socket.AF_INET, socket.SOCK_STREAM)) def testObjectCreation(self): # Test Socket creation @@ -53,10 +48,10 @@ def testFloatReturnValue(self): def testReturnType(self): # Test return type of gettimeout() self.sock.settimeout(1) - self.assertEqual(type(self.sock.gettimeout()), type(1.0)) + self.assertIs(type(self.sock.gettimeout()), float) self.sock.settimeout(3.9) - self.assertEqual(type(self.sock.gettimeout()), type(1.0)) + self.assertIs(type(self.sock.gettimeout()), float) def testTypeCheck(self): # Test type checking by settimeout() @@ -116,8 +111,6 @@ class TimeoutTestCase(unittest.TestCase): def setUp(self): raise NotImplementedError() - tearDown = setUp - def _sock_operation(self, count, timeout, method, *args): """ Test the specified socket method. @@ -145,19 +138,16 @@ class TCPTimeoutTestCase(TimeoutTestCase): """TCP test case for socket.socket() timeout functions""" def setUp(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock = self.enterContext( + socket.socket(socket.AF_INET, socket.SOCK_STREAM)) self.addr_remote = resolve_address('www.python.org.', 80) - def tearDown(self): - self.sock.close() - - @unittest.skipIf(True, 'need to replace these hosts; see bpo-35518') def testConnectTimeout(self): # Testing connect timeout is tricky: we need to have IP connectivity # to a host that silently drops our packets. We can't simulate this # from Python because it's a function of the underlying TCP/IP stack. - # So, the following Snakebite host has been defined: - blackhole = resolve_address('blackhole.snakebite.net', 56666) + # So, the following port on the pythontest.net host has been defined: + blackhole = resolve_address('pythontest.net', 56666) # Blackhole has been configured to silently drop any incoming packets. # No RSTs (for TCP) or ICMP UNREACH (for UDP/ICMP) will be sent back @@ -169,7 +159,7 @@ def testConnectTimeout(self): # to firewalling or general network configuration. In order to improve # our confidence in testing the blackhole, a corresponding 'whitehole' # has also been set up using one port higher: - whitehole = resolve_address('whitehole.snakebite.net', 56667) + whitehole = resolve_address('pythontest.net', 56667) # This address has been configured to immediately drop any incoming # packets as well, but it does it respectfully with regards to the @@ -183,35 +173,27 @@ def testConnectTimeout(self): # timeframe). # For the records, the whitehole/blackhole configuration has been set - # up using the 'pf' firewall (available on BSDs), using the following: + # up using the 'iptables' firewall, using the following rules: # - # ext_if="bge0" - # - # blackhole_ip="35.8.247.6" - # whitehole_ip="35.8.247.6" - # blackhole_port="56666" - # whitehole_port="56667" - # - # block return in log quick on $ext_if proto { tcp udp } \ - # from any to $whitehole_ip port $whitehole_port - # block drop in log quick on $ext_if proto { tcp udp } \ - # from any to $blackhole_ip port $blackhole_port + # -A INPUT -p tcp --destination-port 56666 -j DROP + # -A INPUT -p udp --destination-port 56666 -j DROP + # -A INPUT -p tcp --destination-port 56667 -j REJECT + # -A INPUT -p udp --destination-port 56667 -j REJECT # + # See https://github.com/python/psf-salt/blob/main/pillar/base/firewall/snakebite.sls + # for the current configuration. skip = True - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - timeout = support.LOOPBACK_TIMEOUT - sock.settimeout(timeout) - try: - sock.connect((whitehole)) - except TimeoutError: - pass - except OSError as err: - if err.errno == errno.ECONNREFUSED: - skip = False - finally: - sock.close() - del sock + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + try: + timeout = support.LOOPBACK_TIMEOUT + sock.settimeout(timeout) + sock.connect((whitehole)) + except TimeoutError: + pass + except OSError as err: + if err.errno == errno.ECONNREFUSED: + skip = False if skip: self.skipTest( @@ -278,10 +260,8 @@ class UDPTimeoutTestCase(TimeoutTestCase): """UDP test case for socket.socket() timeout functions""" def setUp(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - - def tearDown(self): - self.sock.close() + self.sock = self.enterContext( + socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) def testRecvfromTimeout(self): # Test recvfrom() timeout @@ -292,6 +272,7 @@ def testRecvfromTimeout(self): def setUpModule(): support.requires('network') + support.requires_working_socket(module=True) if __name__ == "__main__": diff --git a/Lib/test/test_tools/i18n_data/ascii-escapes.pot b/Lib/test/test_tools/i18n_data/ascii-escapes.pot index 18d868b6a20..cc5a9f6ba61 100644 --- a/Lib/test/test_tools/i18n_data/ascii-escapes.pot +++ b/Lib/test/test_tools/i18n_data/ascii-escapes.pot @@ -15,30 +15,36 @@ msgstr "" "Generated-By: pygettext.py 1.5\n" +#. Special characters that are always escaped in the POT file #: escapes.py:5 msgid "" "\"\t\n" "\r\\" msgstr "" +#. All ascii characters 0-31 #: escapes.py:8 msgid "" "\000\001\002\003\004\005\006\007\010\t\n" "\013\014\r\016\017\020\021\022\023\024\025\026\027\030\031\032\033\034\035\036\037" msgstr "" +#. All ascii characters 32-126 #: escapes.py:13 msgid " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" msgstr "" +#. ascii char 127 #: escapes.py:17 msgid "\177" msgstr "" +#. some characters in the 128-255 range #: escapes.py:20 -msgid "€   ÿ" +msgid "\302\200 \302\240 ÿ" msgstr "" +#. some characters >= 256 encoded as 2, 3 and 4 bytes, respectively #: escapes.py:23 msgid "α ㄱ 𓂀" msgstr "" diff --git a/Lib/test/test_tools/i18n_data/comments.pot b/Lib/test/test_tools/i18n_data/comments.pot new file mode 100644 index 00000000000..a1df46d453c --- /dev/null +++ b/Lib/test/test_tools/i18n_data/comments.pot @@ -0,0 +1,110 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"POT-Creation-Date: 2000-01-01 00:00+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: pygettext.py 1.5\n" + + +#: comments.py:4 +msgid "foo" +msgstr "" + +#. i18n: This is a translator comment +#: comments.py:7 +msgid "bar" +msgstr "" + +#. i18n: This is a translator comment +#. i18n: This is another translator comment +#: comments.py:11 +msgid "baz" +msgstr "" + +#. i18n: This is a translator comment +#. with multiple +#. lines +#: comments.py:16 +msgid "qux" +msgstr "" + +#. i18n: This is a translator comment +#: comments.py:21 +msgid "quux" +msgstr "" + +#. i18n: This is a translator comment +#. with multiple lines +#. i18n: This is another translator comment +#. with multiple lines +#: comments.py:27 +msgid "corge" +msgstr "" + +#: comments.py:31 +msgid "grault" +msgstr "" + +#. i18n: This is another translator comment +#: comments.py:36 +msgid "garply" +msgstr "" + +#: comments.py:40 +msgid "george" +msgstr "" + +#. i18n: This is another translator comment +#: comments.py:45 +msgid "waldo" +msgstr "" + +#. i18n: This is a translator comment +#. i18n: This is also a translator comment +#. i18n: This is another translator comment +#: comments.py:50 +msgid "waldo2" +msgstr "" + +#. i18n: This is a translator comment +#. i18n: This is another translator comment +#. i18n: This is yet another translator comment +#. i18n: This is a translator comment +#. with multiple lines +#: comments.py:53 comments.py:56 comments.py:59 comments.py:63 +msgid "fred" +msgstr "" + +#: comments.py:65 +msgid "plugh" +msgstr "" + +#: comments.py:67 +msgid "foobar" +msgstr "" + +#. i18n: This is a translator comment +#: comments.py:71 +msgid "xyzzy" +msgstr "" + +#: comments.py:72 +msgid "thud" +msgstr "" + +#. i18n: This is a translator comment +#. i18n: This is another translator comment +#. i18n: This is yet another translator comment +#: comments.py:78 +msgid "foos" +msgstr "" + diff --git a/Lib/test/test_tools/i18n_data/comments.py b/Lib/test/test_tools/i18n_data/comments.py new file mode 100644 index 00000000000..dca4dfa57b1 --- /dev/null +++ b/Lib/test/test_tools/i18n_data/comments.py @@ -0,0 +1,78 @@ +from gettext import gettext as _ + +# Not a translator comment +_('foo') + +# i18n: This is a translator comment +_('bar') + +# i18n: This is a translator comment +# i18n: This is another translator comment +_('baz') + +# i18n: This is a translator comment +# with multiple +# lines +_('qux') + +# This comment should not be included because +# it does not start with the prefix +# i18n: This is a translator comment +_('quux') + +# i18n: This is a translator comment +# with multiple lines +# i18n: This is another translator comment +# with multiple lines +_('corge') + +# i18n: This comment should be ignored + +_('grault') + +# i18n: This comment should be ignored + +# i18n: This is another translator comment +_('garply') + +# i18n: comment should be ignored +x = 1 +_('george') + +# i18n: This comment should be ignored +x = 1 +# i18n: This is another translator comment +_('waldo') + +# i18n: This is a translator comment +x = 1 # i18n: This is also a translator comment +# i18n: This is another translator comment +_('waldo2') + +# i18n: This is a translator comment +_('fred') + +# i18n: This is another translator comment +_('fred') + +# i18n: This is yet another translator comment +_('fred') + +# i18n: This is a translator comment +# with multiple lines +_('fred') + +_('plugh') # i18n: This comment should be ignored + +_('foo' # i18n: This comment should be ignored + 'bar') # i18n: This comment should be ignored + +# i18n: This is a translator comment +_('xyzzy') +_('thud') + + +## i18n: This is a translator comment +# # i18n: This is another translator comment +### ### i18n: This is yet another translator comment +_('foos') diff --git a/Lib/test/test_tools/i18n_data/custom_keywords.pot b/Lib/test/test_tools/i18n_data/custom_keywords.pot new file mode 100644 index 00000000000..03a9cba3a20 --- /dev/null +++ b/Lib/test/test_tools/i18n_data/custom_keywords.pot @@ -0,0 +1,51 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"POT-Creation-Date: 2000-01-01 00:00+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: pygettext.py 1.5\n" + + +#: custom_keywords.py:10 custom_keywords.py:11 +msgid "bar" +msgstr "" + +#: custom_keywords.py:13 +msgid "cat" +msgid_plural "cats" +msgstr[0] "" +msgstr[1] "" + +#: custom_keywords.py:14 +msgid "dog" +msgid_plural "dogs" +msgstr[0] "" +msgstr[1] "" + +#: custom_keywords.py:16 +msgctxt "context" +msgid "bar" +msgstr "" + +#: custom_keywords.py:18 +msgctxt "context" +msgid "cat" +msgid_plural "cats" +msgstr[0] "" +msgstr[1] "" + +#: custom_keywords.py:34 +msgid "overridden" +msgid_plural "default" +msgstr[0] "" +msgstr[1] "" + diff --git a/Lib/test/test_tools/i18n_data/custom_keywords.py b/Lib/test/test_tools/i18n_data/custom_keywords.py new file mode 100644 index 00000000000..ba0ffe77180 --- /dev/null +++ b/Lib/test/test_tools/i18n_data/custom_keywords.py @@ -0,0 +1,34 @@ +from gettext import ( + gettext as foo, + ngettext as nfoo, + pgettext as pfoo, + npgettext as npfoo, + gettext as bar, + gettext as _, +) + +foo('bar') +foo('bar', 'baz') + +nfoo('cat', 'cats', 1) +nfoo('dog', 'dogs') + +pfoo('context', 'bar') + +npfoo('context', 'cat', 'cats', 1) + +# This is an unknown keyword and should be ignored +bar('baz') + +# 'nfoo' requires at least 2 arguments +nfoo('dog') + +# 'pfoo' requires at least 2 arguments +pfoo('context') + +# 'npfoo' requires at least 3 arguments +npfoo('context') +npfoo('context', 'cat') + +# --keyword should override the default keyword +_('overridden', 'default') diff --git a/Lib/test/test_tools/i18n_data/docstrings.pot b/Lib/test/test_tools/i18n_data/docstrings.pot index 5af1d41422f..387db2413a5 100644 --- a/Lib/test/test_tools/i18n_data/docstrings.pot +++ b/Lib/test/test_tools/i18n_data/docstrings.pot @@ -15,26 +15,40 @@ msgstr "" "Generated-By: pygettext.py 1.5\n" -#: docstrings.py:7 +#: docstrings.py:1 +#, docstring +msgid "Module docstring" +msgstr "" + +#: docstrings.py:9 #, docstring msgid "" msgstr "" -#: docstrings.py:18 +#: docstrings.py:15 +#, docstring +msgid "docstring" +msgstr "" + +#: docstrings.py:20 #, docstring msgid "" "multiline\n" -" docstring\n" -" " +"docstring" msgstr "" -#: docstrings.py:25 +#: docstrings.py:27 #, docstring msgid "docstring1" msgstr "" -#: docstrings.py:30 +#: docstrings.py:38 +#, docstring +msgid "nested docstring" +msgstr "" + +#: docstrings.py:43 #, docstring -msgid "Hello, {}!" +msgid "nested class docstring" msgstr "" diff --git a/Lib/test/test_tools/i18n_data/docstrings.py b/Lib/test/test_tools/i18n_data/docstrings.py index 85d7f159d37..151a55a4b56 100644 --- a/Lib/test/test_tools/i18n_data/docstrings.py +++ b/Lib/test/test_tools/i18n_data/docstrings.py @@ -1,3 +1,5 @@ +"""Module docstring""" + # Test docstring extraction from gettext import gettext as _ @@ -10,10 +12,10 @@ def test(x): # Leading empty line def test2(x): - """docstring""" # XXX This should be extracted but isn't. + """docstring""" -# XXX Multiline docstrings should be cleaned with `inspect.cleandoc`. +# Multiline docstrings are cleaned with `inspect.cleandoc`. def test3(x): """multiline docstring @@ -27,15 +29,15 @@ def test4(x): def test5(x): - """Hello, {}!""".format("world!") # XXX This should not be extracted. + """Hello, {}!""".format("world!") # This should not be extracted. # Nested docstrings def test6(x): def inner(y): - """nested docstring""" # XXX This should be extracted but isn't. + """nested docstring""" class Outer: class Inner: - "nested class docstring" # XXX This should be extracted but isn't. + "nested class docstring" diff --git a/Lib/test/test_tools/i18n_data/escapes.pot b/Lib/test/test_tools/i18n_data/escapes.pot index 2c7899d59da..4dfac0f451d 100644 --- a/Lib/test/test_tools/i18n_data/escapes.pot +++ b/Lib/test/test_tools/i18n_data/escapes.pot @@ -15,30 +15,36 @@ msgstr "" "Generated-By: pygettext.py 1.5\n" +#. Special characters that are always escaped in the POT file #: escapes.py:5 msgid "" "\"\t\n" "\r\\" msgstr "" +#. All ascii characters 0-31 #: escapes.py:8 msgid "" "\000\001\002\003\004\005\006\007\010\t\n" "\013\014\r\016\017\020\021\022\023\024\025\026\027\030\031\032\033\034\035\036\037" msgstr "" +#. All ascii characters 32-126 #: escapes.py:13 msgid " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" msgstr "" +#. ascii char 127 #: escapes.py:17 msgid "\177" msgstr "" +#. some characters in the 128-255 range #: escapes.py:20 msgid "\302\200 \302\240 \303\277" msgstr "" +#. some characters >= 256 encoded as 2, 3 and 4 bytes, respectively #: escapes.py:23 msgid "\316\261 \343\204\261 \360\223\202\200" msgstr "" diff --git a/Lib/test/test_tools/i18n_data/messages.pot b/Lib/test/test_tools/i18n_data/messages.pot index ddfbd18349e..e8167acfc07 100644 --- a/Lib/test/test_tools/i18n_data/messages.pot +++ b/Lib/test/test_tools/i18n_data/messages.pot @@ -15,53 +15,85 @@ msgstr "" "Generated-By: pygettext.py 1.5\n" -#: messages.py:5 +#: messages.py:16 msgid "" msgstr "" -#: messages.py:8 messages.py:9 +#: messages.py:19 messages.py:20 messages.py:21 msgid "parentheses" msgstr "" -#: messages.py:12 +#: messages.py:24 msgid "Hello, world!" msgstr "" -#: messages.py:15 +#: messages.py:27 msgid "" "Hello,\n" " multiline!\n" msgstr "" -#: messages.py:29 +#: messages.py:46 messages.py:89 messages.py:90 messages.py:93 messages.py:94 +#: messages.py:99 messages.py:100 messages.py:101 +msgid "foo" +msgid_plural "foos" +msgstr[0] "" +msgstr[1] "" + +#: messages.py:47 +msgid "something" +msgstr "" + +#: messages.py:50 msgid "Hello, {}!" msgstr "" -#: messages.py:33 +#: messages.py:54 msgid "1" msgstr "" -#: messages.py:33 +#: messages.py:54 msgid "2" msgstr "" -#: messages.py:34 messages.py:35 +#: messages.py:55 messages.py:56 msgid "A" msgstr "" -#: messages.py:34 messages.py:35 +#: messages.py:55 messages.py:56 msgid "B" msgstr "" -#: messages.py:36 +#: messages.py:57 msgid "set" msgstr "" -#: messages.py:42 +#: messages.py:62 messages.py:63 msgid "nested string" msgstr "" -#: messages.py:47 +#: messages.py:68 msgid "baz" msgstr "" +#: messages.py:71 messages.py:75 +msgid "default value" +msgstr "" + +#: messages.py:91 messages.py:92 messages.py:95 messages.py:96 +msgctxt "context" +msgid "foo" +msgid_plural "foos" +msgstr[0] "" +msgstr[1] "" + +#: messages.py:102 +msgid "domain foo" +msgstr "" + +#: messages.py:118 messages.py:119 +msgid "world" +msgid_plural "worlds" +msgstr[0] "" +msgstr[1] "" + diff --git a/Lib/test/test_tools/i18n_data/messages.py b/Lib/test/test_tools/i18n_data/messages.py index f220294b8d5..9457bcb8611 100644 --- a/Lib/test/test_tools/i18n_data/messages.py +++ b/Lib/test/test_tools/i18n_data/messages.py @@ -1,5 +1,16 @@ # Test message extraction -from gettext import gettext as _ +from gettext import ( + gettext, + ngettext, + pgettext, + npgettext, + dgettext, + dngettext, + dpgettext, + dnpgettext +) + +_ = gettext # Empty string _("") @@ -7,6 +18,7 @@ # Extra parentheses (_("parentheses")) ((_("parentheses"))) +_(("parentheses")) # Multiline strings _("Hello, " @@ -21,13 +33,22 @@ _(None) _(1) _(False) -_(x="kwargs are not allowed") +_(["invalid"]) +_({"invalid"}) +_("string"[3]) +_("string"[:3]) +_({"string": "foo"}) + +# pygettext does not allow keyword arguments, but both xgettext and pybabel do +_(x="kwargs are not allowed!") + +# Unusual, but valid arguments _("foo", "bar") _("something", x="something else") # .format() _("Hello, {}!").format("world") # valid -_("Hello, {}!".format("world")) # invalid +_("Hello, {}!".format("world")) # invalid, but xgettext extracts the first string # Nested structures _("1"), _("2") @@ -38,7 +59,7 @@ # Nested functions and classes def test(): - _("nested string") # XXX This should be extracted but isn't. + _("nested string") [_("nested string")] @@ -47,11 +68,11 @@ def bar(self): return _("baz") -def bar(x=_('default value')): # XXX This should be extracted but isn't. +def bar(x=_('default value')): pass -def baz(x=[_('default value')]): # XXX This should be extracted but isn't. +def baz(x=[_('default value')]): pass @@ -62,3 +83,37 @@ def _(x): def _(x="don't extract me"): pass + + +# Other gettext functions +gettext("foo") +ngettext("foo", "foos", 1) +pgettext("context", "foo") +npgettext("context", "foo", "foos", 1) +dgettext("domain", "foo") +dngettext("domain", "foo", "foos", 1) +dpgettext("domain", "context", "foo") +dnpgettext("domain", "context", "foo", "foos", 1) + +# Complex arguments +ngettext("foo", "foos", 42 + (10 - 20)) +ngettext("foo", "foos", *args) +ngettext("foo", "foos", **kwargs) +dgettext(["some", {"complex"}, ("argument",)], "domain foo") + +# Invalid calls which are not extracted +gettext() +ngettext('foo') +pgettext('context') +npgettext('context', 'foo') +dgettext('domain') +dngettext('domain', 'foo') +dpgettext('domain', 'context') +dnpgettext('domain', 'context', 'foo') +dgettext(*args, 'foo') +dpgettext(*args, 'context', 'foo') +dnpgettext(*args, 'context', 'foo', 'foos') + +# f-strings +f"Hello, {_('world')}!" +f"Hello, {ngettext('world', 'worlds', 3)}!" diff --git a/Lib/test/test_tools/i18n_data/multiple_keywords.pot b/Lib/test/test_tools/i18n_data/multiple_keywords.pot new file mode 100644 index 00000000000..954cb8e9948 --- /dev/null +++ b/Lib/test/test_tools/i18n_data/multiple_keywords.pot @@ -0,0 +1,38 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"POT-Creation-Date: 2000-01-01 00:00+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: pygettext.py 1.5\n" + + +#: multiple_keywords.py:3 +msgid "bar" +msgstr "" + +#: multiple_keywords.py:5 +msgctxt "baz" +msgid "qux" +msgstr "" + +#: multiple_keywords.py:9 +msgctxt "corge" +msgid "grault" +msgstr "" + +#: multiple_keywords.py:11 +msgctxt "xyzzy" +msgid "foo" +msgid_plural "foos" +msgstr[0] "" +msgstr[1] "" + diff --git a/Lib/test/test_tools/i18n_data/multiple_keywords.py b/Lib/test/test_tools/i18n_data/multiple_keywords.py new file mode 100644 index 00000000000..7bde349505b --- /dev/null +++ b/Lib/test/test_tools/i18n_data/multiple_keywords.py @@ -0,0 +1,11 @@ +from gettext import gettext as foo + +foo('bar') + +foo('baz', 'qux') + +# The 't' specifier is not supported, so the following +# call is extracted as pgettext instead of ngettext. +foo('corge', 'grault', 1) + +foo('xyzzy', 'foo', 'foos', 1) diff --git a/Lib/test/test_tools/test_compute_changes.py b/Lib/test/test_tools/test_compute_changes.py new file mode 100644 index 00000000000..b20ff975fc2 --- /dev/null +++ b/Lib/test/test_tools/test_compute_changes.py @@ -0,0 +1,144 @@ +"""Tests to cover the Tools/build/compute-changes.py script.""" + +import importlib +import os +import unittest +from pathlib import Path +from unittest.mock import patch + +from test.test_tools import skip_if_missing, imports_under_tool + +skip_if_missing("build") + +with patch.dict(os.environ, {"GITHUB_DEFAULT_BRANCH": "main"}): + with imports_under_tool("build"): + compute_changes = importlib.import_module("compute-changes") + +process_changed_files = compute_changes.process_changed_files +Outputs = compute_changes.Outputs +ANDROID_DIRS = compute_changes.ANDROID_DIRS +IOS_DIRS = compute_changes.IOS_DIRS +MACOS_DIRS = compute_changes.MACOS_DIRS +WASI_DIRS = compute_changes.WASI_DIRS +RUN_TESTS_IGNORE = compute_changes.RUN_TESTS_IGNORE +UNIX_BUILD_SYSTEM_FILE_NAMES = compute_changes.UNIX_BUILD_SYSTEM_FILE_NAMES +LIBRARY_FUZZER_PATHS = compute_changes.LIBRARY_FUZZER_PATHS + + +class TestProcessChangedFiles(unittest.TestCase): + + def test_windows(self): + f = {Path(".github/workflows/reusable-windows.yml")} + result = process_changed_files(f) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_windows_tests) + + def test_docs(self): + for f in ( + ".github/workflows/reusable-docs.yml", + "Doc/library/datetime.rst", + "Doc/Makefile", + ): + with self.subTest(f=f): + result = process_changed_files({Path(f)}) + self.assertTrue(result.run_docs) + self.assertFalse(result.run_tests) + + def test_ci_fuzz_stdlib(self): + for p in LIBRARY_FUZZER_PATHS: + with self.subTest(p=p): + if p.is_dir(): + f = p / "file" + elif p.is_file(): + f = p + else: + continue + result = process_changed_files({f}) + self.assertTrue(result.run_ci_fuzz_stdlib) + + def test_android(self): + for d in ANDROID_DIRS: + with self.subTest(d=d): + result = process_changed_files({Path(d) / "file"}) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_android) + self.assertFalse(result.run_windows_tests) + + def test_ios(self): + for d in IOS_DIRS: + with self.subTest(d=d): + result = process_changed_files({Path(d) / "file"}) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_ios) + self.assertFalse(result.run_windows_tests) + + def test_macos(self): + f = {Path(".github/workflows/reusable-macos.yml")} + result = process_changed_files(f) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_macos) + + for d in MACOS_DIRS: + with self.subTest(d=d): + result = process_changed_files({Path(d) / "file"}) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_macos) + self.assertFalse(result.run_windows_tests) + + def test_wasi(self): + f = {Path(".github/workflows/reusable-wasi.yml")} + result = process_changed_files(f) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_wasi) + + for d in WASI_DIRS: + with self.subTest(d=d): + result = process_changed_files({d / "file"}) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_wasi) + self.assertFalse(result.run_windows_tests) + + def test_unix(self): + for f in UNIX_BUILD_SYSTEM_FILE_NAMES: + with self.subTest(f=f): + result = process_changed_files({f}) + self.assertTrue(result.run_tests) + self.assertFalse(result.run_windows_tests) + + def test_msi(self): + for f in ( + ".github/workflows/reusable-windows-msi.yml", + "Tools/msi/build.bat", + ): + with self.subTest(f=f): + result = process_changed_files({Path(f)}) + self.assertTrue(result.run_windows_msi) + + def test_all_run(self): + for f in ( + ".github/workflows/some-new-workflow.yml", + ".github/workflows/build.yml", + ): + with self.subTest(f=f): + result = process_changed_files({Path(f)}) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_android) + self.assertTrue(result.run_ios) + self.assertTrue(result.run_macos) + self.assertTrue(result.run_ubuntu) + self.assertTrue(result.run_wasi) + + def test_all_ignored(self): + for f in RUN_TESTS_IGNORE: + with self.subTest(f=f): + self.assertEqual(process_changed_files({Path(f)}), Outputs()) + + def test_wasi_and_android(self): + f = {Path(".github/workflows/reusable-wasi.yml"), Path("Android/file")} + result = process_changed_files(f) + self.assertTrue(result.run_tests) + self.assertTrue(result.run_wasi) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_tools/test_i18n.py b/Lib/test/test_tools/test_i18n.py index ffa1b1178ed..d1831d68f02 100644 --- a/Lib/test/test_tools/test_i18n.py +++ b/Lib/test/test_tools/test_i18n.py @@ -8,7 +8,7 @@ from pathlib import Path from test.support.script_helper import assert_python_ok -from test.test_tools import skip_if_missing, toolsdir +from test.test_tools import imports_under_tool, skip_if_missing, toolsdir from test.support.os_helper import temp_cwd, temp_dir @@ -17,6 +17,11 @@ DATA_DIR = Path(__file__).resolve().parent / 'i18n_data' +with imports_under_tool("i18n"): + from pygettext import (parse_spec, process_keywords, DEFAULTKEYWORDS, + unparse_spec) + + def normalize_POT_file(pot): """Normalize the POT creation timestamp, charset and file locations to make the POT file easier to compare. @@ -87,7 +92,8 @@ def assert_POT_equal(self, expected, actual): self.maxDiff = None self.assertEqual(normalize_POT_file(expected), normalize_POT_file(actual)) - def extract_from_str(self, module_content, *, args=(), strict=True): + def extract_from_str(self, module_content, *, args=(), strict=True, + with_stderr=False, raw=False): """Return all msgids extracted from module_content.""" filename = 'test.py' with temp_cwd(None): @@ -98,12 +104,19 @@ def extract_from_str(self, module_content, *, args=(), strict=True): self.assertEqual(res.err, b'') with open('messages.pot', encoding='utf-8') as fp: data = fp.read() - return self.get_msgids(data) + if not raw: + data = self.get_msgids(data) + if not with_stderr: + return data + return data, res.err def extract_docstrings_from_str(self, module_content): """Return all docstrings extracted from module_content.""" return self.extract_from_str(module_content, args=('--docstrings',), strict=False) + def get_stderr(self, module_content): + return self.extract_from_str(module_content, strict=False, with_stderr=True)[1] + def test_header(self): """Make sure the required fields are in the header, according to: http://www.gnu.org/software/gettext/manual/gettext.html#Header-Entry @@ -149,6 +162,14 @@ def test_POT_Creation_Date(self): # This will raise if the date format does not exactly match. datetime.strptime(creationDate, '%Y-%m-%d %H:%M%z') + def test_output_option(self): + for opt in ('-o', '--output='): + with temp_cwd(): + assert_python_ok(self.script, f'{opt}test') + self.assertTrue(os.path.exists('test')) + res = assert_python_ok(self.script, f'{opt}-') + self.assertIn(b'Project-Id-Version: PACKAGE VERSION', res.out) + def test_funcdocstring(self): for doc in ('"""doc"""', "r'''doc'''", "R'doc'", 'u"doc"'): with self.subTest(doc): @@ -332,14 +353,14 @@ def test_calls_in_fstring_with_multiple_args(self): msgids = self.extract_docstrings_from_str(dedent('''\ f"{_('foo', 'bar')}" ''')) - self.assertNotIn('foo', msgids) + self.assertIn('foo', msgids) self.assertNotIn('bar', msgids) def test_calls_in_fstring_with_keyword_args(self): msgids = self.extract_docstrings_from_str(dedent('''\ f"{_('foo', bar='baz')}" ''')) - self.assertNotIn('foo', msgids) + self.assertIn('foo', msgids) self.assertNotIn('bar', msgids) self.assertNotIn('baz', msgids) @@ -400,17 +421,195 @@ def test_files_list(self): self.assertIn(f'msgid "{text2}"', data) self.assertNotIn(text3, data) + def test_help_text(self): + """Test that the help text is displayed.""" + res = assert_python_ok(self.script, '--help') + self.assertEqual(res.out, b'') + self.assertIn(b'pygettext -- Python equivalent of xgettext(1)', res.err) + + def test_error_messages(self): + """Test that pygettext outputs error messages to stderr.""" + stderr = self.get_stderr(dedent('''\ + _(1+2) + ngettext('foo') + dgettext(*args, 'foo') + ''')) + + # Normalize line endings on Windows + stderr = stderr.decode('utf-8').replace('\r', '') + + self.assertEqual( + stderr, + "*** test.py:1: Expected a string constant for argument 1, got 1 + 2\n" + "*** test.py:2: Expected at least 2 positional argument(s) in gettext call, got 1\n" + "*** test.py:3: Variable positional arguments are not allowed in gettext calls\n" + ) + + def test_extract_all_comments(self): + """ + Test that the --add-comments option without an + explicit tag extracts all translator comments. + """ + for arg in ('--add-comments', '-c'): + with self.subTest(arg=arg): + data = self.extract_from_str(dedent('''\ + # Translator comment + _("foo") + '''), args=(arg,), raw=True) + self.assertIn('#. Translator comment', data) + + def test_comments_with_multiple_tags(self): + """ + Test that multiple --add-comments tags can be specified. + """ + for arg in ('--add-comments={}', '-c{}'): + with self.subTest(arg=arg): + args = (arg.format('foo:'), arg.format('bar:')) + data = self.extract_from_str(dedent('''\ + # foo: comment + _("foo") + + # bar: comment + _("bar") + + # baz: comment + _("baz") + '''), args=args, raw=True) + self.assertIn('#. foo: comment', data) + self.assertIn('#. bar: comment', data) + self.assertNotIn('#. baz: comment', data) + + def test_comments_not_extracted_without_tags(self): + """ + Test that translator comments are not extracted without + specifying --add-comments. + """ + data = self.extract_from_str(dedent('''\ + # Translator comment + _("foo") + '''), raw=True) + self.assertNotIn('#.', data) + + def test_parse_keyword_spec(self): + valid = ( + ('foo', ('foo', {'msgid': 0})), + ('foo:1', ('foo', {'msgid': 0})), + ('foo:1,2', ('foo', {'msgid': 0, 'msgid_plural': 1})), + ('foo:1, 2', ('foo', {'msgid': 0, 'msgid_plural': 1})), + ('foo:1,2c', ('foo', {'msgid': 0, 'msgctxt': 1})), + ('foo:2c,1', ('foo', {'msgid': 0, 'msgctxt': 1})), + ('foo:2c ,1', ('foo', {'msgid': 0, 'msgctxt': 1})), + ('foo:1,2,3c', ('foo', {'msgid': 0, 'msgid_plural': 1, 'msgctxt': 2})), + ('foo:1, 2, 3c', ('foo', {'msgid': 0, 'msgid_plural': 1, 'msgctxt': 2})), + ('foo:3c,1,2', ('foo', {'msgid': 0, 'msgid_plural': 1, 'msgctxt': 2})), + ) + for spec, expected in valid: + with self.subTest(spec=spec): + self.assertEqual(parse_spec(spec), expected) + # test unparse-parse round-trip + self.assertEqual(parse_spec(unparse_spec(*expected)), expected) + + invalid = ( + ('foo:', "Invalid keyword spec 'foo:': missing argument positions"), + ('foo:bar', "Invalid keyword spec 'foo:bar': position is not an integer"), + ('foo:0', "Invalid keyword spec 'foo:0': argument positions must be strictly positive"), + ('foo:-2', "Invalid keyword spec 'foo:-2': argument positions must be strictly positive"), + ('foo:1,1', "Invalid keyword spec 'foo:1,1': duplicate positions"), + ('foo:1,2,1', "Invalid keyword spec 'foo:1,2,1': duplicate positions"), + ('foo:1c,2,1c', "Invalid keyword spec 'foo:1c,2,1c': duplicate positions"), + ('foo:1c,2,3c', "Invalid keyword spec 'foo:1c,2,3c': msgctxt can only appear once"), + ('foo:1,2,3', "Invalid keyword spec 'foo:1,2,3': too many positions"), + ('foo:1c', "Invalid keyword spec 'foo:1c': msgctxt cannot appear without msgid"), + ) + for spec, message in invalid: + with self.subTest(spec=spec): + with self.assertRaises(ValueError) as cm: + parse_spec(spec) + self.assertEqual(str(cm.exception), message) + + def test_process_keywords(self): + default_keywords = {name: [spec] for name, spec + in DEFAULTKEYWORDS.items()} + inputs = ( + (['foo'], True), + (['_:1,2'], True), + (['foo', 'foo:1,2'], True), + (['foo'], False), + (['_:1,2', '_:1c,2,3', 'pgettext'], False), + # Duplicate entries + (['foo', 'foo'], True), + (['_'], False) + ) + expected = ( + {'foo': [{'msgid': 0}]}, + {'_': [{'msgid': 0, 'msgid_plural': 1}]}, + {'foo': [{'msgid': 0}, {'msgid': 0, 'msgid_plural': 1}]}, + default_keywords | {'foo': [{'msgid': 0}]}, + default_keywords | {'_': [{'msgid': 0, 'msgid_plural': 1}, + {'msgctxt': 0, 'msgid': 1, 'msgid_plural': 2}, + {'msgid': 0}], + 'pgettext': [{'msgid': 0}, + {'msgctxt': 0, 'msgid': 1}]}, + {'foo': [{'msgid': 0}]}, + default_keywords, + ) + for (keywords, no_default_keywords), expected in zip(inputs, expected): + with self.subTest(keywords=keywords, + no_default_keywords=no_default_keywords): + processed = process_keywords( + keywords, + no_default_keywords=no_default_keywords) + self.assertEqual(processed, expected) + + def test_multiple_keywords_same_funcname_errors(self): + # If at least one keyword spec for a given funcname matches, + # no error should be printed. + msgids, stderr = self.extract_from_str(dedent('''\ + _("foo", 42) + _(42, "bar") + '''), args=('--keyword=_:1', '--keyword=_:2'), with_stderr=True) + self.assertIn('foo', msgids) + self.assertIn('bar', msgids) + self.assertEqual(stderr, b'') + + # If no keyword spec for a given funcname matches, + # all errors are printed. + msgids, stderr = self.extract_from_str(dedent('''\ + _(x, 42) + _(42, y) + '''), args=('--keyword=_:1', '--keyword=_:2'), with_stderr=True, + strict=False) + self.assertEqual(msgids, ['']) + # Normalize line endings on Windows + stderr = stderr.decode('utf-8').replace('\r', '') + self.assertEqual( + stderr, + '*** test.py:1: No keywords matched gettext call "_":\n' + '\tkeyword="_": Expected a string constant for argument 1, got x\n' + '\tkeyword="_:2": Expected a string constant for argument 2, got 42\n' + '*** test.py:2: No keywords matched gettext call "_":\n' + '\tkeyword="_": Expected a string constant for argument 1, got 42\n' + '\tkeyword="_:2": Expected a string constant for argument 2, got y\n') + def extract_from_snapshots(): snapshots = { - 'messages.py': ('--docstrings',), + 'messages.py': (), 'fileloc.py': ('--docstrings',), 'docstrings.py': ('--docstrings',), + 'comments.py': ('--add-comments=i18n:',), + 'custom_keywords.py': ('--keyword=foo', '--keyword=nfoo:1,2', + '--keyword=pfoo:1c,2', + '--keyword=npfoo:1c,2,3', '--keyword=_:1,2'), + 'multiple_keywords.py': ('--keyword=foo:1c,2,3', '--keyword=foo:1c,2', + '--keyword=foo:1,2', + # repeat a keyword to make sure it is extracted only once + '--keyword=foo', '--keyword=foo'), # == Test character escaping # Escape ascii and unicode: - 'escapes.py': ('--escape',), + 'escapes.py': ('--escape', '--add-comments='), # Escape only ascii and let unicode pass through: - ('escapes.py', 'ascii-escapes.pot'): (), + ('escapes.py', 'ascii-escapes.pot'): ('--add-comments=',), } for filename, args in snapshots.items(): diff --git a/Lib/test/test_tools/test_makefile.py b/Lib/test/test_tools/test_makefile.py index 4c7588d4d93..31a51606739 100644 --- a/Lib/test/test_tools/test_makefile.py +++ b/Lib/test/test_tools/test_makefile.py @@ -48,15 +48,18 @@ def test_makefile_test_folders(self): if dirname == '__pycache__' or dirname.startswith('.'): dirs.clear() # do not process subfolders continue - # Skip empty dirs: + + # Skip empty dirs (ignoring hidden files and __pycache__): + files = [ + filename for filename in files + if not filename.startswith('.') + ] + dirs = [ + dirname for dirname in dirs + if not dirname.startswith('.') and dirname != "__pycache__" + ] if not dirs and not files: continue - # Skip dirs with hidden-only files: - if files and all( - filename.startswith('.') or filename == '__pycache__' - for filename in files - ): - continue relpath = os.path.relpath(dirpath, support.STDLIB_DIR) with self.subTest(relpath=relpath): diff --git a/Lib/test/test_tools/test_msgfmt.py b/Lib/test/test_tools/test_msgfmt.py index 8cd31680f76..7be606bbff6 100644 --- a/Lib/test/test_tools/test_msgfmt.py +++ b/Lib/test/test_tools/test_msgfmt.py @@ -1,6 +1,7 @@ """Tests for the Tools/i18n/msgfmt.py tool.""" import json +import struct import sys import unittest from gettext import GNUTranslations @@ -8,18 +9,21 @@ from test.support.os_helper import temp_cwd from test.support.script_helper import assert_python_failure, assert_python_ok -from test.test_tools import skip_if_missing, toolsdir +from test.test_tools import imports_under_tool, skip_if_missing, toolsdir skip_if_missing('i18n') data_dir = (Path(__file__).parent / 'msgfmt_data').resolve() script_dir = Path(toolsdir) / 'i18n' -msgfmt = script_dir / 'msgfmt.py' +msgfmt_py = script_dir / 'msgfmt.py' + +with imports_under_tool("i18n"): + import msgfmt def compile_messages(po_file, mo_file): - assert_python_ok(msgfmt, '-o', mo_file, po_file) + assert_python_ok(msgfmt_py, '-o', mo_file, po_file) class CompilationTest(unittest.TestCase): @@ -40,6 +44,31 @@ def test_compilation(self): self.assertDictEqual(actual._catalog, expected._catalog) + def test_binary_header(self): + with temp_cwd(): + tmp_mo_file = 'messages.mo' + compile_messages(data_dir / "general.po", tmp_mo_file) + with open(tmp_mo_file, 'rb') as f: + mo_data = f.read() + + ( + magic, + version, + num_strings, + orig_table_offset, + trans_table_offset, + hash_table_size, + hash_table_offset, + ) = struct.unpack("=7I", mo_data[:28]) + + self.assertEqual(magic, 0x950412de) + self.assertEqual(version, 0) + self.assertEqual(num_strings, 9) + self.assertEqual(orig_table_offset, 28) + self.assertEqual(trans_table_offset, 100) + self.assertEqual(hash_table_size, 0) + self.assertEqual(hash_table_offset, 0) + def test_translations(self): with open(data_dir / 'general.mo', 'rb') as f: t = GNUTranslations(f) @@ -62,6 +91,14 @@ def test_translations(self): '%d emails sent.', 2), '%d emails sent.') + def test_po_with_bom(self): + with temp_cwd(): + Path('bom.po').write_bytes(b'\xef\xbb\xbfmsgid "Python"\nmsgstr "Pioton"\n') + + res = assert_python_failure(msgfmt_py, 'bom.po') + err = res.err.decode('utf-8') + self.assertIn('The file bom.po starts with a UTF-8 BOM', err) + def test_invalid_msgid_plural(self): with temp_cwd(): Path('invalid.po').write_text('''\ @@ -69,7 +106,7 @@ def test_invalid_msgid_plural(self): msgstr[0] "singular" ''') - res = assert_python_failure(msgfmt, 'invalid.po') + res = assert_python_failure(msgfmt_py, 'invalid.po') err = res.err.decode('utf-8') self.assertIn('msgid_plural not preceded by msgid', err) @@ -80,7 +117,7 @@ def test_plural_without_msgid_plural(self): msgstr[0] "bar" ''') - res = assert_python_failure(msgfmt, 'invalid.po') + res = assert_python_failure(msgfmt_py, 'invalid.po') err = res.err.decode('utf-8') self.assertIn('plural without msgid_plural', err) @@ -92,7 +129,7 @@ def test_indexed_msgstr_without_msgid_plural(self): msgstr "bar" ''') - res = assert_python_failure(msgfmt, 'invalid.po') + res = assert_python_failure(msgfmt_py, 'invalid.po') err = res.err.decode('utf-8') self.assertIn('indexed msgstr required for plural', err) @@ -102,38 +139,136 @@ def test_generic_syntax_error(self): "foo" ''') - res = assert_python_failure(msgfmt, 'invalid.po') + res = assert_python_failure(msgfmt_py, 'invalid.po') err = res.err.decode('utf-8') self.assertIn('Syntax error', err) + +class POParserTest(unittest.TestCase): + @classmethod + def tearDownClass(cls): + # msgfmt uses a global variable to store messages, + # clear it after the tests. + msgfmt.MESSAGES.clear() + + def test_strings(self): + # Test that the PO parser correctly handles and unescape + # strings in the PO file. + # The PO file format allows for a variety of escape sequences, + # octal and hex escapes. + valid_strings = ( + # empty strings + ('""', ''), + ('"" "" ""', ''), + # allowed escape sequences + (r'"\\"', '\\'), + (r'"\""', '"'), + (r'"\t"', '\t'), + (r'"\n"', '\n'), + (r'"\r"', '\r'), + (r'"\f"', '\f'), + (r'"\a"', '\a'), + (r'"\b"', '\b'), + (r'"\v"', '\v'), + # non-empty strings + ('"foo"', 'foo'), + ('"foo" "bar"', 'foobar'), + ('"foo""bar"', 'foobar'), + ('"" "foo" ""', 'foo'), + # newlines and tabs + (r'"foo\nbar"', 'foo\nbar'), + (r'"foo\n" "bar"', 'foo\nbar'), + (r'"foo\tbar"', 'foo\tbar'), + (r'"foo\t" "bar"', 'foo\tbar'), + # escaped quotes + (r'"foo\"bar"', 'foo"bar'), + (r'"foo\"" "bar"', 'foo"bar'), + (r'"foo\\" "bar"', 'foo\\bar'), + # octal escapes + (r'"\120\171\164\150\157\156"', 'Python'), + (r'"\120\171\164" "\150\157\156"', 'Python'), + (r'"\"\120\171\164" "\150\157\156\""', '"Python"'), + # hex escapes + (r'"\x50\x79\x74\x68\x6f\x6e"', 'Python'), + (r'"\x50\x79\x74" "\x68\x6f\x6e"', 'Python'), + (r'"\"\x50\x79\x74" "\x68\x6f\x6e\""', '"Python"'), + ) + + with temp_cwd(): + for po_string, expected in valid_strings: + with self.subTest(po_string=po_string): + # Construct a PO file with a single entry, + # compile it, read it into a catalog and + # check the result. + po = f'msgid {po_string}\nmsgstr "translation"' + Path('messages.po').write_text(po) + # Reset the global MESSAGES dictionary + msgfmt.MESSAGES.clear() + msgfmt.make('messages.po', 'messages.mo') + + with open('messages.mo', 'rb') as f: + actual = GNUTranslations(f) + + self.assertDictEqual(actual._catalog, {expected: 'translation'}) + + invalid_strings = ( + # "''", # invalid but currently accepted + '"', + '"""', + '"" "', + 'foo', + '"" "foo', + '"foo" foo', + '42', + '"" 42 ""', + # disallowed escape sequences + # r'"\'"', # invalid but currently accepted + # r'"\e"', # invalid but currently accepted + # r'"\8"', # invalid but currently accepted + # r'"\9"', # invalid but currently accepted + r'"\x"', + r'"\u1234"', + r'"\N{ROMAN NUMERAL NINE}"' + ) + with temp_cwd(): + for invalid_string in invalid_strings: + with self.subTest(string=invalid_string): + po = f'msgid {invalid_string}\nmsgstr "translation"' + Path('messages.po').write_text(po) + # Reset the global MESSAGES dictionary + msgfmt.MESSAGES.clear() + with self.assertRaises(Exception): + msgfmt.make('messages.po', 'messages.mo') + + class CLITest(unittest.TestCase): def test_help(self): for option in ('--help', '-h'): - res = assert_python_ok(msgfmt, option) + res = assert_python_ok(msgfmt_py, option) err = res.err.decode('utf-8') self.assertIn('Generate binary message catalog from textual translation description.', err) def test_version(self): for option in ('--version', '-V'): - res = assert_python_ok(msgfmt, option) + res = assert_python_ok(msgfmt_py, option) out = res.out.decode('utf-8').strip() self.assertEqual('msgfmt.py 1.2', out) def test_invalid_option(self): - res = assert_python_failure(msgfmt, '--invalid-option') + res = assert_python_failure(msgfmt_py, '--invalid-option') err = res.err.decode('utf-8') self.assertIn('Generate binary message catalog from textual translation description.', err) self.assertIn('option --invalid-option not recognized', err) def test_no_input_file(self): - res = assert_python_ok(msgfmt) + res = assert_python_ok(msgfmt_py) err = res.err.decode('utf-8').replace('\r\n', '\n') self.assertIn('No input file given\n' "Try `msgfmt --help' for more information.", err) def test_nonexistent_file(self): - assert_python_failure(msgfmt, 'nonexistent.po') + assert_python_failure(msgfmt_py, 'nonexistent.po') def update_catalog_snapshots(): diff --git a/Lib/test/test_type_cache.py b/Lib/test/test_type_cache.py new file mode 100644 index 00000000000..7469a1047f8 --- /dev/null +++ b/Lib/test/test_type_cache.py @@ -0,0 +1,265 @@ +""" Tests for the internal type cache in CPython. """ +import dis +import unittest +import warnings +from test import support +from test.support import import_helper, requires_specialization, requires_specialization_ft +try: + from sys import _clear_type_cache +except ImportError: + _clear_type_cache = None + +# Skip this test if the _testcapi module isn't available. +_testcapi = import_helper.import_module("_testcapi") +_testinternalcapi = import_helper.import_module("_testinternalcapi") +type_get_version = _testcapi.type_get_version +type_assign_specific_version_unsafe = _testinternalcapi.type_assign_specific_version_unsafe +type_assign_version = _testcapi.type_assign_version +type_modified = _testcapi.type_modified + +def clear_type_cache(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + _clear_type_cache() + +@support.cpython_only +@unittest.skipIf(_clear_type_cache is None, "requires sys._clear_type_cache") +class TypeCacheTests(unittest.TestCase): + def test_tp_version_tag_unique(self): + """tp_version_tag should be unique assuming no overflow, even after + clearing type cache. + """ + # Check if global version tag has already overflowed. + Y = type('Y', (), {}) + Y.x = 1 + Y.x # Force a _PyType_Lookup, populating version tag + y_ver = type_get_version(Y) + # Overflow, or not enough left to conduct the test. + if y_ver == 0 or y_ver > 0xFFFFF000: + self.skipTest("Out of type version tags") + # Note: try to avoid any method lookups within this loop, + # It will affect global version tag. + all_version_tags = [] + append_result = all_version_tags.append + assertNotEqual = self.assertNotEqual + for _ in range(30): + clear_type_cache() + X = type('Y', (), {}) + X.x = 1 + X.x + tp_version_tag_after = type_get_version(X) + assertNotEqual(tp_version_tag_after, 0, msg="Version overflowed") + append_result(tp_version_tag_after) + self.assertEqual(len(set(all_version_tags)), 30, + msg=f"{all_version_tags} contains non-unique versions") + + def test_type_assign_version(self): + class C: + x = 5 + + self.assertEqual(type_assign_version(C), 1) + c_ver = type_get_version(C) + + C.x = 6 + self.assertEqual(type_get_version(C), 0) + self.assertEqual(type_assign_version(C), 1) + self.assertNotEqual(type_get_version(C), 0) + self.assertNotEqual(type_get_version(C), c_ver) + + def test_type_assign_specific_version(self): + """meta-test for type_assign_specific_version_unsafe""" + class C: + pass + + type_assign_version(C) + orig_version = type_get_version(C) + if orig_version == 0: + self.skipTest("Could not assign a valid type version") + + type_modified(C) + type_assign_specific_version_unsafe(C, orig_version + 5) + type_assign_version(C) # this should do nothing + + new_version = type_get_version(C) + self.assertEqual(new_version, orig_version + 5) + + clear_type_cache() + + def test_per_class_limit(self): + class C: + x = 0 + + type_assign_version(C) + orig_version = type_get_version(C) + for i in range(1001): + C.x = i + type_assign_version(C) + + new_version = type_get_version(C) + self.assertEqual(new_version, 0) + + def test_119462(self): + + class Holder: + value = None + + @classmethod + def set_value(cls): + cls.value = object() + + class HolderSub(Holder): + pass + + for _ in range(1050): + Holder.set_value() + HolderSub.value + +@support.cpython_only +class TypeCacheWithSpecializationTests(unittest.TestCase): + def tearDown(self): + clear_type_cache() + + def _assign_valid_version_or_skip(self, type_): + type_modified(type_) + type_assign_version(type_) + if type_get_version(type_) == 0: + self.skipTest("Could not assign valid type version") + + def _no_more_versions(self, user_type): + type_modified(user_type) + for _ in range(1001): + type_assign_specific_version_unsafe(user_type, 1000_000_000) + type_assign_specific_version_unsafe(user_type, 0) + self.assertEqual(type_get_version(user_type), 0) + + def _all_opnames(self, func): + return set(instr.opname for instr in dis.Bytecode(func, adaptive=True)) + + def _check_specialization(self, func, arg, opname, *, should_specialize): + for _ in range(_testinternalcapi.SPECIALIZATION_THRESHOLD): + func(arg) + + if should_specialize: + self.assertNotIn(opname, self._all_opnames(func)) + else: + self.assertIn(opname, self._all_opnames(func)) + + @requires_specialization + def test_class_load_attr_specialization_user_type(self): + class A: + def foo(self): + pass + + self._assign_valid_version_or_skip(A) + + def load_foo_1(type_): + type_.foo + + self._check_specialization(load_foo_1, A, "LOAD_ATTR", should_specialize=True) + del load_foo_1 + + self._no_more_versions(A) + + def load_foo_2(type_): + return type_.foo + + self._check_specialization(load_foo_2, A, "LOAD_ATTR", should_specialize=False) + + @requires_specialization + def test_class_load_attr_specialization_static_type(self): + self.assertNotEqual(type_get_version(str), 0) + self.assertNotEqual(type_get_version(bytes), 0) + + def get_capitalize_1(type_): + return type_.capitalize + + self._check_specialization(get_capitalize_1, str, "LOAD_ATTR", should_specialize=True) + self.assertEqual(get_capitalize_1(str)('hello'), 'Hello') + self.assertEqual(get_capitalize_1(bytes)(b'hello'), b'Hello') + + @requires_specialization + def test_property_load_attr_specialization_user_type(self): + class G: + @property + def x(self): + return 9 + + self._assign_valid_version_or_skip(G) + + def load_x_1(instance): + instance.x + + self._check_specialization(load_x_1, G(), "LOAD_ATTR", should_specialize=True) + del load_x_1 + + self._no_more_versions(G) + + def load_x_2(instance): + instance.x + + self._check_specialization(load_x_2, G(), "LOAD_ATTR", should_specialize=False) + + @requires_specialization + def test_store_attr_specialization_user_type(self): + class B: + __slots__ = ("bar",) + + self._assign_valid_version_or_skip(B) + + def store_bar_1(type_): + type_.bar = 10 + + self._check_specialization(store_bar_1, B(), "STORE_ATTR", should_specialize=True) + del store_bar_1 + + self._no_more_versions(B) + + def store_bar_2(type_): + type_.bar = 10 + + self._check_specialization(store_bar_2, B(), "STORE_ATTR", should_specialize=False) + + @requires_specialization_ft + def test_class_call_specialization_user_type(self): + class F: + def __init__(self): + pass + + self._assign_valid_version_or_skip(F) + + def call_class_1(type_): + type_() + + self._check_specialization(call_class_1, F, "CALL", should_specialize=True) + del call_class_1 + + self._no_more_versions(F) + + def call_class_2(type_): + type_() + + self._check_specialization(call_class_2, F, "CALL", should_specialize=False) + + @requires_specialization + def test_to_bool_specialization_user_type(self): + class H: + pass + + self._assign_valid_version_or_skip(H) + + def to_bool_1(instance): + not instance + + self._check_specialization(to_bool_1, H(), "TO_BOOL", should_specialize=True) + del to_bool_1 + + self._no_more_versions(H) + + def to_bool_2(instance): + not instance + + self._check_specialization(to_bool_2, H(), "TO_BOOL", should_specialize=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 588edc376a5..01da70b4c68 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -431,7 +431,6 @@ def test(i, format_spec, result): test(123456, "1=20", '11111111111111123456') test(123456, "*=20", '**************123456') - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; AssertionError: '1,234.57' != '1234.57'") @run_with_locale('LC_NUMERIC', 'en_US.UTF8', '') def test_float__format__locale(self): # test locale support for __format__ code 'n' @@ -441,7 +440,6 @@ def test_float__format__locale(self): self.assertEqual(locale.format_string('%g', x, grouping=True), format(x, 'n')) self.assertEqual(locale.format_string('%.10g', x, grouping=True), format(x, '.10n')) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; AssertionError: '123,456,789,012,345,678,901,234,567,890' != '123456789012345678901234567890'") @run_with_locale('LC_NUMERIC', 'en_US.UTF8', '') def test_int__format__locale(self): # test locale support for __format__ code 'n' for integers @@ -700,8 +698,7 @@ def test_traceback_and_frame_types(self): self.assertIsInstance(exc.__traceback__, types.TracebackType) self.assertIsInstance(exc.__traceback__.tb_frame, types.FrameType) - # XXX: RUSTPYTHON - @unittest.skipUnless(_datetime, "requires _datetime module") + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'NoneType' object has no attribute 'datetime_CAPI' def test_capsule_type(self): self.assertIsInstance(_datetime.datetime_CAPI, types.CapsuleType) @@ -2127,6 +2124,21 @@ class Spam(types.SimpleNamespace): self.assertIs(type(spam2), Spam) self.assertEqual(vars(spam2), {'ham': 5, 'eggs': 9}) + def test_replace_invalid_subtype(self): + # See https://github.com/python/cpython/issues/143636. + class MyNS(types.SimpleNamespace): + def __new__(cls, *args, **kwargs): + if created: + return 12345 + return super().__new__(cls) + + created = False + ns = MyNS() + created = True + err = (r"^expect types\.SimpleNamespace type, " + r"but .+\.MyNS\(\) returned 'int' object") + self.assertRaisesRegex(TypeError, err, copy.replace, ns) + def test_fake_namespace_compare(self): # Issue #24257: Incorrect use of PyObject_IsInstance() caused # SystemError. diff --git a/Lib/test/test_unittest/test_async_case.py b/Lib/test/test_unittest/test_async_case.py index 9b1678caf59..91d45283eb3 100644 --- a/Lib/test/test_unittest/test_async_case.py +++ b/Lib/test/test_unittest/test_async_case.py @@ -296,7 +296,6 @@ async def on_cleanup2(self): test.doCleanups() self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup2', 'cleanup1']) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_deprecation_of_return_val_from_test(self): # Issue 41322 - deprecate return of value that is not None from a test class Nothing: diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index a7c347693f9..43db6c9b49a 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -2672,7 +2672,6 @@ def __deepcopy__(self, memo): e[:] = [E('bar')] self.assertRaises(TypeError, copy.deepcopy, e) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_cyclic_gc(self): class Dummy: pass diff --git a/Lib/test/test_xpickle.py b/Lib/test/test_xpickle.py new file mode 100644 index 00000000000..d87c671d4f5 --- /dev/null +++ b/Lib/test/test_xpickle.py @@ -0,0 +1,281 @@ +# This test covers backwards compatibility with previous versions of Python +# by bouncing pickled objects through Python versions by running xpickle_worker.py. +import io +import os +import pickle +import struct +import subprocess +import sys +import unittest + + +from test import support +from test import pickletester + +try: + import _pickle + has_c_implementation = True +except ModuleNotFoundError: + has_c_implementation = False + +support.requires('xpickle') + +is_windows = sys.platform.startswith('win') + +# Map python version to a tuple containing the name of a corresponding valid +# Python binary to execute and its arguments. +py_executable_map = {} + +protocols_map = { + 3: (3, 0), + 4: (3, 4), + 5: (3, 8), +} + +def highest_proto_for_py_version(py_version): + """Finds the highest supported pickle protocol for a given Python version. + Args: + py_version: a 2-tuple of the major, minor version. Eg. Python 3.7 would + be (3, 7) + Returns: + int for the highest supported pickle protocol + """ + proto = 2 + for p, v in protocols_map.items(): + if py_version < v: + break + proto = p + return proto + +def have_python_version(py_version): + """Check whether a Python binary exists for the given py_version and has + support. This respects your PATH. + For Windows, it will first try to use the py launcher specified in PEP 397. + Otherwise (and for all other platforms), it will attempt to check for + python.. + + Eg. given a *py_version* of (3, 7), the function will attempt to try + 'py -3.7' (for Windows) first, then 'python3.7', and return + ['py', '-3.7'] (on Windows) or ['python3.7'] on other platforms. + + Args: + py_version: a 2-tuple of the major, minor version. Eg. python 3.7 would + be (3, 7) + Returns: + List/Tuple containing the Python binary name and its required arguments, + or None if no valid binary names found. + """ + python_str = ".".join(map(str, py_version)) + targets = [('py', f'-{python_str}'), (f'python{python_str}',)] + if py_version not in py_executable_map: + for target in targets[0 if is_windows else 1:]: + try: + worker = subprocess.Popen([*target, '-c', 'pass'], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + shell=is_windows) + worker.communicate() + if worker.returncode == 0: + py_executable_map[py_version] = target + break + except FileNotFoundError: + pass + + return py_executable_map.get(py_version, None) + + +def read_exact(f, n): + buf = b'' + while len(buf) < n: + chunk = f.read(n - len(buf)) + if not chunk: + raise EOFError + buf += chunk + return buf + + +class AbstractCompatTests(pickletester.AbstractPickleTests): + py_version = None + worker = None + + @classmethod + def setUpClass(cls): + assert cls.py_version is not None, 'Needs a python version tuple' + if not have_python_version(cls.py_version): + py_version_str = ".".join(map(str, cls.py_version)) + raise unittest.SkipTest(f'Python {py_version_str} not available') + cls.addClassCleanup(cls.finish_worker) + # Override the default pickle protocol to match what xpickle worker + # will be running. + highest_protocol = highest_proto_for_py_version(cls.py_version) + cls.enterClassContext(support.swap_attr(pickletester, 'protocols', + range(highest_protocol + 1))) + cls.enterClassContext(support.swap_attr(pickle, 'HIGHEST_PROTOCOL', + highest_protocol)) + + @classmethod + def start_worker(cls, python): + target = os.path.join(os.path.dirname(__file__), 'xpickle_worker.py') + worker = subprocess.Popen([*python, target], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + # For windows bpo-17023. + shell=is_windows) + cls.worker = worker + return worker + + @classmethod + def finish_worker(cls): + worker = cls.worker + if worker is None: + return + cls.worker = None + worker.stdin.close() + worker.stdout.close() + worker.stderr.close() + worker.terminate() + worker.wait() + + @classmethod + def send_to_worker(cls, python, data): + """Bounce a pickled object through another version of Python. + This will send data to a child process where it will + be unpickled, then repickled and sent back to the parent process. + Args: + python: list containing the python binary to start and its arguments + data: bytes object to send to the child process + Returns: + The pickled data received from the child process. + """ + worker = cls.worker + if worker is None: + worker = cls.start_worker(python) + + try: + worker.stdin.write(struct.pack('!i', len(data)) + data) + worker.stdin.flush() + + size, = struct.unpack('!i', read_exact(worker.stdout, 4)) + if size > 0: + return read_exact(worker.stdout, size) + # if the worker fails, it will write the exception to stdout + if size < 0: + stdout = read_exact(worker.stdout, -size) + try: + exception = pickle.loads(stdout) + except (pickle.UnpicklingError, EOFError): + pass + else: + if isinstance(exception, Exception): + # To allow for tests which test for errors. + raise exception + _, stderr = worker.communicate() + raise RuntimeError(stderr) + except: + cls.finish_worker() + raise + + def dumps(self, arg, proto=0, **kwargs): + # Skip tests that require buffer_callback arguments since + # there isn't a reliable way to marshal/pickle the callback and ensure + # it works in a different Python version. + if 'buffer_callback' in kwargs: + self.skipTest('Test does not support "buffer_callback" argument.') + f = io.BytesIO() + p = self.pickler(f, proto, **kwargs) + p.dump(arg) + data = struct.pack('!i', proto) + f.getvalue() + python = py_executable_map[self.py_version] + return self.send_to_worker(python, data) + + def loads(self, buf, **kwds): + f = io.BytesIO(buf) + u = self.unpickler(f, **kwds) + return u.load() + + # A scaled-down version of test_bytes from pickletester, to reduce + # the number of calls to self.dumps() and hence reduce the number of + # child python processes forked. This allows the test to complete + # much faster (the one from pickletester takes 3-4 minutes when running + # under text_xpickle). + def test_bytes(self): + if self.py_version < (3, 0): + self.skipTest('not supported in Python < 3.0') + for proto in pickletester.protocols: + for s in b'', b'xyz', b'xyz'*100: + p = self.dumps(s, proto) + self.assert_is_copy(s, self.loads(p)) + s = bytes(range(256)) + p = self.dumps(s, proto) + self.assert_is_copy(s, self.loads(p)) + s = bytes([i for i in range(256) for _ in range(2)]) + p = self.dumps(s, proto) + self.assert_is_copy(s, self.loads(p)) + + # These tests are disabled because they require some special setup + # on the worker that's hard to keep in sync. + test_global_ext1 = None + test_global_ext2 = None + test_global_ext4 = None + + # These tests fail because they require classes from pickletester + # which cannot be properly imported by the xpickle worker. + test_recursive_nested_names = None + test_recursive_nested_names2 = None + + # Attribute lookup problems are expected, disable the test + test_dynamic_class = None + test_evil_class_mutating_dict = None + + # Expected exception is raised during unpickling in a subprocess. + test_pickle_setstate_None = None + + # Other Python version may not have NumPy. + test_buffers_numpy = None + + # Skip tests that require buffer_callback arguments since + # there isn't a reliable way to marshal/pickle the callback and ensure + # it works in a different Python version. + test_in_band_buffers = None + test_buffers_error = None + test_oob_buffers = None + test_oob_buffers_writable_to_readonly = None + +class PyPicklePythonCompat(AbstractCompatTests): + pickler = pickle._Pickler + unpickler = pickle._Unpickler + +if has_c_implementation: + class CPicklePythonCompat(AbstractCompatTests): + pickler = _pickle.Pickler + unpickler = _pickle.Unpickler + + +def make_test(py_version, base): + class_dict = {'py_version': py_version} + name = base.__name__.replace('Python', 'Python%d%d' % py_version) + return type(name, (base, unittest.TestCase), class_dict) + +def load_tests(loader, tests, pattern): + def add_tests(py_version): + test_class = make_test(py_version, PyPicklePythonCompat) + tests.addTest(loader.loadTestsFromTestCase(test_class)) + if has_c_implementation: + test_class = make_test(py_version, CPicklePythonCompat) + tests.addTest(loader.loadTestsFromTestCase(test_class)) + + value = support.get_resource_value('xpickle') + if value is None: + major = sys.version_info.major + assert major == 3 + add_tests((2, 7)) + for minor in range(2, sys.version_info.minor): + add_tests((major, minor)) + else: + add_tests(tuple(map(int, value.split('.')))) + return tests + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_xxlimited.py b/Lib/test/test_xxlimited.py new file mode 100644 index 00000000000..b52e78bc4fb --- /dev/null +++ b/Lib/test/test_xxlimited.py @@ -0,0 +1,90 @@ +import unittest +from test.support import import_helper +import types + +xxlimited = import_helper.import_module('xxlimited') +xxlimited_35 = import_helper.import_module('xxlimited_35') + + +class CommonTests: + module: types.ModuleType + + def test_xxo_new(self): + xxo = self.module.Xxo() + + def test_xxo_attributes(self): + xxo = self.module.Xxo() + with self.assertRaises(AttributeError): + xxo.foo + with self.assertRaises(AttributeError): + del xxo.foo + + xxo.foo = 1234 + self.assertEqual(xxo.foo, 1234) + + del xxo.foo + with self.assertRaises(AttributeError): + xxo.foo + + def test_foo(self): + # the foo function adds 2 numbers + self.assertEqual(self.module.foo(1, 2), 3) + + def test_str(self): + self.assertIsSubclass(self.module.Str, str) + self.assertIsNot(self.module.Str, str) + + custom_string = self.module.Str("abcd") + self.assertEqual(custom_string, "abcd") + self.assertEqual(custom_string.upper(), "ABCD") + + def test_new(self): + xxo = self.module.new() + self.assertEqual(xxo.demo("abc"), "abc") + + +class TestXXLimited(CommonTests, unittest.TestCase): + module = xxlimited + + def test_xxo_demo(self): + xxo = self.module.Xxo() + other = self.module.Xxo() + self.assertEqual(xxo.demo("abc"), "abc") + self.assertEqual(xxo.demo(xxo), xxo) + self.assertEqual(xxo.demo(other), other) + self.assertEqual(xxo.demo(0), None) + + def test_error(self): + with self.assertRaises(self.module.Error): + raise self.module.Error + + def test_buffer(self): + xxo = self.module.Xxo() + self.assertEqual(xxo.x_exports, 0) + b1 = memoryview(xxo) + self.assertEqual(xxo.x_exports, 1) + b2 = memoryview(xxo) + self.assertEqual(xxo.x_exports, 2) + b1[0] = 1 + self.assertEqual(b1[0], 1) + self.assertEqual(b2[0], 1) + + +class TestXXLimited35(CommonTests, unittest.TestCase): + module = xxlimited_35 + + def test_xxo_demo(self): + xxo = self.module.Xxo() + other = self.module.Xxo() + self.assertEqual(xxo.demo("abc"), "abc") + self.assertEqual(xxo.demo(0), None) + + def test_roj(self): + # the roj function always fails + with self.assertRaises(SystemError): + self.module.roj(0) + + def test_null(self): + null1 = self.module.Null() + null2 = self.module.Null() + self.assertNotEqual(null1, null2) diff --git a/Lib/test/test_xxtestfuzz.py b/Lib/test/test_xxtestfuzz.py new file mode 100644 index 00000000000..3304c6e703a --- /dev/null +++ b/Lib/test/test_xxtestfuzz.py @@ -0,0 +1,25 @@ +import faulthandler +from test.support import import_helper +import unittest + +_xxtestfuzz = import_helper.import_module('_xxtestfuzz') + + +class TestFuzzer(unittest.TestCase): + """To keep our https://github.com/google/oss-fuzz API working.""" + + def test_sample_input_smoke_test(self): + """This is only a regression test: Check that it doesn't crash.""" + _xxtestfuzz.run(b"") + _xxtestfuzz.run(b"\0") + _xxtestfuzz.run(b"{") + _xxtestfuzz.run(b" ") + _xxtestfuzz.run(b"x") + _xxtestfuzz.run(b"1") + _xxtestfuzz.run(b"AAAAAAA") + _xxtestfuzz.run(b"AAAAAA\0") + + +if __name__ == "__main__": + faulthandler.enable() + unittest.main() diff --git a/Lib/test/test_zipfile/_path/__init__.py b/Lib/test/test_zipfile/_path/__init__.py index e69de29bb2d..8b137891791 100644 --- a/Lib/test/test_zipfile/_path/__init__.py +++ b/Lib/test/test_zipfile/_path/__init__.py @@ -0,0 +1 @@ + diff --git a/Lib/test/test_zipfile/_path/test_path.py b/Lib/test/test_zipfile/_path/test_path.py index f34251bc93c..351d9eefeb0 100644 --- a/Lib/test/test_zipfile/_path/test_path.py +++ b/Lib/test/test_zipfile/_path/test_path.py @@ -5,7 +5,6 @@ import pickle import stat import sys -import time import unittest import zipfile import zipfile._path @@ -649,7 +648,7 @@ def test_backslash_not_separator(self): """ data = io.BytesIO() zf = zipfile.ZipFile(data, "w") - zf.writestr(DirtyZipInfo.for_name("foo\\bar", zf), b"content") + zf.writestr(DirtyZipInfo("foo\\bar")._for_archive(zf), b"content") zf.filename = '' root = zipfile.Path(zf) (first,) = root.iterdir() @@ -672,20 +671,3 @@ class DirtyZipInfo(zipfile.ZipInfo): def __init__(self, filename, *args, **kwargs): super().__init__(filename, *args, **kwargs) self.filename = filename - - @classmethod - def for_name(cls, name, archive): - """ - Construct the same way that ZipFile.writestr does. - - TODO: extract this functionality and re-use - """ - self = cls(filename=name, date_time=time.localtime(time.time())[:6]) - self.compress_type = archive.compression - self.compress_level = archive.compresslevel - if self.filename.endswith('/'): # pragma: no cover - self.external_attr = 0o40775 << 16 # drwxrwxr-x - self.external_attr |= 0x10 # MS-DOS directory flag - else: - self.external_attr = 0o600 << 16 # ?rw------- - return self diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py index 63413d7b944..96f8caa94dc 100644 --- a/Lib/test/test_zipfile/test_core.py +++ b/Lib/test/test_zipfile/test_core.py @@ -6,6 +6,7 @@ import itertools import os import posixpath +import stat import struct import subprocess import sys @@ -19,14 +20,16 @@ from random import randint, random, randbytes from test import archiver_tests -from test.support import script_helper +from test.support import script_helper, os_helper from test.support import ( findfile, requires_zlib, requires_bz2, requires_lzma, - captured_stdout, captured_stderr, requires_subprocess + requires_zstd, captured_stdout, captured_stderr, requires_subprocess, + cpython_only ) from test.support.os_helper import ( TESTFN, unlink, rmtree, temp_dir, temp_cwd, fd_count, FakePath ) +from test.support.import_helper import ensure_lazy_imports TESTFN2 = TESTFN + "2" @@ -48,6 +51,13 @@ def get_files(test): yield f test.assertFalse(f.closed) + +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("zipfile", {"typing"}) + + class AbstractTestsWithSourceFile: @classmethod def setUpClass(cls): @@ -692,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractTestZip64InSmallFiles: # These tests test the ZIP64 functionality without using large files, @@ -1343,6 +1357,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractWriterTests: @@ -1412,6 +1430,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase): class LzmaWriterTests(AbstractWriterTests, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdWriterTests(AbstractWriterTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class PyZipFileTests(unittest.TestCase): def assertCompiledIn(self, name, namelist): @@ -1431,7 +1452,7 @@ def requiresWriteAccess(self, path): self.skipTest('requires write access to the installed location') unlink(filename) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_write_pyfile(self): self.requiresWriteAccess(os.path.dirname(__file__)) with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp: @@ -1462,7 +1483,7 @@ def test_write_pyfile(self): self.assertNotIn(bn, zipfp.namelist()) self.assertCompiledIn(bn, zipfp.namelist()) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_write_python_package(self): import email packagedir = os.path.dirname(email.__file__) @@ -1477,7 +1498,7 @@ def test_write_python_package(self): self.assertCompiledIn('email/__init__.py', names) self.assertCompiledIn('email/mime/text.py', names) - @unittest.expectedFailure # TODO: RUSTPYTHON; - AttributeError: module 'os' has no attribute 'supports_effective_ids' + @unittest.expectedFailure # TODO: RUSTPYTHON; - AttributeError: module 'os' has no attribute 'supports_effective_ids' def test_write_filtered_python_package(self): import test packagedir = os.path.dirname(test.__file__) @@ -1508,7 +1529,7 @@ def filter(path): print(reportStr) self.assertTrue('SyntaxError' not in reportStr) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_write_with_optimization(self): import email packagedir = os.path.dirname(email.__file__) @@ -1859,6 +1880,35 @@ def test_writestr_extended_local_header_issue1202(self): zinfo.flag_bits |= zipfile._MASK_USE_DATA_DESCRIPTOR # Include an extended local header. orig_zip.writestr(zinfo, data) + def test_write_with_source_date_epoch(self): + with os_helper.EnvironmentVarGuard() as env: + # Set the SOURCE_DATE_EPOCH environment variable to a specific timestamp + env['SOURCE_DATE_EPOCH'] = "1735715999" + + with zipfile.ZipFile(TESTFN, "w") as zf: + zf.writestr("test_source_date_epoch.txt", "Testing SOURCE_DATE_EPOCH") + + with zipfile.ZipFile(TESTFN, "r") as zf: + zip_info = zf.getinfo("test_source_date_epoch.txt") + get_time = time.localtime(int(os.environ['SOURCE_DATE_EPOCH']))[:6] + # Compare each element of the date_time tuple + # Allow for a 1-second difference + for z_time, g_time in zip(zip_info.date_time, get_time): + self.assertAlmostEqual(z_time, g_time, delta=1) + + def test_write_without_source_date_epoch(self): + with os_helper.EnvironmentVarGuard() as env: + del env['SOURCE_DATE_EPOCH'] + + with zipfile.ZipFile(TESTFN, "w") as zf: + zf.writestr("test_no_source_date_epoch.txt", "Testing without SOURCE_DATE_EPOCH") + + with zipfile.ZipFile(TESTFN, "r") as zf: + zip_info = zf.getinfo("test_no_source_date_epoch.txt") + current_time = time.localtime()[:6] + for z_time, c_time in zip(zip_info.date_time, current_time): + self.assertAlmostEqual(z_time, c_time, delta=2) + def test_close(self): """Check that the zipfile is closed after the 'with' block.""" with zipfile.ZipFile(TESTFN2, "w") as zipfp: @@ -2019,6 +2069,25 @@ def test_is_zip_erroneous_file(self): self.assertFalse(zipfile.is_zipfile(fp)) fp.seek(0, 0) self.assertFalse(zipfile.is_zipfile(fp)) + # - passing non-zipfile with ZIP header elements + # data created using pyPNG like so: + # d = [(ord('P'), ord('K'), 5, 6), (ord('P'), ord('K'), 6, 6)] + # w = png.Writer(1,2,alpha=True,compression=0) + # f = open('onepix.png', 'wb') + # w.write(f, d) + # w.close() + data = (b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00" + b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I" + b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07" + b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82") + # - passing a filename + with open(TESTFN, "wb") as fp: + fp.write(data) + self.assertFalse(zipfile.is_zipfile(TESTFN)) + # - passing a file-like object + fp = io.BytesIO() + fp.write(data) + self.assertFalse(zipfile.is_zipfile(fp)) def test_damaged_zipfile(self): """Check that zipfiles with missing bytes at the end raise BadZipFile.""" @@ -2048,10 +2117,16 @@ def test_is_zip_valid_file(self): zip_contents = fp.read() # - passing a file-like object fp = io.BytesIO() - fp.write(zip_contents) + end = fp.write(zip_contents) + self.assertEqual(fp.tell(), end) + mid = end // 2 + fp.seek(mid, 0) self.assertTrue(zipfile.is_zipfile(fp)) - fp.seek(0, 0) + # check that the position is left unchanged after the call + # see: https://github.com/python/cpython/issues/122356 + self.assertEqual(fp.tell(), mid) self.assertTrue(zipfile.is_zipfile(fp)) + self.assertEqual(fp.tell(), mid) def test_non_existent_file_raises_OSError(self): # make sure we don't raise an AttributeError when a partially-constructed @@ -2286,6 +2361,34 @@ def test_create_empty_zipinfo_repr(self): zi = zipfile.ZipInfo(filename="empty") self.assertEqual(repr(zi), "") + def test_for_archive(self): + base_filename = TESTFN2.rstrip('/') + + with zipfile.ZipFile(TESTFN, mode="w", compresslevel=1, + compression=zipfile.ZIP_STORED) as zf: + # no trailing forward slash + zi = zipfile.ZipInfo(base_filename)._for_archive(zf) + self.assertEqual(zi.compress_level, 1) + self.assertEqual(zi.compress_type, zipfile.ZIP_STORED) + # ?rw- --- --- + filemode = stat.S_IRUSR | stat.S_IWUSR + # filemode is stored as the highest 16 bits of external_attr + self.assertEqual(zi.external_attr >> 16, filemode) + self.assertEqual(zi.external_attr & 0xFF, 0) # no MS-DOS flag + + with zipfile.ZipFile(TESTFN, mode="w", compresslevel=1, + compression=zipfile.ZIP_STORED) as zf: + # with a trailing slash + zi = zipfile.ZipInfo(f'{base_filename}/')._for_archive(zf) + self.assertEqual(zi.compress_level, 1) + self.assertEqual(zi.compress_type, zipfile.ZIP_STORED) + # d rwx rwx r-x + filemode = stat.S_IFDIR + filemode |= stat.S_IRWXU | stat.S_IRWXG + filemode |= stat.S_IROTH | stat.S_IXOTH + self.assertEqual(zi.external_attr >> 16, filemode) + self.assertEqual(zi.external_attr & 0xFF, 0x10) # MS-DOS flag + def test_create_empty_zipinfo_default_attributes(self): """Ensure all required attributes are set.""" zi = zipfile.ZipInfo() @@ -2432,6 +2535,10 @@ def test_decompress_without_3rd_party_library(self): @requires_zlib() def test_full_overlap_different_names(self): + # The ZIP file contains two central directory entries with + # different names which refer to the same local header. + # The name of the local header matches the name of the first + # central directory entry. data = ( b'PK\x03\x04\x14\x00\x00\x00\x08\x00\xa0lH\x05\xe2\x1e' b'8\xbb\x10\x00\x00\x00\t\x04\x00\x00\x01\x00\x00\x00b\xed' @@ -2461,6 +2568,10 @@ def test_full_overlap_different_names(self): @requires_zlib() def test_full_overlap_different_names2(self): + # The ZIP file contains two central directory entries with + # different names which refer to the same local header. + # The name of the local header matches the name of the second + # central directory entry. data = ( b'PK\x03\x04\x14\x00\x00\x00\x08\x00\xa0lH\x05\xe2\x1e' b'8\xbb\x10\x00\x00\x00\t\x04\x00\x00\x01\x00\x00\x00a\xed' @@ -2492,6 +2603,8 @@ def test_full_overlap_different_names2(self): @requires_zlib() def test_full_overlap_same_name(self): + # The ZIP file contains two central directory entries with + # the same name which refer to the same local header. data = ( b'PK\x03\x04\x14\x00\x00\x00\x08\x00\xa0lH\x05\xe2\x1e' b'8\xbb\x10\x00\x00\x00\t\x04\x00\x00\x01\x00\x00\x00a\xed' @@ -2524,6 +2637,8 @@ def test_full_overlap_same_name(self): @requires_zlib() def test_quoted_overlap(self): + # The ZIP file contains two files. The second local header + # is contained in the range of the first file. data = ( b'PK\x03\x04\x14\x00\x00\x00\x08\x00\xa0lH\x05Y\xfc' b'8\x044\x00\x00\x00(\x04\x00\x00\x01\x00\x00\x00a\x00' @@ -2555,6 +2670,7 @@ def test_quoted_overlap(self): @requires_zlib() def test_overlap_with_central_dir(self): + # The local header offset is equal to the central directory offset. data = ( b'PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00G_|Z' b'\xe2\x1e8\xbb\x0b\x00\x00\x00\t\x04\x00\x00\x01\x00\x00\x00' @@ -2569,11 +2685,15 @@ def test_overlap_with_central_dir(self): self.assertEqual(zi.header_offset, 0) self.assertEqual(zi.compress_size, 11) self.assertEqual(zi.file_size, 1033) + # Found central directory signature PK\x01\x02 instead of + # local header signature PK\x03\x04. with self.assertRaisesRegex(zipfile.BadZipFile, 'Bad magic number'): zipf.read('a') @requires_zlib() def test_overlap_with_archive_comment(self): + # The local header is written after the central directory, + # in the archive comment. data = ( b'PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00G_|Z' b'\xe2\x1e8\xbb\x0b\x00\x00\x00\t\x04\x00\x00\x01\x00\x00\x00' @@ -2685,6 +2805,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase): b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00' b'\x00>\x00\x00\x00\x00\x00') +@requires_zstd() +class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD + zip_with_bad_crc = ( + b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00' + b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00' + b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00' + b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK' + b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00' + b'\x00\x00') class DecryptionTests(unittest.TestCase): """Check that ZIP decryption works. Since the library does not @@ -2912,6 +3043,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD # Provide the tell() method but not seek() class Tellable: @@ -3160,7 +3295,7 @@ def test_write_dir(self): with zipfile.ZipFile(TESTFN, "w") as zipf: zipf.write(dirpath) zinfo = zipf.filelist[0] - self.assertTrue(zinfo.filename.endswith("/x/")) + self.assertEndsWith(zinfo.filename, "/x/") self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10) zipf.write(dirpath, "y") zinfo = zipf.filelist[1] @@ -3168,7 +3303,7 @@ def test_write_dir(self): self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10) with zipfile.ZipFile(TESTFN, "r") as zipf: zinfo = zipf.filelist[0] - self.assertTrue(zinfo.filename.endswith("/x/")) + self.assertEndsWith(zinfo.filename, "/x/") self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10) zinfo = zipf.filelist[1] self.assertTrue(zinfo.filename, "y/") @@ -3188,7 +3323,7 @@ def test_writestr_dir(self): self.assertEqual(zinfo.external_attr, (0o40775 << 16) | 0x10) with zipfile.ZipFile(TESTFN, "r") as zipf: zinfo = zipf.filelist[0] - self.assertTrue(zinfo.filename.endswith("x/")) + self.assertEndsWith(zinfo.filename, "x/") self.assertEqual(zinfo.external_attr, (0o40775 << 16) | 0x10) target = os.path.join(TESTFN2, "target") os.mkdir(target) @@ -3415,7 +3550,7 @@ def test_read_zip_with_exe_prepended(self): def test_read_zip64_with_exe_prepended(self): self._test_zip_works(self.exe_zip64) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(sys.executable, 'sys.executable required.') @unittest.skipUnless(os.access('/bin/bash', os.X_OK), 'Test relies on #!/bin/bash working.') @@ -3424,7 +3559,7 @@ def test_execute_zip2(self): output = subprocess.check_output([self.exe_zip, sys.executable]) self.assertIn(b'number in executable: 5', output) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(sys.executable, 'sys.executable required.') @unittest.skipUnless(os.access('/bin/bash', os.X_OK), 'Test relies on #!/bin/bash working.') @@ -3433,17 +3568,15 @@ def test_execute_zip64(self): output = subprocess.check_output([self.exe_zip64, sys.executable]) self.assertIn(b'number in executable: 5', output) - -@unittest.skip("TODO: RUSTPYTHON shift_jis encoding unsupported") +@unittest.skip("TODO: RUSTPYTHON; LookupError: unknown encoding: shift_jis") class EncodedMetadataTests(unittest.TestCase): file_names = ['\u4e00', '\u4e8c', '\u4e09'] # Han 'one', 'two', 'three' file_content = [ "This is pure ASCII.\n".encode('ascii'), # This is modern Japanese. (UTF-8) "\u3053\u308c\u306f\u73fe\u4ee3\u7684\u65e5\u672c\u8a9e\u3067\u3059\u3002\n".encode('utf-8'), - # TODO RUSTPYTHON - # Uncomment when Shift JIS is supported # This is obsolete Japanese. (Shift JIS) + # TODO: RUSTPYTHON; LookupError: unknown encoding: shift_jis # "\u3053\u308c\u306f\u53e4\u3044\u65e5\u672c\u8a9e\u3067\u3059\u3002\n".encode('shift_jis'), ] @@ -3484,13 +3617,11 @@ def _test_read(self, zipfp, expected_names, expected_content): self.assertEqual(info.file_size, len(content)) self.assertEqual(zipfp.read(name), content) - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_read_with_metadata_encoding(self): # Read the ZIP archive with correct metadata_encoding with zipfile.ZipFile(TESTFN, "r", metadata_encoding='shift_jis') as zipfp: self._test_read(zipfp, self.file_names, self.file_content) - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_read_without_metadata_encoding(self): # Read the ZIP archive without metadata_encoding expected_names = [name.encode('shift_jis').decode('cp437') @@ -3498,7 +3629,6 @@ def test_read_without_metadata_encoding(self): with zipfile.ZipFile(TESTFN, "r") as zipfp: self._test_read(zipfp, expected_names, self.file_content) - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_read_with_incorrect_metadata_encoding(self): # Read the ZIP archive with incorrect metadata_encoding expected_names = [name.encode('shift_jis').decode('koi8-u') @@ -3506,7 +3636,6 @@ def test_read_with_incorrect_metadata_encoding(self): with zipfile.ZipFile(TESTFN, "r", metadata_encoding='koi8-u') as zipfp: self._test_read(zipfp, expected_names, self.file_content) - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_read_with_unsuitable_metadata_encoding(self): # Read the ZIP archive with metadata_encoding unsuitable for # decoding metadata @@ -3515,7 +3644,6 @@ def test_read_with_unsuitable_metadata_encoding(self): with self.assertRaises(UnicodeDecodeError): zipfile.ZipFile(TESTFN, "r", metadata_encoding='utf-8') - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_read_after_append(self): newname = '\u56db' # Han 'four' expected_names = [name.encode('shift_jis').decode('cp437') @@ -3542,7 +3670,6 @@ def test_read_after_append(self): else: self.assertEqual(zipfp.read(name), content) - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_write_with_metadata_encoding(self): ZF = zipfile.ZipFile for mode in ("w", "x", "a"): @@ -3550,7 +3677,6 @@ def test_write_with_metadata_encoding(self): "^metadata_encoding is only"): ZF("nonesuch.zip", mode, metadata_encoding="shift_jis") - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_cli_with_metadata_encoding(self): errmsg = "Non-conforming encodings not supported with -c." args = ["--metadata-encoding=shift_jis", "-c", "nonesuch", "nonesuch"] @@ -3570,7 +3696,6 @@ def test_cli_with_metadata_encoding(self): for name in self.file_names: self.assertIn(name, listing) - @unittest.expectedFailure # TODO: RUSTPYTHON; def test_cli_with_metadata_encoding_extract(self): os.mkdir(TESTFN2) self.addCleanup(rmtree, TESTFN2) diff --git a/Lib/test/test_zoneinfo/test_zoneinfo.py b/Lib/test/test_zoneinfo/test_zoneinfo.py index 85877c984c8..e9516c0b127 100644 --- a/Lib/test/test_zoneinfo/test_zoneinfo.py +++ b/Lib/test/test_zoneinfo/test_zoneinfo.py @@ -741,6 +741,38 @@ def test_empty_zone(self): with self.assertRaises(ValueError): self.klass.from_file(zf) + def test_invalid_transition_index(self): + STD = ZoneOffset("STD", ZERO) + DST = ZoneOffset("DST", ONE_H, ONE_H) + + zf = self.construct_zone([ + ZoneTransition(datetime(2026, 3, 1, 2), STD, DST), + ZoneTransition(datetime(2026, 11, 1, 2), DST, STD), + ], after="", version=1) + + data = bytearray(zf.read()) + timecnt = struct.unpack_from(">l", data, 32)[0] + idx_offset = 44 + timecnt * 4 + data[idx_offset + 1] = 2 # typecnt is 2, so index 2 is OOB + f = io.BytesIO(bytes(data)) + + with self.assertRaises(ValueError): + self.klass.from_file(f) + + def test_transition_lookahead_out_of_bounds(self): + STD = ZoneOffset("STD", ZERO) + DST = ZoneOffset("DST", ONE_H, ONE_H) + EXT = ZoneOffset("EXT", ONE_H) + + zf = self.construct_zone([ + ZoneTransition(datetime(2026, 3, 1), STD, DST), + ZoneTransition(datetime(2026, 6, 1), DST, EXT), + ZoneTransition(datetime(2026, 9, 1), EXT, DST), + ], after="") + + zi = self.klass.from_file(zf) + self.assertIsNotNone(zi) + def test_zone_very_large_timestamp(self): """Test when a transition is in the far past or future. @@ -1577,6 +1609,59 @@ class EvilZoneInfo(self.klass): class CZoneInfoCacheTest(ZoneInfoCacheTest): module = c_zoneinfo + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: RuntimeError not raised + def test_inconsistent_weak_cache_get(self): + class Cache: + def get(self, key, default=None): + return 1337 + + class ZI(self.klass): + pass + # Class attribute must be set after class creation + # to override zoneinfo.ZoneInfo.__init_subclass__. + ZI._weak_cache = Cache() + + with self.assertRaises(RuntimeError) as te: + ZI("America/Los_Angeles") + self.assertEqual( + str(te.exception), + "Unexpected instance of int in ZI weak cache for key 'America/Los_Angeles'" + ) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: AttributeError not raised + def test_deleted_weak_cache(self): + class ZI(self.klass): + pass + delattr(ZI, '_weak_cache') + + # These should not segfault + with self.assertRaises(AttributeError): + ZI("UTC") + + with self.assertRaises(AttributeError): + ZI.clear_cache() + + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'int' object has no attribute '_from_cache' + def test_inconsistent_weak_cache_setdefault(self): + class Cache: + def get(self, key, default=None): + return default + def setdefault(self, key, value): + return 1337 + + class ZI(self.klass): + pass + # Class attribute must be set after class creation + # to override zoneinfo.ZoneInfo.__init_subclass__. + ZI._weak_cache = Cache() + + with self.assertRaises(RuntimeError) as te: + ZI("America/Los_Angeles") + self.assertEqual( + str(te.exception), + "Unexpected instance of int in ZI weak cache for key 'America/Los_Angeles'" + ) + class ZoneInfoPickleTest(TzPathUserMixin, ZoneInfoTestBase): module = py_zoneinfo diff --git a/Lib/test/xpickle_worker.py b/Lib/test/xpickle_worker.py new file mode 100644 index 00000000000..1b49515123c --- /dev/null +++ b/Lib/test/xpickle_worker.py @@ -0,0 +1,62 @@ +# This script is called by test_xpickle as a subprocess to load and dump +# pickles in a different Python version. +import os +import pickle +import struct +import sys + + +# This allows the xpickle worker to import picklecommon.py, which it needs +# since some of the pickle objects hold references to picklecommon.py. +test_mod_path = os.path.abspath(os.path.join(os.path.dirname(__file__), + 'picklecommon.py')) +if sys.version_info >= (3, 5): + import importlib.util + spec = importlib.util.spec_from_file_location('test.picklecommon', test_mod_path) + sys.modules['test'] = type(sys)('test') + test_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(test_module) + sys.modules['test.picklecommon'] = test_module +else: + test_module = type(sys)('test.picklecommon') + sys.modules['test.picklecommon'] = test_module + sys.modules['test'] = type(sys)('test') + with open(test_mod_path, 'rb') as f: + sources = f.read() + exec(sources, vars(test_module)) + +def read_exact(f, n): + buf = b'' + while len(buf) < n: + chunk = f.read(n - len(buf)) + if not chunk: + raise EOFError + buf += chunk + return buf + +in_stream = getattr(sys.stdin, 'buffer', sys.stdin) +out_stream = getattr(sys.stdout, 'buffer', sys.stdout) + +try: + while True: + size, = struct.unpack('!i', read_exact(in_stream, 4)) + if not size: + break + data = read_exact(in_stream, size) + protocol, = struct.unpack('!i', data[:4]) + obj = pickle.loads(data[4:]) + data = pickle.dumps(obj, protocol) + out_stream.write(struct.pack('!i', len(data)) + data) + out_stream.flush() +except Exception as exc: + # dump the exception to stdout and write to stderr, then exit + try: + data = pickle.dumps(exc) + out_stream.write(struct.pack('!i', -len(data)) + data) + out_stream.flush() + except Exception: + out_stream.write(struct.pack('!i', 0)) + out_stream.flush() + sys.stderr.write(repr(exc)) + sys.stderr.flush() + sys.exit(1) diff --git a/Lib/venv/__init__.py b/Lib/venv/__init__.py index f7a6d261401..88f3340af41 100644 --- a/Lib/venv/__init__.py +++ b/Lib/venv/__init__.py @@ -103,8 +103,6 @@ def _venv_path(self, env_dir, name): vars = { 'base': env_dir, 'platbase': env_dir, - 'installed_base': env_dir, - 'installed_platbase': env_dir, } return sysconfig.get_path(name, scheme='venv', vars=vars) @@ -175,9 +173,20 @@ def create_if_needed(d): context.python_dir = dirname context.python_exe = exename binpath = self._venv_path(env_dir, 'scripts') - incpath = self._venv_path(env_dir, 'include') libpath = self._venv_path(env_dir, 'purelib') + # PEP 405 says venvs should create a local include directory. + # See https://peps.python.org/pep-0405/#include-files + # XXX: This directory is not exposed in sysconfig or anywhere else, and + # doesn't seem to be utilized by modern packaging tools. We keep it + # for backwards-compatibility, and to follow the PEP, but I would + # recommend against using it, as most tooling does not pass it to + # compilers. Instead, until we standardize a site-specific include + # directory, I would recommend installing headers as package data, + # and providing some sort of API to get the include directories. + # Example: https://numpy.org/doc/2.1/reference/generated/numpy.get_include.html + incpath = os.path.join(env_dir, 'Include' if os.name == 'nt' else 'include') + context.inc_path = incpath create_if_needed(incpath) context.lib_path = libpath @@ -304,8 +313,11 @@ def setup_python(self, context): copier(context.executable, path) if not os.path.islink(path): os.chmod(path, 0o755) - for suffix in ('python', 'python3', - f'python3.{sys.version_info[1]}'): + + suffixes = ['python', 'python3', f'python3.{sys.version_info[1]}'] + if sys.version_info[:2] == (3, 14) and sys.getfilesystemencoding() == 'utf-8': + suffixes.append('𝜋thon') + for suffix in suffixes: path = os.path.join(binpath, suffix) if not os.path.exists(path): # Issue 18807: make copies if @@ -576,7 +588,7 @@ def skip_file(f): 'may be binary: %s', srcfile, e) continue if new_data == data: - shutil.copy2(srcfile, dstfile) + shutil.copy(srcfile, dstfile) else: with open(dstfile, 'wb') as f: f.write(new_data) @@ -604,8 +616,7 @@ def create(env_dir, system_site_packages=False, clear=False, def main(args=None): import argparse - parser = argparse.ArgumentParser(prog=__name__, - description='Creates virtual Python ' + parser = argparse.ArgumentParser(description='Creates virtual Python ' 'environments in one or ' 'more target ' 'directories.', @@ -613,7 +624,9 @@ def main(args=None): 'created, you may wish to ' 'activate it, e.g. by ' 'sourcing an activate script ' - 'in its bin directory.') + 'in its bin directory.', + color=True, + ) parser.add_argument('dirs', metavar='ENV_DIR', nargs='+', help='A directory to create the environment in.') parser.add_argument('--system-site-packages', default=False, diff --git a/Lib/venv/scripts/nt/venvlauncher.exe b/Lib/venv/scripts/nt/venvlauncher.exe deleted file mode 100644 index c6863b56e57..00000000000 Binary files a/Lib/venv/scripts/nt/venvlauncher.exe and /dev/null differ diff --git a/Lib/venv/scripts/nt/venvlaunchert.exe b/Lib/venv/scripts/nt/venvlaunchert.exe deleted file mode 100644 index c12a7a869f4..00000000000 Binary files a/Lib/venv/scripts/nt/venvlaunchert.exe and /dev/null differ diff --git a/Lib/venv/scripts/nt/venvwlauncher.exe b/Lib/venv/scripts/nt/venvwlauncher.exe deleted file mode 100644 index d0d3733266f..00000000000 Binary files a/Lib/venv/scripts/nt/venvwlauncher.exe and /dev/null differ diff --git a/Lib/venv/scripts/nt/venvwlaunchert.exe b/Lib/venv/scripts/nt/venvwlaunchert.exe deleted file mode 100644 index 9456a9e9b4a..00000000000 Binary files a/Lib/venv/scripts/nt/venvwlaunchert.exe and /dev/null differ diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py index 19aea290b58..c8d8d3a2c0a 100644 --- a/Lib/zipfile/__init__.py +++ b/Lib/zipfile/__init__.py @@ -1410,6 +1410,7 @@ class ZipFile: fp = None # Set here since __del__ checks it _windows_illegal_name_trans_table = None + _ignore_invalid_names = False def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=True, compresslevel=None, *, strict_timestamps=True, metadata_encoding=None): @@ -1890,21 +1891,31 @@ def _extract_member(self, member, targetpath, pwd): # build the destination pathname, replacing # forward slashes to platform specific separators. - arcname = member.filename.replace('/', os.path.sep) - - if os.path.altsep: + arcname = member.filename + if os.path.sep != '/': + arcname = arcname.replace('/', os.path.sep) + if os.path.altsep and os.path.altsep != '/': arcname = arcname.replace(os.path.altsep, os.path.sep) # interpret absolute pathname as relative, remove drive letter or # UNC path, redundant separators, "." and ".." components. - arcname = os.path.splitdrive(arcname)[1] + drive, root, arcname = os.path.splitroot(arcname) + if self._ignore_invalid_names and (drive or root): + return None + if self._ignore_invalid_names and os.path.pardir in arcname.split(os.path.sep): + return None invalid_path_parts = ('', os.path.curdir, os.path.pardir) arcname = os.path.sep.join(x for x in arcname.split(os.path.sep) if x not in invalid_path_parts) if os.path.sep == '\\': # filter illegal characters on Windows - arcname = self._sanitize_windows_name(arcname, os.path.sep) + arcname2 = self._sanitize_windows_name(arcname, os.path.sep) + if self._ignore_invalid_names and arcname2 != arcname: + return None + arcname = arcname2 if not arcname and not member.is_dir(): + if self._ignore_invalid_names: + return None raise ValueError("Empty filename.") targetpath = os.path.join(targetpath, arcname) diff --git a/Lib/zoneinfo/_common.py b/Lib/zoneinfo/_common.py index 59f3f0ce853..98668c15d8b 100644 --- a/Lib/zoneinfo/_common.py +++ b/Lib/zoneinfo/_common.py @@ -67,6 +67,10 @@ def load_data(fobj): f">{timecnt}{time_type}", fobj.read(timecnt * time_size) ) trans_idx = struct.unpack(f">{timecnt}B", fobj.read(timecnt)) + + if max(trans_idx) >= typecnt: + raise ValueError("Invalid transition index found while reading TZif: " + f"{max(trans_idx)}") else: trans_list_utc = () trans_idx = () diff --git a/Lib/zoneinfo/_zoneinfo.py b/Lib/zoneinfo/_zoneinfo.py index 3ffdb4c8371..7063eb6a902 100644 --- a/Lib/zoneinfo/_zoneinfo.py +++ b/Lib/zoneinfo/_zoneinfo.py @@ -47,7 +47,11 @@ def __new__(cls, key): cls._strong_cache[key] = cls._strong_cache.pop(key, instance) if len(cls._strong_cache) > cls._strong_cache_size: - cls._strong_cache.popitem(last=False) + try: + cls._strong_cache.popitem(last=False) + except KeyError: + # another thread may have already emptied the cache + pass return instance @@ -334,7 +338,7 @@ def _utcoff_to_dstoff(trans_idx, utcoffsets, isdsts): if not isdsts[comp_idx]: dstoff = utcoff - utcoffsets[comp_idx] - if not dstoff and idx < (typecnt - 1): + if not dstoff and idx < (typecnt - 1) and i + 1 < len(trans_idx): comp_idx = trans_idx[i + 1] # If the following transition is also DST and we couldn't diff --git a/README.md b/README.md index 6949c6e66e2..37b577d0f0e 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ $ python # now `python` is the alias of the RustPython for the new env If you'd like to make https requests, you can enable the `ssl` feature, which also lets you install the `pip` package manager. Note that on Windows, you may -need to install OpenSSL, or you can enable the `ssl-vendor` feature instead, +need to install OpenSSL, or you can enable the `ssl-openssl-vendor` feature instead, which compiles OpenSSL for you but requires a C compiler, perl, and `make`. OpenSSL version 3 is expected and tested in CI. Older versions may not work. @@ -102,8 +102,8 @@ rustpython ### SSL provider -For HTTPS requests, `ssl-rustls` feature is enabled by default. You can replace it with `ssl-openssl` feature if your environment requires OpenSSL. -Note that to use OpenSSL on Windows, you may need to install OpenSSL, or you can enable the `ssl-vendor` feature instead, +For HTTPS requests, `ssl-rustls-aws-lc` is enabled by default for the RustPython binary. Embedders can use `rustpython-stdlib`'s provider-agnostic `ssl-rustls` feature and install their own rustls crypto provider, or replace rustls with `ssl-openssl` if their environment requires OpenSSL. +Note that to use OpenSSL on Windows, you may need to install OpenSSL, or you can enable the `ssl-openssl-vendor` feature instead, which compiles OpenSSL for you but requires a C compiler, perl, and `make`. OpenSSL version 3 is expected and tested in CI. Older versions may not work. @@ -229,24 +229,10 @@ For a high level overview of the components, see the [architecture](architecture ## Contributing -Contributions are more than welcome, and in many cases we are happy to guide -contributors through PRs or on Discord. Please refer to the -[development guide](DEVELOPMENT.md) as well for tips on developments. +Contributions are welcome and highly appreciated. To get started, check out the +[**contributing guidelines**](CONTRIBUTING.md). -With that in mind, please note this project is maintained by volunteers, some of -the best ways to get started are below: - -Most tasks are listed in the -[issue tracker](https://github.com/RustPython/RustPython/issues). Check issues -labeled with [good first issue](https://github.com/RustPython/RustPython/issues?q=label%3A%22good+first+issue%22+is%3Aissue+is%3Aopen+) if you wish to start coding. - -To enhance CPython compatibility, try to increase unittest coverage by checking this article: [How to contribute to RustPython by CPython unittest](https://rustpython.github.io/guideline/2020/04/04/how-to-contribute-by-cpython-unittest.html) - -Another approach is to checkout the source code: builtin functions and object -methods are often the simplest and easiest way to contribute. - -You can also simply run `python -I scripts/whats_left.py` to assist in finding any unimplemented -method. +You can also join us on [**Discord**](https://discord.gg/vru8NypEhv). ## Compiling to WebAssembly diff --git a/crates/capi/Cargo.toml b/crates/capi/Cargo.toml index e9dc8de00d8..685b44d1ee0 100644 --- a/crates/capi/Cargo.toml +++ b/crates/capi/Cargo.toml @@ -12,7 +12,9 @@ license.workspace = true crate-type = ["cdylib", "rlib"] [dependencies] -rustpython-vm = { workspace = true, features = ["threading"] } +bitflags = { workspace = true } +num-complex = { workspace = true } +rustpython-vm = { workspace = true, features = ["threading", "compiler"] } rustpython-stdlib = {workspace = true, features = ["threading"] } [dev-dependencies] diff --git a/crates/capi/src/abstract_.rs b/crates/capi/src/abstract_.rs new file mode 100644 index 00000000000..3c109cd1ed5 --- /dev/null +++ b/crates/capi/src/abstract_.rs @@ -0,0 +1,165 @@ +use crate::{PyObject, pystate::with_vm}; +use alloc::slice; +use core::ffi::c_int; +use rustpython_vm::builtins::{PyDict, PyStr, PyTuple}; +use rustpython_vm::function::{FuncArgs, KwArgs, PosArgs}; +use rustpython_vm::{AsObject, Py, PyObjectRef, PyResult, VirtualMachine}; + +const PY_VECTORCALL_ARGUMENTS_OFFSET: usize = 1usize << (usize::BITS as usize - 1); + +fn tuple_to_args(tuple: &Py) -> PosArgs { + tuple.iter().cloned().collect::>().into() +} + +fn dict_to_kwargs(vm: &VirtualMachine, dict: &Py) -> PyResult { + dict.items_vec() + .into_iter() + .map(|(key, value)| { + let key = key + .downcast_ref::() + .map(|s| s.to_string()) + .ok_or_else(|| vm.new_type_error("keywords must be strings"))?; + Ok((key, value)) + }) + .collect::>() + .map(KwArgs::new) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_Call( + callable: *mut PyObject, + args: *mut PyObject, + kwargs: *mut PyObject, +) -> *mut PyObject { + with_vm(|vm| { + let callable = unsafe { &*callable }; + let args = tuple_to_args(unsafe { &*args }.try_downcast_ref::(vm)?); + + let kwargs: Option = unsafe { kwargs.as_ref() } + .map(|kwargs| dict_to_kwargs(vm, kwargs.try_downcast_ref::(vm)?)) + .transpose()?; + + callable.call_with_args(FuncArgs::new(args, kwargs.unwrap_or_default()), vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_CallNoArgs(callable: *mut PyObject) -> *mut PyObject { + with_vm(|vm| unsafe { &*callable }.call((), vm)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_Vectorcall( + callable: *mut PyObject, + args: *const *mut PyObject, + nargsf: usize, + kwnames: *mut PyObject, +) -> *mut PyObject { + with_vm(|vm| { + let num_positional_args = nargsf & !PY_VECTORCALL_ARGUMENTS_OFFSET; + + let kwnames: Option<&[PyObjectRef]> = unsafe { + kwnames + .as_ref() + .map(|tuple| Ok(&***tuple.try_downcast_ref::(vm)?)) + .transpose()? + }; + + let args_len = num_positional_args + kwnames.map_or(0, <[PyObjectRef]>::len); + let args = if args_len == 0 { + Vec::new() + } else { + unsafe { slice::from_raw_parts(args, args_len) } + .iter() + .map(|arg| unsafe { &**arg }.to_owned()) + .collect::>() + }; + + let callable = unsafe { &*callable }; + callable.vectorcall(args, num_positional_args, kwnames, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_VectorcallMethod( + name: *mut PyObject, + args: *const *mut PyObject, + nargsf: usize, + kwnames: *mut PyObject, +) -> *mut PyObject { + with_vm(|vm| { + let args_len = nargsf & !PY_VECTORCALL_ARGUMENTS_OFFSET; + + if args_len == 0 { + return Err( + vm.new_system_error("PyObject_VectorcallMethod called with no receiver".to_owned()) + ); + } + + let (receiver, args) = unsafe { slice::from_raw_parts(args, args_len) } + .split_first() + .expect("args_len > 0 should guarantee a receiver"); + + let method_name = unsafe { (&*name).try_downcast_ref::(vm)? }; + let callable = unsafe { (&**receiver).get_attr(method_name, vm)? }; + + Ok(unsafe { + PyObject_Vectorcall( + callable.as_object().as_raw().cast_mut(), + args.as_ptr(), + nargsf - 1, + kwnames, + ) + }) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_GetItem(obj: *mut PyObject, key: *mut PyObject) -> *mut PyObject { + with_vm(|vm| { + let obj = unsafe { &*obj }; + let key = unsafe { &*key }; + obj.get_item(key, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_SetItem( + obj: *mut PyObject, + key: *mut PyObject, + value: *mut PyObject, +) -> c_int { + with_vm(|vm| { + let obj = unsafe { &*obj }; + let key = unsafe { &*key }; + let value = unsafe { &*value }.to_owned(); + obj.set_item(key, value, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_DelItem(obj: *mut PyObject, key: *mut PyObject) -> c_int { + with_vm(|vm| { + let obj = unsafe { &*obj }; + let key = unsafe { &*key }; + obj.del_item(key, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_IsSubclass(derived: *mut PyObject, cls: *mut PyObject) -> c_int { + with_vm(|vm| { + let derived = unsafe { &*derived }; + let cls = unsafe { &*cls }; + derived.is_subclass(cls, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyObject_IsInstance(inst: *mut PyObject, cls: *mut PyObject) -> c_int { + with_vm(|vm| { + let inst = unsafe { &*inst }; + let cls = unsafe { &*cls }; + inst.is_instance(cls, vm) + }) +} diff --git a/crates/capi/src/boolobject.rs b/crates/capi/src/boolobject.rs new file mode 100644 index 00000000000..464d0ec0b8a --- /dev/null +++ b/crates/capi/src/boolobject.rs @@ -0,0 +1,28 @@ +use crate::object::define_py_check; +use crate::{PyObject, pystate::with_vm}; +use core::ffi::{c_int, c_long}; +use rustpython_vm::AsObject; + +define_py_check!(fn PyBool_Check, types.bool_type); + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn Py_IsTrue(obj: *mut PyObject) -> c_int { + with_vm(|vm| unsafe { obj.as_ref().is_some_and(|obj| obj.is(&vm.ctx.true_value)) }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn Py_IsFalse(obj: *mut PyObject) -> c_int { + with_vm(|vm| unsafe { obj.as_ref().is_some_and(|obj| obj.is(&vm.ctx.false_value)) }) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyBool_FromLong(value: c_long) -> *mut PyObject { + with_vm(|vm| { + if value == 0 { + &vm.ctx.false_value + } else { + &vm.ctx.true_value + } + .to_owned() + }) +} diff --git a/crates/capi/src/bytesobject.rs b/crates/capi/src/bytesobject.rs new file mode 100644 index 00000000000..1fe535efba5 --- /dev/null +++ b/crates/capi/src/bytesobject.rs @@ -0,0 +1,73 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::c_char; +use rustpython_vm::builtins::PyBytes; + +define_py_check!(fn PyBytes_Check, types.bytes_type); +define_py_check!(exact fn PyBytes_CheckExact, types.bytes_type); + +#[unsafe(no_mangle)] +#[allow(clippy::uninit_vec)] +pub unsafe extern "C" fn PyBytes_FromStringAndSize( + bytes: *mut c_char, + len: isize, +) -> *mut PyObject { + with_vm(|vm| { + let len = len.try_into().map_err(|_| { + vm.new_system_error("Negative size passed to PyBytes_FromStringAndSize") + })?; + + let data = if bytes.is_null() { + let mut data = Vec::with_capacity(len); + unsafe { data.set_len(len) }; + data + } else { + unsafe { core::slice::from_raw_parts(bytes as *const u8, len) }.to_vec() + }; + + Ok(vm.ctx.new_bytes(data)) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyBytes_Size(bytes: *mut PyObject) -> isize { + with_vm(|vm| { + let bytes = unsafe { &*bytes }.try_downcast_ref::(vm)?; + Ok(bytes.as_bytes().len()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyBytes_AsString(bytes: *mut PyObject) -> *mut c_char { + with_vm(|vm| { + let bytes = unsafe { &*bytes }.try_downcast_ref::(vm)?; + Ok(bytes.as_bytes().as_ptr()) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyBytes; + + #[test] + fn test_bytes() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"Hello, World!"); + assert_eq!(bytes.as_bytes(), b"Hello, World!"); + }) + } + + #[test] + fn test_bytes_uninit() { + Python::attach(|py| { + let bytes = PyBytes::new_with(py, 13, |data| { + data.copy_from_slice(b"Hello, World!"); + Ok(()) + }) + .unwrap(); + assert_eq!(bytes.as_bytes(), b"Hello, World!"); + }) + } +} diff --git a/crates/capi/src/ceval.rs b/crates/capi/src/ceval.rs new file mode 100644 index 00000000000..be8d9c48b61 --- /dev/null +++ b/crates/capi/src/ceval.rs @@ -0,0 +1,95 @@ +use crate::pystate::with_vm; +use core::ffi::{CStr, c_char, c_int}; +use core::ptr::NonNull; +use rustpython_vm::builtins::{PyCode, PyDict}; +use rustpython_vm::compiler::Mode; +use rustpython_vm::function::ArgMapping; +use rustpython_vm::scope::Scope; +use rustpython_vm::{AsObject, PyObject, TryFromObject}; + +const PY_SINGLE_INPUT: c_int = 256; +const PY_FILE_INPUT: c_int = 257; +const PY_EVAL_INPUT: c_int = 258; +const PY_FUNC_TYPE_INPUT: c_int = 345; + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn Py_CompileString( + code: *const c_char, + filename: *const c_char, + start: c_int, +) -> *mut PyObject { + with_vm(|vm| { + let code = unsafe { CStr::from_ptr(code) }.to_str().map_err(|_| { + vm.new_system_error("Py_CompileString called with non UTF-8 code string") + })?; + let filename = unsafe { CStr::from_ptr(filename) } + .to_str() + .map_err(|_| vm.new_system_error("Py_CompileString called with non UTF-8 filename"))?; + + let mode = match start { + PY_SINGLE_INPUT => Mode::Single, + PY_FILE_INPUT => Mode::Exec, + PY_EVAL_INPUT => Mode::Eval, + PY_FUNC_TYPE_INPUT => Mode::BlockExpr, + _ => { + return Err( + vm.new_system_error("Invalid start argument passed to Py_CompileString") + ); + } + }; + + vm.compile(code, mode, filename.to_owned()) + .map_err(|err| vm.new_syntax_error(&err, Some(code))) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyEval_EvalCode( + co: *mut PyObject, + globals: *mut PyObject, + locals: *mut PyObject, +) -> *mut PyObject { + with_vm(|vm| { + let code = unsafe { &*co }.try_downcast_ref::(vm)?; + let globals = unsafe { &*globals }.try_downcast_ref::(vm)?; + let locals = NonNull::new(locals) + .map(|ptr| ArgMapping::try_from_object(vm, unsafe { ptr.as_ref() }.to_owned())) + .transpose()?; + + let scope = Scope::with_builtins(locals, globals.to_owned(), vm); + + vm.run_code_obj(code.to_owned(), scope) + }) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyEval_GetBuiltins() -> *mut PyObject { + with_vm(|vm| { + vm.current_frame().map_or_else( + || vm.builtins.as_object().as_raw(), + |frame| frame.builtins.as_object().as_raw(), + ) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::exceptions::PyException; + use pyo3::prelude::*; + + #[test] + fn test_code_eval() { + Python::attach(|py| { + let result = py.eval(c"1 + 1", None, None).unwrap(); + assert_eq!(result.extract::().unwrap(), 2); + }) + } + + #[test] + fn test_code_run_exception() { + Python::attach(|py| { + let err = py.run(c"raise Exception()", None, None).unwrap_err(); + assert!(err.is_instance_of::(py)); + }) + } +} diff --git a/crates/capi/src/complexobject.rs b/crates/capi/src/complexobject.rs new file mode 100644 index 00000000000..a6b2bb731a0 --- /dev/null +++ b/crates/capi/src/complexobject.rs @@ -0,0 +1,52 @@ +use crate::object::define_py_check; +use crate::{PyObject, pystate::with_vm}; +use core::ffi::c_double; +use num_complex::{Complex, Complex64}; +use rustpython_vm::builtins::PyComplex; +use rustpython_vm::{PyResult, VirtualMachine}; + +define_py_check!(fn PyComplex_Check, types.complex_type); +define_py_check!(exact fn PyComplex_CheckExact, types.complex_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyComplex_FromDoubles(real: c_double, imag: c_double) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_complex(Complex::new(real, imag))) +} + +fn try_to_complex(vm: &VirtualMachine, obj: &PyObject) -> PyResult { + obj.try_downcast_ref::(vm).map_or_else( + |type_err| { + if let Some((complex, _)) = obj.to_owned().try_complex(vm)? { + Ok(complex) + } else { + Err(type_err) + } + }, + |complex| Ok(complex.to_complex()), + ) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyComplex_RealAsDouble(obj: *mut PyObject) -> c_double { + with_vm(|vm| try_to_complex(vm, unsafe { &*obj }).map(|complex| complex.re)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyComplex_ImagAsDouble(obj: *mut PyObject) -> c_double { + with_vm(|vm| try_to_complex(vm, unsafe { &*obj }).map(|complex| complex.im)) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyComplex; + + #[test] + fn test_py_int() { + Python::attach(|py| { + let number = PyComplex::from_doubles(py, 1.0, 2.0); + assert_eq!(number.real(), 1.0); + assert_eq!(number.imag(), 2.0); + }) + } +} diff --git a/crates/capi/src/dictobject.rs b/crates/capi/src/dictobject.rs new file mode 100644 index 00000000000..dc5dd58485e --- /dev/null +++ b/crates/capi/src/dictobject.rs @@ -0,0 +1,132 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::c_int; +use core::ptr::NonNull; +use rustpython_vm::AsObject; +use rustpython_vm::builtins::PyDict; + +define_py_check!(fn PyDict_Check, types.dict_type); +define_py_check!(exact fn PyDict_CheckExact, types.dict_type); +define_py_check!(fn PyDictKeys_Check, types.dict_keys_type); +define_py_check!(fn PyDictValues_Check, types.dict_values_type); +define_py_check!(fn PyDictItems_Check, types.dict_items_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyDict_New() -> *mut PyObject { + with_vm(|vm| vm.ctx.new_dict()) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyDict_SetItem( + dict: *mut PyObject, + key: *mut PyObject, + val: *mut PyObject, +) -> c_int { + with_vm(|vm| { + let dict = unsafe { &*dict }.try_downcast_ref::(vm)?; + let key = unsafe { &*key }; + let value = unsafe { &*val }.to_owned(); + dict.inner_setitem(key, value, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyDict_GetItemRef( + dict: *mut PyObject, + key: *mut PyObject, + result: *mut *mut PyObject, +) -> c_int { + with_vm(|vm| { + unsafe { *result = core::ptr::null_mut() }; + let dict = unsafe { &*dict }.try_downcast_ref::(vm)?; + let key = unsafe { &*key }; + + if let Some(value) = dict.inner_getitem_opt(key, vm)? { + unsafe { + *result = value.into_raw().as_ptr(); + } + Ok(true) + } else { + unsafe { + *result = core::ptr::null_mut(); + } + Ok(false) + } + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyDict_Size(dict: *mut PyObject) -> isize { + with_vm(|vm| { + let dict = unsafe { &*dict }.try_downcast_ref::(vm)?; + Ok(dict.__len__()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyDict_Next( + dict: *mut PyObject, + pos: *mut isize, + key: *mut *mut PyObject, + value: *mut *mut PyObject, +) -> c_int { + with_vm(|vm| { + let dict = unsafe { &*dict }.try_downcast_ref::(vm)?; + let index = unsafe { *pos } as usize; + + if let Some((next_pos, k, v)) = dict.next_entry(index) { + unsafe { + *pos = next_pos as isize; + if let Some(key) = NonNull::new(key) { + key.write(k.as_object().as_raw().cast_mut()); + } + if let Some(value) = NonNull::new(value) { + value.write(v.as_object().as_raw().cast_mut()); + } + } + Ok(true) + } else { + Ok(false) + } + }) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::{IntoPyDict, PyDict, PyInt}; + + #[test] + fn test_create_empty_dict() { + Python::attach(|py| { + let dict = PyDict::new(py); + assert!(dict.is_instance_of::()); + }) + } + + #[test] + fn test_create_dict_with_items() { + Python::attach(|py| { + let dict = [(1, 2), (3, 4)].into_py_dict(py)?; + let value = dict.get_item(1)?.unwrap().cast_into::()?; + assert_eq!(value, 2); + assert_eq!(dict.len(), 2); + + Ok::<_, PyErr>(()) + }) + .unwrap() + } + + #[test] + fn test_dict_iter() { + Python::attach(|py| { + let dict = [(1, 2), (3, 4)].into_py_dict(py).unwrap(); + let values = dict + .into_iter() + .flat_map(|(k, v)| [k.extract().unwrap(), v.extract().unwrap()]) + .collect::>(); + assert_eq!(values, vec![1, 2, 3, 4]); + }) + } +} diff --git a/crates/capi/src/floatobject.rs b/crates/capi/src/floatobject.rs new file mode 100644 index 00000000000..f1bb078106d --- /dev/null +++ b/crates/capi/src/floatobject.rs @@ -0,0 +1,41 @@ +use crate::object::define_py_check; +use crate::{PyObject, pystate::with_vm}; +use core::ffi::c_double; +use rustpython_vm::builtins::PyFloat; + +define_py_check!(fn PyFloat_Check, types.float_type); +define_py_check!(exact fn PyFloat_CheckExact, types.float_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyFloat_FromDouble(value: c_double) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_float(value)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyFloat_AsDouble(obj: *mut PyObject) -> c_double { + with_vm(|vm| { + let obj_ref = unsafe { &*obj }; + let float_obj = obj_ref + .to_owned() + .try_downcast::(vm) + .or_else(|_| obj_ref.try_float(vm))?; + + Ok(float_obj.to_f64()) + }) +} + +#[cfg(false)] +mod tests { + use core::f64::consts::PI; + use pyo3::prelude::*; + use pyo3::types::PyFloat; + + #[test] + fn test_py_float() { + Python::attach(|py| { + let pi = PyFloat::new(py, PI); + assert!(pi.is_instance_of::()); + assert_eq!(pi.extract::().unwrap(), PI); + }) + } +} diff --git a/crates/capi/src/import.rs b/crates/capi/src/import.rs new file mode 100644 index 00000000000..d380d1f8266 --- /dev/null +++ b/crates/capi/src/import.rs @@ -0,0 +1,22 @@ +use crate::{PyObject, pystate::with_vm}; +use rustpython_vm::builtins::PyStr; + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyImport_Import(name: *mut PyObject) -> *mut PyObject { + with_vm(|vm| { + let name = unsafe { (&*name).try_downcast_ref::(vm)? }; + vm.import(name, 0) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + + #[test] + fn test_import() { + Python::attach(|py| { + let _module = py.import("sys").unwrap(); + }) + } +} diff --git a/crates/capi/src/lib.rs b/crates/capi/src/lib.rs index 4dc17536a8a..d7f9f8b7464 100644 --- a/crates/capi/src/lib.rs +++ b/crates/capi/src/lib.rs @@ -8,11 +8,26 @@ use std::sync::MutexGuard; extern crate alloc; +pub mod abstract_; +pub mod boolobject; +pub mod bytesobject; +pub mod ceval; +pub mod complexobject; +pub mod dictobject; +pub mod floatobject; +pub mod import; +pub mod listobject; +pub mod longobject; +pub mod methodobject; pub mod object; +pub mod pycapsule; pub mod pyerrors; pub mod pylifecycle; pub mod pystate; pub mod refcount; +pub mod traceback; +pub mod tupleobject; +pub mod unicodeobject; mod util; /// Get main interpreter of this process. Will be None if it has not been initialized yet. diff --git a/crates/capi/src/listobject.rs b/crates/capi/src/listobject.rs new file mode 100644 index 00000000000..1f70e37b651 --- /dev/null +++ b/crates/capi/src/listobject.rs @@ -0,0 +1,186 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::c_int; +use core::ptr::NonNull; +use rustpython_vm::PyObjectRef; +use rustpython_vm::builtins::PyList; + +define_py_check!(fn PyList_Check, types.list_type); +define_py_check!(exact fn PyList_CheckExact, types.list_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyList_New(size: isize) -> *mut PyObject { + with_vm(|vm| { + let capacity = size + .try_into() + .map_err(|_| vm.new_system_error("Negative size passed to PyList_New"))?; + Ok(vm.ctx.new_list(Vec::with_capacity(capacity))) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyList_Size(obj: *mut PyObject) -> isize { + with_vm(|vm| { + let list = unsafe { &*obj }.try_downcast_ref::(vm)?; + Ok(list.__len__()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyList_GetItemRef(obj: *mut PyObject, index: isize) -> *mut PyObject { + with_vm(|vm| { + let list = unsafe { &*obj }.try_downcast_ref::(vm)?; + index + .try_into() + .ok() + .and_then(|index: usize| list.borrow_vec().get(index).map(ToOwned::to_owned)) + .ok_or_else(|| vm.new_index_error(format!("list index out of range: {index}"))) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyList_SetItem( + list: *mut PyObject, + index: isize, + item: *mut PyObject, +) -> c_int { + with_vm(|vm| { + let list = unsafe { &*list }.try_downcast_ref::(vm)?; + let item = unsafe { PyObjectRef::from_raw(NonNull::new_unchecked(item)) }; + let index_error = + || vm.new_index_error(format!("list assignment index out of range: {index}")); + if index < 0 { + return Err(index_error()); + } + + let mut list_mut = list.borrow_vec_mut(); + match index - list_mut.len() as isize { + ..0 => { + list_mut[index as usize] = item; + Ok(()) + } + // This is somewhat a hack, we assume that we are populating a list right after PyList_New + 0 if list_mut.capacity() > index as usize => { + list_mut.push(item); + Ok(()) + } + 0.. => Err(index_error()), + } + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyList_Append(list: *mut PyObject, item: *mut PyObject) -> c_int { + with_vm(|vm| { + let list = unsafe { &*list }.try_downcast_ref::(vm)?; + let item = unsafe { &*item }.to_owned(); + list.borrow_vec_mut().push(item); + Ok(()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyList_Insert( + list: *mut PyObject, + index: isize, + item: *mut PyObject, +) -> c_int { + with_vm(|vm| { + let list = unsafe { &*list }.try_downcast_ref::(vm)?; + let item = unsafe { &*item }.to_owned(); + let mut vec = list.borrow_vec_mut(); + let index = if index < 0 { + index + vec.len() as isize + } else { + index + } + .clamp(0, vec.len() as isize) as usize; + vec.insert(index, item); + Ok(()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyList_Reverse(list: *mut PyObject) -> c_int { + with_vm(|vm| { + let list = unsafe { &*list }.try_downcast_ref::(vm)?; + list.borrow_vec_mut().reverse(); + Ok(()) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::exceptions::PyIndexError; + use pyo3::prelude::*; + use pyo3::types::PyList; + + #[test] + fn test_create_list() { + Python::attach(|py| { + let list = PyList::new(py, &[1, 2, 3]).unwrap(); + assert_eq!(list.len(), 3); + assert_eq!(list.get_item(0).unwrap().extract::().unwrap(), 1); + assert_eq!(list.get_item(1).unwrap().extract::().unwrap(), 2); + assert_eq!(list.get_item(2).unwrap().extract::().unwrap(), 3); + assert!(list.get_item(3).is_err()); + }) + } + + #[test] + fn test_replace_item_in_list() { + Python::attach(|py| { + let list = PyList::new(py, &[1]).unwrap(); + assert_eq!(list.len(), 1); + list.set_item(0, 2).unwrap(); + assert_eq!(list.len(), 1); + assert_eq!(list.get_item(0).unwrap().extract::().unwrap(), 2); + }) + } + + #[test] + fn test_set_item_out_of_range() { + Python::attach(|py| { + let list = PyList::empty(py); + assert!( + list.set_item(0, 1) + .unwrap_err() + .is_instance_of::(py) + ); + }) + } + + #[test] + fn test_list_append() { + Python::attach(|py| { + let list = PyList::empty(py); + assert_eq!(list.len(), 0); + list.append(1).unwrap(); + assert_eq!(list.len(), 1); + assert_eq!(list.get_item(0).unwrap().extract::().unwrap(), 1); + }) + } + + #[test] + fn test_list_insert() { + Python::attach(|py| { + let list = PyList::empty(py); + assert_eq!(list.len(), 0); + list.insert(0, 1).unwrap(); + assert_eq!(list.len(), 1); + list.insert(2, 3).unwrap(); + assert_eq!(list.get_item(1).unwrap().extract::().unwrap(), 3); + }) + } + + #[test] + fn test_list_reverse() { + Python::attach(|py| { + let list = PyList::new(py, &[1, 2, 3]).unwrap(); + list.reverse().unwrap(); + assert_eq!(list.get_item(0).unwrap().extract::().unwrap(), 3); + assert_eq!(list.get_item(2).unwrap().extract::().unwrap(), 1); + }) + } +} diff --git a/crates/capi/src/longobject.rs b/crates/capi/src/longobject.rs new file mode 100644 index 00000000000..8c9fe5e1acb --- /dev/null +++ b/crates/capi/src/longobject.rs @@ -0,0 +1,89 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::{c_long, c_longlong, c_ulong, c_ulonglong}; +use rustpython_vm::PyResult; +use rustpython_vm::builtins::PyInt; + +define_py_check!(fn PyLong_Check, types.int_type); +define_py_check!(exact fn PyLong_CheckExact, types.int_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyLong_FromLong(value: c_long) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_int(value)) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyLong_FromLongLong(value: c_longlong) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_int(value)) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyLong_FromSsize_t(value: isize) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_int(value)) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyLong_FromSize_t(value: usize) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_int(value)) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyLong_FromUnsignedLong(value: c_ulong) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_int(value)) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyLong_FromUnsignedLongLong(value: c_ulonglong) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_int(value)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyLong_AsLong(obj: *mut PyObject) -> c_long { + with_vm::, _>(|vm| { + unsafe { &*obj } + .to_owned() + .try_index(vm)? + .as_bigint() + .try_into() + .map_err(|_| vm.new_overflow_error("Python int too large to convert to C long")) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyLong_AsUnsignedLongLong(obj: *mut PyObject) -> c_ulonglong { + with_vm::, _>(|vm| { + unsafe { &*obj } + .to_owned() + .try_downcast::(vm)? + .as_bigint() + .try_into() + .map_err(|_| { + vm.new_overflow_error("Python int too large to convert to C unsigned long long") + }) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyInt; + + #[test] + fn test_py_int_u32() { + Python::attach(|py| { + let number = PyInt::new(py, 123); + assert!(number.is_instance_of::()); + assert_eq!(number.extract::().unwrap(), 123); + }) + } + + #[test] + fn test_py_int_u64() { + Python::attach(|py| { + let number = PyInt::new(py, 123u64); + assert!(number.is_instance_of::()); + assert_eq!(number.extract::().unwrap(), 123); + }) + } +} diff --git a/crates/capi/src/methodobject.rs b/crates/capi/src/methodobject.rs new file mode 100644 index 00000000000..c0a6611a01f --- /dev/null +++ b/crates/capi/src/methodobject.rs @@ -0,0 +1,369 @@ +use crate::PyObject; +use crate::object::PyTypeObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::{CStr, c_char, c_int}; +use core::ptr::NonNull; +use rustpython_vm::function::{FuncArgs, HeapMethodDef, PosArgs, PyMethodFlags}; +use rustpython_vm::{AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine}; + +define_py_check!(fn PyCFunction_Check, types.builtin_function_or_method_type); +define_py_check!(exact fn PyCFunction_CheckExact, types.builtin_function_or_method_type); + +#[repr(C)] +pub struct PyMethodDef { + pub ml_name: *const c_char, + pub ml_meth: PyMethodPointer, + pub ml_flags: c_int, + pub ml_doc: *const c_char, +} + +#[repr(C)] +#[derive(Copy, Clone)] +#[allow(non_snake_case)] +pub union PyMethodPointer { + pub PyCFunction: unsafe extern "C" fn(slf: *mut PyObject, args: *mut PyObject) -> *mut PyObject, + pub PyCFunctionWithKeywords: unsafe extern "C" fn( + slf: *mut PyObject, + args: *mut PyObject, + kwargs: *mut PyObject, + ) -> *mut PyObject, + pub PyCFunctionFast: unsafe extern "C" fn( + slf: *mut PyObject, + args: *const *mut PyObject, + nargs: isize, + ) -> *mut PyObject, + pub PyCFunctionFastWithKeywords: unsafe extern "C" fn( + slf: *mut PyObject, + args: *const *mut PyObject, + nargs: isize, + kwnames: *mut PyObject, + ) -> *mut PyObject, +} + +pub(crate) fn build_method_def( + vm: &VirtualMachine, + ml: &PyMethodDef, + has_self: bool, +) -> PyResult> { + let name = unsafe { CStr::from_ptr(ml.ml_name) } + .to_str() + .map_err(|_| vm.new_system_error("Method name was not valid UTF-8"))?; + + let doc = NonNull::new(ml.ml_doc.cast_mut()) + .map(|doc| { + unsafe { CStr::from_ptr(doc.as_ptr()) } + .to_str() + .map_err(|_| vm.new_system_error("Method doc was not valid UTF-8")) + }) + .transpose()?; + + let flags = PyMethodFlags::from_bits(ml.ml_flags as u32) + .ok_or_else(|| vm.new_system_error("PyMethodDef contains unknown flags"))?; + + let method = ml.ml_meth; + + if flags.contains(PyMethodFlags::METHOD) { + return Err(vm.new_system_error("METH_METHOD is not supported on abi3")); + } + + let call_flags = flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::KEYWORDS + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::FASTCALL); + + bitflags::bitflags_match!(call_flags, { + PyMethodFlags::NOARGS => { + if has_self { + let callable = move |zelf: PyObjectRef, vm: &VirtualMachine| unsafe { + let f = method.PyCFunction; + let ret_ptr = f(zelf.as_raw().cast_mut(), core::ptr::null_mut()); + ret_ptr_to_pyresult(vm, ret_ptr) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + } else { + let callable = move |vm: &VirtualMachine| unsafe { + let f = method.PyCFunction; + let ret_ptr = f(core::ptr::null_mut(), core::ptr::null_mut()); + ret_ptr_to_pyresult(vm, ret_ptr) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + } + }, + PyMethodFlags::VARARGS => { + let callable = move |args: PosArgs, vm: &VirtualMachine| unsafe { + call_function(vm, method, flags, Some(args)) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + }, + PyMethodFlags::VARARGS | PyMethodFlags::KEYWORDS => { + let callable = move | args: FuncArgs, vm: &VirtualMachine| unsafe { + call_function_with_keywords(vm, method, flags, args) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + }, + PyMethodFlags::FASTCALL | PyMethodFlags::KEYWORDS => { + let callable = move |args: FuncArgs, vm: &VirtualMachine| unsafe { + call_fast_function_with_keywords(vm, method, flags, args) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + }, + PyMethodFlags::FASTCALL => { + let callable = move |args: PosArgs, vm: &VirtualMachine| unsafe { + call_fast_function(vm, method, flags, args) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + }, + PyMethodFlags::O => { + let f = unsafe { method.PyCFunction }; + if has_self { + let callable = move |zelf: PyObjectRef, arg: PyObjectRef, vm: &VirtualMachine| -> PyResult { + let ret_ptr = unsafe { f(zelf.as_raw().cast_mut(), arg.as_raw().cast_mut()) }; + ret_ptr_to_pyresult(vm, ret_ptr) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + } else { + let callable = move |arg: PyObjectRef, vm: &VirtualMachine| -> PyResult { + let ret_ptr = unsafe { f(core::ptr::null_mut(), arg.as_raw().cast_mut()) }; + ret_ptr_to_pyresult(vm, ret_ptr) + }; + Ok(vm.ctx.new_method_def(name, callable, flags, doc)) + } + }, + _ => { + Err(vm.new_system_error(format!( + "function {name} has unsupported or invalid calling-convention flags: {flags:?}" + ))) + }, + }) +} + +unsafe fn call_function>( + vm: &VirtualMachine, + method: PyMethodPointer, + flags: PyMethodFlags, + args: Option, +) -> PyResult { + let f = unsafe { method.PyCFunction }; + let (slf, arg_tuple) = if let Some(mut args) = args.map(Into::into) { + let slf = take_self_arg(&mut args, flags); + let arg_tuple = vm.ctx.new_tuple(args.args); + (slf, Some(arg_tuple)) + } else { + (None, None) + }; + + let slf_ptr = slf + .as_ref() + .map(|obj| obj.as_object().as_raw().cast_mut()) + .unwrap_or_default(); + + let arg_ptr = arg_tuple + .as_ref() + .map(|tuple| tuple.as_object().as_raw().cast_mut()) + .unwrap_or_default(); + + let ret_ptr = unsafe { f(slf_ptr, arg_ptr) }; + ret_ptr_to_pyresult(vm, ret_ptr) +} + +unsafe fn call_function_with_keywords( + vm: &VirtualMachine, + method: PyMethodPointer, + flags: PyMethodFlags, + mut args: FuncArgs, +) -> PyResult { + let f = unsafe { method.PyCFunctionWithKeywords }; + let slf = take_self_arg(&mut args, flags); + let slf_ptr = slf + .as_ref() + .map(|obj| obj.as_object().as_raw().cast_mut()) + .unwrap_or_default(); + let arg_tuple = vm.ctx.new_tuple(args.args); + let kwargs = vm.ctx.new_dict(); + for (k, v) in args.kwargs { + kwargs.set_item(&*k, v, vm)?; + } + let ret_ptr = unsafe { + f( + slf_ptr, + arg_tuple.as_object().as_raw().cast_mut(), + kwargs.as_object().as_raw().cast_mut(), + ) + }; + ret_ptr_to_pyresult(vm, ret_ptr) +} + +unsafe fn call_fast_function_with_keywords( + vm: &VirtualMachine, + method: PyMethodPointer, + flags: PyMethodFlags, + mut args: FuncArgs, +) -> PyResult { + let f = unsafe { method.PyCFunctionFastWithKeywords }; + let slf = take_self_arg(&mut args, flags); + let slf_ptr = slf + .as_ref() + .map(|obj| obj.as_object().as_raw().cast_mut()) + .unwrap_or_default(); + let nargs = args.args.len(); + let mut fastcall_args = args.args; + let kwnames_tuple = if !args.kwargs.is_empty() { + let mut kwnames = Vec::with_capacity(args.kwargs.len()); + for (k, v) in args.kwargs { + kwnames.push(vm.ctx.new_str(k).into()); + fastcall_args.push(v); + } + Some(vm.ctx.new_tuple(kwnames)) + } else { + None + }; + let kwnames_ptr = kwnames_tuple + .as_ref() + .map(|tuple| tuple.as_object().as_raw().cast_mut()) + .unwrap_or_default(); + // SAFETY: PyObjectRef is repr(transparent) over a pointer to PyObject, so a + // Vec has a layout-compatible contiguous backing buffer. The + // vector is kept alive for the duration of the call. + let fastcall_arg_ptrs = fastcall_args.as_ptr().cast::<*mut PyObject>(); + let ret_ptr = unsafe { f(slf_ptr, fastcall_arg_ptrs, nargs as isize, kwnames_ptr) }; + ret_ptr_to_pyresult(vm, ret_ptr) +} + +unsafe fn call_fast_function( + vm: &VirtualMachine, + method: PyMethodPointer, + flags: PyMethodFlags, + args: PosArgs, +) -> PyResult { + let f = unsafe { method.PyCFunctionFast }; + let mut args: FuncArgs = args.into(); + let slf = take_self_arg(&mut args, flags); + let slf_ptr = slf + .as_ref() + .map(|obj| obj.as_object().as_raw().cast_mut()) + .unwrap_or_default(); + // SAFETY: PyObjectRef is repr(transparent) over a pointer to PyObject, so a + // Vec has a layout-compatible contiguous backing buffer. The + // vector is kept alive for the duration of the call. + let fastcall_arg_ptrs = args.args.as_mut_ptr().cast::<*mut PyObject>(); + let ret_ptr = unsafe { f(slf_ptr, fastcall_arg_ptrs, args.args.len() as isize) }; + ret_ptr_to_pyresult(vm, ret_ptr) +} + +fn ret_ptr_to_pyresult(vm: &VirtualMachine, ret_ptr: *mut PyObject) -> PyResult { + let ret_ptr = NonNull::new(ret_ptr).ok_or_else(|| { + vm.take_raised_exception() + .expect("Native function returned NULL, but there was no exception set") + })?; + Ok(unsafe { PyObjectRef::from_raw(ret_ptr) }) +} + +fn take_self_arg(args: &mut FuncArgs, flags: PyMethodFlags) -> Option { + if flags.contains(PyMethodFlags::STATIC) { + None + } else { + args.take_positional() + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCMethod_New( + ml: *mut PyMethodDef, + slf: *mut PyObject, + _module: *mut PyObject, + _cls: *mut PyTypeObject, +) -> *mut PyObject { + with_vm(|vm| -> PyResult { + assert!( + _cls.is_null(), + "PyCMethod_New does not support METH_METHOD on abi3" + ); + let ml = unsafe { &*ml }; + let zelf = unsafe { slf.as_ref().map(|obj| obj.to_owned()) }; + Ok(build_method_def(vm, ml, zelf.is_some())? + .build_function(vm, zelf) + .into()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCFunction_New( + ml: *mut PyMethodDef, + slf: *mut PyObject, +) -> *mut PyObject { + unsafe { PyCMethod_New(ml, slf, core::ptr::null_mut(), core::ptr::null_mut()) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCFunction_NewEx( + ml: *mut PyMethodDef, + slf: *mut PyObject, + module: *mut PyObject, +) -> *mut PyObject { + unsafe { PyCMethod_New(ml, slf, module, core::ptr::null_mut()) } +} + +#[cfg(false)] +mod tests { + use pyo3::exceptions::PyException; + use pyo3::ffi::{PyLong_FromLong, PyObject}; + use pyo3::prelude::*; + use pyo3::types::{PyCFunction, PyInt, PyString}; + + #[test] + fn test_closure_function() { + Python::attach(|py| { + let f = PyCFunction::new_closure(py, None, None, |_args, _kwargs| "Hello from Rust!") + .unwrap(); + + assert_eq!( + f.call0().unwrap().cast::().unwrap(), + "Hello from Rust!" + ); + }) + } + + #[test] + fn test_function_no_args() { + Python::attach(|py| { + unsafe extern "C" fn c_fn(_self: *mut PyObject, _args: *mut PyObject) -> *mut PyObject { + assert!(_self.is_null()); + assert!(_args.is_null()); + unsafe { PyLong_FromLong(4200) } + } + + let py_fn = PyCFunction::new(py, c_fn, c"py_fn", c"", None).unwrap(); + + let result = py_fn + .call0() + .unwrap() + .cast::() + .unwrap() + .extract::() + .unwrap(); + assert_eq!(result, 4200); + + assert!(py_fn.call((1,), None).is_err()); + assert!(py_fn.call((1, 2), None).is_err()); + }) + } + + #[test] + fn test_closure_function_error() { + Python::attach(|py| { + let f = PyCFunction::new_closure(py, None, None, |_args, _kwargs| { + Err::<(), _>(PyException::new_err("Something went wrong")) + }) + .unwrap(); + + let err = f.call0().unwrap_err(); + assert_eq!( + err.value(py).repr().unwrap(), + "Exception('Something went wrong')" + ); + }) + } +} diff --git a/crates/capi/src/object.rs b/crates/capi/src/object.rs index 47ac58560db..9866d9041f9 100644 --- a/crates/capi/src/object.rs +++ b/crates/capi/src/object.rs @@ -7,6 +7,36 @@ use rustpython_vm::{AsObject, Py}; pub type PyTypeObject = Py; +macro_rules! define_py_check { + (fn $name:ident, $($ctx_path:ident).+) => { + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $name(obj: *mut crate::PyObject) -> core::ffi::c_int { + crate::pystate::with_vm(|vm| unsafe { + obj + .as_ref() + .map(|obj| obj.class().is_subtype(vm.ctx.$($ctx_path).+)) + .unwrap_or_default() + }) + } + }; + (exact fn $name:ident, $($ctx_path:ident).+) => { + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $name(obj: *mut crate::PyObject) -> core::ffi::c_int { + use rustpython_vm::AsObject; + crate::pystate::with_vm(|vm| unsafe { + obj + .as_ref() + .map(|obj| obj.class().is(vm.ctx.$($ctx_path).+)) + .unwrap_or_default() + }) + } + }; +} + +pub(crate) use define_py_check; +define_py_check!(fn PyType_Check, types.type_type); +define_py_check!(exact fn PyType_CheckExact, types.type_type); + #[unsafe(no_mangle)] pub unsafe extern "C" fn Py_TYPE(op: *mut PyObject) -> *const PyTypeObject { unsafe { (*op).class() } @@ -27,6 +57,42 @@ pub unsafe extern "C" fn PyType_GetFlags(ptr: *const PyTypeObject) -> c_ulong { ty.slots.flags.bits() as u32 as c_ulong } +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyType_IsSubtype(a: *const PyTypeObject, b: *const PyTypeObject) -> c_int { + with_vm(move |_vm| { + let a = unsafe { &*a }; + let b = unsafe { &*b }; + Ok(a.is_subtype(b)) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyType_GetName(ptr: *const PyTypeObject) -> *mut PyObject { + with_vm(|vm| unsafe { &*ptr }.__name__(vm)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyType_GetQualName(ptr: *const PyTypeObject) -> *mut PyObject { + with_vm(|vm| unsafe { &*ptr }.__qualname__(vm)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyType_GetFullyQualifiedName(ptr: *const PyTypeObject) -> *mut PyObject { + with_vm(|vm| { + let ty = unsafe { &*ptr }; + let qualname = ty.__qualname__(vm).try_downcast::(vm)?; + let module = ty.__module__(vm); + + if let Some(module) = module.downcast_ref::() + && module.as_wtf8() != "builtins" + { + Ok(vm.ctx.new_str(format!("{module}.{qualname}"))) + } else { + Ok(qualname) + } + }) +} + #[unsafe(no_mangle)] pub extern "C" fn Py_GetConstantBorrowed(constant_id: c_uint) -> *mut PyObject { with_vm(|vm| { diff --git a/crates/capi/src/pycapsule.rs b/crates/capi/src/pycapsule.rs new file mode 100644 index 00000000000..7a5d599c851 --- /dev/null +++ b/crates/capi/src/pycapsule.rs @@ -0,0 +1,160 @@ +use crate::PyObject; +use crate::pystate::with_vm; +use core::ffi::{CStr, c_char, c_int, c_void}; +use core::ptr::NonNull; +use rustpython_vm::builtins::PyCapsule; +use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; + +#[allow(non_camel_case_types)] +pub type PyCapsule_Destructor = unsafe extern "C" fn(capsule: *mut PyObject); + +#[unsafe(no_mangle)] +pub extern "C" fn PyCapsule_New( + pointer: *mut c_void, + name: *const c_char, + destructor: Option, +) -> *mut PyObject { + with_vm(|vm| { + if pointer.is_null() { + return Err(vm.new_value_error("PyCapsule_New called with null pointer")); + } + let name = NonNull::new(name.cast_mut()).map(|ptr| unsafe { CStr::from_ptr(ptr.as_ptr()) }); + Ok(vm.ctx.new_capsule(pointer, name, destructor)) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_GetPointer( + capsule: *mut PyObject, + name: *const c_char, +) -> *mut c_void { + with_vm(|vm| Ok(checked_capsule(vm, unsafe { &*capsule }, name)?.pointer())) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_GetName(capsule: *mut PyObject) -> *const c_char { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + Ok(capsule.name().map(CStr::as_ptr).unwrap_or_default()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_GetContext(capsule: *mut PyObject) -> *mut c_void { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + Ok(capsule.context()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_SetContext( + capsule: *mut PyObject, + context: *mut c_void, +) -> c_int { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + let _: () = capsule.set_context(context); + Ok(()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_SetPointer( + capsule: *mut PyObject, + pointer: *mut c_void, +) -> c_int { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + let _: () = capsule.set_pointer(pointer); + Ok(()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_IsValid(capsule: *mut PyObject, name: *const c_char) -> c_int { + with_vm(|vm| { + if capsule.is_null() { + return false; + } + + checked_capsule(vm, unsafe { &*capsule }, name).is_ok() + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_Import(name: *const c_char, _no_block: c_int) -> *mut c_void { + with_vm(|vm| { + let capsule_name = unsafe { CStr::from_ptr(name) } + .to_str() + .map_err(|_| vm.new_system_error("capsule name is not valid UTF-8"))?; + let (module_name, attrs_path) = capsule_name.split_once('.').ok_or_else(|| { + vm.new_import_error( + "capsule name is missing attribute path", + vm.ctx.new_str(capsule_name), + ) + })?; + let mut obj: PyObjectRef = vm.import(module_name, 0)?; + + for attr in attrs_path.split('.') { + obj = obj.get_attr(attr, vm)?; + } + + Ok(checked_capsule(vm, &obj, name)?.pointer()) + }) +} + +#[inline] +fn names_match(stored_name: *const c_char, expected_name: *const c_char) -> bool { + if stored_name.is_null() || expected_name.is_null() { + stored_name.is_null() && expected_name.is_null() + } else { + unsafe { CStr::from_ptr(stored_name) == CStr::from_ptr(expected_name) } + } +} + +#[inline] +fn checked_capsule<'a>( + vm: &VirtualMachine, + obj: &'a PyObject, + name: *const c_char, +) -> PyResult<&'a PyCapsule> { + let capsule = obj + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + + if !names_match(capsule.name().map(CStr::as_ptr).unwrap_or_default(), name) { + return Err(vm.new_value_error("Capsule name does not match")); + } + + if capsule.pointer().is_null() { + return Err(vm.new_value_error("Capsule has null pointer")); + } + + Ok(capsule) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyCapsule; + + #[test] + fn test_capsule_new() { + Python::attach(|py| { + let value = String::from("Some data"); + let capsule = PyCapsule::new_with_value(py, value, c"my_capsule").unwrap(); + assert!(capsule.is_valid_checked(Some(c"my_capsule"))); + let ptr = capsule.pointer_checked(Some(c"my_capsule")).unwrap(); + assert_eq!(unsafe { ptr.cast::().as_ref() }, "Some data"); + }) + } +} diff --git a/crates/capi/src/pyerrors.rs b/crates/capi/src/pyerrors.rs index 4554757c66e..b767b7c4090 100644 --- a/crates/capi/src/pyerrors.rs +++ b/crates/capi/src/pyerrors.rs @@ -1,5 +1,5 @@ -use crate::PyObject; -use crate::pystate::with_vm; +use crate::object::define_py_check; +use crate::{PyObject, pystate::with_vm}; use core::convert::Infallible; use core::ffi::{CStr, c_char, c_int}; use core::ptr::NonNull; @@ -96,6 +96,8 @@ define_exception_statics! { PyExc_EncodingWarning => encoding_warning, } +define_py_check!(fn PyExceptionInstance_Check, exceptions.base_exception_type); + #[unsafe(no_mangle)] pub extern "C" fn PyErr_Occurred() -> *mut PyObject { with_vm(|vm| { @@ -191,6 +193,15 @@ pub unsafe extern "C" fn PyErr_WriteUnraisable(obj: *mut PyObject) { }) } +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyExceptionClass_Check(obj: *mut PyObject) -> c_int { + with_vm(|vm| unsafe { + obj.as_ref() + .and_then(|obj| obj.downcast_ref::()) + .is_some_and(|ty| ty.is_subtype(vm.ctx.exceptions.base_exception_type)) + }) +} + #[unsafe(no_mangle)] pub unsafe extern "C" fn PyErr_NewException( name: *const c_char, @@ -252,13 +263,49 @@ pub unsafe extern "C" fn PyErr_GivenExceptionMatches( }) } +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyException_GetTraceback(exc: *mut PyObject) -> *mut PyObject { + with_vm(|vm| { + let exc = unsafe { &*exc }.try_downcast_ref::(vm)?; + let tb = exc + .__traceback__() + .map(|tb| tb.into_object().into_raw().as_ptr()) + .unwrap_or_default(); + Ok(tb) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyException_GetCause(exc: *mut PyObject) -> *mut PyObject { + with_vm(|vm| { + let exc = unsafe { &*exc }.try_downcast_ref::(vm)?; + let cause = exc + .__cause__() + .map(|cause| cause.into_object().into_raw().as_ptr()) + .unwrap_or_default(); + Ok(cause) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyException_GetContext(exc: *mut PyObject) -> *mut PyObject { + with_vm(|vm| { + let exc = unsafe { &*exc }.try_downcast_ref::(vm)?; + let context = exc + .__context__() + .map(|context| context.into_object().into_raw().as_ptr()) + .unwrap_or_default(); + Ok(context) + }) +} + #[cfg(test)] mod tests { use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; #[test] - fn test_raised_exception() { + fn raised_exception() { Python::attach(|py| { PyTypeError::new_err(py.None()).restore(py); assert!(PyErr::occurred(py)); @@ -268,7 +315,7 @@ mod tests { } #[test] - fn test_error_is_instance() { + fn error_is_instance() { Python::attach(|py| { let err = PyTypeError::new_err(py.None()); assert!(err.is_instance_of::(py)); diff --git a/crates/capi/src/pystate.rs b/crates/capi/src/pystate.rs index 97b29bcebe1..f3d12f04a0b 100644 --- a/crates/capi/src/pystate.rs +++ b/crates/capi/src/pystate.rs @@ -58,7 +58,7 @@ mod tests { use rustpython_vm::vm::thread::{current_vm_is_set, with_current_vm}; #[test] - fn test_new_thread() { + fn new_thread() { Python::attach(|_py| { with_current_vm(|_vm| { assert!( @@ -83,7 +83,7 @@ mod tests { } #[test] - fn test_current_vm_main_thread() { + fn current_vm_main_thread() { Python::initialize(); // let RustPython create a vm for this thread. @@ -105,7 +105,7 @@ mod tests { } #[test] - fn test_gilstate_release_detaches_external_thread() { + fn gilstate_release_detaches_external_thread() { Python::initialize(); std::thread::spawn(|| { diff --git a/crates/capi/src/traceback.rs b/crates/capi/src/traceback.rs new file mode 100644 index 00000000000..5b264b1ddbc --- /dev/null +++ b/crates/capi/src/traceback.rs @@ -0,0 +1,24 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::c_int; +use rustpython_vm::function::{FuncArgs, KwArgs}; + +define_py_check!(exact fn PyTraceBack_Check, types.traceback_type); + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyTraceBack_Print(tb: *mut PyObject, file: *mut PyObject) -> c_int { + with_vm(|vm| { + let tb = unsafe { &*tb }; + let file = unsafe { &*file }; + let tb_module = vm.import("traceback", 0)?; + let print_tb = tb_module.get_attr("print_tb", vm)?; + + let kwargs: KwArgs = [("file".to_string(), file.to_owned())] + .into_iter() + .collect(); + print_tb.call(FuncArgs::new(vec![tb.to_owned()], kwargs), vm)?; + + Ok(()) + }) +} diff --git a/crates/capi/src/tupleobject.rs b/crates/capi/src/tupleobject.rs new file mode 100644 index 00000000000..985141f6d4c --- /dev/null +++ b/crates/capi/src/tupleobject.rs @@ -0,0 +1,122 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::c_int; +use core::slice; +use rustpython_vm::PyResult; +use rustpython_vm::builtins::PyTuple; +use rustpython_vm::sliceable::SliceableSequenceOp; + +define_py_check!(fn PyTuple_Check, types.tuple_type); +define_py_check!(exact fn PyTuple_CheckExact, types.tuple_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyTuple_New(len: isize) -> *mut PyObject { + with_vm(|vm| { + if len == 0 { + return Ok(vm.ctx.empty_tuple.to_owned()); + } + + Err(vm.new_not_implemented_error( + "PyTuple_New for non zero sized tuples is not supported, use PyTuple_FromArray instead", + )) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyTuple_FromArray( + array: *const *mut PyObject, + size: isize, +) -> *mut PyObject { + with_vm(|vm| { + let size = size + .try_into() + .map_err(|_| vm.new_system_error("negative size passed to Tuple_FromArray"))?; + let slice = unsafe { slice::from_raw_parts(array, size) }; + let elements = slice + .iter() + .map(|ptr| unsafe { &**ptr }.to_owned()) + .collect::>(); + Ok(vm.new_tuple(elements)) + }) +} + +#[unsafe(no_mangle)] +pub extern "C" fn PyTuple_SetItem( + _tuple: *mut PyObject, + _pos: isize, + _value: *mut PyObject, +) -> c_int { + with_vm::, _>( + |vm| Err(vm.new_not_implemented_error("Tuple objects are immutable")), + ) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyTuple_Size(tuple: *mut PyObject) -> isize { + with_vm(|vm| { + let tuple = unsafe { &*tuple }.try_downcast_ref::(vm)?; + Ok(tuple.__len__()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyTuple_GetItem(tuple: *mut PyObject, pos: isize) -> *mut PyObject { + with_vm(|vm| { + let tuple = unsafe { &*tuple }.try_downcast_ref::(vm)?; + let result: &PyObject = pos + .try_into() + .ok() + .and_then(|index: usize| tuple.get(index)) + .ok_or_else(|| vm.new_index_error("tuple index out of range"))?; + + Ok(result.as_raw()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyTuple_GetSlice( + tuple: *mut PyObject, + low: isize, + high: isize, +) -> *mut PyObject { + with_vm(|vm| { + let tuple = unsafe { &*tuple }.try_downcast_ref::(vm)?; + let len = tuple.__len__() as isize; + let low = low.clamp(0, len); + let high = high.clamp(low, len); + let slice = tuple.do_slice(low as usize..high as usize); + Ok(vm.ctx.new_tuple(slice)) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyTuple; + + #[test] + fn test_empty_tuple() { + Python::attach(|py| { + let tuple = PyTuple::empty(py); + assert_eq!(tuple.len(), 0); + }) + } + + #[test] + fn test_tuple_into_python() { + Python::attach(|py| { + let tuple = (1, 2, 3).into_pyobject(py).unwrap(); + assert_eq!(tuple.len(), 3); + }) + } + + #[test] + fn test_tuple_get_slice() { + Python::attach(|py| { + let tuple = (1, 2, 3).into_pyobject(py).unwrap(); + let slice = tuple.get_slice(1, 2); + assert_eq!(slice.extract::<(u32,)>().unwrap(), (2,)); + }) + } +} diff --git a/crates/capi/src/unicodeobject.rs b/crates/capi/src/unicodeobject.rs new file mode 100644 index 00000000000..76e0fb0df58 --- /dev/null +++ b/crates/capi/src/unicodeobject.rs @@ -0,0 +1,154 @@ +use crate::PyObject; +use crate::object::define_py_check; +use crate::pystate::with_vm; +use core::ffi::{CStr, c_char, c_int}; +use core::ptr::NonNull; +use core::slice; +use core::str; +use rustpython_vm::PyObjectRef; +use rustpython_vm::builtins::PyStr; + +define_py_check!(fn PyUnicode_Check, types.str_type); +define_py_check!(exact fn PyUnicode_CheckExact, types.str_type); + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyUnicode_FromStringAndSize( + s: *const c_char, + len: isize, +) -> *mut PyObject { + with_vm(|vm| { + let len: usize = len + .try_into() + .map_err(|_| vm.new_system_error("length must be non-negative"))?; + + let text = if s.is_null() { + if len != 0 { + return Err(vm.new_system_error( + "PyUnicode_FromStringAndSize called with null data and non-zero len", + )); + } + "" + } else { + let bytes = unsafe { slice::from_raw_parts(s.cast::(), len) }; + str::from_utf8(bytes).expect("PyUnicode_FromStringAndSize got non-UTF8 data") + }; + + Ok(vm.ctx.new_str(text)) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyUnicode_AsUTF8AndSize( + obj: *mut PyObject, + size: *mut isize, +) -> *const c_char { + with_vm(|vm| { + let unicode = unsafe { &*obj }.try_downcast_ref::(vm)?; + + let str = unicode.to_str().ok_or_else(|| { + vm.new_system_error("PyUnicode_AsUTF8AndSize only supports UTF-8 or ASCII strings") + })?; + + if size.is_null() { + // We do not support null size arguments because the returned string is not NULL terminated. + return Err( + vm.new_system_error("size argument to PyUnicode_AsUTF8AndSize cannot be null") + ); + } + + unsafe { *size = str.len() as isize }; + Ok(str.as_ptr()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyUnicode_AsEncodedString( + unicode: *mut PyObject, + encoding: *const c_char, + errors: *const c_char, +) -> *mut PyObject { + with_vm(|vm| { + let unicode = unsafe { &*unicode } + .try_downcast_ref::(vm)? + .to_owned(); + let encoding = if encoding.is_null() { + "utf-8" + } else { + unsafe { CStr::from_ptr(encoding) } + .to_str() + .expect("encoding must be valid UTF-8") + }; + let errors = if errors.is_null() { + None + } else { + let errors = unsafe { CStr::from_ptr(errors) } + .to_str() + .expect("errors must be valid UTF-8"); + Some(vm.ctx.new_utf8_str(errors)) + }; + vm.state + .codec_registry + .encode_text(unicode, encoding, errors, vm) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyUnicode_InternInPlace(string: *mut *mut PyObject) { + with_vm(|vm| { + let old_str = unsafe { PyObjectRef::from_raw(NonNull::new_unchecked(*string)) } + .downcast_exact::(vm) + .expect("PyUnicode_InternInPlace called with non-string object"); + + let interned: PyObjectRef = vm.ctx.intern_str(old_str).to_owned().into(); + + unsafe { *string = interned.into_raw().as_ptr() } + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyUnicode_EqualToUTF8AndSize( + unicode: *mut PyObject, + string: *const c_char, + size: isize, +) -> c_int { + with_vm(|vm| { + let size = size.try_into().map_err(|_| { + vm.new_system_error("Negative size passed to PyUnicode_EqualToUTF8AndSize") + })?; + + let unicode = unsafe { &*unicode }.try_downcast_ref::(vm)?; + let result = unsafe { + let slice = slice::from_raw_parts(string as _, size); + str::from_utf8(slice) + } + .ok() + .and_then(|other| Some(unicode.to_str()? == other)) + .unwrap_or(false); + + Ok(result) + }) +} + +#[cfg(false)] +mod tests { + use pyo3::intern; + use pyo3::prelude::*; + use pyo3::types::PyString; + + #[test] + fn test_unicode() { + Python::attach(|py| { + let string = PyString::new(py, "Hello, World!"); + assert!(string.is_instance_of::()); + assert_eq!(string.to_str().unwrap(), "Hello, World!"); + assert_eq!(string, "Hello, World!"); + }) + } + + #[test] + fn test_intern_str() { + Python::attach(|py| { + let _string = intern!(py, "Hello, World!"); + }) + } +} diff --git a/crates/capi/src/util.rs b/crates/capi/src/util.rs index 6137ca9029f..6eef9163a5c 100644 --- a/crates/capi/src/util.rs +++ b/crates/capi/src/util.rs @@ -1,6 +1,6 @@ use crate::PyObject; use core::convert::Infallible; -use core::ffi::{c_char, c_double, c_int, c_long, c_void}; +use core::ffi::{c_char, c_double, c_int, c_long, c_ulonglong, c_void}; use rustpython_vm::{PyObjectRef, PyRef, PyResult, VirtualMachine}; pub(crate) trait FfiResult { @@ -76,6 +76,14 @@ impl FfiResult<*mut c_char> for *const u8 { } } +impl FfiResult for *const c_char { + const ERR_VALUE: *const c_char = core::ptr::null_mut(); + + fn into_output(self, _vm: &VirtualMachine) -> *const c_char { + self + } +} + impl FfiResult for usize { const ERR_VALUE: isize = -1; @@ -93,6 +101,14 @@ impl FfiResult for c_long { } } +impl FfiResult for c_ulonglong { + const ERR_VALUE: Self = Self::MAX; + + fn into_output(self, _vm: &VirtualMachine) -> Self { + self + } +} + impl FfiResult for c_double { const ERR_VALUE: Self = -1.0; diff --git a/crates/codegen/Cargo.toml b/crates/codegen/Cargo.toml index 4f32fcc3c3f..031f3b96521 100644 --- a/crates/codegen/Cargo.toml +++ b/crates/codegen/Cargo.toml @@ -19,7 +19,6 @@ rustpython-wtf8 = { workspace = true } ruff_python_ast = { workspace = true } ruff_text_size = { workspace = true } -ahash = { workspace = true } bitflags = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } @@ -29,6 +28,7 @@ num-traits = { workspace = true } thiserror = { workspace = true } malachite-bigint = { workspace = true } memchr = { workspace = true } +rapidhash = { workspace = true } unicode_names2 = { workspace = true } [dev-dependencies] diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 3873736c228..c3b86ce46bc 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -27,10 +27,10 @@ use ruff_text_size::{Ranged, TextRange, TextSize}; use rustpython_compiler_core::{ Mode, OneIndexed, PositionEncoding, SourceFile, SourceLocation, bytecode::{ - self, AnyInstruction, Arg as OpArgMarker, BinaryOperator, BuildSliceArgCount, CodeObject, - ComparisonOperator, ConstantData, ConvertValueOparg, Instruction, IntrinsicFunction1, - Invert, LoadAttr, LoadSuperAttr, OpArg, OpArgType, PseudoInstruction, SpecialMethod, - UnpackExArgs, oparg, + self, AnyInstruction, AnyOpcode, Arg as OpArgMarker, BinaryOperator, BuildSliceArgCount, + CodeObject, ComparisonOperator, ConstantData, ConvertValueOparg, Instruction, + IntrinsicFunction1, Invert, LoadAttr, LoadSuperAttr, OpArg, OpArgType, PseudoInstruction, + SpecialMethod, UnpackExArgs, oparg, }, }; use rustpython_wtf8::Wtf8Buf; @@ -82,7 +82,7 @@ impl ExprExt for ast::Expr { const MAXBLOCKS: usize = 20; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FBlockType { WhileLoop, ForLoop, @@ -132,8 +132,10 @@ enum BuiltinGeneratorCallKind { #[derive(Debug, Clone)] pub struct FBlockInfo { pub fb_type: FBlockType, - pub fb_block: BlockIdx, - pub fb_exit: BlockIdx, + // CPython _PyCompile_FBlockInfo stores jump_target_label values here. + pub(crate) fb_block: ir::InstructionSequenceLabel, + // CPython's optional type-specific exit or cleanup jump_target_label. + pub(crate) fb_exit: ir::InstructionSequenceLabel, pub fb_range: TextRange, // additional data for fblock unwinding pub fb_datum: FBlockDatum, @@ -167,17 +169,9 @@ struct Compiler { /// When > 0, the compiler walks AST (consuming sub_tables) but emits no bytecode. /// Mirrors CPython's `c_do_not_emit_bytecode`. do_not_emit_bytecode: u32, - /// Disable constant BoolOp folding in contexts where CPython preserves - /// short-circuit structure, such as starred unpack expressions. - disable_const_boolop_folding: bool, /// Disable constant tuple/list/set collection folding in contexts where /// CPython keeps the builder form for later assignment lowering. disable_const_collection_folding: bool, - split_next_for_normal_exit_from_break: bool, - fallthrough_has_statement_successor: bool, - fallthrough_has_local_statement_successor: bool, - fallthrough_successor_stack: Vec<(bool, bool)>, - try_else_orelse_conditional_base_stack: Vec, } #[derive(Clone, Copy)] @@ -207,7 +201,6 @@ impl Default for CompileOpts { #[derive(Debug, Clone, Copy)] struct CompileContext { - loop_data: Option<(BlockIdx, BlockIdx)>, in_class: bool, func: FunctionContext, /// True if we're anywhere inside an async function (even inside nested comprehensions) @@ -241,6 +234,7 @@ enum ComprehensionLoopControl { loop_block: BlockIdx, if_cleanup_block: BlockIdx, after_block: BlockIdx, + backedge_range: TextRange, is_async: bool, end_async_for_target: BlockIdx, }, @@ -480,68 +474,6 @@ impl Compiler { } } - fn boolop_fast_fold_literal(expr: &ast::Expr) -> bool { - matches!( - expr, - ast::Expr::NumberLiteral(_) - | ast::Expr::StringLiteral(_) - | ast::Expr::BytesLiteral(_) - | ast::Expr::BooleanLiteral(_) - | ast::Expr::NoneLiteral(_) - | ast::Expr::EllipsisLiteral(_) - ) - } - - fn constant_expr_truthiness(&mut self, expr: &ast::Expr) -> CompileResult> { - Ok(self - .try_fold_constant_expr(expr)? - .map(|constant| Self::constant_truthiness(&constant))) - } - - fn has_always_taken_jump_in_test( - &mut self, - expr: &ast::Expr, - condition: bool, - ) -> CompileResult { - Ok(match expr { - ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) => { - let (last_value, prefix_values) = values.split_last().unwrap(); - let cond2 = matches!(op, ast::BoolOp::Or); - for value in prefix_values { - if self.has_always_taken_jump_in_test(value, cond2)? { - return Ok(true); - } - } - self.has_always_taken_jump_in_test(last_value, condition)? - } - ast::Expr::UnaryOp(ast::ExprUnaryOp { - op: ast::UnaryOp::Not, - operand, - .. - }) => self.has_always_taken_jump_in_test(operand, !condition)?, - ast::Expr::If(ast::ExprIf { - test, body, orelse, .. - }) => { - self.has_always_taken_jump_in_test(test, false)? - || self.has_always_taken_jump_in_test(body, condition)? - || self.has_always_taken_jump_in_test(orelse, condition)? - } - _ => matches!(self.constant_expr_truthiness(expr)?, Some(value) if value == condition), - }) - } - - fn disable_load_fast_borrow_for_block(&mut self, block: BlockIdx) { - if block != BlockIdx::NULL { - self.current_code_info().blocks[block.idx()].disable_load_fast_borrow = true; - } - } - - fn mark_try_else_orelse_entry_block(&mut self, block: BlockIdx) { - if block != BlockIdx::NULL { - self.current_code_info().blocks[block.idx()].try_else_orelse_entry = true; - } - } - fn new(opts: CompileOpts, source_file: SourceFile, code_name: &str) -> Self { let module_code = ir::CodeInfo { // CPython convention: top-level module / interactive / @@ -550,18 +482,20 @@ impl Compiler { // scope.) This matches the per-scope mapping at // enter_scope::CompilerScope::Module below, which also returns // empty flags. frame.rs:725-731 then binds locals to globals - // for module/REPL frames whose `scope.locals` is None — the + // for module/REPL frames whose `scope.locals` is None - the // correct semantics for `exec(code, globals)` and module init. flags: bytecode::CodeFlags::empty(), source_path: source_file.name().to_owned(), private: None, blocks: vec![ir::Block::default()], current_block: BlockIdx::new(0), - annotations_blocks: None, + instr_sequence: ir::InstructionSequence::new(), + instr_sequence_label_map: ir::InstructionSequenceLabelMap::new(), + annotations_instr_sequence: None, metadata: ir::CodeUnitMetadata { name: code_name.to_string(), qualname: Some(code_name.to_string()), - consts: IndexSet::default(), + consts: Default::default(), names: IndexSet::default(), varnames: IndexSet::default(), cellvars: IndexSet::default(), @@ -577,9 +511,8 @@ impl Compiler { in_inlined_comp: false, fblock: Vec::with_capacity(MAXBLOCKS), symbol_table_index: 0, // Module is always the first symbol table + nparams: 0, in_conditional_block: 0, - in_final_with_cleanup_statement: 0, - in_try_else_orelse: 0, next_conditional_annotation_index: 0, }; Self { @@ -591,7 +524,6 @@ impl Compiler { done_with_future_stmts: DoneWithFuture::No, future_annotations: false, ctx: CompileContext { - loop_data: None, in_class: false, func: FunctionContext::NoFunction, in_async_scope: false, @@ -600,27 +532,10 @@ impl Compiler { in_annotation: false, interactive: false, do_not_emit_bytecode: 0, - disable_const_boolop_folding: false, disable_const_collection_folding: false, - split_next_for_normal_exit_from_break: false, - fallthrough_has_statement_successor: false, - fallthrough_has_local_statement_successor: false, - fallthrough_successor_stack: Vec::new(), - try_else_orelse_conditional_base_stack: Vec::new(), } } - fn compile_expression_without_const_boolop_folding( - &mut self, - expression: &ast::Expr, - ) -> CompileResult<()> { - let previous = self.disable_const_boolop_folding; - self.disable_const_boolop_folding = true; - let result = self.compile_expression(expression); - self.disable_const_boolop_folding = previous; - result.map(|_| ()) - } - fn compile_expression_without_const_collection_folding( &mut self, expression: &ast::Expr, @@ -636,307 +551,43 @@ impl Compiler { matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)) } - fn statements_end_with_scope_exit(body: &[ast::Stmt]) -> bool { - body.last() - .is_some_and(Self::statement_ends_with_scope_exit) - } - - fn statements_end_with_finally_entry_scope_exit(body: &[ast::Stmt]) -> bool { - body.last() - .is_some_and(Self::statement_ends_with_finally_entry_scope_exit) - } - - fn statement_ends_with_finally_entry_scope_exit(stmt: &ast::Stmt) -> bool { - match stmt { - ast::Stmt::Return(_) - | ast::Stmt::Raise(_) - | ast::Stmt::Break(_) - | ast::Stmt::Continue(_) => true, - ast::Stmt::If(ast::StmtIf { body, .. }) => { - Self::statements_end_with_finally_entry_scope_exit(body) - } - _ => false, - } - } - - fn statement_ends_with_scope_exit(stmt: &ast::Stmt) -> bool { - match stmt { - ast::Stmt::Return(_) | ast::Stmt::Raise(_) => true, - ast::Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - let has_else = elif_else_clauses - .last() - .is_some_and(|clause| clause.test.is_none()); - has_else - && Self::statements_end_with_scope_exit(body) - && elif_else_clauses - .iter() - .all(|clause| Self::statements_end_with_scope_exit(&clause.body)) - } - _ => false, - } - } - - fn statements_end_with_with_cleanup_scope_exit(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::With(ast::StmtWith { body, .. }) => { - Self::statements_end_with_scope_exit(body) - || Self::statements_end_with_with_cleanup_scope_exit(body) - } - _ => false, - }) - } - - fn statements_end_with_nonterminal_with_cleanup(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::With(ast::StmtWith { body, .. }) => { - !Self::statements_end_with_scope_exit(body) - && !Self::statements_end_with_with_cleanup_scope_exit(body) - } - _ => false, - }) - } - - fn statements_end_with_try_finally(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| { - matches!( - stmt, - ast::Stmt::Try(ast::StmtTry { finalbody, .. }) if !finalbody.is_empty() - ) - }) - } - - fn statements_end_with_nested_finalbody_try_finally(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::Try(ast::StmtTry { finalbody, .. }) if !finalbody.is_empty() => { - Self::statements_end_with_try_finally(finalbody) - } - _ => false, - }) - } - - fn statements_end_with_try_star_except(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| { - matches!( - stmt, - ast::Stmt::Try(ast::StmtTry { - handlers, - finalbody, - is_star: true, - .. - }) if !handlers.is_empty() && finalbody.is_empty() - ) - }) - } - - fn statements_end_with_try_except_handler_fallthrough(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::Try(ast::StmtTry { - body, - handlers, - finalbody, - is_star: false, - .. - }) => { - finalbody.is_empty() - && !handlers.is_empty() - && Self::statements_end_with_scope_exit(body) - && handlers.iter().any(|handler| { - let ast::ExceptHandler::ExceptHandler(handler) = handler; - !Self::statements_end_with_scope_exit(&handler.body) - }) - } - _ => false, - }) - } - - fn statements_end_with_try_except_else_handler_scope_exit(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::Try(ast::StmtTry { - handlers, - orelse, - finalbody, - is_star: false, - .. - }) => { - !orelse.is_empty() - && finalbody.is_empty() - && handlers.iter().any(|handler| { - let ast::ExceptHandler::ExceptHandler(handler) = handler; - Self::statements_end_with_finally_entry_scope_exit(&handler.body) - }) - } - _ => false, - }) - } - - fn statements_end_with_open_conditional_fallthrough(body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::If(ast::StmtIf { - elif_else_clauses, .. - }) => elif_else_clauses - .last() - .is_none_or(|clause| clause.test.is_some()), - _ => false, - }) - } - - fn statements_end_with_conditional_scope_exit(&self, body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::Assert(_) => self.opts.optimize == 0, - ast::Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - Self::statements_end_with_scope_exit(body) - || elif_else_clauses - .iter() - .any(|clause| Self::statements_end_with_scope_exit(&clause.body)) - } - _ => false, - }) - } - - fn statements_end_with_loop_fallthrough(&mut self, body: &[ast::Stmt]) -> CompileResult { - match body.last() { - Some(ast::Stmt::For(ast::StmtFor { body, .. })) => { - Ok(!Self::statements_contain_direct_break(body)) - } - Some(ast::Stmt::While(ast::StmtWhile { test, body, .. })) => { - Ok(!matches!(self.constant_expr_truthiness(test)?, Some(true)) - && !Self::statements_contain_direct_break(body)) - } - _ => Ok(false), - } - } - - fn statements_contain_direct_break(body: &[ast::Stmt]) -> bool { - body.iter().any(Self::statement_contains_direct_break) - } - - fn statements_are_single_for_direct_break(body: &[ast::Stmt]) -> bool { - matches!( - body, - [ast::Stmt::For(ast::StmtFor { body, .. })] - if Self::statements_contain_direct_break(body) - ) - } - - fn statement_contains_direct_break(stmt: &ast::Stmt) -> bool { - match stmt { - ast::Stmt::Break(_) => true, - ast::Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - Self::statements_contain_direct_break(body) - || elif_else_clauses - .iter() - .any(|clause| Self::statements_contain_direct_break(&clause.body)) - } - ast::Stmt::With(ast::StmtWith { body, .. }) => { - Self::statements_contain_direct_break(body) - } - ast::Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - Self::statements_contain_direct_break(body) - || handlers.iter().any(|handler| { - let ast::ExceptHandler::ExceptHandler(handler) = handler; - Self::statements_contain_direct_break(&handler.body) - }) - || Self::statements_contain_direct_break(orelse) - || Self::statements_contain_direct_break(finalbody) - } - ast::Stmt::Match(ast::StmtMatch { cases, .. }) => cases - .iter() - .any(|case| Self::statements_contain_direct_break(&case.body)), - ast::Stmt::For(_) | ast::Stmt::While(_) => false, - _ => false, - } - } - - fn has_resuming_bare_except(handlers: &[ast::ExceptHandler]) -> bool { - handlers.iter().any(|handler| { - let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - type_, - body, - .. - }) = handler; - type_.is_none() && !Self::statements_end_with_scope_exit(body) - }) - } - - fn statements_end_with_optimized_finally_entry_scope_exit(&self, body: &[ast::Stmt]) -> bool { - body.last() - .is_some_and(|stmt| self.statement_ends_with_optimized_finally_entry_scope_exit(stmt)) - } - - fn statement_ends_with_optimized_finally_entry_scope_exit(&self, stmt: &ast::Stmt) -> bool { - match stmt { - ast::Stmt::Assert(_) => self.opts.optimize == 0, - ast::Stmt::If(ast::StmtIf { body, .. }) => { - self.statements_end_with_optimized_finally_entry_scope_exit(body) - } - _ => Self::statement_ends_with_finally_entry_scope_exit(stmt), - } - } - - fn preserves_finally_entry_nop(&self, body: &[ast::Stmt]) -> bool { - body.last().is_some_and(|stmt| match stmt { - ast::Stmt::Try(ast::StmtTry { - body, - handlers, - finalbody, - .. - }) => { - !finalbody.is_empty() - && !Self::statements_end_with_open_conditional_fallthrough(finalbody) - || (!handlers.is_empty() && Self::statements_end_with_scope_exit(body)) - } - ast::Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - elif_else_clauses.is_empty() - && self.statements_end_with_optimized_finally_entry_scope_exit(body) - } - ast::Stmt::Assert(_) => self.opts.optimize == 0, - _ => false, - }) - } - fn compile_module_annotation_setup_sequence( &mut self, body: &[ast::Stmt], + loc: TextRange, ) -> CompileResult<()> { - let (saved_blocks, saved_current_block) = { + let ( + saved_blocks, + saved_current_block, + saved_instr_sequence, + saved_instr_sequence_label_map, + saved_annotations_instr_sequence, + ) = { let code = self.current_code_info(); ( mem::replace(&mut code.blocks, vec![ir::Block::default()]), mem::replace(&mut code.current_block, BlockIdx::new(0)), + mem::replace(&mut code.instr_sequence, ir::InstructionSequence::new()), + mem::replace( + &mut code.instr_sequence_label_map, + ir::InstructionSequenceLabelMap::new(), + ), + code.annotations_instr_sequence.take(), ) }; - let result = self.compile_module_annotate(body); + let result = self.compile_module_annotate(body, Some(loc)); - let annotations_blocks = { + { let code = self.current_code_info(); - let annotations_blocks = mem::replace(&mut code.blocks, saved_blocks); + code.blocks = saved_blocks; + let annotations_instr_sequence = + mem::replace(&mut code.instr_sequence, saved_instr_sequence); code.current_block = saved_current_block; - annotations_blocks + code.instr_sequence_label_map = saved_instr_sequence_label_map; + code.annotations_instr_sequence = Some(annotations_instr_sequence); + debug_assert!(saved_annotations_instr_sequence.is_none()); }; - self.current_code_info().annotations_blocks = Some(annotations_blocks); result.map(|_| ()) } @@ -948,6 +599,7 @@ impl Compiler { if let Some(lower) = &s.lower { self.compile_expression(lower)?; } else { + self.set_source_range(s.range); self.emit_load_const(ConstantData::None); } @@ -955,6 +607,7 @@ impl Compiler { if let Some(upper) = &s.upper { self.compile_expression(upper)?; } else { + self.set_source_range(s.range); self.emit_load_const(ConstantData::None); } @@ -1128,7 +781,7 @@ impl Compiler { } // Compile the starred expression and extend - self.compile_expression_without_const_boolop_folding(value)?; + self.compile_expression(value)?; self.set_source_range(collection_range); match collection_type { CollectionType::List => { @@ -1461,8 +1114,14 @@ impl Compiler { /// Load arguments for super() optimization onto the stack /// Stack result: [global_super, class, self] - fn load_args_for_super(&mut self, super_type: &SuperCallType<'_>) -> CompileResult<()> { + fn load_args_for_super( + &mut self, + super_type: &SuperCallType<'_>, + super_name_range: TextRange, + super_call_range: TextRange, + ) -> CompileResult<()> { // 1. Load global super + self.set_source_range(super_name_range); self.compile_name("super", NameUsage::Load)?; match super_type { @@ -1477,6 +1136,7 @@ impl Compiler { SuperCallType::ZeroArg => { // 0-arg: load __class__ cell and first parameter // Load __class__ from cell/free variable + self.set_source_range(super_call_range); let scope = self.get_ref_type("__class__").map_err(|e| self.error(e))?; let idx = match scope { SymbolScope::Cell => self.get_cell_var_index("__class__"), @@ -1501,6 +1161,7 @@ impl Compiler { "super(): no arguments and no first parameter".to_owned(), )) })?; + self.set_source_range(super_call_range); self.compile_name(&first_param, NameUsage::Load)?; } } @@ -1526,7 +1187,7 @@ impl Compiler { &mut self, name: &str, scope_type: CompilerScope, - key: usize, // In RustPython, we use the index in symbol_table_stack as key + key: usize, // Symbol table stack index used like CPython's scope key. lineno: u32, ) -> CompileResult<()> { // Allocate a new compiler unit @@ -1546,6 +1207,7 @@ impl Compiler { // Use varnames from symbol table (already collected in definition order) let varname_cache: IndexSet = ste.varnames.iter().cloned().collect(); + let nparams = ste.varnames.len(); // Build cellvars using dictbytype (CELL scope or COMP_CELL flag, sorted) let mut cellvar_cache = IndexSet::default(); @@ -1639,7 +1301,7 @@ impl Compiler { } // Initialize u_metadata fields - let (flags, posonlyarg_count, arg_count, kwonlyarg_count) = match scope_type { + let (mut flags, posonlyarg_count, arg_count, kwonlyarg_count) = match scope_type { CompilerScope::Module => (bytecode::CodeFlags::empty(), 0, 0, 0), CompilerScope::Class => (bytecode::CodeFlags::empty(), 0, 0, 0), CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Lambda => ( @@ -1668,13 +1330,30 @@ impl Compiler { ), }; - // Set CO_NESTED for scopes defined inside another function/class/etc. - // (i.e., not at module level) - let flags = if self.code_stack.len() > 1 { + if ste.is_method { + flags |= bytecode::CodeFlags::METHOD; + } + + // CPython sets CO_NESTED from symtable's ste_nested, not merely + // from lexical depth: module-level class methods are CO_METHOD but + // not CO_NESTED. + let mut flags = if ste.is_nested + && matches!( + scope_type, + CompilerScope::Function + | CompilerScope::AsyncFunction + | CompilerScope::Lambda + | CompilerScope::Comprehension + | CompilerScope::Annotation + | CompilerScope::TypeParams + ) { flags | bytecode::CodeFlags::NESTED } else { flags }; + if self.future_annotations { + flags |= bytecode::CodeFlags::FUTURE_ANNOTATIONS; + } // Get private name from parent scope let private = if !self.code_stack.is_empty() { @@ -1690,11 +1369,13 @@ impl Compiler { private, blocks: vec![ir::Block::default()], current_block: BlockIdx::new(0), - annotations_blocks: None, + instr_sequence: ir::InstructionSequence::new(), + instr_sequence_label_map: ir::InstructionSequenceLabelMap::new(), + annotations_instr_sequence: None, metadata: ir::CodeUnitMetadata { name: name.to_owned(), qualname: None, // Will be set below - consts: IndexSet::default(), + consts: Default::default(), names: IndexSet::default(), varnames: varname_cache, cellvars: cellvar_cache, @@ -1714,20 +1395,13 @@ impl Compiler { in_inlined_comp: false, fblock: Vec::with_capacity(MAXBLOCKS), symbol_table_index: key, + nparams, in_conditional_block: 0, - in_final_with_cleanup_statement: 0, - in_try_else_orelse: 0, next_conditional_annotation_index: 0, }; // Push the old compiler unit on the stack (like PyCapsule) // This happens before setting qualname - self.fallthrough_successor_stack.push(( - self.fallthrough_has_statement_successor, - self.fallthrough_has_local_statement_successor, - )); - self.fallthrough_has_statement_successor = false; - self.fallthrough_has_local_statement_successor = false; self.code_stack.push(code_info); // Set qualname after pushing (uses compiler_set_qualname logic) @@ -1735,8 +1409,6 @@ impl Compiler { self.set_qualname(); } - self.emit_prefix_cell_setup(); - // Emit RESUME (handles async preamble and module lineno 0) // CPython: LOCATION(lineno, lineno, 0, 0), then loc.lineno = 0 for module self.emit_resume_for_scope(scope_type, lineno); @@ -1744,17 +1416,9 @@ impl Compiler { Ok(()) } - /// Emit RESUME instruction with proper handling for async preamble and module lineno. + /// Emit RESUME instruction with proper handling for module lineno. /// codegen_enter_scope equivalent for RESUME emission. fn emit_resume_for_scope(&mut self, scope_type: CompilerScope, lineno: u32) { - // For generators and async functions, emit RETURN_GENERATOR + POP_TOP before RESUME - let is_gen = - scope_type == CompilerScope::AsyncFunction || self.current_symbol_table().is_generator; - if is_gen { - emit!(self, Instruction::ReturnGenerator); - emit!(self, Instruction::PopTop); - } - // CPython: LOCATION(lineno, lineno, 0, 0) // Module scope: loc.lineno = 0 (before the first line) let lineno_override = if scope_type == CompilerScope::Module { @@ -1771,7 +1435,7 @@ impl Compiler { let end_location = location; // end_lineno = lineno, end_col = 0 let except_handler = None; - self.current_block().instructions.push(ir::InstructionInfo { + self.cpython_cfg_builder_addop(ir::InstructionInfo { instr: Instruction::Resume { context: OpArgMarker::marker(), } @@ -1781,48 +1445,10 @@ impl Compiler { location, end_location, except_handler, - folded_from_nonliteral_expr: false, lineno_override, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, }); } - fn emit_prefix_cell_setup(&mut self) { - let metadata = &self.code_stack.last().unwrap().metadata; - let varnames = metadata.varnames.clone(); - let cellvars = metadata.cellvars.clone(); - let freevars = metadata.freevars.clone(); - let ncells = cellvars.len(); - if ncells > 0 { - let cellfixedoffsets = ir::build_cellfixedoffsets(&varnames, &cellvars, &freevars); - let mut sorted = vec![None; varnames.len() + ncells]; - for (oldindex, fixed) in cellfixedoffsets.iter().copied().take(ncells).enumerate() { - sorted[fixed as usize] = Some(oldindex); - } - for oldindex in sorted.into_iter().flatten() { - let i_varnum: oparg::VarNum = - u32::try_from(oldindex).expect("too many cellvars").into(); - emit!(self, Instruction::MakeCell { i: i_varnum }); - } - } - - let nfrees = freevars.len(); - if nfrees > 0 { - emit!( - self, - Instruction::CopyFreeVars { - n: u32::try_from(nfrees).expect("too many freevars"), - } - ); - } - } - fn push_output( &mut self, flags: bytecode::CodeFlags, @@ -1848,8 +1474,12 @@ impl Compiler { // enter_scope sets default values based on scope_type, but push_output // allows callers to specify exact values if let Some(info) = self.code_stack.last_mut() { - // Preserve NESTED flag set by enter_scope - info.flags = flags | (info.flags & bytecode::CodeFlags::NESTED); + // Preserve flags computed from the symbol-table context. + info.flags = flags + | (info.flags + & (bytecode::CodeFlags::NESTED + | bytecode::CodeFlags::METHOD + | bytecode::CodeFlags::FUTURE_ANNOTATIONS)); info.metadata.argcount = arg_count; info.metadata.posonlyargcount = posonlyarg_count; info.metadata.kwonlyargcount = kwonlyarg_count; @@ -1860,11 +1490,6 @@ impl Compiler { // compiler_exit_scope fn exit_scope(&mut self) -> CodeObject { let _table = self.pop_symbol_table(); - if let Some((previous, previous_local)) = self.fallthrough_successor_stack.pop() { - self.fallthrough_has_statement_successor = previous; - self.fallthrough_has_local_statement_successor = previous_local; - } - // Various scopes can have sub_tables: // - ast::TypeParams scope can have sub_tables (the function body's symbol table) // - Module scope can have sub_tables (for TypeAlias scopes, nested functions, classes) @@ -1881,11 +1506,6 @@ impl Compiler { fn exit_annotation_scope(&mut self, saved_ctx: CompileContext) -> CodeObject { self.pop_annotation_symbol_table(); self.ctx = saved_ctx; - if let Some((previous, previous_local)) = self.fallthrough_successor_stack.pop() { - self.fallthrough_has_statement_successor = previous; - self.fallthrough_has_local_statement_successor = previous_local; - } - let pop = self.code_stack.pop(); let stack_top = compiler_unwrap_option(self, pop); unwrap_internal(self, stack_top.finalize_code(&self.opts)) @@ -1897,6 +1517,7 @@ impl Compiler { fn enter_annotation_scope( &mut self, _func_name: &str, + loc: TextRange, ) -> CompileResult> { if !self.push_annotation_symbol_table() { return Ok(None); @@ -1905,12 +1526,12 @@ impl Compiler { // Annotation scopes are never async (even inside async functions) let saved_ctx = self.ctx; self.ctx = CompileContext { - loop_data: None, in_class: saved_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, }; + self.set_source_range(loc); let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get(); self.enter_scope( @@ -1975,26 +1596,16 @@ impl Compiler { ); // Body label - continue with annotation evaluation - self.switch_to_block(body_block); - } - - /// Push a new fblock - // = compiler_push_fblock - fn push_fblock( - &mut self, - fb_type: FBlockType, - fb_block: BlockIdx, - fb_exit: BlockIdx, - ) -> CompileResult<()> { - self.push_fblock_full(fb_type, fb_block, fb_exit, FBlockDatum::None) + self.use_cpython_label_block(body_block); } - /// Push an fblock with all parameters including fb_datum - fn push_fblock_full( + /// CPython `_PyCompile_PushFBlock()`: store the active label targets on the + /// fblock stack. + fn push_fblock_labels( &mut self, fb_type: FBlockType, - fb_block: BlockIdx, - fb_exit: BlockIdx, + fb_block: ir::InstructionSequenceLabel, + fb_exit: ir::InstructionSequenceLabel, fb_datum: FBlockDatum, ) -> CompileResult<()> { let fb_range = self.current_source_range; @@ -2014,18 +1625,42 @@ impl Compiler { Ok(()) } - /// Pop an fblock - // = compiler_pop_fblock - fn pop_fblock(&mut self, _expected_type: FBlockType) -> FBlockInfo { + /// CPython `_PyCompile_PopFBlock()`: assert the popped type and label. + fn pop_fblock_label( + &mut self, + expected_type: FBlockType, + expected_block: ir::InstructionSequenceLabel, + ) -> FBlockInfo { let code = self.current_code_info(); - // TODO: Add assertion to check expected type matches - // assert!(matches!(fblock.fb_type, expected_type)); - code.fblock.pop().expect("fblock stack underflow") + let fblock = code.fblock.pop().expect("fblock stack underflow"); + debug_assert_eq!(fblock.fb_type, expected_type); + debug_assert_eq!( + fblock.fb_block, expected_block, + "CPython _PyCompile_PopFBlock asserts the popped fb_block label" + ); + fblock + } + + fn set_unwind_source_range(&mut self, loc: Option) { + if let Some(range) = loc { + self.set_source_range(range); + } + } + + fn mark_unwind_no_location(&mut self, loc: Option) { + if loc.is_none() { + self.set_no_location(); + } } /// Unwind a single fblock, emitting cleanup code /// preserve_tos: if true, preserve the top of stack (e.g., return value) - fn unwind_fblock(&mut self, info: &FBlockInfo, preserve_tos: bool) -> CompileResult<()> { + fn unwind_fblock( + &mut self, + info: &FBlockInfo, + preserve_tos: bool, + loc: &mut Option, + ) -> CompileResult<()> { match info.fb_type { FBlockType::WhileLoop | FBlockType::ExceptionHandler @@ -2039,13 +1674,19 @@ impl Compiler { // When returning from a for-loop, CPython swaps the preserved // value with the iterator and uses POP_TOP for loop cleanup. if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); + self.mark_unwind_no_location(*loc); } FBlockType::TryExcept => { + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); } FBlockType::FinallyTry => { @@ -2058,71 +1699,106 @@ impl Compiler { FBlockType::FinallyEnd => { // codegen_unwind_fblock(FINALLY_END) if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); // exc_value + self.mark_unwind_no_location(*loc); if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopExcept); + self.mark_unwind_no_location(*loc); } FBlockType::With | FBlockType::AsyncWith => { // Stack: [..., exit_func, self_exit, return_value (if preserve_tos)] - self.set_source_range(info.fb_range); + // CPython codegen_unwind_fblock() assigns *ploc = info->fb_loc + // for WITH/ASYNC_WITH cleanup and then makes following unwind + // instructions artificial with *ploc = NO_LOCATION. + *loc = Some(info.fb_range); + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); if preserve_tos { // Rotate return value below the exit pair // [exit_func, self_exit, value] → [value, exit_func, self_exit] + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 3 }); // [value, self_exit, exit_func] + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); // [value, exit_func, self_exit] } - // Call exit_func(self_exit, None, None, None) - self.emit_load_const(ConstantData::None); - self.emit_load_const(ConstantData::None); - self.emit_load_const(ConstantData::None); - emit!(self, Instruction::Call { argc: 3 }); + self.set_unwind_source_range(*loc); + self.compile_call_exit_with_nones(); // For async with, await the result if matches!(info.fb_type, FBlockType::AsyncWith) { + self.set_unwind_source_range(*loc); emit!(self, Instruction::GetAwaitable { r#where: 2 }); + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; + let _ = self.compile_yield_from_sequence(true); } // Pop the __exit__ result + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); + *loc = None; } FBlockType::HandlerCleanup => { // codegen_unwind_fblock(HANDLER_CLEANUP) if let FBlockDatum::ExceptionName(_) = info.fb_datum { // Named handler: PopBlock for inner SETUP_CLEANUP + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); } if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } // PopBlock for outer SETUP_CLEANUP (ExceptionHandler) + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopExcept); + self.mark_unwind_no_location(*loc); // If there's an exception name, clean it up if let FBlockDatum::ExceptionName(ref name) = info.fb_datum { + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); self.store_name(name)?; + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); self.compile_name(name, NameUsage::Delete)?; + self.mark_unwind_no_location(*loc); } } FBlockType::PopValue => { if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); + self.mark_unwind_no_location(*loc); } } Ok(()) @@ -2135,7 +1811,20 @@ impl Compiler { &mut self, preserve_tos: bool, stop_at_loop: bool, - ) -> CompileResult { + ) -> CompileResult> { + let (unwind_loc, _loop_fblock) = + self.unwind_fblock_stack_with_loop(preserve_tos, stop_at_loop)?; + Ok(unwind_loc) + } + + /// CPython `codegen_unwind_fblock_stack()`: unwind frame blocks and, when + /// requested by break/continue codegen, return the first loop fblock instead + /// of unwinding it. + fn unwind_fblock_stack_with_loop( + &mut self, + preserve_tos: bool, + stop_at_loop: bool, + ) -> CompileResult<(Option, Option)> { // Collect the info we need, with indices for FinallyTry blocks #[derive(Clone)] enum UnwindInfo { @@ -2146,6 +1835,7 @@ impl Compiler { }, } let mut unwind_infos = Vec::new(); + let mut loop_fblock = None; { let code = self.current_code_info(); @@ -2162,6 +1852,7 @@ impl Compiler { FBlockType::WhileLoop | FBlockType::ForLoop ) { + loop_fblock = Some(code.fblock[i].clone()); break; } @@ -2179,15 +1870,17 @@ impl Compiler { } // Process each fblock - let mut unwound_finally = false; + let mut unwind_loc = Some(self.current_source_range); for info in unwind_infos { match info { UnwindInfo::Normal(fblock_info) => { - self.unwind_fblock(&fblock_info, preserve_tos)?; + self.unwind_fblock(&fblock_info, preserve_tos, &mut unwind_loc)?; } UnwindInfo::FinallyTry { body, fblock_idx } => { // codegen_unwind_fblock(FINALLY_TRY) + self.set_unwind_source_range(unwind_loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(unwind_loc); // Temporarily remove the FinallyTry fblock so nested return/break/continue // in the finally body won't see it again @@ -2196,18 +1889,22 @@ impl Compiler { // Push PopValue fblock if preserving tos if preserve_tos { - self.push_fblock( + self.push_fblock_labels( FBlockType::PopValue, - saved_fblock.fb_block, - saved_fblock.fb_block, + ir::InstructionSequenceLabel::NO_LABEL, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, )?; } self.compile_statements(&body)?; - unwound_finally = true; + unwind_loc = None; if preserve_tos { - self.pop_fblock(FBlockType::PopValue); + self.pop_fblock_label( + FBlockType::PopValue, + ir::InstructionSequenceLabel::NO_LABEL, + ); } // Restore the fblock @@ -2217,7 +1914,7 @@ impl Compiler { } } - Ok(unwound_finally) + Ok((unwind_loc, loop_fblock)) } // could take impl Into>, but everything is borrowed from ast structs; we never @@ -2266,7 +1963,7 @@ impl Compiler { let mut parent_idx = stack_size - 2; let mut parent = &self.code_stack[parent_idx]; - let parent_scope = self + let mut parent_scope = self .symbol_table_stack .get(parent_idx) .map(|table| table.typ); @@ -2285,31 +1982,48 @@ impl Compiler { // Use grandparent parent_idx = stack_size - 3; parent = &self.code_stack[parent_idx]; + parent_scope = self + .symbol_table_stack + .get(parent_idx) + .map(|table| table.typ); } - // Check if this is a global class/function + // Check if this is a global class/function. + // CPython compiler_set_qualname() only applies this GLOBAL_EXPLICIT + // shortcut to function, async-function, and class scopes. Annotation + // scopes, including type-alias value scopes, still inherit the parent + // function's . qualname. let mut force_global = false; - if stack_size > self.symbol_table_stack.len() { - // We might be in a situation where symbol table isn't pushed yet - // In this case, check the parent symbol table - if let Some(parent_table) = self.symbol_table_stack.last() - && let Some(symbol) = parent_table.lookup(¤t_obj_name) - && symbol.scope == SymbolScope::GlobalExplicit - { - force_global = true; - } - } else if let Some(_current_table) = self.symbol_table_stack.last() { - // Mangle the name if necessary (for private names in classes) - let mangled_name = self.mangle(¤t_obj_name); - - // Look up in parent symbol table to check scope - if self.symbol_table_stack.len() >= 2 { - let parent_table = &self.symbol_table_stack[self.symbol_table_stack.len() - 2]; - if let Some(symbol) = parent_table.lookup(&mangled_name) + let current_scope = self + .code_stack + .last() + .map(|code| self.symbol_table_stack[code.symbol_table_index].typ); + if matches!( + current_scope, + Some(CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Class) + ) { + if stack_size > self.symbol_table_stack.len() { + // We might be in a situation where symbol table isn't pushed yet + // In this case, check the parent symbol table + if let Some(parent_table) = self.symbol_table_stack.last() + && let Some(symbol) = parent_table.lookup(¤t_obj_name) && symbol.scope == SymbolScope::GlobalExplicit { force_global = true; } + } else if let Some(_current_table) = self.symbol_table_stack.last() { + // Mangle the name if necessary (for private names in classes) + let mangled_name = self.mangle(¤t_obj_name); + + // Look up in parent symbol table to check scope + if self.symbol_table_stack.len() >= 2 { + let parent_table = &self.symbol_table_stack[self.symbol_table_stack.len() - 2]; + if let Some(symbol) = parent_table.lookup(&mangled_name) + && symbol.scope == SymbolScope::GlobalExplicit + { + force_global = true; + } + } } } @@ -2322,9 +2036,12 @@ impl Compiler { let parent_obj_name = &parent.metadata.name; // Determine if parent is a function-like scope - let is_function_parent = parent.flags.contains(bytecode::CodeFlags::OPTIMIZED) - && !parent_obj_name.starts_with("<") // Not a special scope like , , etc. - && parent_obj_name != ""; // Not the module scope + let is_function_parent = matches!( + parent_scope, + Some( + CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Lambda + ) + ); if is_function_parent { // For functions, append . to parent qualname @@ -2354,6 +2071,11 @@ impl Compiler { let size_before = self.code_stack.len(); // Set future_annotations from symbol table (detected during symbol table scan) self.future_annotations = symbol_table.future_annotations; + if self.future_annotations { + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + } // Module-level __conditional_annotations__ cell let has_module_cond_ann = Self::scope_needs_conditional_annotations_cell(&symbol_table); @@ -2366,19 +2088,15 @@ impl Compiler { self.symbol_table_stack.push(symbol_table); - // Match flowgraph.c insert_prefix_instructions() for module-level - // synthetic cells before RESUME. - if has_module_cond_ann { - self.emit_prefix_cell_setup(); - } - self.emit_resume_for_scope(CompilerScope::Module, 1); emit!(self, PseudoInstruction::AnnotationsPlaceholder); - let (doc, statements) = split_doc(&body.body, &self.opts); + let (doc, statements) = split_doc_with_range(&body.body, &self.opts); + let module_start_loc = self.module_start_location(&body.body); // Handle annotation bookkeeping before the docstring assignment, as // codegen_body() does after _PyCodegen_Module() inserts the prefix set. if Self::find_ann(statements) { + self.set_source_range(module_start_loc); if Self::scope_needs_conditional_annotations_cell(self.current_symbol_table()) { emit!(self, Instruction::BuildSet { count: 0 }); self.store_name("__conditional_annotations__")?; @@ -2389,19 +2107,22 @@ impl Compiler { } } - if let Some(value) = doc { + if let Some((value, range)) = doc { + let saved_range = self.current_source_range; + self.set_source_range(range); self.emit_load_const(ConstantData::Str { value: value.into(), }); let doc = self.name("__doc__"); - emit!(self, Instruction::StoreName { namei: doc }) + emit!(self, Instruction::StoreName { namei: doc }); + self.set_source_range(saved_range); } // Compile all statements self.compile_statements(statements)?; if Self::find_ann(statements) && !self.future_annotations { - self.compile_module_annotation_setup_sequence(statements)?; + self.compile_module_annotation_setup_sequence(statements, module_start_loc)?; } assert_eq!(self.code_stack.len(), size_before); @@ -2421,13 +2142,20 @@ impl Compiler { self.interactive = true; // Set future_annotations from symbol table (detected during symbol table scan) self.future_annotations = symbol_table.future_annotations; + if self.future_annotations { + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + } self.symbol_table_stack.push(symbol_table); + let module_start_loc = self.module_start_location(body); self.emit_resume_for_scope(CompilerScope::Module, 1); emit!(self, PseudoInstruction::AnnotationsPlaceholder); // Handle annotations based on future_annotations flag if Self::find_ann(body) { + self.set_source_range(module_start_loc); if self.future_annotations { // PEP 563: Initialize __annotations__ dict emit!(self, Instruction::SetupAnnotations); @@ -2480,7 +2208,7 @@ impl Compiler { }; if Self::find_ann(body) && !self.future_annotations { - self.compile_module_annotation_setup_sequence(body)?; + self.compile_module_annotation_setup_sequence(body, module_start_loc)?; } self.emit_return_value(); @@ -2495,21 +2223,58 @@ impl Compiler { self.symbol_table_stack.push(symbol_table); self.emit_resume_for_scope(CompilerScope::Module, 1); - self.compile_statements(body)?; - - if let Some(last_statement) = body.last() { + if let Some((last_statement, statements)) = body.split_last() { + self.compile_statements(statements)?; match last_statement { - ast::Stmt::Expr(_) => { - self.current_block().instructions.pop(); // pop Instruction::PopTop + ast::Stmt::Expr(ast::StmtExpr { value, .. }) => { + self.compile_expression(value)?; } - ast::Stmt::FunctionDef(_) | ast::Stmt::ClassDef(_) => { - let pop_instructions = self.current_block().instructions.pop(); - let store_inst = compiler_unwrap_option(self, pop_instructions); // pop Instruction::Store - emit!(self, Instruction::Copy { i: 1 }); - self.current_block().instructions.push(store_inst); + ast::Stmt::FunctionDef(ast::StmtFunctionDef { + name, + parameters, + body, + decorator_list, + returns, + type_params, + is_async, + .. + }) => { + validate_duplicate_params(parameters).map_err(|e| self.error(e))?; + self.compile_function_def( + name.as_str(), + parameters, + body, + decorator_list, + returns.as_deref(), + *is_async, + type_params.as_deref(), + true, + )?; + } + ast::Stmt::ClassDef(ast::StmtClassDef { + name, + body, + decorator_list, + type_params, + arguments, + .. + }) => { + self.compile_class_def( + name.as_str(), + body, + decorator_list, + type_params.as_deref(), + arguments.as_deref(), + true, + )?; + } + _ => { + self.compile_statement(last_statement)?; + self.emit_load_const(ConstantData::None); } - _ => self.emit_load_const(ConstantData::None), } + } else { + self.emit_load_const(ConstantData::None); } self.emit_return_value(); @@ -2531,55 +2296,76 @@ impl Compiler { } fn compile_statements(&mut self, statements: &[ast::Stmt]) -> CompileResult<()> { - let inherited_successor = self.fallthrough_has_statement_successor; - for (idx, statement) in statements.iter().enumerate() { - let previous_successor = self.fallthrough_has_statement_successor; - let previous_local_successor = self.fallthrough_has_local_statement_successor; - self.fallthrough_has_statement_successor = - inherited_successor || idx + 1 < statements.len(); - self.fallthrough_has_local_statement_successor = idx + 1 < statements.len(); - let result = self.compile_statement(statement); - self.fallthrough_has_statement_successor = previous_successor; - self.fallthrough_has_local_statement_successor = previous_local_successor; - result?; + for statement in statements { + self.compile_statement(statement)?; } Ok(()) } fn compile_with_body_statements(&mut self, statements: &[ast::Stmt]) -> CompileResult<()> { - let inherited_successor = self.fallthrough_has_statement_successor; - for (idx, statement) in statements.iter().enumerate() { - let previous_successor = self.fallthrough_has_statement_successor; - let previous_local_successor = self.fallthrough_has_local_statement_successor; - self.fallthrough_has_statement_successor = - inherited_successor || idx + 1 < statements.len(); - self.fallthrough_has_local_statement_successor = idx + 1 < statements.len(); - if idx + 1 == statements.len() && matches!(statement, ast::Stmt::Try(_)) { - self.current_code_info().in_final_with_cleanup_statement += 1; - let result = self.compile_statement(statement); - self.current_code_info().in_final_with_cleanup_statement -= 1; - self.fallthrough_has_statement_successor = previous_successor; - self.fallthrough_has_local_statement_successor = previous_local_successor; - result?; - } else { - let result = self.compile_statement(statement); - self.fallthrough_has_statement_successor = previous_successor; - self.fallthrough_has_local_statement_successor = previous_local_successor; - result?; - } + for statement in statements { + self.compile_statement(statement)?; } Ok(()) } + /// CPython `codegen_call_exit_with_nones()`. + fn compile_call_exit_with_nones(&mut self) { + self.emit_load_const(ConstantData::None); + self.emit_load_const(ConstantData::None); + self.emit_load_const(ConstantData::None); + emit!(self, Instruction::Call { argc: 3 }); + } + + /// CPython `codegen_with_except_finish()`. + fn compile_with_except_finish(&mut self, cleanup_block: BlockIdx) { + let suppress_block = self.new_block(); + + emit!(self, Instruction::ToBool); + self.set_no_location(); + emit!( + self, + Instruction::PopJumpIfTrue { + delta: suppress_block + } + ); + self.set_no_location(); + emit!(self, Instruction::Reraise { depth: 2 }); + self.set_no_location(); + + self.use_cpython_label_block(suppress_block); + emit!(self, Instruction::PopTop); + self.set_no_location(); + emit!(self, PseudoInstruction::PopBlock); + self.set_no_location(); + emit!(self, Instruction::PopExcept); + self.set_no_location(); + emit!(self, Instruction::PopTop); + self.set_no_location(); + emit!(self, Instruction::PopTop); + self.set_no_location(); + emit!(self, Instruction::PopTop); + self.set_no_location(); + let exit_block = self.new_block(); + emit!( + self, + PseudoInstruction::JumpNoInterrupt { delta: exit_block } + ); + self.set_no_location(); + + self.use_cpython_label_block(cleanup_block); + emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); + emit!(self, Instruction::PopExcept); + self.set_no_location(); + emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); + + self.use_cpython_label_block(exit_block); + } + fn compile_loop_body_statements(&mut self, statements: &[ast::Stmt]) -> CompileResult<()> { - let previous_successor = self.fallthrough_has_statement_successor; - let previous_local_successor = self.fallthrough_has_local_statement_successor; - self.fallthrough_has_statement_successor = false; - self.fallthrough_has_local_statement_successor = false; - let result = self.compile_statements(statements); - self.fallthrough_has_statement_successor = previous_successor; - self.fallthrough_has_local_statement_successor = previous_local_successor; - result + self.compile_statements(statements) } fn scope_needs_conditional_annotations_cell(symbol_table: &SymbolTable) -> bool { @@ -2599,6 +2385,19 @@ impl Compiler { self.compile_name(name, NameUsage::Store) } + fn emit_no_location_exception_name_cleanup(&mut self, name: &str) -> CompileResult<()> { + // CPython codegen_try_except() emits `name = None; del name` + // with NO_LOCATION for `except ... as name` cleanup. + self.set_no_location(); + self.emit_load_const(ConstantData::None); + self.set_no_location(); + self.store_name(name)?; + self.set_no_location(); + self.compile_name(name, NameUsage::Delete)?; + self.set_no_location(); + Ok(()) + } + fn mangle<'a>(&self, name: &'a str) -> Cow<'a, str> { // Use private from current code unit for name mangling let private = self @@ -2710,6 +2509,18 @@ impl Compiler { CompilerScope::Annotation | CompilerScope::TypeParams ) { SymbolScope::GlobalImplicit + } else if matches!( + name.as_ref(), + "__name__" + | "__module__" + | "__qualname__" + | "__firstlineno__" + | "__doc__" + | "__static_attributes__" + | "__classdictcell__" + | "__classcell__" + ) { + SymbolScope::Unknown } else { return Err(self.error(CodegenErrorType::SyntaxError(format!( "the symbol '{name}' must be present in the symbol table" @@ -2966,6 +2777,7 @@ impl Compiler { } ); emit!(self, Instruction::PopTop); + self.set_no_location(); } else { // from mod import a, b as c @@ -3039,8 +2851,9 @@ impl Compiler { body, orelse, is_async, + range, .. - }) => self.compile_for(target, iter, body, orelse, *is_async)?, + }) => self.compile_for(target, iter, body, orelse, *is_async, *range)?, ast::Stmt::Match(ast::StmtMatch { subject, cases, .. }) => { self.compile_match(subject, cases)? } @@ -3062,10 +2875,6 @@ impl Compiler { }; self.set_source_range(*range); emit!(self, Instruction::RaiseVarargs { argc: kind }); - // Start a new block so dead code after raise doesn't - // corrupt the except stack in label_exception_targets - let dead = self.new_block(); - self.switch_to_block(dead); } ast::Stmt::Try(ast::StmtTry { body, @@ -3077,7 +2886,7 @@ impl Compiler { }) => { self.enter_conditional_block(); if *is_star { - self.compile_try_star_except(body, handlers, orelse, finalbody)? + self.compile_try_star_statement(body, handlers, orelse, finalbody)? } else { self.compile_try_statement(body, handlers, orelse, finalbody)? } @@ -3103,6 +2912,7 @@ impl Compiler { returns.as_deref(), *is_async, type_params.as_deref(), + false, )? } ast::Stmt::ClassDef(ast::StmtClassDef { @@ -3118,12 +2928,16 @@ impl Compiler { decorator_list, type_params.as_deref(), arguments.as_deref(), + false, )?, - ast::Stmt::Assert(ast::StmtAssert { test, msg, .. }) => { + ast::Stmt::Assert(ast::StmtAssert { + test, msg, range, .. + }) => { // if some flag, ignore all assert statements! if self.opts.optimize == 0 { let after_block = self.new_block(); self.compile_jump_if(test, true, after_block)?; + self.set_source_range(*range); emit!( self, Instruction::LoadCommonConstant { @@ -3132,15 +2946,17 @@ impl Compiler { ); if let Some(e) = msg { self.compile_expression(e)?; + self.set_source_range(*range); emit!(self, Instruction::Call { argc: 0 }); } + self.set_source_range(test.range()); emit!( self, Instruction::RaiseVarargs { argc: bytecode::RaiseKind::Raise, } ); - self.switch_to_block(after_block); + self.use_cpython_label_block(after_block); } else { // Optimized-out asserts still need to consume any nested // scope symbol tables they contain so later nested scopes @@ -3152,18 +2968,12 @@ impl Compiler { } } ast::Stmt::Break(_) => { - emit!(self, Instruction::Nop); // NOP for line tracing // Unwind fblock stack until we find a loop, emitting cleanup for each fblock self.compile_break_continue(statement.range(), true)?; - let dead = self.new_block(); - self.switch_to_block(dead); } ast::Stmt::Continue(_) => { - emit!(self, Instruction::Nop); // NOP for line tracing // Unwind fblock stack until we find a loop, emitting cleanup for each fblock self.compile_break_continue(statement.range(), false)?; - let dead = self.new_block(); - self.switch_to_block(dead); } ast::Stmt::Return(ast::StmtReturn { value, .. }) => { if !self.ctx.in_func() { @@ -3177,10 +2987,7 @@ impl Compiler { match value { Some(v) => { if self.ctx.func == FunctionContext::AsyncFunction - && self - .current_code_info() - .flags - .contains(bytecode::CodeFlags::GENERATOR) + && self.current_symbol_table().is_generator { return Err(self.error_ranged( CodegenErrorType::AsyncReturnValue, @@ -3193,9 +3000,11 @@ impl Compiler { None }; let preserve_tos = folded_constant.is_none(); + let mut return_range = stmt_range; if preserve_tos { self.compile_expression(v)?; } else { + return_range = v.range(); self.set_source_range(v.range()); emit!(self, Instruction::Nop); } @@ -3204,16 +3013,17 @@ impl Compiler { if source.line_index(v.range().start()) != source.line_index(stmt_range.start()) { + return_range = stmt_range; self.set_source_range(stmt_range); emit!(self, Instruction::Nop); } - self.set_source_range(stmt_range); - let unwound_finally = self.unwind_fblock_stack(preserve_tos, false)?; - if !unwound_finally { - self.set_source_range(stmt_range); + self.set_source_range(return_range); + let unwind_loc = self.unwind_fblock_stack(preserve_tos, false)?; + if let Some(loc) = unwind_loc { + self.set_source_range(loc); } match folded_constant { - Some(constant) if unwound_finally => { + Some(constant) if unwind_loc.is_none() => { self.emit_return_const_no_location(constant); } Some(constant) => { @@ -3222,7 +3032,7 @@ impl Compiler { } None => { self.emit_return_value(); - if unwound_finally { + if unwind_loc.is_none() { self.set_no_location(); } } @@ -3232,20 +3042,23 @@ impl Compiler { self.set_source_range(stmt_range); emit!(self, Instruction::Nop); // Unwind fblock stack with preserve_tos=false (no value to preserve) - let unwound_finally = self.unwind_fblock_stack(false, false)?; - if unwound_finally { - self.emit_return_const_no_location(ConstantData::None); - } else { - self.set_source_range(stmt_range); + let unwind_loc = self.unwind_fblock_stack(false, false)?; + if let Some(loc) = unwind_loc { + self.set_source_range(loc); self.emit_return_const(ConstantData::None); + } else { + self.emit_return_const_no_location(ConstantData::None); } } } self.set_source_range(prev_source_range); - let dead = self.new_block(); - self.switch_to_block(dead); } - ast::Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { + ast::Stmt::Assign(ast::StmtAssign { + targets, + value, + range, + .. + }) => { if targets.len() == 1 && Self::is_unpack_assignment_target(&targets[0]) { self.compile_expression_without_const_collection_folding(value)?; } else { @@ -3254,6 +3067,7 @@ impl Compiler { for (i, target) in targets.iter().enumerate() { if i + 1 != targets.len() { + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); } self.compile_store(target)?; @@ -3267,9 +3081,16 @@ impl Compiler { annotation, value, simple, + range, .. }) => { - self.compile_annotated_assign(target, annotation, value.as_deref(), *simple)?; + self.compile_annotated_assign( + target, + annotation, + value.as_deref(), + *simple, + *range, + )?; // Bare annotations in function scope emit no code; restore // source range so subsequent instructions keep the correct line. if value.is_none() && self.ctx.in_func() { @@ -3288,6 +3109,7 @@ impl Compiler { name, type_params, value, + range, .. }) => { let Some(name) = name.as_name_expr() else { @@ -3307,7 +3129,6 @@ impl Compiler { // TypeParams scope is function-like let prev_ctx = self.ctx; self.ctx = CompileContext { - loop_data: None, in_class: prev_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, @@ -3317,7 +3138,7 @@ impl Compiler { value: name_string.clone().into(), }); self.compile_type_params(type_params)?; - self.compile_typealias_value_closure(&name_string, value)?; + self.compile_typealias_value_closure(&name_string, value, *range)?; emit!(self, Instruction::BuildTuple { count: 3 }); emit!( self, @@ -3337,7 +3158,7 @@ impl Compiler { value: name_string.clone().into(), }); self.emit_load_const(ConstantData::None); - self.compile_typealias_value_closure(&name_string, value)?; + self.compile_typealias_value_closure(&name_string, value, *range)?; emit!(self, Instruction::BuildTuple { count: 3 }); emit!( self, @@ -3355,32 +3176,43 @@ impl Compiler { } fn compile_delete(&mut self, expression: &ast::Expr) -> CompileResult<()> { - match &expression { - ast::Expr::Name(ast::ExprName { id, .. }) => { - self.compile_name(id.as_str(), NameUsage::Delete)? - } - ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { - self.compile_expression(value)?; - let namei = self.name(attr.as_str()); - emit!(self, Instruction::DeleteAttr { namei }); - } - ast::Expr::Subscript(ast::ExprSubscript { - value, slice, ctx, .. - }) => { - self.compile_subscript(value, slice, *ctx)?; - } - ast::Expr::Tuple(ast::ExprTuple { elts, .. }) - | ast::Expr::List(ast::ExprList { elts, .. }) => { - for element in elts { - self.compile_delete(element)?; + let prev_source_range = self.current_source_range; + self.set_source_range(expression.range()); + let result = (|| -> CompileResult<()> { + match &expression { + ast::Expr::Name(ast::ExprName { id, .. }) => { + self.compile_name(id.as_str(), NameUsage::Delete)? } + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + self.compile_expression(value)?; + let namei = self.name(attr.as_str()); + self.set_source_range(self.update_start_location_to_match_attr( + expression.range(), + expression.range(), + attr.as_str(), + )); + emit!(self, Instruction::DeleteAttr { namei }); + } + ast::Expr::Subscript(ast::ExprSubscript { + value, slice, ctx, .. + }) => { + self.compile_subscript(value, slice, *ctx)?; + } + ast::Expr::Tuple(ast::ExprTuple { elts, .. }) + | ast::Expr::List(ast::ExprList { elts, .. }) => { + for element in elts { + self.compile_delete(element)?; + } + } + ast::Expr::BinOp(_) | ast::Expr::UnaryOp(_) => { + return Err(self.error(CodegenErrorType::Delete("expression"))); + } + _ => return Err(self.error(CodegenErrorType::Delete(expression.python_name()))), } - ast::Expr::BinOp(_) | ast::Expr::UnaryOp(_) => { - return Err(self.error(CodegenErrorType::Delete("expression"))); - } - _ => return Err(self.error(CodegenErrorType::Delete(expression.python_name()))), - } - Ok(()) + Ok(()) + })(); + self.set_source_range(prev_source_range); + result } fn enter_function(&mut self, name: &str, parameters: &ast::Parameters) -> CompileResult<()> { @@ -3437,7 +3269,8 @@ impl Compiler { /// Apply decorators: each decorator calls the function below it. /// Stack: [dec1, dec2, func] → CALL 0 → [dec1, dec2(func)] → CALL 0 → [dec1(dec2(func))] fn apply_decorators(&mut self, decorator_list: &[ast::Decorator]) { - for _ in decorator_list { + for decorator in decorator_list.iter().rev() { + self.set_source_range(decorator.expression.range()); emit!(self, Instruction::Call { argc: 0 }); } } @@ -3449,6 +3282,8 @@ impl Compiler { name: &str, allow_starred: bool, ) -> CompileResult<()> { + let expr_range = expr.range(); + self.set_source_range(expr_range); self.emit_load_const(ConstantData::Tuple { elements: vec![ConstantData::Integer { value: 1.into() }], }); @@ -3473,7 +3308,6 @@ impl Compiler { // TypeParams scope is function-like let prev_ctx = self.ctx; self.ctx = CompileContext { - loop_data: None, in_class: prev_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, @@ -3483,6 +3317,7 @@ impl Compiler { if allow_starred && matches!(expr, ast::Expr::Starred(_)) { if let ast::Expr::Starred(starred) = expr { self.compile_expression(&starred.value)?; + self.set_source_range(expr_range); emit!(self, Instruction::UnpackSequence { count: 1 }); } } else { @@ -3490,12 +3325,14 @@ impl Compiler { } // Return value + self.set_source_range(expr_range); emit!(self, Instruction::ReturnValue); // Exit scope and create closure let code = self.exit_scope(); self.ctx = prev_ctx; + self.set_source_range(expr_range); self.make_closure( code, bytecode::MakeFunctionFlags::from([bytecode::MakeFunctionFlag::Defaults]), @@ -3508,7 +3345,9 @@ impl Compiler { &mut self, alias_name: &str, value: &ast::Expr, + alias_range: TextRange, ) -> CompileResult<()> { + self.set_source_range(alias_range); self.emit_load_const(ConstantData::Tuple { elements: vec![ConstantData::Integer { value: 1.into() }], }); @@ -3525,17 +3364,18 @@ impl Compiler { let prev_ctx = self.ctx; self.ctx = CompileContext { - loop_data: None, in_class: prev_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, }; self.compile_expression(value)?; + self.set_source_range(alias_range); emit!(self, Instruction::ReturnValue); let code = self.exit_scope(); self.ctx = prev_ctx; + self.set_source_range(alias_range); self.make_closure( code, bytecode::MakeFunctionFlags::from([bytecode::MakeFunctionFlag::Defaults]), @@ -3554,8 +3394,10 @@ impl Compiler { name, bound, default, + range, .. }) => { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -3563,6 +3405,7 @@ impl Compiler { if let Some(expr) = &bound { self.compile_type_param_bound_or_default(expr, name.as_str(), false)?; + self.set_source_range(*range); let intrinsic = if expr.is_tuple_expr() { bytecode::IntrinsicFunction2::TypeVarWithConstraint } else { @@ -3584,6 +3427,7 @@ impl Compiler { name.as_str(), false, )?; + self.set_source_range(*range); emit!( self, Instruction::CallIntrinsic2 { @@ -3592,10 +3436,17 @@ impl Compiler { ); } + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.store_name(name.as_ref())?; } - ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, default, .. }) => { + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { + name, + default, + range, + .. + }) => { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -3612,6 +3463,7 @@ impl Compiler { name.as_str(), false, )?; + self.set_source_range(*range); emit!( self, Instruction::CallIntrinsic2 { @@ -3620,12 +3472,17 @@ impl Compiler { ); } + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.store_name(name.as_ref())?; } ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { - name, default, .. + name, + default, + range, + .. }) => { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -3643,6 +3500,7 @@ impl Compiler { name.as_str(), true, )?; + self.set_source_range(*range); emit!( self, Instruction::CallIntrinsic2 { @@ -3651,11 +3509,15 @@ impl Compiler { ); } + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.store_name(name.as_ref())?; } }; } + if let Some(first) = type_params.type_params.first() { + self.set_source_range(first.range()); + } emit!( self, Instruction::BuildTuple { @@ -3676,289 +3538,75 @@ impl Compiler { return self.compile_try_except_no_finally(body, handlers, orelse); } - let handler_block = self.new_block(); - let finally_block = self.new_block(); - - // finally needs TWO blocks: - // - finally_block: normal path (no exception active) - // - finally_except_block: exception path (PUSH_EXC_INFO -> body -> RERAISE) - let finally_except_block = if !finalbody.is_empty() { - Some(self.new_block()) - } else { - None - }; - let finally_cleanup_block = if finally_except_block.is_some() { - Some(self.new_block()) - } else { - None - }; - // End block - continuation point after try-finally - // Normal path jumps here to skip exception path blocks - let end_block = self.new_block(); - if Self::has_resuming_bare_except(handlers) { - self.disable_load_fast_borrow_for_block(end_block); - } - - // Emit NOP at the try: line so LINE events fire for it - emit!(self, Instruction::Nop); + let body_block = self.new_block(); + let finally_except_block = self.new_block(); + let exit_block = self.new_block(); + let finally_cleanup_block = self.new_block(); - // Setup a finally block if we have a finally statement. - // Push fblock with handler info for exception table generation - // IMPORTANT: handler goes to finally_except_block (exception path), not finally_block - if !finalbody.is_empty() { - // SETUP_FINALLY doesn't push lasti for try body handler - // Exception table: L1 to L2 -> L4 [1] (no lasti) - let setup_target = finally_except_block.unwrap_or(finally_block); - emit!( - self, - PseudoInstruction::SetupFinally { - delta: setup_target - } - ); - // Store finally body in fb_datum for unwind_fblock to compile inline - self.push_fblock_full( - FBlockType::FinallyTry, - finally_block, - finally_block, - FBlockDatum::FinallyBody(finalbody.to_vec()), // Clone finally body for unwind - )?; - } + emit!( + self, + PseudoInstruction::SetupFinally { + delta: finally_except_block + } + ); + self.use_cpython_label_block(body_block); + let body_label = self.instr_sequence_label_for_block(body_block); + let finally_except_label = self.instr_sequence_label_for_block(finally_except_block); + self.push_fblock_labels( + FBlockType::FinallyTry, + body_label, + finally_except_label, + FBlockDatum::FinallyBody(finalbody.to_vec()), + )?; - // if handlers is empty, compile body directly - // without wrapping in TryExcept (only FinallyTry is needed) if handlers.is_empty() { - let preserve_finally_entry_nop = self.preserves_finally_entry_nop(body) - || self.statements_end_with_loop_fallthrough(body)?; - - // Just compile body with FinallyTry fblock active (if finalbody exists) self.compile_statements(body)?; - - // Pop FinallyTry fblock BEFORE compiling orelse/finally (normal path) - // This prevents exception table from covering the normal path - if !finalbody.is_empty() { - emit!(self, PseudoInstruction::PopBlock); - if preserve_finally_entry_nop { - self.preserve_last_redundant_nop(); - } else { - self.set_no_location(); - self.remove_last_no_location_nop(); - } - self.pop_fblock(FBlockType::FinallyTry); - } - - // Compile orelse (usually empty for try-finally without except) self.compile_statements(orelse)?; - - // Snapshot sub_tables before first finally compilation - // This allows us to restore them for the second compilation (exception path) - let sub_table_cursor = if !finalbody.is_empty() && finally_except_block.is_some() { - self.symbol_table_stack.last().map(|t| t.next_sub_table) - } else { - None - }; - - // Compile finally body inline for normal path - if !finalbody.is_empty() { - self.compile_statements(finalbody)?; - } - - // Jump to end (skip exception path blocks) - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: end_block } - ); - self.set_no_location(); - - if let Some(finally_except) = finally_except_block { - // Restore sub_tables for exception path compilation - if let Some(cursor) = sub_table_cursor - && let Some(current_table) = self.symbol_table_stack.last_mut() - { - current_table.next_sub_table = cursor; - } - - self.switch_to_block(finally_except); - // SETUP_CLEANUP before PUSH_EXC_INFO - if let Some(cleanup) = finally_cleanup_block { - emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup }); - } - emit!(self, Instruction::PushExcInfo); - if let Some(cleanup) = finally_cleanup_block { - self.push_fblock(FBlockType::FinallyEnd, cleanup, cleanup)?; - } - self.compile_statements(finalbody)?; - - // RERAISE must be inside the cleanup handler's exception table - // range. When RERAISE re-raises the exception, the cleanup - // handler (COPY 3, POP_EXCEPT, RERAISE 1) runs POP_EXCEPT to - // restore exc_info before the exception reaches the outer handler. - emit!(self, Instruction::Reraise { depth: 0 }); - self.set_no_location(); - - // PopBlock after RERAISE (dead code, but marks the exception - // table range end so the cleanup covers RERAISE). - if finally_cleanup_block.is_some() { - emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::FinallyEnd); - } - } - - if let Some(cleanup) = finally_cleanup_block { - self.switch_to_block(cleanup); - emit!(self, Instruction::Copy { i: 3 }); - emit!(self, Instruction::PopExcept); - emit!(self, Instruction::Reraise { depth: 1 }); - } - - self.switch_to_block(end_block); - return Ok(()); + } else { + self.compile_try_except_no_finally(body, handlers, orelse)?; } - // try: - emit!( - self, - PseudoInstruction::SetupFinally { - delta: handler_block - } - ); - self.push_fblock(FBlockType::TryExcept, handler_block, handler_block)?; - self.compile_statements(body)?; emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - self.pop_fblock(FBlockType::TryExcept); + self.pop_fblock_label(FBlockType::FinallyTry, body_label); - let cleanup_block = self.new_block(); - // We successfully ran the try block: - // else: - self.compile_statements(orelse)?; + let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); + self.compile_statements(finalbody)?; emit!( self, - PseudoInstruction::JumpNoInterrupt { - delta: finally_block, - } + PseudoInstruction::JumpNoInterrupt { delta: exit_block } ); self.set_no_location(); - // except handlers: - self.switch_to_block(handler_block); + if let Some(cursor) = sub_table_cursor + && let Some(current_table) = self.symbol_table_stack.last_mut() + { + current_table.next_sub_table = cursor; + } - // SETUP_CLEANUP(cleanup) for except block - // This handles exceptions during exception matching - // Exception table: L2 to L3 -> L5 [1] lasti - // After PUSH_EXC_INFO, stack is [prev_exc, exc] - // depth=1 means keep prev_exc on stack when routing to cleanup + self.use_cpython_label_block(finally_except_block); emit!( self, PseudoInstruction::SetupCleanup { - delta: cleanup_block + delta: finally_cleanup_block } ); self.set_no_location(); - self.push_fblock(FBlockType::ExceptionHandler, cleanup_block, cleanup_block)?; - - // Exception is on top of stack now, pushed by unwind_blocks - // PUSH_EXC_INFO transforms [exc] -> [prev_exc, exc] for PopExcept emit!(self, Instruction::PushExcInfo); self.set_no_location(); - for handler in handlers { - let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - type_, - name, - body, - range: handler_range, - .. - }) = &handler; - self.set_source_range(*handler_range); - let next_handler = self.new_block(); - - if let Some(exc_type) = type_ { - self.compile_expression(exc_type)?; - emit!(self, Instruction::CheckExcMatch); - emit!( - self, - Instruction::PopJumpIfFalse { - delta: next_handler - } - ); - - if let Some(alias) = name { - self.store_name(alias.as_str())? - } else { - emit!(self, Instruction::PopTop); - } - } else { - emit!(self, Instruction::PopTop); - } - - let handler_cleanup_block = if name.is_some() { - let cleanup_end = self.new_block(); - emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup_end }); - self.push_fblock_full( - FBlockType::HandlerCleanup, - cleanup_end, - cleanup_end, - FBlockDatum::ExceptionName(name.as_ref().unwrap().as_str().to_owned()), - )?; - Some(cleanup_end) - } else { - self.push_fblock(FBlockType::HandlerCleanup, finally_block, finally_block)?; - None - }; - - self.compile_statements(body)?; - - self.pop_fblock(FBlockType::HandlerCleanup); - if handler_cleanup_block.is_some() { - emit!(self, PseudoInstruction::PopBlock); - } - - if let Some(cleanup_end) = handler_cleanup_block { - let handler_normal_exit = self.new_block(); - emit!( - self, - PseudoInstruction::JumpNoInterrupt { - delta: handler_normal_exit, - } - ); - - self.switch_to_block(cleanup_end); - if let Some(alias) = name { - self.emit_load_const(ConstantData::None); - self.store_name(alias.as_str())?; - self.compile_name(alias.as_str(), NameUsage::Delete)?; - } - emit!(self, Instruction::Reraise { depth: 1 }); - self.switch_to_block(handler_normal_exit); - } - - emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::ExceptionHandler); - emit!(self, Instruction::PopExcept); - - if let Some(alias) = name { - self.emit_load_const(ConstantData::None); - self.store_name(alias.as_str())?; - self.compile_name(alias.as_str(), NameUsage::Delete)?; - } - - emit!( - self, - PseudoInstruction::JumpNoInterrupt { - delta: finally_block, - } - ); - self.set_no_location(); - - self.push_fblock(FBlockType::ExceptionHandler, cleanup_block, cleanup_block)?; - self.switch_to_block(next_handler); - } - + self.push_fblock_labels( + FBlockType::FinallyEnd, + finally_except_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; + self.compile_statements(finalbody)?; + self.pop_fblock_label(FBlockType::FinallyEnd, finally_except_label); emit!(self, Instruction::Reraise { depth: 0 }); self.set_no_location(); - self.pop_fblock(FBlockType::ExceptionHandler); - self.switch_to_block(cleanup_block); + self.use_cpython_label_block(finally_cleanup_block); emit!(self, Instruction::Copy { i: 3 }); self.set_no_location(); emit!(self, Instruction::PopExcept); @@ -3966,108 +3614,7 @@ impl Compiler { emit!(self, Instruction::Reraise { depth: 1 }); self.set_no_location(); - // finally (normal path): - // CPython's codegen_try_finally emits the wrapped try/except first and - // places the outer finally body at the inner try/except end label. Keep - // the FinallyTry fblock active through exception-handler normal exits so - // the CFG and exception-table ranges match that structure. - self.switch_to_block(finally_block); - if !finalbody.is_empty() { - let preserve_finally_normal_pop_block_nop = orelse.is_empty() - && !Self::statements_end_with_scope_exit(body) - && (!Self::statements_end_with_open_conditional_fallthrough(body) - || Self::statements_end_with_finally_entry_scope_exit(body)) - && handlers.iter().all(|handler| match handler { - ast::ExceptHandler::ExceptHandler(handler) => { - Self::statements_end_with_scope_exit(&handler.body) - } - }); - if preserve_finally_normal_pop_block_nop && let Some(last_body_stmt) = body.last() { - self.set_source_range(last_body_stmt.range()); - } - emit!(self, PseudoInstruction::PopBlock); - if preserve_finally_normal_pop_block_nop { - self.preserve_last_redundant_nop(); - } else { - self.set_no_location(); - } - self.pop_fblock(FBlockType::FinallyTry); - - // Snapshot sub_tables before first finally compilation (for double compilation issue) - let sub_table_cursor = if finally_except_block.is_some() { - self.symbol_table_stack.last().map(|t| t.next_sub_table) - } else { - None - }; - - self.compile_statements(finalbody)?; - // Jump to end_block to skip exception path blocks - // This prevents fall-through to finally_except_block - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: end_block } - ); - self.set_no_location(); - - // finally (exception path) - // This is where exceptions go to run finally before reraise - // Stack at entry: [lasti, exc] (from exception table with preserve_lasti=true) - let finally_except = finally_except_block.expect("finally except block"); - // Restore sub_tables for exception path compilation - if let Some(cursor) = sub_table_cursor - && let Some(current_table) = self.symbol_table_stack.last_mut() - { - current_table.next_sub_table = cursor; - } - - self.switch_to_block(finally_except); - - // SETUP_CLEANUP for finally body - // Exceptions during finally body need to go to cleanup block - if let Some(cleanup) = finally_cleanup_block { - emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup }); - self.set_no_location(); - } - emit!(self, Instruction::PushExcInfo); - self.set_no_location(); - if let Some(cleanup) = finally_cleanup_block { - self.push_fblock(FBlockType::FinallyEnd, cleanup, cleanup)?; - } - - // Run finally body - self.compile_statements(finalbody)?; - - // RERAISE must be inside the cleanup handler's exception table - // range. The cleanup handler (COPY 3, POP_EXCEPT, RERAISE 1) - // runs POP_EXCEPT to restore exc_info before re-raising to - // the outer handler. - emit!(self, Instruction::Reraise { depth: 0 }); - self.set_no_location(); - - // PopBlock after RERAISE (dead code, but marks the exception - // table range end so the cleanup covers RERAISE). - if finally_cleanup_block.is_some() { - emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::FinallyEnd); - } - } - - // finally cleanup block - // This handles exceptions that occur during the finally body itself - // Stack at entry: [lasti, prev_exc, lasti2, exc2] after exception table routing - if let Some(cleanup) = finally_cleanup_block { - self.switch_to_block(cleanup); - // COPY 3: copy the exception from position 3 - emit!(self, Instruction::Copy { i: 3 }); - // POP_EXCEPT: restore prev_exc as current exception - emit!(self, Instruction::PopExcept); - // RERAISE 1: reraise with lasti from stack - emit!(self, Instruction::Reraise { depth: 1 }); - } - - // End block - continuation point after try-finally - // Normal execution continues here after the finally block - self.switch_to_block(end_block); + self.use_cpython_label_block(exit_block); Ok(()) } @@ -4078,104 +3625,36 @@ impl Compiler { handlers: &[ast::ExceptHandler], orelse: &[ast::Stmt], ) -> CompileResult<()> { + let body_block = self.new_block(); let handler_block = self.new_block(); - let cleanup_block = self.new_block(); let end_block = self.new_block(); - let has_terminal_raise_handlers = handlers.iter().all(|handler| { - let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { body, .. }) = - handler; - body.last() - .is_some_and(|stmt| matches!(stmt, ast::Stmt::Raise(_))) - }); - let handlers_end_with_scope_exit = handlers.iter().all(|handler| { - let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { body, .. }) = - handler; - Self::statements_end_with_scope_exit(body) - }); - let typed_handlers_end_with_scope_exit = handlers.iter().all(|handler| { - let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - type_, - body, - .. - }) = handler; - type_.is_some() && Self::statements_end_with_scope_exit(body) - }); - if Self::has_resuming_bare_except(handlers) { - self.disable_load_fast_borrow_for_block(end_block); - } - if typed_handlers_end_with_scope_exit { - self.disable_load_fast_borrow_for_block(end_block); - } - + let cleanup_block = self.new_block(); emit!( self, PseudoInstruction::SetupFinally { delta: handler_block } ); - - self.push_fblock(FBlockType::TryExcept, handler_block, handler_block)?; - let split_for_normal_exit_from_break = orelse.is_empty() - && self.fallthrough_has_statement_successor - && Self::statements_are_single_for_direct_break(body) - && !self - .current_code_info() - .fblock - .iter() - .any(|info| matches!(info.fb_type, FBlockType::With | FBlockType::AsyncWith)); - let previous_split_for_normal_exit_from_break = self.split_next_for_normal_exit_from_break; - self.split_next_for_normal_exit_from_break = - previous_split_for_normal_exit_from_break || split_for_normal_exit_from_break; - let compile_body_result = self.compile_statements(body); - self.split_next_for_normal_exit_from_break = previous_split_for_normal_exit_from_break; - compile_body_result?; - self.pop_fblock(FBlockType::TryExcept); + self.use_cpython_label_block(body_block); + let body_label = self.instr_sequence_label_for_block(body_block); + self.push_fblock_labels( + FBlockType::TryExcept, + body_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; + self.compile_statements(body)?; + self.pop_fblock_label(FBlockType::TryExcept, body_label); emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - let exits_directly_to_with_cleanup = { - let code_info = self.current_code_info(); - code_info.in_final_with_cleanup_statement > 0 - && code_info.fblock.last().is_some_and(|info| { - matches!(info.fb_type, FBlockType::With | FBlockType::AsyncWith) - }) - }; - if !orelse.is_empty() && self.statements_end_with_conditional_scope_exit(body) { - self.preserve_last_redundant_nop(); - } else { - self.remove_last_no_location_nop(); - } - if !orelse.is_empty() { - if has_terminal_raise_handlers { - let orelse_block = self.new_block(); - self.switch_to_block(orelse_block); - } - let current = self.current_code_info().current_block; - self.mark_try_else_orelse_entry_block(current); - } - let try_else_orelse_conditional_base = self.current_code_info().in_conditional_block; - self.current_code_info().in_try_else_orelse += 1; - self.try_else_orelse_conditional_base_stack - .push(try_else_orelse_conditional_base); - let compile_orelse_result = self.compile_statements(orelse); - self.try_else_orelse_conditional_base_stack.pop(); - self.current_code_info().in_try_else_orelse -= 1; - compile_orelse_result?; + self.compile_statements(orelse)?; emit!( self, PseudoInstruction::JumpNoInterrupt { delta: end_block } ); self.set_no_location(); - if (!orelse.is_empty() && self.statements_end_with_loop_fallthrough(orelse)?) - || (exits_directly_to_with_cleanup - && handlers_end_with_scope_exit - && !Self::statements_end_with_nonterminal_with_cleanup(body)) - { - self.preserve_last_redundant_jump_as_nop(); - } else { - self.remove_last_no_location_nop(); - } - self.switch_to_block(handler_block); + self.use_cpython_label_block(handler_block); emit!( self, PseudoInstruction::SetupCleanup { @@ -4185,9 +3664,14 @@ impl Compiler { self.set_no_location(); emit!(self, Instruction::PushExcInfo); self.set_no_location(); - self.push_fblock(FBlockType::ExceptionHandler, cleanup_block, cleanup_block)?; + self.push_fblock_labels( + FBlockType::ExceptionHandler, + ir::InstructionSequenceLabel::NO_LABEL, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; - for handler in handlers { + for (i, handler) in handlers.iter().enumerate() { let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { type_, name, @@ -4196,12 +3680,16 @@ impl Compiler { .. }) = handler; self.set_source_range(*handler_range); + if type_.is_none() && i < handlers.len() - 1 { + return Err(self.error(CodegenErrorType::SyntaxError( + "default 'except:' must be last".to_owned(), + ))); + } let next_handler = self.new_block(); - let handler_body_exits = Self::statements_end_with_scope_exit(body); - let mut exception_handler_was_popped = false; if let Some(exc_type) = type_ { self.compile_expression(exc_type)?; + self.set_source_range(*handler_range); emit!(self, Instruction::CheckExcMatch); emit!( self, @@ -4212,90 +3700,81 @@ impl Compiler { } if let Some(alias) = name { + let cleanup_end = self.new_block(); + let cleanup_body = self.new_block(); + self.store_name(alias.as_str())?; - let cleanup_end = self.new_block(); emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup_end }); - let cleanup_body = self.new_block(); - self.switch_to_block(cleanup_body); - self.push_fblock_full( + self.use_cpython_label_block(cleanup_body); + let cleanup_body_label = self.instr_sequence_label_for_block(cleanup_body); + self.push_fblock_labels( FBlockType::HandlerCleanup, - cleanup_body, - cleanup_end, + cleanup_body_label, + ir::InstructionSequenceLabel::NO_LABEL, FBlockDatum::ExceptionName(alias.as_str().to_owned()), )?; self.compile_statements(body)?; - self.pop_fblock(FBlockType::HandlerCleanup); - if !handler_body_exits { - emit!(self, PseudoInstruction::PopBlock); - self.set_no_location(); - emit!(self, PseudoInstruction::PopBlock); - self.set_no_location(); - self.pop_fblock(FBlockType::ExceptionHandler); - exception_handler_was_popped = true; - emit!(self, Instruction::PopExcept); - self.set_no_location(); - - self.emit_load_const(ConstantData::None); - self.set_no_location(); - self.store_name(alias.as_str())?; - self.set_no_location(); - self.compile_name(alias.as_str(), NameUsage::Delete)?; - self.set_no_location(); - - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: end_block } - ); - self.set_no_location(); - } - - self.switch_to_block(cleanup_end); - self.emit_load_const(ConstantData::None); + self.pop_fblock_label(FBlockType::HandlerCleanup, cleanup_body_label); + emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - self.store_name(alias.as_str())?; + emit!(self, PseudoInstruction::PopBlock); + self.set_no_location(); + emit!(self, Instruction::PopExcept); self.set_no_location(); - self.compile_name(alias.as_str(), NameUsage::Delete)?; + + self.emit_no_location_exception_name_cleanup(alias.as_str())?; + + emit!( + self, + PseudoInstruction::JumpNoInterrupt { delta: end_block } + ); self.set_no_location(); + + self.use_cpython_label_block(cleanup_end); + self.emit_no_location_exception_name_cleanup(alias.as_str())?; emit!(self, Instruction::Reraise { depth: 1 }); self.set_no_location(); } else { - emit!(self, Instruction::PopTop); let cleanup_body = self.new_block(); - self.switch_to_block(cleanup_body); - self.push_fblock(FBlockType::HandlerCleanup, cleanup_body, end_block)?; + + emit!(self, Instruction::PopTop); + self.use_cpython_label_block(cleanup_body); + let cleanup_body_label = self.instr_sequence_label_for_block(cleanup_body); + self.push_fblock_labels( + FBlockType::HandlerCleanup, + cleanup_body_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; self.compile_statements(body)?; - self.pop_fblock(FBlockType::HandlerCleanup); - if !handler_body_exits { - emit!(self, PseudoInstruction::PopBlock); - self.set_no_location(); - self.pop_fblock(FBlockType::ExceptionHandler); - exception_handler_was_popped = true; - emit!(self, Instruction::PopExcept); - self.set_no_location(); - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: end_block } - ); - self.set_no_location(); - } + self.pop_fblock_label(FBlockType::HandlerCleanup, cleanup_body_label); + emit!(self, PseudoInstruction::PopBlock); + self.set_no_location(); + emit!(self, Instruction::PopExcept); + self.set_no_location(); + emit!( + self, + PseudoInstruction::JumpNoInterrupt { delta: end_block } + ); + self.set_no_location(); } - if exception_handler_was_popped { - self.push_fblock(FBlockType::ExceptionHandler, cleanup_block, cleanup_block)?; - } - self.switch_to_block(next_handler); + self.use_cpython_label_block(next_handler); } emit!(self, Instruction::Reraise { depth: 0 }); self.set_no_location(); - self.pop_fblock(FBlockType::ExceptionHandler); + self.pop_fblock_label( + FBlockType::ExceptionHandler, + ir::InstructionSequenceLabel::NO_LABEL, + ); - self.switch_to_block(cleanup_block); + self.use_cpython_label_block(cleanup_block); emit!(self, Instruction::Copy { i: 3 }); self.set_no_location(); emit!(self, Instruction::PopExcept); @@ -4303,126 +3782,167 @@ impl Compiler { emit!(self, Instruction::Reraise { depth: 1 }); self.set_no_location(); - self.switch_to_block(end_block); + self.use_cpython_label_block(end_block); Ok(()) } - fn compile_try_star_except( + fn compile_try_star_statement( &mut self, body: &[ast::Stmt], handlers: &[ast::ExceptHandler], orelse: &[ast::Stmt], finalbody: &[ast::Stmt], ) -> CompileResult<()> { - // compiler_try_star_except - // Stack layout during handler processing: [prev_exc, orig, list, rest] - let handler_block = self.new_block(); - let finally_block = self.new_block(); - let cleanup_block = self.new_block(); - let end_block = self.new_block(); - let reraise_star_block = self.new_block(); - let reraise_block = self.new_block(); - let finally_cleanup_block = if !finalbody.is_empty() { - Some(self.new_block()) - } else { - None - }; - let exit_block = self.new_block(); - let continuation_block = end_block; - let else_block = if orelse.is_empty() && finalbody.is_empty() { - continuation_block - } else { - self.new_block() - }; - if !handlers.is_empty() { - self.disable_load_fast_borrow_for_block(end_block); - if !finalbody.is_empty() { - self.disable_load_fast_borrow_for_block(exit_block); - } + if finalbody.is_empty() { + return self.compile_try_star_except(body, handlers, orelse); } - // Emit NOP at the try: line so LINE events fire for it - emit!(self, Instruction::Nop); - - // Push fblock with handler info for exception table generation - if !finalbody.is_empty() { - emit!( - self, - PseudoInstruction::SetupFinally { - delta: finally_block - } - ); - self.push_fblock_full( - FBlockType::FinallyTry, - finally_block, - finally_block, - FBlockDatum::FinallyBody(finalbody.to_vec()), - )?; - } + let body_block = self.new_block(); + let finally_except_block = self.new_block(); + let exit_block = self.new_block(); + let finally_cleanup_block = self.new_block(); - // SETUP_FINALLY for try body emit!( self, PseudoInstruction::SetupFinally { - delta: handler_block + delta: finally_except_block } ); - self.push_fblock(FBlockType::TryExcept, handler_block, handler_block)?; - self.compile_statements(body)?; + self.use_cpython_label_block(body_block); + let body_label = self.instr_sequence_label_for_block(body_block); + let finally_except_label = self.instr_sequence_label_for_block(finally_except_block); + self.push_fblock_labels( + FBlockType::FinallyTry, + body_label, + finally_except_label, + FBlockDatum::FinallyBody(finalbody.to_vec()), + )?; + + if handlers.is_empty() { + self.compile_statements(body)?; + } else { + self.compile_try_star_except(body, handlers, orelse)?; + } + emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - self.remove_last_no_location_nop(); - self.pop_fblock(FBlockType::TryExcept); + self.pop_fblock_label(FBlockType::FinallyTry, body_label); + + let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); + self.compile_statements(finalbody)?; + emit!( self, - PseudoInstruction::JumpNoInterrupt { delta: else_block } + PseudoInstruction::JumpNoInterrupt { delta: exit_block } ); self.set_no_location(); - self.remove_last_no_location_nop(); - // Exception handler entry - self.switch_to_block(handler_block); - // Stack: [exc] (from exception table) + if let Some(cursor) = sub_table_cursor + && let Some(current_table) = self.symbol_table_stack.last_mut() + { + current_table.next_sub_table = cursor; + } + self.use_cpython_label_block(finally_except_block); emit!( self, PseudoInstruction::SetupCleanup { - delta: cleanup_block + delta: finally_cleanup_block } ); - - // PUSH_EXC_INFO + self.set_no_location(); emit!(self, Instruction::PushExcInfo); - // Stack: [prev_exc, exc] + self.set_no_location(); + self.push_fblock_labels( + FBlockType::FinallyEnd, + finally_except_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; + self.compile_statements(finalbody)?; + self.pop_fblock_label(FBlockType::FinallyEnd, finally_except_label); + emit!(self, Instruction::Reraise { depth: 0 }); + self.set_no_location(); + + self.use_cpython_label_block(finally_cleanup_block); + emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); + emit!(self, Instruction::PopExcept); + self.set_no_location(); + emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); + + self.use_cpython_label_block(exit_block); + + Ok(()) + } + + fn compile_try_star_except( + &mut self, + body: &[ast::Stmt], + handlers: &[ast::ExceptHandler], + orelse: &[ast::Stmt], + ) -> CompileResult<()> { + // compiler_try_star_except + // Stack layout during handler processing: [prev_exc, orig, list, rest] + let body_block = self.new_block(); + let handler_block = self.new_block(); + let else_block = self.new_block(); + let end_block = self.new_block(); + let cleanup_block = self.new_block(); + let reraise_star_block = self.new_block(); + + // SETUP_FINALLY for try body + emit!( + self, + PseudoInstruction::SetupFinally { + delta: handler_block + } + ); + self.use_cpython_label_block(body_block); + let body_label = self.instr_sequence_label_for_block(body_block); + self.push_fblock_labels( + FBlockType::TryExcept, + body_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; + self.compile_statements(body)?; + emit!(self, PseudoInstruction::PopBlock); + self.set_no_location(); + self.pop_fblock_label(FBlockType::TryExcept, body_label); + emit!( + self, + PseudoInstruction::JumpNoInterrupt { delta: else_block } + ); + self.set_no_location(); + + // Exception handler entry + self.use_cpython_label_block(handler_block); + // Stack: [exc] (from exception table) + + emit!( + self, + PseudoInstruction::SetupCleanup { + delta: cleanup_block + } + ); + self.set_no_location(); + + // PUSH_EXC_INFO + emit!(self, Instruction::PushExcInfo); + self.set_no_location(); + // Stack: [prev_exc, exc] // Push EXCEPTION_GROUP_HANDLER fblock - self.push_fblock( + self.push_fblock_labels( FBlockType::ExceptionGroupHandler, - cleanup_block, - cleanup_block, + ir::InstructionSequenceLabel::NO_LABEL, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, )?; - // Initialize handler stack before the loop - // BUILD_LIST 0 + COPY 2 to set up [prev_exc, orig, list, rest] - emit!(self, Instruction::BuildList { count: 0 }); - // Stack: [prev_exc, exc, []] - emit!(self, Instruction::Copy { i: 2 }); - // Stack: [prev_exc, orig, list, rest] - let n = handlers.len(); - if n == 0 { - // Empty handlers (invalid AST) - append rest to list and proceed - // Stack: [prev_exc, orig, list, rest] - emit!(self, Instruction::ListAppend { i: 1 }); - // Stack: [prev_exc, orig, list] - emit!( - self, - PseudoInstruction::JumpNoInterrupt { - delta: reraise_star_block - } - ); - self.set_no_location(); - } for (i, handler) in handlers.iter().enumerate() { let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { type_, @@ -4432,26 +3952,23 @@ impl Compiler { .. }) = handler; let is_last_handler = i == n - 1; + self.set_source_range(*handler_range); + let next_handler_block = self.new_block(); + let except_with_error_block = self.new_block(); let no_match_block = self.new_block(); - let next_handler_block = if is_last_handler { - reraise_star_block - } else { - self.new_block() - }; + + if i == 0 { + // CPython initializes the except* work stack inside the first + // handler iteration in codegen_try_star_except(). + emit!(self, Instruction::BuildList { count: 0 }); + emit!(self, Instruction::Copy { i: 2 }); + } // Compile exception type if let Some(exc_type) = type_ { - // Check for unparenthesized tuple - if let ast::Expr::Tuple(ast::ExprTuple { elts, range, .. }) = exc_type.as_ref() - && let Some(first) = elts.first() - && range.start().to_u32() == first.range().start().to_u32() - { - return Err(self.error(CodegenErrorType::SyntaxError( - "multiple exception types must be parenthesized".to_owned(), - ))); - } self.compile_expression(exc_type)?; + self.set_source_range(*handler_range); } else { return Err(self.error(CodegenErrorType::SyntaxError( "except* must specify an exception type".to_owned(), @@ -4476,7 +3993,8 @@ impl Compiler { // Handler matched // Stack: [prev_exc, orig, list, new_rest, match] // Note: CheckEgMatch already sets the matched exception as current exception - let handler_except_block = self.new_block(); + let cleanup_end_block = self.new_block(); + let cleanup_body_block = self.new_block(); // Store match to name or pop if let Some(alias) = name { @@ -4490,13 +4008,15 @@ impl Compiler { emit!( self, PseudoInstruction::SetupCleanup { - delta: handler_except_block + delta: cleanup_end_block } ); - self.push_fblock_full( + self.use_cpython_label_block(cleanup_body_block); + let cleanup_body_label = self.instr_sequence_label_for_block(cleanup_body_block); + self.push_fblock_labels( FBlockType::HandlerCleanup, - next_handler_block, - end_block, + cleanup_body_label, + ir::InstructionSequenceLabel::NO_LABEL, if let Some(alias) = name { FBlockDatum::ExceptionName(alias.as_str().to_owned()) } else { @@ -4510,36 +4030,29 @@ impl Compiler { // Handler body completed normally emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - self.pop_fblock(FBlockType::HandlerCleanup); + self.pop_fblock_label(FBlockType::HandlerCleanup, cleanup_body_label); // Cleanup name binding if let Some(alias) = name { - self.emit_load_const(ConstantData::None); - self.store_name(alias.as_str())?; - self.compile_name(alias.as_str(), NameUsage::Delete)?; + self.emit_no_location_exception_name_cleanup(alias.as_str())?; } - if is_last_handler { - emit!(self, Instruction::ListAppend { i: 1 }); - } emit!( self, PseudoInstruction::JumpNoInterrupt { delta: next_handler_block } ); + self.set_no_location(); // Handler raised an exception (cleanup_end label) - self.switch_to_block(handler_except_block); + self.use_cpython_label_block(cleanup_end_block); // Stack: [prev_exc, orig, list, new_rest, lasti, raised_exc] // (lasti is pushed because push_lasti=true in HANDLER_CLEANUP fblock) // Cleanup name binding - self.set_no_location(); if let Some(alias) = name { - self.emit_load_const(ConstantData::None); - self.store_name(alias.as_str())?; - self.compile_name(alias.as_str(), NameUsage::Delete)?; + self.emit_no_location_exception_name_cleanup(alias.as_str())?; } // LIST_APPEND(3) - append raised_exc to list @@ -4548,59 +4061,63 @@ impl Compiler { // nth_value(i) = stack[len - i - 1], we need stack[2] = list // stack[5 - i - 1] = 2 -> i = 2 emit!(self, Instruction::ListAppend { i: 3 }); + self.set_no_location(); // Stack: [prev_exc, orig, list, new_rest, lasti] // POP_TOP - pop lasti emit!(self, Instruction::PopTop); + self.set_no_location(); // Stack: [prev_exc, orig, list, new_rest] - if is_last_handler { - emit!(self, Instruction::ListAppend { i: 1 }); - emit!( - self, - PseudoInstruction::JumpNoInterrupt { - delta: reraise_star_block - } - ); - } else { - emit!( - self, - PseudoInstruction::JumpNoInterrupt { - delta: next_handler_block - } - ); - } + emit!( + self, + PseudoInstruction::JumpNoInterrupt { + delta: except_with_error_block + } + ); + self.set_no_location(); - if is_last_handler { - self.switch_to_block(no_match_block); - self.set_source_range(*handler_range); - emit!(self, Instruction::PopTop); // pop match (None) - // Stack: [prev_exc, orig, list, new_rest] + self.use_cpython_label_block(next_handler_block); + emit!(self, Instruction::Nop); + self.set_no_location(); + emit!( + self, + PseudoInstruction::JumpNoInterrupt { + delta: except_with_error_block + } + ); + self.set_no_location(); - self.set_no_location(); + self.use_cpython_label_block(no_match_block); + self.set_source_range(*handler_range); + emit!(self, Instruction::PopTop); // pop match (None) + // Stack: [prev_exc, orig, list, new_rest] + + self.use_cpython_label_block(except_with_error_block); + + if is_last_handler { emit!(self, Instruction::ListAppend { i: 1 }); + self.set_no_location(); emit!( self, PseudoInstruction::JumpNoInterrupt { delta: reraise_star_block } ); - } else { - self.switch_to_block(no_match_block); - self.set_source_range(*handler_range); - emit!(self, Instruction::PopTop); // pop match (None) - // Stack: [prev_exc, orig, list, new_rest] - self.switch_to_block(next_handler_block); + self.set_no_location(); } } // Pop EXCEPTION_GROUP_HANDLER fblock - self.pop_fblock(FBlockType::ExceptionGroupHandler); + self.pop_fblock_label( + FBlockType::ExceptionGroupHandler, + ir::InstructionSequenceLabel::NO_LABEL, + ); + let reraise_block = self.new_block(); // Reraise star block - self.switch_to_block(reraise_star_block); + self.use_cpython_label_block(reraise_star_block); // Stack: [prev_exc, orig, list] - self.set_no_location(); // CALL_INTRINSIC_2 PREP_RERAISE_STAR // Takes 2 args (orig, list) and produces result @@ -4610,10 +4127,12 @@ impl Compiler { func: bytecode::IntrinsicFunction2::PrepReraiseStar } ); + self.set_no_location(); // Stack: [prev_exc, result] // COPY 1 emit!(self, Instruction::Copy { i: 1 }); + self.set_no_location(); // Stack: [prev_exc, result, result] // POP_JUMP_IF_NOT_NONE reraise @@ -4623,118 +4142,59 @@ impl Compiler { delta: reraise_block } ); + self.set_no_location(); // Stack: [prev_exc, result] // Nothing to reraise // POP_TOP - pop result (None) emit!(self, Instruction::PopTop); + self.set_no_location(); // Stack: [prev_exc] emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); // POP_EXCEPT - restore previous exception context emit!(self, Instruction::PopExcept); + self.set_no_location(); // Stack: [] emit!( self, - PseudoInstruction::JumpNoInterrupt { - delta: continuation_block - } + PseudoInstruction::JumpNoInterrupt { delta: end_block } ); + self.set_no_location(); // Reraise the result - self.switch_to_block(reraise_block); + self.use_cpython_label_block(reraise_block); // Stack: [prev_exc, result] emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); emit!(self, Instruction::Swap { i: 2 }); + self.set_no_location(); // Stack: [result, prev_exc] // POP_EXCEPT emit!(self, Instruction::PopExcept); + self.set_no_location(); // Stack: [result] // RERAISE 0 emit!(self, Instruction::Reraise { depth: 0 }); - - self.switch_to_block(cleanup_block); self.set_no_location(); + + self.use_cpython_label_block(cleanup_block); emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); emit!(self, Instruction::PopExcept); + self.set_no_location(); emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); - // try-else path - if else_block != continuation_block { - self.switch_to_block(else_block); - self.compile_statements(orelse)?; - - emit!( - self, - PseudoInstruction::JumpNoInterrupt { - delta: continuation_block - } - ); - self.set_no_location(); - } - - if !finalbody.is_empty() { - self.switch_to_block(end_block); - emit!(self, PseudoInstruction::PopBlock); - self.set_no_location(); - self.remove_last_no_location_nop(); - self.pop_fblock(FBlockType::FinallyTry); - - // Snapshot sub_tables before first finally compilation - let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); - - // Compile finally body inline for normal path - self.compile_statements(finalbody)?; - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: exit_block } - ); - - // Restore sub_tables for exception path compilation - if let Some(cursor) = sub_table_cursor - && let Some(current_table) = self.symbol_table_stack.last_mut() - { - current_table.next_sub_table = cursor; - } - - // Exception handler path - self.switch_to_block(finally_block); - emit!(self, Instruction::PushExcInfo); - - if let Some(cleanup) = finally_cleanup_block { - emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup }); - self.push_fblock(FBlockType::FinallyEnd, cleanup, cleanup)?; - } - - self.compile_statements(finalbody)?; - - if finally_cleanup_block.is_some() { - emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::FinallyEnd); - } - - emit!(self, Instruction::Reraise { depth: 0 }); - self.set_no_location(); - - if let Some(cleanup) = finally_cleanup_block { - self.switch_to_block(cleanup); - emit!(self, Instruction::Copy { i: 3 }); - emit!(self, Instruction::PopExcept); - emit!(self, Instruction::Reraise { depth: 1 }); - } - } + self.use_cpython_label_block(else_block); + self.compile_statements(orelse)?; - self.switch_to_block(if finalbody.is_empty() { - end_block - } else { - exit_block - }); + self.use_cpython_label_block(end_block); Ok(()) } @@ -4744,6 +4204,7 @@ impl Compiler { fn compile_default_arguments( &mut self, parameters: &ast::Parameters, + loc: TextRange, ) -> CompileResult { let mut funcflags = bytecode::MakeFunctionFlags::new(); @@ -4759,6 +4220,7 @@ impl Compiler { for default in &defaults { self.compile_expression(default)?; } + self.set_source_range(loc); emit!( self, Instruction::BuildTuple { @@ -4779,11 +4241,13 @@ impl Compiler { if !kw_with_defaults.is_empty() { // Compile kwdefaults and build dict for (arg, default) in &kw_with_defaults { + self.set_source_range(loc); self.emit_load_const(ConstantData::Str { value: self.mangle(arg.name.as_str()).into_owned().into(), }); self.compile_expression(default)?; } + self.set_source_range(loc); emit!( self, Instruction::BuildMap { @@ -4805,10 +4269,8 @@ impl Compiler { body: &[ast::Stmt], is_async: bool, funcflags: bytecode::MakeFunctionFlags, + closure_range: TextRange, ) -> CompileResult<()> { - // Save source range so MAKE_FUNCTION gets the `def` line, not the body's last line - let saved_range = self.current_source_range; - // Always enter function scope self.enter_function(name, parameters)?; self.current_code_info() @@ -4818,7 +4280,6 @@ impl Compiler { // Set up context let prev_ctx = self.ctx; self.ctx = CompileContext { - loop_data: None, in_class: prev_ctx.in_class, func: if is_async { FunctionContext::AsyncFunction @@ -4832,35 +4293,37 @@ impl Compiler { // Set qualname self.set_qualname(); - // PEP 479: Wrap generator/coroutine body with StopIteration handler - let is_gen = is_async || self.current_symbol_table().is_generator; - let stop_iteration_block = if is_gen { - let handler_block = self.new_block(); - emit!( - self, - PseudoInstruction::SetupCleanup { - delta: handler_block - } - ); - self.set_no_location(); - self.push_fblock(FBlockType::StopIteration, handler_block, handler_block)?; - Some(handler_block) - } else { - None - }; - // Handle docstring - store in co_consts[0] if present - let (doc_str, body) = split_doc(body, &self.opts); + let (doc_info, body) = split_doc_with_range(body, &self.opts); + let doc_str = doc_info.as_ref().map(|(doc, _)| doc); if let Some(doc) = &doc_str { // Docstring present: store in co_consts[0] and set HAS_DOCSTRING flag self.current_code_info() .metadata .consts .insert_full(ConstantData::Str { - value: doc.to_string().into(), + value: (*doc).to_string().into(), }); self.current_code_info().flags |= bytecode::CodeFlags::HAS_DOCSTRING; } + + let start_label = self.use_cpython_function_start_label(); + + // PEP 479: Wrap generator/coroutine body with StopIteration handler + let is_gen = is_async || self.current_symbol_table().is_generator; + let stop_iteration_block = if is_gen { + let handler_block = self.new_block(); + self.insert_cpython_stopiteration_setup_cleanup(handler_block); + self.push_fblock_labels( + FBlockType::StopIteration, + start_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + )?; + Some(handler_block) + } else { + None + }; // Compile body statements self.compile_statements(body)?; @@ -4881,10 +4344,8 @@ impl Compiler { // Close StopIteration handler and emit handler code if let Some(handler_block) = stop_iteration_block { - emit!(self, PseudoInstruction::PopBlock); - self.set_no_location(); - self.pop_fblock(FBlockType::StopIteration); - self.switch_to_block(handler_block); + self.pop_fblock_label(FBlockType::StopIteration, start_label); + self.use_cpython_label_block(handler_block); emit!( self, Instruction::CallIntrinsic1 { @@ -4900,7 +4361,7 @@ impl Compiler { let code = self.exit_scope(); self.ctx = prev_ctx; - self.set_source_range(saved_range); + self.set_source_range(closure_range); // Create function object with closure self.make_closure(code, funcflags)?; @@ -4919,6 +4380,7 @@ impl Compiler { func_name: &str, parameters: &ast::Parameters, returns: Option<&ast::Expr>, + func_range: TextRange, ) -> CompileResult { let has_signature_annotations = parameters .args @@ -4935,7 +4397,7 @@ impl Compiler { } // Try to enter annotation scope - returns None if no annotation_block exists - let Some(saved_ctx) = self.enter_annotation_scope(func_name)? else { + let Some(saved_ctx) = self.enter_annotation_scope(func_name, func_range)? else { return Ok(false); }; @@ -4966,6 +4428,7 @@ impl Compiler { for param in parameters_iter { if let Some(annotation) = ¶m.annotation { + self.set_source_range(func_range); self.emit_load_const(ConstantData::Str { value: self.mangle(param.name.as_str()).into_owned().into(), }); @@ -4975,6 +4438,7 @@ impl Compiler { // Handle return annotation if let Some(annotation) = returns { + self.set_source_range(func_range); self.emit_load_const(ConstantData::Str { value: "return".into(), }); @@ -4982,6 +4446,7 @@ impl Compiler { } // Build the map and return it + self.set_source_range(func_range); emit!( self, Instruction::BuildMap { @@ -4994,6 +4459,7 @@ impl Compiler { let annotate_code = self.exit_annotation_scope(saved_ctx); // Make a closure from the code object + self.set_source_range(func_range); self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; Ok(true) @@ -5055,9 +4521,21 @@ impl Compiler { annotations } + fn compile_annotation_for_symbol_cursor_only( + &mut self, + annotation: &ast::Expr, + ) -> CompileResult<()> { + self.consume_skipped_nested_scopes_in_expr(annotation) + } + /// Compile module-level __annotate__ function (PEP 649) /// Returns true if __annotate__ was created and stored - fn compile_module_annotate(&mut self, body: &[ast::Stmt]) -> CompileResult { + fn compile_module_annotate( + &mut self, + body: &[ast::Stmt], + loc: Option, + ) -> CompileResult { + let loc = loc.unwrap_or(self.current_source_range); let annotations = Self::collect_annotations(body); let simple_annotation_count = annotations .iter() @@ -5081,13 +4559,13 @@ impl Compiler { // Annotation scopes are never async (even inside async functions) let saved_ctx = self.ctx; self.ctx = CompileContext { - loop_data: None, in_class: saved_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, }; // Enter annotation scope for code generation + self.set_source_range(loc); let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get(); self.enter_scope( @@ -5106,6 +4584,7 @@ impl Compiler { // Emit format validation: if format > VALUE_WITH_FAKE_GLOBALS: raise NotImplementedError self.emit_format_validation(); + self.set_source_range(loc); emit!(self, Instruction::BuildMap { count: 0 }); let mut simple_idx = 0usize; @@ -5114,6 +4593,7 @@ impl Compiler { target, annotation, simple, + range, .. } = stmt; let simple_name = if *simple { @@ -5127,18 +4607,18 @@ impl Compiler { if simple_name.is_none() { if !self.future_annotations { - self.do_not_emit_bytecode += 1; - let result = self.compile_annotation(annotation); - self.do_not_emit_bytecode -= 1; - result?; + self.compile_annotation_for_symbol_cursor_only(annotation)?; } continue; } let not_set_block = has_conditional.then(|| self.new_block()); + let not_set_label = + (!has_conditional).then(|| self.current_code_info().new_instr_sequence_label()); let name = simple_name.expect("missing simple annotation name"); if has_conditional { + self.set_source_range(*range); self.emit_load_const(ConstantData::Integer { value: simple_idx.into(), }); @@ -5164,18 +4644,26 @@ impl Compiler { } self.compile_annotation(annotation)?; + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 2 }); self.emit_load_const(ConstantData::Str { value: self.mangle(name).into_owned().into(), }); + self.set_source_range(loc); emit!(self, Instruction::StoreSubscr); simple_idx += 1; if let Some(not_set_block) = not_set_block { - self.switch_to_block(not_set_block); + self.use_cpython_label_block(not_set_block); + } else if let Some(not_set_label) = not_set_label { + let result = self + .current_code_info() + .use_raw_instr_sequence_label(not_set_label); + unwrap_internal(self, result); } } + self.set_source_range(loc); emit!(self, Instruction::ReturnValue); // Exit annotation scope - pop symbol table, restore to parent's annotation_block, and get code @@ -5195,6 +4683,7 @@ impl Compiler { ); // Make a closure from the code object + self.set_source_range(loc); self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; // Store as __annotate_func__ for classes, __annotate__ for modules @@ -5203,13 +4692,14 @@ impl Compiler { } else { "__annotate__" }; + self.set_source_range(loc); self.store_name(name)?; Ok(true) } // = compiler_function - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments, reason = "ignore warning for now")] fn compile_function_def( &mut self, name: &str, @@ -5219,22 +4709,28 @@ impl Compiler { returns: Option<&ast::Expr>, // TODO: use type hint somehow.. is_async: bool, type_params: Option<&ast::TypeParams>, + preserve_value_before_store: bool, ) -> CompileResult<()> { - // Save the source range of the `def` line before compiling decorators/defaults, - // so that the function code object gets the correct co_firstlineno. - let def_source_range = self.current_source_range; + // CPython's FunctionDef/AsyncFunctionDef LOC(s) starts at the + // definition line even when decorators are present. + let stmt_source_range = self.current_source_range; + let def_source_range = self.decorated_definition_range( + stmt_source_range, + decorator_list, + if is_async { "async def " } else { "def " }, + ); self.prepare_decorators(decorator_list)?; // compile defaults and return funcflags - let funcflags = self.compile_default_arguments(parameters)?; + let funcflags = self.compile_default_arguments(parameters, def_source_range)?; // Restore the `def` line range so that enter_function → push_output → get_source_line_number() // records the `def` keyword's line as co_firstlineno, not the last default-argument line. self.set_source_range(def_source_range); let is_generic = type_params.is_some(); - let mut num_typeparam_args = 0; + let mut num_typeparam_args = 0u32; // Save context before entering TypeParams scope let saved_ctx = self.ctx; @@ -5256,14 +4752,13 @@ impl Compiler { self.push_output( bytecode::CodeFlags::OPTIMIZED | bytecode::CodeFlags::NEWLOCALS, 0, - num_typeparam_args as u32, + num_typeparam_args, 0, &type_params_name, )?; // TypeParams scope is function-like self.ctx = CompileContext { - loop_data: None, in_class: saved_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, @@ -5290,20 +4785,28 @@ impl Compiler { // Load defaults/kwdefaults with LOAD_FAST for i in 0..num_typeparam_args { - let var_num = oparg::VarNum::from(i as u32); + let var_num = oparg::VarNum::from(i); emit!(self, Instruction::LoadFast { var_num }); } } // Compile annotations as closure (PEP 649) let mut annotations_flag = bytecode::MakeFunctionFlags::new(); - if self.compile_annotations_closure(name, parameters, returns)? { + if self.compile_annotations_closure(name, parameters, returns, def_source_range)? { annotations_flag.insert(bytecode::MakeFunctionFlag::Annotate); } // Compile function body + self.set_source_range(stmt_source_range); let final_funcflags = funcflags | annotations_flag; - self.compile_function_body(name, parameters, body, is_async, final_funcflags)?; + self.compile_function_body( + name, + parameters, + body, + is_async, + final_funcflags, + def_source_range, + )?; // Handle type params if present if is_generic { @@ -5323,7 +4826,8 @@ impl Compiler { emit!(self, Instruction::ReturnValue); // Set argcount for type params scope - self.current_code_info().metadata.argcount = num_typeparam_args as u32; + self.current_code_info().metadata.argcount = num_typeparam_args; + self.current_code_info().nparams = num_typeparam_args as usize; // Exit type params scope and create closure let type_params_code = self.exit_scope(); @@ -5336,13 +4840,13 @@ impl Compiler { emit!( self, Instruction::Swap { - i: num_typeparam_args as u32 + 1 + i: num_typeparam_args + 1 } ); emit!( self, Instruction::Call { - argc: num_typeparam_args as u32 - 1 + argc: num_typeparam_args - 1 } ); } else { @@ -5357,6 +4861,10 @@ impl Compiler { self.apply_decorators(decorator_list); // Store the function + self.set_source_range(def_source_range); + if preserve_value_before_store { + emit!(self, Instruction::Copy { i: 1 }); + } self.store_name(name)?; Ok(()) @@ -5629,42 +5137,25 @@ impl Compiler { self.code_stack.last_mut().unwrap().private = Some(name.to_owned()); // 2. Set up class namespace - let (doc_str, body) = split_doc(body, &self.opts); + let (doc_str, body) = split_doc_with_range(body, &self.opts); + let class_body_prefix_range = self.source_line_start_range(firstlineno); + self.set_source_range(class_body_prefix_range); // Load __name__ and store as __module__ - let dunder_name = self.name("__name__"); - emit!(self, Instruction::LoadName { namei: dunder_name }); - let dunder_module = self.name("__module__"); - emit!( - self, - Instruction::StoreName { - namei: dunder_module - } - ); + self.load_name("__name__")?; + self.store_name("__module__")?; // Store __qualname__ self.emit_load_const(ConstantData::Str { value: qualname.into(), }); - let qualname_name = self.name("__qualname__"); - emit!( - self, - Instruction::StoreName { - namei: qualname_name - } - ); + self.store_name("__qualname__")?; // Store __firstlineno__ before __doc__ self.emit_load_const(ConstantData::Integer { value: BigInt::from(firstlineno), }); - let firstlineno_name = self.name("__firstlineno__"); - emit!( - self, - Instruction::StoreName { - namei: firstlineno_name - } - ); + self.store_name("__firstlineno__")?; // Set __type_params__ from the enclosing type-params closure when // compiling a generic class body. @@ -5694,17 +5185,19 @@ impl Compiler { } // Store __doc__ only if there's an explicit docstring. - if let Some(doc) = doc_str { + if let Some((doc, range)) = doc_str { + let saved_range = self.current_source_range; + self.set_source_range(range); self.emit_load_const(ConstantData::Str { value: doc.into() }); - let doc_name = self.name("__doc__"); - emit!(self, Instruction::StoreName { namei: doc_name }); + self.store_name("__doc__")?; + self.set_source_range(saved_range); } // 3. Compile the class body self.compile_statements(body)?; if Self::find_ann(body) && !self.future_annotations { - self.compile_module_annotate(body)?; + self.compile_module_annotate(body, Some(class_body_prefix_range))?; } // 4. Handle __classcell__ if needed @@ -5735,13 +5228,7 @@ impl Compiler { .collect(), }); self.set_no_location(); - let static_attrs_name = self.name("__static_attributes__"); - emit!( - self, - Instruction::StoreName { - namei: static_attrs_name - } - ); + self.store_name("__static_attributes__")?; self.set_no_location(); } @@ -5750,13 +5237,7 @@ impl Compiler { let classdict_idx = u32::from(self.get_cell_var_index("__classdict__")); emit!(self, PseudoInstruction::LoadClosure { i: classdict_idx }); self.set_no_location(); - let classdictcell = self.name("__classdictcell__"); - emit!( - self, - Instruction::StoreName { - namei: classdictcell - } - ); + self.store_name("__classdictcell__")?; self.set_no_location(); } @@ -5770,8 +5251,7 @@ impl Compiler { self.set_no_location(); emit!(self, Instruction::Copy { i: 1 }); self.set_no_location(); - let classcell = self.name("__classcell__"); - emit!(self, Instruction::StoreName { namei: classcell }); + self.store_name("__classcell__")?; self.set_no_location(); } else { self.emit_load_const(ConstantData::None); @@ -5793,7 +5273,13 @@ impl Compiler { decorator_list: &[ast::Decorator], type_params: Option<&ast::TypeParams>, arguments: Option<&ast::Arguments>, + preserve_value_before_store: bool, ) -> CompileResult<()> { + // CPython's ClassDef LOC(s) starts at the class line even when + // decorators are present. + let stmt_source_range = self.current_source_range; + let class_source_range = + self.decorated_definition_range(stmt_source_range, decorator_list, "class "); self.prepare_decorators(decorator_list)?; let is_generic = type_params.is_some(); @@ -5828,7 +5314,6 @@ impl Compiler { // TypeParams scope is function-like self.ctx = CompileContext { - loop_data: None, in_class: saved_ctx.in_class, func: FunctionContext::Function, in_async_scope: false, @@ -5837,6 +5322,7 @@ impl Compiler { // Compile type parameters and store them in the synthetic cell that // generic class bodies close over. self.compile_type_params(type_params.unwrap())?; + self.set_source_range(class_source_range); self.store_name(".type_params")?; } @@ -5845,11 +5331,11 @@ impl Compiler { self.ctx = CompileContext { func: FunctionContext::NoFunction, in_class: true, - loop_data: None, in_async_scope: false, }; let class_code = self.compile_class_body(name, body, type_params, firstlineno)?; self.ctx = prev_ctx; + self.set_source_range(class_source_range); // Step 3: Generate the rest of the code for the call if is_generic { @@ -5864,6 +5350,7 @@ impl Compiler { // Create .generic_base after the class function and name are on the // stack so the remaining call shape matches CPython's ordering. + self.set_source_range(class_source_range); self.load_name(".type_params")?; emit!( self, @@ -5871,6 +5358,7 @@ impl Compiler { func: bytecode::IntrinsicFunction1::SubscriptGeneric } ); + self.set_source_range(class_source_range); self.store_name(".generic_base")?; // Compile bases and call __build_class__ @@ -5907,10 +5395,13 @@ impl Compiler { } // Add .generic_base as final element + self.set_source_range(class_source_range); self.load_name(".generic_base")?; + self.set_source_range(class_source_range); emit!(self, Instruction::ListAppend { i: 1 }); // Convert list to tuple + self.set_source_range(class_source_range); emit!( self, Instruction::CallIntrinsic1 { @@ -5920,6 +5411,7 @@ impl Compiler { self.compile_call_function_ex_keywords( arguments.map_or(&[][..], |args| &args.keywords[..]), + class_source_range, )?; emit!(self, Instruction::CallFunctionEx); } else if has_double_star { @@ -5928,7 +5420,9 @@ impl Compiler { self.compile_expression(arg)?; } } + self.set_source_range(class_source_range); self.load_name(".generic_base")?; + self.set_source_range(class_source_range); emit!( self, Instruction::BuildTuple { @@ -5936,7 +5430,10 @@ impl Compiler { .map_or(0, |args| u32::try_from(args.args.len()).unwrap()) } ); - self.compile_call_function_ex_keywords(&arguments.unwrap().keywords[..])?; + self.compile_call_function_ex_keywords( + &arguments.unwrap().keywords[..], + class_source_range, + )?; emit!(self, Instruction::CallFunctionEx); } else { // Simple case: no starred bases, no **kwargs @@ -5951,6 +5448,7 @@ impl Compiler { }; // Load .generic_base as the last base + self.set_source_range(class_source_range); self.load_name(".generic_base")?; let nargs = 2 + u32::try_from(base_count).expect("too many base classes") + 1; @@ -5969,9 +5467,11 @@ impl Compiler { }); self.compile_expression(&keyword.value)?; } + self.set_source_range(class_source_range); self.emit_load_const(ConstantData::Tuple { elements: kwarg_names, }); + self.set_source_range(class_source_range); emit!( self, Instruction::CallKw { @@ -5981,11 +5481,13 @@ impl Compiler { } ); } else { + self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: nargs }); } } // Return the created class + self.set_source_range(class_source_range); self.emit_return_value(); // Exit type params scope and wrap in function @@ -5993,8 +5495,11 @@ impl Compiler { self.ctx = saved_ctx; // Execute the type params function + self.set_source_range(class_source_range); self.make_closure(type_params_code, bytecode::MakeFunctionFlags::new())?; + self.set_source_range(class_source_range); emit!(self, Instruction::PushNull); + self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: 0 }); } else { // Non-generic class: standard path @@ -6006,14 +5511,19 @@ impl Compiler { self.emit_load_const(ConstantData::Str { value: name.into() }); if let Some(arguments) = arguments { - self.codegen_call_helper(2, arguments, self.current_source_range)?; + self.codegen_call_helper(2, arguments, class_source_range, None)?; } else { + self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: 2 }); } } // Step 4: Apply decorators and store (common to both paths) self.apply_decorators(decorator_list); + self.set_source_range(class_source_range); + if preserve_value_before_store { + emit!(self, Instruction::Copy { i: 1 }); + } self.store_name(name) } @@ -6033,29 +5543,11 @@ impl Compiler { self.new_block() }; - if self.has_always_taken_jump_in_test(test, false)? { - self.disable_load_fast_borrow_for_block(next_block); - self.disable_load_fast_borrow_for_block(end_block); - } - let in_direct_try_else_orelse_conditional = { - let base = self.try_else_orelse_conditional_base_stack.last().copied(); - let code_info = self.current_code_info(); - code_info.in_try_else_orelse > 0 - && base.is_some_and(|base| code_info.in_conditional_block == base + 1) - }; - let preserve_try_else_scope_exit_target_nop = elif_else_clauses.is_empty() - && in_direct_try_else_orelse_conditional - && !self.fallthrough_has_local_statement_successor - && Self::statements_end_with_scope_exit(body); self.compile_jump_if(test, false, next_block)?; self.compile_statements(body)?; let Some((clause, rest)) = elif_else_clauses.split_first() else { - self.switch_to_block(end_block); - if preserve_try_else_scope_exit_target_nop { - self.set_source_range(test.range()); - emit!(self, Instruction::Nop); - } + self.use_cpython_label_block(end_block); return Ok(()); }; @@ -6064,7 +5556,7 @@ impl Compiler { PseudoInstruction::JumpNoInterrupt { delta: end_block } ); self.set_no_location(); - self.switch_to_block(next_block); + self.use_cpython_label_block(next_block); if let Some(test) = &clause.test { self.compile_if(test, &clause.body, rest, test.range())?; @@ -6072,7 +5564,7 @@ impl Compiler { debug_assert!(rest.is_empty()); self.compile_statements(&clause.body)?; } - self.switch_to_block(end_block); + self.use_cpython_label_block(end_block); Ok(()) } @@ -6084,34 +5576,28 @@ impl Compiler { ) -> CompileResult<()> { self.enter_conditional_block(); - let while_block = self.new_block(); - let after_block = self.new_block(); - let else_block = if orelse.is_empty() { - after_block - } else { - self.new_block() - }; - - self.switch_to_block(while_block); - self.push_fblock(FBlockType::WhileLoop, while_block, after_block)?; - if matches!(self.constant_expr_truthiness(test)?, Some(false)) { - self.disable_load_fast_borrow_for_block(else_block); - self.disable_load_fast_borrow_for_block(after_block); - } - self.compile_jump_if(test, false, else_block)?; + let loop_block = self.new_block(); + let end_block = self.new_block(); + let anchor_block = self.new_block(); + let loop_label = self.instr_sequence_label_for_block(loop_block); + let end_label = self.instr_sequence_label_for_block(end_block); + self.use_cpython_label_block(loop_block); + self.push_fblock_labels( + FBlockType::WhileLoop, + loop_label, + end_label, + FBlockDatum::None, + )?; + self.compile_jump_if(test, false, anchor_block)?; - let was_in_loop = self.ctx.loop_data.replace((while_block, after_block)); self.compile_loop_body_statements(body)?; - self.ctx.loop_data = was_in_loop; - emit!(self, PseudoInstruction::Jump { delta: while_block }); + emit!(self, PseudoInstruction::Jump { delta: loop_block }); self.set_no_location(); - self.switch_to_block(else_block); - self.pop_fblock(FBlockType::WhileLoop); + self.pop_fblock_label(FBlockType::WhileLoop, loop_label); + self.use_cpython_label_block(anchor_block); self.compile_statements(orelse)?; - if !orelse.is_empty() { - self.switch_to_block(after_block); - } + self.use_cpython_label_block(end_block); self.leave_conditional_block(); Ok(()) @@ -6157,8 +5643,10 @@ impl Compiler { }; let with_range = item.context_expr.range(); + let body_block = self.new_block(); let exc_handler_block = self.new_block(); let after_block = self.new_block(); + let cleanup_block = self.new_block(); // Compile context expression and load __enter__/__exit__ methods self.compile_expression(&item.context_expr)?; @@ -6189,7 +5677,7 @@ impl Compiler { emit!(self, Instruction::Call { argc: 0 }); // [aexit_func, self_ae, awaitable] emit!(self, Instruction::GetAwaitable { r#where: 1 }); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; + let _ = self.compile_yield_from_sequence(true); } else { // Load __exit__ and __enter__, then call __enter__ emit!( @@ -6218,14 +5706,19 @@ impl Compiler { delta: exc_handler_block } ); - self.push_fblock( - if is_async { - FBlockType::AsyncWith - } else { - FBlockType::With - }, - exc_handler_block, // block start (will become exit target after store) - after_block, + self.use_cpython_label_block(body_block); + let fblock_type = if is_async { + FBlockType::AsyncWith + } else { + FBlockType::With + }; + let body_label = self.instr_sequence_label_for_block(body_block); + let exc_handler_label = self.instr_sequence_label_for_block(exc_handler_block); + self.push_fblock_labels( + fblock_type, + body_label, + exc_handler_label, + FBlockDatum::None, )?; // Store or pop the enter result @@ -6251,74 +5744,24 @@ impl Compiler { self.compile_with(items, body, is_async)?; } - let nested_multiline_with_cleanup_target_nop = - !is_async && Self::statements_end_with_scope_exit(body) && { - let parent_with_ranges: Vec<_> = self - .current_code_info() - .fblock - .iter() - .rev() - .skip(1) - .filter_map(|info| { - matches!(info.fb_type, FBlockType::With).then_some(info.fb_range) - }) - .collect(); - let source = self.source_file.to_source_code(); - let current_line = source.line_index(with_range.start()); - parent_with_ranges - .iter() - .any(|range| source.line_index(range.start()) != current_line) - }; - let preserve_outer_cleanup_target_nop = !is_async - && (Self::statements_end_with_with_cleanup_scope_exit(body) - || self.statements_end_with_conditional_scope_exit(body) - || Self::statements_end_with_try_except_handler_fallthrough(body) - || Self::statements_end_with_try_except_else_handler_scope_exit(body) - || Self::statements_end_with_try_finally(body) - || self.statements_end_with_loop_fallthrough(body)?); - let materialize_async_with_outer_cleanup_target_nop = is_async - && Self::statements_end_with_nested_finalbody_try_finally(body) - && self - .current_code_info() - .fblock - .iter() - .any(|info| matches!(info.fb_type, FBlockType::With)); - - // Pop fblock before normal exit. CPython emits this POP_BLOCK with - // no location for sync with, but with the with-item location for - // async with. + // CPython pops the async-with fblock before emitting POP_BLOCK, but + // sync with emits the artificial POP_BLOCK before popping the fblock. if is_async { + self.pop_fblock_label(fblock_type, body_label); self.set_source_range(with_range); - } - emit!(self, PseudoInstruction::PopBlock); - if !is_async { + emit!(self, PseudoInstruction::PopBlock); + } else { + emit!(self, PseudoInstruction::PopBlock); self.set_no_location(); - if preserve_outer_cleanup_target_nop { - self.preserve_last_redundant_nop(); - } else if Self::statements_end_with_try_star_except(body) { - self.force_remove_last_no_location_nop(); - } else { - self.remove_last_no_location_nop(); - } self.set_source_range(with_range); + self.pop_fblock_label(fblock_type, body_label); } - self.pop_fblock(if is_async { - FBlockType::AsyncWith - } else { - FBlockType::With - }); - // ===== Normal exit path ===== - // Stack: [..., exit_func, self_exit] - // Call exit_func(self_exit, None, None, None) - self.emit_load_const(ConstantData::None); - self.emit_load_const(ConstantData::None); - self.emit_load_const(ConstantData::None); - emit!(self, Instruction::Call { argc: 3 }); + self.compile_call_exit_with_nones(); if is_async { emit!(self, Instruction::GetAwaitable { r#where: 2 }); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; + let _ = self.compile_yield_from_sequence(true); } emit!(self, Instruction::PopTop); // Pop __exit__ result emit!(self, PseudoInstruction::Jump { delta: after_block }); @@ -6327,10 +5770,7 @@ impl Compiler { // ===== Exception handler path ===== // Stack at entry: [..., exit_func, self_exit, lasti, exc] // PUSH_EXC_INFO -> [..., exit_func, self_exit, lasti, prev_exc, exc] - self.switch_to_block(exc_handler_block); - - let cleanup_block = self.new_block(); - let suppress_block = self.new_block(); + self.use_cpython_label_block(exc_handler_block); emit!( self, @@ -6338,7 +5778,6 @@ impl Compiler { delta: cleanup_block } ); - self.push_fblock(FBlockType::ExceptionHandler, exc_handler_block, after_block)?; emit!(self, Instruction::PushExcInfo); @@ -6349,57 +5788,12 @@ impl Compiler { if is_async { emit!(self, Instruction::GetAwaitable { r#where: 2 }); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; + let _ = self.compile_yield_from_sequence(true); } - emit!(self, Instruction::ToBool); - emit!( - self, - Instruction::PopJumpIfTrue { - delta: suppress_block - } - ); + self.compile_with_except_finish(cleanup_block); - emit!(self, Instruction::Reraise { depth: 2 }); - - // ===== Suppress block ===== - // Stack: [..., exit_func, self_exit, lasti, prev_exc, exc, True] - self.switch_to_block(suppress_block); - emit!(self, Instruction::PopTop); // pop True - emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::ExceptionHandler); - emit!(self, Instruction::PopExcept); // pop exc, restore prev_exc - emit!(self, Instruction::PopTop); // pop lasti - emit!(self, Instruction::PopTop); // pop self_exit - emit!(self, Instruction::PopTop); // pop exit_func - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: after_block } - ); - - // ===== Cleanup block (for nested exception during __exit__) ===== - // Stack: [..., __exit__, lasti, prev_exc, lasti2, exc2] - // COPY 3: copy prev_exc to TOS - // POP_EXCEPT: restore exception state - // RERAISE 1: re-raise with lasti - // - // NOTE: We DON'T clear the fblock stack here because we want - // outer exception handlers (e.g., try-except wrapping this with statement) - // to be in the exception table for these instructions. - // If we cleared fblock, exceptions here would propagate uncaught. - self.switch_to_block(cleanup_block); - emit!(self, Instruction::Copy { i: 3 }); - emit!(self, Instruction::PopExcept); - emit!(self, Instruction::Reraise { depth: 1 }); - - // ===== After block ===== - self.switch_to_block(after_block); - if materialize_async_with_outer_cleanup_target_nop - || nested_multiline_with_cleanup_target_nop - { - self.set_source_range(with_range); - emit!(self, Instruction::Nop); - } + self.use_cpython_label_block(after_block); self.leave_conditional_block(); Ok(()) @@ -6412,28 +5806,35 @@ impl Compiler { body: &[ast::Stmt], orelse: &[ast::Stmt], is_async: bool, + for_range: TextRange, ) -> CompileResult<()> { self.enter_conditional_block(); // Start loop let for_block = self.new_block(); - let else_block = self.new_block(); - let after_block = self.new_block(); - let split_normal_exit_from_break = !is_async - && self.split_next_for_normal_exit_from_break - && Self::statements_contain_direct_break(body) - && self - .current_code_info() - .fblock - .iter() - .any(|info| matches!(info.fb_type, FBlockType::TryExcept)); - let normal_exit_block = if split_normal_exit_from_break { - self.new_block() + let (body_label, send_block) = if is_async { + (None, self.new_block()) } else { - after_block + let body_label = self.current_code_info().new_instr_sequence_label(); + (Some(body_label), BlockIdx::NULL) }; + let else_block = self.new_block(); + let after_block = self.new_block(); + let for_label = self.instr_sequence_label_for_block(for_block); + let after_label = self.instr_sequence_label_for_block(after_block); let mut end_async_for_target = BlockIdx::NULL; + if !is_async { + // CPython codegen_for() pushes the loop fblock before compiling + // the iterable expression. + self.push_fblock_labels( + FBlockType::ForLoop, + for_label, + after_label, + FBlockDatum::None, + )?; + } + // The thing iterated: self.compile_for_iterable_expression(iter, is_async)?; @@ -6441,18 +5842,27 @@ impl Compiler { if self.ctx.func != FunctionContext::AsyncFunction { return Err(self.error(CodegenErrorType::InvalidAsyncFor)); } + self.set_source_range(iter.range()); emit!(self, Instruction::GetAiter); - self.switch_to_block(for_block); + self.use_cpython_label_block(for_block); + self.set_source_range(for_range); // codegen_async_for: push fblock BEFORE SETUP_FINALLY - self.push_fblock(FBlockType::ForLoop, for_block, after_block)?; + self.push_fblock_labels( + FBlockType::ForLoop, + for_label, + after_label, + FBlockDatum::None, + )?; // SETUP_FINALLY to guard the __anext__ call emit!(self, PseudoInstruction::SetupFinally { delta: else_block }); emit!(self, Instruction::GetAnext); self.emit_load_const(ConstantData::None); - end_async_for_target = self.compile_yield_from_sequence(true)?; + self.use_cpython_label_block(send_block); + let _ = self.compile_yield_from_sequence(true); + end_async_for_target = send_block; // POP_BLOCK for SETUP_FINALLY - only GetANext/yield_from are protected emit!(self, PseudoInstruction::PopBlock); emit!(self, Instruction::NotTaken); @@ -6463,65 +5873,51 @@ impl Compiler { // Retrieve Iterator emit!(self, Instruction::GetIter); - self.switch_to_block(for_block); - - // Push fblock for for loop - self.push_fblock(FBlockType::ForLoop, for_block, after_block)?; + self.use_cpython_label_block(for_block); emit!(self, Instruction::ForIter { delta: else_block }); - // Match CPython's line attribution by compiling the loop target on - // the target range directly instead of leaving a synthetic anchor - // NOP between FOR_ITER and the unpack/store sequence. let saved_range = self.current_source_range; self.set_source_range(target.range()); + emit!(self, Instruction::Nop); + let result = self + .current_code_info() + .use_raw_instr_sequence_label(body_label.expect("sync for must have body label")); + unwrap_internal(self, result); self.compile_store(target)?; self.set_source_range(saved_range); }; - let was_in_loop = self.ctx.loop_data.replace((for_block, after_block)); self.compile_loop_body_statements(body)?; - self.ctx.loop_data = was_in_loop; emit!(self, PseudoInstruction::Jump { delta: for_block }); self.set_no_location(); - self.switch_to_block(else_block); + self.use_cpython_label_block(else_block); // Except block for __anext__ / end of sync for - // No PopBlock here - for async, POP_BLOCK is already in for_block - self.pop_fblock(FBlockType::ForLoop); - - // End-of-loop instructions are on the `for` line, not the body's last line - let saved_range = self.current_source_range; - self.set_source_range(iter.range()); if is_async { + // codegen_async_for emits END_ASYNC_FOR at the iterator location, + // then pops the for-loop fblock before the else block. + let saved_range = self.current_source_range; + self.set_source_range(iter.range()); self.emit_end_async_for(end_async_for_target); + self.set_source_range(saved_range); } else { + // codegen_for emits END_FOR/POP_ITER with NO_LOCATION. Line numbers + // are propagated later by flowgraph.c::resolve_line_numbers(). emit!(self, Instruction::EndFor); + self.set_no_location(); emit!(self, Instruction::PopIter); + self.set_no_location(); } - self.set_source_range(saved_range); + // No PopBlock here - for async, POP_BLOCK is already in for_block + self.pop_fblock_label(FBlockType::ForLoop, for_label); self.compile_statements(orelse)?; - self.switch_to_block(normal_exit_block); + self.use_cpython_label_block(after_block); // Implicit return after for-loop should be attributed to the `for` line self.set_source_range(iter.range()); - if split_normal_exit_from_break { - emit!(self, Instruction::Nop); - self.set_no_location(); - self.preserve_last_redundant_nop(); - self.preserve_last_redundant_jump_as_nop(); - self.mark_last_no_location_exit(); - emit!( - self, - PseudoInstruction::JumpNoInterrupt { delta: after_block } - ); - self.set_no_location(); - self.remove_last_no_location_nop(); - self.switch_to_block(after_block); - self.set_source_range(iter.range()); - } self.leave_conditional_block(); Ok(()) @@ -6541,25 +5937,35 @@ impl Compiler { && elts.len() <= usize::try_from(STACK_USE_GUIDELINE).unwrap() && !elts.iter().any(|e| matches!(e, ast::Expr::Starred(_))) { - if let Some(folded) = self.try_fold_constant_collection(elts, CollectionType::List)? { - self.emit_load_const(folded); - } else { - for elt in elts { - self.compile_expression(elt)?; - } - emit!( - self, - Instruction::BuildTuple { - count: u32::try_from(elts.len()).expect("too many elements"), - } - ); + for elt in elts { + self.compile_expression(elt)?; } + self.set_source_range(iter.range()); + emit!( + self, + Instruction::BuildList { + count: u32::try_from(elts.len()).expect("too many elements"), + } + ); return Ok(()); } self.compile_expression(iter) } + fn compile_comprehension_iter(&mut self, generator: &ast::Comprehension) -> CompileResult<()> { + let saved_range = self.current_source_range; + self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; + self.set_source_range(generator.iter.range()); + if generator.is_async { + emit!(self, Instruction::GetAiter); + } else { + emit!(self, Instruction::GetIter); + } + self.set_source_range(saved_range); + Ok(()) + } + fn singleton_comprehension_assignment_iter(iter: &ast::Expr) -> Option<&ast::Expr> { let elts = match iter { ast::Expr::List(ast::ExprList { elts, .. }) => elts, @@ -6640,12 +6046,14 @@ impl Compiler { } // Iterate over the fail_pop vector in reverse order, skipping the first label. for &label in pc.fail_pop.iter().skip(1).rev() { - self.switch_to_block(label); + // CPython emit_and_reset_fail_pop() uses USE_LABEL here. + self.use_cpython_label_block(label); // Emit the POP instruction. emit!(self, Instruction::PopTop); } // Finally, use the first label. - self.switch_to_block(pc.fail_pop[0]); + // CPython emit_and_reset_fail_pop() uses USE_LABEL here too. + self.use_cpython_label_block(pc.fail_pop[0]); pc.fail_pop.clear(); // Free the memory used by the vector. pc.fail_pop.shrink_to_fit(); @@ -6949,6 +6357,7 @@ impl Compiler { // Compile the class expression. self.compile_expression(&match_class.cls)?; + self.set_source_range(p.range); // Create a new tuple of attribute names. let mut attr_names = vec![]; @@ -7085,13 +6494,11 @@ impl Compiler { } // Check for overflow (INT_MAX < size - 1) - if size > (i32::MAX as usize + 1) { - return Err(self.error(CodegenErrorType::SyntaxError( + let size = u32::try_from(size).map_err(|_| { + self.error(CodegenErrorType::SyntaxError( "too many sub-patterns in mapping pattern".to_string(), - ))); - } - #[allow(clippy::cast_possible_truncation, reason = "checked right before")] - let size = size as u32; + )) + })?; // Step 2: If we have keys to match if size > 0 { @@ -7127,8 +6534,9 @@ impl Compiler { seen.insert(key_repr); } - self.compile_expression(key)?; + self.compile_match_pattern_expr(key)?; } + self.set_source_range(p.range); } // Stack: [subject, key1, key2, ..., key_n] @@ -7256,6 +6664,7 @@ impl Compiler { pc.fail_pop.clear(); pc.on_top = 0; // Emit a COPY(1) instruction before compiling the alternative. + self.set_source_range(alt.range()); emit!(self, Instruction::Copy { i: 1 }); self.compile_pattern(alt, pc)?; @@ -7320,7 +6729,8 @@ impl Compiler { self.jump_to_fail_pop(pc, JumpOp::Jump); // Use the label "end". - self.switch_to_block(end); + // CPython codegen_pattern_or() emits USE_LABEL(c, end). + self.use_cpython_label_block(end); // Adjust the final captures. let n_stores = control.as_ref().unwrap().len(); @@ -7425,7 +6835,7 @@ impl Compiler { // Match CPython codegen_pattern_value(): compare, then normalize to bool // before the fail jump. Late IR folding will collapse COMPARE_OP+TO_BOOL // into COMPARE_OP bool(...) when applicable. - self.compile_expression(&p.value)?; + self.compile_match_pattern_expr(&p.value)?; emit!( self, Instruction::CompareOp { @@ -7519,6 +6929,7 @@ impl Compiler { for (i, m) in cases.iter().enumerate().take(case_count) { // Only copy the subject if not on the last case if i != case_count - 1 { + self.set_source_range(m.pattern.range()); emit!(self, Instruction::Copy { i: 1 }); } @@ -7549,30 +6960,25 @@ impl Compiler { if let Some(first_stmt) = m.body.first() { self.set_source_range(first_stmt.range()); } - if matches!(m.pattern, ast::Pattern::MatchOr(_)) { - emit!(self, Instruction::Nop); - } emit!(self, Instruction::PopTop); - if matches!(m.body.first(), Some(ast::Stmt::Try(_))) { - let body_block = self.new_block(); - self.switch_to_block(body_block); - } + // CPython emits NEXT_LOCATION here; resolve it after redundant + // NOP removal so a following pass NOP survives. + self.set_last_emitted_lineno_override(ir::NEXT_LOCATION_OVERRIDE); } self.compile_statements(&m.body)?; - emit!(self, PseudoInstruction::JumpNoInterrupt { delta: end }); - if let Some(last) = self.current_block().instructions.last_mut() { - last.match_success_jump = true; - } + emit!(self, PseudoInstruction::Jump { delta: end }); + self.set_no_location(); self.set_source_range(m.pattern.range()); self.emit_and_reset_fail_pop(pattern_context); } if has_default { let m = &cases[num_cases - 1]; + self.set_source_range(m.pattern.range()); if num_cases == 1 { emit!(self, Instruction::PopTop); - } else if m.guard.is_none() { + } else { emit!(self, Instruction::Nop); } if let Some(ref guard) = m.guard { @@ -7580,7 +6986,8 @@ impl Compiler { } self.compile_statements(&m.body)?; } - self.switch_to_block(end); + // CPython codegen_match_inner() emits USE_LABEL(c, end). + self.use_cpython_label_block(end); Ok(()) } @@ -7704,11 +7111,11 @@ impl Compiler { self.set_no_location(); // early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs. - self.switch_to_block(cleanup); + self.use_cpython_label_block(cleanup); emit!(self, Instruction::Swap { i: 2 }); emit!(self, Instruction::PopTop); - self.switch_to_block(end); + self.use_cpython_label_block(end); Ok(()) } @@ -7724,9 +7131,8 @@ impl Compiler { let (last_op, mid_ops) = ops.split_last().unwrap(); let (last_comparator, mid_comparators) = comparators.split_last().unwrap(); - self.compile_expression(left)?; - if mid_comparators.is_empty() { + self.compile_expression(left)?; self.compile_expression(last_comparator)?; self.set_source_range(compare_range); self.compile_addcompare(last_op); @@ -7735,7 +7141,7 @@ impl Compiler { } let cleanup = self.new_block(); - let end = self.new_block(); + self.compile_expression(left)?; for (op, comparator) in mid_ops.iter().zip(mid_comparators) { self.compile_expression(comparator)?; @@ -7752,21 +7158,23 @@ impl Compiler { self.compile_addcompare(last_op); emit!(self, Instruction::ToBool); self.emit_pop_jump_by_condition(condition, target_block); + let end = self.new_block(); emit!(self, PseudoInstruction::JumpNoInterrupt { delta: end }); self.set_no_location(); - self.switch_to_block(cleanup); + self.use_cpython_label_block(cleanup); emit!(self, Instruction::PopTop); if !condition { + self.set_no_location(); emit!( self, - PseudoInstruction::Jump { + PseudoInstruction::JumpNoInterrupt { delta: target_block } ); } - self.switch_to_block(end); + self.use_cpython_label_block(end); Ok(()) } @@ -7790,6 +7198,7 @@ impl Compiler { fn compile_annotation(&mut self, annotation: &ast::Expr) -> CompileResult<()> { if self.future_annotations { + self.set_source_range(annotation.range()); self.emit_load_const(ConstantData::Str { value: UnparseExpr::new(annotation, &self.source_file) .to_string() @@ -7854,6 +7263,7 @@ impl Compiler { annotation: &ast::Expr, value: Option<&ast::Expr>, simple: bool, + loc: TextRange, ) -> CompileResult<()> { // Perform the actual assignment first if let Some(value) = value { @@ -7870,6 +7280,7 @@ impl Compiler { // PEP 563: Store stringified annotation directly to __annotations__ // Compile annotation as string self.compile_annotation(annotation)?; + self.set_source_range(loc); // Load __annotations__ let annotations_name = self.name("__annotations__"); emit!( @@ -7879,10 +7290,12 @@ impl Compiler { } ); // Load the variable name + self.set_source_range(loc); self.emit_load_const(ConstantData::Str { value: self.mangle(id.as_str()).into_owned().into(), }); // Store: __annotations__[name] = annotation + self.set_source_range(loc); emit!(self, Instruction::StoreSubscr); } else { // PEP 649: Handle conditional annotations @@ -7944,6 +7357,11 @@ impl Compiler { ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { self.maybe_add_static_attribute_to_class(value, attr.as_str()); self.compile_expression(value)?; + self.set_source_range(self.update_start_location_to_match_attr( + target.range(), + target.range(), + attr.as_str(), + )); let namei = self.name(attr.as_str()); emit!(self, Instruction::StoreAttr { namei }); } @@ -8012,15 +7430,25 @@ impl Compiler { op: &ast::Operator, value: &ast::Expr, ) -> CompileResult<()> { + let stmt_range = self.current_source_range; + let target_range = target.range(); enum AugAssignKind<'a> { - Name { id: &'a str }, - Subscript { use_slice_opt: bool }, - Attr { idx: bytecode::NameIdx }, + Name { + id: &'a str, + }, + Subscript { + use_slice_opt: bool, + }, + Attr { + idx: bytecode::NameIdx, + attr_range: TextRange, + }, } let kind = match &target { ast::Expr::Name(ast::ExprName { id, .. }) => { let id = id.as_str(); + self.set_source_range(target_range); self.compile_name(id, NameUsage::Load)?; AugAssignKind::Name { id } } @@ -8032,6 +7460,7 @@ impl Compiler { }) => { let use_slice_opt = slice.should_use_slice_optimization(); self.compile_expression(value)?; + self.set_source_range(target_range); if use_slice_opt { let ast::Expr::Slice(slice_expr) = slice.as_ref() else { unreachable!( @@ -8039,12 +7468,14 @@ impl Compiler { ); }; self.compile_slice_two_parts(slice_expr)?; + self.set_source_range(target_range); emit!(self, Instruction::Copy { i: 3 }); emit!(self, Instruction::Copy { i: 3 }); emit!(self, Instruction::Copy { i: 3 }); emit!(self, Instruction::BinarySlice); } else { self.compile_expression(slice)?; + self.set_source_range(target_range); emit!(self, Instruction::Copy { i: 2 }); emit!(self, Instruction::Copy { i: 2 }); emit!( @@ -8059,10 +7490,14 @@ impl Compiler { ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { let attr = attr.as_str(); self.compile_expression(value)?; + let attr_range = + self.update_start_location_to_match_attr(target_range, target_range, attr); + self.set_source_range(attr_range); emit!(self, Instruction::Copy { i: 1 }); let idx = self.name(attr); + self.set_source_range(attr_range); self.emit_load_attr(idx); - AugAssignKind::Attr { idx } + AugAssignKind::Attr { idx, attr_range } } _ => { return Err(self.error(CodegenErrorType::Assign(target.python_name()))); @@ -8070,14 +7505,17 @@ impl Compiler { }; self.compile_expression(value)?; + self.set_source_range(stmt_range); self.compile_op(op, true); match kind { AugAssignKind::Name { id } => { // stack: RESULT + self.set_source_range(target_range); self.compile_name(id, NameUsage::Store)?; } AugAssignKind::Subscript { use_slice_opt } => { + self.set_source_range(target_range); if use_slice_opt { // stack: CONTAINER START STOP RESULT emit!(self, Instruction::Swap { i: 4 }); @@ -8091,8 +7529,9 @@ impl Compiler { emit!(self, Instruction::StoreSubscr); } } - AugAssignKind::Attr { idx } => { + AugAssignKind::Attr { idx, attr_range } => { // stack: CONTAINER RESULT + self.set_source_range(attr_range); emit!(self, Instruction::Swap { i: 2 }); emit!(self, Instruction::StoreAttr { namei: idx }); } @@ -8157,7 +7596,7 @@ impl Compiler { self.compile_jump_if_inner(last_value, condition, target_block, source_range)?; if next2 != target_block { - self.switch_to_block(next2); + self.use_cpython_label_block(next2); } Ok(()) } @@ -8176,10 +7615,10 @@ impl Compiler { emit!(self, PseudoInstruction::JumpNoInterrupt { delta: end }); self.set_no_location(); - self.switch_to_block(next2); + self.use_cpython_label_block(next2); self.compile_jump_if_inner(orelse, condition, target_block, source_range)?; - self.switch_to_block(end); + self.use_cpython_label_block(end); Ok(()) } ast::Expr::Compare(ast::ExprCompare { @@ -8190,56 +7629,10 @@ impl Compiler { }) if ops.len() > 1 => { self.compile_jump_if_compare(left, ops, comparators, condition, target_block) } - // `x is None` / `x is not None` → POP_JUMP_IF_NONE / POP_JUMP_IF_NOT_NONE - ast::Expr::Compare(ast::ExprCompare { - left, - ops, - comparators, - .. - }) if ops.len() == 1 - && matches!(ops[0], ast::CmpOp::Is | ast::CmpOp::IsNot) - && comparators.len() == 1 - && matches!(&comparators[0], ast::Expr::NoneLiteral(_)) => - { - self.compile_expression(left)?; - let source = self.source_file.to_source_code(); - let comparator_line = source.line_index(comparators[0].range().start()); - let left_line = source.line_index(left.range().start()); - if comparator_line != left_line { - self.set_source_range(comparators[0].range()); - emit!(self, Instruction::Nop); - self.set_source_range(source_range.unwrap_or_else(|| expression.range())); - } - let is_not = matches!(ops[0], ast::CmpOp::IsNot); - // is None + jump_if_false → POP_JUMP_IF_NOT_NONE - // is None + jump_if_true → POP_JUMP_IF_NONE - // is not None + jump_if_false → POP_JUMP_IF_NONE - // is not None + jump_if_true → POP_JUMP_IF_NOT_NONE - let jump_if_none = condition != is_not; - if jump_if_none { - emit!( - self, - Instruction::PopJumpIfNone { - delta: target_block, - } - ); - } else { - emit!( - self, - Instruction::PopJumpIfNotNone { - delta: target_block, - } - ); - } - Ok(()) - } _ => { // Fall back case which always will work! - if matches!(self.constant_expr_truthiness(expression)?, Some(value) if value == condition) - { - self.disable_load_fast_borrow_for_block(target_block); - } self.compile_expression(expression)?; + self.set_source_range(expression.range()); emit!(self, Instruction::ToBool); if condition { emit!( @@ -8276,85 +7669,37 @@ impl Compiler { /// Compile a boolean operation as an expression. /// This means, that the last value remains on the stack. fn compile_bool_op(&mut self, op: &ast::BoolOp, values: &[ast::Expr]) -> CompileResult<()> { - fn flatten_same_boolop_values<'a>( - op: &ast::BoolOp, - value: &'a ast::Expr, - out: &mut Vec<&'a ast::Expr>, - ) { - if let ast::Expr::BoolOp(ast::ExprBoolOp { - op: inner_op, - values, - .. - }) = value - && inner_op == op - { - for value in values { - flatten_same_boolop_values(op, value, out); - } - } else { - out.push(value); - } - } - - let mut flattened = Vec::with_capacity(values.len()); - for value in values { - flatten_same_boolop_values(op, value, &mut flattened); - } - + let boolop_range = self.current_source_range; let after_block = self.new_block(); - let (last_value, prefix_values) = flattened.split_last().unwrap(); + let (last_value, prefix_values) = values.split_last().unwrap(); for value in prefix_values { - let continue_block = self.new_block(); self.compile_expression(value)?; + self.set_source_range(boolop_range); self.emit_short_circuit_test(op, after_block); - self.switch_to_block(continue_block); + self.set_source_range(boolop_range); emit!(self, Instruction::PopTop); } self.compile_expression(last_value)?; - self.switch_to_block(after_block); - Ok(()) - } - - fn compile_bool_op_with_head_constant( - &mut self, - op: &ast::BoolOp, - head: ConstantData, - tail: &[ast::Expr], - ) -> CompileResult<()> { - self.emit_load_const(head); - self.mark_last_instruction_folded_from_nonliteral_expr(); - if tail.is_empty() { - return Ok(()); - } - - let after_block = self.new_block(); - for value in tail { - self.emit_short_circuit_test(op, after_block); - emit!(self, Instruction::PopTop); - self.compile_expression(value)?; - } - self.switch_to_block(after_block); + self.use_cpython_label_block(after_block); Ok(()) } - /// Emit `Copy 1` + conditional jump for short-circuit evaluation. - /// For `And`, emits `PopJumpIfFalse`; for `Or`, emits `PopJumpIfTrue`. + /// Emit CPython-style pseudo conditional jump for short-circuit evaluation. + /// flowgraph.c lowers it to `COPY 1; TO_BOOL; POP_JUMP_IF_*`. fn emit_short_circuit_test(&mut self, op: &ast::BoolOp, target: BlockIdx) { - emit!(self, Instruction::Copy { i: 1 }); - emit!(self, Instruction::ToBool); match op { ast::BoolOp::And => { - emit!(self, Instruction::PopJumpIfFalse { delta: target }); + emit!(self, PseudoInstruction::JumpIfFalse { delta: target }); } ast::BoolOp::Or => { - emit!(self, Instruction::PopJumpIfTrue { delta: target }); + emit!(self, PseudoInstruction::JumpIfTrue { delta: target }); } } } - fn compile_dict(&mut self, items: &[ast::DictItem]) -> CompileResult<()> { + fn compile_dict(&mut self, items: &[ast::DictItem], range: TextRange) -> CompileResult<()> { let has_unpacking = items.iter().any(|item| item.key.is_none()); if !has_unpacking { @@ -8370,6 +7715,7 @@ impl Compiler { self.compile_expression(item.key.as_ref().unwrap())?; self.compile_expression(&item.value)?; } + self.set_source_range(range); emit!( self, Instruction::BuildMap { @@ -8396,11 +7742,13 @@ impl Compiler { (total_map_add, 0usize) }; + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); let mut idx = 0; for chunk_i in 0..big_count { if chunk_i > 0 { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); } let chunk_size = if idx + BIG_MAP_CHUNK <= n - tail_count { @@ -8411,9 +7759,11 @@ impl Compiler { for item in &items[idx..idx + chunk_size] { self.compile_expression(item.key.as_ref().unwrap())?; self.compile_expression(&item.value)?; + self.set_source_range(range); emit!(self, Instruction::MapAdd { i: 1 }); } if chunk_i > 0 { + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } idx += chunk_size; @@ -8425,12 +7775,14 @@ impl Compiler { self.compile_expression(item.key.as_ref().unwrap())?; self.compile_expression(&item.value)?; } + self.set_source_range(range); emit!( self, Instruction::BuildMap { count: tail_count.to_u32(), } ); + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } } @@ -8449,8 +7801,10 @@ impl Compiler { () => { #[allow(unused_assignments)] if elements > 0 { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: elements }); if have_dict { + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } else { have_dict = true; @@ -8470,16 +7824,19 @@ impl Compiler { // ** unpacking entry flush_pending!(); if !have_dict { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); have_dict = true; } self.compile_expression(&item.value)?; + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } } flush_pending!(); if !have_dict { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); } @@ -8493,7 +7850,7 @@ impl Compiler { /// SEND exit /// SETUP_FINALLY fail (via exception table) /// YIELD_VALUE 1 - /// POP_BLOCK (implicit) + /// POP_BLOCK (NO_LOCATION) /// RESUME /// JUMP send /// fail: @@ -8501,29 +7858,24 @@ impl Compiler { /// JUMP exit /// exit: /// END_SEND - fn compile_yield_from_sequence(&mut self, is_await: bool) -> CompileResult { + fn compile_yield_from_sequence(&mut self, is_await: bool) -> BlockIdx { let send_block = self.new_block(); let fail_block = self.new_block(); let exit_block = self.new_block(); // send: - self.switch_to_block(send_block); + self.use_cpython_label_block(send_block); emit!(self, Instruction::Send { delta: exit_block }); // SETUP_FINALLY fail - set up exception handler for YIELD_VALUE emit!(self, PseudoInstruction::SetupFinally { delta: fail_block }); - self.push_fblock( - FBlockType::TryExcept, // Use TryExcept for exception handler - send_block, - exit_block, - )?; // YIELD_VALUE with arg=1 (yield-from/await mode - not wrapped for async gen) emit!(self, Instruction::YieldValue { arg: 1 }); // POP_BLOCK before RESUME emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::TryExcept); + self.set_no_location(); // RESUME emit!( @@ -8549,16 +7901,16 @@ impl Compiler { // CPython lets this block fall through to END_SEND during codegen; // push_cold_blocks_to_end later inserts the no-interrupt jump after // moving the cold fail block behind the warm exit path. - self.switch_to_block(fail_block); + self.use_cpython_label_block(fail_block); emit!(self, Instruction::CleanupThrow); // exit: END_SEND // Stack: [receiver, value] (from SEND) or [None, value] (from CLEANUP_THROW) // END_SEND: [receiver/None, value] -> [value] - self.switch_to_block(exit_block); + self.use_cpython_label_block(exit_block); emit!(self, Instruction::EndSend); - Ok(send_block) + send_block } /// Returns true if the expression is a constant with no side effects. @@ -8571,26 +7923,7 @@ impl Compiler { | ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) | ast::Expr::EllipsisLiteral(_) - ) || matches!(expr, ast::Expr::FString(fstring) if Self::fstring_value_is_const(&fstring.value)) - } - - fn fstring_value_is_const(fstring: &ast::FStringValue) -> bool { - for part in fstring { - if !Self::fstring_part_is_const(part) { - return false; - } - } - true - } - - fn fstring_part_is_const(part: &ast::FStringPart) -> bool { - match part { - ast::FStringPart::Literal(_) => true, - ast::FStringPart::FString(fstring) => fstring - .elements - .iter() - .all(|element| matches!(element, ast::InterpolatedStringElement::Literal(_))), - } + ) } fn compile_expression(&mut self, expression: &ast::Expr) -> CompileResult<()> { @@ -8598,16 +7931,6 @@ impl Compiler { let range = expression.range(); self.set_source_range(range); - if let ast::Expr::Subscript(ast::ExprSubscript { - ctx: ast::ExprContext::Load, - .. - }) = expression - && let Some(constant) = self.try_fold_constant_expr(expression)? - { - self.emit_load_const(constant); - return Ok(()); - } - if matches!(expression, ast::Expr::BinOp(_)) && let Some(constant) = self.try_fold_constant_expr(expression)? { @@ -8615,61 +7938,6 @@ impl Compiler { return Ok(()); } - if !self.disable_const_boolop_folding - && let ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) = expression - { - let mut simplified_prefix = 0usize; - let mut last_constant = None; - let mut retained_head = None; - for value in values { - let Some(constant) = self.try_fold_constant_expr(value)? else { - break; - }; - if !Self::boolop_fast_fold_literal(value) { - retained_head = Some(constant); - simplified_prefix += 1; - break; - } - let is_truthy = Self::constant_truthiness(&constant); - last_constant = Some(constant); - match op { - ast::BoolOp::Or if is_truthy => { - self.emit_load_const(last_constant.expect("missing boolop constant")); - self.mark_last_instruction_folded_from_nonliteral_expr(); - return Ok(()); - } - ast::BoolOp::And if !is_truthy => { - self.emit_load_const(last_constant.expect("missing boolop constant")); - self.mark_last_instruction_folded_from_nonliteral_expr(); - return Ok(()); - } - ast::BoolOp::Or | ast::BoolOp::And => { - simplified_prefix += 1; - } - } - } - - if let Some(head) = retained_head { - self.compile_bool_op_with_head_constant(op, head, &values[simplified_prefix..])?; - return Ok(()); - } - if simplified_prefix == values.len() { - self.emit_load_const(last_constant.expect("missing folded boolop constant")); - self.mark_last_instruction_folded_from_nonliteral_expr(); - return Ok(()); - } - if simplified_prefix > 0 { - let tail = &values[simplified_prefix..]; - if let [value] = tail { - self.compile_expression(value)?; - } else { - self.compile_bool_op(op, tail)?; - } - self.mark_last_instruction_folded_from_nonliteral_expr(); - return Ok(()); - } - } - match &expression { ast::Expr::Call(ast::ExprCall { func, arguments, .. @@ -8693,7 +7961,22 @@ impl Compiler { self.compile_subscript(value, slice, *ctx)?; } ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { - self.compile_expression(operand)?; + if let ( + ast::UnaryOp::Not, + ast::Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + .. + }), + ) = (op, operand.as_ref()) + && ops.len() == 1 + { + self.set_source_range(range); + self.compile_compare(left, ops, comparators)?; + } else { + self.compile_expression(operand)?; + } // Restore full expression range before emitting the operation self.set_source_range(range); @@ -8717,11 +8000,14 @@ impl Compiler { if let Some(super_type) = self.can_optimize_super_call(value, attr.as_str()) { // super().attr or super(cls, self).attr optimization // Stack: [global_super, class, self] → LOAD_SUPER_ATTR → [attr] - // Set source range to super() call for arg-loading instructions - let super_range = value.range(); - self.set_source_range(super_range); - self.load_args_for_super(&super_type)?; - self.set_source_range(super_range); + let ast::Expr::Call(ast::ExprCall { + func: super_func, .. + }) = value.as_ref() + else { + unreachable!("can_optimize_super_call only accepts calls"); + }; + self.load_args_for_super(&super_type, super_func.range(), value.range())?; + self.set_source_range(range); let idx = self.name(attr.as_str()); match super_type { SuperCallType::TwoArg { .. } => { @@ -8734,6 +8020,11 @@ impl Compiler { } else { // Normal attribute access self.compile_expression(value)?; + self.set_source_range(self.update_start_location_to_match_attr( + range, + range, + attr.as_str(), + )); let idx = self.name(attr.as_str()); self.emit_load_attr(idx); } @@ -8749,38 +8040,49 @@ impl Compiler { // ast::Expr::Constant(ExprConstant { value, .. }) => { // self.emit_load_const(compile_constant(value)); // } - ast::Expr::List(ast::ExprList { elts, .. }) => { + ast::Expr::List(ast::ExprList { elts, range, .. }) => { + self.set_source_range(*range); self.starunpack_helper(elts, 0, CollectionType::List)?; } - ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { + ast::Expr::Tuple(ast::ExprTuple { elts, range, .. }) => { + self.set_source_range(*range); self.starunpack_helper(elts, 0, CollectionType::Tuple)?; } - ast::Expr::Set(ast::ExprSet { elts, .. }) => { + ast::Expr::Set(ast::ExprSet { elts, range, .. }) => { + self.set_source_range(*range); self.starunpack_helper(elts, 0, CollectionType::Set)?; } - ast::Expr::Dict(ast::ExprDict { items, .. }) => { - self.compile_dict(items)?; + ast::Expr::Dict(ast::ExprDict { items, range, .. }) => { + self.compile_dict(items, *range)?; } ast::Expr::Slice(ast::ExprSlice { - lower, upper, step, .. + lower, + upper, + step, + range, + .. }) => { if let Some(folded_const) = self.try_fold_constant_slice( lower.as_deref(), upper.as_deref(), step.as_deref(), )? { + self.set_source_range(*range); self.emit_load_const(folded_const); return Ok(()); } - let mut compile_bound = |bound: Option<&ast::Expr>| match bound { - Some(exp) => self.compile_expression(exp), - None => { - self.emit_load_const(ConstantData::None); - Ok(()) - } - }; - compile_bound(lower.as_deref())?; - compile_bound(upper.as_deref())?; + if let Some(lower) = lower { + self.compile_expression(lower)?; + } else { + self.set_source_range(*range); + self.emit_load_const(ConstantData::None); + } + if let Some(upper) = upper { + self.compile_expression(upper)?; + } else { + self.set_source_range(*range); + self.emit_load_const(ConstantData::None); + } if let Some(step) = step { self.compile_expression(step)?; } @@ -8788,6 +8090,7 @@ impl Compiler { Some(_) => BuildSliceArgCount::Three, None => BuildSliceArgCount::Two, }; + self.set_source_range(*range); emit!(self, Instruction::BuildSlice { argc }); } ast::Expr::Yield(ast::ExprYield { value, .. }) => { @@ -8799,6 +8102,7 @@ impl Compiler { Some(expression) => self.compile_expression(expression)?, Option::None => self.emit_load_const(ConstantData::None), }; + self.set_source_range(range); if self.ctx.func == FunctionContext::AsyncFunction { emit!( self, @@ -8821,9 +8125,10 @@ impl Compiler { return Err(self.error(CodegenErrorType::InvalidAwait)); } self.compile_expression(value)?; + self.set_source_range(range); emit!(self, Instruction::GetAwaitable { r#where: 0 }); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; + let _ = self.compile_yield_from_sequence(true); } ast::Expr::YieldFrom(ast::ExprYieldFrom { value, .. }) => { match self.ctx.func { @@ -8837,13 +8142,17 @@ impl Compiler { } self.mark_generator(); self.compile_expression(value)?; + self.set_source_range(range); emit!(self, Instruction::GetYieldFromIter); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(false)?; + let _ = self.compile_yield_from_sequence(false); } ast::Expr::Name(ast::ExprName { id, .. }) => self.load_name(id.as_str())?, ast::Expr::Lambda(ast::ExprLambda { - parameters, body, .. + parameters, + body, + range, + .. }) => { let default_params = ast::Parameters::default(); let params = parameters.as_deref().unwrap_or(&default_params); @@ -8865,6 +8174,7 @@ impl Compiler { for element in &defaults { self.compile_expression(element)?; } + self.set_source_range(*range); emit!(self, Instruction::BuildTuple { count: size }); } @@ -8880,11 +8190,13 @@ impl Compiler { if have_kwdefaults { let default_kw_count = kw_with_defaults.len(); for (arg, default) in &kw_with_defaults { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: self.mangle(arg.name.as_str()).into_owned().into(), }); self.compile_expression(default)?; } + self.set_source_range(*range); emit!( self, Instruction::BuildMap { @@ -8906,20 +8218,27 @@ impl Compiler { self.set_qualname(); self.ctx = CompileContext { - loop_data: Option::None, in_class: prev_ctx.in_class, func: FunctionContext::Function, // Lambda is never async, so new scope is not async in_async_scope: false, }; - // Lambda cannot have docstrings, so no None is added to co_consts - self.compile_expression(body)?; + self.set_source_range(body.range()); self.emit_return_value(); + // _PyCodegen_AddReturnAtEnd() appends a no-location + // return-None epilogue even after lambda's explicit + // RETURN_VALUE. It is later removed as unreachable, but + // remove_unused_consts() keeps None when it was the first + // constant in an otherwise constant-free lambda. + if self.current_code_info().metadata.consts.is_empty() { + self.arg_constant(ConstantData::None); + } let code = self.exit_scope(); // Create lambda function with closure + self.set_source_range(*range); self.make_closure(code, func_flags)?; self.ctx = prev_ctx; @@ -8941,6 +8260,7 @@ impl Compiler { generators, &|compiler, collection_add_i| { compiler.compile_comprehension_element(elt)?; + compiler.set_source_range(elt.range()); emit!( compiler, Instruction::ListAppend { @@ -8952,6 +8272,8 @@ impl Compiler { ComprehensionType::List, Self::contains_await(elt) || Self::generators_contain_await(generators), *range, + elt.range(), + elt.range(), )?; } ast::Expr::SetComp(ast::ExprSetComp { @@ -8971,6 +8293,7 @@ impl Compiler { generators, &|compiler, collection_add_i| { compiler.compile_comprehension_element(elt)?; + compiler.set_source_range(elt.range()); emit!( compiler, Instruction::SetAdd { @@ -8982,6 +8305,8 @@ impl Compiler { ComprehensionType::Set, Self::contains_await(elt) || Self::generators_contain_await(generators), *range, + elt.range(), + elt.range(), )?; } ast::Expr::DictComp(ast::ExprDictComp { @@ -9005,6 +8330,10 @@ impl Compiler { compiler.compile_expression(key)?; compiler.compile_expression(value)?; + compiler.set_source_range(TextRange::new( + key.range().start(), + value.range().end(), + )); emit!( compiler, Instruction::MapAdd { @@ -9019,6 +8348,8 @@ impl Compiler { || Self::contains_await(value) || Self::generators_contain_await(generators), *range, + TextRange::new(key.range().start(), value.range().end()), + key.range(), )?; } ast::Expr::Generator(ast::ExprGenerator { @@ -9027,47 +8358,7 @@ impl Compiler { range, .. }) => { - // Check if element or generators contain async content - // This makes the generator expression into an async generator - let element_contains_await = - Self::contains_await(elt) || Self::generators_contain_await(generators); - self.compile_comprehension( - "", - None, - generators, - &|compiler, _collection_add_i| { - // Compile the element expression - // Note: if element is an async comprehension, compile_expression - // already handles awaiting it, so we don't need to await again here - compiler.compile_comprehension_element(elt)?; - - compiler.mark_generator(); - if compiler.ctx.func == FunctionContext::AsyncFunction { - emit!( - compiler, - Instruction::CallIntrinsic1 { - func: bytecode::IntrinsicFunction1::AsyncGenWrap - } - ); - } - // arg=0: direct yield (wrapped for async generators) - emit!(compiler, Instruction::YieldValue { arg: 0 }); - emit!( - compiler, - Instruction::Resume { - context: oparg::ResumeContext::from( - oparg::ResumeLocation::AfterYield - ) - } - ); - emit!(compiler, Instruction::PopTop); - - Ok(()) - }, - ComprehensionType::Generator, - element_contains_await, - *range, - )?; + self.compile_generator_expression(elt, generators, *range)?; } ast::Expr::Starred(ast::ExprStarred { value, .. }) => { if self.in_annotation { @@ -9082,19 +8373,12 @@ impl Compiler { ast::Expr::If(ast::ExprIf { test, body, orelse, .. }) => { - let folded_test_truthiness = self - .try_fold_constant_expr(test)? - .as_ref() - .map(Self::constant_truthiness); - let else_block = self.new_block(); let after_block = self.new_block(); + let else_block = self.new_block(); self.compile_jump_if(test, false, else_block)?; // True case self.compile_expression(body)?; - if folded_test_truthiness == Some(true) { - self.mark_last_instruction_folded_from_nonliteral_expr(); - } emit!( self, PseudoInstruction::JumpNoInterrupt { delta: after_block } @@ -9102,21 +8386,18 @@ impl Compiler { self.set_no_location(); // False case - self.switch_to_block(else_block); + self.use_cpython_label_block(else_block); self.compile_expression(orelse)?; - if folded_test_truthiness == Some(false) { - self.mark_last_instruction_folded_from_nonliteral_expr(); - } // End - self.switch_to_block(after_block); + self.use_cpython_label_block(after_block); } ast::Expr::Named(ast::ExprNamed { target, value, node_index: _, - range: _, + range, }) => { // Walrus targets in inlined comps should NOT be hidden from locals() if self.current_code_info().in_inlined_comp @@ -9128,8 +8409,10 @@ impl Compiler { info.metadata.fast_hidden_final.swap_remove(name.as_ref()); } self.compile_expression(value)?; + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.compile_store(target)?; + self.set_source_range(target.range()); } ast::Expr::FString(fstring) => { self.compile_expr_fstring(fstring)?; @@ -9176,21 +8459,42 @@ impl Compiler { Ok(()) } - fn detect_builtin_generator_call( + fn cpython_sync_genexpr_call_name<'a>( &self, - func: &ast::Expr, + func: &'a ast::Expr, args: &ast::Arguments, - ) -> Option { + ) -> Option<&'a str> { let ast::Expr::Name(ast::ExprName { id, .. }) = func else { return None; }; - if args.args.len() != 1 - || !args.keywords.is_empty() - || !matches!(args.args[0], ast::Expr::Generator(_)) - { + let [ + ast::Expr::Generator(ast::ExprGenerator { + elt: _, + generators: _, + .. + }), + ] = &args.args[..] + else { + return None; + }; + if !args.keywords.is_empty() || { + let table = self.current_symbol_table(); + table + .sub_tables + .get(table.next_sub_table) + .is_none_or(|generator_entry| generator_entry.is_coroutine) + } { return None; } - match id.as_str() { + Some(id.as_str()) + } + + fn detect_builtin_generator_call( + &self, + func: &ast::Expr, + args: &ast::Arguments, + ) -> Option { + match self.cpython_sync_genexpr_call_name(func, args)? { "tuple" => Some(BuiltinGeneratorCallKind::Tuple), "all" => Some(BuiltinGeneratorCallKind::All), "any" => Some(BuiltinGeneratorCallKind::Any), @@ -9201,12 +8505,13 @@ impl Compiler { /// Emit the optimized inline loop for builtin(genexpr) calls. /// /// Stack on entry: `[func]` where `func` is the builtin candidate. - /// On return the compiler is positioned at the fallback block so the + /// On return the compiler is positioned at the skip-optimization block so the /// normal call path can compile the original generator argument again. fn optimize_builtin_generator_call( &mut self, kind: BuiltinGeneratorCallKind, generator_expr: &ast::Expr, + loc: TextRange, end: BlockIdx, ) -> CompileResult<()> { let common_constant = match kind { @@ -9215,11 +8520,10 @@ impl Compiler { BuiltinGeneratorCallKind::Any => bytecode::CommonConstant::BuiltinAny, }; - let fallback = self.new_block(); - let loop_block = self.new_block(); - let cleanup = self.new_block(); + let skip_optimization = self.new_block(); - // Stack: [func] — copy function for identity check + // Stack: [func] - copy function for identity check + self.set_source_range(loc); emit!(self, Instruction::Copy { i: 1 }); emit!( self, @@ -9228,49 +8532,78 @@ impl Compiler { } ); emit!(self, Instruction::IsOp { invert: Invert::No }); - emit!(self, Instruction::PopJumpIfFalse { delta: fallback }); + emit!( + self, + Instruction::PopJumpIfFalse { + delta: skip_optimization + } + ); emit!(self, Instruction::PopTop); if matches!(kind, BuiltinGeneratorCallKind::Tuple) { + self.set_source_range(loc); emit!(self, Instruction::BuildList { count: 0 }); } let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); - self.compile_expression(generator_expr)?; + if let Some(range) = self.cpython_implicit_call_generator_range(generator_expr) { + self.compile_expression_with_generator_range(generator_expr, range)?; + } else { + self.compile_expression(generator_expr)?; + } if let Some(cursor) = sub_table_cursor && let Some(current_table) = self.symbol_table_stack.last_mut() { current_table.next_sub_table = cursor; } - self.switch_to_block(loop_block); + + let loop_block = self.new_block(); + let cleanup = self.new_block(); + self.use_cpython_label_block(loop_block); + self.set_source_range(loc); emit!(self, Instruction::ForIter { delta: cleanup }); match kind { BuiltinGeneratorCallKind::Tuple => { + self.set_source_range(loc); emit!(self, Instruction::ListAppend { i: 2 }); + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: loop_block }); } BuiltinGeneratorCallKind::All => { + self.set_source_range(loc); emit!(self, Instruction::ToBool); emit!(self, Instruction::PopJumpIfTrue { delta: loop_block }); + self.set_source_range(loc); emit!(self, Instruction::PopIter); + self.set_no_location(); self.emit_load_const(ConstantData::Boolean { value: false }); + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); } BuiltinGeneratorCallKind::Any => { + self.set_source_range(loc); emit!(self, Instruction::ToBool); emit!(self, Instruction::PopJumpIfFalse { delta: loop_block }); + self.set_source_range(loc); emit!(self, Instruction::PopIter); + self.set_no_location(); self.emit_load_const(ConstantData::Boolean { value: true }); + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); } } - self.switch_to_block(cleanup); + self.use_cpython_label_block(cleanup); + self.set_source_range(loc); emit!(self, Instruction::EndFor); + self.set_no_location(); + self.set_source_range(loc); emit!(self, Instruction::PopIter); + self.set_no_location(); match kind { BuiltinGeneratorCallKind::Tuple => { + self.set_source_range(loc); emit!( self, Instruction::CallIntrinsic1 { @@ -9279,15 +8612,18 @@ impl Compiler { ); } BuiltinGeneratorCallKind::All => { + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: true }); } BuiltinGeneratorCallKind::Any => { + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: false }); } } + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); - self.switch_to_block(fallback); + self.use_cpython_label_block(skip_optimization); Ok(()) } @@ -9305,13 +8641,27 @@ impl Compiler { // super().method() or super(cls, self).method() optimization // CALL path: [global_super, class, self] → LOAD_SUPER_METHOD → [method, self] // CALL_FUNCTION_EX path: [global_super, class, self] → LOAD_SUPER_ATTR → [attr] - // Set source range to the super() call for LOAD_GLOBAL/LOAD_DEREF/etc. - let super_range = value.range(); - self.set_source_range(super_range); - self.load_args_for_super(&super_type)?; - self.set_source_range(super_range); + let ast::Expr::Call(ast::ExprCall { + func: super_func, .. + }) = value.as_ref() + else { + unreachable!("can_optimize_super_call only accepts calls"); + }; + self.load_args_for_super(&super_type, super_func.range(), value.range())?; + let attr_access_range = self.update_start_location_to_match_attr( + func.range(), + func.range(), + attr.as_str(), + ); + let method_call_range = self.update_start_location_to_match_attr( + call_range, + func.range(), + attr.as_str(), + ); + self.set_source_range(attr_access_range); let idx = self.name(attr.as_str()); if uses_ex_call { + self.set_source_range(func.range()); match super_type { SuperCallType::TwoArg { .. } => { self.emit_load_super_attr(idx); @@ -9323,11 +8673,11 @@ impl Compiler { // CPython's Attribute_kind super path emits an attr-line // NOP after LOAD_SUPER_ATTR, even when the call later uses // CALL_FUNCTION_EX for starred arguments. - self.set_source_range(attr.range()); + self.set_source_range(attr_access_range); emit!(self, Instruction::Nop); - self.set_source_range(super_range); + self.set_source_range(func.range()); emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range)?; + self.codegen_call_helper(0, args, call_range, None)?; } else { match super_type { SuperCallType::TwoArg { .. } => { @@ -9338,14 +8688,25 @@ impl Compiler { } } // NOP for line tracking at .method( line - self.set_source_range(attr.range()); + self.set_source_range(attr_access_range); emit!(self, Instruction::Nop); // CALL at .method( line (not the full expression line) - self.codegen_call_helper(0, args, attr.range())?; + self.codegen_call_helper(0, args, method_call_range, Some(attr_access_range))?; } } else { self.compile_expression(value)?; let idx = self.name(attr.as_str()); + let attr_access_range = self.update_start_location_to_match_attr( + func.range(), + func.range(), + attr.as_str(), + ); + let method_call_range = self.update_start_location_to_match_attr( + call_range, + func.range(), + attr.as_str(), + ); + self.set_source_range(attr_access_range); // Imported names and CALL_FUNCTION_EX-style calls use plain // LOAD_ATTR + PUSH_NULL; other names use method-call mode. // Check current scope and enclosing scopes for IMPORTED flag. @@ -9357,25 +8718,56 @@ impl Compiler { } else { self.emit_load_attr_method(idx); } - self.codegen_call_helper(0, args, call_range)?; + if is_import || uses_ex_call { + self.codegen_call_helper(0, args, call_range, None)?; + } else { + self.codegen_call_helper(0, args, method_call_range, Some(attr_access_range))?; + } } } else if let Some(kind) = (!uses_ex_call) .then(|| self.detect_builtin_generator_call(func, args)) .flatten() { - let end = self.new_block(); + let skip_normal_call = self.new_block(); self.compile_expression(func)?; - self.optimize_builtin_generator_call(kind, &args.args[0], end)?; - self.set_source_range(call_range); + self.optimize_builtin_generator_call( + kind, + &args.args[0], + func.range(), + skip_normal_call, + )?; + self.set_source_range(func.range()); emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range)?; - self.switch_to_block(end); + self.codegen_call_helper(0, args, call_range, None)?; + self.use_cpython_label_block(skip_normal_call); } else { // Regular call: push func, then NULL for self_or_null slot // Stack layout: [func, NULL, args...] - same as method call [func, self, args...] + // CPython `codegen_call()` always creates and uses + // `skip_normal_call`, even when `maybe_optimize_function_call()` + // leaves it untargeted. + let skip_normal_call = self.current_code_info().new_instr_sequence_label(); + let sync_genexpr_call_name = (!uses_ex_call) + .then(|| self.cpython_sync_genexpr_call_name(func, args)) + .flatten() + .is_some(); self.compile_expression(func)?; + if sync_genexpr_call_name { + // CPython `maybe_optimize_function_call()` creates and uses + // `skip_optimization` for every sync name(genexpr) shape after + // loading the function, even when the name is not all/any/tuple. + let skip_optimization = self.current_code_info().new_instr_sequence_label(); + let result = self + .current_code_info() + .use_raw_instr_sequence_label(skip_optimization); + unwrap_internal(self, result); + } emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range)?; + self.codegen_call_helper(0, args, call_range, None)?; + let result = self + .current_code_info() + .use_raw_instr_sequence_label(skip_normal_call); + unwrap_internal(self, result); } Ok(()) } @@ -9397,6 +8789,7 @@ impl Compiler { keywords: &[ast::Keyword], begin: usize, end: usize, + call_range: TextRange, ) -> CompileResult<()> { let n = end - begin; assert!(n > 0); @@ -9405,22 +8798,26 @@ impl Compiler { let big = n * 2 > STACK_USE_GUIDELINE as usize; if big { + self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: 0 }); } for kw in &keywords[begin..end] { // Key first, then value - this is critical! + self.set_source_range(call_range); self.emit_load_const(ConstantData::Str { value: kw.arg.as_ref().unwrap().as_str().into(), }); self.compile_expression(&kw.value)?; if big { + self.set_source_range(call_range); emit!(self, Instruction::MapAdd { i: 1 }); } } if !big { + self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: n.to_u32() }); } @@ -9435,6 +8832,7 @@ impl Compiler { additional_positional: u32, arguments: &ast::Arguments, call_range: TextRange, + kw_names_range: Option, ) -> CompileResult<()> { let nelts = arguments.args.len(); let nkwelts = arguments.keywords.len(); @@ -9453,8 +8851,18 @@ impl Compiler { if !has_starred && !has_double_star && !too_big { // Simple call path: no * or ** args + let implicit_generator_range = + if additional_positional == 0 && nelts == 1 && nkwelts == 0 { + self.cpython_implicit_call_generator_range(&arguments.args[0]) + } else { + None + }; for arg in &arguments.args { - self.compile_expression(arg)?; + if let Some(range) = implicit_generator_range { + self.compile_expression_with_generator_range(arg, range)?; + } else { + self.compile_expression(arg)?; + } } if nkwelts > 0 { @@ -9468,11 +8876,12 @@ impl Compiler { } // Restore call expression range for kwnames and CALL_KW - self.set_source_range(call_range); + self.set_source_range(kw_names_range.unwrap_or(call_range)); self.emit_load_const(ConstantData::Tuple { elements: kwarg_names, }); + self.set_source_range(call_range); let argc = additional_positional + nelts.to_u32() + nkwelts.to_u32(); emit!(self, Instruction::CallKw { argc }); } else { @@ -9491,40 +8900,23 @@ impl Compiler { // Single starred arg: pass value directly to CallFunctionEx. // Runtime will convert to tuple and validate with function name. if let ast::Expr::Starred(ast::ExprStarred { value, .. }) = &arguments.args[0] { - self.compile_expression_without_const_boolop_folding(value)?; - } - } else if !has_starred { - for arg in &arguments.args { - self.compile_expression(arg)?; - } - self.set_source_range(call_range); - let positional_count = additional_positional + nelts.to_u32(); - if positional_count == 0 { - self.emit_load_const(ConstantData::Tuple { elements: vec![] }); - } else { - emit!( - self, - Instruction::BuildTuple { - count: positional_count - } - ); + self.compile_expression(value)?; } } else { - // Use starunpack_helper to build a list, then convert to tuple + // CPython `codegen_call_helper_impl()` sends every other + // CALL_FUNCTION_EX positional shape through + // `starunpack_helper_impl(..., BUILD_LIST, LIST_APPEND, + // LIST_EXTEND, tuple=1)`, even when the only reason for the + // ex-call path is too many non-starred positional arguments. + self.set_source_range(call_range); self.starunpack_helper( &arguments.args, additional_positional, - CollectionType::List, + CollectionType::Tuple, )?; - emit!( - self, - Instruction::CallIntrinsic1 { - func: IntrinsicFunction1::ListToTuple - } - ); } - self.compile_call_function_ex_keywords(&arguments.keywords)?; + self.compile_call_function_ex_keywords(&arguments.keywords, call_range)?; self.set_source_range(call_range); emit!(self, Instruction::CallFunctionEx); @@ -9536,8 +8928,10 @@ impl Compiler { fn compile_call_function_ex_keywords( &mut self, keywords: &[ast::Keyword], + call_range: TextRange, ) -> CompileResult<()> { if keywords.is_empty() { + self.set_source_range(call_range); emit!(self, Instruction::PushNull); return Ok(()); } @@ -9548,8 +8942,9 @@ impl Compiler { for (i, keyword) in keywords.iter().enumerate() { if keyword.arg.is_none() { if nseen > 0 { - self.codegen_subkwargs(keywords, i - nseen, i)?; + self.codegen_subkwargs(keywords, i - nseen, i, call_range)?; if have_dict { + self.set_source_range(call_range); emit!(self, Instruction::DictMerge { i: 1 }); } have_dict = true; @@ -9557,11 +8952,13 @@ impl Compiler { } if !have_dict { + self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: 0 }); have_dict = true; } - self.compile_expression_without_const_boolop_folding(&keyword.value)?; + self.compile_expression(&keyword.value)?; + self.set_source_range(call_range); emit!(self, Instruction::DictMerge { i: 1 }); } else { nseen += 1; @@ -9569,8 +8966,9 @@ impl Compiler { } if nseen > 0 { - self.codegen_subkwargs(keywords, keywords.len() - nseen, keywords.len())?; + self.codegen_subkwargs(keywords, keywords.len() - nseen, keywords.len(), call_range)?; if have_dict { + self.set_source_range(call_range); emit!(self, Instruction::DictMerge { i: 1 }); } have_dict = true; @@ -9592,6 +8990,173 @@ impl Compiler { }) } + fn compile_expression_with_generator_range( + &mut self, + expression: &ast::Expr, + range: TextRange, + ) -> CompileResult<()> { + if let ast::Expr::Generator(ast::ExprGenerator { + elt, generators, .. + }) = expression + { + self.set_source_range(range); + self.compile_generator_expression(elt, generators, range) + } else { + self.compile_expression(expression) + } + } + + fn cpython_implicit_call_generator_range(&self, expression: &ast::Expr) -> Option { + if !matches!(expression, ast::Expr::Generator(_)) { + return None; + } + let range = expression.range(); + let source = self.source_file.source_text().as_bytes(); + let start = range.start().to_usize(); + let end = range.end().to_usize(); + if source.get(start) == Some(&b'(') + && !Self::starts_with_parenthesized_generator_element(source, start, end) + { + return None; + } + + let mut open = start; + while open > 0 && source[open - 1].is_ascii_whitespace() { + open -= 1; + } + if open == 0 || source[open - 1] != b'(' { + return None; + } + + let mut close = end; + while close < source.len() && source[close].is_ascii_whitespace() { + close += 1; + } + if source.get(close) != Some(&b')') { + return None; + } + + let adjusted_start = u32::try_from(open - 1).ok()?; + let adjusted_end = u32::try_from(close + 1).ok()?; + Some(TextRange::new( + TextSize::from(adjusted_start), + TextSize::from(adjusted_end), + )) + } + + fn starts_with_parenthesized_generator_element( + source: &[u8], + start: usize, + end: usize, + ) -> bool { + let mut depth = 0usize; + let mut i = start; + while i < end { + match source[i] { + b'(' | b'[' | b'{' => depth += 1, + b')' | b']' | b'}' => { + if depth == 0 { + return false; + } + depth -= 1; + if depth == 0 { + return Self::next_token_is_for(source, i + 1, end); + } + } + b'\'' | b'"' => i = Self::skip_python_string_literal(source, i), + _ => {} + } + i += 1; + } + false + } + + fn skip_python_string_literal(source: &[u8], quote: usize) -> usize { + let quote_byte = source[quote]; + let triple = source.get(quote + 1) == Some("e_byte) + && source.get(quote + 2) == Some("e_byte); + let mut i = quote + if triple { 3 } else { 1 }; + while i < source.len() { + if source[i] == b'\\' { + i += 2; + continue; + } + if triple { + if source[i] == quote_byte + && source.get(i + 1) == Some("e_byte) + && source.get(i + 2) == Some("e_byte) + { + return i + 2; + } + } else if source[i] == quote_byte { + return i; + } + i += 1; + } + source.len().saturating_sub(1) + } + + fn next_token_is_for(source: &[u8], mut i: usize, end: usize) -> bool { + while i < end && source[i].is_ascii_whitespace() { + i += 1; + } + source.get(i..i + 3) == Some(b"for") + && source + .get(i + 3) + .is_none_or(|byte| !byte.is_ascii_alphanumeric() && *byte != b'_') + } + + fn compile_generator_expression( + &mut self, + elt: &ast::Expr, + generators: &[ast::Comprehension], + range: TextRange, + ) -> CompileResult<()> { + // Check if element or generators contain async content + // This makes the generator expression into an async generator + let element_contains_await = + Self::contains_await(elt) || Self::generators_contain_await(generators); + self.compile_comprehension( + "", + None, + generators, + &|compiler, _collection_add_i| { + // Compile the element expression + // Note: if element is an async comprehension, compile_expression + // already handles awaiting it, so we don't need to await again here + compiler.compile_comprehension_element(elt)?; + + compiler.mark_generator(); + if compiler.ctx.func == FunctionContext::AsyncFunction { + compiler.set_source_range(elt.range()); + emit!( + compiler, + Instruction::CallIntrinsic1 { + func: bytecode::IntrinsicFunction1::AsyncGenWrap + } + ); + } + // arg=0: direct yield (wrapped for async generators) + compiler.set_source_range(elt.range()); + emit!(compiler, Instruction::YieldValue { arg: 0 }); + emit!( + compiler, + Instruction::Resume { + context: oparg::ResumeContext::from(oparg::ResumeLocation::AfterYield) + } + ); + emit!(compiler, Instruction::PopTop); + + Ok(()) + }, + ComprehensionType::Generator, + element_contains_await, + range, + elt.range(), + elt.range(), + ) + } + fn consume_next_sub_table(&mut self) -> CompileResult<()> { { let _ = self.push_symbol_table()?; @@ -9723,7 +9288,11 @@ impl Compiler { self.enter_scope(obj_name, scope_type, key, lineno.to_u32())?; if let Some(info) = self.code_stack.last_mut() { - info.flags = flags | (info.flags & bytecode::CodeFlags::NESTED); + info.flags = flags + | (info.flags + & (bytecode::CodeFlags::NESTED + | bytecode::CodeFlags::METHOD + | bytecode::CodeFlags::FUTURE_ANNOTATIONS)); info.metadata.argcount = arg_count; info.metadata.posonlyargcount = posonlyarg_count; info.metadata.kwonlyargcount = kwonlyarg_count; @@ -9731,7 +9300,7 @@ impl Compiler { Ok(()) } - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments, reason = "ignore warning for now")] fn compile_comprehension( &mut self, name: &str, @@ -9741,6 +9310,8 @@ impl Compiler { comprehension_type: ComprehensionType, element_contains_await: bool, comprehension_range: TextRange, + element_range: TextRange, + outer_backedge_range: TextRange, ) -> CompileResult<()> { let prev_ctx = self.ctx; let has_an_async_gen = generators.iter().any(|g| g.is_async); @@ -9788,14 +9359,12 @@ impl Compiler { init_collection, generators, compile_element, - has_an_async_gen, - comprehension_range, + (comprehension_range, element_range, outer_backedge_range), ); } // Non-inlined path: create a new code object (generator expressions, etc.) self.ctx = CompileContext { - loop_data: None, in_class: prev_ctx.in_class, func: if is_async { FunctionContext::AsyncFunction @@ -9819,27 +9388,22 @@ impl Compiler { // scope itself. Peek past those nested scopes so we can enter the // correct comprehension table here, then let the real outermost // iterator compile consume its nested scopes later in parent scope. - self.push_output_with_symbol_table(comp_table, flags, 1, 1, 0, name)?; + self.push_output_with_symbol_table(comp_table, flags, 0, 1, 0, name)?; // Set qualname for comprehension self.set_qualname(); + self.set_source_range(comprehension_range); let arg0 = self.varname(".0"); let return_none = init_collection.is_none(); - // PEP 479: Wrap generator/coroutine body with StopIteration handler - let is_gen_scope = self.current_symbol_table().is_generator || is_async; - let stop_iteration_block = if is_gen_scope { + // CPython codegen_comprehension() wraps only generator expressions + // with codegen_wrap_in_stopiteration_handler(); unlike function bodies, + // it does not push a COMPILE_FBLOCK_STOP_ITERATION fblock. + let stop_iteration_block = if comprehension_type == ComprehensionType::Generator { let handler_block = self.new_block(); - emit!( - self, - PseudoInstruction::SetupCleanup { - delta: handler_block - } - ); - self.set_no_location(); - self.push_fblock(FBlockType::StopIteration, handler_block, handler_block)?; + self.insert_cpython_stopiteration_setup_cleanup(handler_block); Some(handler_block) } else { None @@ -9858,66 +9422,81 @@ impl Compiler { && let Some(singleton_iter) = Self::singleton_comprehension_assignment_iter(&generator.iter) { + // CPython allocates start/if_cleanup/anchor labels before the + // singleton sub-iterator fast path sets start = NO_LABEL. + let _start_label = self.current_code_info().new_instr_sequence_label(); + let if_cleanup_block = self.new_block(); + let _anchor_label = self.current_code_info().new_instr_sequence_label(); self.compile_expression(singleton_iter)?; self.compile_store(&generator.target)?; if !generator.ifs.is_empty() { - let if_cleanup_block = self.new_block(); for if_condition in &generator.ifs { self.compile_jump_if(if_condition, false, if_cleanup_block)?; } - let body_block = self.new_block(); - self.switch_to_block(body_block); - loop_labels.push(ComprehensionLoopControl::IfCleanupOnly { if_cleanup_block }); } + loop_labels.push(ComprehensionLoopControl::IfCleanupOnly { if_cleanup_block }); continue; } let loop_block = self.new_block(); - let if_cleanup_block = self.new_block(); - let after_block = self.new_block(); + let (send_block, after_block, if_cleanup_block) = if generator.is_async { + let send_block = self.new_block(); + let after_block = self.new_block(); + let if_cleanup_block = self.new_block(); + (send_block, after_block, if_cleanup_block) + } else { + let if_cleanup_block = self.new_block(); + let after_block = self.new_block(); + (BlockIdx::NULL, after_block, if_cleanup_block) + }; if gen_index == 0 { // Load iterator onto stack (passed as first argument): emit!(self, Instruction::LoadFast { var_num: arg0 }); } else { // Evaluate iterated item: - self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; - - // Get iterator / turn item into an iterator - if generator.is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - } + self.compile_comprehension_iter(generator)?; } - self.switch_to_block(loop_block); + self.use_cpython_label_block(loop_block); let mut end_async_for_target = BlockIdx::NULL; if generator.is_async { - emit!(self, PseudoInstruction::SetupFinally { delta: after_block }); - emit!(self, Instruction::GetAnext); - self.push_fblock( + let loop_label = self.instr_sequence_label_for_block(loop_block); + self.push_fblock_labels( FBlockType::AsyncComprehensionGenerator, - loop_block, - after_block, + loop_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, )?; + emit!(self, PseudoInstruction::SetupFinally { delta: after_block }); + emit!(self, Instruction::GetAnext); self.emit_load_const(ConstantData::None); - end_async_for_target = self.compile_yield_from_sequence(true)?; + self.use_cpython_label_block(send_block); + let _ = self.compile_yield_from_sequence(true); + end_async_for_target = send_block; // POP_BLOCK before store: only __anext__/yield_from are // protected by SetupFinally targeting END_ASYNC_FOR. emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::AsyncComprehensionGenerator); self.compile_store(&generator.target)?; } else { + let saved_range = self.current_source_range; + self.set_source_range(generator.iter.range()); emit!(self, Instruction::ForIter { delta: after_block }); + self.set_source_range(saved_range); self.compile_store(&generator.target)?; } real_loop_depth += 1; + let backedge_range = if gen_index + 1 == generators.len() { + element_range + } else { + outer_backedge_range + }; loop_labels.push(ComprehensionLoopControl::Iteration { loop_block, if_cleanup_block, after_block, + backedge_range, is_async: generator.is_async, end_async_for_target, }); @@ -9927,10 +9506,6 @@ impl Compiler { for if_condition in &generator.ifs { self.compile_jump_if(if_condition, false, if_cleanup_block)?; } - if !generator.ifs.is_empty() { - let body_block = self.new_block(); - self.switch_to_block(body_block); - } } compile_element(self, real_loop_depth + 1)?; @@ -9941,43 +9516,47 @@ impl Compiler { loop_block, if_cleanup_block, after_block, + backedge_range, is_async, end_async_for_target, } => { + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); - self.switch_to_block(if_cleanup_block); + self.use_cpython_label_block(if_cleanup_block); + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); - self.switch_to_block(after_block); if is_async { + let loop_label = self.instr_sequence_label_for_block(loop_block); + self.pop_fblock_label(FBlockType::AsyncComprehensionGenerator, loop_label); + } + + self.use_cpython_label_block(after_block); + if is_async { + self.set_source_range(comprehension_range); // EndAsyncFor pops both the exception and the aiter // (handler depth is before GetANext, so aiter is at handler depth) self.emit_end_async_for(end_async_for_target); } else { - // END_FOR + POP_ITER pattern (CPython 3.14) - emit!(self, Instruction::EndFor); - emit!(self, Instruction::PopIter); + self.emit_sync_comprehension_end_for(); } } ComprehensionLoopControl::IfCleanupOnly { if_cleanup_block } => { - self.switch_to_block(if_cleanup_block); + self.use_cpython_label_block(if_cleanup_block); } } } if return_none { - self.emit_load_const(ConstantData::None) + self.emit_return_const_no_location(ConstantData::None); + } else { + self.emit_return_value(); } - self.emit_return_value(); - // Close StopIteration handler and emit handler code if let Some(handler_block) = stop_iteration_block { - emit!(self, PseudoInstruction::PopBlock); - self.set_no_location(); - self.pop_fblock(FBlockType::StopIteration); - self.switch_to_block(handler_block); + self.use_cpython_label_block(handler_block); emit!( self, Instruction::CallIntrinsic1 { @@ -9994,29 +9573,23 @@ impl Compiler { self.ctx = prev_ctx; // Create comprehension function with closure + self.set_source_range(comprehension_range); self.make_closure(code, bytecode::MakeFunctionFlags::new())?; - // Evaluate iterated item: - self.compile_for_iterable_expression(&outermost.iter, outermost.is_async)?; + // Evaluate iterated item and get its iterator. + self.compile_comprehension_iter(outermost)?; self.symbol_table_stack .last_mut() .expect("no current symbol table") .next_sub_table += 1; - // Get iterator / turn item into an iterator - // Use is_async from the first generator, not has_an_async_gen which covers ALL generators - if outermost.is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - }; - // Call just created function: + self.set_source_range(comprehension_range); emit!(self, Instruction::Call { argc: 0 }); if is_async_list_set_dict_comprehension { emit!(self, Instruction::GetAwaitable { r#where: 0 }); self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; + let _ = self.compile_yield_from_sequence(true); } Ok(()) @@ -10030,9 +9603,9 @@ impl Compiler { init_collection: Option, generators: &[ast::Comprehension], compile_element: &dyn Fn(&mut Self, usize) -> CompileResult<()>, - has_async: bool, - comprehension_range: TextRange, + ranges: (TextRange, TextRange, TextRange), ) -> CompileResult<()> { + let (comprehension_range, element_range, outer_backedge_range) = ranges; fn collect_bound_names(target: &ast::Expr, out: &mut Vec) { match target { ast::Expr::Name(ast::ExprName { id, .. }) => out.push(id.to_string()), @@ -10053,10 +9626,7 @@ impl Compiler { // nested scopes (e.g. lambdas) whose sub_tables sit at the current // position in the parent's list. Those must be consumed before we // splice in the comprehension's own children. - self.compile_for_iterable_expression( - &generators[0].iter, - has_async && generators[0].is_async, - )?; + self.compile_comprehension_iter(&generators[0])?; self.symbol_table_stack .last_mut() .expect("no current symbol table") @@ -10086,12 +9656,6 @@ impl Compiler { current_table.sub_tables.insert(insert_pos + i, st.clone()); } } - if has_async && generators[0].is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - } - let mut source_order_bound_names = Vec::new(); for generator in generators { collect_bound_names(&generator.target, &mut source_order_bound_names); @@ -10205,25 +9769,27 @@ impl Compiler { ); } - // Step 4: Create the collection (list/set/dict) - if let Some(init_collection) = init_collection { - self._emit(init_collection, OpArg::new(0), BlockIdx::NULL); - // SWAP to get iterator on top - emit!(self, Instruction::Swap { i: 2 }); - } - - // Set up exception handler for cleanup on exception - let cleanup_block = self.new_block(); - let end_block = self.new_block(); - - if !pushed_locals.is_empty() { + // CPython's codegen_push_inlined_comprehension_locals() + // installs the virtual cleanup before codegen_comprehension() + // emits BUILD_LIST/BUILD_SET/BUILD_MAP for the result object. + let cleanup_block = if !pushed_locals.is_empty() { + let cleanup_block = self.new_block(); emit!( self, PseudoInstruction::SetupFinally { delta: cleanup_block } ); - self.push_fblock(FBlockType::TryExcept, cleanup_block, end_block)?; + Some(cleanup_block) + } else { + None + }; + + // Step 4: Create the collection (list/set/dict) + if let Some(init_collection) = init_collection { + self._emit(init_collection, OpArg::new(0), BlockIdx::NULL); + // SWAP to get iterator on top + emit!(self, Instruction::Swap { i: 2 }); } // Step 5: Compile the comprehension loop(s) @@ -10235,50 +9801,57 @@ impl Compiler { && let Some(singleton_iter) = Self::singleton_comprehension_assignment_iter(&generator.iter) { + // CPython allocates start/if_cleanup/anchor labels before + // the singleton sub-iterator fast path sets start = NO_LABEL. + let _start_label = self.current_code_info().new_instr_sequence_label(); + let if_cleanup_block = self.new_block(); + let _anchor_label = self.current_code_info().new_instr_sequence_label(); self.compile_expression(singleton_iter)?; self.compile_store(&generator.target)?; if !generator.ifs.is_empty() { - let if_cleanup_block = self.new_block(); for if_condition in &generator.ifs { self.compile_jump_if(if_condition, false, if_cleanup_block)?; } - let body_block = self.new_block(); - self.switch_to_block(body_block); - loop_labels - .push(ComprehensionLoopControl::IfCleanupOnly { if_cleanup_block }); } + loop_labels.push(ComprehensionLoopControl::IfCleanupOnly { if_cleanup_block }); continue; } let loop_block = self.new_block(); - let if_cleanup_block = self.new_block(); - let after_block = self.new_block(); + let (send_block, after_block, if_cleanup_block) = if generator.is_async { + let send_block = self.new_block(); + let after_block = self.new_block(); + let if_cleanup_block = self.new_block(); + (send_block, after_block, if_cleanup_block) + } else { + let if_cleanup_block = self.new_block(); + let after_block = self.new_block(); + (BlockIdx::NULL, after_block, if_cleanup_block) + }; if i > 0 { - self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; - if generator.is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - } + self.compile_comprehension_iter(generator)?; } - self.switch_to_block(loop_block); + self.use_cpython_label_block(loop_block); let mut end_async_for_target = BlockIdx::NULL; if generator.is_async { - emit!(self, PseudoInstruction::SetupFinally { delta: after_block }); - emit!(self, Instruction::GetAnext); - self.push_fblock( + let loop_label = self.instr_sequence_label_for_block(loop_block); + self.push_fblock_labels( FBlockType::AsyncComprehensionGenerator, - loop_block, - after_block, + loop_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, )?; + emit!(self, PseudoInstruction::SetupFinally { delta: after_block }); + emit!(self, Instruction::GetAnext); self.emit_load_const(ConstantData::None); - end_async_for_target = self.compile_yield_from_sequence(true)?; + self.use_cpython_label_block(send_block); + let _ = self.compile_yield_from_sequence(true); + end_async_for_target = send_block; emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::AsyncComprehensionGenerator); self.compile_store(&generator.target)?; } else { let saved_range = self.current_source_range; @@ -10289,10 +9862,16 @@ impl Compiler { } real_loop_depth += 1; + let backedge_range = if i + 1 == generators.len() { + element_range + } else { + outer_backedge_range + }; loop_labels.push(ComprehensionLoopControl::Iteration { loop_block, if_cleanup_block, after_block, + backedge_range, is_async: generator.is_async, end_async_for_target, }); @@ -10314,50 +9893,63 @@ impl Compiler { loop_block, if_cleanup_block, after_block, + backedge_range, is_async, end_async_for_target, } => { + self.use_cpython_label_block(if_cleanup_block); + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); - self.switch_to_block(if_cleanup_block); - emit!(self, PseudoInstruction::Jump { delta: loop_block }); + if is_async { + let loop_label = self.instr_sequence_label_for_block(loop_block); + self.pop_fblock_label( + FBlockType::AsyncComprehensionGenerator, + loop_label, + ); + } - self.switch_to_block(after_block); + self.use_cpython_label_block(after_block); if is_async { + self.set_source_range(comprehension_range); self.emit_end_async_for(end_async_for_target); } else { - emit!(self, Instruction::EndFor); - emit!(self, Instruction::PopIter); + self.emit_sync_comprehension_end_for(); } } ComprehensionLoopControl::IfCleanupOnly { if_cleanup_block } => { - self.switch_to_block(if_cleanup_block); + self.use_cpython_label_block(if_cleanup_block); } } } // Step 8: Clean up - restore saved locals (and cell values) self.set_source_range(comprehension_range); - if total_stack_items > 0 { + if let Some(cleanup_block) = cleanup_block { emit!(self, PseudoInstruction::PopBlock); - self.pop_fblock(FBlockType::TryExcept); + self.set_no_location(); // Match CPython codegen_pop_inlined_comprehension_locals(): // the synthetic jump that skips the exception cleanup uses // JUMP_NO_INTERRUPT, which becomes JUMP_BACKWARD_NO_INTERRUPT // when the cleanup tail sits above the final restore block. + let end_block = self.new_block(); emit!( self, PseudoInstruction::JumpNoInterrupt { delta: end_block } ); + self.set_no_location(); // Exception cleanup path - self.switch_to_block(cleanup_block); + self.use_cpython_label_block(cleanup_block); // Stack: [saved_values..., collection, exception] emit!(self, Instruction::Swap { i: 2 }); + self.set_no_location(); emit!(self, Instruction::PopTop); // Pop incomplete collection + self.set_no_location(); // Restore locals and cell values + self.set_source_range(comprehension_range); emit!( self, Instruction::Swap { @@ -10370,9 +9962,11 @@ impl Compiler { } // Re-raise the exception emit!(self, Instruction::Reraise { depth: 0 }); + self.set_no_location(); // Normal end path - self.switch_to_block(end_block); + self.use_cpython_label_block(end_block); + self.set_source_range(comprehension_range); } // SWAP result to TOS (above saved values) @@ -10432,77 +10026,175 @@ impl Compiler { } // Low level helper functions: + /// CPython `_PyCfgBuilder_Addop()`: start a new basic block if the current + /// one is terminated, then append the instruction to the current block. + fn cpython_cfg_builder_addop(&mut self, info: ir::InstructionInfo) { + self.maybe_start_cpython_cfg_addop_block(); + self.push_emitted_instruction(info); + } + + fn push_emitted_instruction(&mut self, info: ir::InstructionInfo) { + self.current_code_info() + .addop_to_instr_sequence(info) + .expect("malformed instruction sequence emission"); + self.current_code_info() + .addop_to_current_block(info) + .expect("malformed CFG emission"); + } + + fn push_emitted_instruction_with_target_label( + &mut self, + info: ir::InstructionInfo, + target_label: ir::InstructionSequenceLabel, + ) { + self.current_code_info() + .addop_to_instr_sequence_with_target_label(info, target_label) + .expect("malformed instruction sequence emission"); + self.current_code_info() + .addop_to_current_block(info) + .expect("malformed CFG emission"); + } + + fn last_emitted_instruction_mut(&mut self) -> Option<&mut ir::InstructionInfo> { + self.current_code_info().last_current_block_instr_mut() + } + + fn set_last_emitted_lineno_override(&mut self, lineno_override: i32) { + self.current_code_info() + .set_last_instr_sequence_lineno_override(lineno_override); + if let Some(last) = self.last_emitted_instruction_mut() { + last.lineno_override = Some(lineno_override); + } + } + fn _emit>(&mut self, instr: I, arg: OpArg, target: BlockIdx) { if self.do_not_emit_bytecode > 0 { return; } + let instr = instr.into(); + let opcode = AnyOpcode::from(instr); + debug_assert!( + !instr.is_assembler(), + "CPython codegen_addop_* must not emit assembler-only opcodes" + ); + debug_assert!( + opcode.has_arg() || instr.has_target() || u32::from(arg) == 0, + "CPython _PyInstructionSequence_Addop requires either OPCODE_HAS_ARG, HAS_TARGET, or oparg == 0" + ); + debug_assert!( + target == BlockIdx::NULL || instr.has_target(), + "CPython codegen_addop_j only accepts HAS_TARGET opcodes" + ); let range = self.current_source_range; let source = self.source_file.to_source_code(); let location = source.source_location(range.start(), PositionEncoding::Utf8); let end_location = source.source_location(range.end(), PositionEncoding::Utf8); let except_handler = None; - self.current_block().instructions.push(ir::InstructionInfo { - instr: instr.into(), + self.cpython_cfg_builder_addop(ir::InstructionInfo { + instr, arg, target, location, end_location, except_handler, - folded_from_nonliteral_expr: false, lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, }); } - fn mark_last_instruction_folded_from_nonliteral_expr(&mut self) { - if let Some(info) = self.current_block().instructions.last_mut() { - info.folded_from_nonliteral_expr = true; - } - } - - fn preserve_last_redundant_jump_as_nop(&mut self) { - if let Some(info) = self.current_block().instructions.last_mut() { - info.preserve_redundant_jump_as_nop = true; - } - } - - fn preserve_last_redundant_nop(&mut self) { - if let Some(info) = self.current_block().instructions.last_mut() { - info.preserve_block_start_no_location_nop = true; - } - } - - fn remove_last_no_location_nop(&mut self) { - if let Some(info) = self.current_block().instructions.last_mut() { - info.remove_no_location_nop = true; - } - } - - fn force_remove_last_no_location_nop(&mut self) { - if let Some(info) = self.current_block().instructions.last_mut() { - info.remove_no_location_nop = true; - info.folded_operand_nop = true; + /// CPython `codegen_addop_j()`: emit a HAS_TARGET instruction with a + /// jump_target_label oparg. + fn emit_jump_label>( + &mut self, + instr: I, + target_label: ir::InstructionSequenceLabel, + ) { + if self.do_not_emit_bytecode > 0 { + return; } + let instr = instr.into(); + debug_assert!( + instr.has_target(), + "CPython codegen_addop_j only accepts HAS_TARGET opcodes" + ); + debug_assert!( + !instr.is_assembler(), + "CPython codegen_addop_j must not emit assembler-only opcodes" + ); + let range = self.current_source_range; + let source = self.source_file.to_source_code(); + let location = source.source_location(range.start(), PositionEncoding::Utf8); + let end_location = source.source_location(range.end(), PositionEncoding::Utf8); + let target = self + .current_code_info() + .block_for_instr_sequence_label(target_label); + self.maybe_start_cpython_cfg_addop_block(); + self.push_emitted_instruction_with_target_label( + ir::InstructionInfo { + instr, + arg: OpArg::NULL, + target, + location, + end_location, + except_handler: None, + lineno_override: None, + }, + target_label, + ); } /// Mark the last emitted instruction as having no source location. /// Prevents it from triggering LINE events in sys.monitoring. fn set_no_location(&mut self) { - if let Some(last) = self.current_block().instructions.last_mut() { - last.lineno_override = Some(-1); - } + self.set_last_emitted_lineno_override(-1); } - fn mark_last_no_location_exit(&mut self) { - if let Some(last) = self.current_block().instructions.last_mut() { - last.no_location_exit = true; - } + /// CPython `codegen_sync_comprehension_generator()` emits END_FOR/POP_ITER + /// with NO_LOCATION and lets `flowgraph.c::propagate_line_numbers()` copy + /// the FOR_ITER location onto the cleanup block. + fn emit_sync_comprehension_end_for(&mut self) { + emit!(self, Instruction::EndFor); + self.set_no_location(); + emit!(self, Instruction::PopIter); + self.set_no_location(); + } + + /// CPython `codegen_wrap_in_stopiteration_handler()` inserts + /// `SETUP_CLEANUP` into the instruction sequence at index 0 with + /// `_PyInstructionSequence_InsertInstruction()`. Generator and cell/free + /// prefixes are inserted later by `flowgraph.c::insert_prefix_instructions()`. + fn insert_cpython_stopiteration_setup_cleanup(&mut self, handler_block: BlockIdx) { + let code = self.current_code_info(); + let entry = code + .blocks + .first_mut() + .expect("code unit must have an entry block"); + debug_assert!( + entry + .used_instructions() + .first() + .is_some_and(|info| match info.instr.real() { + Some(Instruction::Resume { context }) => matches!( + context.get(info.arg).location(), + oparg::ResumeLocation::AtFuncStart + ), + _ => false, + }), + "scope entry must start with a function-start RESUME" + ); + debug_assert!( + !entry.used_instructions().iter().any(|info| matches!( + info.instr.real(), + Some( + Instruction::ReturnGenerator + | Instruction::MakeCell { .. } + | Instruction::CopyFreeVars { .. } + ) + )), + "CPython inserts StopIteration cleanup before CFG prefix instructions" + ); + + let result = code.insert_start_setup_cleanup(handler_block); + unwrap_internal(self, result); } fn emit_no_arg>(&mut self, ins: I) { @@ -10688,89 +10380,22 @@ impl Compiler { ) -> Option { let (left_int, left_is_bool) = Self::constant_as_fold_int(left)?; let (right_int, right_is_bool) = Self::constant_as_fold_int(right)?; - let zero = BigInt::from(0); - if !left_is_bool && !right_is_bool { + if !(left_is_bool && right_is_bool) { return None; } match op { - ast::Operator::Add => Some(ConstantData::Integer { - value: left_int + right_int, + ast::Operator::BitAnd => Some(ConstantData::Boolean { + value: !left_int.is_zero() & !right_int.is_zero(), }), - ast::Operator::Sub => Some(ConstantData::Integer { - value: left_int - right_int, + ast::Operator::BitOr => Some(ConstantData::Boolean { + value: !left_int.is_zero() | !right_int.is_zero(), }), - ast::Operator::Mult => Some(ConstantData::Integer { - value: left_int * right_int, + ast::Operator::BitXor => Some(ConstantData::Boolean { + value: !left_int.is_zero() ^ !right_int.is_zero(), }), - ast::Operator::Div => { - if right_int.is_zero() { - return None; - } - Some(ConstantData::Float { - value: left_int.to_f64()? / right_int.to_f64()?, - }) - } - ast::Operator::FloorDiv => { - if right_int.is_zero() || left_int < zero || right_int < zero { - return None; - } - Some(ConstantData::Integer { - value: left_int / right_int, - }) - } - ast::Operator::Mod => { - if right_int.is_zero() || left_int < zero || right_int < zero { - return None; - } - Some(ConstantData::Integer { - value: left_int % right_int, - }) - } - ast::Operator::Pow => { - let exponent = right_int.to_u32()?; - if exponent > 128 { - return None; - } - Some(ConstantData::Integer { - value: left_int.pow(exponent), - }) - } - ast::Operator::BitAnd => { - if left_is_bool && right_is_bool { - Some(ConstantData::Boolean { - value: !left_int.is_zero() & !right_int.is_zero(), - }) - } else { - Some(ConstantData::Integer { - value: left_int & right_int, - }) - } - } - ast::Operator::BitOr => { - if left_is_bool && right_is_bool { - Some(ConstantData::Boolean { - value: !left_int.is_zero() | !right_int.is_zero(), - }) - } else { - Some(ConstantData::Integer { - value: left_int | right_int, - }) - } - } - ast::Operator::BitXor => { - if left_is_bool && right_is_bool { - Some(ConstantData::Boolean { - value: !left_int.is_zero() ^ !right_int.is_zero(), - }) - } else { - Some(ConstantData::Integer { - value: left_int ^ right_int, - }) - } - } - ast::Operator::MatMult | ast::Operator::LShift | ast::Operator::RShift => None, + _ => None, } } @@ -10909,6 +10534,7 @@ impl Compiler { (ast::UnaryOp::Invert, ConstantData::Integer { value }) => { ConstantData::Integer { value: !value } } + (ast::UnaryOp::Not, ConstantData::Tuple { .. }) => return Ok(None), (ast::UnaryOp::Not, value) => ConstantData::Boolean { value: !Self::constant_truthiness(&value), }, @@ -10958,6 +10584,152 @@ impl Compiler { })) } + fn try_compile_ast_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + Ok(Some(match expr { + ast::Expr::NumberLiteral(num) => match &num.value { + ast::Number::Int(int) => ConstantData::Integer { + value: ruff_int_to_bigint(int).map_err(|e| self.error(e))?, + }, + ast::Number::Float(value) => ConstantData::Float { value: *value }, + ast::Number::Complex { real, imag } => ConstantData::Complex { + value: Complex::new(*real, *imag), + }, + }, + ast::Expr::StringLiteral(s) => ConstantData::Str { + value: self.compile_string_value(s), + }, + ast::Expr::BytesLiteral(b) => ConstantData::Bytes { + value: b.value.bytes().collect(), + }, + ast::Expr::BooleanLiteral(b) => ConstantData::Boolean { value: b.value }, + ast::Expr::NoneLiteral(_) => ConstantData::None, + ast::Expr::EllipsisLiteral(_) => ConstantData::Ellipsis, + _ => return Ok(None), + })) + } + + fn try_negate_match_pattern_constant(constant: ConstantData) -> Option { + match constant { + ConstantData::Integer { value } => Some(ConstantData::Integer { value: -value }), + ConstantData::Float { value } => Some(ConstantData::Float { value: -value }), + ConstantData::Complex { value } => Some(ConstantData::Complex { value: -value }), + ConstantData::Boolean { value } => Some(ConstantData::Integer { + value: -BigInt::from(u8::from(value)), + }), + _ => None, + } + } + + fn constant_as_match_pattern_complex(constant: &ConstantData) -> Option> { + match constant { + ConstantData::Integer { value } => Some(Complex::new(value.to_f64()?, 0.0)), + ConstantData::Float { value } => Some(Complex::new(*value, 0.0)), + ConstantData::Complex { value } => Some(*value), + ConstantData::Boolean { value } => Some(Complex::new(f64::from(u8::from(*value)), 0.0)), + _ => None, + } + } + + fn try_fold_match_pattern_binop( + op: ast::Operator, + left: &ConstantData, + right: &ConstantData, + ) -> Option { + if let (ConstantData::Integer { value: left }, ConstantData::Integer { value: right }) = + (left, right) + { + return match op { + ast::Operator::Add => Some(ConstantData::Integer { + value: left + right, + }), + ast::Operator::Sub => Some(ConstantData::Integer { + value: left - right, + }), + _ => None, + }; + } + + let left_is_complex = matches!(left, ConstantData::Complex { .. }); + let right_is_complex = matches!(right, ConstantData::Complex { .. }); + if left_is_complex || right_is_complex { + let left = Self::constant_as_match_pattern_complex(left)?; + let right = Self::constant_as_match_pattern_complex(right)?; + let value = match op { + ast::Operator::Add => Complex::new(left.re + right.re, left.im + right.im), + ast::Operator::Sub => { + let imag = if !left_is_complex && right_is_complex { + -right.im + } else { + left.im - right.im + }; + Complex::new(left.re - right.re, imag) + } + _ => return None, + }; + return Some(ConstantData::Complex { value }); + } + + let left = Self::constant_as_match_pattern_complex(left)?; + let right = Self::constant_as_match_pattern_complex(right)?; + match op { + ast::Operator::Add => Some(ConstantData::Float { + value: left.re + right.re, + }), + ast::Operator::Sub => Some(ConstantData::Float { + value: left.re - right.re, + }), + _ => None, + } + } + + fn try_fold_match_pattern_const_expr( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + // CPython 3.14 ast_preprocess.c::fold_const_match_patterns() + // folds only the constant forms needed by match patterns before + // codegen_pattern_value()/codegen_pattern_mapping_key() visit them. + Ok(match expr { + ast::Expr::UnaryOp(ast::ExprUnaryOp { + op: ast::UnaryOp::USub, + operand, + .. + }) => { + let Some(constant) = self.try_compile_ast_constant(operand)? else { + return Ok(None); + }; + Self::try_negate_match_pattern_constant(constant) + } + ast::Expr::BinOp(ast::ExprBinOp { + left, op, right, .. + }) if matches!(op, ast::Operator::Add | ast::Operator::Sub) => { + let Some(left) = (match self.try_fold_match_pattern_const_expr(left)? { + Some(constant) => Some(constant), + None => self.try_compile_ast_constant(left)?, + }) else { + return Ok(None); + }; + let Some(right) = self.try_compile_ast_constant(right)? else { + return Ok(None); + }; + Self::try_fold_match_pattern_binop(*op, &left, &right) + } + _ => None, + }) + } + + fn compile_match_pattern_expr(&mut self, expr: &ast::Expr) -> CompileResult<()> { + if let Some(constant) = self.try_fold_match_pattern_const_expr(expr)? { + self.emit_load_const(constant); + } else { + self.compile_expression(expr)?; + } + Ok(()) + } + fn emit_load_const(&mut self, constant: ConstantData) { let idx = self.arg_constant(constant); self.emit_arg(idx, |consti| Instruction::LoadConst { consti }) @@ -11018,10 +10790,8 @@ impl Compiler { fn emit_return_const_no_location(&mut self, constant: ConstantData) { self.emit_load_const(constant); self.set_no_location(); - self.mark_last_no_location_exit(); emit!(self, Instruction::ReturnValue); self.set_no_location(); - self.mark_last_no_location_exit(); } fn emit_end_async_for(&mut self, send_target: BlockIdx) { @@ -11134,232 +10904,184 @@ impl Compiler { return Ok(()); } - // unwind_fblock_stack - // We need to unwind fblocks and compile cleanup code. For FinallyTry blocks, - // we need to compile the finally body inline, but we must temporarily pop - // the fblock so that nested break/continue in the finally body don't see it. - - // First, find the loop - let code = self.current_code_info(); - let mut loop_idx = None; - let mut is_for_loop = false; - - for i in (0..code.fblock.len()).rev() { - match code.fblock[i].fb_type { - FBlockType::WhileLoop => { - loop_idx = Some(i); - is_for_loop = false; - break; - } - FBlockType::ForLoop => { - loop_idx = Some(i); - is_for_loop = true; - break; - } - FBlockType::ExceptionGroupHandler => { - return Err( - self.error_ranged(CodegenErrorType::BreakContinueReturnInExceptStar, range) - ); - } - _ => {} - } - } + let prev_source_range = self.current_source_range; + self.set_source_range(range); + emit!(self, Instruction::Nop); - let Some(loop_idx) = loop_idx else { + let (mut unwind_loc, loop_fblock) = self.unwind_fblock_stack_with_loop(false, true)?; + let Some(loop_fblock) = loop_fblock else { + self.set_source_range(prev_source_range); if is_break { return Err(self.error_ranged(CodegenErrorType::InvalidBreak, range)); } return Err(self.error_ranged(CodegenErrorType::InvalidContinue, range)); }; - let loop_block = code.fblock[loop_idx].fb_block; - let exit_block = code.fblock[loop_idx].fb_exit; - - let prev_source_range = self.current_source_range; - self.set_source_range(range); - emit!(self, Instruction::Nop); - self.set_source_range(prev_source_range); + // CPython `codegen_break()` unwinds the loop fblock itself after + // `codegen_unwind_fblock_stack()` returns it. `codegen_continue()` + // jumps directly to the loop body label. + if is_break { + self.unwind_fblock(&loop_fblock, false, &mut unwind_loc)?; + } - // Collect the fblocks we need to unwind through, from top down to (but not including) the loop - #[derive(Clone)] - enum UnwindAction { - With { - is_async: bool, - range: TextRange, - }, - HandlerCleanup { - name: Option, - }, - TryExcept, - FinallyTry { - body: Vec, - fblock_idx: usize, + // Jump to target + let target_label = if is_break { + debug_assert!(loop_fblock.fb_exit.is_jump_target_label()); + loop_fblock.fb_exit + } else { + debug_assert!(loop_fblock.fb_block.is_jump_target_label()); + loop_fblock.fb_block + }; + if let Some(loc) = unwind_loc { + self.set_source_range(loc); + } else { + self.set_source_range(range); + }; + self.emit_jump_label( + PseudoInstruction::Jump { + delta: OpArgMarker::marker(), }, - FinallyEnd, - PopValue, // Pop return value when continue/break cancels a return + target_label, + ); + if unwind_loc.is_none() { + self.set_no_location(); } - let mut unwind_actions = Vec::new(); + self.set_source_range(prev_source_range); + Ok(()) + } + + /// CPython `_PyCfgBuilder_Addop()` calls + /// `cfg_builder_maybe_start_new_block()` before appending each instruction. + /// That keeps any `IS_TERMINATOR_OPCODE` as the final instruction in its + /// basicblock before `flowgraph.c::check_cfg()`. + fn maybe_start_cpython_cfg_addop_block(&mut self) { + let code = self.current_code_info(); + let cur = code.current_block; + if !code.blocks[cur.idx()] + .instructions + .last() + .is_some_and(|instr| instr.instr.is_terminator()) { - let code = self.current_code_info(); - for i in (loop_idx + 1..code.fblock.len()).rev() { - match code.fblock[i].fb_type { - FBlockType::With => { - unwind_actions.push(UnwindAction::With { - is_async: false, - range: code.fblock[i].fb_range, - }); - } - FBlockType::AsyncWith => { - unwind_actions.push(UnwindAction::With { - is_async: true, - range: code.fblock[i].fb_range, - }); - } - FBlockType::HandlerCleanup => { - let name = match &code.fblock[i].fb_datum { - FBlockDatum::ExceptionName(name) => Some(name.clone()), - _ => None, - }; - unwind_actions.push(UnwindAction::HandlerCleanup { name }); - } - FBlockType::TryExcept => { - unwind_actions.push(UnwindAction::TryExcept); - } - FBlockType::FinallyTry => { - // Need to execute finally body before break/continue - if let FBlockDatum::FinallyBody(ref body) = code.fblock[i].fb_datum { - unwind_actions.push(UnwindAction::FinallyTry { - body: body.clone(), - fblock_idx: i, - }); - } - } - FBlockType::FinallyEnd => { - // Inside finally block reached via exception - need to pop exception - unwind_actions.push(UnwindAction::FinallyEnd); - } - FBlockType::PopValue => { - // Pop the return value that was saved on stack - unwind_actions.push(UnwindAction::PopValue); - } - _ => {} - } - } + return; } - // Emit cleanup for each fblock - let mut jump_no_location = false; - for action in unwind_actions { - match action { - UnwindAction::With { is_async, range } => { - // Stack: [..., exit_func, self_exit] - let saved_range = self.current_source_range; - self.set_source_range(range); - emit!(self, PseudoInstruction::PopBlock); - self.emit_load_const(ConstantData::None); - self.emit_load_const(ConstantData::None); - self.emit_load_const(ConstantData::None); - emit!(self, Instruction::Call { argc: 3 }); - - if is_async { - emit!(self, Instruction::GetAwaitable { r#where: 2 }); - self.emit_load_const(ConstantData::None); - let _ = self.compile_yield_from_sequence(true)?; - } - - emit!(self, Instruction::PopTop); - self.set_source_range(saved_range); - jump_no_location = true; - } - UnwindAction::HandlerCleanup { ref name } => { - // codegen_unwind_fblock(HANDLER_CLEANUP) - if name.is_some() { - // Named handler: PopBlock for inner SETUP_CLEANUP - emit!(self, PseudoInstruction::PopBlock); - } - // PopBlock for outer SETUP_CLEANUP (ExceptionHandler) - emit!(self, PseudoInstruction::PopBlock); - emit!(self, Instruction::PopExcept); - if let Some(name) = name { - self.emit_load_const(ConstantData::None); - self.store_name(name)?; - self.compile_name(name, NameUsage::Delete)?; - } - } - UnwindAction::TryExcept => { - // codegen_unwind_fblock(TRY_EXCEPT) - emit!(self, PseudoInstruction::PopBlock); - } - UnwindAction::FinallyTry { body, fblock_idx } => { - // codegen_unwind_fblock(FINALLY_TRY) - emit!(self, PseudoInstruction::PopBlock); - - // compile finally body inline - // Temporarily pop the FinallyTry fblock so nested break/continue - // in the finally body won't see it again. - let code = self.current_code_info(); - let saved_fblock = code.fblock.remove(fblock_idx); + debug_assert_eq!(code.blocks[cur.idx()].next, BlockIdx::NULL); + let block = self.cpython_cfg_builder_new_block(); + self.cpython_cfg_builder_use_next_block(block); + } - self.compile_statements(&body)?; + /// CPython `codegen_funcbody()` emits `NEW_JUMP_TARGET_LABEL(start)` and + /// `USE_LABEL(start)` after scope entry. That label is part of the + /// instruction-sequence label map, but it does not become a CFG + /// `basicblock.b_label` unless some emitted instruction targets it. + fn use_cpython_function_start_label(&mut self) -> ir::InstructionSequenceLabel { + let (label, result) = { + let code = self.current_code_info(); + let label = code.new_instr_sequence_label(); + let result = code.use_raw_instr_sequence_label(label); + (label, result) + }; + unwrap_internal(self, result); + label + } - // Restore the fblock (though this break/continue will jump away, - // this keeps the fblock stack consistent for error checking) - let code = self.current_code_info(); - code.fblock.insert(fblock_idx, saved_fblock); - jump_no_location = true; - } - UnwindAction::FinallyEnd => { - // codegen_unwind_fblock(FINALLY_END) - emit!(self, Instruction::PopTop); // exc_value - emit!(self, PseudoInstruction::PopBlock); - emit!(self, Instruction::PopExcept); - } - UnwindAction::PopValue => { - // Pop the return value - continue/break cancels the pending return - emit!(self, Instruction::PopTop); - } - } - } + fn instr_sequence_label_for_block(&mut self, block: BlockIdx) -> ir::InstructionSequenceLabel { + let result = self + .current_code_info() + .instr_sequence_label_for_block(block); + unwrap_internal(self, result) + } - // CPython unwinds a for-loop break with POP_TOP rather than POP_ITER. - if is_break && is_for_loop { - emit!(self, Instruction::PopTop); + /// Switch to a block as CPython instruction-sequence labels would resolve. + /// + /// Consecutive `USE_LABEL()` calls can map multiple labels to the same + /// instruction offset before `_PyCfg_FromInstructionSequence()` builds a + /// CFG. Reuse an empty current block even if it already carries a label so + /// codegen CFG path preserves that aliasing. + fn use_cpython_label_block(&mut self, block: BlockIdx) { + let result = self.current_code_info().use_instr_sequence_label(block); + unwrap_internal(self, result); + let code = self.current_code_info(); + let block = code.resolve_instr_sequence_label(block); + let cur = code.current_block; + let can_reuse_current = cur != block + && code.blocks[cur.idx()].is_empty() + && code.blocks[cur.idx()].next == BlockIdx::NULL + && code.blocks[block.idx()].is_empty() + && code.blocks[block.idx()].next == BlockIdx::NULL; + + if !can_reuse_current { + let result = code.mark_cpython_cfg_label(block); + unwrap_internal(self, result); + self.switch_to_block(block); + return; } - // Jump to target - let target = if is_break { exit_block } else { loop_block }; - let saved_range = self.current_source_range; - self.set_source_range(range); - emit!(self, PseudoInstruction::Jump { delta: target }); - if jump_no_location { - self.set_no_location(); + { + let target = &code.blocks[block.idx()]; + debug_assert!(!target.except_handler); + debug_assert!(!target.preserve_lasti); + debug_assert_eq!(target.start_depth, ir::START_DEPTH_UNSET); + debug_assert!(!target.cold); } - self.set_source_range(saved_range); + let result = code.mark_cpython_cfg_label(cur); + unwrap_internal(self, result); - Ok(()) + let result = self + .current_code_info() + .use_instr_sequence_label_at_block(block, cur); + unwrap_internal(self, result); } - fn current_block(&mut self) -> &mut ir::Block { - let info = self.current_code_info(); - &mut info.blocks[info.current_block] + fn new_block(&mut self) -> BlockIdx { + let result = self + .current_code_info() + .blocks + .try_reserve(1) + .map_err(|_| InternalError::MalformedControlFlowGraph); + unwrap_internal(self, result); + let code = self.current_code_info(); + let idx = BlockIdx::new(code.blocks.len().to_u32()); + code.blocks.push(ir::Block::default()); + let result = code.push_unmapped_instr_sequence_label(); + unwrap_internal(self, result); + idx } - fn new_block(&mut self) -> BlockIdx { + fn new_unlabeled_block(&mut self) -> BlockIdx { + let result = self + .current_code_info() + .blocks + .try_reserve(1) + .map_err(|_| InternalError::MalformedControlFlowGraph); + unwrap_internal(self, result); let code = self.current_code_info(); let idx = BlockIdx::new(code.blocks.len().to_u32()); - let inherited_disable_load_fast_borrow = - code.blocks[code.current_block].disable_load_fast_borrow; - let block = ir::Block { - disable_load_fast_borrow: inherited_disable_load_fast_borrow, - ..ir::Block::default() - }; - code.blocks.push(block); + code.blocks.push(ir::Block::default()); + let result = code.push_unlabeled_instr_sequence_block(); + unwrap_internal(self, result); idx } + /// flowgraph.c cfg_builder_new_block + fn cpython_cfg_builder_new_block(&mut self) -> BlockIdx { + self.new_unlabeled_block() + } + + /// flowgraph.c cfg_builder_use_next_block + fn cpython_cfg_builder_use_next_block(&mut self, block: BlockIdx) { + let code = self.current_code_info(); + let cur = code.current_block; + code.blocks[cur.idx()].next = block; + code.current_block = block; + } + fn switch_to_block(&mut self, block: BlockIdx) { + let result = self.current_code_info().use_instr_sequence_label(block); + unwrap_internal(self, result); let code = self.current_code_info(); + let block = code.resolve_instr_sequence_label(block); let prev = code.current_block; assert_ne!(prev, block, "recursive switching {prev:?} -> {block:?}"); assert_eq!( @@ -11381,6 +11103,65 @@ impl Compiler { self.current_source_range = range; } + fn decorated_definition_range( + &self, + statement_range: TextRange, + decorator_list: &[ast::Decorator], + keyword: &str, + ) -> TextRange { + let Some(last_decorator) = decorator_list.last() else { + return statement_range; + }; + let search_start = last_decorator.expression.range().end(); + if search_start >= statement_range.end() { + return statement_range; + } + let search_range = TextRange::new(search_start, statement_range.end()); + let source = self.source_file.slice(search_range); + let Some(keyword_offset) = source.find(keyword) else { + return statement_range; + }; + let Ok(keyword_offset) = u32::try_from(keyword_offset) else { + return statement_range; + }; + TextRange::new( + search_start + TextSize::new(keyword_offset), + statement_range.end(), + ) + } + + fn update_start_location_to_match_attr( + &self, + loc_range: TextRange, + attr_range: TextRange, + attr: &str, + ) -> TextRange { + let source = self.source_file.to_source_code(); + if source.line_index(loc_range.start()) == source.line_index(attr_range.end()) { + return loc_range; + } + let Ok(attr_len) = u32::try_from(attr.len()) else { + return TextRange::new(loc_range.start(), loc_range.end()); + }; + let attr_len = TextSize::new(attr_len); + if attr_len > attr_range.len() { + return TextRange::new(loc_range.start(), loc_range.end()); + } + TextRange::new(attr_range.end() - attr_len, loc_range.end()) + } + + fn source_line_start_range(&self, lineno: u32) -> TextRange { + let source = self.source_file.to_source_code(); + let line = OneIndexed::new(lineno as usize).unwrap_or(OneIndexed::MIN); + let start = source.line_start(line); + TextRange::new(start, start) + } + + fn module_start_location(&self, body: &[ast::Stmt]) -> TextRange { + body.first() + .map_or_else(|| self.source_line_start_range(1), Ranged::range) + } + fn get_source_line_number(&mut self) -> OneIndexed { self.source_file .to_source_code() @@ -11388,7 +11169,14 @@ impl Compiler { } fn mark_generator(&mut self) { - self.current_code_info().flags |= bytecode::CodeFlags::GENERATOR + let is_async = self.ctx.func == FunctionContext::AsyncFunction; + let flags = &mut self.current_code_info().flags; + if is_async { + flags.remove(bytecode::CodeFlags::COROUTINE); + flags.insert(bytecode::CodeFlags::ASYNC_GENERATOR); + } else { + flags.insert(bytecode::CodeFlags::GENERATOR); + } } /// Whether the expression contains an await expression and @@ -11460,18 +11248,25 @@ impl Compiler { let mut element_count = 0; let mut pending_literal = None; + let mut pending_literal_range = None; let mut pending_literal_no_location = false; for part in fstring { self.compile_fstring_part_into( part, &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, false, )?; } - self.set_source_range(fstring_range); - self.finish_fstring(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + Some(fstring_range), + ); Ok(()) } @@ -11485,17 +11280,24 @@ impl Compiler { let mut element_count = 0; let mut pending_literal = None; + let mut pending_literal_range = None; let mut pending_literal_no_location = false; for part in fstring { self.compile_fstring_part_into( part, &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, true, )?; } - self.finish_fstring_join(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring_join( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + ); Ok(()) } @@ -11503,6 +11305,7 @@ impl Compiler { &mut self, part: &ast::FStringPart, pending_literal: &mut Option, + pending_literal_range: &mut Option, pending_literal_no_location: &mut bool, element_count: &mut u32, append_to_join_list: bool, @@ -11511,10 +11314,11 @@ impl Compiler { ast::FStringPart::Literal(string) => { let value = self.compile_fstring_part_literal_value(string); if pending_literal.is_none() { - self.set_source_range(string.range); + *pending_literal_range = Some(string.range); *pending_literal_no_location = string.range == TextRange::default(); *pending_literal = Some(value); } else if let Some(pending) = pending_literal.as_mut() { + Self::extend_pending_literal_range(pending_literal_range, string.range); *pending_literal_no_location &= string.range == TextRange::default(); pending.push_wtf8(value.as_ref()); } @@ -11524,7 +11328,7 @@ impl Compiler { fstring.flags, &fstring.elements, pending_literal, - pending_literal_no_location, + (pending_literal_range, pending_literal_no_location), element_count, append_to_join_list, ), @@ -11534,12 +11338,15 @@ impl Compiler { fn finish_fstring( &mut self, mut pending_literal: Option, + mut pending_literal_range: Option, mut pending_literal_no_location: bool, mut element_count: u32, + fstring_range: Option, ) { let keep_empty = element_count == 0; self.emit_pending_fstring_literal( &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, keep_empty, @@ -11547,10 +11354,16 @@ impl Compiler { ); if element_count == 0 { + if let Some(fstring_range) = fstring_range { + self.set_source_range(fstring_range); + } self.emit_load_const(ConstantData::Str { value: Wtf8Buf::new(), }); } else if element_count > 1 { + if let Some(fstring_range) = fstring_range { + self.set_source_range(fstring_range); + } emit!( self, Instruction::BuildString { @@ -11563,12 +11376,14 @@ impl Compiler { fn finish_fstring_join( &mut self, mut pending_literal: Option, + mut pending_literal_range: Option, mut pending_literal_no_location: bool, mut element_count: u32, ) { let keep_empty = element_count == 0; self.emit_pending_fstring_literal( &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, keep_empty, @@ -11580,6 +11395,7 @@ impl Compiler { fn emit_pending_fstring_literal( &mut self, pending_literal: &mut Option, + pending_literal_range: &mut Option, pending_literal_no_location: &mut bool, element_count: &mut u32, keep_empty: bool, @@ -11588,6 +11404,7 @@ impl Compiler { let Some(value) = pending_literal.take() else { return; }; + let range = pending_literal_range.take(); let no_location = *pending_literal_no_location; *pending_literal_no_location = false; @@ -11598,6 +11415,9 @@ impl Compiler { return; } + if let Some(range) = range { + self.set_source_range(range); + } self.emit_load_const(ConstantData::Str { value }); if no_location { self.set_no_location(); @@ -11608,6 +11428,18 @@ impl Compiler { } } + fn extend_pending_literal_range(pending: &mut Option, range: TextRange) { + let Some(existing) = pending else { + *pending = Some(range); + return; + }; + if *existing == TextRange::default() { + *existing = range; + } else if range != TextRange::default() { + *existing = TextRange::new(existing.start(), range.end()); + } + } + fn count_fstring_parts(&self, fstring: &[ast::FStringPart]) -> u32 { let mut element_count = 0; let mut pending_literal = None; @@ -11663,6 +11495,7 @@ impl Compiler { &mut self, flags: ast::FStringFlags, fstring_elements: &ast::InterpolatedStringElements, + fstring_range: Option, ) -> CompileResult<()> { if self.count_fstring_elements(flags, fstring_elements) > STACK_USE_GUIDELINE { return self.compile_fstring_elements_joined(flags, fstring_elements); @@ -11670,16 +11503,23 @@ impl Compiler { let mut element_count = 0; let mut pending_literal: Option = None; + let mut pending_literal_range: Option = None; let mut pending_literal_no_location = false; self.compile_fstring_elements_into( flags, fstring_elements, &mut pending_literal, - &mut pending_literal_no_location, + (&mut pending_literal_range, &mut pending_literal_no_location), &mut element_count, false, )?; - self.finish_fstring(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + fstring_range, + ); Ok(()) } @@ -11697,37 +11537,58 @@ impl Compiler { let mut element_count = 0; let mut pending_literal: Option = None; + let mut pending_literal_range: Option = None; let mut pending_literal_no_location = false; self.compile_fstring_elements_into( flags, fstring_elements, &mut pending_literal, - &mut pending_literal_no_location, + (&mut pending_literal_range, &mut pending_literal_no_location), &mut element_count, true, )?; - self.finish_fstring_join(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring_join( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + ); Ok(()) } + fn cpython_format_spec_range(&self, range: TextRange) -> TextRange { + let start = range.start().to_usize(); + if start == 0 { + return range; + } + let source = self.source_file.source_text().as_bytes(); + if source.get(start - 1) == Some(&b':') { + TextRange::new(range.start() - TextSize::new(1), range.end()) + } else { + range + } + } + fn compile_fstring_elements_into( &mut self, flags: ast::FStringFlags, fstring_elements: &ast::InterpolatedStringElements, pending_literal: &mut Option, - pending_literal_no_location: &mut bool, + pending_literal_meta: (&mut Option, &mut bool), element_count: &mut u32, append_to_join_list: bool, ) -> CompileResult<()> { + let (pending_literal_range, pending_literal_no_location) = pending_literal_meta; for element in fstring_elements { match element { ast::InterpolatedStringElement::Literal(string) => { let value = self.compile_fstring_literal_value(string, flags); if pending_literal.is_none() { - self.set_source_range(string.range); + *pending_literal_range = Some(string.range); *pending_literal_no_location = string.range == TextRange::default(); *pending_literal = Some(value); } else if let Some(pending) = pending_literal.as_mut() { + Self::extend_pending_literal_range(pending_literal_range, string.range); *pending_literal_no_location &= string.range == TextRange::default(); pending.push_wtf8(value.as_ref()); } @@ -11742,18 +11603,33 @@ impl Compiler { if let Some(ast::DebugText { leading, trailing }) = &fstring_expr.debug_text { let range = fstring_expr.expression.range(); + let leading = strip_fstring_debug_comments(leading); + let trailing = strip_fstring_debug_comments(trailing); let source = self.source_file.slice(range); - let text = [ - strip_fstring_debug_comments(leading).as_str(), - source, - strip_fstring_debug_comments(trailing).as_str(), - ] - .concat(); + let text = [leading.as_str(), source, trailing.as_str()].concat(); + let debug_text_range = TextRange::new( + range.start() + - TextSize::new( + u32::try_from(leading.len()) + .expect("debug f-string leading text too long"), + ), + range.end() + + TextSize::new( + u32::try_from(trailing.len()) + .expect("debug f-string trailing text too long"), + ), + ); let text: Wtf8Buf = text.into(); if pending_literal.is_none() { + *pending_literal_range = Some(debug_text_range); *pending_literal_no_location = false; *pending_literal = Some(Wtf8Buf::new()); + } else { + Self::extend_pending_literal_range( + pending_literal_range, + debug_text_range, + ); } pending_literal.as_mut().unwrap().push_wtf8(text.as_ref()); @@ -11769,6 +11645,7 @@ impl Compiler { self.emit_pending_fstring_literal( pending_literal, + pending_literal_range, pending_literal_no_location, element_count, false, @@ -11777,22 +11654,32 @@ impl Compiler { self.compile_expression(&fstring_expr.expression)?; + let formatted_value_range = fstring_expr.range; match conversion { ConvertValueOparg::None => {} ConvertValueOparg::Str | ConvertValueOparg::Repr | ConvertValueOparg::Ascii => { + self.set_source_range(formatted_value_range); emit!(self, Instruction::ConvertValue { oparg: conversion }) } } match &fstring_expr.format_spec { Some(format_spec) => { - self.compile_fstring_elements(flags, &format_spec.elements)?; - + let format_spec_range = + self.cpython_format_spec_range(format_spec.range); + self.compile_fstring_elements( + flags, + &format_spec.elements, + Some(format_spec_range), + )?; + + self.set_source_range(formatted_value_range); emit!(self, Instruction::FormatWithSpec); } None => { + self.set_source_range(formatted_value_range); emit!(self, Instruction::FormatSimple); } } @@ -11984,7 +11871,11 @@ impl Compiler { let has_format_spec = interp.format_spec.is_some(); if let Some(format_spec) = &interp.format_spec { - self.compile_fstring_elements(ast::FStringFlags::empty(), &format_spec.elements)?; + self.compile_fstring_elements( + ast::FStringFlags::empty(), + &format_spec.elements, + Some(format_spec.range), + )?; } // CPython keeps bit 1 set in BUILD_INTERPOLATION's oparg and uses @@ -12090,17 +11981,20 @@ fn expandtabs(input: &str, tab_size: usize) -> String { expanded_str } -fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, &'a [ast::Stmt]) { +fn split_doc_with_range<'a>( + body: &'a [ast::Stmt], + opts: &CompileOpts, +) -> (Option<(String, TextRange)>, &'a [ast::Stmt]) { if let Some((ast::Stmt::Expr(expr), body_rest)) = body.split_first() { let doc_comment = match &*expr.value { - ast::Expr::StringLiteral(value) => Some(&value.value), + ast::Expr::StringLiteral(value) => Some((&value.value, expr.value.range())), // f-strings are not allowed in Python doc comments. ast::Expr::FString(_) => None, _ => None, }; - if let Some(doc) = doc_comment { + if let Some((doc, range)) = doc_comment { return if opts.optimize < 2 { - (Some(clean_doc(doc.to_str())), body_rest) + (Some((clean_doc(doc.to_str()), range)), body_rest) } else { (None, body_rest) }; @@ -12109,6 +12003,12 @@ fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, (None, body) } +#[cfg(test)] +fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, &'a [ast::Stmt]) { + let (doc, body) = split_doc_with_range(body, opts); + (doc.map(|(doc, _)| doc), body) +} + pub fn ruff_int_to_bigint(int: &ast::Int) -> Result { if let Some(small) = int.as_u64() { Ok(BigInt::from(small)) @@ -12187,7 +12087,7 @@ mod ruff_tests { /// Test if the compiler can correctly identify fstrings containing an `await` expression. #[test] - fn test_fstring_contains_await() { + fn fstring_contains_await() { let range = TextRange::default(); let flags = ast::FStringFlags::empty(); @@ -12383,6 +12283,46 @@ mod tests { compiler.exit_scope() } + #[test] + fn empty_module_implicit_return_inherits_resume_location_like_cpython() { + let code = compile_exec(""); + // CPython 3.14 codegen emits the implicit LOAD_CONST/RETURN_VALUE with + // NO_LOCATION, then flowgraph.c::propagate_line_numbers() propagates + // the module RESUME location, whose line is 0. + assert_eq!(code.linetable.as_ref(), &[0xf2, 0x03, 0x01, 0x01, 0x01]); + } + + #[test] + fn redundant_nop_location_copies_full_location_like_cpython() { + let code = compile_exec( + "\ +def f(x, y, z): + while x: + if y: + pass + elif z: + if y < 0: + return y + if z: + y = y + 1 + elif y: + return 1 + return -1 +", + ); + let f = find_code(&code, "f").expect("missing function code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdf, 0x0a, 0x0b, 0xdf, 0x0b, 0x0c, 0xd9, 0x0c, 0x10, 0xdf, 0x0d, 0x0e, + 0xd8, 0x0f, 0x10, 0x90, 0x31, 0x8c, 0x75, 0xd8, 0x17, 0x18, 0x90, 0x08, 0xdf, 0x0f, + 0x10, 0xd8, 0x14, 0x15, 0x98, 0x01, 0x95, 0x45, 0x92, 0x01, 0xf1, 0x03, 0x00, 0x10, + 0x11, 0xe7, 0x0d, 0x0e, 0x89, 0x51, 0xd9, 0x13, 0x14, 0xd8, 0x0b, 0x0d, 0x80, 0x49, + ], + "CPython basicblock_remove_redundant_nops() copies the full NOP location into a following no-location jump" + ); + } + fn scan_program_symbol_table(source: &str) -> SymbolTable { let source_file = SourceFileBuilder::new("source_path", source).finish(); let parsed = ruff_python_parser::parse( @@ -12400,6 +12340,16 @@ mod tests { .unwrap() } + fn find_symbol_table<'a>(table: &'a SymbolTable, name: &str) -> Option<&'a SymbolTable> { + if table.name == name { + return Some(table); + } + table + .sub_tables + .iter() + .find_map(|sub_table| find_symbol_table(sub_table, name)) + } + fn compile_exec_late_cfg_trace(source: &str) -> Vec<(String, String)> { let opts = CompileOpts::default(); let source_file = SourceFileBuilder::new("source_path", source).finish(); @@ -12439,7 +12389,7 @@ mod tests { ruff_python_ast::Mod::Module(stmts) => stmts, _ => unreachable!(), }; - let symbol_table = SymbolTable::scan_program(&ast, source_file.clone()) + let mut symbol_table = SymbolTable::scan_program(&ast, source_file.clone()) .map_err(|e| e.into_codegen_error(source_file.name().to_owned())) .unwrap(); let function = ast @@ -12450,6 +12400,11 @@ mod tests { _ => None, }) .unwrap_or_else(|| panic!("missing function {function_name}")); + symbol_table.next_sub_table = symbol_table + .sub_tables + .iter() + .position(|table| table.name == function_name) + .unwrap_or_else(|| panic!("missing symbol table for {function_name}")); let name = &function.name; let parameters = &function.parameters; @@ -12469,7 +12424,6 @@ mod tests { let prev_ctx = compiler.ctx; compiler.ctx = CompileContext { - loop_data: None, in_class: prev_ctx.in_class, func: if is_async { FunctionContext::AsyncFunction @@ -12479,6 +12433,24 @@ mod tests { in_async_scope: is_async, }; compiler.set_qualname(); + let (_doc_str, body) = split_doc(body, &compiler.opts); + let start_label = compiler.use_cpython_function_start_label(); + let is_gen = is_async || compiler.current_symbol_table().is_generator; + let stop_iteration_block = if is_gen { + let handler_block = compiler.new_block(); + compiler.insert_cpython_stopiteration_setup_cleanup(handler_block); + compiler + .push_fblock_labels( + FBlockType::StopIteration, + start_label, + ir::InstructionSequenceLabel::NO_LABEL, + FBlockDatum::None, + ) + .unwrap(); + Some(handler_block) + } else { + None + }; compiler.compile_statements(body).unwrap(); match body.last() { Some(ast::Stmt::Return(_)) => {} @@ -12487,6 +12459,19 @@ mod tests { if compiler.current_code_info().metadata.consts.is_empty() { compiler.arg_constant(ConstantData::None); } + if let Some(handler_block) = stop_iteration_block { + compiler.pop_fblock_label(FBlockType::StopIteration, start_label); + compiler.use_cpython_label_block(handler_block); + emit!( + compiler, + Instruction::CallIntrinsic1 { + func: oparg::IntrinsicFunction1::StopIterationError + } + ); + compiler.set_no_location(); + emit!(compiler, Instruction::Reraise { depth: 1u32 }); + compiler.set_no_location(); + } let _table = compiler.pop_symbol_table(); let stack_top = compiler.code_stack.pop().unwrap(); @@ -12494,44 +12479,43 @@ mod tests { } #[test] - #[ignore = "debug helper"] - fn debug_trace_nested_continue_after_optional_body() { + fn try_else_nested_try_const_list_keeps_setup_finally_nop() { let trace = compile_single_function_late_cfg_trace( - "\ -def f(names, show_empty, keywords, args_buffer, args, cls, object, level): - for name in names: - value = getattr(cls, name) - if not show_empty: - if value == []: - field_type = cls._field_types.get(name, object) - if getattr(field_type, '__origin__', ...) is list: - if not keywords: - args_buffer.append(repr(value)) - continue - if not keywords: - args.extend(args_buffer) - args_buffer = [] - value, simple = _format(value, level) - if keywords: - args.append('%s=%s' % (name, value)) + r#" +def f(arch): + try: + [arch, *_] = g() + except OSError: + pass + else: + try: + arch = ['x86', 'MIPS', 'Alpha', 'PowerPC', None, + 'ARM', 'ia64', None, None, + 'AMD64', None, None, 'ARM64', + ][int(arch)] + except (ValueError, IndexError): + pass else: - args.append(value) -", + if arch: + return arch +"#, "f", ); - for (label, dump) in trace { - if label == "after_reorder" - || label == "after_remove_redundant_nops_and_jumps" - || label == "after_final_cfg_cleanup" - || label == "after_borrow_deopts" - { - eprintln!("=== {label} ===\n{dump}"); - } - } + let (_, dump) = trace + .iter() + .find(|(label, _)| label == "after_convert_pseudo_ops") + .expect("missing convert_pseudo_ops trace"); + assert!( + dump.contains("[disp=8:9 raw=8:9-17:28 override=None] Real(Nop)"), + "SETUP_FINALLY should survive as a line-bearing NOP like CPython" + ); + assert!( + dump.contains("[disp=9:20 raw=9:20-12:14 override=None] Real(BuildList"), + "CPython optimize_lists_and_sets() restores the literal location to BUILD_LIST" + ); } #[test] - #[ignore = "debug helper"] fn debug_trace_make_dataclass_borrow_tail() { let trace = compile_single_function_late_cfg_trace( r#" @@ -12562,7 +12546,6 @@ def f(module, cls, decorator, init, repr, eq, order, unsafe_hash, frozen, match_ } #[test] - #[ignore = "debug helper"] fn debug_trace_protected_attr_subscript_tail() { let trace = compile_single_function_late_cfg_trace( r#" @@ -12587,7 +12570,6 @@ def f(f, oldcls, newcls): } #[test] - #[ignore = "debug helper"] fn debug_trace_dtrace_tail() { let trace = compile_single_function_late_cfg_trace( r#" @@ -12607,9 +12589,8 @@ def f(proc, unittest): "f", ); for (label, dump) in trace { - if label == "after_raw_optimize_load_fast_borrow" + if label == "after_optimize_load_fast" || label.contains("deoptimize_borrow_in_protected_conditional_tail") - || label.contains("deoptimize_borrow_after_terminal_except_tail") { eprintln!("=== {label} ===\n{dump}"); } @@ -12617,7 +12598,6 @@ def f(proc, unittest): } #[test] - #[ignore = "debug helper"] fn debug_trace_colorize_tail() { let trace = compile_single_function_late_cfg_trace( r#" @@ -12638,9 +12618,8 @@ def f(sys, os, file): "f", ); for (label, dump) in trace { - if label == "after_raw_optimize_load_fast_borrow" + if label == "after_optimize_load_fast" || label == "after_deoptimize_borrow_after_protected_import" - || label == "after_optimize_load_fast_borrow" || label == "after_borrow_deopts" { eprintln!("=== {label} ===\n{dump}"); @@ -12649,333 +12628,3044 @@ def f(sys, os, file): } #[test] - fn test_named_except_continue_resume_try_body_keeps_method_borrows() { + fn for_try_except_break_keeps_cpython_if_layout() { let code = compile_exec( - r#" -def f(self, block=True): - if not block and not self.wait(timeout=0): - return None - while self.event_queue.empty(): - while True: - try: - self.push_char(self.read(1)) - except OSError as err: - if err.errno == errno.EINTR: - if not self.event_queue.empty(): - return self.event_queue.get() - else: - continue - else: - raise - else: + "\ +def f(support, func, value): + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): + try: + if func() == value: break - return self.event_queue.get() -"#, + except NotImplementedError: + break + sink(value) +", ); let f = find_code(&code, "f").expect("missing f code"); - let instructions: Vec<_> = f + let ops = f .instructions .iter() - .filter(|unit| !matches!(unit.op, Instruction::Cache)) - .collect(); - - let attr_idx = |name: &str| { - instructions - .iter() - .position(|unit| match unit.op { - Instruction::LoadAttr { namei } => { - let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); - f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == name - } - _ => false, - }) - .unwrap_or_else(|| panic!("missing {name} LOAD_ATTR")) - }; - let push_char_idx = attr_idx("push_char"); - let read_idx = attr_idx("read"); - - assert!( - matches!( - instructions[push_char_idx - 1].op, - Instruction::LoadFastBorrow { .. } - ), - "named-except cleanup continue backedge should not deopt the protected try-body method receiver, got instructions={:?}", - instructions.iter().map(|unit| unit.op).collect::>() - ); + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect::>(); + let cond = ops + .iter() + .position(|op| matches!(op, Instruction::PopJumpIfFalse { .. })) + .expect("missing CPython-style false jump for if/break"); assert!( matches!( - instructions[read_idx - 1].op, - Instruction::LoadFastBorrow { .. } + ops.get(cond..cond + 5), + Some([ + Instruction::PopJumpIfFalse { .. }, + Instruction::NotTaken, + Instruction::PopTop, + Instruction::JumpForward { .. }, + Instruction::JumpBackward { .. }, + ]) ), - "nested protected try-body method receiver should remain borrowed like CPython, got instructions={:?}", - instructions.iter().map(|unit| unit.op).collect::>() + "CPython codegen_if() keeps the break cleanup in the true-body fallthrough before the loop backedge, got ops={ops:?}" ); } #[test] - fn test_boolop_or_shared_body_keeps_false_jump_before_loop_backedge() { + fn try_else_loop_break_keeps_body_before_protected_backedge() { let code = compile_exec( - r#" -def f(value): - digits = [] - for digit in value: - if isinstance(digit, int) and 0 <= digit <= 9: - if digits or digit != 0: - digits.append(digit) + "\ +def f(input): + while 1: + try: + pass + except IndexError: + break else: - raise ValueError - return digits -"#, + key = None + while key is None: + key = input() + if key not in ('', 'q'): + key = None + if key == 'q': + break +", ); - let f = find_code(&code, "f").expect("missing function code"); - let ops: Vec<_> = f + let f = find_code(&code, "f").expect("missing f code"); + let ops = f .instructions .iter() .map(|unit| unit.op) .filter(|op| !matches!(op, Instruction::Cache)) - .collect(); - + .collect::>(); assert!( - ops.windows(11).any(|window| { + ops.windows(6).any(|window| { matches!( window, [ - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::ToBool, - Instruction::PopJumpIfTrue { .. }, - Instruction::NotTaken, - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::LoadSmallInt { .. }, Instruction::CompareOp { .. }, Instruction::PopJumpIfFalse { .. }, Instruction::NotTaken, - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::LoadAttr { .. }, + Instruction::LoadConst { .. }, + Instruction::ReturnValue, + Instruction::JumpBackward { .. }, ] ) }), - "OR-shared body should keep CPython last-condition false jump before the loop backedge, got ops={ops:?}" + "CPython codegen_if() keeps the break body before the false backedge into the protected try/except loop, got ops={ops:?}" + ); + } + + #[test] + fn loop_nested_if_tail_keeps_duplicate_jump_back_label() { + let code = compile_exec( + "\ +def f(value, digits): + for digit in value: + if isinstance(digit, int) and 0 <= digit <= 9: + if digits or digit != 0: + digits.append(digit) + else: + raise ValueError('x') + return digits +", ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache | Instruction::Nop)) + .collect::>(); assert!( - !ops.windows(6).any(|window| { + ops.windows(4).any(|window| { matches!( window, [ - Instruction::CompareOp { .. }, - Instruction::PopJumpIfTrue { .. }, - Instruction::NotTaken, - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. }, - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::LoadAttr { .. }, + Instruction::PopTop, + Instruction::JumpBackward { .. }, + Instruction::JumpBackward { .. }, + Instruction::LoadGlobal { .. }, ] ) }), - "OR-shared body should not be moved after the implicit loop backedge, got ops={ops:?}" + "CPython codegen_if() leaves a distinct no-location end label before the loop else/raise path, got ops={ops:?}" ); } #[test] - fn test_single_if_loop_backedge_keeps_true_body_fallthrough_backedge_shape() { + fn match_for_break_threads_empty_end_label_to_outer_backedge() { let code = compile_exec( - r##" -def f(buffer, pos, last_char): - while pos > 0: - pos -= 1 - if buffer[pos] == "#": - last_char = None - return last_char -"##, + "\ +def f(items, T): + for st in items: + match st.type: + case T.TYPE: + for c in st.children: + if c.name == st.name: + x = 1 + break + return x +", ); let f = find_code(&code, "f").expect("missing f code"); - let ops: Vec<_> = f + let ops = f .instructions .iter() .map(|unit| unit.op) - .filter(|op| !matches!(op, Instruction::Cache)) - .collect(); - + .filter(|op| !matches!(op, Instruction::Cache | Instruction::Nop)) + .collect::>(); assert!( - ops.windows(6).any(|window| { + ops.windows(5).any(|window| { matches!( window, [ - Instruction::CompareOp { .. }, - Instruction::PopJumpIfTrue { .. }, - Instruction::NotTaken, - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. }, - Instruction::LoadConst { .. }, - Instruction::StoreFast { .. }, + Instruction::PopTop, + Instruction::JumpBackward { .. }, + Instruction::EndFor, + Instruction::PopIter, + Instruction::JumpBackward { .. }, ] ) }), - "single-if loop tail should keep CPython true-body plus fallthrough-backedge shape, got ops={ops:?}", + "CPython codegen_break() threads the match-case inner for break through the empty end label to the outer loop backedge, got ops={ops:?}" + ); + } + + #[test] + fn match_constant_guard_keeps_cpython_guard_nop_before_subject_pop() { + let code = compile_exec( + "\ +def f(self): + x = 0 + match x: + case 0 if True: + y = 0 + case 0 if True: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) +", ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect::>(); assert!( - !ops.windows(4).any(|window| { + ops.windows(7).any(|window| { matches!( window, [ Instruction::CompareOp { .. }, Instruction::PopJumpIfFalse { .. }, Instruction::NotTaken, - Instruction::LoadConst { .. }, + Instruction::Nop, + Instruction::PopTop, + Instruction::LoadSmallInt { .. } | Instruction::LoadConst { .. }, + Instruction::StoreFast { .. }, ] ) }), - "single-if loop tail should not be inverted away from CPython shape, got ops={ops:?}", + "CPython codegen_match_inner() emits the guard through codegen_jump_if(), and flowgraph.c keeps the folded constant-guard NOP in a separate success block before POP_TOP, got ops={ops:?}" ); } - fn find_code<'a>(code: &'a CodeObject, name: &str) -> Option<&'a CodeObject> { - if code.obj_name == name { - return Some(code); - } - code.constants.iter().find_map(|constant| { - if let ConstantData::Code { code } = constant { - find_code(code, name) - } else { - None - } - }) - } - - fn has_common_constant(code: &CodeObject, expected: bytecode::CommonConstant) -> bool { - code.instructions.iter().any(|unit| match unit.op { - Instruction::LoadCommonConstant { idx } => { - idx.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected - } - _ => false, - }) - } - - fn has_intrinsic_1(code: &CodeObject, expected: IntrinsicFunction1) -> bool { - code.instructions.iter().any(|unit| match unit.op { - Instruction::CallIntrinsic1 { func } => { - func.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected - } - _ => false, - }) - } - #[test] - fn test_trace_assert_true_try_pair() { - let trace = compile_exec_late_cfg_trace( + fn match_or_default_tail_uses_cpython_load_fast_borrow() { + let code = compile_exec( "\ -try: - assert True -except AssertionError as e: - fail() -try: - assert True, 'msg' -except AssertionError as e: - fail() +def f(format, annotationlib, cls, annotation_fields, return_type, MISSING): + Format = annotationlib.Format + match format: + case Format.VALUE | Format.FORWARDREF | Format.STRING: + cls_annotations = {} + for base in reversed(cls.__mro__): + cls_annotations.update( + annotationlib.get_annotations(base, format=format) + ) + new_annotations = {} + for k in annotation_fields: + try: + new_annotations[k] = cls_annotations[k] + except KeyError: + pass + if return_type is not MISSING: + if format == Format.STRING: + new_annotations['return'] = annotationlib.type_repr(return_type) + else: + new_annotations['return'] = return_type + return new_annotations + case _: + raise NotImplementedError(format) ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let raise = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::RaiseVarargs { .. })) + .expect("missing default raise"); + let load_format = &ops[raise - 2]; + let Instruction::LoadFastBorrow { var_num } = load_format.op else { + panic!( + "CPython codegen_match_inner() emits the default case without a load-fast barrier, so optimize_load_fast() borrows the raise argument; got ops={ops:?}" + ); + }; + let arg = OpArg::new(u32::from(u8::from(load_format.arg))); + assert_eq!(f.varnames[usize::from(var_num.get(arg))], "format"); } #[test] - fn test_trace_for_unpack_list_literal() { - let trace = compile_exec_late_cfg_trace( + fn preceding_match_or_default_tail_keeps_cpython_strong_load_fast() { + let code = compile_exec( "\ -result = [] -for x, in [(1,), (2,), (3,)]: - result.append(x) +def f(format): + match format: + case _lazy_annotationlib.Format.VALUE | _lazy_annotationlib.Format.FORWARDREF: + return checked_types + case _lazy_annotationlib.Format.STRING: + return _lazy_annotationlib.annotations_to_string(types) + case _: + raise NotImplementedError(format) ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let raise = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::RaiseVarargs { .. })) + .expect("missing default raise"); + let load_format = &ops[raise - 2]; + let Instruction::LoadFast { var_num } = load_format.op else { + panic!( + "CPython keeps the default raise argument strong when an earlier copied OR-pattern precedes the final non-default case; got ops={ops:?}" + ); + }; + let arg = OpArg::new(u32::from(u8::from(load_format.arg))); + assert_eq!(f.varnames[usize::from(var_num.get(arg))], "format"); } #[test] - fn test_trace_break_in_finally_function() { - let trace = compile_single_function_late_cfg_trace( + fn try_else_after_nested_try_except_exit_keeps_cpython_strong_load_fast() { + let code = compile_exec( "\ def f(self): - count = 0 - while count < 2: - count += 1 + try: try: - pass - finally: - break - self.assertEqual(count, 1) + 1 / 0 + except ZeroDivisionError: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail('No exception raised') ", - "f", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let fail = ops + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "fail" + } + _ => false, + }) + .expect("missing fail load"); + assert!( + matches!(ops[fail - 1].op, Instruction::LoadFast { .. }), + "CPython codegen_try_except() keeps the orelse entry after a nested try/except end label strong, got ops={ops:?}" + ); } #[test] - fn test_import_originated_name_disables_method_call_optimization_even_with_local_import() { + fn try_after_inherited_try_barrier_keeps_successor_loads_strong() { let code = compile_exec( "\ -import warnings +def f(self, x): + try: + 1 / 0 + except EOFError: + pass + except TypeError as msg: + pass + except: + pass + else: + pass + try: + x + except (EOFError, TypeError, ZeroDivisionError): + pass + with self.assertRaises(SyntaxError): + pass -def f(ch): - import warnings - warnings.warn( - '\"\\\\%c\" is an invalid escape sequence' % ch - if 0x20 <= ch < 0x7F - else '\"\\\\x%02x\" is an invalid escape sequence' % ch, - DeprecationWarning, - stacklevel=2, - ) +def g(self): + try: + 1 / 0 + except: + pass + try: + 1 / 0 + except (EOFError, TypeError, ZeroDivisionError): + pass + with self.assertRaises(SyntaxError): + pass ", ); let f = find_code(&code, "f").expect("missing f code"); - let ops: Vec<_> = f.instructions.iter().map(|unit| unit.op).collect(); - let warn_attr = ops - .iter() - .position(|op| matches!(op, Instruction::LoadAttr { .. })) - .expect("missing LOAD_ATTR for warnings.warn"); - let push_null = ops[warn_attr + 10..] - .iter() - .position(|op| matches!(op, Instruction::PushNull)) - .map(|idx| warn_attr + 10 + idx) - .expect("expected PUSH_NULL after plain LOAD_ATTR"); - - let load_attr = match f.instructions[warn_attr].op { - Instruction::LoadAttr { namei } => namei.get(OpArg::new(u32::from(u8::from( - f.instructions[warn_attr].arg, - )))), - _ => unreachable!(), - }; + let x_loads = load_fast_ops_for_var(f, "x"); assert!( - !load_attr.is_method(), - "import-originated names should use plain LOAD_ATTR" + x_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_try_except() reaches the first try end through USE_LABEL(end); flowgraph.c keeps that barrier state through the following try end, got x loads {x_loads:?}" ); + let self_loads = load_fast_ops_for_var(f, "self"); assert!( - matches!(ops[push_null + 1], Instruction::LoadSmallInt { .. }), - "expected warning message expression to start after PUSH_NULL, got ops={ops:?}" + self_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython keeps successor with-statement loads strong after a try that started from an inherited try barrier, got self loads {self_loads:?}" + ); + let g = find_code(&code, "g").expect("missing g code"); + let g_self_loads = load_fast_ops_for_var(g, "self"); + assert!( + g_self_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython keeps a bare-handler try end label as a barrier when the next statement is another try, got self loads {g_self_loads:?}" ); } #[test] - fn test_trace_constant_false_elif_chain() { - let trace = compile_exec_late_cfg_trace( + fn loop_continue_try_before_try_else_keeps_orelse_loads_strong() { + let code = compile_exec( "\ -if 0: pass -elif 0: pass -elif 0: pass -elif 0: pass -else: pass +def f(candidate_locales, locales): + for loc in candidate_locales: + try: + work(loc) + except Error: + continue + encoding = getencoding() + try: + localeconv() + except Exception as err: + print(loc, encoding, type(err), err) + else: + locales.append(loc) ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } - } + let f = find_code(&code, "f").expect("missing f code"); + let locales_loads = load_fast_ops_for_var(f, "locales"); + assert!( + locales_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_try_except() leaves an empty end label after a loop try whose handlers continue; optimize_load_fast() keeps the following try/else append receiver strong, got {locales_loads:?}" + ); + let loc_loads = load_fast_ops_for_var(f, "loc"); + assert!( + loc_loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython keeps the try/else append argument strong after the inherited end-label barrier, got {loc_loads:?}" + ); + } + + #[test] + fn try_else_after_try_finally_conditional_finalbody_keeps_store_attr_loads_strong() { + let code = compile_exec( + "\ +def f(self, w, pid, prev): + try: + try: + if cond: + prev = call() + pid = spawn() + finally: + if prev is not None: + reset(prev) + except: + close(w) + raise + else: + self._fd = w + self._pid = pid + finally: + close(r) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let w_self_pairs = load_fast_pair_ops_for_vars(f, "w", "self"); + let pid_self_pairs = load_fast_pair_ops_for_vars(f, "pid", "self"); + assert!( + w_self_pairs + .iter() + .all(|op| matches!(op, Instruction::LoadFastLoadFast { .. })) + && !w_self_pairs.is_empty(), + "CPython codegen_try_finally() calls codegen_try_except() inside the active finally try; the else suite starts from the inner try/finally exit label and keeps w/self strong, got {w_self_pairs:?}" + ); + assert!( + pid_self_pairs + .iter() + .all(|op| matches!(op, Instruction::LoadFastLoadFast { .. })) + && !pid_self_pairs.is_empty(), + "CPython keeps the second store-attr source pair strong in the same try/except/else/finally else suite, got {pid_self_pairs:?}" + ); + } + + #[test] + fn try_except_end_before_following_try_keeps_protected_attr_loads_strong() { + let code = compile_exec( + "\ +def f(f, dotlock=True): + dotlock_done = False + try: + if dotlock: + try: + pre_lock = _create_temporary(f.name + '.lock') + pre_lock.close() + except OSError as e: + if e.errno in (errno.EACCES, errno.EROFS): + return + else: + raise + try: + try: + os.link(pre_lock.name, f.name + '.lock') + dotlock_done = True + except (AttributeError, PermissionError): + os.rename(pre_lock.name, f.name + '.lock') + dotlock_done = True + else: + os.unlink(pre_lock.name) + except FileExistsError: + os.remove(pre_lock.name) + raise ExternalClashError('dot lock unavailable: %s' % + f.name) + except: + if dotlock_done: + os.remove(f.name + '.lock') + raise +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let pre_lock_loads = load_fast_ops_for_var(f, "pre_lock"); + let strong_pre_lock_loads = pre_lock_loads + .iter() + .filter(|op| matches!(op, Instruction::LoadFast { .. })) + .count(); + assert!( + strong_pre_lock_loads >= 2, + "CPython codegen_try_except() emits USE_LABEL(end) before the following try; flowgraph.c::optimize_load_fast() does not push fallthrough through the empty end label, so protected pre_lock attr receivers stay strong, got {pre_lock_loads:?}" + ); + let f_loads = load_fast_ops_for_var(f, "f"); + assert!( + f_loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython keeps the f.name receiver inside the following protected try strong after the preceding try/except end label, got {f_loads:?}" + ); + } + + #[test] + fn try_except_method_probe_end_before_if_keeps_loads_strong() { + let code = compile_exec( + "\ +def f(param, value=None, quote=True): + if value is not None and len(value) > 0: + if isinstance(value, tuple): + param += '*' + value = encode(value[2], value[0], value[1]) + return f'{param}={value}' + else: + try: + value.encode('ascii') + except UnicodeEncodeError: + param += '*' + value = encode(value, 'utf-8', '') + return f'{param}={value}' + if quote or tspecials.search(value): + return f'{param}=\"{quote_value(value)}\"' + else: + return f'{param}={value}' + else: + return param +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let quote_loads = load_fast_ops_for_var(f, "quote"); + assert!( + quote_loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_try_except() emits USE_LABEL(end) after a handled method probe; flowgraph.c::optimize_load_fast() leaves the following if-test local strong, got {quote_loads:?}" + ); + let param_loads = load_fast_ops_for_var(f, "param"); + assert!( + param_loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython keeps the return f-string local loads strong after the protected method-probe end label, got {param_loads:?}" + ); + let value_loads = load_fast_ops_for_var(f, "value"); + assert!( + value_loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython keeps the post-try value loads strong after the protected method-probe end label, got {value_loads:?}" + ); + } + + #[test] + fn try_except_fallthrough_before_return_call_keeps_borrow() { + let code = compile_exec( + "\ +def f(obj, lock, ctx, cls, class_cache): + if cond1(obj): + return Synchronized(obj, lock, ctx) + elif cond2(obj): + return SynchronizedArray(obj, lock, ctx) + else: + try: + scls = class_cache[cls] + except KeyError: + scls = make_synchronized(cls) + return scls(obj, lock, ctx) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let scls_loads = load_fast_ops_for_var(f, "scls"); + assert!( + scls_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFastBorrow { .. })), + "CPython codegen_try_except() emits USE_LABEL(end) and the following return call into the same basic block; flowgraph.c borrows the callable local, got {scls_loads:?}" + ); + let obj_lock_pairs = load_fast_pair_ops_for_vars(f, "obj", "lock"); + assert!( + obj_lock_pairs + .iter() + .all(|op| matches!(op, Instruction::LoadFastBorrowLoadFastBorrow { .. })), + "CPython flowgraph.c borrows the return call argument pair after a typed fallthrough handler, got {obj_lock_pairs:?}" + ); + let ctx_loads = load_fast_ops_for_var(f, "ctx"); + assert!( + ctx_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFastBorrow { .. })), + "CPython flowgraph.c borrows the trailing return call argument after the try-end label, got {ctx_loads:?}" + ); + } + + #[test] + fn try_except_comprehension_handler_before_return_call_keeps_borrow() { + let code = compile_exec( + "\ +def f(obj, lock, ctx): + assert not isinstance(obj, SynchronizedBase), 'object already synchronized' + ctx = ctx or get_context() + + if isinstance(obj, ctypes._SimpleCData): + return Synchronized(obj, lock, ctx) + elif isinstance(obj, ctypes.Array): + if obj._type_ is ctypes.c_char: + return SynchronizedString(obj, lock, ctx) + return SynchronizedArray(obj, lock, ctx) + else: + cls = type(obj) + try: + scls = class_cache[cls] + except KeyError: + names = [field[0] for field in cls._fields_] + d = {name: make_property(name) for name in names} + classname = 'Synchronized' + cls.__name__ + scls = class_cache[cls] = type(classname, (SynchronizedBase,), d) + return scls(obj, lock, ctx) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let scls_loads = load_fast_ops_for_var(f, "scls"); + assert!( + scls_loads + .iter() + .any(|op| matches!(op, Instruction::LoadFastBorrow { .. })) + && scls_loads + .iter() + .all(|op| !matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_try_except() keeps the return call in the end-label block even when the handler contains comprehensions, got {scls_loads:?}" + ); + let obj_lock_pairs = load_fast_pair_ops_for_vars(f, "obj", "lock"); + assert!( + obj_lock_pairs + .iter() + .any(|op| matches!(op, Instruction::LoadFastBorrowLoadFastBorrow { .. })) + && obj_lock_pairs + .iter() + .all(|op| !matches!(op, Instruction::LoadFastLoadFast { .. })), + "CPython flowgraph.c borrows the return call argument pair after the try-end label and cold handler reordering, got {obj_lock_pairs:?}" + ); + let instructions = non_cache_instructions(f).collect::>(); + let ctx_idx = varname_index(f, "ctx"); + let has_borrowed_return_tail = instructions.windows(6).any(|window| { + matches!(window[0].op, Instruction::LoadFastBorrow { .. }) + && matches!(window[1].op, Instruction::PushNull) + && matches!( + window[2].op, + Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + && matches!(window[3].op, Instruction::LoadFastBorrow { .. }) + && { + let Instruction::LoadFastBorrow { var_num } = window[3].op else { + return false; + }; + usize::from(var_num.get(OpArg::new(u32::from(u8::from(window[3].arg))))) + == ctx_idx + } + && matches!(window[4].op, Instruction::Call { .. }) + && matches!(window[5].op, Instruction::ReturnValue) + }); + assert!( + has_borrowed_return_tail, + "CPython flowgraph.c borrows the full final return-call tail after the protected try body, got instructions={instructions:?}" + ); + } + + #[test] + fn try_finally_closed_conditional_exit_allows_cpython_borrow() { + let code = compile_exec( + "\ +def f(self, os, tempfile, oldmode): + try: + work() + finally: + if os.name == 'nt': + os.chmod(tempfile.tempdir, oldmode) + else: + os.chmod(tempfile.tempdir, oldmode) + self.assertEqual(os.listdir(tempfile.tempdir), []) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let assert_equal = ops + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() + == "assertEqual" + } + _ => false, + }) + .expect("missing assertEqual load"); + assert!( + matches!(ops[assert_equal - 1].op, Instruction::LoadFastBorrow { .. }), + "CPython codegen_try_finally() does not make a load-fast barrier after a closed conditional finalbody, got ops={ops:?}" + ); + } + + #[test] + fn handler_resume_after_nested_try_keeps_successor_load_fast_strong() { + let code = compile_exec( + r#" +def f(x): + try: + try: + import readline + except ImportError: + readline = None + else: + import rlcompleter + except ImportError: + return + try: + if x: + y = 1 + except ImportError: + return +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let truthiness_load = ops + .windows(3) + .find(|window| { + let arg = OpArg::new(u32::from(u8::from(window[0].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "x" + ) && matches!(window[1].op, Instruction::ToBool) + && matches!(window[2].op, Instruction::PopJumpIfFalse { .. }) + }) + .unwrap_or_else(|| { + panic!( + "missing if x truthiness load: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(truthiness_load[0].op, Instruction::LoadFast { .. }), + "CPython flowgraph.c leaves an empty try-end label before this successor block, so optimize_load_fast() does not borrow the if-test load: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn nested_finally_handler_try_end_keeps_return_load_fast_strong() { + let code = compile_exec( + r#" +def f(sys): + try: + import _testinternalcapi + depth = _testinternalcapi.get_recursion_depth() + except (ImportError, RecursionError) as exc: + try: + depth = 0 + frame = sys._getframe() + while frame is not None: + depth += 1 + frame = frame.f_back + finally: + frame = None + return max(depth - 1, 1) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let depth_load = ops + .windows(3) + .find(|window| { + let arg = OpArg::new(u32::from(u8::from(window[0].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "depth" + ) && matches!(window[1].op, Instruction::LoadSmallInt { .. }) + && matches!(window[2].op, Instruction::BinaryOp { .. }) + }) + .unwrap_or_else(|| { + panic!( + "missing return depth load: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(depth_load[0].op, Instruction::LoadFast { .. }), + "CPython flowgraph.c preserves an empty try-end label before this return block, so optimize_load_fast() leaves depth strong: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn try_except_finally_exit_label_keeps_return_load_fast_strong() { + let code = compile_exec( + r#" +def f(): + global _importing_zlib + if _importing_zlib: + _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') + raise ZipImportError("can't decompress data; zlib not available") + + _importing_zlib = True + try: + from zlib import decompress + except Exception: + _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') + raise ZipImportError("can't decompress data; zlib not available") + finally: + _importing_zlib = False + + _bootstrap._verbose_message('zipimport: zlib available') + return decompress +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let return_load = ops + .windows(2) + .find(|window| { + let arg = OpArg::new(u32::from(u8::from(window[0].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "decompress" + ) && matches!(window[1].op, Instruction::ReturnValue) + }) + .unwrap_or_else(|| { + panic!( + "missing return decompress load: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(return_load[0].op, Instruction::LoadFast { .. }), + "CPython codegen_try_finally() emits a JUMP_NO_INTERRUPT to an empty exit label after the normal finally body, and flowgraph.c::optimize_load_fast() does not fall through an empty block: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn bare_except_finally_exit_label_keeps_successor_load_fast_strong() { + let code = compile_exec( + r#" +def f(self): + hit = False + try: + pass + except: + hit = True + finally: + done = True + self.assertFalse(hit) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let assert_false = ops + .windows(3) + .find(|window| { + let self_arg = OpArg::new(u32::from(u8::from(window[0].arg))); + let attr_arg = OpArg::new(u32::from(u8::from(window[1].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(self_arg))] == "self" + ) && matches!( + window[1].op, + Instruction::LoadAttr { namei } + if f.names[usize::try_from(namei.get(attr_arg).name_idx()).unwrap()] + == "assertFalse" + ) && matches!( + window[2].op, + Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. } + ) + }) + .unwrap_or_else(|| { + panic!( + "missing assertFalse call: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(assert_false[0].op, Instruction::LoadFast { .. }) + && matches!(assert_false[2].op, Instruction::LoadFast { .. }), + "CPython codegen_try_finally() wraps codegen_try_except(); with a bare handler, the normal finally body jumps to an empty exit label, and flowgraph.c::optimize_load_fast() does not fall through that empty block: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn typed_except_finally_fallthrough_keeps_successor_load_fast_borrow() { + let code = compile_exec( + r#" +def f(self): + hit = False + try: + pass + except Exception: + hit = True + finally: + done = True + self.assertFalse(hit) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let assert_false = ops + .windows(3) + .find(|window| { + let self_arg = OpArg::new(u32::from(u8::from(window[0].arg))); + let attr_arg = OpArg::new(u32::from(u8::from(window[1].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(self_arg))] == "self" + ) && matches!( + window[1].op, + Instruction::LoadAttr { namei } + if f.names[usize::try_from(namei.get(attr_arg).name_idx()).unwrap()] + == "assertFalse" + ) && matches!( + window[2].op, + Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. } + ) + }) + .unwrap_or_else(|| { + panic!( + "missing assertFalse call: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(assert_false[0].op, Instruction::LoadFastBorrow { .. }) + && matches!(assert_false[2].op, Instruction::LoadFastBorrow { .. }), + "CPython typed-handler fallthrough keeps this successor reachable for optimize_load_fast(), so the loads remain borrowed: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn bare_except_finally_no_exception_shares_return_target() { + let code = compile_exec( + "\ +def func(): + try: + 2 + except: + 4 + finally: + 6 +", + ); + let f = find_code(&code, "func").expect("missing func code"); + let ops = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect::>(); + let first_push_exc = ops + .iter() + .position(|op| matches!(op, Instruction::PushExcInfo)) + .expect("missing PushExcInfo"); + let returns_before_handler = ops[..first_push_exc] + .iter() + .filter(|op| matches!(op, Instruction::ReturnValue)) + .count(); + + assert_eq!( + returns_before_handler, 1, + "CPython codegen_try_finally() wraps codegen_try_except(); the bare handler jumps back to the normal finally return target instead of forcing duplicate_end_returns() to create a Rust-only return copy, got ops={ops:?}" + ); + } + + #[test] + fn for_exhaustion_assert_false_message_borrows_load_fast() { + let code = compile_exec( + r#" +def f(arg, opcode): + for i, nb_op in enumerate(opcode._nb_ops): + if arg == nb_op[0]: + return i + assert False, f"{arg} is not a valid BINARY_OP argument." +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let assertion_arg_load = ops + .windows(2) + .find(|window| { + let arg = OpArg::new(u32::from(u8::from(window[0].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "arg" + ) && matches!(window[1].op, Instruction::FormatSimple) + }) + .unwrap_or_else(|| { + panic!( + "missing assertion message arg load: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(assertion_arg_load[0].op, Instruction::LoadFastBorrow { .. }), + "CPython codegen_assert() emits AssertionError directly after the failing test, so flowgraph.c::optimize_load_fast() visits the assertion message block: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn try_except_else_conditional_join_borrows_else_receiver() { + let code = compile_exec( + r#" +def f(self): + try: + if not self.result_is_file() or not self.sendfile(): + for data in self.result: + self.write(data) + self.finish_content() + except: + if hasattr(self.result, 'close'): + self.result.close() + raise + else: + self.close() +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let else_close_load = ops + .windows(6) + .find(|window| { + let load_arg = OpArg::new(u32::from(u8::from(window[0].arg))); + let attr_arg = OpArg::new(u32::from(u8::from(window[1].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(load_arg))] == "self" + ) && matches!( + window[1].op, + Instruction::LoadAttr { namei } + if f.names[usize::try_from(namei.get(attr_arg).name_idx()).unwrap()] + == "close" + ) && matches!(window[2].op, Instruction::Call { .. }) + && matches!(window[3].op, Instruction::PopTop) + && matches!(window[4].op, Instruction::LoadConst { .. }) + && matches!(window[5].op, Instruction::ReturnValue) + }) + .unwrap_or_else(|| { + panic!( + "missing else self.close return sequence: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + assert!( + matches!(else_close_load[0].op, Instruction::LoadFastBorrow { .. }), + "CPython codegen_try_except() emits Try.orelse directly with VISIT_SEQ(), so flowgraph.c::optimize_load_fast() reaches the else self.close receiver: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn try_finally_exit_label_reuses_empty_block_for_borrow() { + let code = compile_exec( + r#" +def f(self, exc): + try: + end_time = self.get_time() + finally: + result = self.context.__exit__(*exc) + self.seconds = end_time - self.start_time + return result +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + let seconds_store = ops + .windows(5) + .find(|window| { + let load_pair = OpArg::new(u32::from(u8::from(window[0].arg))); + let self_load = OpArg::new(u32::from(u8::from(window[3].arg))); + matches!( + window[0].op, + Instruction::LoadFastLoadFast { var_nums } + | Instruction::LoadFastBorrowLoadFastBorrow { var_nums } + if { + let (left, right) = var_nums.get(load_pair).indexes(); + f.varnames[usize::from(left)] == "end_time" + && f.varnames[usize::from(right)] == "self" + } + ) && matches!(window[1].op, Instruction::LoadAttr { .. }) + && matches!(window[2].op, Instruction::BinaryOp { .. }) + && matches!( + window[3].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(self_load))] == "self" + ) + && matches!(window[4].op, Instruction::StoreAttr { .. }) + }) + .unwrap_or_else(|| { + panic!( + "missing self.seconds store sequence: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + let return_result = ops + .windows(2) + .find(|window| { + let arg = OpArg::new(u32::from(u8::from(window[0].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "result" + ) && matches!(window[1].op, Instruction::ReturnValue) + }) + .unwrap_or_else(|| { + panic!( + "missing result return sequence: {:?}", + ops.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!( + seconds_store[0].op, + Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) && matches!(seconds_store[3].op, Instruction::LoadFastBorrow { .. }) + && matches!(return_result[0].op, Instruction::LoadFastBorrow { .. }), + "CPython codegen_try_finally() emits JUMP_NO_INTERRUPT to the exit label, and flowgraph.c labels the current empty block instead of inserting a b_next barrier before following code: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn with_tail_while_true_break_successor_uses_strong_load() { + let code = compile_exec( + r#" +def f(self, cm): + with cm as out: + while 1: + data = out.read() + if not data: + break + self.close() +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let close_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "close" + } + _ => false, + }) + .expect("missing close load"); + + assert!( + matches!( + instructions[close_attr - 1].op, + Instruction::LoadFast { .. } + ), + "CPython codegen_while() emits USE_LABEL(end) for the tail break target before codegen_with_inner() emits normal __exit__ cleanup, so flowgraph.c::optimize_load_fast() leaves the successor receiver strong: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn folded_ifexp_nested_in_call_keeps_successor_load_fast_strong() { + let code = compile_exec( + r#" +def f(self, g): + self.x = g(self.y, optimization='' if __debug__ else 1) + self.close() +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let close_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "close" + } + _ => false, + }) + .expect("missing close load"); + + assert!( + matches!( + instructions[close_attr - 1].op, + Instruction::LoadFast { .. } + ), + "CPython codegen_ifexp() emits USE_LABEL(end); with a folded conditional nested in a larger stack expression, flowgraph.c::optimize_load_fast() sees an empty end block and does not visit the successor loads: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn folded_ifexp_assignment_before_with_keeps_context_load_fast_strong() { + let code = compile_exec( + r#" +def f(self, cm): + optlevel = 1 if __debug__ else 0 + with cm as t: + self.use(t, optlevel) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let first_exit = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadSpecial { method } => { + method.get(OpArg::new(u32::from(u8::from(unit.arg)))) == SpecialMethod::Exit + } + _ => false, + }) + .expect("missing __exit__ load"); + let cm_load = (0..first_exit) + .rev() + .find(|idx| { + let arg = OpArg::new(u32::from(u8::from(instructions[*idx].arg))); + matches!( + instructions[*idx].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "cm" + ) + }) + .expect("missing cm load before __exit__"); + + assert!( + matches!(instructions[cm_load].op, Instruction::LoadFast { .. }), + "CPython codegen_ifexp() emits USE_LABEL(end), then codegen_with_inner() starts the with header after that empty block; flowgraph.c::optimize_load_fast() does not push fallthrough successors from empty blocks, so the context manager load stays strong: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn folded_ifexp_assignment_keeps_later_statement_load_fast_strong() { + let code = compile_exec( + r#" +def f(self, x): + optlevel = 1 if __debug__ else 0 + ext = '.pyc' + self.use(x, ext) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let use_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "use" + } + _ => false, + }) + .expect("missing use load"); + + assert!( + matches!(instructions[use_attr - 1].op, Instruction::LoadFast { .. }), + "CPython codegen_ifexp() emits USE_LABEL(end) for the folded assignment; flowgraph.c::optimize_load_fast() does not push fallthrough from the empty end block, so the next statement receiver stays strong: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn const_assignment_before_with_keeps_context_load_fast_borrowed() { + let code = compile_exec( + r#" +def f(self, cm): + optlevel = 1 + with cm as t: + self.use(t, optlevel) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let first_exit = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadSpecial { method } => { + method.get(OpArg::new(u32::from(u8::from(unit.arg)))) == SpecialMethod::Exit + } + _ => false, + }) + .expect("missing __exit__ load"); + let cm_load = (0..first_exit) + .rev() + .find(|idx| { + let arg = OpArg::new(u32::from(u8::from(instructions[*idx].arg))); + matches!( + instructions[*idx].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "cm" + ) + }) + .expect("missing cm load before __exit__"); + + assert!( + matches!(instructions[cm_load].op, Instruction::LoadFastBorrow { .. }), + "without CPython's folded if-expression end label, flowgraph.c::optimize_load_fast() sees the following with header in the same reachable state and borrows the context manager load: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn if_end_label_reuse_allows_following_return_borrow() { + let code = compile_exec( + r#" +def f(data): + msgids = [] + reading_msgid = False + cur_msgid = [] + for line in data.split('\n'): + if reading_msgid: + if line.startswith('"'): + cur_msgid.append(line.strip('"')) + else: + msgids.append('\n'.join(cur_msgid)) + cur_msgid = [] + reading_msgid = False + continue + if line.startswith('msgid '): + line = line[len('msgid '):] + cur_msgid.append(line.strip('"')) + reading_msgid = True + else: + if reading_msgid: + msgids.append('\n'.join(cur_msgid)) + + return msgids +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let return_load = instructions + .windows(2) + .find(|window| { + let arg = OpArg::new(u32::from(u8::from(window[0].arg))); + matches!( + window[0].op, + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(arg))] == "msgids" + ) && matches!(window[1].op, Instruction::ReturnValue) + }) + .unwrap_or_else(|| { + panic!( + "missing return msgids sequence: {:?}", + instructions.iter().map(|unit| unit.op).collect::>() + ) + }); + + assert!( + matches!(return_load[0].op, Instruction::LoadFastBorrow { .. }), + "CPython codegen_if() ends with USE_LABEL(end), and flowgraph.c::cfg_builder_current_block_is_terminated() reuses the current empty block for that label instead of leaving a b_next barrier before the following return: {:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn named_except_continue_resume_try_body_keeps_method_borrows() { + let code = compile_exec( + r#" +def f(self, block=True): + if not block and not self.wait(timeout=0): + return None + while self.event_queue.empty(): + while True: + try: + self.push_char(self.read(1)) + except OSError as err: + if err.errno == errno.EINTR: + if not self.event_queue.empty(): + return self.event_queue.get() + else: + continue + else: + raise + else: + break + return self.event_queue.get() +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + + let attr_idx = |name: &str| { + instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == name + } + _ => false, + }) + .unwrap_or_else(|| panic!("missing {name} LOAD_ATTR")) + }; + let push_char_idx = attr_idx("push_char"); + let read_idx = attr_idx("read"); + + assert!( + matches!( + instructions[push_char_idx - 1].op, + Instruction::LoadFastBorrow { .. } + ), + "named-except cleanup continue backedge should not deopt the protected try-body method receiver, got instructions={:?}", + instructions.iter().map(|unit| unit.op).collect::>() + ); + assert!( + matches!( + instructions[read_idx - 1].op, + Instruction::LoadFastBorrow { .. } + ), + "nested protected try-body method receiver should remain borrowed like CPython, got instructions={:?}", + instructions.iter().map(|unit| unit.op).collect::>() + ); + } + + #[test] + fn boolop_or_shared_body_keeps_false_jump_before_loop_backedge() { + let code = compile_exec( + r#" +def f(value): + digits = [] + for digit in value: + if isinstance(digit, int) and 0 <= digit <= 9: + if digits or digit != 0: + digits.append(digit) + else: + raise ValueError + return digits +"#, + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.windows(11).any(|window| { + matches!( + window, + [ + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::ToBool, + Instruction::PopJumpIfTrue { .. }, + Instruction::NotTaken, + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadSmallInt { .. }, + Instruction::CompareOp { .. }, + Instruction::PopJumpIfFalse { .. }, + Instruction::NotTaken, + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + ] + ) + }), + "OR-shared body should keep CPython last-condition false jump before the loop backedge, got ops={ops:?}" + ); + assert!( + !ops.windows(6).any(|window| { + matches!( + window, + [ + Instruction::CompareOp { .. }, + Instruction::PopJumpIfTrue { .. }, + Instruction::NotTaken, + Instruction::JumpBackward { .. } + | Instruction::JumpBackwardNoInterrupt { .. }, + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + ] + ) + }), + "OR-shared body should not be moved after the implicit loop backedge, got ops={ops:?}" + ); + } + + #[test] + fn single_if_loop_backedge_keeps_true_body_fallthrough_backedge_shape() { + let code = compile_exec( + r##" +def f(buffer, pos, last_char): + while pos > 0: + pos -= 1 + if buffer[pos] == "#": + last_char = None + return last_char +"##, + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.windows(6).any(|window| { + matches!( + window, + [ + Instruction::CompareOp { .. }, + Instruction::PopJumpIfTrue { .. }, + Instruction::NotTaken, + Instruction::JumpBackward { .. } + | Instruction::JumpBackwardNoInterrupt { .. }, + Instruction::LoadConst { .. }, + Instruction::StoreFast { .. }, + ] + ) + }), + "single-if loop tail should keep CPython true-body plus fallthrough-backedge shape, got ops={ops:?}", + ); + assert!( + !ops.windows(4).any(|window| { + matches!( + window, + [ + Instruction::CompareOp { .. }, + Instruction::PopJumpIfFalse { .. }, + Instruction::NotTaken, + Instruction::LoadConst { .. }, + ] + ) + }), + "single-if loop tail should not be inverted away from CPython shape, got ops={ops:?}", + ); + } + + fn find_code<'a>(code: &'a CodeObject, name: &str) -> Option<&'a CodeObject> { + if code.obj_name == name { + return Some(code); + } + code.constants.iter().find_map(|constant| { + if let ConstantData::Code { code } = constant { + find_code(code, name) + } else { + None + } + }) + } + + fn find_direct_child_code<'a>(code: &'a CodeObject, name: &str) -> Option<&'a CodeObject> { + code.constants.iter().find_map(|constant| { + if let ConstantData::Code { code } = constant { + (code.obj_name == name).then_some(code.as_ref()) + } else { + None + } + }) + } + + #[test] + fn annotated_multiline_function_body_keeps_def_firstlineno_like_cpython() { + let code = compile_exec( + r#" +a = 1 +def f( + x: a, +): ... +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + // CPython 3.14 codegen_function() computes firstlineno from the + // FunctionDef before compiling annotations, then passes it to + // codegen_function_body(). + assert_eq!(f.linetable.as_ref(), &[0x80, 0x00, 0xe1, 0x03, 0x06]); + } + + #[test] + fn annotation_scope_return_uses_function_location_like_cpython() { + let code = compile_exec( + r#" +def g(): + def f(x: not (int is int), /): ... +"#, + ); + let g = find_code(&code, "g").expect("missing g code"); + let annotate = find_code(g, "__annotate__").expect("missing annotation code"); + // CPython 3.14 codegen_function_annotations() receives LOC(function) + // and uses it for the annotation closure's BUILD_MAP/RETURN_VALUE and + // for the parent MAKE_FUNCTION annotate sequence. + assert_eq!(g.linetable.as_ref(), &[0x80, 0x00, 0xdf, 0x04, 0x26]); + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd7, 0x04, 0x26, 0xd1, 0x04, 0x26, 0x94, 0x23, 0x9c, 0x13, 0xd0, 0x0d, + 0x1d, 0xd1, 0x04, 0x26, + ], + ); + } + + #[test] + fn module_deferred_annotations_use_start_location_like_cpython() { + let code = compile_exec( + "\ +import os +X: int +Y: str +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + + // CPython 3.14 compile.c::start_location() passes the first module + // statement location into _PyCodegen_Module(), and + // codegen_process_deferred_annotations() uses that loc for annotation + // scope setup, BUILD_MAP, STORE_SUBSCR, and RETURN_VALUE. + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0x87, 0x09, 0x81, 0x09, 0xdf, 0x00, 0x06, 0x82, 0x06, 0x84, 0x33, 0x81, + 0x06, 0xf1, 0x03, 0x00, 0x01, 0x0a, 0xe7, 0x00, 0x06, 0x82, 0x06, 0x84, 0x33, 0x81, + 0x06, 0xf2, 0x05, 0x00, 0x01, 0x0a, + ] + ); + } + + #[test] + fn super_method_call_kw_names_use_attribute_location_like_cpython() { + let code = compile_exec( + "\ +class C: + def f(self, x, y): + super().__init__( + x=x, + y=y) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let call_kw_index = f + .instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::CallKw { .. })) + .expect("missing CALL_KW"); + let (kw_names, (location, end_location)) = f + .instructions + .iter() + .zip(&f.locations) + .take(call_kw_index) + .rev() + .find(|(unit, _)| matches!(unit.op, Instruction::LoadConst { .. })) + .expect("missing CALL_KW names tuple"); + + assert!( + matches!(kw_names.op, Instruction::LoadConst { .. }), + "expected keyword names tuple before CALL_KW" + ); + assert_eq!( + (location.line.get(), end_location.line.get()), + (3, 3), + "CPython maybe_optimize_method_call() passes the updated method-attribute loc into codegen_call_simple_kw_helper()" + ); + } + + #[test] + fn lambda_return_uses_body_location_like_cpython() { + let code = compile_exec( + "\ +def outer(): + return lambda x: x if x else 1 +", + ); + let lambda = find_code(&code, "").expect("missing lambda code"); + let return_positions: Vec<_> = lambda + .instructions + .iter() + .zip(&lambda.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 22, 2, 35), (2, 22, 2, 35)], + "CPython codegen_lambda() emits RETURN_VALUE at LOC(lambda body)" + ); + } + + #[test] + fn not_compare_uses_unary_location_like_cpython() { + let code = compile_exec( + "\ +def f(self, other): + return not self == other +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 parses the Compare inside UnaryOp(Not) with the + // UnaryOp start location, so codegen_compare() emits COMPARE_OP at + // the full "not self == other" range before flowgraph folds TO_BOOL. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0f, 0x13, 0xd2, 0x0b, 0x1c, 0xd0, 0x04, 0x1c, + ] + ); + } + + #[test] + fn not_chained_compare_keeps_compare_location_like_cpython() { + let code = compile_exec( + "\ +def f(c): + return not (b\" \" <= c <= b\"~\") +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython's single Compare under UnaryOp(Not) includes "not" in the + // Compare range, but chained comparisons keep their inner range for + // compare scaffolding and only use the UnaryOp range for TO_BOOL and + // UNARY_NOT. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x10, 0x14, 0x98, 0x01, 0xd7, 0x10, 0x21, 0xd4, 0x10, 0x21, 0x98, + 0x54, 0xd1, 0x10, 0x21, 0xd4, 0x0b, 0x22, 0xd0, 0x04, 0x22, 0xd1, 0x10, 0x21, 0xd4, + 0x0b, 0x22, 0xd0, 0x04, 0x22, + ] + ); + } + + #[test] + fn type_param_scopes_use_cpython_locations() { + let code = compile_exec("type BoundGenericAlias[X: int] = set[X]\n"); + let type_params = find_code(&code, "") + .expect("missing generic parameters code"); + let bound = find_direct_child_code(type_params, "X").expect("missing X bound code"); + let alias = + find_direct_child_code(type_params, "BoundGenericAlias").expect("missing alias code"); + + // CPython 3.14 codegen_type_params() emits type-parameter ops at + // LOC(typeparam), bound/default evaluator ops at LOC(e), and type alias + // body plumbing at LOC(s). + assert_eq!( + type_params.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0xd0, 0x00, 0x27, 0x90, 0x76, 0x9b, 0x23, 0x93, 0x76, 0xd7, 0x00, + 0x27, 0xd1, 0x00, 0x27, + ], + ); + assert_eq!( + bound.linetable.as_ref(), + &[0x80, 0x00, 0x9f, 0x23, 0x9e, 0x23] + ); + assert_eq!( + alias.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0xd7, 0x00, 0x27, 0xd0, 0x00, 0x27, 0xa4, 0x13, 0xa0, 0x51, 0xa5, + 0x16, 0xd0, 0x00, 0x27, + ], + ); + } + + #[test] + fn generic_function_annotation_scope_uses_function_location_like_cpython() { + let code = compile_exec("def f[T](x: int): ...\n"); + let type_params = + find_code(&code, "").expect("missing type params code"); + let annotate = + find_direct_child_code(type_params, "__annotate__").expect("missing annotation code"); + + // CPython 3.14 passes LOC(function) into codegen_function_annotations(), + // even when the annotation closure is emitted inside the generic + // parameters scope after codegen_type_params(). + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd7, 0x00, 0x15, 0xd1, 0x00, 0x15, 0x8c, 0x43, 0xd1, 0x00, 0x15, + ], + ); + } + + #[test] + fn generic_class_type_params_store_uses_class_location_like_cpython() { + let code = compile_exec( + "\ +def outer(): + class X[T]: ... +", + ); + let type_params = + find_code(&code, "").expect("missing type params code"); + + // CPython 3.14 codegen_class() calls codegen_type_params(), then stores + // the resulting .type_params cell with codegen_nameop(c, LOC(class), ...). + assert_eq!( + type_params.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0x8c, 0x41, 0x87, 0x4f, 0x87, 0x4f, 0x80, 0x4f, + ] + ); + } + + #[test] + fn generic_class_wrapper_ops_use_class_location_like_cpython() { + let code = compile_exec( + "\ +def f(): + class X[T](tuple): + pass +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let wrapper_positions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Resume { .. })) + .take(4) + .zip(f.locations.iter().filter(|_| true).skip(1)) + .map(|(unit, (location, end_location))| { + ( + unit.op, + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + ) + }) + .collect(); + assert_eq!( + wrapper_positions + .iter() + .map(|(_, line, col, end_line, end_col)| (*line, *col, *end_line, *end_col)) + .collect::>(), + vec![(2, 5, 3, 13); 4], + "CPython codegen_class() emits type-params wrapper closure, PUSH_NULL, and CALL at LOC(class)" + ); + + let type_params = + find_code(f, "").expect("missing generic parameters code"); + let generic_base_position = type_params + .instructions + .iter() + .zip(&type_params.locations) + .find_map(|(unit, (location, end_location))| { + let Instruction::LoadFastBorrow { var_num } = unit.op else { + return None; + }; + let idx = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + let localsplus = type_params + .varnames + .iter() + .chain(type_params.cellvars.iter()) + .chain(type_params.freevars.iter()) + .collect::>(); + localsplus + .get(usize::from(idx)) + .is_some_and(|name| name.as_str() == ".generic_base") + .then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing .generic_base load"); + assert_eq!( + generic_base_position, + (2, 5, 3, 13), + "CPython codegen_class() injects .generic_base with LOC(class)" + ); + } + + #[test] + fn class_deferred_annotations_use_class_body_location_like_cpython() { + let code = compile_exec( + r#" +class C: + "doc" + x: int +"#, + ); + let class_code = find_code(&code, "C").expect("missing class code"); + + // CPython 3.14 calls codegen_body(c, loc, ...) from codegen_class_body() + // with LOCATION(firstlineno, firstlineno, 0, 0). Deferred annotation + // closure setup and following artificial class tail inherit that class + // body location, not the annotation expression location. + assert_eq!( + class_code.linetable.as_ref(), + &[ + 0xf8, 0x87, 0x00, 0x80, 0x00, 0xd9, 0x04, 0x09, 0xf7, 0x03, 0x00, 0x01, 0x01, 0x83, + 0x00, + ], + ); + } + + #[test] + fn future_annotation_string_uses_annotation_location_like_cpython() { + let code = compile_exec("from __future__ import annotations\nclass Bar:\n foo: Foo\n"); + let class_code = find_code(&code, "Bar").expect("missing class code"); + + // CPython 3.14 codegen_annassign() calls codegen_visit_annexpr(), + // which emits the stringized annotation at LOC(annotation), then emits + // the __annotations__ store sequence at LOC(AnnAssign). + assert_eq!( + class_code.linetable.as_ref(), + &[0x87, 0x00, 0xd8, 0x09, 0x0c, 0x87, 0x48] + ); + } + + #[test] + fn lambda_dict_literal_ops_use_dict_location_like_cpython() { + let code = compile_exec( + "\ +f = lambda data: {'x': data} +g = lambda i: {**i} +", + ); + let f = find_code(&code, "").expect("missing f lambda code"); + let g = code + .constants + .iter() + .filter_map(|constant| { + if let ConstantData::Code { code } = constant { + (code.obj_name == "").then_some(code.as_ref()) + } else { + None + } + }) + .nth(1) + .expect("missing g lambda code"); + + // CPython 3.14 codegen_dict()/codegen_subdict() uses LOC(dict) for + // BUILD_MAP, MAP_ADD, and DICT_UPDATE, so the lambda RETURN_VALUE + // inherits the full dict literal location after compiling its body. + assert_eq!( + f.linetable.as_ref(), + &[0x80, 0x00, 0x90, 0x23, 0x90, 0x74, 0x91, 0x1b] + ); + assert_eq!( + g.linetable.as_ref(), + &[0x80, 0x00, 0x88, 0x65, 0x90, 0x11, 0x89, 0x65] + ); + } + + #[test] + fn class_function_like_scopes_set_method_flag_like_cpython() { + let code = compile_exec_with_options( + r#" +class C: + def m(self): + pass + + async def am(self): + pass + + f = lambda self: self + y = (i for i in ()) + +def f(): + pass +"#, + CompileOpts::default(), + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let method = find_code(class_code, "m").expect("missing method code"); + let async_method = find_code(class_code, "am").expect("missing async method code"); + let lambda = find_code(class_code, "").expect("missing lambda code"); + let genexpr = find_code(class_code, "").expect("missing genexpr code"); + let module_function = find_code(&code, "f").expect("missing module function code"); + + for code in [method, async_method, lambda, genexpr] { + assert!( + code.flags.contains(bytecode::CodeFlags::METHOD), + "class-scope function-like code should carry CO_METHOD like CPython 3.14, got {:?}", + code.flags + ); + } + assert!( + !module_function.flags.contains(bytecode::CodeFlags::METHOD), + "module-scope function must not carry CO_METHOD" + ); + } + + #[test] + fn inlined_comprehension_lambda_in_class_is_not_method_like_cpython() { + let code = compile_exec( + "\ +class C: + def method(self): + super() + return __class__ + items = [(lambda: i) for i in range(5)] +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let lambda = find_code(class_code, "").expect("missing lambda code"); + assert!( + lambda.flags.contains(bytecode::CodeFlags::NESTED), + "lambda under inlined class comprehension should stay nested" + ); + assert!( + !lambda.flags.contains(bytecode::CodeFlags::METHOD), + "CPython creates this lambda while the current symtable block is the comprehension, not the class" + ); + } + + #[test] + fn genexpr_implicit_iterator_is_not_posonly_like_cpython() { + let code = compile_exec("x = (i for i in ())"); + let genexpr = find_code(&code, "").expect("missing genexpr code"); + + assert_eq!(genexpr.arg_count, 1); + assert_eq!( + genexpr.posonlyarg_count, 0, + "CPython codegen_comprehension() sets u_argcount=1 and leaves u_posonlyargcount=0" + ); + } + + #[test] + fn async_generator_uses_cpython_async_generator_flag() { + let code = compile_exec_with_options( + r#" +def g(): + yield 1 + +async def c(): + return 1 + +async def ag(): + yield 1 +"#, + CompileOpts::default(), + ); + let generator = find_code(&code, "g").expect("missing generator code"); + let coroutine = find_code(&code, "c").expect("missing coroutine code"); + let async_generator = find_code(&code, "ag").expect("missing async generator code"); + + assert!(generator.flags.contains(bytecode::CodeFlags::GENERATOR)); + assert!(!generator.flags.contains(bytecode::CodeFlags::COROUTINE)); + assert!( + !generator + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + + assert!(coroutine.flags.contains(bytecode::CodeFlags::COROUTINE)); + assert!(!coroutine.flags.contains(bytecode::CodeFlags::GENERATOR)); + assert!( + !coroutine + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + + assert!( + async_generator + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + assert!( + !async_generator + .flags + .contains(bytecode::CodeFlags::GENERATOR) + ); + assert!( + !async_generator + .flags + .contains(bytecode::CodeFlags::COROUTINE) + ); + } + + #[test] + fn is_none_jump_preserves_cpython_const_order() { + let code = compile_exec_with_options( + r#" +def f(self, payload): + "doc" + if self.x is None: + self.x = [payload] + else: + raise TypeError("bad") +"#, + CompileOpts::default(), + ); + let function = find_code(&code, "f").expect("missing function code"); + assert!( + matches!( + function.constants.as_ref(), + [ + ConstantData::Str { value: doc }, + ConstantData::None, + ConstantData::Str { value: message }, + ] if doc.as_ref() == "doc" && message.as_ref() == "bad" + ), + "CPython registers None from the pre-folded `is None` comparison before the else-body string" + ); + } + + #[test] + fn stop_iteration_handler_starts_at_scope_start_resume_like_cpython() { + let code = compile_exec_with_options( + r#" +def g(): + yield 1 + +async def c(): + return 1 + +x = (i for i in ()) +"#, + CompileOpts::default(), + ); + + fn assert_stop_iteration_table_starts_at_resume(code: &CodeObject) { + let resume_idx = u32::try_from( + code.instructions + .iter() + .position(|unit| { + matches!( + unit.op, + Instruction::Resume { context } + if matches!( + context + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .location(), + oparg::ResumeLocation::AtFuncStart + ) + ) + }) + .expect("missing function-start RESUME"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&code.exceptiontable); + assert!( + entries.iter().any(|entry| entry.start == resume_idx), + "CPython codegen_wrap_in_stopiteration_handler() inserts SETUP_CLEANUP before RESUME so the StopIteration table starts at RESUME; resume_idx={resume_idx}, entries={entries:?}, instructions={:?}", + code.instructions + ); + } + + assert_stop_iteration_table_starts_at_resume(find_code(&code, "g").expect("missing g")); + assert_stop_iteration_table_starts_at_resume(find_code(&code, "c").expect("missing c")); + assert_stop_iteration_table_starts_at_resume( + find_code(&code, "").expect("missing genexpr"), + ); + } + + #[test] + fn inlined_comprehension_cleanup_starts_at_result_build_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(self): + return [k for k, v in self._headers] +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let build_list_idx = u32::try_from( + f.instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::BuildList { .. })) + .expect("missing BUILD_LIST"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + assert!( + entries.iter().any(|entry| { + entry.start == build_list_idx && entry.depth == 3 && !entry.push_lasti + }), + "CPython codegen_push_inlined_comprehension_locals() emits SETUP_FINALLY before BUILD_LIST, so the virtual cleanup table starts at BUILD_LIST with saved locals depth; build_list_idx={build_list_idx}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn or_return_not_taken_before_jump_target_splits_exception_table_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(self, maintype): + if maintype != "multipart" or not self.is_multipart(): + return + yield 1 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let not_taken_before_return = u32::try_from( + f.instructions + .windows(3) + .position(|window| { + matches!( + window, + [ + CodeUnit { + op: Instruction::NotTaken, + .. + }, + CodeUnit { + op: Instruction::LoadConst { .. }, + .. + }, + CodeUnit { + op: Instruction::ReturnValue, + .. + }, + ] + ) + }) + .expect("missing NOT_TAKEN before return"), + ) + .unwrap(); + let return_load = not_taken_before_return + 1; + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries.iter().all(|entry| { + not_taken_before_return < entry.start || not_taken_before_return >= entry.end + }), + "CPython normalize_jumps() can leave a NOT_TAKEN before a separately labelled jump target outside the generator StopIteration range; entries={entries:?}, instructions={:?}", + f.instructions + ); + assert!( + entries + .iter() + .any(|entry| entry.start <= return_load && return_load < entry.end), + "the return block after that NOT_TAKEN is still protected by the StopIteration handler; entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn loop_break_condition_splits_exception_table_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(start, items): + if start: + for x in items: + if x == start: + break + yield 1 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let break_jump = u32::try_from( + f.instructions + .windows(3) + .position(|window| { + matches!( + window, + [ + CodeUnit { + op: Instruction::PopJumpIfTrue { .. }, + .. + }, + CodeUnit { + op: Instruction::Cache, + .. + }, + CodeUnit { + op: Instruction::NotTaken, + .. + }, + ] + ) || matches!( + window, + [ + CodeUnit { + op: Instruction::PopJumpIfTrue { .. }, + .. + }, + CodeUnit { + op: Instruction::NotTaken, + .. + }, + CodeUnit { + op: Instruction::JumpBackward { .. }, + .. + }, + ] + ) + }) + .expect("missing loop break conditional jump"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries + .iter() + .all(|entry| break_jump < entry.start || break_jump >= entry.end), + "CPython normalize_jumps() leaves the loop-break conditional before the synthetic NOT_TAKEN/JUMP_BACKWARD block outside the StopIteration table; break_jump={break_jump}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn nested_ifexp_not_taken_splits_exception_table_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(flag, subparts): + if flag: + candidate = subparts[0] if subparts else None + yield 1 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let conditional_expr_not_taken = u32::try_from( + f.instructions + .iter() + .enumerate() + .find_map(|(idx, unit)| { + if !matches!(unit.op, Instruction::NotTaken) { + return None; + } + let prev = f.instructions[..idx] + .iter() + .rev() + .find(|unit| !matches!(unit.op, Instruction::Cache))?; + let mut following = f.instructions[idx + 1..] + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)); + let next = following.next()?; + let after_next = following.next()?; + (matches!(prev.op, Instruction::PopJumpIfFalse { .. }) + && matches!(next.op, Instruction::LoadFastBorrow { .. }) + && matches!(after_next.op, Instruction::LoadSmallInt { .. })) + .then_some(idx) + }) + .expect("missing conditional expression NOT_TAKEN"), + ) + .unwrap(); + let body_start = conditional_expr_not_taken + 1; + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries.iter().all(|entry| { + conditional_expr_not_taken < entry.start || conditional_expr_not_taken >= entry.end + }), + "CPython codegen_ifexp() uses a separate orelse label inside conditional statements, leaving the normalize_jumps NOT_TAKEN outside the StopIteration table; not_taken={conditional_expr_not_taken}, entries={entries:?}, instructions={:?}", + f.instructions + ); + assert!( + entries + .iter() + .any(|entry| entry.start <= body_start && body_start < entry.end), + "the conditional-expression body after that NOT_TAKEN remains protected; body_start={body_start}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn bool_not_taken_after_conditional_yield_splits_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(a, b, c): + if a: + yield 1 + if b: + x = 2 + if c: + x = 3 + yield 4 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let split_not_taken = f + .instructions + .iter() + .enumerate() + .filter_map(|(idx, unit)| { + if !matches!(unit.op, Instruction::NotTaken) { + return None; + } + let prev = f.instructions[..idx] + .iter() + .rev() + .find(|unit| !matches!(unit.op, Instruction::Cache))?; + matches!( + prev.op, + Instruction::PopJumpIfFalse { .. } | Instruction::PopJumpIfTrue { .. } + ) + .then(|| u32::try_from(idx).unwrap()) + }) + .nth(1) + .expect("missing second bool conditional NOT_TAKEN"); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries + .iter() + .all(|entry| split_not_taken < entry.start || split_not_taken >= entry.end), + "CPython labels exception targets before normalize_jumps(), so the general bool-jump NOT_TAKEN after a conditional yield is outside the StopIteration table; not_taken={split_not_taken}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + fn non_cache_instructions(code: &CodeObject) -> impl Iterator { + code.instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + } + + fn varname_index(code: &CodeObject, name: &str) -> usize { + code.varnames + .iter() + .position(|varname| varname.as_str() == name) + .unwrap_or_else(|| panic!("missing {name} local")) + } + + fn load_fast_ops_for_var(code: &CodeObject, name: &str) -> Vec { + let var_idx = varname_index(code, name); + non_cache_instructions(code) + .filter_map(|unit| match unit.op { + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { + let var_num = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + (usize::from(var_num) == var_idx).then_some(unit.op) + } + _ => None, + }) + .collect() + } + + fn load_fast_pair_ops_for_vars( + code: &CodeObject, + left_name: &str, + right_name: &str, + ) -> Vec { + let left_idx = varname_index(code, left_name); + let right_idx = varname_index(code, right_name); + non_cache_instructions(code) + .filter_map(|unit| { + let var_nums = match unit.op { + Instruction::LoadFastLoadFast { var_nums } + | Instruction::LoadFastBorrowLoadFastBorrow { var_nums } => var_nums, + _ => return None, + }; + let (left, right) = var_nums + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .indexes(); + (usize::from(left) == left_idx && usize::from(right) == right_idx) + .then_some(unit.op) + }) + .collect() + } + + fn count_strong_loads_for_vars(code: &CodeObject, names: &[&str]) -> usize { + let var_indices = names + .iter() + .map(|name| varname_index(code, name)) + .collect::>(); + non_cache_instructions(code) + .filter(|unit| match unit.op { + Instruction::LoadFast { var_num } => { + let var_num = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + var_indices.contains(&usize::from(var_num)) + } + _ => false, + }) + .count() + } + + fn count_strong_loads(code: &CodeObject) -> usize { + non_cache_instructions(code) + .filter(|unit| matches!(unit.op, Instruction::LoadFast { .. })) + .count() + } + + #[test] + fn match_or_default_block_keeps_load_fast_strong() { + let code = compile_exec( + r#" +def f(format, other): + match format: + case 1 | 2: + return other + case _: + raise NotImplementedError(other) +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let loads = load_fast_ops_for_var(function, "other"); + assert!( + matches!( + loads.as_slice(), + [ + Instruction::LoadFastBorrow { .. }, + Instruction::LoadFastBorrow { .. }, + Instruction::LoadFast { .. }, + ] + ), + "CPython optimize_load_fast() keeps trailing OR-pattern default loads strong, got {loads:?}", + ); + } + + #[test] + fn match_nested_or_default_block_keeps_load_fast_strong() { + let code = compile_exec( + r#" +def f(format, other): + match format: + case [1 | 2, value]: + return other + case _: + raise NotImplementedError(other) +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let loads = load_fast_ops_for_var(function, "other"); + assert!( + loads + .iter() + .all(|op| matches!(op, Instruction::LoadFastBorrow { .. })), + "CPython 3.14 optimize_load_fast() borrows nested OR-pattern default loads, got {loads:?}", + ); + } + + #[test] + fn match_success_next_location_preserves_pass_nop() { + let code = compile_exec( + r#" +def f(command): + match command: + case "": + pass + case _ as unknown: + sink(unknown) + return False +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let ops = non_cache_instructions(function) + .map(|unit| unit.op) + .collect::>(); + assert!( + ops.windows(3).any(|window| matches!( + window, + [ + Instruction::PopTop, + Instruction::Nop, + Instruction::LoadConst { .. }, + ] + )), + "CPython NEXT_LOCATION keeps the pass NOP after match subject POP_TOP, got {ops:?}", + ); + } + + #[test] + fn match_subject_copy_uses_case_pattern_location_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case 1: + return True + case 2: + return False +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let copy_line = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, _))| { + let Instruction::Copy { i } = unit.op else { + return None; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + (i.get(arg) == 1).then_some(location.line.get()) + }) + .expect("missing match subject COPY"); + assert_eq!( + copy_line, 3, + "CPython codegen_match_inner() emits ADDOP_I(c, LOC(m->pattern), COPY, 1)" + ); + } + + #[test] + fn match_or_alternative_copies_use_alternative_locations_like_cpython() { + let code = compile_exec( + "\ +def f(): + x = False + match 0: + case 0 | 1 | 2 | 3: + x = True + return x +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x08, 0x0d, 0x80, 0x41, 0xd8, 0x0a, 0x0b, 0xdf, 0x0d, 0x0e, 0x97, + 0x11, 0x97, 0x51, 0x9f, 0x11, 0x88, 0x5d, 0xe0, 0x0b, 0x0c, 0x80, 0x48, 0xf0, 0x05, + 0x00, 0x0e, 0x1b, 0xd8, 0x10, 0x14, 0x88, 0x41, 0xd8, 0x0b, 0x0c, 0x80, 0x48, + ], + "CPython codegen_pattern_or() emits each alternative COPY with LOC(alt)" + ); + } + + #[test] + fn match_success_jump_uses_no_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + match 0: + case 0: + x = True + case 0: + x = False + self.assertIs(x, True) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0a, 0x0b, 0xde, 0x0d, 0x0e, 0xd9, 0x10, 0x14, 0x89, 0x41, 0xdd, + 0x0d, 0x0e, 0xd8, 0x10, 0x15, 0x88, 0x41, 0xd8, 0x04, 0x08, 0x87, 0x4d, 0x81, 0x4d, + 0x90, 0x21, 0x90, 0x54, 0xd6, 0x04, 0x1a, + ], + "CPython codegen_match_inner() emits the success jump with NO_LOCATION" + ); + } + + #[test] + fn match_mapping_keys_scaffolding_uses_mapping_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + x = {} + y = None + match x: + case {0: 0}: + y = 0 + self.assertIs(y, None) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x08, 0x0a, 0x80, 0x41, 0xd8, 0x08, 0x0c, 0x80, 0x41, 0xd8, 0x0a, + 0x0b, 0xdf, 0x0d, 0x13, 0x8f, 0x56, 0x8a, 0x56, 0x95, 0x11, 0x89, 0x56, 0xd8, 0x10, + 0x11, 0x89, 0x41, 0xf2, 0x03, 0x00, 0x0e, 0x14, 0xe0, 0x04, 0x08, 0x87, 0x4d, 0x81, + 0x4d, 0x90, 0x21, 0x90, 0x54, 0xd6, 0x04, 0x1a, + ], + "CPython codegen_pattern_mapping() returns to LOC(p) for BUILD_TUPLE/MATCH_KEYS scaffolding" + ); + } + + #[test] + fn match_class_scaffolding_uses_class_pattern_location_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case bool(z): + y = 0 + return y, z +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0a, 0x0b, 0xdc, 0x0d, 0x11, 0x8f, 0x57, 0x88, 0x57, 0xd8, 0x10, + 0x11, 0x88, 0x41, 0xd8, 0x0b, 0x0c, 0x88, 0x34, 0x80, 0x4b, 0xf0, 0x05, 0x00, 0x0e, + 0x15, 0xe0, 0x0b, 0x0c, 0x88, 0x61, 0x88, 0x34, 0x80, 0x4b, + ], + "CPython codegen_pattern_class() returns to LOC(p) after VISIT(cls)" + ); + } + + #[test] + fn while_try_body_layout_keeps_false_jump_to_anchor() { + let code = compile_exec( + r#" +def f(stack, itstack, node_to_stack_index): + while True: + while stack: + try: + node = itstack[-1]() + break + except StopIteration: + del node_to_stack_index[stack.pop()] + itstack.pop() + else: + break +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let ops = non_cache_instructions(function) + .map(|unit| unit.op) + .collect::>(); + let stack_test = ops + .windows(5) + .find(|window| { + matches!( + window, + [ + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::ToBool, + Instruction::PopJumpIfFalse { .. }, + Instruction::NotTaken, + Instruction::Nop, + ] + ) + }) + .unwrap_or_else(|| { + panic!("expected CPython-style while/try false jump to anchor, got {ops:?}") + }); + assert!(matches!(stack_test[2], Instruction::PopJumpIfFalse { .. })); + } + + #[test] + fn while_if_not_break_keeps_body_call() { + let code = compile_exec( + r#" +def f(waiters): + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + break +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let ops = non_cache_instructions(function) + .map(|unit| unit.op) + .collect::>(); + assert!( + ops.windows(4).any(|window| matches!( + window, + [ + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + Instruction::LoadConst { .. }, + Instruction::Call { .. }, + ] + )), + "CPython keeps waiter.set_result(None) before the break, got {ops:?}", + ); + } + + fn localsplus_name(code: &CodeObject, idx: usize) -> Option<&str> { + if idx < code.varnames.len() { + return Some(code.varnames[idx].as_str()); + } + + let mut extra_idx = idx - code.varnames.len(); + for cellvar in &code.cellvars { + if !code.varnames.iter().any(|varname| varname == cellvar) { + if extra_idx == 0 { + return Some(cellvar.as_str()); + } + extra_idx -= 1; + } + } + code.freevars.get(extra_idx).map(|name| name.as_str()) + } + + fn has_common_constant(code: &CodeObject, expected: bytecode::CommonConstant) -> bool { + code.instructions.iter().any(|unit| match unit.op { + Instruction::LoadCommonConstant { idx } => { + idx.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected + } + _ => false, + }) + } + + fn has_intrinsic_1(code: &CodeObject, expected: IntrinsicFunction1) -> bool { + code.instructions.iter().any(|unit| match unit.op { + Instruction::CallIntrinsic1 { func } => { + func.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected + } + _ => false, + }) + } + + #[test] + fn trace_assert_true_try_pair() { + let trace = compile_exec_late_cfg_trace( + "\ +try: + assert True +except AssertionError as e: + fail() +try: + assert True, 'msg' +except AssertionError as e: + fail() +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn trace_for_unpack_list_literal() { + let trace = compile_exec_late_cfg_trace( + "\ +result = [] +for x, in [(1,), (2,), (3,)]: + result.append(x) +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn trace_break_in_finally_function() { + let trace = compile_single_function_late_cfg_trace( + "\ +def f(self): + count = 0 + while count < 2: + count += 1 + try: + pass + finally: + break + self.assertEqual(count, 1) +", + "f", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn import_originated_name_disables_method_call_optimization_even_with_local_import() { + let code = compile_exec( + "\ +import warnings + +def f(ch): + import warnings + warnings.warn( + '\"\\\\%c\" is an invalid escape sequence' % ch + if 0x20 <= ch < 0x7F + else '\"\\\\x%02x\" is an invalid escape sequence' % ch, + DeprecationWarning, + stacklevel=2, + ) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f.instructions.iter().map(|unit| unit.op).collect(); + let warn_attr = ops + .iter() + .position(|op| matches!(op, Instruction::LoadAttr { .. })) + .expect("missing LOAD_ATTR for warnings.warn"); + let push_null = ops[warn_attr + 10..] + .iter() + .position(|op| matches!(op, Instruction::PushNull)) + .map(|idx| warn_attr + 10 + idx) + .expect("expected PUSH_NULL after plain LOAD_ATTR"); + + let load_attr = match f.instructions[warn_attr].op { + Instruction::LoadAttr { namei } => namei.get(OpArg::new(u32::from(u8::from( + f.instructions[warn_attr].arg, + )))), + _ => unreachable!(), + }; + assert!( + !load_attr.is_method(), + "import-originated names should use plain LOAD_ATTR" + ); + assert!( + matches!(ops[push_null + 1], Instruction::LoadSmallInt { .. }), + "expected warning message expression to start after PUSH_NULL, got ops={ops:?}" + ); + } + + #[test] + fn trace_constant_false_elif_chain() { + let trace = compile_exec_late_cfg_trace( + "\ +if 0: pass +elif 0: pass +elif 0: pass +elif 0: pass +else: pass +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } #[test] - fn test_trace_multi_pass_suite() { + fn trace_multi_pass_suite() { let trace = compile_exec_late_cfg_trace( "\ if 1: @@ -12995,7 +15685,7 @@ if 1: } #[test] - fn test_trace_single_compare_if() { + fn trace_single_compare_if() { let trace = compile_exec_late_cfg_trace( "\ if 1 == 1: @@ -13008,7 +15698,7 @@ if 1 == 1: } #[test] - fn test_trace_comparison_suite() { + fn trace_comparison_suite() { let trace = compile_exec_late_cfg_trace( "\ if 1: pass @@ -13031,7 +15721,7 @@ if 1 not in (): pass } #[test] - fn test_trace_if_for_except_layout() { + fn trace_if_for_except_layout() { let trace = compile_exec_late_cfg_trace( "\ from sys import maxsize @@ -13051,7 +15741,7 @@ elif maxsize == 9223372036854775807: } #[test] - fn test_break_in_finally_tail_loads_borrow_through_empty_fallthrough_block() { + fn break_in_finally_tail_loads_borrow_through_empty_fallthrough_block() { let code = compile_exec( "\ def f(self): @@ -13094,7 +15784,7 @@ def f(self): } #[test] - fn test_plain_constant_bool_op_folds_to_selected_operand() { + fn plain_constant_bool_op_folds_to_selected_operand() { let code = compile_exec( "\ x = 1 or 2 or 3 @@ -13143,7 +15833,598 @@ x = 1 or 2 or 3 } #[test] - fn test_starred_call_preserves_bool_op_short_circuit_shape() { + fn taken_constant_boolop_load_const_uses_literal_location_like_cpython() { + let code = compile_exec( + "\ +def and_false(x): + return False and x + +def or_true(x): + return True or x +", + ); + let and_false = find_code(&code, "and_false").expect("missing and_false code"); + let or_true = find_code(&code, "or_true").expect("missing or_true code"); + + // CPython 3.14 codegen_boolop() VISITs the selected literal before the + // short-circuit jump is optimized away, so the surviving LOAD_CONST + // keeps the literal range rather than the whole BoolOp range. + assert_eq!( + and_false.linetable.as_ref(), + &[0x80, 0x00, 0xd8, 0x0b, 0x10, 0xd0, 0x04, 0x16] + ); + assert_eq!( + or_true.linetable.as_ref(), + &[0x80, 0x00, 0xd8, 0x0b, 0x0f, 0xd0, 0x04, 0x14] + ); + } + + #[test] + fn assert_false_message_call_uses_assert_location_like_cpython() { + let code = compile_exec( + "\ +def f(): + assert False, \"x\" +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 codegen_assert() emits LOAD_COMMON_CONSTANT and CALL + // at LOC(assert statement), then RAISE_VARARGS at LOC(test). + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x04, 0x15, 0x90, 0x23, 0xd3, 0x04, 0x15, 0x88, 0x35, + ] + ); + } + + #[test] + fn static_swap_implicit_return_keeps_preswap_store_location_like_cpython() { + let code = compile_exec( + "\ +def f(a, b): + a, b = a, b + b, a = a, b +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 flowgraph.c resolves line numbers before + // optimize_basic_block() turns BUILD_TUPLE/UNPACK_SEQUENCE into SWAP + // and apply_static_swaps() reorders the STORE_FAST pair. The + // synthetic return epilogue therefore keeps the pre-swap final store + // location. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0b, 0x0c, 0x80, 0x71, 0xd8, 0x0b, 0x0c, 0x82, 0x71, + ] + ); + } + + #[test] + fn unpack_store_pair_jump_uses_second_target_location_like_cpython() { + let code = compile_exec( + "\ +def f(value): + if value.startswith('=?'): + try: + token, value = get_encoded_word(value) + except E: + token, value = get_atext(value) + else: + token, value = get_atext(value) + atom.append(token) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let jump_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::JumpForward { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing post-try JUMP_FORWARD"); + + // CPython 3.14 flowgraph.c turns the second STORE_FAST into a NOP + // during STORE_FAST_STORE_FAST fusion, then NOP removal copies that + // second target location onto the following no-location jump. + assert_eq!(jump_position, (4, 20, 4, 25)); + } + + #[test] + fn chained_store_pair_jump_keeps_copy_target_location_like_cpython() { + let code = compile_exec( + "\ +def f(flag): + if flag: + a = b = True + else: + a = False + b = False + g(a, b) + return a +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let jump_position = f + .instructions + .windows(2) + .zip(f.locations.windows(2)) + .find_map(|(units, locations)| { + matches!(units[0].op, Instruction::StoreFastStoreFast { .. }) + .then(|| { + matches!(units[1].op, Instruction::JumpForward { .. }).then_some(( + locations[1].0.line.get(), + locations[1].0.character_offset.get(), + locations[1].1.line.get(), + locations[1].1.character_offset.get(), + )) + }) + .flatten() + }) + .expect("missing jump after chained STORE_FAST_STORE_FAST"); + + // CPython 3.14 flowgraph.c preserves the second chained-assignment + // target location on the jump that skips the else body. + assert_eq!(jump_position, (3, 13, 3, 14)); + } + + #[test] + fn tuple_store_pair_jump_keeps_fused_store_location_like_cpython() { + let code = compile_exec( + "\ +def f(flag, n, exp): + if flag: + n, d = n * 10**exp, 1 + else: + d = -exp + g(n, d) + return n +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let jump_position = f + .instructions + .windows(2) + .zip(f.locations.windows(2)) + .find_map(|(units, locations)| { + matches!(units[0].op, Instruction::StoreFastStoreFast { .. }) + .then(|| { + matches!(units[1].op, Instruction::JumpForward { .. }).then_some(( + locations[1].0.line.get(), + locations[1].0.character_offset.get(), + locations[1].1.line.get(), + locations[1].1.character_offset.get(), + )) + }) + .flatten() + }) + .expect("missing jump after tuple STORE_FAST_STORE_FAST"); + + // Without COPY before the fused stores, CPython keeps the fused + // STORE_FAST_STORE_FAST location on the following jump. + assert_eq!(jump_position, (3, 12, 3, 13)); + } + + #[test] + fn genexpr_make_closure_and_call_use_genexpr_location_like_cpython() { + let code = compile_exec( + "\ +def f(parameters): + return ((p, type(p)) for p in parameters) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let genexpr = find_code(f, "").expect("missing genexpr code"); + + // CPython 3.14 codegen_comprehension() uses LOC(e) for + // codegen_make_closure(), the outer CALL, and the implicit .0 load + // in codegen_sync_comprehension_generator(). + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd9, 0x0b, 0x2d, 0xa1, 0x2a, 0xd3, 0x0b, 0x2d, 0xd0, 0x04, 0x2d, + ] + ); + assert_eq!( + genexpr.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0b, 0x2d, 0xa1, 0x2a, 0x98, 0x51, 0x94, 0x04, 0x90, + 0x51, 0x93, 0x07, 0x8d, 0x4c, 0xa3, 0x2a, 0xf9, + ] + ); + } + + #[test] + fn implicit_call_genexpr_range_includes_call_parens_like_cpython() { + let code = compile_exec( + "\ +def implicit(): + return list(x for x in range(10)) + +def explicit(): + return list((x for x in range(10))) +", + ); + let implicit = find_code(&code, "implicit").expect("missing implicit code"); + let implicit_gen = find_code(implicit, "").expect("missing implicit genexpr code"); + let explicit = find_code(&code, "explicit").expect("missing explicit code"); + let explicit_gen = find_code(explicit, "").expect("missing explicit genexpr code"); + + // CPython's parser gives an unparenthesized sole GeneratorExp call + // argument the call-parenthesized range, and codegen_comprehension() + // uses LOC(e) for MAKE_FUNCTION, the outer CALL, and the implicit .0 + // LOAD_FAST. Explicitly parenthesized genexprs already carry their own + // parentheses and must not be widened again. + assert_eq!( + implicit.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x0f, 0x25, 0x9c, 0x35, 0xa0, 0x12, 0x9c, 0x39, + 0xd3, 0x0f, 0x25, 0xd3, 0x0b, 0x25, 0xd0, 0x04, 0x25, + ] + ); + assert_eq!( + implicit_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0f, 0x25, 0x99, 0x39, 0x90, 0x61, 0x94, 0x01, 0x9b, + 0x39, 0xf9, + ] + ); + assert_eq!( + explicit_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x10, 0x26, 0x99, 0x49, 0x90, 0x71, 0x94, 0x11, 0x9b, + 0x49, 0xf9, + ] + ); + } + + #[test] + fn implicit_call_genexpr_parenthesized_element_range_like_cpython() { + let code = compile_exec( + "\ +def bytes_binop(): + return bytes((x ^ 0x5C) for x in range(256)) + +def dict_tuple(d): + return dict((v, k) for (k, v) in d.items()) + +def plain_tuple_elt(xs): + return list((x, y) for x, y in xs) + +def explicit_gen(xs): + return list(((x, y) for x, y in xs)) +", + ); + let bytes_binop = find_code(&code, "bytes_binop").expect("missing bytes_binop code"); + let bytes_gen = find_code(bytes_binop, "").expect("missing bytes genexpr code"); + let dict_tuple = find_code(&code, "dict_tuple").expect("missing dict_tuple code"); + let dict_gen = find_code(dict_tuple, "").expect("missing dict genexpr code"); + let plain_tuple_elt = + find_code(&code, "plain_tuple_elt").expect("missing plain_tuple_elt code"); + let plain_gen = + find_code(plain_tuple_elt, "").expect("missing plain genexpr code"); + let explicit_gen = find_code(&code, "explicit_gen").expect("missing explicit_gen code"); + let explicit_inner = + find_code(explicit_gen, "").expect("missing explicit genexpr code"); + + // CPython 3.14's parser includes the call argument parentheses in + // LOC(GeneratorExp) for implicit sole-argument generator expressions, + // even when the element expression itself starts with parentheses. + assert_eq!( + bytes_binop.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x10, 0xd1, 0x10, 0x30, 0xa4, 0x55, 0xa8, 0x33, 0xa4, 0x5a, + 0xd3, 0x10, 0x30, 0xd3, 0x0b, 0x30, 0xd0, 0x04, 0x30, + ] + ); + assert_eq!( + bytes_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x10, 0x30, 0xa1, 0x5a, 0xa0, 0x01, 0x90, 0x64, 0x97, + 0x28, 0x92, 0x28, 0xa3, 0x5a, 0xf9, + ] + ); + assert_eq!( + dict_tuple.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x0f, 0x2f, 0xa0, 0x51, 0xa7, 0x57, 0xa1, 0x57, + 0xa4, 0x59, 0xd3, 0x0f, 0x2f, 0xd3, 0x0b, 0x2f, 0xd0, 0x04, 0x2f, + ] + ); + assert_eq!( + dict_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0f, 0x2f, 0xa1, 0x59, 0x99, 0x36, 0x98, 0x41, 0x90, + 0x11, 0x95, 0x06, 0xa3, 0x59, 0xf9, + ] + ); + assert_eq!( + plain_tuple_elt.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x0f, 0x26, 0xa1, 0x32, 0xd3, 0x0f, 0x26, 0xd3, + 0x0b, 0x26, 0xd0, 0x04, 0x26, + ] + ); + assert_eq!( + plain_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0f, 0x26, 0xa1, 0x32, 0x99, 0x34, 0x98, 0x31, 0x90, + 0x11, 0x95, 0x06, 0xa3, 0x32, 0xf9, + ] + ); + assert_eq!( + explicit_gen.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x10, 0x27, 0xa1, 0x42, 0xd3, 0x10, 0x27, 0xd3, + 0x0b, 0x28, 0xd0, 0x04, 0x28, + ] + ); + assert_eq!( + explicit_inner.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x10, 0x27, 0xa1, 0x42, 0x99, 0x44, 0x98, 0x41, 0x90, + 0x21, 0x95, 0x16, 0xa3, 0x42, 0xf9, + ] + ); + } + + #[test] + fn genexpr_filter_cleanup_jumps_use_element_location_like_cpython() { + let code = compile_exec( + "\ +def simple(names): + return (x for x in names if not _ishidden(x)) + +def boolop(fields): + return (f for f in fields if f.init and not f.kw_only) +", + ); + let simple = find_code(&code, "simple").expect("missing simple code"); + let simple_gen = find_code(simple, "").expect("missing simple genexpr code"); + let boolop = find_code(&code, "boolop").expect("missing boolop code"); + let boolop_gen = find_code(boolop, "").expect("missing boolop genexpr code"); + + // CPython 3.14 codegen_sync_comprehension_generator() emits the + // comprehension guard jump to if_cleanup, then emits the if_cleanup + // backedge with elt_loc. flowgraph.c::jump_thread() copies that target + // jump location to the threaded POP_JUMP/NOT_TAKEN cleanup path. + assert_eq!( + simple_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0b, 0x31, 0x91, 0x75, 0x90, 0x21, 0xa4, 0x49, 0xa8, + 0x61, 0xa7, 0x4c, 0x8f, 0x41, 0x8a, 0x41, 0x93, 0x75, 0xf9, + ] + ); + assert_eq!( + boolop_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0b, 0x3a, 0x91, 0x76, 0x90, 0x21, 0xa7, 0x16, 0xa5, + 0x16, 0x8c, 0x41, 0xb0, 0x01, 0xb7, 0x09, 0xb5, 0x09, 0x8f, 0x41, 0x8a, 0x41, 0x93, + 0x76, 0xf9, + ] + ); + } + + #[test] + fn try_finally_exception_scaffolding_uses_no_location_like_cpython() { + let code = compile_exec( + "\ +def f(self, node): + self.flag = True + try: + self.body(node) + finally: + self.flag = False +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 codegen_try_finally() emits the exception path + // SETUP_CLEANUP/PUSH_EXC_INFO and POP_EXCEPT_AND_RERAISE with + // NO_LOCATION; flowgraph line propagation then gives only the + // finalbody's direct RERAISE the finalbody location. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x10, 0x14, 0x80, 0x44, 0x84, 0x49, 0xf0, 0x02, 0x03, 0x05, 0x1a, + 0xd8, 0x08, 0x0c, 0x8f, 0x09, 0x89, 0x09, 0x90, 0x24, 0x8c, 0x0f, 0xe0, 0x14, 0x19, + 0x88, 0x04, 0x8e, 0x09, 0xf8, 0x90, 0x45, 0x88, 0x04, 0x8d, 0x09, 0xfa, + ] + ); + } + + #[test] + fn adjacent_no_location_entries_merge_like_cpython() { + let code = compile_exec( + "\ +def f(file): + if sys.platform == \"win32\": + try: + import nt + if not nt._supports_virtual_terminal(): + return False + except (ImportError, AttributeError): + return False + try: + return os.isatty(file.fileno()) + except OSError: + return hasattr(file, \"isatty\") and file.isatty() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython's NO_LOCATION is {-1, -1, -1, -1}, and + // assemble.c::assemble_location_info() merges adjacent instructions + // with the same NO_LOCATION into one linetable entry. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x07, 0x0a, 0x87, 0x7c, 0x81, 0x7c, 0x90, 0x77, 0xd4, 0x07, 0x1e, + 0xf0, 0x02, 0x05, 0x09, 0x19, 0xdb, 0x0c, 0x15, 0xd8, 0x13, 0x15, 0xd7, 0x13, 0x30, + 0xd1, 0x13, 0x30, 0xd7, 0x13, 0x32, 0xd2, 0x13, 0x32, 0xd9, 0x17, 0x1c, 0xf0, 0x03, + 0x00, 0x14, 0x33, 0xf0, 0x08, 0x03, 0x05, 0x39, 0xdc, 0x0f, 0x11, 0x8f, 0x79, 0x89, + 0x79, 0x98, 0x14, 0x9f, 0x1b, 0x99, 0x1b, 0x9b, 0x1d, 0xd3, 0x0f, 0x27, 0xd0, 0x08, + 0x27, 0xf8, 0xf4, 0x07, 0x00, 0x11, 0x1c, 0x9c, 0x5e, 0xd0, 0x0f, 0x2c, 0xf4, 0x00, + 0x01, 0x09, 0x19, 0xda, 0x13, 0x18, 0xf0, 0x03, 0x01, 0x09, 0x19, 0xfb, 0xf4, 0x08, + 0x00, 0x0c, 0x13, 0xf4, 0x00, 0x01, 0x05, 0x39, 0xdc, 0x0f, 0x16, 0x90, 0x74, 0x98, + 0x58, 0xd3, 0x0f, 0x26, 0xd7, 0x0f, 0x38, 0xd0, 0x0f, 0x38, 0xa8, 0x34, 0xaf, 0x3b, + 0xa9, 0x3b, 0xab, 0x3d, 0xd2, 0x08, 0x38, 0xf0, 0x03, 0x01, 0x05, 0x39, 0xfa, + ] + ); + } + + #[test] + fn fstring_format_ops_use_formatted_value_location_like_cpython() { + let code = compile_exec( + "\ +def simple(self): + return f'{self.value}' + +def spec(x): + return f'{x!r:>3}' +", + ); + let simple = find_code(&code, "simple").expect("missing simple code"); + let spec = find_code(&code, "spec").expect("missing spec code"); + + // CPython 3.14 codegen_formatted_value() VISITs the inner expression + // first, then emits CONVERT_VALUE / FORMAT_* at LOC(FormattedValue). + assert_eq!( + simple.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0e, 0x12, 0x8f, 0x6a, 0x89, 0x6a, 0x88, 0x5c, 0xd0, 0x04, 0x1a, + ] + ); + assert_eq!( + spec.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0e, 0x0f, 0x88, 0x58, 0x90, 0x22, 0x88, 0x58, 0xd0, 0x04, 0x16, + ] + ); + } + + #[test] + fn debug_fstring_literal_location_like_cpython() { + fn string_load_position(code: &CodeObject, expected: &str) -> (usize, usize, usize, usize) { + code.instructions + .iter() + .zip(&code.locations) + .find_map(|(unit, (location, end_location))| { + let Instruction::LoadConst { consti } = unit.op else { + return None; + }; + let constant = + &code.constants[consti.get(OpArg::new(u32::from(u8::from(unit.arg))))]; + matches!(constant, ConstantData::Str { value } if value.to_string() == expected) + .then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing debug f-string literal") + } + + let code = compile_exec( + "\ +def simple(x): + return f'{x=}' + +def prefixed(x): + return f'a {x=} b' +", + ); + let simple = find_code(&code, "simple").expect("missing simple code"); + let prefixed = find_code(&code, "prefixed").expect("missing prefixed code"); + + assert_eq!( + string_load_position(simple, "x="), + (2, 15, 2, 17), + "CPython represents f'{{x=}}' debug text as a literal at the expression/debug-text location" + ); + assert_eq!( + string_load_position(prefixed, "a x="), + (5, 14, 5, 19), + "CPython extends a pending f-string literal through the debug text range" + ); + } + + #[test] + fn fstring_format_spec_build_string_location_like_cpython() { + let code = compile_exec( + "\ +def simple(lbl, label_width): + return f'{lbl:>{label_width}}' + +def padded(digits, int_len): + return f'{digits:0>{int_len + 1}d}' +", + ); + let simple = find_code(&code, "simple").expect("missing simple code"); + let padded = find_code(&code, "padded").expect("missing padded code"); + + let build_string_position = |code: &CodeObject| { + code.instructions + .iter() + .zip(&code.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::BuildString { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing format-spec BUILD_STRING") + }; + + assert_eq!( + build_string_position(simple), + (2, 18, 2, 33), + "CPython uses the format-spec JoinedStr location, including the ':' prefix, for BUILD_STRING" + ); + assert_eq!( + build_string_position(padded), + (5, 21, 5, 38), + "CPython format-spec JoinedStr location spans from ':' through the final literal" + ); + } + + #[test] + fn joined_string_literals_extend_pending_literal_location_like_cpython() { + let code = compile_exec( + "\ +def f(a): + return ( + 'x' + f'y{a}z' + 'w' + ) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xf0, 0x04, 0x01, 0x09, 0x0c, 0xd8, 0x0c, 0x0d, 0x88, 0x33, 0xf0, 0x00, + 0x01, 0x0f, 0x0c, 0xf0, 0x03, 0x02, 0x09, 0x0c, 0xf0, 0x03, 0x04, 0x05, 0x06, + ], + "CPython parser/codegen represents adjacent f-string literal fragments as Constant ranges spanning the merged fragments" + ); + } + + #[test] + fn starred_call_preserves_bool_op_short_circuit_shape() { let code = compile_exec( "\ def f(g): @@ -13174,7 +16455,7 @@ def f(g): } #[test] - fn test_partial_constant_bool_op_folds_prefix_in_value_context() { + fn partial_constant_bool_op_folds_prefix_in_value_context() { let code = compile_exec( "\ def outer(null): @@ -13215,7 +16496,61 @@ def outer(null): } #[test] - fn test_taken_constant_boolop_jump_disables_following_borrows() { + fn decorated_definitions_use_cpython_locations() { + let code = compile_exec( + "\ +def dec(f): return f + +class C: + @dec + def f(self): + yield + +@dec +class D: + pass + +class E: + @dec + def g(self, flags: int, /) -> memoryview: + raise NotImplementedError +", + ); + let c = find_code(&code, "C").expect("missing C code"); + let d = find_code(&code, "D").expect("missing D code"); + let e = find_code(&code, "E").expect("missing E code"); + let annotate = find_code(e, "__annotate__").expect("missing annotation code"); + + // CPython 3.14 codegen_function()/codegen_class() evaluate + // decorators first, then use LOC(s) for codegen_make_closure() and + // codegen_nameop(); codegen_apply_decorators() emits CALL at each + // decorator expression's location. + assert_eq!( + c.linetable.as_ref(), + &[ + 0xf8, 0x87, 0x00, 0x80, 0x00, 0xd8, 0x05, 0x08, 0xf1, 0x02, 0x01, 0x05, 0x0e, 0xf3, + 0x03, 0x00, 0x06, 0x09, 0xf6, 0x02, 0x01, 0x05, 0x0e, + ] + ); + assert_eq!(d.linetable.as_ref(), &[0x86, 0x00, 0xe3, 0x04, 0x08]); + assert_eq!( + e.linetable.as_ref(), + &[ + 0xf8, 0x87, 0x00, 0x80, 0x00, 0xd8, 0x05, 0x08, 0xf7, 0x02, 0x01, 0x05, 0x22, 0xf3, + 0x03, 0x00, 0x06, 0x09, 0xf6, 0x02, 0x01, 0x05, 0x22, + ] + ); + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0xf7, 0x00, 0x01, 0x05, 0x22, 0xf1, 0x00, 0x01, 0x05, 0x22, 0x91, + 0x73, 0xf0, 0x00, 0x01, 0x05, 0x22, 0xa1, 0x2a, 0xf1, 0x00, 0x01, 0x05, 0x22, + ] + ); + } + + #[test] + fn taken_constant_boolop_jump_disables_following_borrows() { for source in [ "\ def f(self): @@ -13254,7 +16589,7 @@ def f(self): } #[test] - fn test_untaken_constant_boolop_jump_keeps_following_borrows() { + fn not_taken_constant_boolop_jump_keeps_following_borrows() { for source in [ "\ def f(self): @@ -13286,7 +16621,226 @@ def f(self): } #[test] - fn test_nonliteral_constant_bool_op_preserves_short_circuit_shape() { + fn while_before_folded_boolop_if_keeps_successor_load_fast_strong() { + let code = compile_exec( + "\ +def f(running, errors): + try: + while running: + pass + if __debug__ and errors: + raise ExceptionGroup('x', errors) + return errors + finally: + del errors +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let errors_loads = load_fast_ops_for_var(f, "errors"); + assert!( + errors_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_while() emits USE_LABEL(end), then codegen_jump_if() emits a folded constant BoolOp prefix; flowgraph.c::basicblock_optimize_load_const() and optimize_load_fast() leave the successor errors loads strong, got {errors_loads:?}" + ); + } + + #[test] + fn with_try_except_tail_keeps_successor_load_fast_strong() { + let code = compile_exec( + "\ +def f(self, cm, value): + with cm: + try: + self.run() + except OSError as e: + if e.errno: + raise ConnectionError + else: + raise + self.run_loop(value.done) + self.assertTrue(value.nbytes) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let value_loads = load_fast_ops_for_var(f, "value"); + assert!( + value_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_try_except() leaves USE_LABEL(end) before codegen_with_inner() emits normal __exit__ cleanup, so optimize_load_fast() leaves with-successor loads strong; got {value_loads:?}" + ); + } + + #[test] + fn with_try_except_conditional_body_allows_successor_borrow() { + let code = compile_exec( + "\ +def f(max_decode, gzf, cm, decoded): + with cm: + try: + if max_decode < 0: + decoded = gzf.read() + else: + decoded = gzf.read(max_decode + 1) + except OSError: + raise ValueError('invalid data') + if max_decode >= 0 and len(decoded) > max_decode: + raise ValueError('too large') + return decoded +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let max_decode_loads = load_fast_ops_for_var(f, "max_decode"); + assert!( + max_decode_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFastBorrow { .. })), + "CPython codegen_try_except() keeps borrowing through a with tail when the protected try body ends with a conditional label, got {max_decode_loads:?}" + ); + } + + #[test] + fn loop_try_orelse_nested_try_before_next_try_keeps_load_fast_strong() { + let code = compile_exec( + "\ +def f(test_cases): + for src, dest in test_cases: + try: + os.symlink(src, dest) + except FileNotFoundError: + pass + else: + try: + os.remove(dest) + except OSError: + pass + try: + os.symlink(os.fsencode(src), os.fsencode(dest)) + except FileNotFoundError: + pass + else: + try: + os.remove(dest) + except OSError: + pass +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let strong_src_dest_loads = count_strong_loads_for_vars(f, &["src", "dest"]); + assert!( + strong_src_dest_loads >= 2, + "CPython codegen_try_except() emits orelse then USE_LABEL(end) before the following loop try, so optimize_load_fast() leaves fsencode arguments strong; got {strong_src_dest_loads} strong src/dest loads" + ); + } + + #[test] + fn try_orelse_with_before_next_try_keeps_load_fast_strong() { + let code = compile_exec( + "\ +def f(self, f, cm): + try: + f = C(0) + except ValueError: + pass + else: + self.assertTrue(f.readable()) + with cm: + with C(False): + pass + try: + f = C(1) + except ValueError: + pass + else: + self.assertFalse(f.readable()) +", + ); + let f_code = find_code(&code, "f").expect("missing function code"); + let strong_self_or_f_loads = count_strong_loads_for_vars(f_code, &["self", "f"]); + assert!( + strong_self_or_f_loads >= 2, + "CPython codegen_try_except() emits orelse with nested with cleanup before USE_LABEL(end), so optimize_load_fast() leaves loads in the following try/else strong; got {strong_self_or_f_loads} strong self/f loads" + ); + } + + #[test] + fn try_orelse_single_with_before_next_try_keeps_borrows() { + let code = compile_exec( + "\ +def f(self, cm): + try: + import _testcapi + except ImportError: + pass + else: + code = 'x' + with cm: + out = self.run_xdev('-c', code) + self.assertEqual(out, 'x') + try: + import faulthandler + except ImportError: + pass + else: + code = 'y' + out = self.run_xdev('-c', code) + self.assertEqual(out, 'y') +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let final_return = instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::ReturnValue)) + .expect("missing return"); + let tail = &instructions[..final_return]; + let borrowed_self_loads = tail + .iter() + .filter(|unit| match unit.op { + Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.varnames[usize::from(var_num.get(arg))] == "self" + } + _ => false, + }) + .count(); + assert!( + borrowed_self_loads >= 4, + "CPython codegen_with() for a single with in try/else does not make the following try/else a load-fast barrier; expected borrowed self loads, got instructions={instructions:?}" + ); + } + + #[test] + fn with_try_finally_nested_with_keeps_successor_load_fast_strong() { + let code = compile_exec( + "\ +def f(self, cm): + with cm as cm1: + try: + work() + finally: + with cm as cm2: + work() + e1 = cm1.exception + e12 = e1.__cause__ + self.assertIsInstance(e12, Error) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let strong_loads = count_strong_loads(f); + assert!( + strong_loads >= 3, + "CPython codegen_try_finally() inlines a finalbody with codegen_with_inner() cleanup before the with exit label, so optimize_load_fast() leaves successor loads strong; got {strong_loads} strong loads" + ); + } + + #[test] + fn nonliteral_constant_bool_op_preserves_short_circuit_shape() { let code = compile_exec( "\ x = (\"a\"[0]) or 2 @@ -13324,7 +16878,7 @@ x = (\"a\"[0]) or 2 } #[test] - fn test_unary_positive_complex_constant_folds_to_load_const() { + fn unary_positive_complex_constant_folds_to_load_const() { let code = compile_exec( "\ x = +0.0j @@ -13358,7 +16912,7 @@ x = +0.0j } #[test] - fn test_folded_nonliteral_bool_op_tail_keeps_plain_load_fast() { + fn folded_nonliteral_bool_op_tail_keeps_plain_load_fast() { let code = compile_exec( "\ def and_true(x): @@ -13397,7 +16951,57 @@ def or_false(x): } #[test] - fn test_folded_nonliteral_tuple_unpack_tail_keeps_plain_load_fast() { + fn folded_nonliteral_bool_op_direct_tail_load_keeps_plain_load_fast() { + let code = compile_exec( + "\ +def return_tail(x): + return False or x + +def assign_tail(x): + y = False or x + return y + +def call_arg(x, g): + return g(False or x) + +def attr_tail(x): + return False or x.y + +def class_tail(class_decorator): + @False or class_decorator + class H: + pass +", + ); + + for name in ["return_tail", "assign_tail", "call_arg", "class_tail"] { + let function = find_code(&code, name).unwrap_or_else(|| panic!("missing {name} code")); + let local = if name == "class_tail" { + "class_decorator" + } else { + "x" + }; + let loads = load_fast_ops_for_var(function, local); + assert!( + loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_boolop() emits USE_LABEL(end) after the folded BoolOp tail, so optimize_load_fast() leaves the direct tail load strong in {name}; got {loads:?}" + ); + } + + let attr_tail = find_code(&code, "attr_tail").expect("missing attr_tail code"); + let attr_receiver_loads = load_fast_ops_for_var(attr_tail, "x"); + assert!( + attr_receiver_loads + .iter() + .all(|op| matches!(op, Instruction::LoadFastBorrow { .. })), + "CPython only keeps direct folded tail local loads strong; an attribute receiver is consumed before the BoolOp end label, got {attr_receiver_loads:?}" + ); + } + + #[test] + fn folded_nonliteral_tuple_unpack_tail_keeps_plain_load_fast() { let code = compile_exec( "\ def f(self, mod): @@ -13438,7 +17042,7 @@ def f(self, mod): } #[test] - fn test_scope_exit_instructions_keep_line_numbers() { + fn scope_exit_instructions_keep_line_numbers() { let code = compile_exec( "\ async def test(): @@ -13457,7 +17061,7 @@ async def test(): } #[test] - fn test_attribute_ex_call_uses_plain_load_attr() { + fn attribute_ex_call_uses_plain_load_attr() { let code = compile_exec( "\ def f(cls, args, kwargs): @@ -13493,7 +17097,7 @@ def f(cls, args, kwargs): } #[test] - fn test_large_plain_call_uses_direct_call_until_stack_guideline() { + fn large_plain_call_uses_direct_call_until_stack_guideline() { let code = compile_exec( "\ def f(g): @@ -13525,7 +17129,49 @@ def f(g): } #[test] - fn test_simple_attribute_call_keeps_method_load() { + fn too_large_plain_call_uses_cpython_tuple_ex_call_path() { + let args = (0..=STACK_USE_GUIDELINE) + .map(|i| format!("'v{i}'")) + .collect::>() + .join(", "); + let code = compile_exec(&format!("def f(g):\n return g({args})\n")); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.iter() + .any(|op| matches!(op, Instruction::CallFunctionEx)), + "CPython routes calls over _PY_STACK_USE_GUIDELINE through CALL_FUNCTION_EX, got ops={ops:?}" + ); + assert!( + !ops.iter().any(|op| matches!( + op, + Instruction::BuildTuple { .. } + | Instruction::ListAppend { .. } + | Instruction::CallIntrinsic1 { .. } + )), + "CPython flowgraph.c folds the starunpack tuple for constant too-large calls, got ops={ops:?}" + ); + assert!( + f.constants.iter().any(|constant| { + matches!( + constant, + ConstantData::Tuple { elements } + if elements.len() == usize::try_from(STACK_USE_GUIDELINE + 1).unwrap() + ) + }), + "expected CPython folded tuple constant for too-large call args, got constants={:?}", + f.constants.iter().collect::>() + ); + } + + #[test] + fn simple_attribute_call_keeps_method_load() { let code = compile_exec( "\ def f(obj, arg): @@ -13551,7 +17197,7 @@ def f(obj, arg): } #[test] - fn test_starred_super_call_keeps_attr_line_nop() { + fn starred_super_call_keeps_attr_line_nop() { let code = compile_exec( "\ def outer(log): @@ -13587,7 +17233,7 @@ def outer(log): } #[test] - fn test_builtin_any_genexpr_call_is_optimized() { + fn builtin_any_genexpr_call_is_optimized() { let code = compile_exec( "\ def f(xs): @@ -13615,10 +17261,67 @@ def f(xs): 1, "fallback call path should remain for shadowed any()" ); + let genexpr_const_count = f + .constants + .iter() + .filter(|constant| { + matches!(constant, ConstantData::Code { code } if code.obj_name == "") + }) + .count(); + assert_eq!( + genexpr_const_count, 1, + "optimized and fallback any(genexpr) paths should share the same CPython-range code const" + ); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdf, 0x0b, 0x0e, 0x8b, 0x33, 0x89, 0x6f, 0x99, 0x22, 0x8b, 0x6f, 0x8f, + 0x33, 0x8c, 0x33, 0xd0, 0x04, 0x1d, 0x8a, 0x33, 0xd0, 0x04, 0x1d, 0x88, 0x33, 0x89, + 0x6f, 0x99, 0x22, 0x8b, 0x6f, 0xd3, 0x0b, 0x1d, 0xd0, 0x04, 0x1d, + ] + ); } #[test] - fn test_builtin_tuple_genexpr_call_is_optimized_but_list_set_are_not() { + fn builtin_any_async_genexpr_call_is_not_optimized() { + let code = compile_exec( + "\ +async def f(xs): + return any(x async for x in xs) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + + assert!( + !has_common_constant(f, bytecode::CommonConstant::BuiltinAny), + "CPython maybe_optimize_function_call() skips coroutine generator expressions" + ); + assert!( + f.instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::Call { .. })), + "async genexpr any() should stay on the normal call path" + ); + } + + #[test] + fn builtin_any_genexpr_outermost_await_is_optimized_like_cpython() { + let code = compile_exec( + "\ +async def f(get_xs): + return any(x for x in await get_xs()) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + + assert!( + has_common_constant(f, bytecode::CommonConstant::BuiltinAny), + "CPython checks the generator expression symtable entry, so await in the outermost iterator does not make the genexpr coroutine" + ); + } + + #[test] + fn builtin_tuple_genexpr_call_is_optimized_but_list_set_are_not() { let code = compile_exec( "\ def tuple_f(xs): @@ -13647,6 +17350,14 @@ def set_f(xs): }) .expect("tuple(genexpr) fast path should emit LIST_APPEND"); assert_eq!(tuple_list_append, 2); + assert_eq!( + tuple_f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdf, 0x0b, 0x10, 0x8c, 0x35, 0x91, 0x0f, 0x99, 0x42, 0x93, 0x0f, 0x8f, + 0x35, 0xd0, 0x04, 0x1f, 0x88, 0x35, 0x91, 0x0f, 0x99, 0x42, 0x93, 0x0f, 0xd3, 0x0b, + 0x1f, 0xd0, 0x04, 0x1f, + ] + ); let list_f = find_code(&code, "list_f").expect("missing list_f code"); assert!( @@ -13676,7 +17387,7 @@ def set_f(xs): } #[test] - fn test_builtin_tuple_genexpr_try_assignment_uses_shared_tail() { + fn builtin_tuple_genexpr_try_assignment_uses_shared_tail() { let code = compile_exec( "\ def f(xs): @@ -13718,7 +17429,7 @@ def f(xs): } #[test] - fn test_builtin_tuple_genexpr_unprotected_assignment_return_duplicates_tail() { + fn builtin_tuple_genexpr_unprotected_assignment_return_duplicates_tail() { let code = compile_exec( "\ def f(arg): @@ -13753,7 +17464,7 @@ def f(arg): } #[test] - fn test_unprotected_builtin_any_prefix_before_returning_try_keeps_borrow() { + fn unprotected_builtin_any_prefix_before_returning_try_keeps_borrow() { let code = compile_exec( "\ def f(template): @@ -13805,7 +17516,7 @@ def f(template): } #[test] - fn test_module_store_uses_store_global_when_nested_scope_declares_global() { + fn module_store_uses_store_global_when_nested_scope_declares_global() { let code = compile_exec( "\ _address_fmt_re = None @@ -13828,7 +17539,7 @@ class C: } #[test] - fn test_conditional_return_epilogue_is_duplicated() { + fn conditional_return_epilogue_is_duplicated() { let code = compile_exec( "\ def f(base, cls, state): @@ -13850,7 +17561,7 @@ def f(base, cls, state): } #[test] - fn test_loop_store_subscr_threads_direct_backedge() { + fn loop_store_subscr_threads_direct_backedge() { let code = compile_exec( "\ def f(kwonlyargs, kw_only_defaults, arg2value): @@ -13890,7 +17601,7 @@ def f(kwonlyargs, kw_only_defaults, arg2value): } #[test] - fn test_protected_store_subscr_tail_uses_strong_loads() { + fn protected_store_subscr_tail_uses_strong_loads() { let code = compile_exec( "\ def f(cache, lock, format): @@ -13972,7 +17683,40 @@ def g(lock, format): } #[test] - fn test_augassign_two_part_slice_uses_slice_opcodes() { + fn try_except_inner_for_cleanup_allows_try_end_borrow() { + let code = compile_exec( + "\ +def f(self, futures, already_completed, future, short_timeout): + for timeout in (0, short_timeout): + with self.subTest(timeout): + completed_futures = set() + try: + for item in futures.as_completed(already_completed | {future}, timeout): + completed_futures.add(item) + except futures.TimeoutError: + pass + self.assertEqual(completed_futures, already_completed) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let borrow_pair_count = f + .instructions + .iter() + .filter(|unit| matches!(unit.op, Instruction::LoadFastBorrowLoadFastBorrow { .. })) + .count(); + + assert!( + borrow_pair_count >= 2, + "expected CPython-style borrowed pair loads before and after inner for cleanup, got ops={:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn augassign_two_part_slice_uses_slice_opcodes() { let code = compile_exec( "\ def aug(x, a, b, y): @@ -14033,7 +17777,28 @@ def aug(x, a, b, y): } #[test] - fn test_loop_return_reorders_backedge_before_exit_cleanup() { + fn augassign_constant_slice_copy_uses_subscript_location_like_cpython() { + let code = compile_exec( + "\ +def aug_const(x, y): + x[1:2] += y +", + ); + let aug_const = find_code(&code, "aug_const").expect("missing aug_const code"); + + // CPython 3.14 codegen_augassign() visits a constant slice, then emits + // COPY/COPY/BINARY_OP NB_SUBSCR at LOC(target), not at LOC(slice). + assert_eq!( + aug_const.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x04, 0x05, 0x80, 0x63, 0x87, 0x46, 0x88, 0x61, 0x85, 0x4b, 0x85, + 0x46, + ] + ); + } + + #[test] + fn loop_return_reorders_backedge_before_exit_cleanup() { let code = compile_exec( "\ def f(obj): @@ -14066,24 +17831,8 @@ def f(obj): ] ) }); - let has_conservative_shape = ops.windows(9).any(|window| { - matches!( - window, - [ - Instruction::PopJumpIfNone { .. }, - Instruction::NotTaken, - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::Swap { .. }, - Instruction::PopTop, - Instruction::ReturnValue, - Instruction::Nop, - Instruction::JumpBackward { .. }, - Instruction::EndFor, - ] - ) - }); assert!( - has_cpython_shape || has_conservative_shape, + has_cpython_shape, "expected loop return null-check to keep the backedge adjacent to the return cleanup, got ops={ops:?}" ); @@ -14103,12 +17852,12 @@ def f(obj): } #[test] - fn test_nested_try_finally_cleanup_reorder_does_not_invert_forward_jumps() { + fn nested_try_finally_cleanup_reorder_does_not_invert_forward_jumps() { compile_exec(include_str!("../../../Lib/poplib.py")); } #[test] - fn test_conditional_body_is_preserved_before_final_return() { + fn conditional_body_is_preserved_before_final_return() { let code = compile_exec( "\ def f(x, y): @@ -14142,7 +17891,7 @@ def f(x, y): } #[test] - fn test_nested_conditional_body_is_preserved_before_final_return() { + fn nested_conditional_body_is_preserved_before_final_return() { let code = compile_exec( "\ def outer(): @@ -14181,7 +17930,7 @@ def outer(): } #[test] - fn test_try_line_nop_is_preserved_before_setup_finally() { + fn try_line_nop_is_preserved_before_setup_finally() { let code = compile_exec( "\ def f(msg): @@ -14213,7 +17962,7 @@ def f(msg): } #[test] - fn test_nested_try_line_nops_after_for_cleanup_are_preserved() { + fn nested_try_line_nops_after_for_cleanup_are_preserved() { let code = compile_exec( "\ def f(xs, env): @@ -14258,7 +18007,7 @@ def f(xs, env): } #[test] - fn test_try_finally_assert_keeps_finalbody_entry_nop() { + fn try_finally_assert_keeps_finalbody_entry_nop() { let code = compile_exec( "\ def f(x): @@ -14309,7 +18058,7 @@ def f(x): } #[test] - fn test_try_finally_if_break_false_edge_keeps_finalbody_entry_nop() { + fn try_finally_if_break_false_edge_keeps_finalbody_entry_nop() { let code = compile_exec( "\ def f(self, pid): @@ -14351,7 +18100,7 @@ def f(self, pid): } #[test] - fn test_try_percent_format_preprocess_removes_redundant_try_nop() { + fn try_percent_format_preprocess_removes_redundant_try_nop() { let code = compile_exec( "\ def f(self, signal): @@ -14404,7 +18153,7 @@ def f(self, signal): } #[test] - fn test_nested_try_except_in_finally_exception_path_shares_continuation() { + fn nested_try_except_in_finally_exception_path_shares_continuation() { let code = compile_exec( "\ def f(self, exc_type, KeyboardInterrupt, TimeoutExpired): @@ -14462,7 +18211,7 @@ def f(self, exc_type, KeyboardInterrupt, TimeoutExpired): } #[test] - fn test_try_else_return_keeps_nop_before_final_call_return() { + fn try_else_return_keeps_nop_before_final_call_return() { let code = compile_exec( "\ def f(msg): @@ -14504,7 +18253,7 @@ def f(msg): } #[test] - fn test_try_else_conditional_scope_exit_keeps_pop_block_nop() { + fn try_else_conditional_scope_exit_keeps_pop_block_nop() { let code = compile_exec( "\ def f(values, check): @@ -14547,7 +18296,7 @@ def f(values, check): } #[test] - fn test_try_else_loop_fallthrough_keeps_end_jump_nop_before_finally() { + fn try_else_loop_fallthrough_keeps_end_jump_nop_before_finally() { let code = compile_exec( "\ def f(locale, category, locales): @@ -14602,7 +18351,7 @@ def f(locale, category, locales): } #[test] - fn test_conditional_compare_uses_bool_compare_oparg() { + fn conditional_compare_uses_bool_compare_oparg() { let code = compile_exec( "\ def f(x, y): @@ -14622,7 +18371,7 @@ def f(x, y): } #[test] - fn test_multiline_is_none_conditional_keeps_comparator_nop() { + fn multiline_is_none_conditional_keeps_comparator_nop() { let code = compile_exec( "\ def f(x): @@ -14656,7 +18405,7 @@ def f(x): } #[test] - fn test_chained_conditional_compares_use_bool_compare_oparg() { + fn chained_conditional_compares_use_bool_compare_oparg() { let code = compile_exec( "\ def f(a, b, c): @@ -14677,7 +18426,7 @@ def f(a, b, c): } #[test] - fn test_shared_final_return_is_cloned_for_jump_target() { + fn shared_final_return_is_cloned_for_jump_target() { let code = compile_exec( "\ def f(node): @@ -14712,7 +18461,7 @@ def f(node): } #[test] - fn test_for_break_uses_poptop_cleanup() { + fn for_break_uses_poptop_cleanup() { let code = compile_exec( "\ def f(parts): @@ -14762,7 +18511,7 @@ def f(parts): } #[test] - fn test_for_exit_before_elif_does_not_leave_line_anchor_nop() { + fn for_exit_before_elif_does_not_leave_line_anchor_nop() { let code = compile_exec( "\ from sys import maxsize @@ -14814,7 +18563,7 @@ elif maxsize == 9223372036854775807: } #[test] - fn test_for_tuple_target_does_not_leave_loop_header_nop() { + fn for_tuple_target_does_not_leave_loop_header_nop() { let code = compile_exec( "\ def f(pairs): @@ -14858,7 +18607,7 @@ def f(pairs): } #[test] - fn test_tstring_build_template_matches_cpython_stack_order() { + fn tstring_build_template_matches_cpython_stack_order() { let code = compile_exec("t = t\"{0}\""); let units: Vec<_> = code .instructions @@ -14900,7 +18649,7 @@ def f(pairs): } #[test] - fn test_tstring_debug_specifier_uses_debug_literal_and_repr_default() { + fn tstring_debug_specifier_uses_debug_literal_and_repr_default() { let code = compile_exec( "\ value = 42 @@ -14943,7 +18692,7 @@ t = t\"Value: {value=}\" } #[test] - fn test_tstring_literal_preserves_surrogate_wtf8() { + fn tstring_literal_preserves_surrogate_wtf8() { let code = compile_exec("t = t\"\\ud800\""); assert!(code.constants.iter().any(|constant| matches!( @@ -14953,7 +18702,7 @@ t = t\"Value: {value=}\" } #[test] - fn test_break_in_finally_after_return_keeps_load_fast_check_for_loop_locals() { + fn break_in_finally_after_return_keeps_load_fast_check_for_loop_locals() { let code = compile_exec( "\ def g2(x): @@ -14992,7 +18741,7 @@ def g2(x): } #[test] - fn test_high_index_parameter_stays_initialized_in_fast_scan() { + fn high_index_parameter_stays_initialized_in_fast_scan() { let params = (0..65) .map(|idx| format!("p{idx}")) .collect::>() @@ -15008,31 +18757,21 @@ def f({params}): assert!( f.instructions.iter().any(|unit| matches!( unit.op, - Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } + Instruction::LoadFastCheck { var_num } if f.varnames [usize::from(var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))))] == "p64" )), - "expected high-index parameter p64 to use LOAD_FAST, got ops={:?}", + "CPython 3.14 fast_scan_many_locals() checks high-index parameters per block; expected p64 to use LOAD_FAST_CHECK, got ops={:?}", f.instructions .iter() .map(|unit| unit.op) .collect::>() ); - assert!( - !f.instructions.iter().any(|unit| matches!( - unit.op, - Instruction::LoadFastCheck { var_num } - if f.varnames - [usize::from(var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))))] - == "p64" - )), - "high-index parameter p64 should not use LOAD_FAST_CHECK before deletion" - ); } #[test] - fn test_deleted_high_index_parameter_uses_load_fast_check() { + fn deleted_high_index_parameter_uses_load_fast_check() { let params = (0..65) .map(|idx| format!("p{idx}")) .collect::>() @@ -15063,7 +18802,7 @@ def f({params}): } #[test] - fn test_assert_without_message_raises_class_directly() { + fn assert_without_message_raises_class_directly() { let code = compile_exec( "\ def f(x): @@ -15087,7 +18826,7 @@ def f(x): } #[test] - fn test_assert_with_message_uses_common_constant_direct_call() { + fn assert_with_message_uses_common_constant_direct_call() { let code = compile_exec( "\ def f(x, y): @@ -15137,7 +18876,7 @@ def f(x, y): } #[test] - fn test_conditional_assert_message_target_uses_strong_load_fast() { + fn conditional_assert_message_target_uses_strong_load_fast() { let code = compile_exec( "\ def f(fname): @@ -15175,7 +18914,7 @@ def f(fname): } #[test] - fn test_chained_compare_assert_message_keeps_borrowed_load_fast() { + fn chained_compare_assert_message_keeps_borrowed_load_fast() { let code = compile_exec( "\ def f(month): @@ -15215,7 +18954,7 @@ def f(month): } #[test] - fn test_assert_message_after_condition_in_same_block_keeps_borrowed_loads() { + fn assert_message_after_condition_in_same_block_keeps_borrowed_loads() { let code = compile_exec( "\ def f(expected_ns, namespace): @@ -15291,7 +19030,7 @@ def f(expected_ns, namespace): } #[test] - fn test_bare_function_annotations_check_attribute_and_subscript_expressions() { + fn bare_function_annotations_check_attribute_and_subscript_expressions() { let code = compile_exec( "\ def f(one: int): @@ -15325,7 +19064,7 @@ def f(one: int): } #[test] - fn test_function_local_annassign_annotation_does_not_capture_outer_local() { + fn function_local_annassign_annotation_does_not_capture_outer_local() { let code = compile_exec( "\ def f(): @@ -15360,7 +19099,7 @@ def f(): } #[test] - fn test_finally_exception_path_inlines_except_pass_reraise_tail() { + fn finally_exception_path_inlines_except_pass_reraise_tail() { let source = "\ def f(self, file, backupfilename): try: @@ -15408,7 +19147,63 @@ def f(self, file, backupfilename): } #[test] - fn test_non_simple_bare_name_annotation_does_not_create_local_binding() { + fn nested_finally_exception_path_prunes_dead_normal_cleanup() { + let code = compile_exec( + "\ +def f(): + try: + raise ValueError + finally: + try: + raise KeyError + finally: + 1/0 +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let load_global_name = |unit: &&bytecode::CodeUnit| match unit.op { + Instruction::LoadGlobal { namei } => { + let name = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))) >> 1; + Some(f.names[usize::try_from(name).unwrap()].as_str()) + } + _ => None, + }; + let value_error_pos = ops + .iter() + .position(|unit| load_global_name(unit) == Some("ValueError")) + .expect("missing ValueError load"); + let key_error_pos = ops + .iter() + .position(|unit| load_global_name(unit) == Some("KeyError")) + .expect("missing KeyError load"); + let first_push_exc_after_value_error = ops[value_error_pos..] + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .map(|pos| pos + value_error_pos) + .expect("missing finally exception entry"); + let first_copy_after_value_error = ops[value_error_pos..] + .iter() + .position(|unit| matches!(unit.op, Instruction::Copy { .. })) + .map(|pos| pos + value_error_pos) + .expect("missing finally cleanup copy"); + + assert!( + first_push_exc_after_value_error < key_error_pos, + "CPython codegen_try_finally() enters the outer finally exception path before compiling the nested try body; got ops={ops:?}" + ); + assert!( + first_push_exc_after_value_error < first_copy_after_value_error, + "CPython remove_unreachable() must not keep cleanup targets from dead normal-finally paths before the live exception path; got ops={ops:?}" + ); + } + + #[test] + fn non_simple_bare_name_annotation_does_not_create_local_binding() { let code = compile_exec( "\ def f2bad(): @@ -15440,7 +19235,7 @@ def f2bad(): } #[test] - fn test_negative_constant_binop_folds_after_unary_folding() { + fn negative_constant_binop_folds_after_unary_folding() { let code = compile_exec( "\ def f(): @@ -15472,7 +19267,7 @@ def f(): } #[test] - fn test_genexpr_filter_header_uses_store_fast_load_fast() { + fn genexpr_filter_header_uses_store_fast_load_fast() { let code = compile_exec( "\ def f(it): @@ -15504,7 +19299,7 @@ def f(it): } #[test] - fn test_generator_filter_keeps_cpython_style_forward_yield_body_entry() { + fn generator_filter_keeps_cpython_style_forward_yield_body_entry() { let code = compile_exec( "\ def gen(it): @@ -15542,7 +19337,7 @@ def gen(it): } #[test] - fn test_generator_negated_filter_keeps_cpython_style_false_edge_into_yield_body() { + fn generator_negated_filter_keeps_cpython_style_false_edge_into_yield_body() { let code = compile_exec( "\ def gen(fields): @@ -15580,7 +19375,7 @@ def gen(fields): } #[test] - fn test_loop_filter_with_nested_loop_body_uses_cpython_implicit_continue_layout() { + fn loop_filter_with_nested_loop_body_uses_cpython_implicit_continue_layout() { let code = compile_exec( "\ def f(values): @@ -15627,7 +19422,7 @@ def f(values): } #[test] - fn test_final_elif_with_inlined_comprehensions_threads_backedge_before_body() { + fn final_elif_with_inlined_comprehensions_threads_backedge_before_body() { let code = compile_exec( "\ def f(checks, enumeration, named): @@ -15698,7 +19493,7 @@ def f(checks, enumeration, named): } #[test] - fn test_multi_with_header_uses_store_fast_load_fast() { + fn multi_with_header_uses_store_fast_load_fast() { let code = compile_exec( "\ def f(manager): @@ -15720,7 +19515,7 @@ def f(manager): } #[test] - fn test_sequential_store_then_load_uses_store_fast_load_fast() { + fn sequential_store_then_load_uses_store_fast_load_fast() { let code = compile_exec( "\ def f(self): @@ -15741,7 +19536,7 @@ def f(self): } #[test] - fn test_match_guard_capture_uses_store_fast_load_fast() { + fn match_guard_capture_uses_store_fast_load_fast() { let code = compile_exec( "\ def f(): @@ -15764,7 +19559,7 @@ def f(): } #[test] - fn test_match_nested_capture_uses_store_fast_store_fast() { + fn match_nested_capture_uses_store_fast_store_fast() { let code = compile_exec( "\ def f(x): @@ -15787,7 +19582,7 @@ def f(x): } #[test] - fn test_match_value_real_zero_minus_zero_complex_folds_to_negative_zero_imag() { + fn match_value_real_zero_minus_zero_complex_folds_to_negative_zero_imag() { let code = compile_exec( "\ def f(x): @@ -15808,7 +19603,39 @@ def f(x): } #[test] - fn test_match_or_uses_shared_success_block() { + fn match_negative_value_const_precedes_implicit_none_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case -0.0: + y = 0 +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let negative_zero_index = f + .constants + .iter() + .position(|constant| { + matches!( + constant, + ConstantData::Float { value } if *value == 0.0 && value.is_sign_negative() + ) + }) + .expect("missing folded -0.0 match value"); + let none_index = f + .constants + .iter() + .position(|constant| matches!(constant, ConstantData::None)) + .expect("missing implicit None"); + assert!( + negative_zero_index < none_index, + "CPython ast_preprocess.c folds MatchValue constants before codegen registers the implicit None" + ); + } + + #[test] + fn match_or_uses_shared_success_block() { let code = compile_exec( "\ def http_error(status): @@ -15855,7 +19682,7 @@ def http_error(status): } #[test] - fn test_match_try_body_keeps_setup_nop_after_success_pop() { + fn match_try_body_keeps_setup_nop_after_success_pop() { let code = compile_exec( "\ def f(x): @@ -15891,7 +19718,7 @@ def f(x): } #[test] - fn test_match_mapping_attribute_key_keeps_plain_load_fast() { + fn match_mapping_attribute_key_keeps_plain_load_fast_without_block_disable() { let code = compile_exec( "\ def f(self): @@ -15905,6 +19732,27 @@ def f(self): ", ); let f = find_code(&code, "f").expect("missing function code"); + let assert_raises_attr = f + .instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() + == "assertRaises" + } + _ => false, + }) + .expect("missing assertRaises attribute load"); + let assert_raises_receiver = f.instructions[assert_raises_attr - 1].op; + assert!( + matches!(assert_raises_receiver, Instruction::LoadFastBorrow { .. }), + "mapping attribute key handling must not disable borrow optimization for the whole block; got ops={:?}", + f.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); let key_load_idx = f .instructions .iter() @@ -15919,7 +19767,7 @@ def f(self): let prev = f.instructions[key_load_idx - 1].op; assert!( matches!(prev, Instruction::LoadFast { .. }), - "expected plain LOAD_FAST before Keys.KEY mapping key, got ops={:?}", + "CPython optimize_load_fast() records MATCH_KEYS' no-input pseudo-ref with the produced-value loop index, so this consumed Keys load stays strong; got ops={:?}", f.instructions .iter() .map(|unit| unit.op) @@ -15928,8 +19776,7 @@ def f(self): } #[test] - #[ignore = "debug trace for sequence star-wildcard pattern layout"] - fn test_debug_trace_match_sequence_star_wildcard_layout() { + fn debug_trace_match_sequence_star_wildcard_layout() { let trace = compile_single_function_late_cfg_trace( "\ def f(w): @@ -15946,8 +19793,7 @@ def f(w): } #[test] - #[ignore = "debug trace for loop bool-chain jump-back layout"] - fn test_debug_trace_loop_break_bool_chain_layout() { + fn debug_trace_loop_break_bool_chain_layout() { let trace = compile_single_function_late_cfg_trace( "\ def f(filters, text, category, module, lineno, defaultaction): @@ -15970,8 +19816,7 @@ def f(filters, text, category, module, lineno, defaultaction): } #[test] - #[ignore = "debug trace for loop conditional body jump-back layout"] - fn test_debug_trace_loop_conditional_body_layout() { + fn debug_trace_loop_conditional_body_layout() { let trace = compile_single_function_late_cfg_trace( "\ def f(new, old): @@ -15988,8 +19833,140 @@ def f(new, old): } #[test] - #[ignore = "debug trace for minimized utf7 encode nested-if layout"] - fn test_debug_trace_utf7_min_encode_layout() { + fn if_false_body_blocks_following_load_fast_borrow() { + let code = compile_exec( + "\ +def f(self, groupby): + self.a() + if False: + self.dead() + self.b(groupby) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let units: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let b_attr_idx = units + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "b" + } + _ => false, + }) + .expect("missing self.b attribute load"); + assert!( + matches!(units[b_attr_idx - 1].op, Instruction::LoadFast { .. }), + "CPython keeps plain LOAD_FAST after an if False dead-body placeholder, got ops={:?}", + units.iter().map(|unit| unit.op).collect::>() + ); + assert!( + matches!(units[b_attr_idx + 1].op, Instruction::LoadFast { .. }), + "CPython keeps the argument LOAD_FAST plain after an if False dead-body placeholder, got ops={:?}", + units.iter().map(|unit| unit.op).collect::>() + ); + } + + #[test] + fn imap_append_untagged_assert_tail_keeps_load_fast() { + let code = compile_exec( + "\ +def f(self, typ, dat): + if self._idle_capture: + if (not self._idle_responses or + isinstance(self._idle_responses[-1][1][-1], bytes)): + self._idle_responses.append((typ, [dat])) + else: + response = self._idle_responses[-1] + assert response[0] == typ + response[1].append(dat) + if __debug__ and self.debug >= 5: + self._mesg(f'idle: queue untagged {typ} {dat!r}') + return +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let units: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let debug_attr_idx = units + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "debug" + } + _ => false, + }) + .expect("missing debug attribute load"); + assert!( + matches!(units[debug_attr_idx - 1].op, Instruction::LoadFast { .. }), + "CPython keeps the debug tail after an emptied bool-op block as LOAD_FAST, got ops={:?}", + units.iter().map(|unit| unit.op).collect::>() + ); + + let mesg_attr_idx = units + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "_mesg" + } + _ => false, + }) + .expect("missing _mesg attribute load"); + assert!( + matches!(units[mesg_attr_idx - 1].op, Instruction::LoadFast { .. }) + && matches!(units[mesg_attr_idx + 2].op, Instruction::LoadFast { .. }) + && matches!(units[mesg_attr_idx + 5].op, Instruction::LoadFast { .. }), + "CPython keeps LOAD_FAST in the debug message after the empty join block, got ops={:?}", + units.iter().map(|unit| unit.op).collect::>() + ); + } + + #[test] + fn assert_success_empty_boolop_block_keeps_load_fast() { + let code = compile_exec( + "\ +def f(self): + imap = self._imap + assert not imap._idle_responses + assert not imap._idle_capture + if __debug__ and imap.debug >= 4: + imap._mesg(f'idle start duration={self._duration}') +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let units: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let debug_attr_idx = units + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "debug" + } + _ => false, + }) + .expect("missing debug attribute load"); + assert!( + matches!(units[debug_attr_idx - 1].op, Instruction::LoadFast { .. }), + "CPython preserves the empty assert-success bool-op block as a LOAD_FAST barrier, got ops={:?}", + units.iter().map(|unit| unit.op).collect::>() + ); + } + + #[test] + fn debug_trace_utf7_min_encode_layout() { let trace = compile_single_function_late_cfg_trace( "\ def f(s, size, encodeSetO, encodeWhiteSpace): @@ -16017,8 +19994,7 @@ def f(s, size, encodeSetO, encodeWhiteSpace): } #[test] - #[ignore = "debug trace for with-protected loop bool-chain layout"] - fn test_debug_trace_with_loop_break_bool_chain_layout() { + fn debug_trace_with_loop_break_bool_chain_layout() { let trace = compile_single_function_late_cfg_trace( "\ def f(filters, text, category, module, lineno, defaultaction, _wm): @@ -16042,7 +20018,7 @@ def f(filters, text, category, module, lineno, defaultaction, _wm): } #[test] - fn test_try_except_else_with_finally_keeps_with_handler_before_outer_except() { + fn try_except_else_with_finally_keeps_with_handler_before_outer_except() { let code = compile_exec( "\ def f(i): @@ -16101,7 +20077,7 @@ def f(i): } #[test] - fn test_nested_try_finally_keeps_inner_finally_cleanup_nop() { + fn nested_try_finally_keeps_inner_finally_cleanup_nop() { let code = compile_exec( "\ def f(a, b, d): @@ -16143,7 +20119,7 @@ def f(a, b, d): } #[test] - fn test_nested_finally_open_conditional_falls_through_without_entry_nop() { + fn nested_finally_open_conditional_falls_through_without_entry_nop() { let code = compile_exec( "\ def f(self, f, closed, new_key): @@ -16197,7 +20173,64 @@ def f(self, f, closed, new_key): } #[test] - fn test_with_try_finally_normal_cleanup_keeps_redundant_jump_nop() { + fn nested_finally_closed_conditional_falls_through_without_extra_entry_nop() { + let code = compile_exec( + "\ +def f(was_enabled, faulthandler, sys, orig_stderr): + try: + try: + faulthandler.enable() + faulthandler.disable() + finally: + if was_enabled: + faulthandler.enable() + else: + faulthandler.disable() + finally: + sys.stderr = orig_stderr +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache | Instruction::NotTaken)) + .collect(); + + assert!( + ops.windows(3).any(|window| { + matches!( + window, + [ + Instruction::Nop, + Instruction::LoadFastBorrowLoadFastBorrow { .. } + | Instruction::LoadFastLoadFast { .. }, + Instruction::StoreAttr { .. }, + ] + ) + }), + "CPython keeps the inner finally cleanup anchor before the outer finalbody, got ops={ops:?}" + ); + assert!( + !ops.windows(4).any(|window| { + matches!( + window, + [ + Instruction::Nop, + Instruction::Nop, + Instruction::LoadFastBorrowLoadFastBorrow { .. } + | Instruction::LoadFastLoadFast { .. }, + Instruction::StoreAttr { .. }, + ] + ) + }), + "closed conditional inner finalbody should not add a second outer finalbody-entry NOP, got ops={ops:?}" + ); + } + + #[test] + fn with_try_finally_normal_cleanup_keeps_redundant_jump_nop() { let code = compile_exec( "\ def f(cm): @@ -16238,7 +20271,7 @@ def f(cm): } #[test] - fn test_with_try_except_normal_cleanup_keeps_body_exit_nop() { + fn with_try_except_normal_cleanup_keeps_body_exit_nop() { let code = compile_exec( "\ def f(cm, names, modname): @@ -16278,7 +20311,7 @@ def f(cm, names, modname): } #[test] - fn test_with_try_except_return_handler_keeps_body_exit_nop() { + fn with_try_except_return_handler_keeps_body_exit_nop() { let code = compile_exec( "\ def f(cm): @@ -16317,7 +20350,7 @@ def f(cm): } #[test] - fn test_with_try_except_else_return_handler_keeps_body_exit_nop() { + fn with_try_except_else_return_handler_keeps_body_exit_nop() { let code = compile_exec( "\ def f(cm, func, check): @@ -16359,7 +20392,7 @@ def f(cm, func, check): } #[test] - fn test_with_try_except_else_continue_handler_keeps_body_exit_nop() { + fn with_try_except_else_continue_handler_keeps_body_exit_nop() { let code = compile_exec( "\ def f(meta_path, cm): @@ -16405,7 +20438,7 @@ def f(meta_path, cm): } #[test] - fn test_elif_boolop_skips_following_elif_with_forward_jumpback_block() { + fn elif_boolop_skips_following_elif_with_forward_jumpback_block() { let code = compile_exec( r#" def f(module, fromlist, import_, recursive=False): @@ -16454,7 +20487,7 @@ def f(module, fromlist, import_, recursive=False): } #[test] - fn test_with_nonterminal_try_except_normal_cleanup_drops_body_exit_nop() { + fn with_nonterminal_try_except_normal_cleanup_drops_body_exit_nop() { let code = compile_exec( "\ def f(cm): @@ -16508,7 +20541,7 @@ def f(cm): } #[test] - fn test_with_try_except_scope_exit_body_handler_fallthrough_keeps_body_exit_nop() { + fn with_try_except_scope_exit_body_handler_fallthrough_keeps_body_exit_nop() { let code = compile_exec( "\ def f(cm, ValueError): @@ -16547,7 +20580,7 @@ def f(cm, ValueError): } #[test] - fn test_with_try_except_nested_with_normal_cleanup_drops_body_exit_nop() { + fn with_try_except_nested_with_normal_cleanup_drops_body_exit_nop() { let code = compile_exec( "\ def f(open, src, dst, copyfileobj): @@ -16606,7 +20639,7 @@ def f(open, src, dst, copyfileobj): } #[test] - fn test_with_nested_if_try_except_normal_cleanup_drops_body_exit_nop() { + fn with_nested_if_try_except_normal_cleanup_drops_body_exit_nop() { let code = compile_exec( "\ def f(cm, root): @@ -16646,7 +20679,7 @@ def f(cm, root): } #[test] - fn test_try_except_finally_normal_cleanup_keeps_body_exit_nop() { + fn try_except_finally_normal_cleanup_keeps_body_exit_nop() { let code = compile_exec( "\ def f(self, x): @@ -16691,7 +20724,7 @@ def f(self, x): } #[test] - fn test_try_except_finally_open_conditional_fallthrough_drops_body_exit_nop() { + fn try_except_finally_open_conditional_fallthrough_drops_body_exit_nop() { let code = compile_exec( "\ def f(err, ov, self): @@ -16724,7 +20757,7 @@ def f(err, ov, self): } #[test] - fn test_try_finally_loop_fallthrough_keeps_finalbody_entry_nop() { + fn try_finally_loop_fallthrough_keeps_finalbody_entry_nop() { let code = compile_exec( "\ def f(close, dup, first, second): @@ -16766,7 +20799,134 @@ def f(close, dup, first, second): } #[test] - fn test_try_finally_loop_direct_break_drops_finalbody_entry_nop() { + fn try_finally_boolop_while_fallthrough_drops_finalbody_entry_nop() { + let code = compile_exec( + "\ +def f(active, socket_map, asyncore): + try: + while active and socket_map: + asyncore.loop(timeout=0.1, count=1) + finally: + asyncore.close_all(ignore_all=True) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.windows(3).any(|window| { + matches!( + window, + [ + Instruction::JumpBackward { .. }, + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + ] + ) + }), + "CPython removes the no-location POP_BLOCK NOP before a boolop-while try/finally finalbody, got ops={ops:?}" + ); + assert!( + !ops.windows(4).any(|window| { + matches!( + window, + [ + Instruction::JumpBackward { .. }, + Instruction::Nop, + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + ] + ) + }), + "boolop-while try/finally finalbody should not keep a POP_BLOCK NOP, got ops={ops:?}" + ); + } + + #[test] + fn try_finally_with_infinite_loop_body_drops_finalbody_entry_nop() { + let code = compile_exec( + "\ +def f(self, func, args, kwargs): + try: + with self.assertRaises(ZeroDivisionError) as cm: + while True: + self.setAlarm(self.alarm_time) + func(*args, **kwargs) + finally: + self.setAlarm(0) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.windows(3).any(|window| { + matches!( + window, + [ + Instruction::Reraise { .. }, + Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + ] + ) + }), + "CPython removes the no-location POP_BLOCK NOP before the normal finalbody after a with-wrapped infinite loop, got ops={ops:?}" + ); + assert!( + !ops.windows(4).any(|window| { + matches!( + window, + [ + Instruction::Reraise { .. }, + Instruction::Nop, + Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + ] + ) + }), + "with-wrapped infinite loop try/finally should not keep a finalbody-entry NOP, got ops={ops:?}" + ); + } + + #[test] + fn try_finally_with_finalbody_blocks_following_with_borrow() { + let code = compile_exec( + "\ +def f(self, sock, socket, HOST, OSError, TypeError): + try: + sock.bind((HOST, 0)) + socket.close(sock.fileno()) + with self.assertRaises(OSError): + sock.listen(1) + finally: + with self.assertRaises(OSError): + sock.close() + with self.assertRaises(TypeError): + socket.close(42, 42) + with self.assertRaises(OSError): + socket.close(-1) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let strong_self_loads = count_strong_loads_for_vars(f, &["self"]); + assert!( + strong_self_loads >= 2, + "CPython codegen_try_finally() emits USE_LABEL(exit) after a with finalbody, so optimize_load_fast() leaves following with receivers strong; got {strong_self_loads} strong self loads" + ); + } + + #[test] + fn try_finally_loop_direct_break_drops_finalbody_entry_nop() { let code = compile_exec( "\ def f(lines, close): @@ -16820,7 +20980,150 @@ def f(lines, close): } #[test] - fn test_try_except_finally_suppressing_handler_drops_body_exit_nop() { + fn try_except_finally_handler_normal_exit_keeps_nointerrupt_jump() { + let code = compile_exec( + "\ +def f(): + try: + 2 + except: + 4 + finally: + 6 +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.windows(2).any(|window| { + matches!( + window, + [ + Instruction::PopExcept, + Instruction::JumpBackwardNoInterrupt { .. } + ] + ) + }), + "CPython codegen_try_except() emits JUMP_NO_INTERRUPT to the inner end label; when wrapped by codegen_try_finally(), push_cold_blocks_to_end() preserves it as a backward no-interrupt jump, got ops={ops:?}", + ); + assert!( + !ops.windows(3).any(|window| { + matches!( + window, + [ + Instruction::PopExcept, + Instruction::LoadConst { .. }, + Instruction::ReturnValue, + ] + ) + }), + "try/except/finally handler normal exit should not inline the function epilogue over CPython's no-interrupt jump, got ops={ops:?}", + ); + } + + #[test] + fn nested_while_break_keeps_cpython_unreachable_end_epilogue() { + let code = compile_exec( + "\ +def f(): + TRUE = 1 + while TRUE: + while TRUE: + break + break +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let returns = ops + .iter() + .filter(|op| matches!(op, Instruction::ReturnValue)) + .count(); + + assert_eq!( + returns, 3, + "CPython codegen_while() emits separate anchor/end labels and codegen_break() jumps to loop->fb_exit; after redundant jump removal, the b_next return epilogue still remains, got ops={ops:?}", + ); + } + + #[test] + fn while_else_break_keeps_separate_continue_backedges() { + let code = compile_exec( + "\ +def func(): + TRUE = 1 + x = [1] + while x: + x.pop() + while TRUE: + break + else: + continue +", + ); + let func = find_code(&code, "func").expect("missing func code"); + let ops: Vec<_> = func + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let jump_backwards = ops + .iter() + .filter(|op| matches!(op, Instruction::JumpBackward { .. })) + .count(); + + assert_eq!( + jump_backwards, 2, + "CPython codegen_break() emits a line-bearing jump to the inner while end, and codegen_while() keeps the else anchor separate from the end label; the break path and else-continue path should remain distinct backedges, got ops={ops:?}", + ); + } + + #[test] + fn break_through_finally_assert_tail_keeps_borrow_loads() { + let code = compile_exec( + "\ +def func(): + a, c, d, i = 1, 1, 1, 99 + try: + for i in range(3): + try: + a = 5 + if i > 0: + break + a = 8 + finally: + c = 10 + except: + d = 12 + assert a == 5 and c == 10 and d == 1 +", + ); + let func = find_code(&code, "func").expect("missing func code"); + for name in ["a", "c", "d"] { + let loads = load_fast_ops_for_var(func, name); + assert!( + loads + .iter() + .any(|op| matches!(op, Instruction::LoadFastBorrow { .. })), + "CPython flowgraph.c::optimize_load_fast() keeps assert-tail {name} loads borrowed after a resuming bare except; got loads={loads:?}", + ); + } + } + + #[test] + fn try_except_finally_suppressing_handler_drops_body_exit_nop() { let code = compile_exec( "\ def f(self): @@ -16859,7 +21162,7 @@ def f(self): } #[test] - fn test_conditional_break_finally_does_not_keep_break_cleanup_nop() { + fn conditional_break_finally_does_not_keep_break_cleanup_nop() { let code = compile_exec( "\ def f(tar1, x): @@ -16873,6 +21176,11 @@ def f(tar1, x): ", ); let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); let ops_lines: Vec<_> = f .instructions .iter() @@ -16897,10 +21205,66 @@ def f(tar1, x): }), "expected CPython-style break cleanup to jump directly into finally body, got ops_lines={ops_lines:?}", ); + + let close_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "close" + } + _ => false, + }) + .expect("missing close load"); + assert!( + matches!( + instructions[close_attr - 1].op, + Instruction::LoadFastBorrow { .. } + ), + "CPython visits the finalbody through the loop fallthrough when break is not the loop body tail; got instructions={instructions:?}", + ); + } + + #[test] + fn tail_conditional_break_finally_uses_empty_end_label_barrier() { + let code = compile_exec( + "\ +def f(tar1, x): + try: + while True: + if x: + break + finally: + tar1.close() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let close_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "close" + } + _ => false, + }) + .expect("missing close load"); + assert!( + matches!( + instructions[close_attr - 1].op, + Instruction::LoadFast { .. } + ), + "CPython leaves an empty while-end label before finalbody when the direct break is the loop body tail; got instructions={instructions:?}", + ); } #[test] - fn test_with_break_cleanup_makes_following_jump_artificial() { + fn with_break_cleanup_makes_following_jump_artificial() { let code = compile_exec( "\ def f(self): @@ -16940,7 +21304,7 @@ def f(self): } #[test] - fn test_while_exit_before_with_cleanup_materializes_anchor_nop() { + fn while_exit_before_with_cleanup_materializes_anchor_nop() { let code = compile_exec( "\ def f(selector, self): @@ -16982,7 +21346,7 @@ def f(selector, self): } #[test] - fn test_nested_boolop_same_or_prefixes_compile_without_extra_boolop_block() { + fn nested_boolop_same_or_prefixes_compile_without_extra_boolop_block() { let code = compile_exec( "\ def f(c, encodeO, encodeWS): @@ -17011,7 +21375,7 @@ def f(c, encodeO, encodeWS): } #[test] - fn test_nested_opposite_boolop_threads_to_fallthrough_like_cpython() { + fn nested_opposite_boolop_threads_to_fallthrough_like_cpython() { for source in [ "\ def f(a, b, c): @@ -17050,7 +21414,7 @@ def f(a, b, c): } #[test] - fn test_loop_or_continue_keeps_boolop_true_edge_to_continue() { + fn loop_or_continue_keeps_boolop_true_edge_to_continue() { let code = compile_exec( "\ def f(numpy_array, lshape, rshape, litems, fmt, tl): @@ -17106,7 +21470,7 @@ def f(numpy_array, lshape, rshape, litems, fmt, tl): } #[test] - fn test_nested_and_or_expression_threads_same_false_short_circuit() { + fn nested_and_or_expression_threads_same_false_short_circuit() { let code = compile_exec( "\ def f(fmt, MEMORYVIEW): @@ -17137,7 +21501,7 @@ def f(fmt, MEMORYVIEW): } #[test] - fn test_broad_exception_import_keeps_borrow_in_common_tail() { + fn broad_exception_import_keeps_borrow_in_common_tail() { let code = compile_exec( "\ def f(msg): @@ -17175,7 +21539,7 @@ def f(msg): } #[test] - fn test_try_import_return_handler_deopts_common_tail_borrow() { + fn try_import_return_handler_deopts_common_tail_borrow() { let code = compile_exec( "\ def f(): @@ -17206,7 +21570,7 @@ def f(): } #[test] - fn test_try_import_return_handler_deopts_later_protected_tail_borrow() { + fn try_import_return_handler_deopts_later_protected_tail_borrow() { let code = compile_exec( "\ def f(info_add): @@ -17255,7 +21619,7 @@ def f(info_add): } #[test] - fn test_try_import_continue_handler_deopts_loop_tail_borrow() { + fn try_import_continue_handler_deopts_loop_tail_borrow() { let code = compile_exec( "\ def f(size): @@ -17306,7 +21670,7 @@ def f(size): } #[test] - fn test_try_import_continue_inside_loop_keeps_earlier_loop_body_borrows() { + fn try_import_continue_inside_loop_keeps_earlier_loop_body_borrows() { let code = compile_exec( r#" def f(s, size, errors): @@ -17369,7 +21733,7 @@ def f(s, size, errors): } #[test] - fn test_try_import_pass_else_keeps_borrow() { + fn try_import_pass_else_keeps_borrow() { let code = compile_exec( "\ def f(self): @@ -17407,7 +21771,7 @@ def f(self): } #[test] - fn test_try_import_broad_handler_implicit_return_keeps_borrow() { + fn try_import_broad_handler_implicit_return_keeps_borrow() { let code = compile_exec( "\ def f(self, record): @@ -17460,7 +21824,7 @@ def f(self, record): } #[test] - fn test_try_import_handler_assignment_resume_tail_keeps_borrow() { + fn try_import_handler_assignment_resume_tail_keeps_borrow() { let code = compile_exec( "\ def f(): @@ -17511,7 +21875,109 @@ def f(): } #[test] - fn test_protected_attr_direct_return_keeps_borrow() { + fn empty_fallthrough_handler_assignment_tail_keeps_borrows() { + let code = compile_exec( + "\ +def f(value): + obs_local_part = ObsLocalPart() + try: + token, value = get_word(value) + except HeaderParseError: + if value[0] not in CFWS_LEADER: + raise + token, value = get_cfws(value) + obs_local_part.append(token) + return obs_local_part, value +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let handler_start = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .expect("missing handler entry"); + let normal_path = &ops[..handler_start]; + let load_name = |unit: &&CodeUnit, name: &str, borrowed: bool| { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + match (unit.op, borrowed) { + (Instruction::LoadFastBorrow { var_num }, true) + | (Instruction::LoadFast { var_num }, false) => { + f.varnames[usize::from(var_num.get(arg))].as_str() == name + } + _ => false, + } + }; + + for name in ["obs_local_part", "token"] { + assert!( + normal_path.iter().any(|unit| load_name(unit, name, true)), + "handler assignment tail should keep CPython-style borrowed {name} loads, got path={normal_path:?}" + ); + assert!( + !normal_path.iter().any(|unit| load_name(unit, name, false)), + "handler assignment tail should not force strong {name}, got path={normal_path:?}" + ); + } + } + + #[test] + fn protected_store_of_preinitialized_local_keeps_return_borrow() { + let code = compile_exec( + "\ +def f(obj): + maybe_routine = obj + try: + maybe_routine = inspect.unwrap(maybe_routine) + except ValueError: + pass + return inspect.isroutine(maybe_routine) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let handler_start = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .expect("missing handler entry"); + let normal_path = &ops[..handler_start]; + let maybe_routine_idx = f + .varnames + .iter() + .position(|name| name == "maybe_routine") + .expect("missing maybe_routine local"); + let loads_maybe_routine = |unit: &&CodeUnit, borrowed: bool| match (unit.op, borrowed) { + (Instruction::LoadFastBorrow { var_num }, true) + | (Instruction::LoadFast { var_num }, false) => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + usize::from(var_num.get(arg)) == maybe_routine_idx + } + _ => false, + }; + + assert!( + normal_path + .iter() + .any(|unit| loads_maybe_routine(unit, true)), + "preinitialized protected-store tail should keep CPython-style borrowed local, got path={normal_path:?}" + ); + assert!( + !normal_path + .iter() + .any(|unit| loads_maybe_routine(unit, false)), + "preinitialized protected-store tail should not force strong local, got path={normal_path:?}" + ); + } + + #[test] + fn protected_attr_direct_return_keeps_borrow() { let code = compile_exec( "\ def f(obj): @@ -17552,7 +22018,7 @@ def f(obj): } #[test] - fn test_protected_store_normal_tail_uses_strong_loads() { + fn protected_store_normal_tail_uses_strong_loads() { let code = compile_exec( "\ def f(tarfile, tarinfo, self): @@ -17592,7 +22058,46 @@ def f(tarfile, tarinfo, self): } #[test] - fn test_protected_call_arm_final_store_return_uses_strong_load() { + fn protected_subscript_store_normal_tail_uses_strong_loads() { + let code = compile_exec( + "\ +def f(self, d, option, fallback): + try: + value = d[option] + except KeyError: + return fallback + return self.convert(value, option) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let value_store = ops + .iter() + .position(|op| matches!(op, Instruction::StoreFast { .. })) + .expect("missing value store"); + let handler_start = ops + .iter() + .position(|op| matches!(op, Instruction::PushExcInfo)) + .expect("missing handler entry"); + let normal_tail = &ops[value_store + 1..handler_start]; + + assert!( + !normal_tail.iter().any(|op| matches!( + op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + )), + "expected CPython-style strong LOAD_FAST after protected subscript store, got tail={normal_tail:?}", + ); + } + + #[test] + fn protected_call_arm_final_store_return_uses_strong_load() { let code = compile_exec( "\ def f(self, action, default_metavar): @@ -17667,7 +22172,7 @@ def f(self, action, default_metavar): } #[test] - fn test_protected_store_try_else_tail_keeps_borrowed_loads() { + fn protected_store_try_else_tail_keeps_borrowed_loads() { let code = compile_exec( "\ def f(value): @@ -17759,7 +22264,7 @@ def f(value): } #[test] - fn test_nested_try_except_common_tail_uses_strong_loads() { + fn nested_try_except_common_tail_uses_strong_loads() { let code = compile_exec( "\ def f(value): @@ -17802,7 +22307,7 @@ def f(value): } #[test] - fn test_nested_try_except_branch_tail_with_following_try_uses_strong_loads() { + fn nested_try_except_branch_tail_with_following_try_uses_strong_loads() { let code = compile_exec( r#" def f(value): @@ -17868,7 +22373,7 @@ def f(value): } #[test] - fn test_nested_try_store_subscr_following_try_tail_uses_strong_loads() { + fn nested_try_store_subscr_following_try_tail_uses_strong_loads() { let code = compile_exec( r#" def f(value): @@ -17925,7 +22430,7 @@ def f(value): } #[test] - fn test_resuming_except_in_loop_keeps_post_try_store_tail_borrowed() { + fn resuming_except_in_loop_keeps_post_try_store_tail_borrowed() { let code = compile_exec( "\ def f(part, lines, maxlen, encoding): @@ -17985,7 +22490,7 @@ def f(part, lines, maxlen, encoding): } #[test] - fn test_handler_resume_loop_latch_method_call_uses_strong_loads() { + fn handler_resume_loop_latch_method_call_uses_strong_loads() { let code = compile_exec( "\ def f(phrase, value): @@ -18046,7 +22551,7 @@ def f(phrase, value): } #[test] - fn test_single_handler_multiple_resume_branches_keep_post_try_tail_borrowed() { + fn single_handler_multiple_resume_branches_keep_post_try_tail_borrowed() { let code = compile_exec( "\ def f(part, lines, maxlen, encoding): @@ -18110,7 +22615,7 @@ def f(part, lines, maxlen, encoding): } #[test] - fn test_nested_exception_handler_resume_update_tail_uses_strong_load() { + fn nested_exception_handler_resume_update_tail_uses_strong_load() { let code = compile_exec( "\ def f(inpos, size, g, replacement): @@ -18170,7 +22675,7 @@ def f(inpos, size, g, replacement): } #[test] - fn test_protected_store_finally_cleanup_keeps_borrow_tail() { + fn protected_store_finally_cleanup_keeps_borrow_tail() { let code = compile_exec( "\ def f(re, f): @@ -18213,7 +22718,7 @@ def f(re, f): } #[test] - fn test_try_else_finally_cleanup_keeps_borrow_tail() { + fn try_else_finally_cleanup_keeps_borrow_tail() { let code = compile_exec( "\ def f(re, open): @@ -18261,7 +22766,53 @@ def f(re, open): } #[test] - fn test_generator_protected_store_subscr_tail_uses_strong_loads() { + fn typed_terminal_attr_store_deopts_later_protected_store_subscr() { + let code = compile_exec( + "\ +def f(instance, self, _NOT_FOUND): + try: + cache = instance.__dict__ + except AttributeError: + raise TypeError('missing dict') from None + val = cache.get(self.attrname, _NOT_FOUND) + if val is _NOT_FOUND: + val = self.func(instance) + try: + cache[self.attrname] = val + except TypeError: + raise TypeError('bad cache') from None + return val +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let store_subscr = ops + .iter() + .position(|op| matches!(op, Instruction::StoreSubscr)) + .expect("missing STORE_SUBSCR"); + let window = &ops[store_subscr.saturating_sub(3)..=store_subscr]; + + assert!( + matches!( + window, + [ + Instruction::LoadFastLoadFast { .. }, + Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + Instruction::StoreSubscr, + ] + ), + "CPython keeps strong loads before protected STORE_SUBSCR after a typed terminal attr-store try, got window={window:?}; ops={ops:?}" + ); + } + + #[test] + fn generator_protected_store_subscr_tail_uses_strong_loads() { let code = compile_exec( "\ def f(names, modules): @@ -18304,7 +22855,7 @@ def f(names, modules): } #[test] - fn test_protected_call_function_ex_store_tail_uses_strong_loads() { + fn protected_call_function_ex_store_tail_uses_strong_loads() { let code = compile_exec( "\ def f(func, *args): @@ -18346,7 +22897,7 @@ def f(func, *args): } #[test] - fn test_protected_attr_subscript_tail_uses_strong_load_fast() { + fn protected_attr_subscript_tail_uses_strong_load_fast() { let code = compile_exec( "\ def f(obj, idx): @@ -18383,7 +22934,7 @@ def f(obj, idx): } #[test] - fn test_protected_direct_subscript_tail_uses_strong_load_fast() { + fn protected_direct_subscript_tail_uses_strong_load_fast() { let code = compile_exec( "\ def f(seq): @@ -18424,7 +22975,7 @@ def f(seq): } #[test] - fn test_protected_attr_iter_chain_uses_strong_load_fast() { + fn protected_attr_iter_chain_uses_strong_load_fast() { let code = compile_exec( "\ def f(fields): @@ -18461,7 +23012,7 @@ def f(fields): } #[test] - fn test_generator_except_return_handler_deopts_normal_tail_borrows() { + fn generator_except_return_handler_deopts_normal_tail_borrows() { let code = compile_exec( "\ def f(fields): @@ -18505,7 +23056,7 @@ def f(fields): } #[test] - fn test_generator_except_yielding_handler_keeps_normal_tail_borrows() { + fn generator_except_yielding_handler_keeps_normal_tail_borrows() { let code = compile_exec( "\ def f(tp, parent=None): @@ -18544,7 +23095,7 @@ def f(tp, parent=None): } #[test] - fn test_generator_returning_except_keeps_yield_from_resume_tail_borrow() { + fn generator_returning_except_keeps_yield_from_resume_tail_borrow() { let code = compile_exec( "\ def f(self, action): @@ -18590,7 +23141,7 @@ def f(self, action): } #[test] - fn test_generator_except_pass_resume_tail_keeps_borrows() { + fn generator_except_pass_resume_tail_keeps_borrows() { let code = compile_exec( "\ def f(self, msg): @@ -18645,7 +23196,7 @@ def f(self, msg): } #[test] - fn test_async_for_cleanup_resume_tail_uses_strong_loads() { + fn async_for_cleanup_resume_tail_uses_strong_loads() { let code = compile_exec( "\ async def f(g, self, x): @@ -18699,7 +23250,7 @@ async def f(g, self, x): } #[test] - fn test_async_generator_async_with_yield_keeps_borrow() { + fn async_generator_async_with_yield_keeps_borrow() { let code = compile_exec( "\ async def f(self, my_cm): @@ -18736,7 +23287,7 @@ async def f(self, my_cm): } #[test] - fn test_deoptimized_async_with_enter_continuation_uses_strong_loads() { + fn deoptimized_async_with_enter_continuation_uses_strong_loads() { let code = compile_exec( "\ async def f(): @@ -18782,7 +23333,7 @@ async def f(): } #[test] - fn test_async_with_bare_raise_continuation_keeps_borrow() { + fn async_with_bare_raise_continuation_keeps_borrow() { let code = compile_exec( "\ async def f(tg): @@ -18817,7 +23368,7 @@ async def f(tg): } #[test] - fn test_except_star_tail_uses_strong_loads() { + fn except_star_tail_uses_strong_loads() { let code = compile_exec( "\ def f(self): @@ -18835,6 +23386,10 @@ def f(self): .map(|unit| unit.op) .filter(|op| !matches!(op, Instruction::Cache)) .collect(); + let fail_attr = ops + .iter() + .position(|op| matches!(op, Instruction::LoadAttr { .. })) + .expect("missing self.fail load"); assert!( ops.windows(4).any(|window| { @@ -18864,10 +23419,17 @@ def f(self): }), "except* tail should not borrow the receiver after the handler region, got ops={ops:?}" ); + assert!( + !matches!( + ops.get(fail_attr.saturating_sub(2)), + Some(Instruction::JumpForward { .. }) + ), + "except* end label should not compile as an extra jump before the continuation, got ops={ops:?}" + ); } #[test] - fn test_protected_attr_subscript_store_tail_uses_strong_load_fast() { + fn protected_attr_subscript_store_tail_uses_strong_load_fast() { let code = compile_exec( "\ def f(f, oldcls, newcls): @@ -18919,7 +23481,7 @@ def f(f, oldcls, newcls): } #[test] - fn test_plain_attr_subscript_tail_keeps_borrow() { + fn plain_attr_subscript_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, name): @@ -18952,7 +23514,7 @@ def f(self, name): } #[test] - fn test_plain_attr_iter_chain_keeps_borrow() { + fn plain_attr_iter_chain_keeps_borrow() { let code = compile_exec( "\ def f(fields): @@ -18984,7 +23546,7 @@ def f(fields): } #[test] - fn test_genexpr_true_filter_omits_bool_scaffolding() { + fn genexpr_true_filter_omits_bool_scaffolding() { let code = compile_exec( "\ def f(it): @@ -19022,7 +23584,7 @@ def f(it): } #[test] - fn test_classdictcell_uses_load_closure_path_and_borrows_after_optimize() { + fn classdictcell_uses_load_closure_path_and_borrows_after_optimize() { let code = compile_exec( "\ class C: @@ -19064,7 +23626,7 @@ class C: } #[test] - fn test_conditional_class_body_duplicates_no_location_exit_tail() { + fn conditional_class_body_duplicates_no_location_exit_tail() { let code = compile_exec( "\ flag = False @@ -19110,7 +23672,7 @@ class C: } #[test] - fn test_class_lambda_assignment_does_not_create_classdictcell() { + fn class_lambda_assignment_does_not_create_classdictcell() { let code = compile_exec( "\ class C: @@ -19140,7 +23702,7 @@ class C: } #[test] - fn test_nested_function_static_attributes_are_collected() { + fn nested_function_static_attributes_are_collected() { let code = compile_exec( "\ class C: @@ -19180,7 +23742,7 @@ class C: } #[test] - fn test_static_attributes_match_cpython_store_rule() { + fn static_attributes_match_cpython_store_rule() { let code = compile_exec( "\ class C: @@ -19221,7 +23783,7 @@ class C: } #[test] - fn test_decorated_class_uses_first_decorator_for_firstlineno() { + fn decorated_class_uses_first_decorator_for_firstlineno() { let code = compile_exec( "\ @dec1 @@ -19271,7 +23833,91 @@ class C: } #[test] - fn test_future_annotations_class_uses_direct_annotation_store() { + fn class_firstlineno_store_uses_name_resolution() { + let code = compile_exec( + "\ +def f(): + __firstlineno__ = 1 + class C: + nonlocal __firstlineno__ + return C +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + + assert!( + class_code + .freevars + .iter() + .any(|name| name == "__firstlineno__"), + "class should close over nonlocal __firstlineno__, got freevars={:?}", + class_code.freevars + ); + assert!( + class_code.instructions.iter().any(|unit| match unit.op { + Instruction::StoreDeref { i } => { + let idx = i.get(OpArg::new(u32::from(u8::from(unit.arg)))).as_usize(); + localsplus_name(class_code, idx) == Some("__firstlineno__") + } + _ => false, + }), + "CPython routes __firstlineno__ through name resolution and emits STORE_DEREF for __firstlineno__, got ops={:?} freevars={:?}", + class_code.instructions, + class_code.freevars + ); + assert!( + !class_code.instructions.iter().any(|unit| { + matches!( + unit.op, + Instruction::StoreName { namei } + if class_code.names + [namei.get(OpArg::new(u32::from(u8::from(unit.arg)))) as usize] + .as_str() + == "__firstlineno__" + ) + }), + "nonlocal __firstlineno__ should not be stored with STORE_NAME, got ops={:?}", + class_code.instructions + ); + } + + #[test] + fn lambda_parent_qualname_includes_locals() { + let code = compile_exec( + "\ +def f(): + return lambda: (lambda: None) +", + ); + let mut lambda_qualnames = Vec::new(); + fn collect_lambda_qualnames(code: &CodeObject, out: &mut Vec) { + if code.obj_name == "" { + out.push(code.qualname.to_string()); + } + for constant in code.constants.iter() { + if let ConstantData::Code { code } = constant { + collect_lambda_qualnames(code, out); + } + } + } + collect_lambda_qualnames(&code, &mut lambda_qualnames); + + assert!( + lambda_qualnames + .iter() + .any(|name| name == "f.."), + "missing outer lambda qualname, got {lambda_qualnames:?}" + ); + assert!( + lambda_qualnames + .iter() + .any(|name| name == "f...."), + "nested lambda parent should include . like CPython, got {lambda_qualnames:?}" + ); + } + + #[test] + fn future_annotations_class_uses_direct_annotation_store() { let code = compile_exec( "\ from __future__ import annotations @@ -19312,7 +23958,7 @@ class C: } #[test] - fn test_future_annotations_module_keeps_conditional_annotations_cell() { + fn future_annotations_module_keeps_conditional_annotations_cell() { let code = compile_exec( "\ from __future__ import annotations @@ -19330,7 +23976,7 @@ x: int = 1 } #[test] - fn test_future_annotations_conditional_class_keeps_conditional_annotations_cell() { + fn future_annotations_conditional_class_keeps_conditional_annotations_cell() { let code = compile_exec( "\ from __future__ import annotations @@ -19352,7 +23998,7 @@ class C: } #[test] - fn test_future_annotations_setup_precedes_docstring() { + fn future_annotations_setup_precedes_docstring() { let code = compile_exec( "\ \"module doc\" @@ -19417,7 +24063,93 @@ class C: } #[test] - fn test_plain_super_call_keeps_class_freevar() { + fn future_annotations_flag_is_inherited_like_cpython() { + let code = compile_exec( + "\ +from __future__ import annotations + +def f(): + class C: + pass + return C +", + ); + assert!(code.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + let f = find_code(&code, "f").expect("missing f code"); + assert!(f.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + let class_code = find_code(f, "C").expect("missing C code"); + assert!( + class_code + .flags + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS) + ); + } + + #[test] + fn annotation_scope_nested_flag_matches_cpython() { + let code = compile_exec( + "\ +class C: + x: int + +def outer(): + class D: + y: int +", + ); + let class_code = find_code(&code, "C").expect("missing C code"); + let class_annotate = + find_code(class_code, "__annotate__").expect("missing class annotation code"); + assert!( + !class_annotate.flags.contains(bytecode::CodeFlags::NESTED), + "module-level class annotation scope should not be nested" + ); + + let outer = find_code(&code, "outer").expect("missing outer code"); + let nested_class = find_code(outer, "D").expect("missing nested class code"); + let nested_annotate = + find_code(nested_class, "__annotate__").expect("missing nested annotation code"); + assert!( + nested_annotate.flags.contains(bytecode::CodeFlags::NESTED), + "annotation scope under a nested class should be nested" + ); + } + + #[test] + fn function_like_parent_marks_child_nested_like_cpython() { + let code = compile_exec( + "\ +x = lambda: (lambda: None) +type A[T] = T +", + ); + let outer_lambda = find_code(&code, "").expect("missing outer lambda code"); + assert!( + !outer_lambda.flags.contains(bytecode::CodeFlags::NESTED), + "module-level lambda should not be nested" + ); + let inner_lambda = + find_direct_child_code(outer_lambda, "").expect("missing inner lambda code"); + assert!( + inner_lambda.flags.contains(bytecode::CodeFlags::NESTED), + "lambda inside lambda should be nested" + ); + + let type_params = + find_code(&code, "").expect("missing type params code"); + assert!( + !type_params.flags.contains(bytecode::CodeFlags::NESTED), + "module-level type-parameter scope should not be nested" + ); + let type_alias = find_direct_child_code(type_params, "A").expect("missing type alias code"); + assert!( + type_alias.flags.contains(bytecode::CodeFlags::NESTED), + "type alias body inside type-parameter scope should be nested" + ); + } + + #[test] + fn plain_super_call_keeps_class_freevar() { let code = compile_exec( "\ class A: @@ -19449,7 +24181,7 @@ class B(A): } #[test] - fn test_nested_class_super_does_not_create_outer_class_closure() { + fn nested_class_super_does_not_create_outer_class_closure() { let code = compile_exec( "\ class C: @@ -19481,7 +24213,7 @@ class C: } #[test] - fn test_nested_closure_parameter_class_does_not_create_outer_class_closure() { + fn nested_closure_parameter_class_does_not_create_outer_class_closure() { let code = compile_exec( "\ class C: @@ -19516,7 +24248,7 @@ class C: } #[test] - fn test_chained_compare_jump_uses_single_cleanup_copy() { + fn chained_compare_jump_uses_single_cleanup_copy() { let code = compile_exec( "\ def f(code): @@ -19541,7 +24273,7 @@ def f(code): } #[test] - fn test_yield_from_cleanup_jumps_to_shared_end_send() { + fn yield_from_cleanup_jumps_to_shared_end_send() { let code = compile_exec( "\ def outer(): @@ -19576,7 +24308,7 @@ def outer(): } #[test] - fn test_try_except_falls_through_to_post_handler_code() { + fn try_except_falls_through_to_post_handler_code() { let code = compile_exec( "\ def f(): @@ -19617,7 +24349,92 @@ def f(): } #[test] - fn test_try_except_while_body_preserves_while_exit_line_nop() { + fn try_finally_loop_fallthrough_pop_block_bounds_exception_table() { + let code = compile_exec( + "\ +def f(os, E, data): + try: + while True: + part = os.read(3, 50000) + data += part + if not part or len(data) > 50000: + break + finally: + os.close(3) + if data: + raise E(2, 'x') +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let raise_idx = u32::try_from( + f.instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::RaiseVarargs { .. })) + .expect("missing post-finally raise"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries + .iter() + .all(|entry| raise_idx < entry.start || raise_idx >= entry.end), + "post-finally raise should not remain protected by the try/finally table; entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn except_as_alias_cleanup_exception_table_matches_cpython() { + let code = compile_exec( + "\ +def bug(): + try: + 1/0 + except Exception as e: + tb = e.__traceback__ + return tb +", + ); + let bug = find_code(&code, "bug").expect("missing bug code"); + let entries = bytecode::decode_exception_table(&bug.exceptiontable); + let not_taken_idx = u32::try_from( + bug.instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::NotTaken)) + .expect("missing NOT_TAKEN"), + ) + .unwrap(); + let alias_store_idx = not_taken_idx + 1; + let copy_idx = u32::try_from( + bug.instructions + .iter() + .position(|unit| { + matches!( + unit.op, + Instruction::Copy { i } + if i.get(OpArg::new(u32::from(u8::from(unit.arg)))) == 3 + ) + }) + .expect("missing outer cleanup COPY"), + ) + .unwrap(); + + assert!( + entries.iter().any(|entry| { + entry.start <= not_taken_idx + && alias_store_idx < entry.end + && entry.target == copy_idx + && entry.depth == 1 + && entry.push_lasti + }), + "CPython codegen_try_except() stores the exception alias before the inner SETUP_CLEANUP, so NOT_TAKEN and the alias store stay covered by the outer cleanup entry; entries={entries:?}, instructions={:?}", + bug.instructions + ); + } + + #[test] + fn try_except_while_body_preserves_while_exit_line_nop() { let code = compile_exec( "\ def f(x, E): @@ -19660,7 +24477,39 @@ def f(x, E): } #[test] - fn test_try_except_for_direct_break_preserves_normal_exhaustion_nop() { + fn constant_true_while_preserves_loop_line_nop() { + let code = compile_exec( + "\ +def f(self, callback): + i = 1 + while True: + for j in 1, 2, 5: + number = i * j + if callback: + callback(number, j) + return number, j +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let store_i = ops + .iter() + .position(|op| matches!(op, Instruction::StoreFast { .. })) + .expect("missing i store"); + + assert!( + matches!(ops.get(store_i + 1), Some(Instruction::Nop)), + "constant-true while should keep CPython loop-line NOP after setup, got ops={ops:?}" + ); + } + + #[test] + fn try_except_for_direct_break_preserves_normal_exhaustion_nop() { let code = compile_exec( "\ def f(xs, g, E): @@ -19699,7 +24548,7 @@ def f(xs, g, E): } #[test] - fn test_try_except_for_without_direct_break_drops_normal_exhaustion_nop() { + fn try_except_for_without_direct_break_drops_normal_exhaustion_nop() { let code = compile_exec( "\ def f(xs, g, E): @@ -19736,7 +24585,7 @@ def f(xs, g, E): } #[test] - fn test_terminal_except_before_conditional_tail_uses_strong_load() { + fn terminal_except_before_conditional_tail_uses_strong_load() { let code = compile_exec( "\ def f(self, Exception): @@ -19777,7 +24626,7 @@ def f(self, Exception): } #[test] - fn test_try_except_continuation_folded_tuple_drops_operand_nop() { + fn try_except_continuation_folded_tuple_drops_operand_nop() { let code = compile_exec( "\ def f(): @@ -19814,7 +24663,7 @@ def f(): } #[test] - fn test_if_else_normal_fallthrough_end_label_drops_return_anchor_nop() { + fn if_else_normal_fallthrough_end_label_drops_return_anchor_nop() { let code = compile_exec( "\ def f(s): @@ -19859,7 +24708,7 @@ def f(s): } #[test] - fn test_explicit_final_return_none_is_not_duplicated() { + fn explicit_final_return_none_is_not_duplicated() { let code = compile_exec( "\ def f(src, dst, length, exception, bufsize): @@ -19903,7 +24752,7 @@ def f(src, dst, length, exception, bufsize): } #[test] - fn test_named_except_cleanup_keeps_jump_over_cleanup_and_next_try() { + fn named_except_cleanup_keeps_jump_over_cleanup_and_next_try() { let code = compile_exec( r#" def f(self): @@ -19952,7 +24801,7 @@ def f(self): } #[test] - fn test_named_except_with_suppress_does_not_duplicate_following_with() { + fn named_except_with_suppress_does_not_duplicate_following_with() { let code = compile_exec( "\ def f(StringIO, captured_output, print): @@ -19992,7 +24841,7 @@ def f(StringIO, captured_output, print): } #[test] - fn test_bare_except_deopts_post_handler_load_fast_borrow() { + fn bare_except_deopts_post_handler_load_fast_borrow() { let code = compile_exec( "\ def f(self): @@ -20023,7 +24872,51 @@ def f(self): } #[test] - fn test_typed_except_keeps_post_handler_load_fast_borrow() { + fn bare_except_before_if_deopts_successor_load_fast_borrow() { + let code = compile_exec( + "\ +def f(self, x): + try: + x = g() + except: + self.fail('raised') + if x: + self.fail('unexpected') +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let fail_loads = instructions + .iter() + .enumerate() + .filter_map(|(idx, unit)| { + let Instruction::LoadAttr { namei } = unit.op else { + return None; + }; + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + (f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "fail") + .then_some(idx) + }) + .collect::>(); + assert!( + fail_loads.len() >= 2, + "expected handler and successor fail calls, got instructions={instructions:?}" + ); + assert!( + matches!( + instructions.get(fail_loads[1] - 1).map(|unit| unit.op), + Some(Instruction::LoadFast { .. }) + ), + "CPython codegen_try_except() sends a fallthrough bare handler through USE_LABEL(end); flowgraph.c::optimize_load_fast() stops at that empty end label before the following if, got instructions={instructions:?}" + ); + } + + #[test] + fn typed_except_keeps_post_handler_load_fast_borrow() { let code = compile_exec( "\ def f(self): @@ -20057,7 +24950,264 @@ def f(self): } #[test] - fn test_conditional_typed_except_return_join_keeps_borrow() { + fn bare_except_terminal_handler_store_subscr_tail_uses_strong_loads() { + let code = compile_exec( + "\ +def f(g, cache, filename, mtime, result): + try: + module = g() + except: + return None + cache[filename] = (mtime, result) + return result +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let is_pair = |unit: &&CodeUnit, left_name: &str, right_name: &str| { + let Instruction::LoadFastLoadFast { var_nums } = unit.op else { + return false; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == left_name + && f.varnames[usize::from(right)] == right_name + }; + let is_borrow_pair = |unit: &&CodeUnit, left_name: &str, right_name: &str| { + let Instruction::LoadFastBorrowLoadFastBorrow { var_nums } = unit.op else { + return false; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == left_name + && f.varnames[usize::from(right)] == right_name + }; + + assert!( + ops.iter().any(|unit| is_pair(unit, "mtime", "result")) + && ops.iter().any(|unit| is_pair(unit, "cache", "filename")), + "CPython optimize_load_fast() stops at the empty try-end block for a terminal bare handler, got ops={ops:?}" + ); + assert!( + !ops.iter() + .any(|unit| is_borrow_pair(unit, "mtime", "result") + || is_borrow_pair(unit, "cache", "filename")), + "terminal bare handler post-try store tail should not be borrowed, got ops={ops:?}" + ); + } + + #[test] + fn while_true_try_else_break_tail_uses_strong_loads() { + let code = compile_exec( + "\ +def f(path, prefix, self, cache, read, E, stat): + while True: + try: + st = stat(path) + except E: + path = path.dirname + else: + if st.mode: + raise E + break + if path not in cache: + cache[path] = read(path) + self.archive = path + self.prefix = prefix +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let is_pair = |unit: &&CodeUnit, left_name: &str, right_name: &str| { + let Instruction::LoadFastLoadFast { var_nums } = unit.op else { + return false; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == left_name + && f.varnames[usize::from(right)] == right_name + }; + let is_borrow_pair = |unit: &&CodeUnit, left_name: &str, right_name: &str| { + let Instruction::LoadFastBorrowLoadFastBorrow { var_nums } = unit.op else { + return false; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == left_name + && f.varnames[usize::from(right)] == right_name + }; + + assert!( + ops.iter().any(|unit| is_pair(unit, "path", "cache")) + && ops.iter().any(|unit| is_pair(unit, "path", "self")) + && ops.iter().any(|unit| is_pair(unit, "prefix", "self")), + "CPython codegen_while() leaves an empty break end label that stops optimize_load_fast(), got ops={ops:?}" + ); + assert!( + !ops.iter().any(|unit| is_borrow_pair(unit, "path", "cache") + || is_borrow_pair(unit, "path", "self") + || is_borrow_pair(unit, "prefix", "self")), + "while-true break successor should not be reached through a Rust-only fallthrough, got ops={ops:?}" + ); + } + + #[test] + fn except_handler_resume_return_call_tail_keeps_borrow() { + let code = compile_exec( + "\ +def f(class_cache, cls, KeyError, make, obj, lock, ctx): + try: + scls = class_cache[cls] + except KeyError: + scls = make(cls) + return scls(obj, lock, ctx) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let borrows_name = |unit: &&CodeUnit, name: &str| match unit.op { + Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.varnames[usize::from(var_num.get(arg))] == name + } + Instruction::LoadFastBorrowLoadFastBorrow { var_nums } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == name || f.varnames[usize::from(right)] == name + } + _ => false, + }; + let strong_loads_name = |unit: &&CodeUnit, name: &str| match unit.op { + Instruction::LoadFast { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.varnames[usize::from(var_num.get(arg))] == name + } + Instruction::LoadFastLoadFast { var_nums } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == name || f.varnames[usize::from(right)] == name + } + _ => false, + }; + let return_idx = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::ReturnValue)) + .expect("missing return"); + let tail = &ops[..return_idx]; + + for name in ["scls", "obj", "lock", "ctx"] { + assert!( + tail.iter().any(|unit| borrows_name(unit, name)), + "handler resume to CPython codegen_try_except() end label should keep return-call {name} borrowed, got tail={tail:?}" + ); + assert!( + !tail.iter().any(|unit| strong_loads_name(unit, name)), + "handler resume return-call tail should not be separated by a Rust-only empty end block for {name}, got tail={tail:?}" + ); + } + } + + #[test] + fn typed_except_named_handler_closure_tail_keeps_borrows() { + let code = compile_exec( + "\ +def f(self): + filename = TESTFN + ICACLS = expandvars('icacls') + try: + check_output([ICACLS, filename]) + except CalledProcessError as ex: + self.skipTest('Unable to create inaccessible file') + def cleanup(): + check_output([ICACLS, filename]) + self.addCleanup(cleanup) + stat1 = stat(filename) + stat2 = stat(filename) + self.assertEqual(stat1, stat2) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let make_function = ops + .iter() + .position(|op| matches!(op, Instruction::MakeFunction)) + .expect("missing MAKE_FUNCTION for cleanup closure"); + let handler_start = ops + .iter() + .position(|op| matches!(op, Instruction::PushExcInfo)) + .expect("missing handler entry"); + let tail = &ops[make_function.saturating_sub(2)..handler_start]; + + assert!( + tail.iter().any(|op| { + matches!( + op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "typed except closure continuation should be visited by optimize_load_fast(), got tail={tail:?}", + ); + assert!( + !tail + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython codegen_try_except() uses USE_LABEL(end), so the handler continuation should be a shared passthrough and post-handler closure/tail loads should borrow; got tail={tail:?}", + ); + } + + #[test] + fn named_terminal_raise_handler_keeps_return_pair_borrowed() { + let code = compile_exec( + "\ +def f(factory, worker_json, test_name, stdout, E): + try: + result = factory(worker_json) + except E as exc: + raise RuntimeError(test_name, stdout, exc) + return result, stdout +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let has_borrowed_pair = instructions.iter().any(|unit| { + let Instruction::LoadFastBorrowLoadFastBorrow { var_nums } = unit.op else { + return false; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == "result" && f.varnames[usize::from(right)] == "stdout" + }); + + assert!( + has_borrowed_pair, + "named terminal-raise handler should follow CPython codegen_try_except()/flowgraph.c cleanup reachability and keep return pair borrowed; got instructions={instructions:?}" + ); + } + + #[test] + fn conditional_typed_except_return_join_keeps_borrow() { let code = compile_exec( "\ def f(cond, obj, xs, E): @@ -20109,7 +25259,7 @@ def f(cond, obj, xs, E): } #[test] - fn test_typed_except_pass_resume_store_subscr_tail_keeps_borrows() { + fn typed_except_pass_resume_store_subscr_tail_keeps_borrows() { let code = compile_exec( "\ def f(self, sys, KeyError): @@ -20157,7 +25307,7 @@ def f(self, sys, KeyError): } #[test] - fn test_reraising_typed_except_deopts_post_handler_loads() { + fn reraising_typed_except_deopts_post_handler_loads() { let code = compile_exec( "\ def f(x, os, self, pid, exitcode): @@ -20211,7 +25361,7 @@ def f(x, os, self, pid, exitcode): } #[test] - fn test_reraising_outer_handler_keeps_explicit_raise_call_arg_borrow() { + fn reraising_outer_handler_keeps_explicit_raise_call_arg_borrow() { let code = compile_exec( "\ def f(file, os, stat, errno, self, fd): @@ -20259,7 +25409,7 @@ def f(file, os, stat, errno, self, fd): } #[test] - fn test_reraising_except_loop_backedge_keeps_loop_header_borrow() { + fn reraising_except_loop_backedge_keeps_loop_header_borrow() { let code = compile_exec( "\ def f(self, tag, expect_bye): @@ -20315,7 +25465,7 @@ def f(self, tag, expect_bye): } #[test] - fn test_protected_store_break_handler_deopts_bool_guard_tail() { + fn protected_store_break_handler_deopts_bool_guard_tail() { let code = compile_exec( "\ def f(self, size): @@ -20367,7 +25517,7 @@ def f(self, size): } #[test] - fn test_assertion_success_join_keeps_following_debug_tail_borrowed() { + fn assertion_success_join_keeps_following_debug_tail_borrowed() { let code = compile_exec( "\ def f(self, typ, dat): @@ -20429,7 +25579,7 @@ def f(self, typ, dat): } #[test] - fn test_multi_protected_method_call_terminal_handler_keeps_try_body_borrows() { + fn multi_protected_method_call_terminal_handler_keeps_try_body_borrows() { let code = compile_exec( "\ def f(self, literal): @@ -20493,7 +25643,7 @@ def f(self, literal): } #[test] - fn test_dunder_debug_constant_false_if_deopts_tail_borrow() { + fn dunder_debug_constant_false_if_deopts_tail_borrow() { let code = compile_exec( "\ def f(self): @@ -20527,7 +25677,7 @@ def f(self): } #[test] - fn test_constant_slice_folds_constant_bounds() { + fn constant_slice_folds_constant_bounds() { let code = compile_exec( "\ def f(obj): @@ -20576,7 +25726,7 @@ def f(obj): } #[test] - fn test_negative_step_slice_uses_build_slice() { + fn negative_step_slice_uses_build_slice() { let code = compile_exec( "\ def f(obj): @@ -20610,7 +25760,47 @@ def f(obj): } #[test] - fn test_bool_int_binop_constants_fold() { + fn slice_none_bounds_and_build_slice_use_slice_location_like_cpython() { + let code = compile_exec( + "\ +def f(obj, step): + return obj[::step] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let slice_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + let op = match unit.op { + Instruction::LoadConst { .. } => "LOAD_CONST", + Instruction::BuildSlice { .. } => "BUILD_SLICE", + _ => return None, + }; + Some(( + op, + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + slice_positions, + vec![ + ("LOAD_CONST", 2, 16, 2, 22), + ("LOAD_CONST", 2, 16, 2, 22), + ("BUILD_SLICE", 2, 16, 2, 22), + ], + "CPython codegen_slice() emits missing bounds and BUILD_SLICE at LOC(slice)" + ); + } + + #[test] + fn bool_int_binop_constants_fold() { let code = compile_exec( "\ def f(): @@ -20661,7 +25851,7 @@ def g(): } #[test] - fn test_double_not_expression_folds_to_bool_conversion() { + fn double_not_expression_folds_to_bool_conversion() { let code = compile_exec( "\ def f(x): @@ -20691,7 +25881,7 @@ def f(x): } #[test] - fn test_tuple_bound_slice_uses_two_part_slice_path() { + fn tuple_bound_slice_uses_two_part_slice_path() { let code = compile_exec( "\ def f(obj): @@ -20723,7 +25913,7 @@ def f(obj): } #[test] - fn test_exception_cleanup_jump_to_return_is_inlined() { + fn exception_cleanup_jump_to_return_is_inlined() { let code = compile_exec( "\ def f(names, cls): @@ -20748,7 +25938,7 @@ def f(names, cls): } #[test] - fn test_except_break_preserves_plain_jump_when_inlining_no_lineno_tail() { + fn except_break_preserves_plain_jump_when_inlining_no_lineno_tail() { let code = compile_exec( "\ def f(compiler_so, cc_args): @@ -20798,7 +25988,7 @@ def f(compiler_so, cc_args): } #[test] - fn test_nested_with_bare_except_keeps_handler_cleanup_before_following_code() { + fn nested_with_bare_except_keeps_handler_cleanup_before_following_code() { let code = compile_exec( "\ def f(cm, self): @@ -20843,7 +26033,7 @@ def f(cm, self): } #[test] - fn test_try_else_for_cleanup_drops_redundant_jump_nop() { + fn try_else_for_cleanup_drops_redundant_jump_nop() { let code = compile_exec( "\ def f(self, xs, ys, cm1, cm2): @@ -20910,7 +26100,7 @@ def f(self, xs, ys, cm1, cm2): } #[test] - fn test_non_none_final_return_is_not_duplicated() { + fn non_none_final_return_is_not_duplicated() { let code = compile_exec( "\ def f(p, s): @@ -20952,7 +26142,7 @@ def f(p, s): } #[test] - fn test_for_return_unary_constant_preserves_value_over_iterator_cleanup() { + fn for_return_unary_constant_preserves_value_over_iterator_cleanup() { let code = compile_exec( "\ def f(xs): @@ -20984,7 +26174,7 @@ def f(xs): } #[test] - fn test_try_else_if_return_keeps_conditional_target_nop() { + fn try_else_if_return_keeps_conditional_target_nop() { let code = compile_exec( "\ def f(cond): @@ -21025,7 +26215,7 @@ def f(cond): } #[test] - fn test_try_else_nested_if_return_drops_inner_conditional_target_nop() { + fn try_else_nested_if_return_drops_inner_conditional_target_nop() { let code = compile_exec( "\ def f(obj, Sig): @@ -21080,7 +26270,7 @@ def f(obj, Sig): } #[test] - fn test_try_else_nested_final_if_return_drops_nested_conditional_target_nop() { + fn try_else_nested_final_if_return_drops_nested_conditional_target_nop() { let code = compile_exec( "\ def f(cond, outer): @@ -21120,7 +26310,7 @@ def f(cond, outer): } #[test] - fn test_named_except_conditional_branch_duplicates_cleanup_return() { + fn named_except_conditional_branch_duplicates_cleanup_return() { let code = compile_exec( "\ def f(self): @@ -21163,7 +26353,7 @@ def f(self): } #[test] - fn test_named_except_conditional_before_explicit_return_shares_cleanup_return() { + fn named_except_conditional_before_explicit_return_shares_cleanup_return() { let code = compile_exec( "\ def f(onerror, err, OSError): @@ -21207,7 +26397,51 @@ def f(onerror, err, OSError): } #[test] - fn test_listcomp_cleanup_tail_keeps_split_store_fast_pair() { + fn named_except_boolop_condition_shares_cleanup_return() { + let code = compile_exec( + "\ +def f(self, module_name, ModuleNotFoundError): + try: + return importlib.import_module(module_name) + except ModuleNotFoundError as error: + if self._warn_on_extension_import and module_name in builtin_hashes: + logging.getLogger(__name__).warning('msg', error, exc_info=error) + return None +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + let cleanup_return_count = ops + .windows(6) + .filter(|window| { + matches!( + window, + [ + Instruction::PopExcept, + Instruction::LoadConst { .. }, + Instruction::StoreFast { .. } | Instruction::StoreName { .. }, + Instruction::DeleteFast { .. } | Instruction::DeleteName { .. }, + Instruction::LoadConst { .. }, + Instruction::ReturnValue, + ] + ) + }) + .count(); + + assert_eq!( + cleanup_return_count, 1, + "CPython keeps a shared named-except cleanup return when multiple BoolOp false edges target the same cleanup block, got ops={ops:?}" + ); + } + + #[test] + fn listcomp_cleanup_tail_keeps_split_store_fast_pair() { let code = compile_exec( "\ def f(escaped_string, quote_types): @@ -21245,7 +26479,7 @@ def f(escaped_string, quote_types): } #[test] - fn test_dictcomp_cleanup_tail_keeps_split_store_fast_pair() { + fn dictcomp_cleanup_tail_keeps_split_store_fast_pair() { let code = compile_exec( "\ def f(obj, g): @@ -21282,7 +26516,7 @@ def f(obj, g): } #[test] - fn test_static_swap_triple_assign_keeps_store_fast_store_fast() { + fn static_swap_triple_assign_keeps_store_fast_store_fast() { let code = compile_exec( "\ def f(x, y, z): @@ -21314,7 +26548,7 @@ def f(x, y, z): } #[test] - fn test_static_swap_duplicate_pair_eliminates_swap() { + fn static_swap_duplicate_pair_eliminates_swap() { let code = compile_exec( "\ def f(x, y): @@ -21343,7 +26577,7 @@ def f(x, y): } #[test] - fn test_static_swap_duplicate_prefix_eliminates_swap() { + fn static_swap_duplicate_prefix_eliminates_swap() { let code = compile_exec( "\ def f(x, y, z): @@ -21375,7 +26609,7 @@ def f(x, y, z): } #[test] - fn test_constant_if_expression_stmt_in_loop_removes_empty_body() { + fn constant_if_expression_stmt_in_loop_removes_empty_body() { let code = compile_exec( "\ def f(x): @@ -21399,7 +26633,7 @@ def f(x): } #[test] - fn test_if_expression_in_jump_context_skips_constant_true_arm_load() { + fn if_expression_in_jump_context_skips_constant_true_arm_load() { let code = compile_exec( "\ def f(): @@ -21422,7 +26656,7 @@ def f(): } #[test] - fn test_with_suppress_tail_duplicates_final_return_none() { + fn with_suppress_tail_duplicates_final_return_none() { let code = compile_exec( "\ def f(cm, cond): @@ -21456,7 +26690,7 @@ def f(cm, cond): } #[test] - fn test_with_conditional_bare_return_keeps_return_line_nop_before_exit_cleanup() { + fn with_conditional_bare_return_keeps_return_line_nop_before_exit_cleanup() { let code = compile_exec( "\ def f(cm, registry, altkey): @@ -21495,7 +26729,7 @@ def f(cm, registry, altkey): } #[test] - fn test_multiline_nested_with_return_finally_keeps_inner_cleanup_anchor_nop() { + fn multiline_nested_with_return_finally_keeps_inner_cleanup_anchor_nop() { let code = compile_exec( "\ def f(a, b, path): @@ -21538,7 +26772,69 @@ def f(a, b, path): } #[test] - fn test_try_finally_conditional_return_duplicates_finally_exit_return() { + fn with_return_value_uses_context_expr_location_like_cpython() { + let code = compile_exec( + "\ +def f(cm, func, args, kwds): + with cm: + return func(*args, **kwds) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let return_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 10, 2, 12), (2, 10, 2, 12)], + "CPython codegen_unwind_fblock(WITH) leaves RETURN_VALUE inheriting the context expression location" + ); + } + + #[test] + fn async_with_return_value_uses_context_expr_location_like_cpython() { + let code = compile_exec( + "\ +async def f(cm, func, args, kwds): + async with cm: + return await func(*args, **kwds) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let return_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 16, 2, 18), (2, 16, 2, 18)], + "CPython codegen_unwind_fblock(ASYNC_WITH) leaves RETURN_VALUE inheriting the context expression location" + ); + } + + #[test] + fn try_finally_conditional_return_duplicates_finally_exit_return() { let code = compile_exec( "\ def f(flag, data, callback): @@ -21570,7 +26866,7 @@ def f(flag, data, callback): } #[test] - fn test_named_except_conditional_cleanup_is_inlined_per_branch() { + fn named_except_conditional_cleanup_is_inlined_per_branch() { let code = compile_exec( "\ def f(self, logger): @@ -21617,7 +26913,7 @@ def f(self, logger): } #[test] - fn test_try_finally_exception_path_duplicates_conditional_reraise() { + fn try_finally_exception_path_duplicates_conditional_reraise() { let code = compile_exec( "\ def f(flag, callback): @@ -21647,7 +26943,7 @@ def f(flag, callback): } #[test] - fn test_genexpr_compare_header_uses_store_fast_load_fast_like_cpython() { + fn genexpr_compare_header_uses_store_fast_load_fast_like_cpython() { let code = compile_exec( "\ def f(it): @@ -21678,7 +26974,7 @@ def f(it): } #[test] - fn test_fstring_adjacent_literals_are_merged() { + fn fstring_adjacent_literals_are_merged() { let code = compile_exec( "\ def f(cls, proto): @@ -21720,7 +27016,7 @@ def f(cls, proto): } #[test] - fn test_literal_only_fstring_statement_is_optimized_away() { + fn literal_only_fstring_statement_keeps_const_like_cpython() { let code = compile_exec( "\ def f(): @@ -21730,22 +27026,16 @@ def f(): let f = find_code(&code, "f").expect("missing function code"); assert!( - !f.instructions - .iter() - .any(|unit| matches!(unit.op, Instruction::PopTop)), - "literal-only f-string statement should be removed" - ); - assert!( - !f.constants.iter().any(|constant| matches!( + f.constants.iter().any(|constant| matches!( constant, ConstantData::Str { value } if value.to_string() == "Not a docstring" )), - "literal-only f-string should not survive in constants" + "constant f-string statement should survive in co_consts like CPython" ); } #[test] - fn test_empty_fstring_literals_are_elided_around_interpolation() { + fn empty_fstring_literals_are_elided_around_interpolation() { let code = compile_exec( "\ def f(x): @@ -21783,7 +27073,7 @@ def f(x): } #[test] - fn test_large_fstring_uses_join_list_like_cpython() { + fn large_fstring_uses_join_list_like_cpython() { let mut source = String::from("def f(x):\n return f\""); for _ in 0..=STACK_USE_GUIDELINE { source.push_str("{x}"); @@ -21825,7 +27115,7 @@ def f(x): } #[test] - fn test_large_power_is_not_constant_folded() { + fn large_power_is_not_constant_folded() { let code = compile_exec("x = 2**100\n"); assert!(code.instructions.iter().any(|unit| match unit.op { @@ -21837,7 +27127,7 @@ def f(x): } #[test] - fn test_string_and_bytes_binops_constant_fold_like_cpython() { + fn string_and_bytes_binops_constant_fold_like_cpython() { let code = compile_exec( "\ x = b'\\\\' + b'u1881'\n\ @@ -21865,7 +27155,7 @@ y = 103 * 'a' + 'x'\n", } #[test] - fn test_float_floor_division_constant_folds_like_cpython() { + fn float_floor_division_constant_folds_like_cpython() { let code = compile_exec( "\ x = 1.0 // 0.1\n\ @@ -21897,7 +27187,7 @@ z = 1e300 * 1e300 * 0\n", } #[test] - fn test_float_power_overflow_constant_does_not_fold() { + fn float_power_overflow_constant_does_not_fold() { let code = compile_exec("x = 1e300 ** 2\n"); assert!( @@ -21913,7 +27203,7 @@ z = 1e300 * 1e300 * 0\n", } #[test] - fn test_large_string_and_bytes_binops_constant_fold_like_cpython() { + fn large_string_and_bytes_binops_constant_fold_like_cpython() { let code = compile_exec( r#" encoded = b'\xff\xfe\x00\x00' + b'\x00\x00\x01\x00' * 1024 @@ -21940,7 +27230,7 @@ text = '\U00010000' * 1024 } #[test] - fn test_constant_string_subscript_folds_inside_collection() { + fn constant_string_subscript_folds_inside_collection() { let code = compile_exec( "\ values = [item for item in [r\"\\\\'a\\\\'\", r\"\\t3\", r\"\\\\\"[0]]]\n", @@ -21963,7 +27253,7 @@ values = [item for item in [r\"\\\\'a\\\\'\", r\"\\t3\", r\"\\\\\"[0]]]\n", } #[test] - fn test_constant_string_subscript_with_surrogate_skips_lossy_fold() { + fn constant_string_subscript_with_surrogate_skips_lossy_fold() { let code = compile_exec("value = \"\\ud800\"[0]\n"); assert!( @@ -21980,7 +27270,7 @@ values = [item for item in [r\"\\\\'a\\\\'\", r\"\\t3\", r\"\\\\\"[0]]]\n", } #[test] - fn test_constant_subscript_folds_in_load_context() { + fn constant_subscript_folds_in_load_context() { let cases = [ ("value = (1, 2, 3)[0]\n", Some(BigInt::from(1)), None), ("value = b\"abc\"[0]\n", Some(BigInt::from(97)), None), @@ -22031,7 +27321,30 @@ values = [item for item in [r\"\\\\'a\\\\'\", r\"\\t3\", r\"\\\\\"[0]]]\n", } #[test] - fn test_constant_slice_subscript_folds_in_load_context() { + fn constant_subscript_registers_source_const_before_result_like_cpython() { + let code = compile_exec("value = 'string'[3]\n"); + let source_index = code + .constants + .iter() + .position(|constant| { + matches!(constant, ConstantData::Str { value } if value.to_string() == "string") + }) + .expect("missing source string constant"); + let result_index = code + .constants + .iter() + .position(|constant| { + matches!(constant, ConstantData::Str { value } if value.to_string() == "i") + }) + .expect("missing folded subscript result"); + assert!( + source_index < result_index, + "CPython codegen_subscript emits the source constant before flowgraph.c folds NB_SUBSCR" + ); + } + + #[test] + fn constant_slice_subscript_folds_in_load_context() { let code = compile_exec( "\ a = 'hello'[:4]\n\ @@ -22072,7 +27385,7 @@ c = (1, 2, 3)[:2]\n", } #[test] - fn test_list_of_constant_tuples_uses_list_extend() { + fn list_of_constant_tuples_uses_list_extend() { let code = compile_exec( "\ deprecated_cases = [('a', 'b'), ('c', 'd'), ('e', 'f'), ('g', 'h'), ('i', 'j')] @@ -22088,7 +27401,7 @@ deprecated_cases = [('a', 'b'), ('c', 'd'), ('e', 'f'), ('g', 'h'), ('i', 'j')] } #[test] - fn test_large_list_of_unary_constants_uses_list_extend() { + fn large_list_of_unary_constants_uses_list_extend() { let code = compile_exec( "\ values = [-1, not True, ~0, +True, 5] @@ -22115,7 +27428,7 @@ values = [-1, not True, ~0, +True, 5] } #[test] - fn test_outer_unary_after_binop_folds_before_list_folding() { + fn outer_unary_after_binop_folds_before_list_folding() { let code = compile_exec( "\ values = [2.0**53, -0.5, -2.0**-54] @@ -22148,7 +27461,7 @@ values = [2.0**53, -0.5, -2.0**-54] } #[test] - fn test_negative_integer_power_folds_to_float_constant() { + fn negative_integer_power_folds_to_float_constant() { let code = compile_exec("value = -3.0 * 2**(-333)\n"); assert!( @@ -22167,7 +27480,7 @@ values = [2.0**53, -0.5, -2.0**-54] } #[test] - fn test_complex_power_constants_fold_like_cpython() { + fn complex_power_constants_fold_like_cpython() { let code = compile_exec( "\ one = 3j ** 0j @@ -22194,7 +27507,30 @@ zero = 0j ** 2 } #[test] - fn test_zero_complex_power_exception_constants_do_not_fold() { + fn folded_nan_constants_are_not_deduplicated_like_cpython() { + let code = compile_exec( + "\ +def f(): + repr(1e300 * 1e300 * 0) + repr(-1e300 * 1e300 * 0) + str(1e300 * 1e300 * 0) + str(-1e300 * 1e300 * 0) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let nan_count = f + .constants + .iter() + .filter(|constant| matches!(constant, ConstantData::Float { value } if value.is_nan())) + .count(); + assert_eq!( + nan_count, 4, + "CPython _PyCode_ConstantKey keeps folded NaN constants distinct" + ); + } + + #[test] + fn zero_complex_power_exception_constants_do_not_fold() { let code = compile_exec("value = 0j ** (3 - 2j)\n"); assert!( @@ -22207,7 +27543,7 @@ zero = 0j ** 2 } #[test] - fn test_large_constant_list_keeps_streaming_build() { + fn large_constant_list_keeps_streaming_build() { let source = format!( "values = [{}]\n", (0..31) @@ -22235,7 +27571,7 @@ zero = 0j ** 2 } #[test] - fn test_large_constant_tuple_stream_folds_to_tuple_const() { + fn large_constant_tuple_stream_folds_to_tuple_const() { let source = format!( "values = ({},)\n", (0..31) @@ -22262,7 +27598,7 @@ zero = 0j ** 2 } #[test] - fn test_annotation_closure_uses_format_varname() { + fn annotation_closure_uses_format_varname() { let code = compile_exec( "\ class C: @@ -22279,7 +27615,54 @@ class C: } #[test] - fn test_type_param_evaluator_uses_dot_format_varname() { + fn non_simple_class_annotation_is_not_deferred_like_cpython() { + let code = compile_exec( + "\ +class C: + x.y: list = [] + z: int +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + let names = annotate + .names + .iter() + .map(|name| name.as_str()) + .collect::>(); + assert_eq!(names, vec!["int"]); + } + + #[test] + fn non_simple_annotation_only_consumes_symbol_table_cursor() { + let code = compile_exec( + "\ +class C: + x.y: (lambda: str) = [] + z: (lambda: int) +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + let lambdas = annotate + .constants + .iter() + .filter_map(|constant| match constant { + ConstantData::Code { code } if code.obj_name == "" => Some(code.as_ref()), + _ => None, + }) + .collect::>(); + assert_eq!(lambdas.len(), 1); + assert_eq!( + lambdas[0] + .names + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["int"] + ); + } + + #[test] + fn type_param_evaluator_uses_dot_format_varname() { let code = compile_exec( "\ class C[T: int]: @@ -22296,7 +27679,7 @@ class C[T: int]: } #[test] - fn test_generic_class_double_star_bases_use_tuple_ex_call_path() { + fn generic_class_double_star_bases_use_tuple_ex_call_path() { let code = compile_exec( "\ def f(Base, kwargs): @@ -22335,7 +27718,7 @@ def f(Base, kwargs): } #[test] - fn test_generic_function_defaults_call_type_params_like_cpython() { + fn generic_function_defaults_call_type_params_like_cpython() { let code = compile_exec( "\ def func[T](a: T = 'a', *, b: T = 'b'): @@ -22377,12 +27760,33 @@ def func[T](a: T = 'a', *, b: T = 'b'): ] ) }), - "generic defaults call should not use RustPython-specific PUSH_NULL reshuffle, got ops={ops:?}" + "CPython generic defaults use SWAP/CALL after codegen_make_closure(), not a PUSH_NULL reshuffle, got ops={ops:?}" ); } #[test] - fn test_class_type_param_bound_prefers_classdict_over_outer_function_local() { + fn generic_function_type_params_varnames_include_defaults_like_cpython() { + let code = compile_exec( + "\ +def func[T](): + pass +", + ); + let type_params = + find_code(&code, "").expect("missing type params code"); + assert_eq!(type_params.arg_count, 0); + assert_eq!( + type_params + .varnames + .iter() + .map(String::as_str) + .collect::>(), + vec![".defaults", "T"] + ); + } + + #[test] + fn class_type_param_bound_prefers_classdict_over_outer_function_local() { let code = compile_exec( "\ def f(self): @@ -22442,7 +27846,7 @@ def f(self): } #[test] - fn test_class_type_param_bound_respects_class_global_over_outer_function_local() { + fn class_type_param_bound_respects_class_global_over_outer_function_local() { let code = compile_exec( "\ def f(self): @@ -22486,7 +27890,7 @@ def f(self): } #[test] - fn test_generic_type_alias_in_class_does_not_capture_module_name() { + fn generic_type_alias_in_class_does_not_capture_module_name() { let code = compile_exec( r#" T = U = "global" @@ -22557,7 +27961,7 @@ class C: } #[test] - fn test_nested_generic_class_base_child_free_keeps_classdict_lookup() { + fn nested_generic_class_base_child_free_keeps_classdict_lookup() { for (child_name, base_expr) in [ ("", "make_base(T for _ in (1,))"), ("", "make_base([T for _ in (1,)])"), @@ -22607,7 +28011,7 @@ class C[T]: } #[test] - fn test_class_annotation_global_resolution_matches_cpython() { + fn class_annotation_global_resolution_matches_cpython() { let class_global = compile_exec( "\ X = 'global' @@ -22658,7 +28062,7 @@ def f(): } #[test] - fn test_constant_tuple_binops_fold_like_cpython() { + fn constant_tuple_binops_fold_like_cpython() { let code = compile_exec("value = (1,) * 17 + ('spam',)\n"); assert!( @@ -22681,7 +28085,67 @@ def f(): } #[test] - fn test_constant_list_iterable_uses_tuple() { + fn tuple_not_keeps_to_bool_unary_not_like_cpython() { + let code = compile_exec( + "\ +def f(): + return not () +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + + assert!( + ops.windows(3).any(|window| { + matches!(window[0].op, Instruction::LoadConst { consti } + if matches!( + &f.constants[consti.get(OpArg::new(u32::from(u8::from(window[0].arg))))], + ConstantData::Tuple { elements } if elements.is_empty() + )) && matches!(window[1].op, Instruction::ToBool) + && matches!(window[2].op, Instruction::UnaryNot) + }), + "CPython codegen emits TO_BOOL; UNARY_NOT for UnaryOp(Not), while flowgraph.c folds tuple literals only after the LOAD_CONST+TO_BOOL pass, got instructions={:?}", + f.instructions + ); + } + + #[test] + fn tuple_if_test_keeps_to_bool_jump_like_cpython() { + let code = compile_exec( + "\ +def f(): + if (): + return 1 + return 2 +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + + assert!( + ops.windows(3).any(|window| { + matches!(window[0].op, Instruction::LoadConst { consti } + if matches!( + &f.constants[consti.get(OpArg::new(u32::from(u8::from(window[0].arg))))], + ConstantData::Tuple { elements } if elements.is_empty() + )) && matches!(window[1].op, Instruction::ToBool) + && matches!(window[2].op, Instruction::PopJumpIfFalse { .. }) + }), + "CPython leaves tuple literal truth tests as LOAD_CONST tuple; TO_BOOL; POP_JUMP_IF_FALSE because tuple folding happens after constant jump folding, got instructions={:?}", + f.instructions + ); + } + + #[test] + fn constant_list_iterable_uses_tuple() { let code = compile_exec( "\ def f(): @@ -22714,7 +28178,229 @@ def f(): } #[test] - fn test_large_constant_list_iterable_keeps_streaming_list_build() { + fn constant_list_iterable_preserves_cpython_const_order() { + let code = compile_exec( + "\ +def f(): + for x in ['a', 'b', 'c']: + pass +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + + assert!( + matches!(constants[0], ConstantData::Str { value } if value.to_string() == "a"), + "CPython emits list elements as LOAD_CONST before flowgraph folds GET_ITER lists" + ); + assert!(matches!(constants[1], ConstantData::None)); + assert!(matches!( + constants[2], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Str { value: first }, + ConstantData::Str { value: second }, + ConstantData::Str { value: third }, + ] if first.to_string() == "a" + && second.to_string() == "b" + && third.to_string() == "c" + ) + )); + } + + #[test] + fn try_except_folded_tuple_consts_follow_cpython_block_order() { + let code = compile_exec( + "\ +def f(macrelease): + try: + g() + except ValueError: + macrelease = (10, 3) + if macrelease >= (10, 4): + pass +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + + assert!( + constants.windows(2).any(|window| { + matches!( + window, + [ + ConstantData::Tuple { elements: first }, + ConstantData::Tuple { elements: second }, + ] if matches!( + (first.as_slice(), second.as_slice()), + ( + [ + ConstantData::Integer { value: a }, + ConstantData::Integer { value: b }, + ], + [ + ConstantData::Integer { value: c }, + ConstantData::Integer { value: d }, + ], + ) if a == &BigInt::from(10) + && b == &BigInt::from(3) + && c == &BigInt::from(10) + && d == &BigInt::from(4) + ) + ) + }), + "CPython flowgraph.c walks b_next order, so the except-body tuple is folded before the following if-test tuple; got {constants:?}" + ); + } + + #[test] + fn small_set_membership_folds_before_later_unary_const_like_cpython() { + let code = compile_exec( + r#" +def f(method, n): + if method not in {"linear", "ranked"}: + pass + if method == "ranked": + start = (n - 1) / -2 +"#, + ); + let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + let frozenset_index = constants + .iter() + .position(|constant| matches!(constant, ConstantData::Frozenset { .. })) + .expect("missing folded membership frozenset"); + let negative_two_index = constants + .iter() + .position(|constant| { + matches!( + constant, + ConstantData::Integer { value } if value == &BigInt::from(-2) + ) + }) + .expect("missing folded -2 constant"); + + assert!( + frozenset_index < negative_two_index, + "CPython flowgraph.c optimizes BUILD_SET+CONTAINS_OP inline before folding the later unary -2; got {constants:?}" + ); + } + + #[test] + fn boolop_const_order_keeps_cpython_codegen_constants() { + let code = compile_exec( + "\ +def or_false(x): + return False or x + +def zero_or_tuple(): + return 0 or (1, -1) + +def tuple_or_tuple(): + return (1, -1) or (-1, 1) +", + ); + + let or_false = find_code(&code, "or_false").expect("missing or_false code"); + let constants = or_false.constants.iter().collect::>(); + assert_eq!(constants.len(), 1); + assert!( + matches!(constants[0], ConstantData::Boolean { value: false }), + "CPython registers the skipped boolop literal before flowgraph removes the branch" + ); + + let zero_or_tuple = find_code(&code, "zero_or_tuple").expect("missing zero_or_tuple code"); + let constants = zero_or_tuple.constants.iter().collect::>(); + assert_eq!(constants.len(), 2); + assert!( + matches!( + constants[0], + ConstantData::Integer { value } if value == &BigInt::from(0) + ) && matches!( + constants[1], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Integer { value: one }, + ConstantData::Integer { value: minus_one }, + ] if one == &BigInt::from(1) && minus_one == &BigInt::from(-1) + ) + ), + "CPython keeps the skipped scalar literal before the folded tuple constant" + ); + + let tuple_or_tuple = + find_code(&code, "tuple_or_tuple").expect("missing tuple_or_tuple code"); + let constants = tuple_or_tuple.constants.iter().collect::>(); + assert_eq!(constants.len(), 3); + assert!( + matches!( + constants[0], + ConstantData::Integer { value } if value == &BigInt::from(1) + ) && matches!( + constants[1], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Integer { value: one }, + ConstantData::Integer { value: minus_one }, + ] if one == &BigInt::from(1) && minus_one == &BigInt::from(-1) + ) + ) && matches!( + constants[2], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Integer { value: minus_one }, + ConstantData::Integer { value: one }, + ] if minus_one == &BigInt::from(-1) && one == &BigInt::from(1) + ) + ), + "CPython compiles boolop tuple heads before flowgraph folds them" + ); + } + + #[test] + fn lambda_without_body_constants_keeps_none_like_cpython() { + let code = compile_exec("f = lambda x: x"); + let lambda = find_code(&code, "").expect("missing lambda code"); + let constants = lambda.constants.iter().collect::>(); + assert_eq!(constants.len(), 1); + + assert!( + matches!(constants[0], ConstantData::None), + "CPython AddReturnAtEnd registers None for constant-free lambdas" + ); + } + + #[test] + fn call_function_ex_empty_args_tuple_is_folded_late_like_cpython() { + let code = compile_exec( + "\ +def f(g, kwargs, ns): + g(**kwargs) + ns['T'] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + assert_eq!(constants.len(), 3); + + assert!( + matches!(constants[0], ConstantData::Str { value } if value.to_string() == "T") + && matches!(constants[1], ConstantData::None) + && matches!(constants[2], ConstantData::Tuple { elements } if elements.is_empty()), + "CPython emits BUILD_TUPLE 0 for CALL_FUNCTION_EX args and folds it after earlier constants" + ); + } + + #[test] + fn large_constant_list_iterable_keeps_streaming_list_build() { let source = format!( "def f():\n for x in [{}]:\n pass\n", (0..=STACK_USE_GUIDELINE) @@ -22742,7 +28428,7 @@ def f(): } #[test] - fn test_constant_set_iterable_uses_frozenset_const() { + fn constant_set_iterable_uses_frozenset_const() { let code = compile_exec( "\ def f(): @@ -22772,7 +28458,7 @@ def f(): } #[test] - fn test_constant_list_membership_uses_tuple_const() { + fn constant_list_membership_uses_tuple_const() { let code = compile_exec( "\ f = lambda x: x in [1, 2, 3] @@ -22802,7 +28488,7 @@ f = lambda x: x in [1, 2, 3] } #[test] - fn test_small_constant_set_membership_uses_frozenset_const() { + fn small_constant_set_membership_uses_frozenset_const() { let code = compile_exec( "\ f = lambda x: x in {0} @@ -22825,7 +28511,7 @@ f = lambda x: x in {0} } #[test] - fn test_nonconstant_list_membership_uses_tuple() { + fn nonconstant_list_membership_uses_tuple() { let code = compile_exec( "\ def f(a, b, c, x): @@ -22855,7 +28541,7 @@ def f(a, b, c, x): } #[test] - fn test_unary_not_membership_and_identity_invert_compare_op() { + fn unary_not_membership_and_identity_invert_compare_op() { let code = compile_exec( "\ def f(a, b, d): @@ -22885,7 +28571,7 @@ def f(a, b, d): } #[test] - fn test_starred_tuple_iterable_drops_list_to_tuple_before_get_iter() { + fn starred_tuple_iterable_drops_list_to_tuple_before_get_iter() { let code = compile_exec( "\ def f(a, b, c): @@ -22908,7 +28594,7 @@ def f(a, b, c): } #[test] - fn test_comprehension_single_list_iterable_uses_tuple() { + fn comprehension_single_list_iterable_uses_tuple() { let code = compile_exec( "\ def g(): @@ -22935,7 +28621,221 @@ def g(): } #[test] - fn test_nested_comprehension_list_iterable_uses_tuple() { + fn comprehension_list_iterable_build_uses_iter_location_like_cpython() { + let code = compile_exec( + "\ +async def f(i): + return i + +async def run_list(): + return [await c for c in [f(1), f(41)]] +", + ); + let run_list = find_code(&code, "run_list").expect("missing run_list code"); + assert_eq!( + run_list.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xdc, 0x1e, 0x1f, 0xa0, 0x01, 0x9b, 0x64, 0xa4, 0x41, 0xa0, + 0x62, 0xa3, 0x45, 0x99, 0x5d, 0xd3, 0x0b, 0x2b, 0x99, 0x5d, 0x98, 0x01, 0x8f, 0x47, + 0x8a, 0x47, 0x99, 0x5d, 0xd1, 0x0b, 0x2b, 0xd0, 0x04, 0x2b, 0x89, 0x47, 0xf9, 0xd2, + 0x0b, 0x2b, 0xf9, + ], + "CPython codegen_comprehension_iter() emits GET_ITER at LOC(comp->iter)" + ); + } + + #[test] + fn comprehension_boolop_iter_get_iter_uses_iter_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + return any(not w.cancelled() for w in (self._waiters or ())) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let get_iter_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::GetIter).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert!( + get_iter_positions.contains(&(2, 44, 2, 63)), + "CPython codegen_comprehension_iter() emits GET_ITER at LOC(comp->iter), got {get_iter_positions:?}" + ); + } + + #[test] + fn inlined_comprehension_backedges_use_element_location_like_cpython() { + let code = compile_exec( + "\ +async def f(i): + return i + +async def run_list(): + return [s for c in [f(''), f('abc')] for s in await c] +", + ); + let run_list = find_code(&code, "run_list").expect("missing run_list code"); + assert_eq!( + run_list.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xdc, 0x18, 0x19, 0x98, 0x22, 0x9b, 0x05, 0x9c, 0x71, 0xa0, + 0x15, 0x9b, 0x78, 0xd1, 0x17, 0x28, 0xd4, 0x0b, 0x3a, 0xd1, 0x17, 0x28, 0x90, 0x21, + 0xb7, 0x27, 0xb2, 0x27, 0xa8, 0x51, 0x8a, 0x41, 0xb1, 0x27, 0x89, 0x41, 0xd1, 0x17, + 0x28, 0xd2, 0x0b, 0x3a, 0xd0, 0x04, 0x3a, 0xb1, 0x27, 0xf9, 0xd3, 0x0b, 0x3a, 0xf9, + ], + "CPython codegen_sync_comprehension_generator() emits comprehension backedges at elt_loc" + ); + } + + #[test] + fn nested_dict_comprehension_outer_backedge_uses_key_location_like_cpython() { + let code = compile_exec( + "\ +def f(items): + return {op: i for i, ops in items for op in ops} +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let backedge_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::JumpBackward { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert!( + backedge_positions.contains(&(2, 13, 2, 18)), + "CPython extends only the terminal dict-comprehension MAP_ADD/backedge location from key through value, got {backedge_positions:?}" + ); + assert!( + backedge_positions.contains(&(2, 13, 2, 15)), + "CPython keeps outer dict-comprehension generator backedges at LOC(key), got {backedge_positions:?}" + ); + } + + #[test] + fn inlined_comprehension_filter_jump_uses_element_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + return [action for action in self._actions if action.option_strings] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let filter_jump_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::PopJumpIfTrue { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing optimized filter jump"); + assert_eq!( + filter_jump_position, + (2, 13, 2, 19), + "CPython inlined comprehension filter jump inherits the element/backedge location after CFG cleanup" + ); + } + + #[test] + fn inlined_comprehension_ifexp_guard_jump_uses_body_location_like_cpython() { + let code = compile_exec( + "\ +def f(fields): + return [f for f in fields if (f.compare if f.hash is None else f.hash)] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let jump_forward_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::JumpForward { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing if-expression body jump"); + assert_eq!( + jump_forward_position, + (2, 35, 2, 44), + "CPython flowgraph.c::propagate_line_numbers() copies the if-expression body location onto the NO_LOCATION jump" + ); + } + + #[test] + fn inlined_async_comprehension_end_async_for_uses_comprehension_location_like_cpython() { + let code = compile_exec( + "\ +async def f(it): + for i in it: + yield i + +async def run_list(): + return [i + 1 async for i in f([10, 20])] +", + ); + let run_list = find_code(&code, "run_list").expect("missing run_list code"); + assert_eq!( + run_list.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xdc, 0x21, 0x22, 0xa0, 0x42, 0xa8, 0x02, 0xa0, 0x38, 0xa4, + 0x1b, 0xd7, 0x0b, 0x2d, 0xd3, 0x0b, 0x2d, 0x98, 0x41, 0x90, 0x01, 0x8f, 0x45, 0x88, + 0x45, 0xd4, 0x0b, 0x2d, 0xd0, 0x04, 0x2d, 0xf9, 0xd2, 0x0b, 0x2d, 0xf9, + ], + "CPython codegen_async_comprehension_generator() emits END_ASYNC_FOR at comprehension loc" + ); + } + + #[test] + fn async_for_anext_sequence_uses_statement_location_like_cpython() { + let code = compile_exec( + "\ +async def f(source, buffer): + async for i1, i2 in source(): + buffer.append(i1 + i2) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd9, 0x18, 0x1e, 0x9c, 0x08, 0xf7, 0x00, 0x01, 0x05, 0x1f, + 0xf0, 0x00, 0x01, 0x05, 0x1f, 0x89, 0x66, 0x88, 0x62, 0xd8, 0x08, 0x0e, 0x8f, 0x0d, + 0x89, 0x0d, 0x90, 0x62, 0x95, 0x67, 0xd6, 0x08, 0x1e, 0xf1, 0x03, 0x01, 0x05, 0x1f, + 0x9a, 0x08, 0xf9, + ], + "CPython codegen_async_for() emits GET_ANEXT/yield-from scaffolding at LOC(s)" + ); + } + + #[test] + fn nested_comprehension_list_iterable_uses_tuple() { let code = compile_exec( "\ def f(): @@ -22962,7 +28862,7 @@ def f(): } #[test] - fn test_comprehension_singleton_sub_iter_uses_assignment_idiom() { + fn comprehension_singleton_sub_iter_uses_assignment_idiom() { let code = compile_exec( "\ def f(): @@ -23003,7 +28903,7 @@ def f(): } #[test] - fn test_constant_comprehension_iterable_with_unary_int_uses_tuple_const() { + fn constant_comprehension_iterable_with_unary_int_uses_tuple_const() { let code = compile_exec( "\ l = lambda : [2 < x for x in [-1, 3, 0]] @@ -23029,7 +28929,7 @@ l = lambda : [2 < x for x in [-1, 3, 0]] } #[test] - fn test_module_scope_listcomp_is_inlined() { + fn module_scope_listcomp_is_inlined() { let code = compile_exec("values = [i for i in range(3)]\n"); assert!( @@ -23046,7 +28946,7 @@ l = lambda : [2 < x for x in [-1, 3, 0]] } #[test] - fn test_module_scope_dictcomp_is_inlined() { + fn module_scope_dictcomp_is_inlined() { let code = compile_exec("mapping = {i: i for i in range(3)}\n"); assert!( @@ -23063,7 +28963,7 @@ l = lambda : [2 < x for x in [-1, 3, 0]] } #[test] - fn test_async_dictcomp_in_async_function_is_inlined() { + fn async_dictcomp_in_async_function_is_inlined() { let code = compile_exec( "\ async def f(items): @@ -23098,7 +28998,7 @@ async def f(items): } #[test] - fn test_async_inlined_comprehension_inlines_restore_return_into_end_async_for() { + fn async_inlined_comprehension_inlines_restore_return_into_end_async_for() { let code = compile_exec( "\ async def f(): @@ -23146,7 +29046,7 @@ async def f(): } #[test] - fn test_await_cleanup_throw_falls_through_until_cold_reorder() { + fn await_cleanup_throw_falls_through_until_cold_reorder() { let code = compile_exec( "\ async def f(): @@ -23183,7 +29083,7 @@ async def f(): } #[test] - fn test_match_async_inlined_comprehension_success_jump_no_interrupt() { + fn match_async_inlined_comprehension_success_jump_layout() { let code = compile_exec( "\ async def f(name_3, name_5): @@ -23206,35 +29106,36 @@ async def f(name_3, name_5): .collect(); assert!( - ops.windows(3).any(|window| { + ops.windows(4).any(|window| { matches!( window, [ Instruction::PopTop, + Instruction::JumpForward { .. }, + Instruction::Copy { .. }, Instruction::StoreFast { .. }, - Instruction::JumpBackwardNoInterrupt { .. }, ] ) }), - "expected CPython-style no-interrupt match success backedge after async comprehension cleanup, got ops={ops:?}" + "expected CPython-style plain match success jump before async comprehension case, got ops={ops:?}" ); assert!( - !ops.windows(3).any(|window| { + ops.windows(3).any(|window| { matches!( window, [ - Instruction::PopTop, Instruction::StoreFast { .. }, - Instruction::JumpBackward { .. }, + Instruction::JumpBackwardNoInterrupt { .. }, + Instruction::CallIntrinsic1 { .. }, ] ) }), - "match success cleanup backedge should not be a regular interrupting jump, got ops={ops:?}" + "CPython codegen_pop_inlined_comprehension_locals() emits JUMP_NO_INTERRUPT before the cleanup path; after flowgraph reordering it remains a backward no-interrupt jump before the StopIteration handler, got ops={ops:?}" ); } #[test] - fn test_for_loop_if_return_reorders_continue_backedge_before_exit_body() { + fn for_loop_if_return_reorders_continue_backedge_before_exit_body() { let code = compile_exec( "\ def f(items, occurrence): @@ -23283,7 +29184,7 @@ def f(items, occurrence): } #[test] - fn test_sync_with_after_async_for_keeps_end_async_for_line_marker() { + fn sync_with_after_async_for_keeps_end_async_for_line_marker() { let code = compile_exec( "\ async def f(cm, source, tgt): @@ -23319,7 +29220,7 @@ async def f(cm, source, tgt): } #[test] - fn test_genexpr_with_async_comprehension_element_is_async_generator() { + fn genexpr_with_async_comprehension_element_is_async_generator() { let code = compile_exec( "\ async def f(): @@ -23353,7 +29254,25 @@ async def f(): } #[test] - fn test_nested_module_scope_dictcomp_symbols_are_local() { + fn async_comprehension_propagates_coroutine_to_enclosing_genexpr_like_cpython() { + let symbol_table = scan_program_symbol_table( + "\ +async def f(): + gen = ([i async for i in asynciter([1, 2])] for j in [10, 20]) + return [x async for x in gen] +", + ); + let genexpr = + find_symbol_table(&symbol_table, "").expect("missing genexpr symbol table"); + assert!(genexpr.is_generator, "expected genexpr symbol table"); + assert!( + genexpr.is_coroutine, + "CPython symtable_handle_comprehension() propagates non-generator async comprehension ste_coroutine to the enclosing genexpr" + ); + } + + #[test] + fn nested_module_scope_dictcomp_symbols_are_local() { let symbol_table = scan_program_symbol_table( "\ deoptmap = { @@ -23393,7 +29312,7 @@ deoptmap = { } #[test] - fn test_nested_module_scope_dictcomp_uses_fast_locals() { + fn nested_module_scope_dictcomp_uses_fast_locals() { let code = compile_exec( "\ deoptmap = { @@ -23440,7 +29359,7 @@ deoptmap = { } #[test] - fn test_module_scope_inlined_comprehension_keeps_outer_iter_as_name_lookup() { + fn module_scope_inlined_comprehension_keeps_outer_iter_as_name_lookup() { let code = compile_exec( "\ path_separators = ['/'] @@ -23481,7 +29400,7 @@ _pathseps_with_colon = {f':{s}' for s in path_separators} } #[test] - fn test_function_scope_inlined_comprehension_restore_keeps_swap_before_duplicate_store() { + fn function_scope_inlined_comprehension_restore_keeps_swap_before_duplicate_store() { let code = compile_exec( "\ def f(): @@ -23512,7 +29431,7 @@ def f(): } #[test] - fn test_inlined_comprehension_namedexpr_target_stays_parent_fast_local() { + fn inlined_comprehension_namedexpr_target_stays_parent_fast_local() { let code = compile_exec( "\ def f(seq, emit): @@ -23532,7 +29451,26 @@ def f(seq, emit): } #[test] - fn test_global_namedexpr_in_inlined_comprehension_saves_fast_slot() { + fn inlined_comprehension_namedexpr_varnames_match_cpython_order() { + let code = compile_exec( + "\ +def f(): + def spam(a): + return a + input_data = [1, 2, 3] + res = [(x, y, x / y) for x in input_data if (y := spam(x)) > 0] + return res +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.varnames.iter().map(String::as_str).collect::>(), + vec!["spam", "input_data", "x", "y", "res"] + ); + } + + #[test] + fn global_namedexpr_in_inlined_comprehension_saves_fast_slot() { let code = compile_exec( "\ def f(seq, value): @@ -23560,7 +29498,28 @@ def f(seq, value): } #[test] - fn test_genexpr_namedexpr_target_is_cell_not_fast_local() { + fn namedexpr_copy_uses_namedexpr_location_like_cpython() { + let code = compile_exec( + "\ +def outer(): + a = 10 + def spam(): + nonlocal a + (a := 20) +", + ); + let spam = find_code(&code, "spam").expect("missing spam code"); + + // CPython 3.14 NamedExpr_kind emits COPY at LOC(named expression), + // between visiting the value and visiting the target. + assert_eq!( + spam.linetable.as_ref(), + &[0xf8, 0x80, 0x00, 0xe0, 0x0e, 0x10, 0x88, 0x17, 0x8b, 0x11,] + ); + } + + #[test] + fn genexpr_namedexpr_target_is_cell_not_fast_local() { let code = compile_exec( "\ def f(seq): @@ -23578,7 +29537,32 @@ def f(seq): } #[test] - fn test_inlined_comprehension_restore_does_not_form_store_fast_load_fast() { + fn public_cellvars_follow_cpython_localsplus_order() { + let code = compile_exec( + "\ +def f(): + x = 10 + t = False + g = ((i, j) for i in range(x) if t for j in range(x)) + [x for x in range(3)] + return g +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + assert_eq!( + f.varnames.iter().map(String::as_str).collect::>(), + ["g", "x"] + ); + assert_eq!( + f.cellvars.iter().map(String::as_str).collect::>(), + ["x", "t"], + "CPython assemble.c exposes co_cellvars in localsplus order: merged local cells before non-local cells" + ); + } + + #[test] + fn inlined_comprehension_restore_does_not_form_store_fast_load_fast() { let code = compile_exec( "\ def f(e): @@ -23668,7 +29652,7 @@ def g(datadir): } #[test] - fn test_single_mode_folded_multiline_constant_does_not_leave_nops() { + fn single_mode_folded_multiline_constant_does_not_leave_nops() { let code = compile_single( "\ (- @@ -23689,7 +29673,7 @@ def g(datadir): } #[test] - fn test_folded_multiline_tuple_constant_does_not_leave_operand_nops() { + fn folded_multiline_tuple_constant_does_not_leave_operand_nops() { let code = compile_exec( "\ values = ( @@ -23711,7 +29695,7 @@ values = ( } #[test] - fn test_folded_multiline_bytes_binop_does_not_leave_operand_nops() { + fn folded_multiline_bytes_binop_does_not_leave_operand_nops() { let code = compile_exec( "\ def f(self, out): @@ -23732,7 +29716,7 @@ def f(self, out): } #[test] - fn test_folded_binop_at_branch_body_start_does_not_leave_nop() { + fn folded_binop_at_branch_body_start_does_not_leave_nop() { let code = compile_exec( "\ def f(sys): @@ -23765,7 +29749,7 @@ def f(sys): } #[test] - fn test_folded_iterable_at_assert_target_does_not_leave_nop() { + fn folded_iterable_at_assert_target_does_not_leave_nop() { let code = compile_exec( r#" def f(caches, non_caches): @@ -23791,7 +29775,7 @@ def f(caches, non_caches): } #[test] - fn test_multiline_unpack_target_uses_element_locations() { + fn multiline_unpack_target_uses_element_locations() { let code = compile_exec( "\ def f(cm): @@ -23816,7 +29800,7 @@ def f(cm): } #[test] - fn test_or_condition_in_jump_context_uses_shared_true_fallthrough() { + fn or_condition_in_jump_context_uses_shared_true_fallthrough() { let code = compile_exec( "\ def f(lines): @@ -23851,7 +29835,7 @@ def f(lines): } #[test] - fn test_loop_break_bool_chain_reorders_false_path_to_jump_back() { + fn loop_break_bool_chain_reorders_false_path_to_jump_back() { let code = compile_exec( "\ def f(filters, text, category, module, lineno, defaultaction): @@ -23894,7 +29878,7 @@ def f(filters, text, category, module, lineno, defaultaction): } #[test] - fn test_loop_conditional_body_keeps_duplicate_jump_back_paths() { + fn loop_conditional_body_keeps_duplicate_jump_back_paths() { let code = compile_exec( "\ def f(new, old): @@ -23944,7 +29928,7 @@ def f(new, old): } #[test] - fn test_try_loop_inner_if_keeps_duplicate_jump_back_paths() { + fn try_loop_inner_if_keeps_duplicate_jump_back_paths() { let code = compile_exec( "\ def f(config, logging): @@ -23993,7 +29977,7 @@ def f(config, logging): } #[test] - fn test_try_loop_nested_bool_tail_keeps_duplicate_jump_back_paths() { + fn try_loop_nested_bool_tail_keeps_duplicate_jump_back_paths() { let code = compile_exec( "\ def f(obj, flags, writer, value, Error): @@ -24043,7 +30027,7 @@ def f(obj, flags, writer, value, Error): } #[test] - fn test_nested_continue_shares_backedge_with_fallthrough_body() { + fn nested_continue_shares_backedge_with_fallthrough_body() { let code = compile_exec( "\ def f(names, show_empty, keywords, args_buffer, args, cls, object, level): @@ -24106,7 +30090,7 @@ def f(names, show_empty, keywords, args_buffer, args, cls, object, level): } #[test] - fn test_line_bearing_loop_if_false_backedge_keeps_body_before_jump_back() { + fn line_bearing_loop_if_false_backedge_keeps_body_before_jump_back() { let code = compile_exec( "\ def f(self, replacement_pairs): @@ -24165,7 +30149,7 @@ def f(self, replacement_pairs): } #[test] - fn test_branch_local_implicit_continue_keeps_body_before_jump_back() { + fn branch_local_implicit_continue_keeps_body_before_jump_back() { let code = compile_exec( "\ def f(items, outer, cond, sub, out): @@ -24224,7 +30208,7 @@ def f(items, outer, cond, sub, out): } #[test] - fn test_boolop_continue_deduplicates_marker_jump_back() { + fn boolop_continue_deduplicates_marker_jump_back() { let code = compile_exec( "\ def f(ws, seen, more_than): @@ -24276,7 +30260,7 @@ def f(ws, seen, more_than): } #[test] - fn test_loop_elif_nested_if_false_backedge_keeps_body_before_jump_back() { + fn loop_elif_nested_if_false_backedge_keeps_body_before_jump_back() { let code = compile_exec( "\ def f(keys, parse_int, d, ampm, AM, PM): @@ -24366,7 +30350,7 @@ def f(keys, parse_int, d, ampm, AM, PM): } #[test] - fn test_loop_nested_if_before_elif_keeps_body_before_false_backedge() { + fn loop_nested_if_before_elif_keeps_body_before_false_backedge() { let code = compile_exec( "\ def f(keys, parse_int, found_dict, locale_time): @@ -24432,7 +30416,7 @@ def f(keys, parse_int, found_dict, locale_time): } #[test] - fn test_elif_pass_before_raise_keeps_line_bearing_forward_jump() { + fn elif_pass_before_raise_keeps_line_bearing_forward_jump() { let code = compile_exec( "\ def f(entries, path, self): @@ -24485,7 +30469,7 @@ def f(entries, path, self): } #[test] - fn test_loop_multiblock_conditional_body_keeps_body_before_jump_back() { + fn loop_multiblock_conditional_body_keeps_body_before_jump_back() { let code = compile_exec( "\ def f(random, d, f): @@ -24528,7 +30512,7 @@ def f(random, d, f): } #[test] - fn test_loop_not_conditional_body_threads_true_path_to_jump_back() { + fn loop_not_conditional_body_threads_true_path_to_jump_back() { let code = compile_exec( "\ def f(xs): @@ -24564,7 +30548,7 @@ def f(xs): } #[test] - fn test_loop_not_in_conditional_body_threads_true_path_to_jump_back() { + fn loop_not_in_conditional_body_threads_true_path_to_jump_back() { let code = compile_exec( "\ def f(native, array): @@ -24601,7 +30585,7 @@ def f(native, array): } #[test] - fn test_while_implicit_continue_body_after_jumpback_for_boolop_call_arg() { + fn while_implicit_continue_body_after_jumpback_for_boolop_call_arg() { let code = compile_exec( "\ def f(source, state, verbose, nested): @@ -24645,7 +30629,7 @@ def f(source, state, verbose, nested): } #[test] - fn test_multiblock_elif_continue_keeps_next_test_before_backedge() { + fn multiblock_elif_continue_keeps_next_test_before_backedge() { let code = compile_exec( "\ def f(source, state, verbose, nested, subpatternappend, start, MAXGROUPS): @@ -24737,7 +30721,7 @@ def f(source, state, verbose, nested, subpatternappend, start, MAXGROUPS): } #[test] - fn test_while_scope_exit_body_keeps_line_backedge_before_raise_body() { + fn while_scope_exit_body_keeps_line_backedge_before_raise_body() { let code = compile_exec( "\ FLAGS = {} @@ -24832,7 +30816,7 @@ def f(source, state, char): } #[test] - fn test_call_body_implicit_continue_keeps_cpython_normalized_forward_jump() { + fn call_body_implicit_continue_keeps_cpython_normalized_forward_jump() { let code = compile_exec( "\ DIGITS = '0123456789' @@ -24891,7 +30875,7 @@ def f(s, sget, lappend, addgroup, this, c): } #[test] - fn test_empty_if_end_label_preserves_cpython_return_anchor_nop() { + fn empty_if_end_label_preserves_cpython_return_anchor_nop() { let code = compile_exec( "\ SRE_FLAG_LOCALE = 1 @@ -24939,7 +30923,7 @@ def f(src, flags): } #[test] - fn test_nested_except_normal_exit_return_uses_strong_loads() { + fn nested_except_normal_exit_return_uses_strong_loads() { let code = compile_exec( "\ LITERAL = 1 @@ -24983,7 +30967,7 @@ def f(source, escape): } #[test] - fn test_targeted_nop_after_prefix_for_else_uses_strong_for_tail_loads() { + fn targeted_nop_after_prefix_for_else_uses_strong_for_tail_loads() { let code = compile_exec( "\ LITERAL = 1 @@ -25052,7 +31036,7 @@ def f(items): } #[test] - fn test_plain_pass_before_for_tail_keeps_borrows() { + fn plain_pass_before_for_tail_keeps_borrows() { let code = compile_exec( "\ def f(xs): @@ -25081,7 +31065,7 @@ def f(xs): } #[test] - fn test_targeted_nop_after_return_uses_strong_pair_call_args() { + fn targeted_nop_after_return_uses_strong_pair_call_args() { let code = compile_exec( "\ def f(x, d, count, inner, hi, w1, lo, w2): @@ -25133,7 +31117,7 @@ def f(x, d, count, inner, hi, w1, lo, w2): } #[test] - fn test_loop_if_pass_uses_line_bearing_jump_back_instead_of_nop() { + fn loop_if_pass_uses_line_bearing_jump_back_instead_of_nop() { let code = compile_exec( "\ def f(x, y): @@ -25174,7 +31158,7 @@ def f(x, y): } #[test] - fn test_constant_true_while_pass_keeps_loop_header_nop() { + fn constant_true_while_pass_keeps_loop_header_nop() { let code = compile_exec( "\ def f(): @@ -25206,7 +31190,7 @@ def f(): } #[test] - fn test_nested_if_shared_jump_back_target_is_duplicated() { + fn nested_if_shared_jump_back_target_is_duplicated() { let code = compile_exec( "\ def f(s, size, encodeSetO, encodeWhiteSpace): @@ -25256,7 +31240,7 @@ def f(s, size, encodeSetO, encodeWhiteSpace): } #[test] - fn test_exception_cleanup_backedge_target_is_shared() { + fn exception_cleanup_backedge_target_is_shared() { let code = compile_exec( "\ def f(enum_class, value, Flag, int_type, is_single_bit): @@ -25320,7 +31304,7 @@ def f(enum_class, value, Flag, int_type, is_single_bit): } #[test] - fn test_protected_loop_conditional_keeps_forward_body_entry() { + fn protected_loop_conditional_keeps_forward_body_entry() { let code = compile_exec( "\ def outer(it, C1): @@ -25363,7 +31347,7 @@ def outer(it, C1): } #[test] - fn test_nested_except_false_path_duplicates_pop_except_jump_back_tail() { + fn nested_except_false_path_duplicates_pop_except_jump_back_tail() { let code = compile_exec( "\ def f(it, C3): @@ -25408,7 +31392,7 @@ def f(it, C3): } #[test] - fn test_more_nested_except_false_paths_duplicate_all_jump_back_tails() { + fn more_nested_except_false_paths_duplicate_all_jump_back_tails() { let code = compile_exec( "\ def f(it, C3, C4): @@ -25462,7 +31446,7 @@ def f(it, C3, C4): } #[test] - fn test_no_wraparound_jump_keeps_forward_hop_before_loop_backedge() { + fn no_wraparound_jump_keeps_forward_hop_before_loop_backedge() { let code = compile_exec( "\ def while_not_chained(a, b, c): @@ -25497,7 +31481,7 @@ def while_not_chained(a, b, c): } #[test] - fn test_nested_while_chained_compare_break_keeps_break_jump_block() { + fn nested_while_chained_compare_break_keeps_break_jump_block() { let code = compile_exec( "\ def f(start, self, stop, size): @@ -25541,7 +31525,7 @@ def f(start, self, stop, size): } #[test] - fn test_while_break_else_keeps_true_edge_into_forward_break_body() { + fn while_break_else_keeps_true_edge_into_forward_break_body() { let code = compile_exec( "\ def f(i): @@ -25580,7 +31564,7 @@ def f(i): } #[test] - fn test_nested_if_continue_reorders_false_path_to_loop_backedge() { + fn nested_if_continue_reorders_false_path_to_loop_backedge() { let code = compile_exec( "\ def f(items, changes): @@ -25621,7 +31605,7 @@ def f(items, changes): } #[test] - fn test_loop_assert_keeps_false_edge_into_raise_body() { + fn loop_assert_keeps_false_edge_into_raise_body() { let code = compile_exec( "\ def f(bytecode): @@ -25657,7 +31641,7 @@ def f(bytecode): } #[test] - fn test_and_is_not_none_loop_guard_uses_direct_jump_back_false_path() { + fn and_is_not_none_loop_guard_uses_direct_jump_back_false_path() { let code = compile_exec( "\ def f(code): @@ -25696,7 +31680,7 @@ def f(code): } #[test] - fn test_large_is_not_none_loop_guard_uses_direct_jump_back_false_path() { + fn large_is_not_none_loop_guard_uses_direct_jump_back_false_path() { let code = compile_exec( "\ def f(cls, _FIELDS, _PARAMS): @@ -25744,7 +31728,7 @@ def f(cls, _FIELDS, _PARAMS): } #[test] - fn test_continue_inside_with_keeps_line_marker_nop_before_exit_cleanup() { + fn continue_inside_with_keeps_line_marker_nop_before_exit_cleanup() { let code = compile_exec( "\ def f(it): @@ -25785,7 +31769,7 @@ def f(it): } #[test] - fn test_nested_async_with_normal_cleanup_drops_pop_block_nop() { + fn nested_async_with_normal_cleanup_drops_pop_block_nop() { let code = compile_exec( "\ async def foo(): @@ -25836,7 +31820,7 @@ async def foo(): } #[test] - fn test_async_with_try_finally_before_outer_sync_with_cleanup_keeps_anchor_nop() { + fn async_with_try_finally_before_outer_sync_with_cleanup_keeps_anchor_nop() { let code = compile_exec( "\ async def foo(self): @@ -25882,7 +31866,7 @@ async def foo(self): } #[test] - fn test_nested_terminal_with_keeps_outer_cleanup_target_nop() { + fn nested_terminal_with_keeps_outer_cleanup_target_nop() { let code = compile_exec( "\ def f(): @@ -25918,7 +31902,7 @@ def f(): } #[test] - fn test_nested_nonterminal_with_drops_outer_cleanup_target_nop() { + fn nested_nonterminal_with_drops_outer_cleanup_target_nop() { let code = compile_exec( "\ def f(): @@ -25954,7 +31938,57 @@ def f(): } #[test] - fn test_try_loop_elif_places_return_before_orelse_tail() { + fn nested_terminal_with_before_successor_drops_after_block_nop() { + let code = compile_exec( + "\ +def f(a, b, c): + with a: + with b: + raise c + c() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + + assert!( + ops.windows(4).any(|window| { + matches!( + window, + [ + Instruction::Copy { .. }, + Instruction::PopExcept, + Instruction::Reraise { .. }, + Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }, + ] + ) + }), + "CPython falls through from the terminal inner-with cleanup to the following statement without an after-block NOP, got ops={ops:?}" + ); + assert!( + !ops.windows(5).any(|window| { + matches!( + window, + [ + Instruction::Copy { .. }, + Instruction::PopExcept, + Instruction::Reraise { .. }, + Instruction::Nop, + Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }, + ] + ) + }), + "unexpected inner with after-block NOP before following statement, got ops={ops:?}" + ); + } + + #[test] + fn try_loop_elif_places_return_before_orelse_tail() { let code = compile_exec( "\ def f(source, suggest, tb, s): @@ -26018,7 +32052,7 @@ def f(source, suggest, tb, s): } #[test] - fn test_constant_false_while_else_deopts_post_else_borrows() { + fn constant_false_while_else_deopts_post_else_borrows() { let code = compile_exec( "\ def f(self): @@ -26057,7 +32091,7 @@ def f(self): } #[test] - fn test_single_unpack_assignment_disables_constant_collection_folding() { + fn single_unpack_assignment_disables_constant_collection_folding() { let code = compile_exec("a, b, c = 1, 2, 3\n"); assert!( @@ -26090,7 +32124,7 @@ def f(self): } #[test] - fn test_four_item_unpack_assignment_folds_tuple_constant_like_cpython() { + fn four_item_unpack_assignment_folds_tuple_constant_like_cpython() { let code = compile_exec("a, b, c, d = 1, 2, 3, 4\n"); assert!( @@ -26120,7 +32154,7 @@ def f(self): } #[test] - fn test_chained_unpack_assignment_keeps_constant_collection_folding() { + fn chained_unpack_assignment_keeps_constant_collection_folding() { let code = compile_exec("(a, b) = c = d = (1, 2)\n"); assert!( @@ -26146,7 +32180,7 @@ def f(self): } #[test] - fn test_constant_true_assert_skips_message_nested_scope() { + fn constant_true_assert_skips_message_nested_scope() { let code = compile_exec("assert 1, (lambda x: x + 1)\n"); assert_eq!( @@ -26171,7 +32205,7 @@ def f(self): } #[test] - fn test_constant_false_assert_uses_direct_raise_shape() { + fn constant_false_assert_uses_direct_raise_shape() { let code = compile_exec("assert 0, (lambda x: x + 1)\n"); assert!( @@ -26210,7 +32244,7 @@ def f(self): } #[test] - fn test_constant_unary_positive_and_invert_fold() { + fn constant_unary_positive_and_invert_fold() { let code = compile_exec("x = +1\nx = ~1\n"); assert!( @@ -26229,7 +32263,7 @@ def f(self): } #[test] - fn test_bool_invert_is_not_const_folded() { + fn bool_invert_is_not_const_folded() { let code = compile_exec("x = ~True\n"); assert!( @@ -26245,7 +32279,7 @@ def f(self): } #[test] - fn test_optimized_assert_preserves_nested_scope_order() { + fn optimized_assert_preserves_nested_scope_order() { compile_exec_optimized( "\ class S: @@ -26259,7 +32293,7 @@ class S: } #[test] - fn test_optimized_assert_with_nested_scope_in_first_iter() { + fn optimized_assert_with_nested_scope_in_first_iter() { // First iterator of a comprehension is evaluated in the enclosing // scope, so nested scopes inside it (the generator here) must also // be consumed when the assert is optimized away. @@ -26273,7 +32307,7 @@ def f(items): } #[test] - fn test_optimized_assert_with_lambda_defaults() { + fn optimized_assert_with_lambda_defaults() { // Lambda default values are evaluated in the enclosing scope, // so nested scopes inside defaults must be consumed. compile_exec_optimized( @@ -26286,7 +32320,7 @@ def f(items): } #[test] - fn test_try_else_nested_scopes_keep_subtable_cursor_aligned() { + fn try_else_nested_scopes_keep_subtable_cursor_aligned() { let code = compile_exec( "\ try: @@ -26322,7 +32356,7 @@ else: } #[test] - fn test_nested_try_else_multi_resume_join_keeps_strong_load_fast_tail() { + fn nested_try_else_multi_resume_join_keeps_strong_load_fast_tail() { let code = compile_exec( "\ def f(msg): @@ -26385,7 +32419,7 @@ def f(msg): } #[test] - fn test_protected_conditional_tail_keeps_strong_load_fast() { + fn protected_conditional_tail_keeps_strong_load_fast() { let code = compile_exec( "\ def f(m, class_name, category, warning_base): @@ -26435,7 +32469,7 @@ def f(m, class_name, category, warning_base): } #[test] - fn test_nonresuming_protected_conditional_tail_keeps_strong_load_fast() { + fn nonresuming_protected_conditional_tail_keeps_strong_load_fast() { let code = compile_exec( "\ def f(href, parse='xml'): @@ -26479,7 +32513,7 @@ def f(href, parse='xml'): } #[test] - fn test_optional_nonresuming_protected_tail_keeps_borrow() { + fn optional_nonresuming_protected_tail_keeps_borrow() { let code = compile_exec( "\ def f(b): @@ -26534,7 +32568,7 @@ def f(b): } #[test] - fn test_handled_except_conditional_tail_keeps_borrow() { + fn handled_except_conditional_tail_keeps_borrow() { let code = compile_exec( "\ def f(self): @@ -26590,7 +32624,7 @@ def f(self): } #[test] - fn test_handled_except_else_tail_keeps_borrow() { + fn handled_except_else_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, fut=None): @@ -26654,7 +32688,7 @@ def f(self, fut=None): } #[test] - fn test_reraising_handler_with_handled_returns_keeps_borrow() { + fn reraising_handler_with_handled_returns_keeps_borrow() { let code = compile_exec( "\ def f(self, fut=None): @@ -26731,7 +32765,7 @@ def f(self, fut=None): } #[test] - fn test_with_protected_conditional_tail_without_exception_match_keeps_borrow() { + fn with_protected_conditional_tail_without_exception_match_keeps_borrow() { let code = compile_exec( "\ def f(self, cm, p, platform): @@ -26781,7 +32815,7 @@ def f(self, cm, p, platform): } #[test] - fn test_listcomp_cleanup_predecessor_does_not_deopt_following_conditional_tail() { + fn listcomp_cleanup_predecessor_does_not_deopt_following_conditional_tail() { let code = compile_exec( "\ def f(self, compile_snippet): @@ -26831,7 +32865,7 @@ def f(self, compile_snippet): } #[test] - fn test_handler_resume_loop_conditional_tail_keeps_strong_load_fast() { + fn handler_resume_loop_conditional_tail_keeps_strong_load_fast() { let code = compile_exec( "\ def f(self): @@ -26914,7 +32948,7 @@ def f(self): } #[test] - fn test_handler_resume_while_conditional_tail_keeps_borrow_load_fast() { + fn handler_resume_while_conditional_tail_keeps_borrow_load_fast() { let code = compile_exec( "\ def f(value): @@ -26970,7 +33004,7 @@ def f(value): } #[test] - fn test_multi_handler_resume_while_tail_keeps_borrow_load_fast() { + fn multi_handler_resume_while_tail_keeps_borrow_load_fast() { let code = compile_exec( "\ def f(value): @@ -27036,7 +33070,7 @@ def f(value): } #[test] - fn test_multi_handler_resume_before_with_keeps_with_body_borrows() { + fn multi_handler_resume_before_with_keeps_with_body_borrows() { let code = compile_exec( "\ def f(self, input, cm): @@ -27089,7 +33123,7 @@ def f(self, input, cm): } #[test] - fn test_suppressing_with_and_typed_except_resume_loop_method_tail_keeps_strong_load_fast() { + fn suppressing_with_and_typed_except_resume_loop_method_tail_keeps_strong_load_fast() { let code = compile_exec( "\ def f(proc, text): @@ -27131,7 +33165,7 @@ def f(proc, text): } #[test] - fn test_handler_break_join_loop_body_and_tail_keep_strong_load_fast() { + fn handler_break_join_loop_body_and_tail_keep_strong_load_fast() { let code = compile_exec( "\ def f(function, stem): @@ -27214,7 +33248,7 @@ def f(function, stem): } #[test] - fn test_handler_resume_to_loop_header_keeps_loop_header_borrows() { + fn handler_resume_to_loop_header_keeps_loop_header_borrows() { let code = compile_exec( "\ def f(value, Phrase, get_word, errors, ENDS, DOT): @@ -27276,7 +33310,7 @@ def f(value, Phrase, get_word, errors, ENDS, DOT): } #[test] - fn test_reraising_except_loop_break_tail_keeps_post_loop_borrows() { + fn reraising_except_loop_break_tail_keeps_post_loop_borrows() { let code = compile_exec( "\ def f(flag=1, count=0): @@ -27331,7 +33365,7 @@ def f(flag=1, count=0): } #[test] - fn test_try_except_continue_keeps_try_line_nop_before_continue_jump() { + fn try_except_continue_keeps_try_line_nop_before_continue_jump() { let code = compile_exec( "\ def f(done=False): @@ -27369,7 +33403,7 @@ def f(done=False): } #[test] - fn test_for_else_pass_keeps_line_marker_after_pop_iter() { + fn for_else_pass_keeps_line_marker_after_pop_iter() { let code = compile_exec( "\ def f(): @@ -27407,7 +33441,7 @@ def f(): } #[test] - fn test_folded_if_chain_after_previous_chain_keeps_final_elif_line_marker() { + fn folded_if_chain_after_previous_chain_keeps_final_elif_line_marker() { let code = compile_exec( "\ def f(): @@ -27435,7 +33469,7 @@ def f(): } #[test] - fn test_handler_resume_before_later_loop_keeps_borrowed_tail_loads() { + fn handler_resume_before_later_loop_keeps_borrowed_tail_loads() { let code = compile_exec( "\ def f(msg, category): @@ -27519,7 +33553,7 @@ def f(msg, category): } #[test] - fn test_async_early_return_send_tail_uses_strong_load_fast_after_entry() { + fn async_early_return_send_tail_uses_strong_load_fast_after_entry() { let code = compile_exec( "\ class C: @@ -27618,7 +33652,7 @@ class C: } #[test] - fn test_async_with_return_await_after_early_return_keeps_borrow_load_fast() { + fn async_with_return_await_after_early_return_keeps_borrow_load_fast() { let code = compile_exec( "\ async def wait_for(fut, timeout): @@ -27689,7 +33723,155 @@ async def wait_for(fut, timeout): } #[test] - fn test_protected_import_tail_keeps_strong_load_fast() { + fn async_nested_try_finally_except_after_await_return_uses_strong_loads() { + let code = compile_exec( + "\ +async def f(a, b, c, h): + try: + try: + await sleep(1) + finally: + h() + except E: + pass + await sleep(0) + return a, b, c +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let return_idx = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::ReturnValue)) + .expect("missing return"); + let tail = &ops[return_idx.saturating_sub(4)..return_idx]; + + assert!( + tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFast { .. } | Instruction::LoadFastLoadFast { .. } + ) + }), + "nested try/finally inside try/except leaves CPython's empty normal-exit block before the next await, so return loads stay strong, got tail={tail:?}", + ); + assert!( + tail.iter().all(|unit| { + !matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "nested try/finally inside try/except should not borrow final return loads, got tail={tail:?}", + ); + } + + #[test] + fn async_conditional_raise_finally_except_after_await_return_uses_strong_pair() { + let code = compile_exec( + "\ +async def f(self, asyncio, sys, task, timeout_handle, sleep): + timed_out = False + structured_block_finished = False + outer_code_reached = False + try: + try: + await asyncio.sleep(sleep) + structured_block_finished = True + finally: + timeout_handle.cancel() + if ( + timed_out + and task.uncancel() == 0 + and type(sys.exception()) is asyncio.CancelledError + ): + raise TimeoutError + except TimeoutError: + self.assertTrue(timed_out) + outer_code_reached = True + await asyncio.sleep(0) + return timed_out, structured_block_finished, outer_code_reached +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let return_idx = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::ReturnValue)) + .expect("missing return"); + let tail = &ops[return_idx.saturating_sub(5)..return_idx]; + + assert!( + tail.iter() + .any(|unit| matches!(unit.op, Instruction::LoadFastLoadFast { .. })), + "conditional-raise finally inside try/except should keep CPython-style strong final return pair after await, got tail={tail:?}", + ); + assert!( + tail.iter() + .all(|unit| !matches!(unit.op, Instruction::LoadFastBorrowLoadFastBorrow { .. })), + "conditional-raise finally inside try/except should not borrow final return pair after await, got tail={tail:?}", + ); + } + + #[test] + fn try_else_attribute_probe_end_allows_following_loads_borrow() { + let code = compile_exec( + "\ +def f(self): + args = (1,) + try: + getstate = self.__getstate__ + except AttributeError: + dict = None + else: + dict = getstate() + if dict: + return args, dict + return args +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let pair_arg = { + let args = f + .varnames + .iter() + .position(|name| name == "args") + .and_then(|idx| u8::try_from(idx).ok()) + .expect("missing args local"); + let dict = f + .varnames + .iter() + .position(|name| name == "dict") + .and_then(|idx| u8::try_from(idx).ok()) + .expect("missing dict local"); + (args << 4) | dict + }; + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + + assert!( + ops.iter().any(|unit| matches!( + unit.op, + Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) && u8::from(unit.arg) == pair_arg), + "CPython 3.14 optimize_load_fast() borrows args/dict after the try/else attribute-probe end; got ops={ops:?}", + ); + } + + #[test] + fn protected_import_tail_keeps_strong_load_fast() { let code = compile_exec( "\ def f(s, size, pos, errors): @@ -27741,7 +33923,7 @@ def f(s, size, pos, errors): } #[test] - fn test_nested_protected_import_tail_keeps_strong_load_fast() { + fn nested_protected_import_tail_keeps_strong_load_fast() { let code = compile_exec( "\ def f(self, mode, compresslevel): @@ -27799,7 +33981,7 @@ def f(self, mode, compresslevel): } #[test] - fn test_unprotected_import_before_with_keeps_borrow() { + fn unprotected_import_before_with_keeps_borrow() { let code = compile_exec( "\ def f(self, document): @@ -27847,7 +34029,7 @@ def f(self, document): } #[test] - fn test_from_import_after_conditional_store_join_uses_strong_prefix_loads() { + fn from_import_after_conditional_store_join_uses_strong_prefix_loads() { let code = compile_exec( "\ def f(x): @@ -27881,12 +34063,12 @@ def f(x): assert!( matches!(ops[rpartition_idx - 1].op, Instruction::LoadFast { .. }), - "CPython keeps the conditional-store join receiver strong before IMPORT_FROM, got ops={ops:?}" + "CPython optimize_load_fast() keeps the conditional-store join receiver strong before IMPORT_FROM, got ops={ops:?}" ); } #[test] - fn test_plain_import_after_conditional_store_join_keeps_borrow_prefix_loads() { + fn plain_import_after_conditional_store_join_keeps_borrow_prefix_loads() { let code = compile_exec( "\ def f(x): @@ -27928,7 +34110,7 @@ def f(x): } #[test] - fn test_unprotected_prefix_before_try_keeps_attr_subscript_borrow() { + fn unprotected_prefix_before_try_keeps_attr_subscript_borrow() { let code = compile_exec( "\ def f(): @@ -27970,7 +34152,7 @@ def f(): } #[test] - fn test_terminal_except_inlined_comprehension_keeps_borrowed_warm_loads() { + fn terminal_except_inlined_comprehension_keeps_borrowed_warm_loads() { let code = compile_exec( r##" def f(output): @@ -28031,7 +34213,7 @@ def f(output): } #[test] - fn test_outer_guarded_protected_import_keeps_borrow_tail() { + fn outer_guarded_protected_import_keeps_borrow_tail() { let code = compile_exec( "\ def f(sys, os, file): @@ -28074,7 +34256,7 @@ def f(sys, os, file): } #[test] - fn test_loop_or_break_continue_orders_break_before_backedge() { + fn loop_or_break_continue_orders_break_before_backedge() { let code = compile_exec( "\ def f(self, quoted): @@ -28137,7 +34319,36 @@ def f(self, quoted): } #[test] - fn test_for_continue_before_return_orders_backedge_before_return_body() { + fn backward_jump_extended_arg_accounts_for_jump_cache() { + let mut source = String::from("def f(x, items):\n while x:\n"); + for _ in 0..10 { + source.push_str(" x = len(items[-1])\n"); + } + for _ in 0..6 { + source.push_str(" len(items)\n"); + } + source.push_str(" continue\n"); + + let code = compile_exec(&source); + let f = find_code(&code, "f").expect("missing f code"); + assert!( + f.instructions.windows(2).any(|window| { + matches!( + (&window[0].op, &window[1].op), + ( + Instruction::ExtendedArg, + Instruction::JumpBackward { .. } + | Instruction::JumpBackwardNoInterrupt { .. } + ) + ) + }), + "CPython assemble.c resolves unconditional jumps before jump offsets, so the first offset pass must include JUMP_BACKWARD's inline cache and emit EXTENDED_ARG at this boundary; got instructions={:?}", + f.instructions + ); + } + + #[test] + fn for_continue_before_return_orders_backedge_before_return_body() { let code = compile_exec( "\ def f(self): @@ -28197,7 +34408,7 @@ def f(self): } #[test] - fn test_while_conditional_return_orders_backedge_before_return_body() { + fn while_conditional_return_orders_backedge_before_return_body() { let code = compile_exec( "\ def f(self, tag): @@ -28247,7 +34458,7 @@ def f(self, tag): } #[test] - fn test_while_boolop_conditional_return_splits_backedges_before_return_body() { + fn while_boolop_conditional_return_splits_backedges_before_return_body() { let code = compile_exec( "\ def f(flags, A, B, stop): @@ -28293,7 +34504,7 @@ def f(flags, A, B, stop): } #[test] - fn test_for_break_to_return_orders_backedge_before_return() { + fn for_break_to_return_orders_backedge_before_return() { let code = compile_exec( "\ def f(it): @@ -28352,7 +34563,7 @@ def f(it): } #[test] - fn test_for_conditional_raise_orders_backedge_before_raise() { + fn for_conditional_raise_orders_backedge_before_raise() { let code = compile_exec( "\ def f(items, limit): @@ -28408,7 +34619,7 @@ def f(items, limit): } #[test] - fn test_simple_for_conditional_raise_orders_backedge_before_raise() { + fn simple_for_conditional_raise_orders_backedge_before_raise() { let code = compile_exec( "\ def f(kw): @@ -28457,7 +34668,7 @@ def f(kw): } #[test] - fn test_loop_nested_boolop_exit_keeps_cpython_backedge_line_order() { + fn loop_nested_boolop_exit_keeps_cpython_backedge_line_order() { let code = compile_exec( "\ def f(found, value, m, done, name, renamed_variables, keep_unresolved, variables): @@ -28520,7 +34731,7 @@ def f(found, value, m, done, name, renamed_variables, keep_unresolved, variables } #[test] - fn test_loop_conditional_raise_before_elif_keeps_raise_before_backedge() { + fn loop_conditional_raise_before_elif_keeps_raise_before_backedge() { let code = compile_exec( "\ def f(checks, missing, named): @@ -28573,7 +34784,7 @@ def f(checks, missing, named): } #[test] - fn test_protected_for_is_none_raise_threads_backedge_before_raise() { + fn protected_for_is_none_raise_threads_backedge_before_raise() { let code = compile_exec( "\ def f(stacklevel, frame, skip_file_prefixes): @@ -28613,7 +34824,7 @@ def f(stacklevel, frame, skip_file_prefixes): } #[test] - fn test_exception_handler_loop_conditional_raise_orders_backedge_before_raise() { + fn exception_handler_loop_conditional_raise_orders_backedge_before_raise() { let code = compile_exec( "\ def f(chunk, dec, i): @@ -28668,7 +34879,7 @@ def f(chunk, dec, i): } #[test] - fn test_exception_handler_loop_conditional_return_orders_backedge_before_return() { + fn exception_handler_loop_conditional_return_orders_backedge_before_return() { let code = compile_exec( "\ def f(cls, value): @@ -28726,7 +34937,7 @@ def f(cls, value): } #[test] - fn test_loop_if_body_keeps_fallthrough_before_implicit_continue_backedge() { + fn loop_if_body_keeps_fallthrough_before_implicit_continue_backedge() { let code = compile_exec( "\ def f(b, curr, curr_append, decoded_append, packI, curr_clear): @@ -28786,7 +34997,7 @@ def f(b, curr, curr_append, decoded_append, packI, curr_clear): } #[test] - fn test_if_not_continue_before_conditional_listcomp_body_keeps_cpython_layout() { + fn if_not_continue_before_conditional_listcomp_body_keeps_cpython_layout() { let code = compile_exec( "\ def f(data, use): @@ -28828,7 +35039,7 @@ def f(data, use): } #[test] - fn test_chained_compare_continue_does_not_duplicate_cleanup_backedge() { + fn chained_compare_continue_does_not_duplicate_cleanup_backedge() { let code = compile_exec( "\ def f(items): @@ -28891,7 +35102,7 @@ def f(items): } #[test] - fn test_try_else_loop_if_body_keeps_cpython_fallthrough_before_backedge() { + fn try_else_loop_if_body_keeps_cpython_fallthrough_before_backedge() { let code = compile_exec( "\ def f(self, ready, selector, key, input_view, os, BrokenPipeError): @@ -28954,7 +35165,7 @@ def f(self, ready, selector, key, input_view, os, BrokenPipeError): } #[test] - fn test_try_else_after_conditional_raise_keeps_loop_if_body_before_backedge() { + fn try_else_after_conditional_raise_keeps_loop_if_body_before_backedge() { let code = compile_exec( "\ def f(seq, flag, stat, OSError, pred, SpecialFileError): @@ -29015,7 +35226,7 @@ def f(seq, flag, stat, OSError, pred, SpecialFileError): } #[test] - fn test_explicit_continue_after_return_orders_return_before_backedge() { + fn explicit_continue_after_return_orders_return_before_backedge() { let code = compile_exec( "\ def f(j, n): @@ -29070,7 +35281,35 @@ def f(j, n): } #[test] - fn test_implicit_while_tail_return_orders_backedge_before_return() { + fn while_break_tail_does_not_duplicate_loop_false_return_epilogue() { + let code = compile_exec( + "\ +def f(waiters): + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + break +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let returns = ops + .iter() + .filter(|unit| matches!(unit.op, Instruction::ReturnValue)) + .count(); + assert_eq!( + returns, 2, + "CPython codegen_while() reuses the empty anchor block for USE_LABEL(end) when orelse is empty, so only the break fallthrough and loop-false epilogues remain, got ops={ops:?}", + ); + } + + #[test] + fn implicit_while_tail_return_orders_backedge_before_return() { let code = compile_exec( "\ def f(self, j, n): @@ -29125,7 +35364,7 @@ def f(self, j, n): } #[test] - fn test_branch_arm_implicit_continue_keeps_return_before_backedge() { + fn branch_arm_implicit_continue_keeps_return_before_backedge() { let code = compile_exec( "\ def f(self, j, n, c): @@ -29183,7 +35422,7 @@ def f(self, j, n, c): } #[test] - fn test_nested_implicit_while_tail_return_orders_backedge_before_return() { + fn nested_implicit_while_tail_return_orders_backedge_before_return() { let code = compile_exec( "\ def f(self, rawdata, j, match): @@ -29243,7 +35482,7 @@ def f(self, rawdata, j, match): } #[test] - fn test_join_store_global_before_import_keeps_strong_load_fast() { + fn join_store_global_before_import_keeps_strong_load_fast() { let code = compile_exec( "\ def f(module=None): @@ -29277,7 +35516,7 @@ def f(module=None): } #[test] - fn test_handler_resume_join_keeps_borrow_in_common_tail() { + fn handler_resume_join_keeps_borrow_in_common_tail() { let code = compile_exec( "\ def f(p, errors, s, pos, look, final, escape_start, st): @@ -29342,7 +35581,7 @@ def f(p, errors, s, pos, look, final, escape_start, st): } #[test] - fn test_multi_handler_guarded_resume_tail_keeps_borrow() { + fn multi_handler_guarded_resume_tail_keeps_borrow() { let code = compile_exec( "\ def f(a): @@ -29395,7 +35634,7 @@ def f(a): } #[test] - fn test_multi_handler_method_tail_keeps_borrow() { + fn multi_handler_method_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, xs): @@ -29435,7 +35674,7 @@ def f(self, xs): } #[test] - fn test_named_except_cleanup_loop_header_keeps_borrow_in_for_loop() { + fn named_except_cleanup_loop_header_keeps_borrow_in_for_loop() { let code = compile_exec( "\ def f(args): @@ -29478,7 +35717,7 @@ def f(args): } #[test] - fn test_multi_named_except_loop_header_keeps_borrow_for_normal_path() { + fn multi_named_except_loop_header_keeps_borrow_for_normal_path() { let code = compile_exec( "\ def f(self): @@ -29535,7 +35774,7 @@ def f(self): } #[test] - fn test_named_except_cleanup_simple_resume_tail_keeps_borrow() { + fn named_except_cleanup_simple_resume_tail_keeps_borrow() { let code = compile_exec( "\ def f(self): @@ -29581,7 +35820,7 @@ def f(self): } #[test] - fn test_named_except_cleanup_conditional_raise_tail_keeps_borrow() { + fn named_except_cleanup_conditional_raise_tail_keeps_borrow() { let code = compile_exec( "\ def f(self): @@ -29629,7 +35868,7 @@ def f(self): } #[test] - fn test_with_suppress_named_except_resume_tail_uses_strong_loads() { + fn with_suppress_named_except_resume_tail_uses_strong_loads() { let code = compile_exec( "\ def f(self, cm, E): @@ -29682,7 +35921,7 @@ def f(self, cm, E): } #[test] - fn test_with_named_except_return_value_keeps_borrow() { + fn with_named_except_return_value_keeps_borrow() { let code = compile_exec( "\ def f(self, b, BlockingIOError): @@ -29743,7 +35982,7 @@ def f(self, b, BlockingIOError): } #[test] - fn test_with_final_conditional_return_preserves_fallthrough_cleanup_nop() { + fn with_final_conditional_return_preserves_fallthrough_cleanup_nop() { let code = compile_exec( "\ def f(self): @@ -29780,7 +36019,7 @@ def f(self): } #[test] - fn test_with_while_fallthrough_preserves_cleanup_nop() { + fn with_while_fallthrough_preserves_cleanup_nop() { let code = compile_exec( "\ def f(cm, source): @@ -29819,7 +36058,61 @@ def f(cm, source): } #[test] - fn test_with_while_true_break_drops_cleanup_nop() { + fn with_for_fallthrough_drops_cleanup_nop() { + let code = compile_exec( + "\ +def f(cm, xs, g): + with cm: + for x in xs: + g(x) + return None +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .map(|unit| unit.op) + .collect(); + + assert!( + ops.windows(6).any(|window| { + matches!( + window, + [ + Instruction::EndFor, + Instruction::PopIter, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + Instruction::Call { .. }, + ] + ) + }), + "with cleanup after for fallthrough should directly follow END_FOR/POP_ITER like CPython, got ops={ops:?}" + ); + assert!( + !ops.windows(7).any(|window| { + matches!( + window, + [ + Instruction::EndFor, + Instruction::PopIter, + Instruction::Nop, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + Instruction::Call { .. }, + ] + ) + }), + "with cleanup after for fallthrough should not preserve a POP_BLOCK NOP, got ops={ops:?}" + ); + } + + #[test] + fn with_while_true_break_drops_cleanup_nop() { let code = compile_exec( "\ def f(cm, source): @@ -29858,7 +36151,63 @@ def f(cm, source): } #[test] - fn test_with_final_assert_preserves_cleanup_nop() { + fn multi_with_while_true_try_except_drops_outer_cleanup_nop() { + let code = compile_exec( + "\ +def f(cm1, cm2, g, E): + with cm1, cm2: + while True: + try: + g() + except E: + pass +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .map(|unit| unit.op) + .collect(); + + assert!( + ops.windows(6).any(|window| { + matches!( + window, + [ + Instruction::Copy { .. }, + Instruction::PopExcept, + Instruction::Reraise { .. }, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + ] + ) + }), + "outer with cleanup after an infinite inner with body should follow the inner cleanup directly like CPython, got ops={ops:?}" + ); + assert!( + !ops.windows(7).any(|window| { + matches!( + window, + [ + Instruction::Copy { .. }, + Instruction::PopExcept, + Instruction::Reraise { .. }, + Instruction::Nop, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + Instruction::LoadConst { .. }, + ] + ) + }), + "outer with cleanup after an infinite inner with body should not keep a POP_BLOCK NOP, got ops={ops:?}" + ); + } + + #[test] + fn with_final_assert_preserves_cleanup_nop() { let code = compile_exec( "\ def f(cm, dst): @@ -29893,7 +36242,7 @@ def f(cm, dst): } #[test] - fn test_named_except_conditional_reraise_final_store_attr_keeps_borrow() { + fn named_except_conditional_reraise_final_store_attr_keeps_borrow() { let code = compile_exec( "\ def f(self, fd, file, closefd, owned_fd, OSError, AttributeError, errno, os, stat, _setmode): @@ -29979,7 +36328,7 @@ def f(self, fd, file, closefd, owned_fd, OSError, AttributeError, errno, os, sta } #[test] - fn test_with_except_else_with_resume_loop_tail_uses_strong_loads() { + fn with_except_else_with_resume_loop_tail_uses_strong_loads() { let code = compile_exec( "\ def f(self, cm, E): @@ -30031,7 +36380,426 @@ def f(self, cm, E): } #[test] - fn test_plain_with_then_global_loop_tail_keeps_borrow() { + fn final_with_try_except_resume_loop_tail_uses_strong_loads() { + let code = compile_exec( + r#" +def f(resources, valid_zones, TZPATH, os): + try: + with resources.open("r") as f: + pass + except Exception: + pass + for tz_root in TZPATH: + if not os.path.exists(tz_root): + continue + valid_zones.add(tz_root) + return valid_zones +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let get_iter = instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::GetIter)) + .expect("missing post-try loop iterator"); + let tail = &instructions[get_iter.saturating_sub(1)..]; + let load_fast_name = |unit: &&bytecode::CodeUnit| match unit.op { + Instruction::LoadFast { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + Some(f.varnames[usize::from(var_num.get(arg))].as_str()) + } + _ => None, + }; + let borrowed_name = |unit: &&bytecode::CodeUnit| match unit.op { + Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + Some(f.varnames[usize::from(var_num.get(arg))].as_str()) + } + _ => None, + }; + + for name in ["TZPATH", "os", "tz_root", "valid_zones"] { + assert!( + tail.iter() + .filter_map(load_fast_name) + .any(|loaded| loaded == name), + "expected CPython-style strong LOAD_FAST for {name} after final with/except resume, got tail={tail:?}", + ); + assert!( + tail.iter() + .filter_map(borrowed_name) + .all(|loaded| loaded != name), + "final with/except resume loop tail should not borrow {name}, got tail={tail:?}", + ); + } + } + + #[test] + fn finally_ending_try_except_resume_tail_uses_strong_loads() { + let code = compile_exec( + r#" +def f(self, fobj, unlink, TESTFN, C): + try: + fobj.write(1) + finally: + fobj.close() + try: + unlink(TESTFN) + except OSError: + pass + a, b = C(2), C(3) + self.assertEqual((a, b), (1, 2)) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let assert_equal = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() + == "assertEqual" + } + _ => false, + }) + .expect("missing assertEqual load"); + let handler_start = instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .expect("missing exception path"); + let tail = &instructions[assert_equal.saturating_sub(1)..handler_start]; + let is_strong_pair = |unit: &&bytecode::CodeUnit, left_name: &str, right_name: &str| { + let Instruction::LoadFastLoadFast { var_nums } = unit.op else { + return false; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == left_name + && f.varnames[usize::from(right)] == right_name + }; + + assert!( + tail.iter() + .any(|unit| matches!(unit.op, Instruction::LoadFast { .. })), + "expected finally/try-except resume tail to use strong LOAD_FAST ops, got tail={tail:?}" + ); + assert!( + tail.iter().any(|unit| is_strong_pair(unit, "a", "b")), + "expected finally/try-except resume tuple to use strong LOAD_FAST_LOAD_FAST, got tail={tail:?}" + ); + assert!( + tail.iter().all(|unit| { + !matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "finally/try-except resume tail should not borrow LOAD_FAST ops, got tail={tail:?}" + ); + } + + #[test] + fn try_finally_bare_reraise_handler_resume_tail_uses_strong_loads() { + let code = compile_exec( + "\ +def f(self, os, alive_r, alive_w, address, pid): + try: + pid = g() + except: + os.close(alive_w) + raise + finally: + os.close(alive_r) + self.address = address + self.alive_w = alive_w + self.pid = pid +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let close_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "close" + } + _ => false, + }) + .expect("missing close load"); + let handler_start = instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .expect("missing exception path"); + let tail = &instructions[close_attr.saturating_sub(1)..handler_start]; + + assert!( + tail.iter() + .any(|unit| matches!(unit.op, Instruction::LoadFast { .. })), + "bare-reraise try/finally resume tail should use strong LOAD_FAST ops, got tail={tail:?}" + ); + assert!( + tail.iter().all(|unit| { + !matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "bare-reraise try/finally resume tail should not borrow LOAD_FAST ops, got tail={tail:?}" + ); + } + + #[test] + fn typed_except_return_resume_tail_uses_strong_loads() { + let code = compile_exec( + "\ +def f(resource, desired_fds, max_fds): + try: + import math + except ImportError: + return None + fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + if fd_limit < desired_fds and fd_limit < max_fds: + return desired_fds, max_fds + return fd_limit +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let getrlimit_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "getrlimit" + } + _ => false, + }) + .expect("missing getrlimit load"); + let handler_start = instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .expect("missing exception path"); + let tail = &instructions[getrlimit_attr.saturating_sub(1)..handler_start]; + + assert!( + tail.iter() + .any(|unit| matches!(unit.op, Instruction::LoadFast { .. })), + "typed except-return resume tail should use strong LOAD_FAST ops, got tail={tail:?}" + ); + assert!( + tail.iter().all(|unit| { + !matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "typed except-return resume tail should not borrow LOAD_FAST ops, got tail={tail:?}" + ); + } + + #[test] + fn resuming_except_before_try_preserves_next_try_entry_barrier() { + let code = compile_exec( + "\ +def f(scan_once, s, end, _ws, _w): + try: + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + try: + value, end = scan_once(s, end) + except StopIteration as err: + raise ValueError(s, err.value) from None + return value, end +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let scan_once_load = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.varnames[usize::from(var_num.get(arg))].as_str() == "scan_once" + } + _ => false, + }) + .expect("missing scan_once load"); + let scan_tail = &instructions[scan_once_load..scan_once_load + 4]; + + assert!( + scan_tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFast { .. } | Instruction::LoadFastLoadFast { .. } + ) + }), + "resuming except before another try should enter next try with strong LOAD_FAST ops, got scan_tail={scan_tail:?}" + ); + assert!( + scan_tail.iter().all(|unit| { + !matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "resuming except before another try should not borrow next try entry loads, got scan_tail={scan_tail:?}" + ); + } + + #[test] + fn simple_except_before_try_keeps_next_try_entry_borrowed() { + let code = compile_exec( + "\ +def f(scan_once, s, end): + try: + g(s, end) + except IndexError: + pass + try: + value, end = scan_once(s, end) + except StopIteration: + pass + return value, end +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let scan_once_load = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.varnames[usize::from(var_num.get(arg))].as_str() == "scan_once" + } + _ => false, + }) + .expect("missing scan_once load"); + let scan_tail = &instructions[scan_once_load..scan_once_load + 4]; + + assert!( + scan_tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "simple except before another try should keep next try entry borrowed, got scan_tail={scan_tail:?}" + ); + assert!( + !scan_tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFast { .. } | Instruction::LoadFastLoadFast { .. } + ) + }), + "simple except before another try should not force strong LOAD_FAST, got scan_tail={scan_tail:?}" + ); + } + + #[test] + fn loop_break_except_before_try_preserves_next_try_entry_barrier() { + let code = compile_exec( + "\ +def f(scan_once, seq1, seq2, n): + for i in range(n): + try: + item1 = seq1[i] + except (TypeError, IndexError, NotImplementedError): + break + try: + item2 = seq2[i] + except (TypeError, IndexError, NotImplementedError): + break + return item1, item2 +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let second_subscript = instructions + .iter() + .position(|unit| { + matches!( + unit.op, + Instruction::BinaryOp { op } + if op.get(OpArg::new(u32::from(u8::from(unit.arg)))) + == oparg::BinaryOperator::Subscr + ) + }) + .expect("missing first subscript"); + let second_subscript = instructions[second_subscript + 1..] + .iter() + .position(|unit| { + matches!( + unit.op, + Instruction::BinaryOp { op } + if op.get(OpArg::new(u32::from(u8::from(unit.arg)))) + == oparg::BinaryOperator::Subscr + ) + }) + .map(|idx| idx + second_subscript + 1) + .expect("missing second subscript"); + let scan_tail = &instructions[second_subscript.saturating_sub(2)..second_subscript]; + + assert!( + scan_tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFast { .. } | Instruction::LoadFastLoadFast { .. } + ) + }), + "loop break except before another try should keep next try entry strong, got scan_tail={scan_tail:?}" + ); + assert!( + !scan_tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) + }), + "loop break except before another try should not borrow next try entry loads, got scan_tail={scan_tail:?}" + ); + } + + #[test] + fn plain_with_then_global_loop_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, cm): @@ -30070,7 +36838,7 @@ def f(self, cm): } #[test] - fn test_context_manager_for_join_tail_keeps_borrow() { + fn context_manager_for_join_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, factory): @@ -30110,7 +36878,7 @@ def f(self, factory): } #[test] - fn test_with_except_resume_normal_tail_uses_strong_loads() { + fn with_except_resume_normal_tail_uses_strong_loads() { let code = compile_exec( "\ def f(self, cm, E): @@ -30152,7 +36920,7 @@ def f(self, cm, E): } #[test] - fn test_with_except_else_attr_subscript_tail_keeps_borrow() { + fn with_except_else_attr_subscript_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, cm, E, obj): @@ -30206,7 +36974,7 @@ def f(self, cm, E, obj): } #[test] - fn test_with_suppress_attr_subscript_tail_keeps_borrow() { + fn with_suppress_attr_subscript_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, cm): @@ -30247,7 +37015,7 @@ def f(self, cm): } #[test] - fn test_named_except_conditional_reraise_deopts_with_chain_tail() { + fn named_except_conditional_reraise_deopts_with_chain_tail() { let code = compile_exec( "\ def f(self, arc, tmp_filename, new_mode): @@ -30306,7 +37074,51 @@ def f(self, arc, tmp_filename, new_mode): } #[test] - fn test_terminal_except_before_with_deopts_with_body_borrows() { + fn terminal_bare_reraise_successor_join_keeps_final_store_borrow() { + let code = compile_exec( + "\ +def f(self, fd, appending, errno): + try: + if appending: + try: + seek(fd) + except OSError as e: + if e.errno != errno: + raise + except: + self.stat = None + raise + self._fd = fd +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let store_fd = instructions + .iter() + .position(|unit| match unit.op { + Instruction::StoreAttr { namei } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.names[usize::try_from(namei.get(arg)).unwrap()].as_str() == "_fd" + } + _ => false, + }) + .expect("missing _fd STORE_ATTR"); + + assert!( + matches!( + instructions[store_fd - 1].op, + Instruction::LoadFastBorrowLoadFastBorrow { .. } + ), + "terminal bare-reraise body successor join should keep CPython-style borrowed final store pair, got instructions={instructions:?}" + ); + } + + #[test] + fn terminal_except_before_with_deopts_with_body_borrows() { let code = compile_exec( "\ def f(self, cm): @@ -30360,7 +37172,7 @@ def f(self, cm): } #[test] - fn test_terminal_except_resume_tail_uses_strong_loads() { + fn terminal_except_resume_tail_uses_strong_loads() { let code = compile_exec( "\ def f(re, proc, unittest): @@ -30438,7 +37250,7 @@ def f(re, proc, unittest): } #[test] - fn test_terminal_except_conditional_return_tail_uses_strong_loads() { + fn terminal_except_conditional_return_tail_uses_strong_loads() { let code = compile_exec( "\ def f(param, value, quote): @@ -30505,7 +37317,7 @@ def f(param, value, quote): } #[test] - fn test_terminal_except_successor_call_tail_uses_strong_load() { + fn terminal_except_successor_call_tail_uses_strong_load() { let code = compile_exec( "\ def f(curr, decoded_append, packI, curr_clear, Error): @@ -30549,7 +37361,236 @@ def f(curr, decoded_append, packI, curr_clear, Error): } #[test] - fn test_terminal_except_following_if_tail_uses_strong_loads() { + fn loop_terminal_except_continue_if_tail_keeps_borrowed_loads() { + let code = compile_exec( + "\ +def f(self, parser, opt, accum, rest, section, map, path, depth): + while rest: + rawval = rest.pop() + try: + if len(path) == 1: + opt = parser.optionxform(path[0]) + v = map[opt] + elif len(path) == 2: + sect = path[0] + opt = parser.optionxform(path[1]) + v = parser.get(sect, opt, raw=True) + else: + raise InterpolationSyntaxError(option, section, 'x') + except (KeyError, NoSectionError, NoOptionError): + raise InterpolationMissingOptionError(option, section, rawval, ':'.join(path)) from None + if v is None: + continue + if '$' in v: + self._interpolate_some(parser, opt, accum, v, sect, dict(parser.items(sect, raw=True)), depth + 1) + else: + accum.append(v) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let handler_start = ops + .iter() + .position(|unit| matches!(unit.op, Instruction::PushExcInfo)) + .expect("missing handler entry"); + let warm_path = &ops[..handler_start]; + let post_try_if = warm_path + .iter() + .position(|unit| matches!(unit.op, Instruction::PopJumpIfNotNone { .. })) + .expect("missing post-try if"); + let tail = &warm_path[post_try_if.saturating_sub(1)..]; + + let mentions_name = |unit: &CodeUnit, name: &str| match unit.op { + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.varnames[usize::from(var_num.get(arg))] == name + } + Instruction::LoadFastLoadFast { var_nums } + | Instruction::LoadFastBorrowLoadFastBorrow { var_nums } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + let (left, right) = var_nums.get(arg).indexes(); + f.varnames[usize::from(left)] == name || f.varnames[usize::from(right)] == name + } + _ => false, + }; + let borrows_name = |name: &str| { + tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFastBorrow { .. } + | Instruction::LoadFastBorrowLoadFastBorrow { .. } + ) && mentions_name(unit, name) + }) + }; + let strong_loads_name = |name: &str| { + tail.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadFast { .. } | Instruction::LoadFastLoadFast { .. } + ) && mentions_name(unit, name) + }) + }; + + for name in ["v", "self", "parser", "opt", "accum", "depth"] { + assert!( + borrows_name(name), + "CPython keeps loop terminal-except continue-if tail borrowed for {name}, got tail={tail:?}" + ); + assert!( + !strong_loads_name(name), + "loop terminal-except continue-if tail should not deopt {name}, got tail={tail:?}" + ); + } + } + + #[test] + fn method_call_try_return_handler_keeps_following_receiver_borrowed() { + let code = compile_exec( + "\ +def f(charset, failobj, E): + try: + charset.encode('us-ascii') + except E: + return failobj + return charset.lower() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let lower_attr = ops + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "lower" + } + _ => false, + }) + .expect("missing lower attr"); + let receiver = &ops[lower_attr - 1]; + + assert!( + matches!(receiver.op, Instruction::LoadFastBorrow { .. }), + "CPython codegen_try_except() does not leave a load-fast barrier after method-call try body when the handler returns, got ops={ops:?}" + ); + } + + #[test] + fn typed_terminal_method_call_try_deopts_successor_call_args() { + let code = compile_exec( + "\ +def f(events, callback, args, self, sig, signal): + try: + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError: + raise RuntimeError('bad signal') from None + handle = events.Handle(callback, args, self, None) + self._signal_handlers[sig] = handle +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let handle_store = ops + .iter() + .position(|unit| { + matches!( + unit.op, + Instruction::StoreFast { var_num } + if f.varnames[usize::from( + var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))) + )] == "handle" + ) + }) + .expect("missing handle store"); + let handle_call_window = &ops[handle_store.saturating_sub(4)..handle_store]; + + assert!( + matches!( + handle_call_window, + [ + bytecode::CodeUnit { + op: Instruction::LoadFastLoadFast { .. }, + .. + }, + bytecode::CodeUnit { + op: Instruction::LoadFast { .. }, + .. + }, + bytecode::CodeUnit { + op: Instruction::LoadConst { .. }, + .. + }, + bytecode::CodeUnit { + op: Instruction::Call { .. }, + .. + }, + ] + ), + "CPython codegen_try_except() leaves a USE_LABEL(end) continuation after the terminal typed handler; successor call args should be strong loads, got window={handle_call_window:?}; ops={ops:?}" + ); + } + + #[test] + fn typed_terminal_unpack_call_try_deopts_successor_call_args() { + let code = compile_exec( + "\ +def f(self, OSError): + try: + request, client_address = self.get_request() + except OSError: + return + if self.verify_request(request, client_address): + self.process_request(request, client_address) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let verify_attr = ops + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + f.names[usize::try_from(namei.get(arg).name_idx()).unwrap()].as_str() + == "verify_request" + } + _ => false, + }) + .expect("missing verify_request attr"); + let request_pair = &ops[verify_attr + 1]; + + assert!( + matches!( + request_pair.op, + Instruction::LoadFastLoadFast { var_nums } + if { + let arg = OpArg::new(u32::from(u8::from(request_pair.arg))); + let pair = var_nums.get(arg).as_u32(); + f.varnames[(pair >> 4) as usize].as_str() == "request" + && f.varnames[(pair & 0xF) as usize].as_str() == "client_address" + } + ), + "CPython codegen_try_except() keeps successor call args strong after terminal typed unpack-call try body, got ops={ops:?}" + ); + } + + #[test] + fn terminal_except_following_if_tail_uses_strong_loads() { let code = compile_exec( "\ def f(s): @@ -30611,7 +37652,7 @@ def f(s): } #[test] - fn test_bare_except_internal_condition_keeps_try_body_borrows() { + fn bare_except_internal_condition_keeps_try_body_borrows() { let code = compile_exec( "\ def f(buffering, raw, binary, result, BufferedReader): @@ -30692,7 +37733,7 @@ def f(buffering, raw, binary, result, BufferedReader): } #[test] - fn test_try_except_else_terminal_handler_conditional_tail_uses_strong_loads() { + fn try_except_else_terminal_handler_conditional_tail_uses_strong_loads() { let code = compile_exec( "\ def f(self, pos, whence): @@ -30779,7 +37820,7 @@ def f(self, pos, whence): } #[test] - fn test_try_except_else_outer_join_keeps_borrowed_loads() { + fn try_except_else_outer_join_keeps_borrowed_loads() { let code = compile_exec( "\ def f(self, pos=None): @@ -30861,7 +37902,7 @@ def f(self, pos=None): } #[test] - fn test_terminal_except_else_final_store_attr_tail_uses_strong_loads() { + fn terminal_except_else_final_store_attr_tail_uses_strong_loads() { let code = compile_exec( "\ def f(self, E, Event): @@ -30921,7 +37962,7 @@ def f(self, E, Event): } #[test] - fn test_except_break_try_else_loop_tail_keeps_else_borrows() { + fn except_break_try_else_loop_tail_keeps_else_borrows() { let code = compile_exec( "\ def f(self): @@ -30979,7 +38020,7 @@ def f(self): } #[test] - fn test_protected_method_call_after_terminal_except_tail_uses_strong_loads() { + fn protected_method_call_after_terminal_except_tail_uses_strong_loads() { let code = compile_exec( "\ def f(items, chunk, out, packI, Error): @@ -31050,7 +38091,7 @@ def f(items, chunk, out, packI, Error): } #[test] - fn test_terminal_reraising_handler_keeps_try_body_method_borrows() { + fn terminal_reraising_handler_keeps_try_body_method_borrows() { let code = compile_exec( "\ def f(self): @@ -31098,7 +38139,7 @@ def f(self): } #[test] - fn test_terminal_except_loop_successor_augassign_uses_strong_load_pair() { + fn terminal_except_loop_successor_augassign_uses_strong_load_pair() { let code = compile_exec( "\ def f(items, decoded, b32rev): @@ -31139,7 +38180,7 @@ def f(items, decoded, b32rev): } #[test] - fn test_terminal_except_loop_backedge_keeps_header_borrows() { + fn terminal_except_loop_backedge_keeps_header_borrows() { let code = compile_exec( "\ def f(self, value, start=0, stop=None): @@ -31180,7 +38221,95 @@ def f(self, value, start=0, stop=None): } #[test] - fn test_loop_if_implicit_continue_places_body_after_jumpback() { + fn one_line_protected_infinite_while_body_uses_strong_pair() { + let code = compile_exec( + "\ +def f(): + items = range(1, 4) + try: + i = 0 + while 1: i = items[i] + except IndexError: + pass +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + + assert!( + instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadFastLoadFast { .. })), + "one-line protected infinite while body should match CPython strong pair, got instructions={instructions:?}" + ); + assert!( + !instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadFastBorrowLoadFastBorrow { .. })), + "one-line protected infinite while body should not borrow the loop body pair, got instructions={instructions:?}" + ); + } + + #[test] + fn multiline_protected_infinite_while_body_keeps_borrow_pair() { + let code = compile_exec( + "\ +def f(items): + try: + i = 0 + while 1: + i = items[i] + except IndexError: + pass +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + + assert!( + instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadFastBorrowLoadFastBorrow { .. })), + "multiline protected infinite while body should keep CPython borrowed pair, got instructions={instructions:?}" + ); + } + + #[test] + fn one_line_protected_infinite_while_method_call_keeps_borrow_receiver() { + let code = compile_exec( + "\ +def f(self): + try: + while 1: self.x() + except IndexError: + pass +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + + assert!( + instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::LoadFast { .. })), + "CPython 3.14 keeps the one-line protected bound-method receiver strong, got instructions={instructions:?}" + ); + } + + #[test] + fn loop_if_implicit_continue_places_body_after_jumpback() { let code = compile_exec( "\ def f(_config_vars, _INITPRE): @@ -31220,7 +38349,7 @@ def f(_config_vars, _INITPRE): } #[test] - fn test_loop_if_call_body_implicit_continue_places_body_after_jumpback() { + fn loop_if_call_body_implicit_continue_places_body_after_jumpback() { let code = compile_exec( "\ def f(seq, db): @@ -31264,7 +38393,7 @@ def f(seq, db): } #[test] - fn test_nested_loop_if_try_body_implicit_continue_places_body_after_jumpback() { + fn nested_loop_if_try_body_implicit_continue_places_body_after_jumpback() { let code = compile_exec( "\ def f(seq, broken, codecs, LookupError, s, Queue, bytes): @@ -31347,7 +38476,7 @@ def f(seq, broken, codecs, LookupError, s, Queue, bytes): } #[test] - fn test_loop_branch_raise_before_elif_keeps_body_before_backedge() { + fn loop_branch_raise_before_elif_keeps_body_before_backedge() { let code = compile_exec( "\ def f(checks, UNIQUE, CONTINUOUS, ValueError): @@ -31407,7 +38536,7 @@ def f(checks, UNIQUE, CONTINUOUS, ValueError): } #[test] - fn test_loop_nested_raise_then_append_places_body_after_false_backedge() { + fn loop_nested_raise_then_append_places_body_after_false_backedge() { let code = compile_exec( "\ def f(args, parameters, enforce_default_ordering, type_var_tuple_encountered, default_encountered, TypeError): @@ -31467,7 +38596,7 @@ def f(args, parameters, enforce_default_ordering, type_var_tuple_encountered, de } #[test] - fn test_loop_break_before_adjacent_break_keeps_body_before_backedge() { + fn loop_break_before_adjacent_break_keeps_body_before_backedge() { let code = compile_exec( "\ def f(pattern, prefix, get_prefix): @@ -31530,7 +38659,7 @@ def f(pattern, prefix, get_prefix): } #[test] - fn test_loop_elif_and_pass_keeps_shared_false_backedge_after_body() { + fn loop_elif_and_pass_keeps_shared_false_backedge_after_body() { let code = compile_exec( "\ def f(methods, simple_keys, checked_keys, checked_enum, simple_enum, failed): @@ -31596,7 +38725,7 @@ def f(methods, simple_keys, checked_keys, checked_enum, simple_enum, failed): } #[test] - fn test_loop_nested_if_delete_slice_places_body_after_jumpback() { + fn loop_nested_if_delete_slice_places_body_after_jumpback() { let code = compile_exec( "\ def f(compiler_so): @@ -31645,7 +38774,7 @@ def f(compiler_so): } #[test] - fn test_loop_if_subscr_store_delete_places_body_after_jumpback() { + fn loop_if_subscr_store_delete_places_body_after_jumpback() { let code = compile_exec( "\ def f(chunks): @@ -31685,7 +38814,7 @@ def f(chunks): } #[test] - fn test_final_elif_implicit_continue_places_jumpback_before_body() { + fn final_elif_implicit_continue_places_jumpback_before_body() { let code = compile_exec( "\ def f(state, nextchar, whitespace, token, posix, quoted, debug): @@ -31746,7 +38875,7 @@ def f(state, nextchar, whitespace, token, posix, quoted, debug): } #[test] - fn test_final_attribute_elif_implicit_continue_places_jumpback_before_body() { + fn final_attribute_elif_implicit_continue_places_jumpback_before_body() { let code = compile_exec( "\ def f(self, nextchar, quoted): @@ -31795,7 +38924,7 @@ def f(self, nextchar, quoted): } #[test] - fn test_inner_if_implicit_continue_keeps_line_bearing_body_before_backedge() { + fn inner_if_implicit_continue_keeps_line_bearing_body_before_backedge() { let code = compile_exec( "\ def f(self, nextchar, quoted): @@ -31854,7 +38983,7 @@ def f(self, nextchar, quoted): } #[test] - fn test_except_handler_with_conditional_raise_and_resume_keeps_borrow() { + fn except_handler_with_conditional_raise_and_resume_keeps_borrow() { let code = compile_exec( "\ def f(formatstr, args, output, overflowok): @@ -31909,7 +39038,7 @@ def f(formatstr, args, output, overflowok): } #[test] - fn test_typed_except_resume_import_warning_tail_keeps_borrows() { + fn typed_except_resume_import_warning_tail_keeps_borrows() { let code = compile_exec( r#" def f(mod_name, error, sys, RuntimeWarning): @@ -31992,7 +39121,7 @@ def f(mod_name, error, sys, RuntimeWarning): } #[test] - fn test_reraising_except_else_tail_keeps_borrow() { + fn reraising_except_else_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, data, length): @@ -32050,7 +39179,7 @@ def f(self, data, length): } #[test] - fn test_try_else_finally_with_keeps_context_manager_borrow() { + fn try_else_finally_with_keeps_context_manager_borrow() { let code = compile_exec( "\ def f(i): @@ -32093,7 +39222,7 @@ def f(i): } #[test] - fn test_except_star_handler_pop_block_does_not_leave_nop_before_with_exit() { + fn except_star_handler_pop_block_does_not_leave_nop_before_with_exit() { let code = compile_exec( "\ def f(self): @@ -32130,7 +39259,7 @@ def f(self): } #[test] - fn test_except_star_body_to_else_jump_drops_without_line_nop() { + fn except_star_body_to_else_jump_drops_without_line_nop() { let code = compile_exec( "\ async def f(self, cm): @@ -32172,7 +39301,7 @@ async def f(self, cm): } #[test] - fn test_resuming_except_before_with_keeps_with_body_borrows() { + fn resuming_except_before_with_keeps_with_body_borrows() { let code = compile_exec( "\ def f(self, cm): @@ -32219,7 +39348,7 @@ def f(self, cm): } #[test] - fn test_nested_finally_except_resume_loop_uses_strong_loads() { + fn nested_finally_except_resume_loop_uses_strong_loads() { let code = compile_exec( "\ def f(self, xs): @@ -32269,7 +39398,7 @@ def f(self, xs): } #[test] - fn test_finally_protected_loop_without_except_resume_keeps_borrows() { + fn finally_protected_loop_without_except_resume_keeps_borrows() { let code = compile_exec( "\ def f(self, obj, expected, buf): @@ -32338,7 +39467,7 @@ def f(self, obj, expected, buf): } #[test] - fn test_plain_except_resume_loop_keeps_borrows() { + fn plain_except_resume_loop_keeps_borrows() { let code = compile_exec( "\ def f(self, xs): @@ -32379,7 +39508,7 @@ def f(self, xs): } #[test] - fn test_except_pass_resume_loop_branch_keeps_borrows() { + fn except_pass_resume_loop_branch_keeps_borrows() { let code = compile_exec( r#" def f(self, cls, fns): @@ -32468,7 +39597,7 @@ def f(self, cls, fns): } #[test] - fn test_named_except_cleanup_deopts_same_guard_fallbacks_not_outer_tail() { + fn named_except_cleanup_deopts_same_guard_fallbacks_not_outer_tail() { let code = compile_exec( r#" def f(s, size, errors, final): @@ -32582,7 +39711,7 @@ def f(s, size, errors, final): } #[test] - fn test_imap_idle_status_debug_tail_keeps_borrow() { + fn imap_idle_status_debug_tail_keeps_borrow() { let code = compile_exec( "\ def f(self, exc_type, CRLF, OSError): @@ -32624,7 +39753,7 @@ def f(self, exc_type, CRLF, OSError): } #[test] - fn test_match_async_comprehension_iter_keeps_capture_borrow() { + fn match_async_comprehension_iter_keeps_capture_borrow() { let code = compile_exec( r#" async def name_4(): @@ -32658,7 +39787,92 @@ async def name_4(): } #[test] - fn test_with_protected_generator_tail_after_cleanup_uses_strong_loads() { + fn match_fail_cleanup_label_reuse_keeps_post_match_borrow() { + let code = compile_exec( + r#" +def f(self): + match 0: + case 0: + x = True + self.assertIs(x, True) +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let assert_is_attr = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadAttr { namei } => { + let load_attr = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))); + f.names[usize::try_from(load_attr.name_idx()).unwrap()].as_str() == "assertIs" + } + _ => false, + }) + .expect("missing assertIs attr load"); + let receiver = &instructions[assert_is_attr - 1]; + assert!( + matches!( + receiver.op, + Instruction::LoadFastBorrow { var_num } + if f.varnames[usize::from(var_num.get(OpArg::new(u32::from(u8::from(receiver.arg)))))] == "self" + ), + "CPython codegen_match_inner uses USE_LABEL(c, end), so post-match receiver should stay borrowed; got {receiver:?}" + ); + } + + #[test] + fn match_or_preserves_explicit_success_jumps() { + let code = compile_exec( + r#" +def f(w): + match w: + case 1 | 2 | 3: + out = locals() + del out["w"] + return out +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + let instructions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect(); + let locals_pos = instructions + .iter() + .position(|unit| match unit.op { + Instruction::LoadGlobal { namei } => { + let name = namei.get(OpArg::new(u32::from(u8::from(unit.arg)))) >> 1; + f.names[usize::try_from(name).unwrap()].as_str() == "locals" + } + _ => false, + }) + .expect("missing locals load"); + let pattern_prefix = &instructions[..locals_pos]; + let false_jumps = pattern_prefix + .iter() + .filter(|unit| matches!(unit.op, Instruction::PopJumpIfFalse { .. })) + .count(); + let success_jumps = pattern_prefix + .iter() + .filter(|unit| matches!(unit.op, Instruction::JumpForward { .. })) + .count(); + assert_eq!( + false_jumps, 3, + "CPython codegen_pattern_or() keeps each alternative as false-jump plus success jump; got prefix={pattern_prefix:?}" + ); + assert_eq!( + success_jumps, 3, + "CPython codegen_pattern_or() keeps explicit success JUMPs for all alternatives; got prefix={pattern_prefix:?}" + ); + } + + #[test] + fn with_protected_generator_tail_after_cleanup_uses_strong_loads() { let code = compile_exec( r#" def f(scandir, fspath, path, reversed, top, OSError, topdown=True, followlinks=False): @@ -32735,7 +39949,7 @@ def f(scandir, fspath, path, reversed, top, OSError, topdown=True, followlinks=F } #[test] - fn test_yield_from_finally_cleanup_keeps_normal_path_borrows() { + fn yield_from_finally_cleanup_keeps_normal_path_borrows() { let code = compile_exec( r#" def f(_fwalk, stack, isbytes, topdown, onerror, follow_symlinks, close): diff --git a/crates/codegen/src/error.rs b/crates/codegen/src/error.rs index 086f9dfd739..fb848354e86 100644 --- a/crates/codegen/src/error.rs +++ b/crates/codegen/src/error.rs @@ -38,16 +38,20 @@ impl fmt::Display for CodegenError { #[derive(Debug)] #[non_exhaustive] pub enum InternalError { - StackOverflow, StackUnderflow, + InconsistentStackDepth, + InvalidStackEffect, + MalformedControlFlowGraph, MissingSymbol(String), } impl Display for InternalError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::StackOverflow => write!(f, "stack overflow"), - Self::StackUnderflow => write!(f, "stack underflow"), + Self::StackUnderflow => write!(f, "Invalid CFG, stack underflow"), + Self::InconsistentStackDepth => write!(f, "Invalid CFG, inconsistent stackdepth"), + Self::InvalidStackEffect => write!(f, "Invalid stack effect"), + Self::MalformedControlFlowGraph => write!(f, "malformed control flow graph."), Self::MissingSymbol(s) => write!( f, "The symbol '{s}' must be present in the symbol table, even when it is undefined in python." @@ -101,74 +105,73 @@ impl core::error::Error for CodegenErrorType {} impl fmt::Display for CodegenErrorType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use CodegenErrorType::*; match self { - Assign(target) => write!(f, "cannot assign to {target}"), - Delete(target) => write!(f, "cannot delete {target}"), - SyntaxError(err) => write!(f, "{}", err.as_str()), - MultipleStarArgs => { + Self::Assign(target) => write!(f, "cannot assign to {target}"), + Self::Delete(target) => write!(f, "cannot delete {target}"), + Self::SyntaxError(err) => write!(f, "{}", err.as_str()), + Self::MultipleStarArgs => { write!(f, "multiple starred expressions in assignment") } - InvalidStarExpr => write!(f, "can't use starred expression here"), - InvalidBreak => write!(f, "'break' outside loop"), - InvalidContinue => write!(f, "'continue' outside loop"), - InvalidReturn => write!(f, "'return' outside function"), - InvalidYield => write!(f, "'yield' outside function"), - InvalidYieldFrom => write!(f, "'yield from' outside function"), - InvalidAwait => write!(f, "'await' outside async function"), - InvalidAsyncFor => write!(f, "'async for' outside async function"), - InvalidAsyncWith => write!(f, "'async with' outside async function"), - InvalidAsyncComprehension => { + Self::InvalidStarExpr => write!(f, "can't use starred expression here"), + Self::InvalidBreak => write!(f, "'break' outside loop"), + Self::InvalidContinue => write!(f, "'continue' not properly in loop"), + Self::InvalidReturn => write!(f, "'return' outside function"), + Self::InvalidYield => write!(f, "'yield' outside function"), + Self::InvalidYieldFrom => write!(f, "'yield from' outside function"), + Self::InvalidAwait => write!(f, "'await' outside async function"), + Self::InvalidAsyncFor => write!(f, "'async for' outside async function"), + Self::InvalidAsyncWith => write!(f, "'async with' outside async function"), + Self::InvalidAsyncComprehension => { write!( f, "asynchronous comprehension outside of an asynchronous function" ) } - AsyncYieldFrom => write!(f, "'yield from' inside async function"), - AsyncReturnValue => { + Self::AsyncYieldFrom => write!(f, "'yield from' inside async function"), + Self::AsyncReturnValue => { write!(f, "'return' with value inside async generator") } - InvalidFuturePlacement => write!( + Self::InvalidFuturePlacement => write!( f, "from __future__ imports must occur at the beginning of the file" ), - InvalidFutureFeature(feat) => { + Self::InvalidFutureFeature(feat) => { write!(f, "future feature {feat} is not defined") } - FunctionImportStar => { + Self::FunctionImportStar => { write!(f, "import * only allowed at module level") } - TooManyStarUnpack => { + Self::TooManyStarUnpack => { write!(f, "too many expressions in star-unpacking assignment") } - EmptyWithItems => { + Self::EmptyWithItems => { write!(f, "empty items on With") } - EmptyWithBody => { + Self::EmptyWithBody => { write!(f, "empty body on With") } - ForbiddenName => { + Self::ForbiddenName => { write!(f, "forbidden attribute name") } - DuplicateStore(s) => { + Self::DuplicateStore(s) => { write!(f, "duplicate store {s}") } - UnreachablePattern(reason) => { + Self::UnreachablePattern(reason) => { write!(f, "{reason} makes remaining patterns unreachable") } - RepeatedAttributePattern => { + Self::RepeatedAttributePattern => { write!(f, "attribute name repeated in class pattern") } - ConflictingNameBindPattern => { + Self::ConflictingNameBindPattern => { write!(f, "alternative patterns bind different names") } - BreakContinueReturnInExceptStar => { + Self::BreakContinueReturnInExceptStar => { write!( f, "'break', 'continue' and 'return' cannot appear in an except* block" ) } - NotImplementedYet => { + Self::NotImplementedYet => { write!(f, "RustPython does not implement this feature yet") } } diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index 231da704f6f..f22efe8e52d 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -1,4 +1,3 @@ -use alloc::collections::VecDeque; use core::ops; use crate::{IndexMap, IndexSet, error::InternalError}; @@ -10,10 +9,10 @@ use rustpython_wtf8::Wtf8Buf; use rustpython_compiler_core::{ OneIndexed, SourceLocation, bytecode::{ - AnyInstruction, AnyOpcode, Arg, CO_FAST_CELL, CO_FAST_FREE, CO_FAST_HIDDEN, CO_FAST_LOCAL, - CodeFlags, CodeObject, CodeUnit, CodeUnits, ConstantData, ExceptionTableEntry, - InstrDisplayContext, Instruction, IntrinsicFunction1, Label, OpArg, Opcode, - PseudoInstruction, PseudoOpcode, PyCodeLocationInfoKind, encode_exception_table, oparg, + AnyInstruction, AnyOpcode, Arg, CO_FAST_ARG_KW, CO_FAST_ARG_POS, CO_FAST_ARG_VAR, + CO_FAST_CELL, CO_FAST_FREE, CO_FAST_HIDDEN, CO_FAST_LOCAL, CodeFlags, CodeObject, CodeUnit, + CodeUnits, ConstantData, InstrDisplayContext, Instruction, IntrinsicFunction1, OpArg, + OpArgByte, Opcode, PseudoInstruction, PseudoOpcode, PyCodeLocationInfoKind, oparg, }, varint::{write_signed_varint, write_varint}, }; @@ -27,20 +26,145 @@ struct LineTableLocation { end_col: i32, } -const MAX_INT_SIZE_BITS: u64 = 128; +#[derive(Clone, Copy)] +struct InstructionLocation { + location: SourceLocation, + end_location: SourceLocation, + lineno_override: Option, +} + +pub(crate) const LINE_ONLY_LOCATION_OVERRIDE: i32 = -4; +pub(crate) const NEXT_LOCATION_OVERRIDE: i32 = -2; +pub(crate) const NO_LOCATION_OVERRIDE: i32 = -1; + +const MAX_INT_SIZE: u64 = 128; const MAX_COLLECTION_SIZE: usize = 256; +const DEFAULT_CODE_SIZE: usize = 128; +const DEFAULT_LNOTAB_SIZE: usize = 16; +const DEFAULT_CNOTAB_SIZE: usize = 32; +const DEFAULT_BLOCK_SIZE: usize = 16; +const INITIAL_INSTR_SEQUENCE_SIZE: usize = 100; +const INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE: usize = 10; +const MAX_REAL_OPCODE: u16 = 254; +const MAX_OPCODE: u16 = 511; const MAX_TOTAL_ITEMS: isize = 1024; const MAX_STR_SIZE: usize = 4096; const MIN_CONST_SEQUENCE_SIZE: usize = 3; const STACK_USE_GUIDELINE: usize = 30; +/// pycore_opcode_utils.h IS_WITHIN_OPCODE_RANGE +fn is_within_opcode_range(opcode: AnyOpcode) -> bool { + match opcode { + AnyOpcode::Real(opcode) => u16::from(opcode.as_u8()) <= MAX_REAL_OPCODE, + AnyOpcode::Pseudo(opcode) => opcode.as_u16() <= MAX_OPCODE, + } +} + +#[derive(Clone, Debug, Default)] +pub struct ConstantPool { + constants: Vec, +} + +impl ConstantPool { + fn constant_contains_nan(constant: &ConstantData) -> bool { + match constant { + ConstantData::Float { value } => value.is_nan(), + ConstantData::Complex { value } => value.re.is_nan() || value.im.is_nan(), + ConstantData::Tuple { elements } | ConstantData::Frozenset { elements } => { + elements.iter().any(Self::constant_contains_nan) + } + ConstantData::Slice { elements } => elements.iter().any(Self::constant_contains_nan), + _ => false, + } + } + + pub fn insert_full(&mut self, constant: ConstantData) -> (usize, bool) { + // CPython's _PyCode_ConstantKey() keeps NaN-bearing constants distinct + // because Python-level NaN keys do not compare equal. + if !Self::constant_contains_nan(&constant) + && let Some(idx) = self + .constants + .iter() + .position(|existing| existing == &constant) + { + return (idx, false); + } + let idx = self.constants.len(); + self.constants.push(constant); + (idx, true) + } + + fn try_insert_full(&mut self, constant: ConstantData) -> crate::InternalResult<(usize, bool)> { + // CPython's _PyCode_ConstantKey() keeps NaN-bearing constants distinct + // because Python-level NaN keys do not compare equal. + if !Self::constant_contains_nan(&constant) + && let Some(idx) = self + .constants + .iter() + .position(|existing| existing == &constant) + { + return Ok((idx, false)); + } + self.constants + .try_reserve_exact(1) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + let idx = self.constants.len(); + self.constants.push(constant); + Ok((idx, true)) + } + + pub fn insert(&mut self, constant: ConstantData) -> bool { + self.insert_full(constant).1 + } + + #[must_use] + pub fn get_index(&self, idx: usize) -> Option<&ConstantData> { + self.constants.get(idx) + } + + pub fn iter(&self) -> core::slice::Iter<'_, ConstantData> { + self.constants.iter() + } + + #[must_use] + pub fn len(&self) -> usize { + self.constants.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.constants.is_empty() + } + + pub fn clear(&mut self) { + self.constants.clear(); + } +} + +impl ops::Index for ConstantPool { + type Output = ConstantData; + + fn index(&self, idx: usize) -> &Self::Output { + &self.constants[idx] + } +} + +impl IntoIterator for ConstantPool { + type Item = ConstantData; + type IntoIter = alloc::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.constants.into_iter() + } +} + /// Metadata for a code unit // = _PyCompile_CodeUnitMetadata #[derive(Clone, Debug)] pub struct CodeUnitMetadata { pub name: String, // u_name (obj_name) pub qualname: Option, // u_qualname - pub consts: IndexSet, // u_consts + pub consts: ConstantPool, // u_consts pub names: IndexSet, // u_names pub varnames: IndexSet, // u_varnames pub cellvars: IndexSet, // u_cellvars @@ -115,24 +239,8 @@ pub struct InstructionInfo { pub location: SourceLocation, pub end_location: SourceLocation, pub except_handler: Option, - pub folded_from_nonliteral_expr: bool, /// Override line number for linetable (e.g., line 0 for module RESUME) pub lineno_override: Option, - /// Number of CACHE code units emitted after this instruction - pub cache_entries: u32, - /// Preserve a redundant jump until final emission so a zero-width jump - /// materializes as a line-marker NOP, matching CPython's late CFG shape. - pub preserve_redundant_jump_as_nop: bool, - /// Drop this NOP before line propagation if it still has no location. - pub remove_no_location_nop: bool, - /// This no-location NOP was created by CPython-style nop_out() folding. - pub folded_operand_nop: bool, - /// This instruction was emitted as part of a synthetic no-location exit. - pub no_location_exit: bool, - /// Keep this no-location NOP until line propagation when it starts a block. - pub preserve_block_start_no_location_nop: bool, - /// This is the success jump emitted after a non-final match case body. - pub match_success_jump: bool, } /// Exception handler information for an instruction. @@ -140,18946 +248,7389 @@ pub struct InstructionInfo { pub struct ExceptHandlerInfo { /// Block to jump to when exception occurs pub handler_block: BlockIdx, - /// Stack depth at handler entry - pub stack_depth: u32, /// Whether to push lasti before exception pub preserve_lasti: bool, } -fn set_to_nop(info: &mut InstructionInfo) { - info.instr = Instruction::Nop.into(); +/// flowgraph.c INSTR_SET_OP0 +fn instr_set_op0(info: &mut InstructionInfo, instr: AnyInstruction) { + debug_assert!(!AnyOpcode::from(instr).has_arg()); + info.instr = instr; info.arg = OpArg::new(0); - info.target = BlockIdx::NULL; - info.folded_from_nonliteral_expr = false; - info.cache_entries = 0; - info.preserve_redundant_jump_as_nop = false; - info.remove_no_location_nop = false; - info.folded_operand_nop = false; - info.no_location_exit = false; - info.preserve_block_start_no_location_nop = false; - info.match_success_jump = false; } -fn nop_out_no_location(info: &mut InstructionInfo) { - set_to_nop(info); - info.lineno_override = Some(-1); - info.remove_no_location_nop = true; - info.folded_operand_nop = true; +/// flowgraph.c INSTR_SET_OP1 +fn instr_set_op1(info: &mut InstructionInfo, instr: AnyInstruction, arg: OpArg) { + debug_assert!(AnyOpcode::from(instr).has_arg()); + info.instr = instr; + info.arg = arg; } -fn is_named_except_cleanup_normal_exit_block(block: &Block) -> bool { - let len = block.instructions.len(); - if len < 5 { - return false; - } - let tail = &block.instructions[len - 5..]; - matches!(tail[0].instr.real(), Some(Instruction::PopExcept)) - && matches!(tail[1].instr.real(), Some(Instruction::LoadConst { .. })) - && matches!( - tail[2].instr.real(), - Some(Instruction::StoreName { .. } | Instruction::StoreFast { .. }) - ) - && matches!( - tail[3].instr.real(), - Some(Instruction::DeleteName { .. } | Instruction::DeleteFast { .. }) - ) - && tail[4].instr.is_unconditional_jump() +/// flowgraph.c INSTR_SET_LOC +fn instr_set_loc( + info: &mut InstructionInfo, + location: SourceLocation, + end_location: SourceLocation, + lineno_override: Option, +) { + info.location = location; + info.end_location = end_location; + info.lineno_override = lineno_override; } -fn is_standalone_named_except_cleanup_normal_exit_block(block: &Block) -> bool { - let len = block.instructions.len(); - if len < 5 || !is_named_except_cleanup_normal_exit_block(block) { - return false; +fn instr_location(info: &InstructionInfo) -> InstructionLocation { + InstructionLocation { + location: info.location, + end_location: info.end_location, + lineno_override: info.lineno_override, } - block.instructions[..len - 5].iter().all(|info| { - matches!( - info.instr.real(), - Some(Instruction::Nop | Instruction::NotTaken) - ) - }) } -fn named_except_cleanup_body_is_fast_local_only(block: &Block) -> bool { - let len = block.instructions.len(); - if len < 5 || !is_named_except_cleanup_normal_exit_block(block) { - return false; - } - block.instructions[..len - 5].iter().all(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Nop - | Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::StoreFast { .. } - ) - ) - }) +fn instr_set_location(info: &mut InstructionInfo, loc: InstructionLocation) { + instr_set_loc(info, loc.location, loc.end_location, loc.lineno_override); } -// spell-checker:ignore petgraph -// TODO: look into using petgraph for handling blocks and stuff? it's heavier than this, but it -// might enable more analysis/optimizations -#[derive(Debug, Clone)] -pub struct Block { - pub instructions: Vec, - pub next: BlockIdx, - // Post-codegen analysis fields (set by label_exception_targets) - /// Whether this block is an exception handler target (b_except_handler) - pub except_handler: bool, - /// Whether to preserve lasti for this handler block (b_preserve_lasti) - pub preserve_lasti: bool, - /// Stack depth at block entry, set by stack depth analysis - pub start_depth: Option, - /// Whether this block is only reachable via exception table (b_cold) - pub cold: bool, - /// Whether LOAD_FAST borrow optimization should be suppressed for this block. - pub disable_load_fast_borrow: bool, - /// Entry block for a try-else orelse suite split after POP_BLOCK. - pub try_else_orelse_entry: bool, +fn no_instruction_location() -> InstructionLocation { + InstructionLocation { + location: SourceLocation::default(), + end_location: SourceLocation::default(), + lineno_override: Some(NO_LOCATION_OVERRIDE), + } } -impl Default for Block { - fn default() -> Self { - Self { - instructions: Vec::new(), - next: BlockIdx::NULL, - except_handler: false, - preserve_lasti: false, - start_depth: None, - cold: false, - disable_load_fast_borrow: false, - try_else_orelse_entry: false, +fn set_to_nop(info: &mut InstructionInfo) { + instr_set_op0(info, Instruction::Nop.into()); +} + +fn nop_out_no_location(info: &mut InstructionInfo) { + set_to_nop(info); + instr_set_loc( + info, + SourceLocation::default(), + SourceLocation::default(), + Some(NO_LOCATION_OVERRIDE), + ); +} + +fn empty_instruction_info() -> InstructionInfo { + InstructionInfo { + instr: Instruction::Nop.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: None, + } +} + +/// codegen.c _Py_CArray_EnsureCapacity +fn c_array_ensure_capacity( + allocated_entries: usize, + idx: usize, + initial_num_entries: usize, +) -> crate::InternalResult { + if allocated_entries == 0 { + let new_alloc = if idx >= initial_num_entries { + idx.checked_add(initial_num_entries) + .ok_or(InternalError::MalformedControlFlowGraph)? + } else { + initial_num_entries + }; + Ok(new_alloc) + } else if idx >= allocated_entries { + let oldsize = allocated_entries + .checked_mul(core::mem::size_of::()) + .ok_or(InternalError::MalformedControlFlowGraph)?; + let doubled = allocated_entries + .checked_mul(2) + .ok_or(InternalError::MalformedControlFlowGraph)?; + let new_alloc = if idx >= doubled { + idx.checked_add(initial_num_entries) + .ok_or(InternalError::MalformedControlFlowGraph)? + } else { + doubled + }; + let newsize = new_alloc + .checked_mul(core::mem::size_of::()) + .ok_or(InternalError::MalformedControlFlowGraph)?; + if oldsize > usize::MAX >> 1 || newsize == 0 { + return Err(InternalError::MalformedControlFlowGraph); } + Ok(new_alloc) + } else { + Ok(allocated_entries) } } -pub struct CodeInfo { - pub flags: CodeFlags, - pub source_path: String, - pub private: Option, // For private name mangling, mostly for class - - pub blocks: Vec, - pub current_block: BlockIdx, - pub annotations_blocks: Option>, +/// flowgraph.c basicblock_next_instr +fn basicblock_next_instr(block: &mut Block) -> crate::InternalResult { + let off = block.instruction_used; + let new_allocation = c_array_ensure_capacity::( + block.instruction_allocation, + off + 1, + DEFAULT_BLOCK_SIZE, + )?; + if new_allocation > block.instruction_allocation { + if new_allocation > block.instructions.len() { + block + .instructions + .try_reserve_exact(new_allocation - block.instructions.len()) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + block + .instructions + .resize_with(new_allocation, empty_instruction_info); + } + block.instruction_allocation = new_allocation; + } + debug_assert!(block.instruction_allocation > off); + block.instruction_used += 1; + Ok(off) +} - pub metadata: CodeUnitMetadata, +/// flowgraph.c basicblock_last_instr +fn basicblock_last_instr(block: &Block) -> Option<&InstructionInfo> { + debug_assert!(block.instruction_allocation >= block.instruction_used); + if block.instruction_used > 0 { + debug_assert!(!block.instructions.is_empty()); + Some(&block.instructions[block.instruction_used - 1]) + } else { + None + } +} - // For class scopes: attributes accessed via self.X - pub static_attributes: Option>, +/// flowgraph.c basicblock_last_instr +fn basicblock_last_instr_mut(block: &mut Block) -> Option<&mut InstructionInfo> { + debug_assert!(block.instruction_allocation >= block.instruction_used); + if block.instruction_used > 0 { + debug_assert!(!block.instructions.is_empty()); + Some(&mut block.instructions[block.instruction_used - 1]) + } else { + None + } +} - // True if compiling an inlined comprehension - pub in_inlined_comp: bool, +/// flowgraph.c basicblock_addop +fn basicblock_addop(block: &mut Block, mut info: InstructionInfo) -> crate::InternalResult<()> { + let opcode = AnyOpcode::from(info.instr); + debug_assert!(is_within_opcode_range(opcode)); + debug_assert!(!info.instr.is_assembler()); + debug_assert!( + info.instr.has_arg() || info.instr.has_target() || u32::from(info.arg) == 0, + "CPython basicblock_addop requires OPCODE_HAS_ARG, HAS_TARGET, or oparg == 0" + ); + debug_assert!( + u32::from(info.arg) < (1 << 30), + "CPython basicblock_addop requires 0 <= oparg < (1 << 30)" + ); + let off = basicblock_next_instr(block)?; + let except_handler = block.instructions[off].except_handler; + info.target = BlockIdx::NULL; + info.except_handler = except_handler; + block.instructions[off] = info; + Ok(()) +} - // Block stack for tracking nested control structures - pub fblock: Vec, +/// flowgraph.c basicblock_insert_instruction +fn basicblock_insert_instruction( + block: &mut Block, + pos: usize, + info: InstructionInfo, +) -> crate::InternalResult<()> { + let old_len = block.instruction_used; + debug_assert!(pos <= old_len); + basicblock_next_instr(block)?; + for i in (pos + 1..=old_len).rev() { + block.instructions[i] = block.instructions[i - 1]; + } + block.instructions[pos] = info; + Ok(()) +} - // Reference to the symbol table for this scope - pub symbol_table_index: usize, +/// flowgraph.c basicblock_append_instructions +fn basicblock_append_block_instructions( + blocks: &mut [Block], + to: BlockIdx, + from: BlockIdx, +) -> crate::InternalResult<()> { + debug_assert_ne!(to, from); + let from_len = blocks[from.idx()].instruction_used; + for i in 0..from_len { + let info = blocks[from.idx()].instructions[i]; + let off = basicblock_next_instr(&mut blocks[to.idx()])?; + blocks[to.idx()].instructions[off] = info; + } + Ok(()) +} - // PEP 649: Track nesting depth inside conditional blocks (if/for/while/etc.) - // u_in_conditional_block - pub in_conditional_block: u32, +/// flowgraph.c direct `b_iused = 0` +fn basicblock_clear(block: &mut Block) { + block.instruction_used = 0; +} - // Track when compiling the final direct statement in a sync with body. - pub in_final_with_cleanup_statement: u32, +/// CPython direct `b_instr[0]` access. Some passes set `b_iused = 0` +/// without clearing the backing array, so an empty basic block can still have +/// a first raw instruction slot. +fn basicblock_raw_first_instr_mut(block: &mut Block) -> &mut InstructionInfo { + debug_assert!(block.instruction_allocation > 0); + &mut block.instructions[0] +} - // Track when compiling the orelse suite of try/except. - pub in_try_else_orelse: u32, +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) struct InstructionSequenceLabel(i32); - // PEP 649: Next index for conditional annotation tracking - // u_next_conditional_annotation_index - pub next_conditional_annotation_index: u32, +/// flowgraph.c SAME_LABEL +fn same_label(a: InstructionSequenceLabel, b: InstructionSequenceLabel) -> bool { + a == b } -impl CodeInfo { - pub fn finalize_code( - mut self, - opts: &crate::compile::CompileOpts, - ) -> crate::InternalResult { - self.splice_annotations_blocks(); - // Constant folding passes - self.fold_binop_constants(); - self.fold_unary_constants(); - self.fold_binop_constants(); // re-run after unary folding: -1 + 2 → 1 - self.fold_unary_constants(); // re-run after binop folding: -(2.0 ** -54) → const - self.fold_tuple_constants(); - self.fold_binop_constants(); // re-run after tuple folding: (1,) * 2 → const - self.fold_list_constants(); - self.fold_set_constants(); - self.optimize_lists_and_sets(); - self.convert_to_load_small_int(); - self.remove_unused_consts(); - - // DCE always runs (removes dead code after terminal instructions) - self.dce(); - // BUILD_TUPLE n + UNPACK_SEQUENCE n → NOP + SWAP (n=2,3) or NOP+NOP (n=1) - self.optimize_build_tuple_unpack(); - // Dead store elimination for duplicate STORE_FAST targets - // (apply_static_swaps in CPython's flowgraph.c) - self.eliminate_dead_stores(); - // apply_static_swaps: reorder stores to eliminate SWAPs - self.apply_static_swaps(); - // Peephole optimizer handles constant and compare folding. - self.peephole_optimize(); - self.fold_tuple_constants(); - self.fold_binop_constants(); - self.fold_list_constants(); - self.fold_set_constants(); - self.optimize_lists_and_sets(); - self.convert_to_load_small_int(); - self.remove_unused_consts(); - self.dce(); - - // Phase 1: _PyCfg_OptimizeCodeUnit (flowgraph.c) - // Split blocks so each block has at most one branch as its last instruction - split_blocks_at_jumps(&mut self.blocks); - mark_except_handlers(&mut self.blocks); - label_exception_targets(&mut self.blocks); - // CPython's CFG builder does not leave empty unconditional-jump targets - // in front of small exit blocks. Redirect only unconditional jumps - // here so inline_small_or_no_lineno_blocks() can see direct exit - // targets without erasing conditional target NOP anchors. - redirect_empty_unconditional_jump_targets(&mut self.blocks); - // CPython optimize_cfg starts by inlining tiny exit/no-lineno blocks - // before unreachable elimination and later jump cleanup. - inline_small_or_no_lineno_blocks(&mut self.blocks); - // optimize_cfg: jump threading (before push_cold_blocks_to_end) - jump_threading(&mut self.blocks); - self.eliminate_unreachable_blocks(); - self.remove_nops(); - self.add_checks_for_loads_of_uninitialized_variables(); - // CPython inserts superinstructions in _PyCfg_OptimizeCodeUnit, before - // later jump normalization / block reordering can create adjacencies - // that never exist at this stage in flowgraph.c. - self.insert_superinstructions(); - // CPython resolves line numbers once before cold-block extraction and - // again after reordering blocks. - resolve_line_numbers(&mut self.blocks); - inline_single_predecessor_artificial_expr_exit_blocks(&mut self.blocks); - push_cold_blocks_to_end(&mut self.blocks); - reorder_conditional_chain_and_jump_back_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, true, true); - - // Phase 2: _PyCfg_OptimizedCfgToInstructionSequence (flowgraph.c) - normalize_jumps(&mut self.blocks); - reorder_conditional_exit_and_jump_blocks(&mut self.blocks); - reorder_conditional_jump_and_exit_blocks(&mut self.blocks); - reorder_conditional_break_continue_blocks(&mut self.blocks); - reorder_conditional_explicit_continue_scope_exit_blocks(&mut self.blocks); - reorder_conditional_implicit_continue_scope_exit_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, true, true); - reorder_exception_handler_conditional_continue_scope_exit_blocks(&mut self.blocks); - deduplicate_adjacent_jump_back_blocks(&mut self.blocks); - reorder_conditional_body_and_implicit_continue_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, true, true); - reorder_jump_over_exception_cleanup_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, false, true); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, false, true); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, false, false); - self.dce(); // re-run within-block DCE after normalize_jumps creates new instructions - self.eliminate_unreachable_blocks(); - resolve_line_numbers(&mut self.blocks); - materialize_empty_conditional_exit_targets(&mut self.blocks); - redirect_empty_block_targets(&mut self.blocks); - inline_small_fast_return_blocks(&mut self.blocks); - inline_unprotected_tuple_genexpr_assignment_return_blocks(&mut self.blocks); - duplicate_end_returns(&mut self.blocks, &self.metadata); - duplicate_fallthrough_jump_back_targets(&mut self.blocks); - duplicate_shared_jump_back_targets(&mut self.blocks); - self.dce(); // truncate after terminal in blocks that got return duplicated - self.eliminate_unreachable_blocks(); // remove now-unreachable last block - self.remove_redundant_const_pop_top_pairs(); - remove_redundant_nops_and_jumps(&mut self.blocks); - // Some jump-only blocks only appear after late CFG cleanup. Thread them - // once more so loop backedges stay direct instead of becoming - // JUMP_FORWARD -> JUMP_BACKWARD chains. - jump_threading_unconditional(&mut self.blocks); - reorder_jump_over_exception_cleanup_blocks(&mut self.blocks); - self.eliminate_unreachable_blocks(); - remove_redundant_nops_and_jumps(&mut self.blocks); - inline_with_suppress_return_blocks(&mut self.blocks); - inline_pop_except_return_blocks(&mut self.blocks); - inline_named_except_cleanup_normal_exit_jumps(&mut self.blocks); - duplicate_named_except_cleanup_returns(&mut self.blocks, &self.metadata); - self.eliminate_unreachable_blocks(); - resolve_line_numbers(&mut self.blocks); - let cellfixedoffsets = build_cellfixedoffsets( - &self.metadata.varnames, - &self.metadata.cellvars, - &self.metadata.freevars, - ); - // Late CFG cleanup can create or reshuffle handler entry blocks. - // Refresh exceptional block flags before optimize_load_fast_borrow so - // borrow loads are not introduced into exception-handler paths. - mark_except_handlers(&mut self.blocks); - redirect_empty_block_targets(&mut self.blocks); - // CPython's optimize_load_fast runs with block start depths already known. - // Compute them here so the abstract stack simulation can use the real - // CFG entry depth for each block. - let max_stackdepth = self.max_stackdepth()?; - // Match CPython order: pseudo ops are lowered after stackdepth - // calculation, then redundant NOPs from pseudo lowering are removed - // before optimize_load_fast. - convert_pseudo_ops(&mut self.blocks, &cellfixedoffsets); - remove_redundant_nops_and_jumps(&mut self.blocks); - self.mark_unprotected_debug_four_tails_borrow_disabled(); - self.mark_exception_handler_transition_targets_borrow_disabled(); - self.mark_targeted_nop_for_tails_borrow_disabled(); - self.restore_conditional_exception_for_iter_join_borrows(); - self.compute_load_fast_start_depths(); - // optimize_load_fast: after normalize_jumps - self.optimize_load_fast_borrow(); - self.deoptimize_borrow_in_targeted_assert_message_blocks(); - self.deoptimize_borrow_for_folded_nonliteral_exprs(); - self.deoptimize_borrow_after_generator_exception_return(); - self.deoptimize_borrow_after_async_for_cleanup_resume(); - self.deoptimize_borrow_after_multi_handler_resume_join(); - self.deoptimize_borrow_after_named_except_cleanup_join(); - self.deoptimize_borrow_after_reraising_except_handler(); - self.deoptimize_borrow_in_protected_conditional_tail(); - self.deoptimize_borrow_after_terminal_except_tail(); - self.deoptimize_borrow_after_except_star_try_tail(); - self.deoptimize_borrow_in_protected_method_call_after_terminal_except_tail(); - self.deoptimize_borrow_after_terminal_except_before_with(); - self.deoptimize_borrow_after_handler_resume_loop_tail(); - self.deoptimize_borrow_after_protected_import(); - self.deoptimize_borrow_before_import_after_join_store(); - self.deoptimize_borrow_after_protected_store_tail(); - self.deoptimize_borrow_after_deoptimized_async_with_enter(); - self.deoptimize_borrow_for_handler_return_paths(); - self.deoptimize_borrow_for_match_keys_attr(); - self.deoptimize_borrow_in_protected_attr_chain_tail(); - self.reborrow_after_suppressing_handler_resume_cleanup(); - self.deoptimize_store_fast_store_fast_after_cleanup(); - self.apply_static_swaps(); - self.deoptimize_store_fast_store_fast_after_cleanup(); - self.optimize_load_global_push_null(); - self.reorder_entry_prefix_cell_setup(); - self.remove_unused_consts(); +/// flowgraph.c IS_LABEL +fn is_label(label: InstructionSequenceLabel) -> bool { + !same_label(label, InstructionSequenceLabel::NO_LABEL) +} - let Self { - flags, - source_path, - private: _, // private is only used during compilation +impl InstructionSequenceLabel { + pub(crate) const NO_LABEL: Self = Self(-1); - mut blocks, - current_block: _, - annotations_blocks: _, - metadata, - static_attributes: _, - in_inlined_comp: _, - fblock: _, - symbol_table_index: _, - in_conditional_block: _, - in_final_with_cleanup_statement: _, - in_try_else_orelse: _, - next_conditional_annotation_index: _, - } = self; + pub(crate) fn from_index(index: i32) -> Self { + Self(index) + } - let CodeUnitMetadata { - name: obj_name, - qualname, - consts: constants, - names: name_cache, - varnames: varname_cache, - cellvars: cellvar_cache, - freevars: freevar_cache, - fast_hidden, - fast_hidden_final, - argcount: arg_count, - posonlyargcount: posonlyarg_count, - kwonlyargcount: kwonlyarg_count, - firstlineno: first_line_number, - } = metadata; + pub(crate) fn is_jump_target_label(self) -> bool { + is_label(self) + } - let mut instructions = Vec::new(); - let mut locations = Vec::new(); - let mut linetable_locations: Vec = Vec::new(); - - // Build cellfixedoffsets for cell-local merging - let cellfixedoffsets = - build_cellfixedoffsets(&varname_cache, &cellvar_cache, &freevar_cache); - // Convert pseudo ops (LoadClosure uses cellfixedoffsets) and fixup DEREF opargs - convert_pseudo_ops(&mut blocks, &cellfixedoffsets); - fixup_deref_opargs(&mut blocks, &cellfixedoffsets); - // Remove redundant NOPs, keeping line-marker NOPs only when - // they are needed to preserve tracing. - let mut block_order = Vec::new(); - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - block_order.push(current); - current = blocks[current.idx()].next; - } - for block_idx in block_order { - let bi = block_idx.idx(); - let mut src_instructions = core::mem::take(&mut blocks[bi].instructions); - let mut kept = Vec::with_capacity(src_instructions.len()); - let mut prev_lineno = -1i32; - - for src in 0..src_instructions.len() { - let instr = src_instructions[src]; - let lineno = instr - .lineno_override - .unwrap_or_else(|| instr.location.line.get() as i32); - let mut remove = false; - - if matches!(instr.instr.real(), Some(Instruction::Nop)) { - if instr.preserve_redundant_jump_as_nop { - remove = false; - } else if lineno < 0 || prev_lineno == lineno { - remove = true; - } else if src < src_instructions.len() - 1 { - if src_instructions[src + 1].instr.is_block_push() { - remove = false; - } else if src_instructions[src + 1].folded_from_nonliteral_expr { - remove = true; - } else { - let next_lineno = src_instructions[src + 1] - .lineno_override - .unwrap_or_else(|| { - src_instructions[src + 1].location.line.get() as i32 - }); - if next_lineno == lineno { - remove = true; - } else if next_lineno < 0 { - src_instructions[src + 1].lineno_override = Some(lineno); - remove = true; - } - } - } else { - let mut next = blocks[bi].next; - while next != BlockIdx::NULL && blocks[next.idx()].instructions.is_empty() { - next = blocks[next.idx()].next; - } - if next != BlockIdx::NULL { - let mut next_lineno = None; - for next_instr in &blocks[next.idx()].instructions { - let line = next_instr - .lineno_override - .unwrap_or_else(|| next_instr.location.line.get() as i32); - if matches!(next_instr.instr.real(), Some(Instruction::Nop)) - && line < 0 - { - continue; - } - next_lineno = Some(line); - break; - } - if next_lineno.is_some_and(|line| line == lineno) { - remove = true; - } - } - } - } + pub(crate) fn idx(self) -> usize { + debug_assert!(self.0 >= 0); + self.0 as usize + } +} - if !remove { - kept.push(instr); - prev_lineno = lineno; - } - } +#[derive(Clone, Copy)] +struct InstructionSequenceExceptHandlerInfo { + h_label: i32, + start_depth: i32, + preserve_lasti: i32, +} - blocks[bi].instructions = kept; - } +const NO_EXCEPTION_HANDLER_LABEL: i32 = -1; +const ZERO_EXCEPTION_HANDLER_INFO: InstructionSequenceExceptHandlerInfo = + InstructionSequenceExceptHandlerInfo { + h_label: 0, + start_depth: 0, + preserve_lasti: 0, + }; - // Final DCE: truncate instructions after terminal ops in linearized blocks. - // This catches dead code created by normalize_jumps after the initial DCE. - for block in &mut blocks { - if let Some(pos) = block - .instructions - .iter() - .position(|ins| ins.instr.is_scope_exit() || ins.instr.is_unconditional_jump()) - { - block.instructions.truncate(pos + 1); - } - } +#[derive(Clone, Copy)] +struct InstructionSequenceEntry { + info: InstructionInfo, + except_handler: InstructionSequenceExceptHandlerInfo, + i_target: i32, + i_offset: i32, +} - // Pre-compute cache_entries for real (non-pseudo) instructions - for block in &mut blocks { - for instr in &mut block.instructions { - if let AnyInstruction::Real(op) = instr.instr { - instr.cache_entries = op.cache_entries() as u32; - } - } +impl InstructionSequenceEntry { + fn new(info: InstructionInfo, except_handler: InstructionSequenceExceptHandlerInfo) -> Self { + Self { + info, + except_handler, + i_target: 0, + i_offset: 0, + } + } +} + +const INSTRUCTION_SEQUENCE_UNSET_LABEL: i32 = -111; + +#[derive(Clone)] +pub(crate) struct InstructionSequence { + /// CPython `instr_sequence.s_instrs`, including allocated slots beyond `s_used`. + instrs: Vec, + /// CPython `instr_sequence.s_allocated`, the allocated size of `s_instrs`. + instr_allocation: usize, + /// CPython `instr_sequence.s_used`, the number of used entries in `s_instrs`. + instr_used: usize, + /// CPython `instr_sequence.s_next_free_label`. + next_free_label: i32, + label_map: Option>, + label_map_allocation: usize, + annotations_code: Option>, +} + +impl InstructionSequence { + pub(crate) fn new() -> Self { + instruction_sequence_new() + } +} + +/// instruction_sequence.c _PyInstructionSequence_New / inst_seq_create +fn instruction_sequence_new() -> InstructionSequence { + InstructionSequence { + instrs: Vec::new(), + instr_allocation: 0, + instr_used: 0, + next_free_label: 0, + label_map: None, + label_map_allocation: 0, + annotations_code: None, + } +} + +/// instruction_sequence.c instr_sequence_next_inst +fn instruction_sequence_next_inst(seq: &mut InstructionSequence) -> crate::InternalResult { + debug_assert!(!seq.instrs.is_empty() || seq.instr_used == 0); + let idx = seq.instr_used; + let new_allocation = c_array_ensure_capacity::( + seq.instr_allocation, + idx + 1, + INITIAL_INSTR_SEQUENCE_SIZE, + )?; + if new_allocation > seq.instr_allocation { + if new_allocation > seq.instrs.capacity() { + seq.instrs + .try_reserve_exact(new_allocation - seq.instrs.capacity()) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + } + if new_allocation > seq.instrs.len() { + seq.instrs.resize( + new_allocation, + InstructionSequenceEntry::new( + InstructionInfo { + instr: Instruction::Cache.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: None, + }, + ZERO_EXCEPTION_HANDLER_INFO, + ), + ); } + seq.instr_allocation = new_allocation; + } + debug_assert!(seq.instr_allocation > idx); + seq.instr_used += 1; + Ok(idx) +} - let mut block_to_offset = vec![Label::from_u32(0); blocks.len()]; - // block_to_index: maps block idx to instruction index (for exception table) - // This is the index into the final instructions array, including EXTENDED_ARG and CACHE - let mut block_to_index = vec![0u32; blocks.len()]; - // The offset (in code units) of END_SEND from SEND in the yield-from sequence. - const END_SEND_OFFSET: u32 = 5; - loop { - let mut num_instructions = 0; - for (idx, block) in iter_blocks(&blocks) { - block_to_offset[idx.idx()] = Label::from_u32(num_instructions as u32); - // block_to_index uses the same value as block_to_offset but as u32 - // because lasti in frame.rs is the index into instructions array - // and instructions array index == byte offset (each instruction is 1 CodeUnit) - block_to_index[idx.idx()] = num_instructions as u32; - for instr in &block.instructions { - num_instructions += instr.arg.instr_size() + instr.cache_entries as usize; - } - } - - instructions.reserve_exact(num_instructions); - locations.reserve_exact(num_instructions); +/// instruction_sequence.c _PyInstructionSequence_NewLabel +fn instruction_sequence_new_label(seq: &mut InstructionSequence) -> InstructionSequenceLabel { + seq.next_free_label += 1; + InstructionSequenceLabel(seq.next_free_label) +} - let mut recompile = false; - let mut next_block = BlockIdx(0); - while next_block != BlockIdx::NULL { - let block = &mut blocks[next_block]; - // Track current instruction offset for jump direction resolution - let mut current_offset = block_to_offset[next_block.idx()].as_u32(); - for info in &mut block.instructions { - let target = info.target; - let mut op = info.instr.expect_real(); - let old_arg_size = info.arg.instr_size(); - let old_cache_entries = info.cache_entries; - // Keep offsets fixed within this pass: changes in jump - // arg/cache sizes only take effect in the next iteration. - let offset_after = current_offset + old_arg_size as u32 + old_cache_entries; - - if target != BlockIdx::NULL { - let target_offset = block_to_offset[target.idx()].as_u32(); - if info.instr.is_unconditional_jump() && target_offset == offset_after { - op = Opcode::Nop.into(); - info.instr = op.into(); - info.target = BlockIdx::NULL; - recompile = true; - let updated_cache = op.cache_entries() as u32; - recompile |= updated_cache != old_cache_entries; - info.cache_entries = updated_cache; - let new_arg = OpArg::NULL; - recompile |= new_arg.instr_size() != old_arg_size; - info.arg = new_arg; - } else { - // Direction must be based on concrete instruction offsets. - // Empty blocks can share offsets, so block-order-based resolution - // may classify some jumps incorrectly. - op = match op.into() { - Opcode::JumpForward if target_offset <= current_offset => { - Opcode::JumpBackward.into() - } - Opcode::JumpBackward if target_offset > current_offset => { - Opcode::JumpForward.into() - } - Opcode::JumpBackwardNoInterrupt - if target_offset > current_offset => - { - Opcode::JumpForward.into() - } - _ => op, - }; - info.instr = op.into(); - let updated_cache = op.cache_entries() as u32; - recompile |= updated_cache != old_cache_entries; - info.cache_entries = updated_cache; - let new_arg = if matches!(op, Instruction::EndAsyncFor) { - let arg = offset_after - .checked_sub(target_offset + END_SEND_OFFSET) - .expect("END_ASYNC_FOR target must be before instruction"); - OpArg::new(arg) - } else if matches!( - op.into(), - Opcode::JumpBackward | Opcode::JumpBackwardNoInterrupt - ) { - let arg = offset_after - .checked_sub(target_offset) - .expect("backward jump target must be before instruction"); - OpArg::new(arg) - } else { - let arg = target_offset - .checked_sub(offset_after) - .expect("forward jump target must be after instruction"); - OpArg::new(arg) - }; - recompile |= new_arg.instr_size() != old_arg_size; - info.arg = new_arg; - } - } +/// instruction_sequence.c _PyInstructionSequence_Addop asserts. +fn instruction_sequence_debug_check_addop(info: &InstructionInfo) { + let opcode = AnyOpcode::from(info.instr); + debug_assert!(is_within_opcode_range(opcode)); + debug_assert!( + opcode.has_arg() || info.instr.has_target() || u32::from(info.arg) == 0, + "CPython _PyInstructionSequence_Addop requires either OPCODE_HAS_ARG, HAS_TARGET, or oparg == 0" + ); + debug_assert!( + u32::from(info.arg) < (1 << 30), + "CPython _PyInstructionSequence_Addop requires 0 <= oparg < (1 << 30)" + ); +} - let cache_count = info.cache_entries as usize; - let (extras, lo_arg) = info.arg.split(); - let loc_pair = (info.location, info.end_location); - locations.extend(core::iter::repeat_n( - loc_pair, - info.arg.instr_size() + cache_count, - )); - // Collect linetable locations with lineno_override support - let lt_loc = LineTableLocation { - line: info - .lineno_override - .unwrap_or_else(|| info.location.line.get() as i32), - end_line: info.end_location.line.get() as i32, - col: info.location.character_offset.to_zero_indexed() as i32, - end_col: info.end_location.character_offset.to_zero_indexed() as i32, - }; - linetable_locations.extend(core::iter::repeat_n(lt_loc, info.arg.instr_size())); - // CACHE entries inherit parent instruction's location - if cache_count > 0 { - linetable_locations.extend(core::iter::repeat_n(lt_loc, cache_count)); - } - instructions.extend( - extras - .map(|byte| CodeUnit::new(Instruction::ExtendedArg, byte)) - .chain([CodeUnit { op, arg: lo_arg }]), - ); - // Emit CACHE code units after the instruction (all zeroed) - if cache_count > 0 { - instructions.extend(core::iter::repeat_n( - CodeUnit::new(Instruction::Cache, 0.into()), - cache_count, - )); - } - current_offset = offset_after; - } - next_block = block.next; +/// instruction_sequence.c _PyInstructionSequence_SetAnnotationsCode +fn instruction_sequence_set_annotations_code( + seq: &mut InstructionSequence, + annotations_code: Option>, +) { + debug_assert!(seq.annotations_code.is_none()); + seq.annotations_code = annotations_code; +} + +/// instruction_sequence.c _PyInstructionSequence_UseLabel +#[allow(clippy::needless_range_loop)] +fn instruction_sequence_use_label( + seq: &mut InstructionSequence, + label: InstructionSequenceLabel, +) -> crate::InternalResult<()> { + let old_size = seq.label_map_allocation; + let new_allocation = c_array_ensure_capacity::( + seq.label_map_allocation, + label.idx(), + INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE, + )?; + if new_allocation > seq.label_map_allocation { + if let Some(label_map) = &mut seq.label_map { + if new_allocation > label_map.capacity() { + label_map + .try_reserve_exact(new_allocation - label_map.capacity()) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; } + } else { + let mut label_map = Vec::new(); + label_map + .try_reserve_exact(new_allocation) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + seq.label_map = Some(label_map); + } + seq.label_map_allocation = new_allocation; + } + let label_map = seq + .label_map + .as_mut() + .ok_or(InternalError::MalformedControlFlowGraph)?; + if label_map.len() < seq.label_map_allocation { + label_map.resize(seq.label_map_allocation, INSTRUCTION_SEQUENCE_UNSET_LABEL); + } + for i in old_size..seq.label_map_allocation { + label_map[i] = INSTRUCTION_SEQUENCE_UNSET_LABEL; + } + label_map[label.idx()] = seq.instr_used as i32; + Ok(()) +} + +/// instruction_sequence.c _PyInstructionSequence_Addop +fn instruction_sequence_addop( + seq: &mut InstructionSequence, + info: InstructionInfo, +) -> crate::InternalResult<&mut InstructionSequenceEntry> { + instruction_sequence_debug_check_addop(&info); + let idx = instruction_sequence_next_inst(seq)?; + let entry = &mut seq.instrs[idx]; + entry.info = info; + Ok(entry) +} + +fn instruction_sequence_last_info_mut( + seq: &mut InstructionSequence, +) -> Option<&mut InstructionInfo> { + if seq.instr_used == 0 { + None + } else { + Some(&mut seq.instrs[seq.instr_used - 1].info) + } +} - if !recompile { - break; +/// instruction_sequence.c _PyInstructionSequence_InsertInstruction +#[allow(clippy::needless_range_loop)] +fn instruction_sequence_insert_instruction( + seq: &mut InstructionSequence, + pos: usize, + info: InstructionInfo, +) -> crate::InternalResult<()> { + debug_assert!(pos <= seq.instr_used); + let last_idx = instruction_sequence_next_inst(seq)?; + for i in (pos..last_idx).rev() { + seq.instrs[i + 1] = seq.instrs[i]; + } + seq.instrs[pos].info = info; + if let Some(label_map) = &mut seq.label_map { + let pos = pos as i32; + for lbl in 0..seq.label_map_allocation { + if label_map[lbl] >= pos { + label_map[lbl] += 1; } - - instructions.clear(); - locations.clear(); - linetable_locations.clear(); } + } + Ok(()) +} - // Generate linetable from linetable_locations (supports line 0 for RESUME) - let linetable = generate_linetable( - &linetable_locations, - first_line_number.get() as i32, - opts.debug_ranges, - ); - let locations = rustpython_compiler_core::marshal::linetable_to_locations( - &linetable, - first_line_number.get() as i32, - instructions.len(), - ); - - // Generate exception table before moving source_path - let exceptiontable = generate_exception_table(&blocks, &block_to_index); - - // Build localspluskinds with cell-local merging - let nlocals = varname_cache.len(); - let ncells = cellvar_cache.len(); - let nfrees = freevar_cache.len(); - let numdropped = cellvar_cache - .iter() - .filter(|cv| varname_cache.contains(cv.as_str())) - .count(); - let nlocalsplus = nlocals + ncells - numdropped + nfrees; - let mut localspluskinds = vec![0u8; nlocalsplus]; - // Mark locals - for kind in localspluskinds.iter_mut().take(nlocals) { - *kind = CO_FAST_LOCAL; - } - // Mark cells (merged and non-merged) - for (i, cellvar) in cellvar_cache.iter().enumerate() { - let idx = cellfixedoffsets[i] as usize; - if varname_cache.contains(cellvar.as_str()) { - localspluskinds[idx] |= CO_FAST_CELL; // merged: LOCAL | CELL - } else { - localspluskinds[idx] = CO_FAST_CELL; +/// instruction_sequence.c _PyInstructionSequence_ApplyLabelMap +#[allow(clippy::needless_range_loop, clippy::unnecessary_wraps)] +fn instruction_sequence_apply_label_map( + instrs: &mut InstructionSequence, +) -> crate::InternalResult<()> { + { + let Some(label_map) = instrs.label_map.as_ref() else { + return Ok(()); + }; + for i in 0..instrs.instr_used { + let entry = &mut instrs.instrs[i]; + if entry.info.instr.has_target() { + let label = u32::from(entry.info.arg) as usize; + debug_assert!(label < instrs.label_map_allocation); + let target = label_map[label]; + debug_assert!(target >= 0); + entry.info.arg = OpArg::new(target as u32); } - } - // Mark frees - for i in 0..nfrees { - let idx = cellfixedoffsets[ncells + i] as usize; - localspluskinds[idx] = CO_FAST_FREE; - } - // Apply CO_FAST_HIDDEN for inlined comprehension variables - for (name, &hidden) in &fast_hidden { - if (hidden || fast_hidden_final.contains(name)) - && let Some(idx) = varname_cache.get_index_of(name.as_str()) - { - localspluskinds[idx] |= CO_FAST_HIDDEN; + let handler = &mut entry.except_handler; + if handler.h_label >= 0 { + let label = handler.h_label as usize; + debug_assert!(label < instrs.label_map_allocation); + handler.h_label = label_map[label]; } } + } + instrs.label_map = None; + instrs.label_map_allocation = 0; + Ok(()) +} - Ok(CodeObject { - flags, - posonlyarg_count, - arg_count, - kwonlyarg_count, - source_path, - first_line_number: Some(first_line_number), - obj_name: obj_name.clone(), - qualname: qualname.unwrap_or(obj_name), - - max_stackdepth, - instructions: CodeUnits::from(instructions), - locations, - constants: constants.into_iter().collect(), - names: name_cache.into_iter().collect(), - varnames: varname_cache.into_iter().collect(), - cellvars: cellvar_cache.into_iter().collect(), - freevars: freevar_cache.into_iter().collect(), - localspluskinds: localspluskinds.into_boxed_slice(), - linetable, - exceptiontable, - }) - } - - fn dce(&mut self) { - // Truncate instructions after terminal instructions within each block - for block in &mut self.blocks { - let mut last_instr = None; - for (i, ins) in block.instructions.iter().enumerate() { - if ins.instr.is_scope_exit() || ins.instr.is_unconditional_jump() { - last_instr = Some(i); - break; - } - } - if let Some(i) = last_instr { - block.instructions.truncate(i + 1); - } - } - } - - fn reorder_entry_prefix_cell_setup(&mut self) { - let Some(entry) = self.blocks.first_mut() else { - return; - }; - let ncells = self.metadata.cellvars.len(); - let nfrees = self.metadata.freevars.len(); - if ncells == 0 && nfrees == 0 { - return; - } - - let prefix_len = entry - .instructions - .iter() - .take_while(|info| { - matches!( - info.instr.real(), - Some(Instruction::MakeCell { .. } | Instruction::CopyFreeVars { .. }) - ) - }) - .count(); - if prefix_len == 0 { - return; - } - - let original_prefix = entry.instructions[..prefix_len].to_vec(); - let anchor = original_prefix[0]; - let rest = entry.instructions.split_off(prefix_len); - entry.instructions.clear(); - - if nfrees > 0 { - entry.instructions.push(InstructionInfo { - instr: Instruction::CopyFreeVars { n: Arg::marker() }.into(), - arg: OpArg::new(nfrees as u32), - ..anchor - }); +/// flowgraph.c _PyCfg_ToInstructionSequence +fn cfg_to_instruction_sequence( + blocks: &mut [Block], + instr_sequence: &mut InstructionSequence, +) -> crate::InternalResult<()> { + let mut label_id = 0; + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + blocks[block_idx.idx()].cpython_label = InstructionSequenceLabel::from_index(label_id); + label_id += 1; + block_idx = blocks[block_idx.idx()].next; + } + + block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let block_label = blocks[block_idx.idx()].cpython_label; + debug_assert!(is_label(block_label)); + instruction_sequence_use_label(instr_sequence, block_label)?; + + let instr_count = blocks[block_idx.idx()].instruction_used; + for i in 0..instr_count { + if blocks[block_idx.idx()].instructions[i].instr.has_target() { + let target_block = blocks[block_idx.idx()].instructions[i].target; + debug_assert!(target_block != BlockIdx::NULL); + let lbl = blocks[target_block.idx()].cpython_label; + debug_assert!(is_label(lbl)); + blocks[block_idx.idx()].instructions[i].arg = OpArg::new(lbl.0 as u32); + } + + let mut info = blocks[block_idx.idx()].instructions[i]; + info.target = BlockIdx::NULL; + let except_handler = info.except_handler.take(); + let entry = instruction_sequence_addop(instr_sequence, info)?; + let hi = &mut entry.except_handler; + if let Some(handler) = except_handler { + debug_assert!(handler.handler_block != BlockIdx::NULL); + let lbl = blocks[handler.handler_block.idx()].cpython_label; + debug_assert!(is_label(lbl)); + let start_depth = blocks[handler.handler_block.idx()].start_depth; + debug_assert!(start_depth >= 0); + hi.h_label = lbl.0; + hi.start_depth = start_depth; + hi.preserve_lasti = i32::from(handler.preserve_lasti); + } else { + hi.h_label = NO_EXCEPTION_HANDLER_LABEL; + } + } + block_idx = blocks[block_idx.idx()].next; + } + + instruction_sequence_apply_label_map(instr_sequence)?; + Ok(()) +} + +/// assemble.c instr_size +fn instr_size(instr: &InstructionInfo) -> usize { + let opcode = instr.instr.expect_real(); + let oparg = u32::from(instr.arg) as i32; + debug_assert!( + instr.instr.has_arg() || oparg == 0, + "CPython assemble.c instr_size requires OPCODE_HAS_ARG or oparg == 0" + ); + let extended_args = + (0xFF_FFFF < oparg) as usize + (0xFF_FF < oparg) as usize + (0xFF < oparg) as usize; + let caches = opcode.cache_entries(); + extended_args + 1 + caches +} + +/// pycore_opcode_metadata.h is_pseudo_target +fn is_pseudo_target(pseudo: PseudoOpcode, target: Opcode) -> bool { + match pseudo { + PseudoOpcode::LoadClosure => matches!(target, Opcode::LoadFast), + PseudoOpcode::StoreFastMaybeNull => matches!(target, Opcode::StoreFast), + PseudoOpcode::AnnotationsPlaceholder + | PseudoOpcode::SetupFinally + | PseudoOpcode::SetupCleanup + | PseudoOpcode::SetupWith + | PseudoOpcode::PopBlock => matches!(target, Opcode::Nop), + PseudoOpcode::Jump => matches!(target, Opcode::JumpForward | Opcode::JumpBackward), + PseudoOpcode::JumpNoInterrupt => { + matches!( + target, + Opcode::JumpForward | Opcode::JumpBackwardNoInterrupt + ) } - - let cellfixedoffsets = build_cellfixedoffsets( - &self.metadata.varnames, - &self.metadata.cellvars, - &self.metadata.freevars, - ); - let mut sorted = vec![None; self.metadata.varnames.len() + ncells]; - for (oldindex, fixed) in cellfixedoffsets.iter().copied().take(ncells).enumerate() { - sorted[fixed as usize] = Some(oldindex); + PseudoOpcode::JumpIfFalse => { + matches!( + target, + Opcode::Copy | Opcode::ToBool | Opcode::PopJumpIfFalse + ) } - for oldindex in sorted.into_iter().flatten() { - entry.instructions.push(InstructionInfo { - instr: Instruction::MakeCell { i: Arg::marker() }.into(), - arg: OpArg::new(oldindex as u32), - ..anchor - }); + PseudoOpcode::JumpIfTrue => { + matches!( + target, + Opcode::Copy | Opcode::ToBool | Opcode::PopJumpIfTrue + ) } - - entry.instructions.extend(rest); } - - /// Clear blocks that are unreachable (not entry, not a jump target, - /// and only reachable via fall-through from a terminal block). - fn eliminate_unreachable_blocks(&mut self) { - let mut reachable = vec![false; self.blocks.len()]; - reachable[0] = true; - - // Fixpoint: only mark targets of already-reachable blocks - let mut changed = true; - while changed { - changed = false; - for i in 0..self.blocks.len() { - if !reachable[i] { - continue; - } - // Mark jump targets and exception handlers - for ins in &self.blocks[i].instructions { - if ins.target != BlockIdx::NULL && !reachable[ins.target.idx()] { - reachable[ins.target.idx()] = true; - changed = true; +} +/// assemble.c resolve_unconditional_jumps +#[allow(clippy::unnecessary_wraps)] +fn resolve_unconditional_jumps( + instr_sequence: &mut InstructionSequence, +) -> crate::InternalResult<()> { + for i in 0..instr_sequence.instr_used { + let instr = &mut instr_sequence.instrs[i].info; + let is_forward = (u32::from(instr.arg) as i32) > i as i32; + match instr.instr { + AnyInstruction::Pseudo(PseudoInstruction::Jump { .. }) => { + debug_assert!(is_pseudo_target(PseudoOpcode::Jump, Opcode::JumpForward)); + debug_assert!(is_pseudo_target(PseudoOpcode::Jump, Opcode::JumpBackward)); + if is_forward { + instr.instr = Instruction::JumpForward { + delta: Arg::marker(), } - if let Some(eh) = &ins.except_handler - && !reachable[eh.handler_block.idx()] - { - reachable[eh.handler_block.idx()] = true; - changed = true; + .into(); + } else { + instr.instr = Instruction::JumpBackward { + delta: Arg::marker(), } + .into(); } - // Mark fall-through - let next = self.blocks[i].next; - if next != BlockIdx::NULL - && !reachable[next.idx()] - && !self.blocks[i].instructions.last().is_some_and(|ins| { - ins.instr.is_scope_exit() || ins.instr.is_unconditional_jump() - }) - { - reachable[next.idx()] = true; - changed = true; + } + AnyInstruction::Pseudo(PseudoInstruction::JumpNoInterrupt { .. }) => { + debug_assert!(is_pseudo_target( + PseudoOpcode::JumpNoInterrupt, + Opcode::JumpForward + )); + debug_assert!(is_pseudo_target( + PseudoOpcode::JumpNoInterrupt, + Opcode::JumpBackwardNoInterrupt + )); + if is_forward { + instr.instr = Instruction::JumpForward { + delta: Arg::marker(), + } + .into(); + } else { + instr.instr = Instruction::JumpBackwardNoInterrupt { + delta: Arg::marker(), + } + .into(); } } - } - - for (i, block) in self.blocks.iter_mut().enumerate() { - if !reachable[i] { - block.instructions.clear(); + _ => { + if instr.instr.has_jump() && matches!(instr.instr, AnyInstruction::Pseudo(_)) { + unreachable!("remaining pseudo jump in resolve_unconditional_jumps"); + } } } } + Ok(()) +} - fn eval_unary_constant( - operand: &ConstantData, - op: Instruction, - intrinsic: Option, - ) -> Option { - match (operand, op, intrinsic) { - (ConstantData::Integer { value }, Instruction::UnaryNegative, None) => { - Some(ConstantData::Integer { value: -value }) - } - (ConstantData::Float { value }, Instruction::UnaryNegative, None) => { - Some(ConstantData::Float { value: -value }) - } - (ConstantData::Complex { value }, Instruction::UnaryNegative, None) => { - Some(ConstantData::Complex { value: -value }) - } - (ConstantData::Boolean { value }, Instruction::UnaryNegative, None) => { - Some(ConstantData::Integer { - value: BigInt::from(-i32::from(*value)), - }) - } - (ConstantData::Integer { value }, Instruction::UnaryInvert, None) => { - Some(ConstantData::Integer { value: !value }) - } - (ConstantData::Boolean { .. }, Instruction::UnaryInvert, None) => None, - ( - ConstantData::Integer { value }, - Instruction::CallIntrinsic1 { .. }, - Some(oparg::IntrinsicFunction1::UnaryPositive), - ) => Some(ConstantData::Integer { - value: value.clone(), - }), - ( - ConstantData::Float { value }, - Instruction::CallIntrinsic1 { .. }, - Some(oparg::IntrinsicFunction1::UnaryPositive), - ) => Some(ConstantData::Float { value: *value }), - ( - ConstantData::Boolean { value }, - Instruction::CallIntrinsic1 { .. }, - Some(oparg::IntrinsicFunction1::UnaryPositive), - ) => Some(ConstantData::Integer { - value: BigInt::from(i32::from(*value)), - }), - ( - ConstantData::Complex { value }, - Instruction::CallIntrinsic1 { .. }, - Some(oparg::IntrinsicFunction1::UnaryPositive), - ) => Some(ConstantData::Complex { value: *value }), - _ => None, +/// assemble.c resolve_jump_offsets +#[allow(clippy::needless_range_loop, clippy::unnecessary_wraps)] +fn resolve_jump_offsets(instr_sequence: &mut InstructionSequence) -> crate::InternalResult<()> { + // The offset (in code units) of END_SEND from SEND in the yield-from sequence. + const END_SEND_OFFSET: i32 = 5; + for i in 0..instr_sequence.instr_used { + let instr = &mut instr_sequence.instrs[i]; + let opcode = instr.info.instr.expect_real(); + if opcode.has_jump() { + instr.i_target = u32::from(instr.info.arg) as i32; } } - - /// Fold constant unary operations following CPython fold_const_unaryop(). - fn fold_unary_constants(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i < block.instructions.len() { - let instr = &block.instructions[i]; - let (op, intrinsic) = match instr.instr.real() { - Some(Instruction::UnaryNegative) => (Instruction::UnaryNegative, None), - Some(Instruction::UnaryInvert) => (Instruction::UnaryInvert, None), - Some(Instruction::CallIntrinsic1 { func }) - if matches!( - func.get(instr.arg), - oparg::IntrinsicFunction1::UnaryPositive - ) => - { - ( - Instruction::CallIntrinsic1 { - func: Arg::marker(), - }, - Some(func.get(instr.arg)), - ) - } - _ => { - i += 1; - continue; - } - }; - let Some(operand_index) = i - .checked_sub(1) - .and_then(|start| Self::get_const_loading_instr_indices(block, start, 1)) - .and_then(|indices| indices.into_iter().next()) - else { - i += 1; - continue; - }; - let operand = - Self::get_const_value_from(&self.metadata, &block.instructions[operand_index]); - if let Some(operand) = operand - && let Some(folded_const) = Self::eval_unary_constant(&operand, op, intrinsic) - { - let (const_idx, _) = self.metadata.consts.insert_full(folded_const); - nop_out_no_location(&mut block.instructions[operand_index]); - let mut prev = operand_index; - while let Some(idx) = prev.checked_sub(1) { - if !matches!(block.instructions[idx].instr.real(), Some(Instruction::Nop)) { - break; - } - block.instructions[idx].location = block.instructions[i].location; - block.instructions[idx].end_location = block.instructions[i].end_location; - prev = idx; - } - block.instructions[i].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - block.instructions[i].folded_from_nonliteral_expr = false; - i = i.saturating_sub(1); + let mut extended_arg_recompile; + loop { + let mut totsize = 0i32; + for i in 0..instr_sequence.instr_used { + let instr = &mut instr_sequence.instrs[i]; + instr.i_offset = totsize; + let isize = instr_size(&instr.info); + totsize += isize as i32; + } + extended_arg_recompile = false; + let mut offset = 0i32; + for i in 0..instr_sequence.instr_used { + let isize = instr_size(&instr_sequence.instrs[i].info); + // Jump offsets are computed relative to the instruction pointer + // after fetching the jump instruction. + offset += isize as i32; + + let opcode = instr_sequence.instrs[i].info.instr.expect_real(); + if opcode.has_jump() { + let target = instr_sequence.instrs[i].i_target; + let target_offset = instr_sequence.instrs[target as usize].i_offset; + let info = &mut instr_sequence.instrs[i].info; + let op = opcode; + let mut oparg = target_offset; + info.arg = OpArg::new(oparg as u32); + if matches!(op, Instruction::EndAsyncFor) { + oparg = offset - oparg - END_SEND_OFFSET; + } else if oparg < offset { + debug_assert!(matches!( + op.into(), + Opcode::JumpBackward | Opcode::JumpBackwardNoInterrupt + )); + oparg = offset - oparg; } else { - i += 1; + debug_assert!(!matches!( + op.into(), + Opcode::JumpBackward | Opcode::JumpBackwardNoInterrupt + )); + oparg -= offset; } - } - } - } - - fn get_const_loading_instr_indices( - block: &Block, - mut start: usize, - size: usize, - ) -> Option> { - let mut indices = Vec::with_capacity(size); - loop { - let instr = block.instructions.get(start)?; - if !matches!(instr.instr.real(), Some(Instruction::Nop)) { - Self::get_const_value_from_dummy(instr)?; - indices.push(start); - if indices.len() == size { - break; + info.arg = OpArg::new(oparg as u32); + if instr_size(info) != isize { + extended_arg_recompile = true; } } - start = start.checked_sub(1)?; } - indices.reverse(); - Some(indices) - } - - fn get_const_sequence( - metadata: &CodeUnitMetadata, - block: &Block, - build_index: usize, - size: usize, - ) -> Option<(Vec, Vec)> { - if size == 0 { - return Some((Vec::new(), Vec::new())); - } - - let operand_indices = build_index - .checked_sub(1) - .and_then(|start| Self::get_const_loading_instr_indices(block, start, size))?; - let mut elements = Vec::with_capacity(size); - for &j in &operand_indices { - let load_instr = &block.instructions[j]; - if load_instr.folded_from_nonliteral_expr { - return None; - } - elements.push(Self::get_const_value_from(metadata, load_instr)?); + if !extended_arg_recompile { + break; } + } + + Ok(()) +} + +struct AssembledCode { + instructions: Vec, + linetable: Box<[u8]>, + exceptiontable: Box<[u8]>, +} + +struct LocalsPlusInfo { + cellvars: Box<[String]>, + kinds: Box<[u8]>, +} + +/// assemble.c same_location +fn same_location(a: LineTableLocation, b: LineTableLocation) -> bool { + a.line == b.line && a.end_line == b.end_line && a.col == b.col && a.end_col == b.end_col +} + +fn instruction_linetable_location(info: &InstructionInfo) -> LineTableLocation { + match info.lineno_override { + Some(NO_LOCATION_OVERRIDE) => LineTableLocation { + line: NO_LOCATION_OVERRIDE, + end_line: NO_LOCATION_OVERRIDE, + col: NO_LOCATION_OVERRIDE, + end_col: NO_LOCATION_OVERRIDE, + }, + Some(LINE_ONLY_LOCATION_OVERRIDE) => LineTableLocation { + line: info.location.line.get() as i32, + end_line: info.end_location.line.get() as i32, + col: -1, + end_col: -1, + }, + Some(NEXT_LOCATION_OVERRIDE) => next_linetable_location(), + Some(lineno) => LineTableLocation { + line: lineno, + end_line: info.end_location.line.get() as i32, + col: info.location.character_offset.to_zero_indexed() as i32, + end_col: info.end_location.character_offset.to_zero_indexed() as i32, + }, + None => LineTableLocation { + line: info.location.line.get() as i32, + end_line: info.end_location.line.get() as i32, + col: info.location.character_offset.to_zero_indexed() as i32, + end_col: info.end_location.character_offset.to_zero_indexed() as i32, + }, + } +} + +/// assemble.c write_instr +fn write_instr(instructions: &mut Vec, info: &InstructionInfo, ilen: usize) { + let opcode = info.instr.expect_real(); + let oparg = u32::from(info.arg) as i32; + debug_assert!( + info.instr.has_arg() || oparg == 0, + "CPython assemble.c write_instr requires OPCODE_HAS_ARG or oparg == 0" + ); + let caches = opcode.cache_entries(); + let non_cache_units = ilen - caches; + match non_cache_units { + 1..=4 => {} + _ => unreachable!("CPython write_instr expects 1 to 4 non-cache code units"), + } + if non_cache_units >= 4 { + instructions.push(CodeUnit::new( + Instruction::ExtendedArg, + OpArgByte::new(((oparg >> 24) & 0xff) as u8), + )); + } + if non_cache_units >= 3 { + instructions.push(CodeUnit::new( + Instruction::ExtendedArg, + OpArgByte::new(((oparg >> 16) & 0xff) as u8), + )); + } + if non_cache_units >= 2 { + instructions.push(CodeUnit::new( + Instruction::ExtendedArg, + OpArgByte::new(((oparg >> 8) & 0xff) as u8), + )); + } + instructions.push(CodeUnit::new(opcode, OpArgByte::new((oparg & 0xff) as u8))); + for _ in 0..caches { + instructions.push(CodeUnit::new(Instruction::Cache, OpArgByte::new(0))); + } +} - Some((operand_indices, elements)) +/// assemble.c assemble_emit_instr +fn assemble_emit_instr( + instructions: &mut Vec, + info: &mut InstructionInfo, +) -> crate::InternalResult<()> { + let size = instr_size(info); + let required = instructions + .len() + .checked_add(size) + .ok_or(InternalError::MalformedControlFlowGraph)?; + if required >= instructions.capacity() { + vec_try_resize_to_double_capacity(instructions)?; } + write_instr(instructions, info, size); + Ok(()) +} - fn get_non_nop_instr_indices(block: &Block, start: usize, count: usize) -> Option> { - let mut indices = Vec::with_capacity(count); - for idx in start..block.instructions.len() { - if !matches!(block.instructions[idx].instr.real(), Some(Instruction::Nop)) { - indices.push(idx); - if indices.len() == count { - return Some(indices); - } +/// assemble.c assemble_location_info +#[allow(clippy::needless_range_loop)] +fn assemble_location_info( + instr_sequence: &mut InstructionSequence, + first_line: i32, + debug_ranges: bool, +) -> crate::InternalResult> { + for i in (0..instr_sequence.instr_used).rev() { + let loc = instruction_linetable_location(&instr_sequence.instrs[i].info); + if same_location(loc, next_linetable_location()) { + if instr_sequence.instrs[i] + .info + .instr + .expect_real() + .is_terminator() + { + instr_sequence.instrs[i].info.lineno_override = Some(NO_LOCATION_OVERRIDE); + } else { + debug_assert!(i < instr_sequence.instr_used - 1); + let next = instr_sequence.instrs[i + 1].info; + instr_set_loc( + &mut instr_sequence.instrs[i].info, + next.location, + next.end_location, + next.lineno_override, + ); } } - None } - /// Constant folding: fold LOAD_CONST/LOAD_SMALL_INT + LOAD_CONST/LOAD_SMALL_INT + BINARY_OP - /// into a single LOAD_CONST when the result is computable at compile time. - /// = fold_binops_on_constants in CPython flowgraph.c - fn fold_binop_constants(&mut self) { - use oparg::BinaryOperator as BinOp; - - for block in &mut self.blocks { - let mut i = 0; - while i < block.instructions.len() { - let Some(Instruction::BinaryOp { .. }) = block.instructions[i].instr.real() else { - i += 1; - continue; - }; + let mut linetable = Vec::new(); + vec_try_reserve_exact(&mut linetable, DEFAULT_CNOTAB_SIZE)?; + let mut prev_line = first_line; + let mut loc = no_linetable_location(); + let mut size = 0; + for i in 0..instr_sequence.instr_used { + let entry = &instr_sequence.instrs[i]; + let instr_loc = instruction_linetable_location(&entry.info); + if !same_location(loc, instr_loc) { + assemble_emit_location(&mut linetable, loc, size, &mut prev_line, debug_ranges)?; + loc = instr_loc; + size = 0; + } + size += instr_size(&entry.info); + } + assemble_emit_location(&mut linetable, loc, size, &mut prev_line, debug_ranges)?; + Ok(linetable.into_boxed_slice()) +} - let Some(operand_indices) = i - .checked_sub(1) - .and_then(|start| Self::get_const_loading_instr_indices(block, start, 2)) - else { - i += 1; - continue; - }; +/// assemble.c assemble_emit +fn assemble_emit( + instr_sequence: &mut InstructionSequence, + first_line: i32, + debug_ranges: bool, +) -> crate::InternalResult { + let mut instructions = Vec::new(); + vec_try_reserve_exact( + &mut instructions, + DEFAULT_CODE_SIZE / core::mem::size_of::(), + )?; - let op_raw = u32::from(block.instructions[i].arg); - let Ok(op) = BinOp::try_from(op_raw) else { - i += 1; - continue; - }; + for i in 0..instr_sequence.instr_used { + let instr = &mut instr_sequence.instrs[i].info; + assemble_emit_instr(&mut instructions, instr)?; + } - let left = Self::get_const_value_from( - &self.metadata, - &block.instructions[operand_indices[0]], - ); - let right = Self::get_const_value_from( - &self.metadata, - &block.instructions[operand_indices[1]], - ); + let linetable = assemble_location_info(instr_sequence, first_line, debug_ranges)?; - let (Some(left_val), Some(right_val)) = (left, right) else { - i += 1; - continue; - }; + let exceptiontable = + assemble_exception_table(&instr_sequence.instrs[..instr_sequence.instr_used])?; - let result = Self::eval_binop(&left_val, &right_val, op); + Ok(AssembledCode { + instructions, + linetable, + exceptiontable, + }) +} - if let Some(result_const) = result { - let (const_idx, _) = self.metadata.consts.insert_full(result_const); - let folded_from_nonliteral_expr = operand_indices - .iter() - .any(|&idx| block.instructions[idx].folded_from_nonliteral_expr); - for &idx in &operand_indices { - nop_out_no_location(&mut block.instructions[idx]); - } - block.instructions[i].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - block.instructions[i].folded_from_nonliteral_expr = folded_from_nonliteral_expr; - i = i.saturating_sub(1); // re-check with previous instruction - } else { - i += 1; - } +/// assemble.c compute_localsplus_info +fn compute_localsplus_info( + umd: &CodeUnitMetadata, + nlocalsplus: usize, + flags: CodeFlags, +) -> crate::InternalResult { + let nlocals = umd.varnames.len(); + let ncells = umd.cellvars.len(); + let nfrees = umd.freevars.len(); + let mut localspluskinds = Vec::new(); + vec_try_reserve_exact(&mut localspluskinds, nlocalsplus)?; + localspluskinds.resize(nlocalsplus, 0); + let mut cellvars = Vec::new(); + vec_try_reserve_exact(&mut cellvars, ncells)?; + + let argvarkinds = [ + (umd.posonlyargcount as usize, CO_FAST_ARG_POS), + (umd.argcount as usize, CO_FAST_ARG_POS | CO_FAST_ARG_KW), + (umd.kwonlyargcount as usize, CO_FAST_ARG_KW), + ( + usize::from(flags.contains(CodeFlags::VARARGS)), + CO_FAST_ARG_VAR | CO_FAST_ARG_POS, + ), + ( + usize::from(flags.contains(CodeFlags::VARKEYWORDS)), + CO_FAST_ARG_VAR | CO_FAST_ARG_KW, + ), + (usize::MAX, 0), + ]; + let mut pos = 0usize; + let mut max = 0usize; + for (count, argkind) in argvarkinds { + max = if count == usize::MAX { + usize::MAX + } else { + max + count + }; + while pos < max && pos < nlocals { + let name = umd + .varnames + .get_index(pos) + .expect("varname index is in range") + .as_str(); + let mut kind = CO_FAST_LOCAL | argkind; + if umd.fast_hidden.get(name).copied().unwrap_or(false) + || umd.fast_hidden_final.contains(name) + { + kind |= CO_FAST_HIDDEN; } + if umd.cellvars.contains(name) { + kind |= CO_FAST_CELL; + cellvars.push(name.to_owned()); + } + localspluskinds[pos] = kind; + pos += 1; } } - fn get_const_value_from_dummy(info: &InstructionInfo) -> Option<()> { - match info.instr.real() { - Some(Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. }) => Some(()), - _ => None, + let mut numdropped = 0usize; + let mut cellvar_offset = -1i32; + for i in 0..ncells { + let name = umd + .cellvars + .get_index(i) + .expect("cellvar index is in range") + .as_str(); + if umd.varnames.contains(name) { + numdropped += 1; + continue; } + let offset = i + nlocals - numdropped; + debug_assert!(offset < nlocalsplus); + cellvars.push(name.to_owned()); + localspluskinds[offset] = CO_FAST_CELL; + cellvar_offset = offset as i32; } - fn get_const_value_from( - metadata: &CodeUnitMetadata, - info: &InstructionInfo, - ) -> Option { - match info.instr.real() { - Some(Instruction::LoadConst { .. }) => { - let idx = u32::from(info.arg) as usize; - metadata.consts.get_index(idx).cloned() - } - Some(Instruction::LoadSmallInt { .. }) => { - let v = u32::from(info.arg) as i32; - Some(ConstantData::Integer { - value: BigInt::from(v), - }) - } - _ => None, - } - } + for i in 0..nfrees { + let offset = ncells + i + nlocals - numdropped; + debug_assert!(offset < nlocalsplus); + debug_assert!((offset as i32) > cellvar_offset); + localspluskinds[offset] = CO_FAST_FREE; + } + + debug_assert_eq!( + nlocalsplus, + nlocals + ncells - numdropped + nfrees, + "CPython prepare_localsplus() result must match assemble.c localsplus sizing" + ); + debug_assert_eq!(cellvars.len(), ncells); + Ok(LocalsPlusInfo { + cellvars: cellvars.into_boxed_slice(), + kinds: localspluskinds.into_boxed_slice(), + }) +} - fn const_folding_check_complexity(obj: &ConstantData, mut limit: isize) -> Option { - if let ConstantData::Tuple { elements } = obj { - limit -= isize::try_from(elements.len()).ok()?; - if limit < 0 { - return None; - } - for element in elements { - limit = Self::const_folding_check_complexity(element, limit)?; - } +#[derive(Debug, Clone)] +pub struct Block { + /// CPython `basicblock.b_list`, allocation-order list distinct from CFG `b_next`. + allocation_next: BlockIdx, + /// CPython `basicblock.b_label` used by translate_jump_labels_to_targets. + cpython_label: InstructionSequenceLabel, + /// CPython `basicblock.b_ialloc`, the allocated size of `b_instr`. + instruction_allocation: usize, + /// Exception stack at start of block, used by label_exception_targets (b_exceptstack) + except_stack: Option, + /// CPython `basicblock.b_instr`, including allocated slots beyond `b_iused`. + pub instructions: Vec, + pub next: BlockIdx, + /// CPython `basicblock.b_iused`, the number of used entries in `b_instr`. + instruction_used: usize, + /// Potentially uninitialized locals mask for local-check analysis (b_unsafe_locals_mask) + unsafe_locals_mask: u64, + /// Number of incoming CFG edges from reachable blocks (b_predecessors) + predecessors: i32, + /// Stack depth at block entry, set by stack depth analysis + pub start_depth: i32, + /// Whether to preserve lasti for this handler block (b_preserve_lasti) + pub preserve_lasti: bool, + /// Temporary traversal mark used by CFG passes (b_visited) + visited: bool, + /// Whether this block is an exception handler target (b_except_handler) + pub except_handler: bool, + /// Whether this block is only reachable via exception table (b_cold) + pub cold: bool, + /// Definitely reachable outside exception-only paths (b_warm) + warm: bool, +} + +impl Default for Block { + fn default() -> Self { + Self { + allocation_next: BlockIdx::NULL, + cpython_label: InstructionSequenceLabel::NO_LABEL, + instruction_allocation: 0, + except_stack: None, + instructions: Vec::new(), + next: BlockIdx::NULL, + instruction_used: 0, + unsafe_locals_mask: 0, + predecessors: 0, + start_depth: START_DEPTH_UNSET, + preserve_lasti: false, + visited: false, + except_handler: false, + cold: false, + warm: false, } - Some(limit) } +} - fn eval_binop( - left: &ConstantData, - right: &ConstantData, - op: oparg::BinaryOperator, - ) -> Option { - use oparg::BinaryOperator as BinOp; +impl Block { + pub(crate) fn used_instructions(&self) -> &[InstructionInfo] { + &self.instructions[..self.instruction_used] + } - fn repeat_wtf8(value: &Wtf8Buf, n: usize) -> Wtf8Buf { - let mut result = Wtf8Buf::with_capacity(value.len().saturating_mul(n)); - for _ in 0..n { - result.push_wtf8(value); - } - result - } + pub(crate) fn is_empty(&self) -> bool { + self.instruction_used == 0 + } +} - fn checked_repeat_count(n: &BigInt, item_size: usize) -> Option { - let n = n.to_isize()?; - if item_size != 0 && (n < 0 || n as usize > MAX_STR_SIZE / item_size) { - return None; - } - Some(n.max(0) as usize) - } +pub(crate) const START_DEPTH_UNSET: i32 = i32::MIN; +const CO_MAXBLOCKS: usize = 20; - fn eval_complex_binop( - left: Complex, - right: Complex, - op: BinOp, - ) -> Option { - fn complex_const(value: Complex) -> Option { - (value.re.is_finite() && value.im.is_finite()) - .then_some(ConstantData::Complex { value }) - } +/// flowgraph.c struct _PyCfgExceptStack +#[derive(Clone, Debug)] +struct CfgExceptStack { + handlers: [BlockIdx; CO_MAXBLOCKS + 2], + depth: usize, +} - let value = match op { - BinOp::Add => left + right, - BinOp::Subtract => { - let re = left.re - right.re; - // Preserve CPython's signed-zero behavior for real-zero - // minus zero-complex expressions such as `0 - 0j`. - let im = if left.re == 0.0 - && left.im == 0.0 - && right.re == 0.0 - && right.im == 0.0 - && !right.im.is_sign_negative() - { - -0.0 - } else { - left.im - right.im - }; - Complex::new(re, im) - } - BinOp::Multiply => left * right, - BinOp::TrueDivide => { - if right == Complex::new(0.0, 0.0) { - return None; - } - left / right - } - BinOp::Power => { - if left == Complex::new(0.0, 0.0) { - if right.im != 0.0 || right.re < 0.0 { - return None; - } +/// flowgraph.c `basicblock **stack` +#[derive(Clone, Debug)] +struct CfgTraversalStack { + stack: Vec, + sp: usize, +} - return complex_const(if right.re == 0.0 { - Complex::new(1.0, 0.0) - } else { - Complex::new(0.0, 0.0) - }); - } +impl CfgTraversalStack { + fn push(&mut self, block: BlockIdx) { + debug_assert!(self.sp < self.stack.len()); + self.stack[self.sp] = block; + self.sp += 1; + } - if right.im == 0.0 - && right.re.fract() == 0.0 - && right.re >= f64::from(i32::MIN) - && right.re <= f64::from(i32::MAX) - { - left.powi(right.re as i32) - } else { - left.powc(right) - } - } - _ => return None, - }; - complex_const(value) + fn pop(&mut self) -> Option { + if self.sp == 0 { + return None; } + self.sp -= 1; + Some(self.stack[self.sp]) + } - fn float_div_mod(left: f64, right: f64) -> Option<(f64, f64)> { - if right == 0.0 { - return None; - } - - let mut modulo = left % right; - let div = (left - modulo) / right; - let floordiv = if modulo != 0.0 { - let div = if (right < 0.0) != (modulo < 0.0) { - modulo += right; - div - 1.0 - } else { - div - }; - let mut floordiv = div.floor(); - if div - floordiv > 0.5 { - floordiv += 1.0; - } - floordiv - } else { - modulo = 0.0f64.copysign(right); - 0.0f64.copysign(left / right) - }; - - Some((floordiv, modulo)) - } - - fn constant_as_index(value: &ConstantData) -> Option { - match value { - ConstantData::Integer { value } => value.to_i64().or_else(|| { - if value < &BigInt::from(0) { - Some(i64::MIN) - } else { - Some(i64::MAX) - } - }), - ConstantData::Boolean { value } => Some(i64::from(*value)), - _ => None, - } - } - - fn slice_bound(value: &ConstantData) -> Option> { - match value { - ConstantData::None => Some(None), - _ => constant_as_index(value).map(Some), - } - } - - fn adjusted_slice_indices(len: usize, slice: &[ConstantData; 3]) -> Option> { - let len = i64::try_from(len).ok()?; - let start = slice_bound(&slice[0])?; - let stop = slice_bound(&slice[1])?; - let step = slice_bound(&slice[2])?.unwrap_or(1); - if step == 0 || step == i64::MIN { - return None; - } - - let step_is_negative = step < 0; - let lower = if step_is_negative { -1 } else { 0 }; - let upper = if step_is_negative { len - 1 } else { len }; - let adjust = |value: Option, default: i64| { - let mut value = value.unwrap_or(default); - if value < 0 { - value = value.saturating_add(len); - if value < 0 { - value = lower; - } - } else if value >= len { - value = upper; - } - value - }; - let start = adjust(start, if step_is_negative { upper } else { lower }); - let stop = adjust(stop, if step_is_negative { lower } else { upper }); - - let mut indices = Vec::new(); - let mut index = i128::from(start); - let stop = i128::from(stop); - let step = i128::from(step); - if step > 0 { - while index < stop { - indices.push(usize::try_from(index).ok()?); - index += step; - } - } else { - while index > stop { - indices.push(usize::try_from(index).ok()?); - index += step; - } - } - Some(indices) - } - - fn adjusted_const_index(len: usize, index: &ConstantData) -> Option { - let len = i64::try_from(len).ok()?; - let index = constant_as_index(index)?; - let index = if index < 0 { - index.saturating_add(len) - } else { - index - }; - if index < 0 || index >= len { - return None; - } - usize::try_from(index).ok() - } - - fn eval_const_subscript( - container: &ConstantData, - index: &ConstantData, - ) -> Option { - match (container, index) { - ( - ConstantData::Str { value }, - ConstantData::Integer { .. } | ConstantData::Boolean { .. }, - ) => { - let string = value.to_string(); - if string.contains(char::REPLACEMENT_CHARACTER) { - return None; - } - let chars = string.chars().collect::>(); - let index = adjusted_const_index(chars.len(), index)?; - Some(ConstantData::Str { - value: chars[index].to_string().into(), - }) - } - (ConstantData::Str { value }, ConstantData::Slice { elements }) => { - let string = value.to_string(); - if string.contains(char::REPLACEMENT_CHARACTER) { - return None; - } - let chars = string.chars().collect::>(); - let mut result = String::new(); - for index in adjusted_slice_indices(chars.len(), elements)? { - result.push(chars[index]); - } - Some(ConstantData::Str { - value: result.into(), - }) - } - ( - ConstantData::Bytes { value }, - ConstantData::Integer { .. } | ConstantData::Boolean { .. }, - ) => { - let index = adjusted_const_index(value.len(), index)?; - Some(ConstantData::Integer { - value: BigInt::from(value[index]), - }) - } - (ConstantData::Bytes { value }, ConstantData::Slice { elements }) => { - let mut result = Vec::new(); - for index in adjusted_slice_indices(value.len(), elements)? { - result.push(value[index]); - } - Some(ConstantData::Bytes { value: result }) - } - ( - ConstantData::Tuple { elements }, - ConstantData::Integer { .. } | ConstantData::Boolean { .. }, - ) => { - let index = adjusted_const_index(elements.len(), index)?; - Some(elements[index].clone()) - } - (ConstantData::Tuple { elements }, ConstantData::Slice { elements: slice }) => { - let elements = adjusted_slice_indices(elements.len(), slice)? - .into_iter() - .map(|index| elements[index].clone()) - .collect(); - Some(ConstantData::Tuple { elements }) - } - _ => None, - } - } - - if matches!(op, BinOp::Subscr) { - return eval_const_subscript(left, right); - } - - match (left, right) { - (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) => { - let result = match op { - BinOp::Add => l + r, - BinOp::Subtract => l - r, - BinOp::Multiply => { - if !l.is_zero() && !r.is_zero() && l.bits() + r.bits() > MAX_INT_SIZE_BITS { - return None; - } - l * r - } - BinOp::TrueDivide => { - if r.is_zero() { - return None; - } - let l_f = l.to_f64()?; - let r_f = r.to_f64()?; - let result = l_f / r_f; - if !result.is_finite() { - return None; - } - return Some(ConstantData::Float { value: result }); - } - BinOp::FloorDivide => { - if r.is_zero() { - return None; - } - // Python floor division: round towards negative infinity - let (q, rem) = (l.clone() / r.clone(), l.clone() % r.clone()); - if !rem.is_zero() && (rem < BigInt::from(0)) != (*r < BigInt::from(0)) { - q - 1 - } else { - q - } - } - BinOp::Remainder => { - if r.is_zero() { - return None; - } - // Python modulo: result has same sign as divisor - let rem = l.clone() % r.clone(); - if !rem.is_zero() && (rem < BigInt::from(0)) != (*r < BigInt::from(0)) { - rem + r - } else { - rem - } - } - BinOp::Power => { - if r < &BigInt::from(0) { - if l.is_zero() { - return None; - } - let base = l.to_f64()?; - if !base.is_finite() { - return None; - } - let result = if let Some(exp) = r.to_i32() { - base.powi(exp) - } else { - base.powf(r.to_f64()?) - }; - if !result.is_finite() { - return None; - } - return Some(ConstantData::Float { value: result }); - } - let exp: u64 = r.try_into().ok()?; - let exp_usize = usize::try_from(exp).ok()?; - if !l.is_zero() && exp > 0 && l.bits() > MAX_INT_SIZE_BITS / exp { - return None; - } - num_traits::pow::pow(l.clone(), exp_usize) - } - BinOp::Lshift => { - let shift: u64 = r.try_into().ok()?; - let shift_usize = usize::try_from(shift).ok()?; - if shift > MAX_INT_SIZE_BITS - || (!l.is_zero() && l.bits() > MAX_INT_SIZE_BITS - shift) - { - return None; - } - l << shift_usize - } - BinOp::Rshift => { - let shift: u32 = r.try_into().ok()?; - l >> (shift as usize) - } - BinOp::And => l & r, - BinOp::Or => l | r, - BinOp::Xor => l ^ r, - _ => return None, - }; - Some(ConstantData::Integer { value: result }) - } - (ConstantData::Float { value: l }, ConstantData::Float { value: r }) => { - let result = match op { - BinOp::Add => l + r, - BinOp::Subtract => l - r, - BinOp::Multiply => l * r, - BinOp::TrueDivide => { - if *r == 0.0 { - return None; - } - l / r - } - BinOp::FloorDivide => { - let (floordiv, _) = float_div_mod(*l, *r)?; - floordiv - } - BinOp::Remainder => { - let (_, modulo) = float_div_mod(*l, *r)?; - modulo - } - BinOp::Power => l.powf(*r), - _ => return None, - }; - if matches!(op, BinOp::Power) && !result.is_finite() { - return None; - } - Some(ConstantData::Float { value: result }) - } - // Int op Float or Float op Int → Float - (ConstantData::Integer { value: l }, ConstantData::Float { value: r }) => { - let l_f = l.to_f64()?; - Self::eval_binop( - &ConstantData::Float { value: l_f }, - &ConstantData::Float { value: *r }, - op, - ) - } - (ConstantData::Float { value: l }, ConstantData::Integer { value: r }) => { - let r_f = r.to_f64()?; - Self::eval_binop( - &ConstantData::Float { value: *l }, - &ConstantData::Float { value: r_f }, - op, - ) - } - (ConstantData::Integer { value: l }, ConstantData::Complex { value: r }) => { - eval_complex_binop(Complex::new(l.to_f64()?, 0.0), *r, op) - } - (ConstantData::Complex { value: l }, ConstantData::Integer { value: r }) => { - eval_complex_binop(*l, Complex::new(r.to_f64()?, 0.0), op) - } - (ConstantData::Float { value: l }, ConstantData::Complex { value: r }) => { - eval_complex_binop(Complex::new(*l, 0.0), *r, op) - } - (ConstantData::Complex { value: l }, ConstantData::Float { value: r }) => { - eval_complex_binop(*l, Complex::new(*r, 0.0), op) - } - (ConstantData::Complex { value: l }, ConstantData::Complex { value: r }) => { - eval_complex_binop(*l, *r, op) - } - // String concatenation and repetition - (ConstantData::Str { value: l }, ConstantData::Str { value: r }) - if matches!(op, BinOp::Add) => - { - let mut result = l.clone(); - result.push_wtf8(r); - Some(ConstantData::Str { value: result }) - } - (ConstantData::Str { value: s }, ConstantData::Integer { value: n }) - if matches!(op, BinOp::Multiply) => - { - let n = checked_repeat_count(n, s.code_points().count())?; - let result = repeat_wtf8(s, n); - Some(ConstantData::Str { value: result }) - } - (ConstantData::Tuple { elements: l }, ConstantData::Tuple { elements: r }) - if matches!(op, BinOp::Add) => - { - let mut result = l.clone(); - result.extend(r.iter().cloned()); - Some(ConstantData::Tuple { elements: result }) - } - (ConstantData::Tuple { elements }, ConstantData::Integer { value: n }) - if matches!(op, BinOp::Multiply) => - { - let n = n.to_usize()?; - if n != 0 && !elements.is_empty() { - if n > MAX_COLLECTION_SIZE / elements.len() { - return None; - } - Self::const_folding_check_complexity( - &ConstantData::Tuple { - elements: elements.clone(), - }, - MAX_TOTAL_ITEMS / isize::try_from(n).ok()?, - )?; - } - let mut result = Vec::with_capacity(elements.len() * n); - for _ in 0..n { - result.extend(elements.iter().cloned()); - } - Some(ConstantData::Tuple { elements: result }) - } - (ConstantData::Integer { value: n }, ConstantData::Tuple { elements }) - if matches!(op, BinOp::Multiply) => - { - let n = n.to_usize()?; - if n != 0 && !elements.is_empty() { - if n > MAX_COLLECTION_SIZE / elements.len() { - return None; - } - Self::const_folding_check_complexity( - &ConstantData::Tuple { - elements: elements.clone(), - }, - MAX_TOTAL_ITEMS / isize::try_from(n).ok()?, - )?; - } - let mut result = Vec::with_capacity(elements.len() * n); - for _ in 0..n { - result.extend(elements.iter().cloned()); - } - Some(ConstantData::Tuple { elements: result }) - } - (ConstantData::Integer { value: n }, ConstantData::Str { value: s }) - if matches!(op, BinOp::Multiply) => - { - let n = checked_repeat_count(n, s.code_points().count())?; - let result = repeat_wtf8(s, n); - Some(ConstantData::Str { value: result }) - } - (ConstantData::Bytes { value: l }, ConstantData::Bytes { value: r }) - if matches!(op, BinOp::Add) => - { - let mut result = l.clone(); - result.extend_from_slice(r); - Some(ConstantData::Bytes { value: result }) - } - (ConstantData::Bytes { value: b }, ConstantData::Integer { value: n }) - if matches!(op, BinOp::Multiply) => - { - let n = checked_repeat_count(n, b.len())?; - Some(ConstantData::Bytes { value: b.repeat(n) }) - } - (ConstantData::Integer { value: n }, ConstantData::Bytes { value: b }) - if matches!(op, BinOp::Multiply) => - { - let n = checked_repeat_count(n, b.len())?; - Some(ConstantData::Bytes { value: b.repeat(n) }) - } - _ => None, - } - } - - fn fold_tuple_constant_at( - metadata: &mut CodeUnitMetadata, - block: &mut Block, - i: usize, - ) -> bool { - let Some(Instruction::BuildTuple { .. }) = block.instructions[i].instr.real() else { - return false; - }; - - let tuple_size = u32::from(block.instructions[i].arg) as usize; - if tuple_size <= 3 - && block - .instructions - .get(i + 1) - .and_then(|next| next.instr.real()) - .is_some_and(|next| { - matches!( - next, - Instruction::UnpackSequence { .. } - if usize::try_from(u32::from(block.instructions[i + 1].arg)).ok() - == Some(tuple_size) - ) - }) - { - return false; - } - if tuple_size == 0 { - let (const_idx, _) = metadata.consts.insert_full(ConstantData::Tuple { - elements: Vec::new(), - }); - block.instructions[i].instr = Opcode::LoadConst.into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - return true; - } - - let folded_from_nonliteral_expr = block.instructions[i].folded_from_nonliteral_expr; - let Some((operand_indices, elements)) = - Self::get_const_sequence(metadata, block, i, tuple_size) - else { - return false; - }; - - let (const_idx, _) = metadata - .consts - .insert_full(ConstantData::Tuple { elements }); - - for &j in &operand_indices { - nop_out_no_location(&mut block.instructions[j]); - } - - block.instructions[i].instr = Opcode::LoadConst.into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - block.instructions[i].folded_from_nonliteral_expr = folded_from_nonliteral_expr; - true - } - - fn fold_constant_intrinsic_list_to_tuple_at( - metadata: &mut CodeUnitMetadata, - block: &mut Block, - i: usize, - ) -> bool { - let Some(Instruction::CallIntrinsic1 { func }) = block.instructions[i].instr.real() else { - return false; - }; - if func.get(block.instructions[i].arg) != IntrinsicFunction1::ListToTuple { - return false; - } - - let mut consts_found = 0usize; - let mut expect_append = true; - let mut pos = i; - while let Some(prev) = pos.checked_sub(1) { - pos = prev; - let instr = &block.instructions[pos]; - if matches!(instr.instr.real(), Some(Instruction::Nop)) { - continue; - } - - if matches!(instr.instr.real(), Some(Instruction::BuildList { .. })) - && u32::from(instr.arg) == 0 - { - if !expect_append { - return false; - } - - let mut elements = Vec::with_capacity(consts_found); - let mut folded_from_nonliteral_expr = false; - let mut expect_load = true; - for idx in pos + 1..i { - let instr = &block.instructions[idx]; - if matches!(instr.instr.real(), Some(Instruction::Nop)) { - continue; - } - if expect_load { - let Some(value) = Self::get_const_value_from(metadata, instr) else { - return false; - }; - folded_from_nonliteral_expr |= instr.folded_from_nonliteral_expr; - elements.push(value); - } else if !matches!(instr.instr.real(), Some(Instruction::ListAppend { .. })) - || u32::from(instr.arg) != 1 - { - return false; - } - expect_load = !expect_load; - } - if !expect_load || elements.len() != consts_found { - return false; - } - - let (const_idx, _) = metadata - .consts - .insert_full(ConstantData::Tuple { elements }); - for idx in pos..i { - nop_out_no_location(&mut block.instructions[idx]); - } - block.instructions[i].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - block.instructions[i].folded_from_nonliteral_expr = folded_from_nonliteral_expr; - return true; - } - - if expect_append { - if !matches!(instr.instr.real(), Some(Instruction::ListAppend { .. })) - || u32::from(instr.arg) != 1 - { - return false; - } - } else { - if Self::get_const_value_from_dummy(instr).is_none() { - return false; - } - consts_found += 1; - } - expect_append = !expect_append; - } - - false - } - - fn fold_list_constant_at(metadata: &mut CodeUnitMetadata, block: &mut Block, i: usize) -> bool { - let Some(Instruction::BuildList { .. }) = block.instructions[i].instr.real() else { - return false; - }; - - let list_size = u32::from(block.instructions[i].arg) as usize; - if list_size == 0 || list_size > STACK_USE_GUIDELINE { - return false; - } - - let Some((operand_indices, elements)) = - Self::get_const_sequence(metadata, block, i, list_size) - else { - return false; - }; - if list_size < MIN_CONST_SEQUENCE_SIZE { - return false; - } - - let (const_idx, _) = metadata - .consts - .insert_full(ConstantData::Tuple { elements }); - - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; - let eh = block.instructions[i].except_handler; - - let build_idx = operand_indices[0]; - let const_idx_slot = operand_indices[1]; - - block.instructions[build_idx].instr = Instruction::BuildList { - count: Arg::marker(), - } - .into(); - block.instructions[build_idx].arg = OpArg::new(0); - block.instructions[build_idx].location = folded_loc; - block.instructions[build_idx].end_location = end_loc; - block.instructions[build_idx].except_handler = eh; - - block.instructions[const_idx_slot].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[const_idx_slot].arg = OpArg::new(const_idx as u32); - block.instructions[const_idx_slot].location = folded_loc; - block.instructions[const_idx_slot].end_location = end_loc; - block.instructions[const_idx_slot].except_handler = eh; - - for &j in &operand_indices[2..] { - set_to_nop(&mut block.instructions[j]); - block.instructions[j].location = folded_loc; - } - - block.instructions[i].instr = Opcode::ListExtend.into(); - block.instructions[i].arg = OpArg::new(1); - true - } - - /// Constant folding: fold LOAD_CONST/LOAD_SMALL_INT + BUILD_TUPLE into LOAD_CONST tuple - /// fold_tuple_of_constants. This also folds constant list/set literals - /// in block order to match CPython's optimize_basic_block() const-table order. - fn fold_tuple_constants(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i < block.instructions.len() { - if Self::fold_tuple_constant_at(&mut self.metadata, block, i) - || Self::fold_list_constant_at(&mut self.metadata, block, i) - || Self::fold_set_constant_at(&mut self.metadata, block, i) - { - i += 1; - continue; - } - i += 1; - } - } - } - - /// Fold constant list literals: LOAD_CONST* + BUILD_LIST N → - /// BUILD_LIST 0 + LOAD_CONST (tuple) + LIST_EXTEND 1 - fn fold_list_constants(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i < block.instructions.len() { - let instr = &block.instructions[i]; - let Some(Instruction::BuildList { .. }) = instr.instr.real() else { - i += 1; - continue; - }; - - let list_size = u32::from(instr.arg) as usize; - if list_size == 0 || list_size > STACK_USE_GUIDELINE { - i += 1; - continue; - } - - let Some((operand_indices, elements)) = - Self::get_const_sequence(&self.metadata, block, i, list_size) - else { - i += 1; - continue; - }; - if list_size < MIN_CONST_SEQUENCE_SIZE { - i += 1; - continue; - } - - let tuple_const = ConstantData::Tuple { elements }; - let (const_idx, _) = self.metadata.consts.insert_full(tuple_const); - - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; - let eh = block.instructions[i].except_handler; - - let build_idx = operand_indices[0]; - let const_idx_slot = operand_indices[1]; - - block.instructions[build_idx].instr = Instruction::BuildList { - count: Arg::marker(), - } - .into(); - block.instructions[build_idx].arg = OpArg::new(0); - block.instructions[build_idx].location = folded_loc; - block.instructions[build_idx].end_location = end_loc; - block.instructions[build_idx].except_handler = eh; - - block.instructions[const_idx_slot].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[const_idx_slot].arg = OpArg::new(const_idx as u32); - block.instructions[const_idx_slot].location = folded_loc; - block.instructions[const_idx_slot].end_location = end_loc; - block.instructions[const_idx_slot].except_handler = eh; - - // NOP the rest - for &j in &operand_indices[2..] { - set_to_nop(&mut block.instructions[j]); - block.instructions[j].location = folded_loc; - } - - // slot[i] (was BUILD_LIST) → LIST_EXTEND 1 - block.instructions[i].instr = Opcode::ListExtend.into(); - block.instructions[i].arg = OpArg::new(1); - - i += 1; - } - } - } - - /// Port of CPython's flowgraph.c optimize_lists_and_sets(). - /// - /// For GET_ITER / CONTAINS_OP users: - /// - Constant BUILD_LIST/BUILD_SET becomes LOAD_CONST tuple/frozenset. - /// - Non-constant BUILD_LIST becomes BUILD_TUPLE. - /// - Previously folded BUILD_LIST 0 + LOAD_CONST + LIST_EXTEND and - /// BUILD_SET 0 + LOAD_CONST + SET_UPDATE collapse back to LOAD_CONST. - fn optimize_lists_and_sets(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i + 1 < block.instructions.len() { - if matches!( - block.instructions[i].instr.real(), - Some(Instruction::CallIntrinsic1 { func }) - if func.get(block.instructions[i].arg) == IntrinsicFunction1::ListToTuple - ) { - if matches!( - block - .instructions - .get(i + 1) - .and_then(|instr| instr.instr.real()), - Some(Instruction::GetIter) - ) { - set_to_nop(&mut block.instructions[i]); - i += 2; - continue; - } - if Self::fold_constant_intrinsic_list_to_tuple_at(&mut self.metadata, block, i) - { - i += 1; - continue; - } - } - - if let Some(non_nop4) = Self::get_non_nop_instr_indices(block, i, 4) { - let is_build_list = non_nop4[0] == i - && matches!( - block.instructions[non_nop4[0]].instr.real(), - Some(Instruction::BuildList { .. }) - ) - && u32::from(block.instructions[non_nop4[0]].arg) == 0; - let is_const = matches!( - block.instructions[non_nop4[1]].instr.real(), - Some(Instruction::LoadConst { .. }) - ); - let is_list_extend = matches!( - block.instructions[non_nop4[2]].instr.real(), - Some(Instruction::ListExtend { .. }) - ) && u32::from(block.instructions[non_nop4[2]].arg) == 1; - let uses_iter_or_contains = matches!( - block.instructions[non_nop4[3]].instr.real(), - Some(Instruction::GetIter | Instruction::ContainsOp { .. }) - ); - - if is_build_list && is_const && is_list_extend && uses_iter_or_contains { - let loc = block.instructions[i].location; - set_to_nop(&mut block.instructions[i]); - block.instructions[i].location = loc; - set_to_nop(&mut block.instructions[non_nop4[2]]); - block.instructions[non_nop4[2]].location = loc; - i += 1; - continue; - } - - let is_build_set = non_nop4[0] == i - && matches!( - block.instructions[non_nop4[0]].instr.real(), - Some(Instruction::BuildSet { .. }) - ) - && u32::from(block.instructions[non_nop4[0]].arg) == 0; - let is_set_update = matches!( - block.instructions[non_nop4[2]].instr.real(), - Some(Instruction::SetUpdate { .. }) - ) && u32::from(block.instructions[non_nop4[2]].arg) == 1; - - if is_build_set && is_const && is_set_update && uses_iter_or_contains { - let loc = block.instructions[i].location; - set_to_nop(&mut block.instructions[i]); - block.instructions[i].location = loc; - set_to_nop(&mut block.instructions[non_nop4[2]]); - block.instructions[non_nop4[2]].location = loc; - i += 1; - continue; - } - } - - let Some(non_nop2) = Self::get_non_nop_instr_indices(block, i, 2) else { - i += 1; - continue; - }; - let uses_iter_or_contains = non_nop2[0] == i - && matches!( - block.instructions[non_nop2[1]].instr.real(), - Some(Instruction::GetIter | Instruction::ContainsOp { .. }) - ); - if !uses_iter_or_contains { - i += 1; - continue; - } - - if matches!( - block.instructions[i].instr.real(), - Some(Instruction::BuildList { .. }) - ) { - let seq_size = u32::from(block.instructions[i].arg) as usize; - if seq_size > STACK_USE_GUIDELINE { - i += 2; - continue; - } - if let Some((operand_indices, elements)) = - Self::get_const_sequence(&self.metadata, block, i, seq_size) - { - let const_data = ConstantData::Tuple { elements }; - let (const_idx, _) = self.metadata.consts.insert_full(const_data); - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; - let eh = block.instructions[i].except_handler; - - for &j in &operand_indices { - set_to_nop(&mut block.instructions[j]); - block.instructions[j].location = folded_loc; - block.instructions[j].end_location = end_loc; - } - - block.instructions[i].instr = Opcode::LoadConst.into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - block.instructions[i].location = folded_loc; - block.instructions[i].end_location = end_loc; - block.instructions[i].except_handler = eh; - i += 2; - continue; - } - - block.instructions[i].instr = Opcode::BuildTuple.into(); - i += 2; - } else if matches!( - block.instructions[i].instr.real(), - Some(Instruction::BuildSet { .. }) - ) { - let seq_size = u32::from(block.instructions[i].arg) as usize; - if seq_size > STACK_USE_GUIDELINE { - i += 2; - continue; - } - let Some((operand_indices, elements)) = - Self::get_const_sequence(&self.metadata, block, i, seq_size) - else { - i += 2; - continue; - }; - let const_data = ConstantData::Frozenset { elements }; - let (const_idx, _) = self.metadata.consts.insert_full(const_data); - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; - let eh = block.instructions[i].except_handler; - - for &j in &operand_indices { - set_to_nop(&mut block.instructions[j]); - block.instructions[j].location = folded_loc; - block.instructions[j].end_location = end_loc; - } - - block.instructions[i].instr = Opcode::LoadConst.into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); - block.instructions[i].location = folded_loc; - block.instructions[i].end_location = end_loc; - block.instructions[i].except_handler = eh; - i += 2; - } else { - i += 1; - } - } - } - } - - fn fold_set_constant_at(metadata: &mut CodeUnitMetadata, block: &mut Block, i: usize) -> bool { - let Some(Instruction::BuildSet { .. }) = block.instructions[i].instr.real() else { - return false; - }; - - let set_size = u32::from(block.instructions[i].arg) as usize; - if !(3..=STACK_USE_GUIDELINE).contains(&set_size) { - return false; - } - - let Some((operand_indices, elements)) = - Self::get_const_sequence(metadata, block, i, set_size) - else { - return false; - }; - let (const_idx, _) = metadata - .consts - .insert_full(ConstantData::Frozenset { elements }); - - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; - let eh = block.instructions[i].except_handler; - - let build_idx = operand_indices[0]; - let const_idx_slot = operand_indices[1]; - - block.instructions[build_idx].instr = Instruction::BuildSet { - count: Arg::marker(), - } - .into(); - block.instructions[build_idx].arg = OpArg::new(0); - block.instructions[build_idx].location = folded_loc; - block.instructions[build_idx].end_location = end_loc; - block.instructions[build_idx].except_handler = eh; - - block.instructions[const_idx_slot].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[const_idx_slot].arg = OpArg::new(const_idx as u32); - block.instructions[const_idx_slot].location = folded_loc; - block.instructions[const_idx_slot].end_location = end_loc; - block.instructions[const_idx_slot].except_handler = eh; - - for &j in &operand_indices[2..] { - set_to_nop(&mut block.instructions[j]); - block.instructions[j].location = folded_loc; - } - - block.instructions[i].instr = Opcode::SetUpdate.into(); - block.instructions[i].arg = OpArg::new(1); - true - } - - /// Fold constant set literals: LOAD_CONST* + BUILD_SET N → - /// BUILD_SET 0 + LOAD_CONST (frozenset-as-tuple) + SET_UPDATE 1 - fn fold_set_constants(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i < block.instructions.len() { - let instr = &block.instructions[i]; - let Some(Instruction::BuildSet { .. }) = instr.instr.real() else { - i += 1; - continue; - }; - - let set_size = u32::from(instr.arg) as usize; - if !(3..=STACK_USE_GUIDELINE).contains(&set_size) { - i += 1; - continue; - } - - let Some((operand_indices, elements)) = - Self::get_const_sequence(&self.metadata, block, i, set_size) - else { - i += 1; - continue; - }; - let const_data = ConstantData::Frozenset { elements }; - let (const_idx, _) = self.metadata.consts.insert_full(const_data); - - let folded_loc = block.instructions[i].location; - let end_loc = block.instructions[i].end_location; - let eh = block.instructions[i].except_handler; - - let build_idx = operand_indices[0]; - let const_idx_slot = operand_indices[1]; - - block.instructions[build_idx].instr = Instruction::BuildSet { - count: Arg::marker(), - } - .into(); - block.instructions[build_idx].arg = OpArg::new(0); - block.instructions[build_idx].location = folded_loc; - block.instructions[build_idx].end_location = end_loc; - block.instructions[build_idx].except_handler = eh; - - block.instructions[const_idx_slot].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[const_idx_slot].arg = OpArg::new(const_idx as u32); - block.instructions[const_idx_slot].location = folded_loc; - block.instructions[const_idx_slot].end_location = end_loc; - block.instructions[const_idx_slot].except_handler = eh; - - for &j in &operand_indices[2..] { - set_to_nop(&mut block.instructions[j]); - block.instructions[j].location = folded_loc; - } - - block.instructions[i].instr = Opcode::SetUpdate.into(); - block.instructions[i].arg = OpArg::new(1); - - i += 1; - } - } - } - - /// BUILD_TUPLE n + UNPACK_SEQUENCE n optimization. - /// - /// Ported from CPython flowgraph.c optimize_basic_block: - /// - n == 1: both become NOP (identity operation) - /// - n == 2 or 3: BUILD_TUPLE → NOP, UNPACK_SEQUENCE → SWAP - fn optimize_build_tuple_unpack(&mut self) { - for block in &mut self.blocks { - let instructions = &mut block.instructions; - let len = instructions.len(); - for i in 0..len.saturating_sub(1) { - let Some(Instruction::BuildTuple { .. }) = instructions[i].instr.real() else { - continue; - }; - let n = u32::from(instructions[i].arg); - let Some(Instruction::UnpackSequence { .. }) = instructions[i + 1].instr.real() - else { - continue; - }; - if u32::from(instructions[i + 1].arg) != n { - continue; - } - match n { - 1 => { - instructions[i].instr = Opcode::Nop.into(); - instructions[i].arg = OpArg::new(0); - instructions[i + 1].instr = Opcode::Nop.into(); - instructions[i + 1].arg = OpArg::new(0); - } - 2 | 3 => { - instructions[i].instr = Opcode::Nop.into(); - instructions[i].arg = OpArg::new(0); - instructions[i + 1].instr = Opcode::Swap.into(); - instructions[i + 1].arg = OpArg::new(n); - } - _ => {} - } - } - } - } - - /// apply_static_swaps: eliminate SWAPs by reordering target stores/pops. - /// - /// Ported from CPython Python/flowgraph.c::apply_static_swaps. - /// For each SWAP N, find the 1st and N-th swappable instructions after - /// it. If both are STORE_FAST/POP_TOP and safe to swap, exchange them - /// in the bytecode and replace SWAP with NOP. - /// - /// Safety: abort if the two stores write the same variable, or if any - /// intervening swappable stores to one of the same variables. Do not - /// cross line-number boundaries (user-visible name bindings). - fn apply_static_swaps(&mut self) { - const VISITED: i32 = -1; - - /// Instruction classes that are safe to reorder around SWAP. - fn is_swappable(instr: &AnyInstruction) -> bool { - matches!( - (*instr).into(), - AnyOpcode::Real(Opcode::StoreFast | Opcode::PopTop) - | AnyOpcode::Pseudo(PseudoOpcode::StoreFastMaybeNull) - ) - } - - /// Variable index that a STORE_FAST writes to, or None. - fn stores_to(info: &InstructionInfo) -> Option { - match info.instr.into() { - AnyOpcode::Real(Opcode::StoreFast) => Some(u32::from(info.arg)), - AnyOpcode::Pseudo(PseudoOpcode::StoreFastMaybeNull) => Some(u32::from(info.arg)), - _ => None, - } - } - - /// Next swappable index after `i` in `instructions`, skipping NOPs. - /// Returns None if a non-NOP non-swappable instruction blocks, or - /// if `lineno >= 0` and a different lineno is encountered. - fn next_swappable( - instructions: &[InstructionInfo], - mut i: usize, - lineno: i32, - ) -> Option { - loop { - i += 1; - if i >= instructions.len() { - return None; - } - let info = &instructions[i]; - let info_lineno = info.location.line.get() as i32; - if lineno >= 0 && info_lineno > 0 && info_lineno != lineno { - return None; - } - if matches!(info.instr, AnyInstruction::Real(Instruction::Nop)) { - continue; - } - if is_swappable(&info.instr) { - return Some(i); - } - return None; - } - } - - fn optimize_swap_block(instructions: &mut [InstructionInfo]) { - let mut i = 0usize; - while i < instructions.len() { - let AnyInstruction::Real(Instruction::Swap { .. }) = instructions[i].instr else { - i += 1; - continue; - }; - - let mut len = 0usize; - let mut depth = 0usize; - let mut more = false; - while i + len < instructions.len() { - let info = &instructions[i + len]; - match info.instr.real() { - Some(Instruction::Swap { .. }) => { - let oparg = u32::from(info.arg) as usize; - depth = depth.max(oparg); - more |= len > 0; - len += 1; - } - Some(Instruction::Nop) => { - len += 1; - } - _ => break, - } - } - - if !more { - i += len.max(1); - continue; - } - - let mut stack: Vec = (0..depth as i32).collect(); - for info in &instructions[i..i + len] { - if matches!(info.instr.real(), Some(Instruction::Swap { .. })) { - let oparg = u32::from(info.arg) as usize; - stack.swap(0, oparg - 1); - } - } - - let mut current = len as isize - 1; - for slot in 0..depth { - if stack[slot] == VISITED || stack[slot] == slot as i32 { - continue; - } - let mut j = slot; - loop { - if j != 0 { - let out = &mut instructions[i + current as usize]; - out.instr = Opcode::Swap.into(); - out.arg = OpArg::new((j + 1) as u32); - out.target = BlockIdx::NULL; - current -= 1; - } - if stack[j] == VISITED { - debug_assert_eq!(j, slot); - break; - } - let next_j = stack[j] as usize; - stack[j] = VISITED; - j = next_j; - } - } - while current >= 0 { - set_to_nop(&mut instructions[i + current as usize]); - current -= 1; - } - i += len; - } - } - - fn apply_from(instructions: &mut [InstructionInfo], mut i: isize) { - while i >= 0 { - let idx = i as usize; - let swap_arg = match instructions[idx].instr.real() { - Some(Instruction::Swap { .. }) => u32::from(instructions[idx].arg), - Some( - Instruction::Nop | Instruction::PopTop | Instruction::StoreFast { .. }, - ) => { - i -= 1; - continue; - } - _ if matches!( - instructions[idx].instr.pseudo(), - Some(PseudoInstruction::StoreFastMaybeNull { .. }) - ) => - { - i -= 1; - continue; - } - _ => return, - }; - - if swap_arg < 2 { - return; - } - - let Some(j) = next_swappable(instructions, idx, -1) else { - return; - }; - let lineno = instructions[j].location.line.get() as i32; - let mut k = j; - for _ in 1..swap_arg { - let Some(next) = next_swappable(instructions, k, lineno) else { - return; - }; - k = next; - } - - let store_j = stores_to(&instructions[j]); - let store_k = stores_to(&instructions[k]); - if store_j.is_some() || store_k.is_some() { - if store_j == store_k { - return; - } - let conflict = instructions[(j + 1)..k].iter().any(|info| { - if let Some(store_idx) = stores_to(info) { - Some(store_idx) == store_j || Some(store_idx) == store_k - } else { - false - } - }); - if conflict { - return; - } - } - - instructions[idx].instr = Opcode::Nop.into(); - instructions[idx].arg = OpArg::new(0); - instructions.swap(j, k); - i -= 1; - } - } - - for block in &mut self.blocks { - optimize_swap_block(&mut block.instructions); - let len = block.instructions.len(); - for i in 0..len { - if matches!( - block.instructions[i].instr.real(), - Some(Instruction::Swap { .. }) - ) { - apply_from(&mut block.instructions, i as isize); - } - } - } - } - - /// Eliminate dead stores in STORE_FAST sequences (apply_static_swaps). - /// - /// In sequences of consecutive STORE_FAST instructions (from tuple unpacking), - /// only collapse directly adjacent duplicate targets. - /// - /// CPython preserves non-adjacent duplicates such as `_, expr, _` so the - /// store layout still reflects the original unpack order. Replacing the - /// first `_` with POP_TOP there changes the emitted superinstructions and - /// bytecode shape even though the final value is the same. - fn eliminate_dead_stores(&mut self) { - for block in &mut self.blocks { - let instructions = &mut block.instructions; - let len = instructions.len(); - let mut i = 0; - while i < len { - // Look for UNPACK_SEQUENCE or UNPACK_EX - let is_unpack = matches!( - instructions[i].instr.into(), - AnyOpcode::Real(Opcode::UnpackSequence | Opcode::UnpackEx) - ); - if !is_unpack { - i += 1; - continue; - } - // Scan the run of STORE_FAST right after the unpack - let run_start = i + 1; - let mut run_end = run_start; - while run_end < len - && matches!( - instructions[run_end].instr.into(), - AnyOpcode::Real(Opcode::StoreFast) - ) - { - run_end += 1; - } - if run_end - run_start >= 2 { - let mut j = run_start; - while j < run_end { - let arg = u32::from(instructions[j].arg); - let mut group_end = j + 1; - while group_end < run_end && u32::from(instructions[group_end].arg) == arg { - group_end += 1; - } - for instr in &mut instructions[j..group_end.saturating_sub(1)] { - instr.instr = Opcode::PopTop.into(); - instr.arg = OpArg::new(0); - } - j = group_end; - } - } - i = run_end.max(i + 1); - } - - // General same-line duplicate STORE_FAST elimination from - // flowgraph.c optimize_basic_block(). This is required for - // apply_static_swaps() patterns such as `a, a = x, y`. - for i in 0..instructions.len().saturating_sub(1) { - let lhs = &instructions[i]; - let rhs = &instructions[i + 1]; - if !matches!(lhs.instr.real(), Some(Instruction::StoreFast { .. })) - || !matches!(rhs.instr.real(), Some(Instruction::StoreFast { .. })) - || u32::from(lhs.arg) != u32::from(rhs.arg) - || instruction_lineno(lhs) != instruction_lineno(rhs) - { - continue; - } - instructions[i].instr = Instruction::PopTop.into(); - instructions[i].arg = OpArg::NULL; - instructions[i].target = BlockIdx::NULL; - } - } - } - - /// Peephole optimization: combine consecutive instructions into super-instructions - fn peephole_optimize(&mut self) { - let const_truthiness = - |instr: Instruction, arg: OpArg, metadata: &CodeUnitMetadata| match instr { - Instruction::LoadConst { consti } => { - let constant = &metadata.consts[consti.get(arg).as_usize()]; - Some(match constant { - ConstantData::Tuple { elements } => !elements.is_empty(), - ConstantData::Integer { value } => !value.is_zero(), - ConstantData::Float { value } => *value != 0.0, - ConstantData::Complex { value } => value.re != 0.0 || value.im != 0.0, - ConstantData::Boolean { value } => *value, - ConstantData::Str { value } => !value.is_empty(), - ConstantData::Bytes { value } => !value.is_empty(), - ConstantData::Code { .. } => true, - ConstantData::Slice { .. } => true, - ConstantData::Frozenset { elements } => !elements.is_empty(), - ConstantData::None => false, - ConstantData::Ellipsis => true, - }) - } - Instruction::LoadSmallInt { i } => Some(i.get(arg) != 0), - _ => None, - }; - for (block_idx, block) in self.blocks.iter_mut().enumerate() { - let mut i = 0; - while i + 1 < block.instructions.len() { - let curr = &block.instructions[i]; - let next = &block.instructions[i + 1]; - let curr_arg = curr.arg; - let next_arg = next.arg; - - // Only combine if both are real instructions (not pseudo) - let (Some(curr_instr), Some(next_instr)) = (curr.instr.real(), next.instr.real()) - else { - i += 1; - continue; - }; - - if matches!(curr_instr, Instruction::ToBool) - && matches!(next_instr, Instruction::UnaryNot) - && let Some(Instruction::ToBool) = block - .instructions - .get(i + 2) - .and_then(|info| info.instr.real()) - && let Some(Instruction::UnaryNot) = block - .instructions - .get(i + 3) - .and_then(|info| info.instr.real()) - { - block.instructions.drain(i + 1..=i + 3); - continue; - } - - if matches!(curr_instr, Instruction::UnaryNot | Instruction::ToBool) - && matches!(next_instr, Instruction::ToBool) - { - block.instructions.remove(i + 1); - continue; - } - - if matches!( - curr_instr, - Instruction::ContainsOp { .. } | Instruction::IsOp { .. } - ) && matches!(next_instr, Instruction::UnaryNot) - { - set_to_nop(&mut block.instructions[i]); - block.instructions[i + 1].instr = curr_instr.into(); - block.instructions[i + 1].arg = OpArg::new(u32::from(curr_arg) ^ 1); - i += 1; - continue; - } - - if let Some(is_true) = const_truthiness(curr_instr, curr.arg, &self.metadata) { - let jump_if_true = match next_instr { - Instruction::PopJumpIfTrue { .. } => Some(true), - Instruction::PopJumpIfFalse { .. } => Some(false), - _ => None, - }; - if let Some(jump_if_true) = jump_if_true { - let target = match next_instr { - Instruction::PopJumpIfTrue { delta } - | Instruction::PopJumpIfFalse { delta } => delta.get(next.arg), - _ => unreachable!(), - }; - set_to_nop(&mut block.instructions[i]); - let preserves_pure_self_loop_anchor = i == 0 - && block.instructions[i + 1..].iter().all(|info| { - if info.target == BlockIdx(block_idx as u32) - && info.instr.is_unconditional_jump() - { - return true; - } - matches!(info.instr.real(), Some(Instruction::Nop)) - }) - && block.instructions[i + 1..].iter().any(|info| { - info.target == BlockIdx(block_idx as u32) - && info.instr.is_unconditional_jump() - }); - if preserves_pure_self_loop_anchor { - block.instructions[i].preserve_block_start_no_location_nop = true; - } - if is_true == jump_if_true { - block.instructions[i + 1].instr = PseudoInstruction::Jump { - delta: Arg::marker(), - } - .into(); - block.instructions[i + 1].arg = OpArg::new(u32::from(target)); - } else { - set_to_nop(&mut block.instructions[i + 1]); - } - i += 1; - continue; - } - } - - if let Instruction::LoadConst { consti } = curr_instr { - let constant = &self.metadata.consts[consti.get(curr_arg).as_usize()]; - if matches!(constant, ConstantData::None) - && let Instruction::IsOp { invert } = next_instr - { - let mut jump_idx = i + 2; - if jump_idx >= block.instructions.len() { - i += 1; - continue; - } - - if matches!( - block.instructions[jump_idx].instr.real(), - Some(Instruction::ToBool) - ) { - set_to_nop(&mut block.instructions[jump_idx]); - jump_idx += 1; - if jump_idx >= block.instructions.len() { - i += 1; - continue; - } - } - - let Some(jump_instr) = block.instructions[jump_idx].instr.real() else { - i += 1; - continue; - }; - - let mut invert = matches!( - invert.get(next_arg), - rustpython_compiler_core::bytecode::Invert::Yes - ); - let delta = match jump_instr { - Instruction::PopJumpIfFalse { delta } => { - invert = !invert; - delta.get(block.instructions[jump_idx].arg) - } - Instruction::PopJumpIfTrue { delta } => { - delta.get(block.instructions[jump_idx].arg) - } - _ => { - i += 1; - continue; - } - }; - - set_to_nop(&mut block.instructions[i]); - set_to_nop(&mut block.instructions[i + 1]); - block.instructions[jump_idx].instr = if invert { - Instruction::PopJumpIfNotNone { - delta: Arg::marker(), - } - } else { - Instruction::PopJumpIfNone { - delta: Arg::marker(), - } - } - .into(); - block.instructions[jump_idx].arg = OpArg::new(u32::from(delta)); - i = jump_idx; - continue; - } - } - - if matches!( - curr_instr, - Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. } - ) && matches!(next_instr, Instruction::PopTop) - { - set_to_nop(&mut block.instructions[i]); - set_to_nop(&mut block.instructions[i + 1]); - i += 1; - continue; - } - - if matches!(curr_instr, Instruction::Copy { i } if i.get(curr.arg) == 1) - && matches!(next_instr, Instruction::PopTop) - { - set_to_nop(&mut block.instructions[i]); - set_to_nop(&mut block.instructions[i + 1]); - i += 1; - continue; - } - - let combined = { - match (curr_instr, next_instr) { - // Note: StoreFast + LoadFast → StoreFastLoadFast is done in a - // later pass aligned with CPython insert_superinstructions(). - ( - Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. }, - Instruction::ToBool, - ) => { - if let Some(value) = - const_truthiness(curr_instr, curr.arg, &self.metadata) - { - let (const_idx, _) = self - .metadata - .consts - .insert_full(ConstantData::Boolean { value }); - Some(( - Instruction::LoadConst { - consti: Arg::marker(), - }, - OpArg::new(const_idx as u32), - )) - } else { - None - } - } - (Instruction::CompareOp { .. }, Instruction::ToBool) => Some(( - curr_instr, - OpArg::new(u32::from(curr.arg) | oparg::COMPARE_OP_BOOL_MASK), - )), - ( - Instruction::ContainsOp { .. } | Instruction::IsOp { .. }, - Instruction::ToBool, - ) => Some((curr_instr, curr.arg)), - (Instruction::LoadConst { consti }, Instruction::UnaryNot) => { - let constant = &self.metadata.consts[consti.get(curr.arg).as_usize()]; - match constant { - ConstantData::Boolean { value } => { - let (const_idx, _) = self - .metadata - .consts - .insert_full(ConstantData::Boolean { value: !value }); - Some(((Opcode::LoadConst.into()), OpArg::new(const_idx as u32))) - } - _ => None, - } - } - _ => None, - } - }; - - if let Some((new_instr, new_arg)) = combined { - // Combine: keep first instruction's location, replace with combined instruction - block.instructions[i].instr = new_instr.into(); - block.instructions[i].arg = new_arg; - // Remove the second instruction - block.instructions.remove(i + 1); - // Don't increment i - check if we can combine again with the next instruction - } else { - i += 1; - } - } - } - } - - /// LOAD_GLOBAL + PUSH_NULL -> LOAD_GLOBAL , NOP - fn optimize_load_global_push_null(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i + 1 < block.instructions.len() { - let curr = &block.instructions[i]; - let next = &block.instructions[i + 1]; - - let (Some(Instruction::LoadGlobal { .. }), Some(Instruction::PushNull)) = - (curr.instr.real(), next.instr.real()) - else { - i += 1; - continue; - }; - - let oparg = u32::from(block.instructions[i].arg); - if (oparg & 1) != 0 { - i += 1; - continue; - } - - block.instructions[i].arg = OpArg::new(oparg | 1); - block.instructions.remove(i + 1); - } - } - } - - fn remove_redundant_const_pop_top_pairs(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i + 1 < block.instructions.len() { - let curr = &block.instructions[i]; - let next = &block.instructions[i + 1]; - let Some(curr_instr) = curr.instr.real() else { - i += 1; - continue; - }; - let Some(next_instr) = next.instr.real() else { - i += 1; - continue; - }; - - let redundant = matches!( - (curr_instr, next_instr), - ( - Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. }, - Instruction::PopTop - ) - ) || matches!(curr_instr, Instruction::Copy { i } if i.get(curr.arg) == 1) - && matches!(next_instr, Instruction::PopTop); - - if redundant { - set_to_nop(&mut block.instructions[i]); - set_to_nop(&mut block.instructions[i + 1]); - i += 2; - } else { - i += 1; - } - } - } - } - - /// Convert LOAD_CONST for small integers to LOAD_SMALL_INT - /// maybe_instr_make_load_smallint - fn convert_to_load_small_int(&mut self) { - for block in &mut self.blocks { - for instr in &mut block.instructions { - // Check if it's a LOAD_CONST instruction - let Some(Instruction::LoadConst { .. }) = instr.instr.real() else { - continue; - }; - - // Get the constant value - let const_idx = u32::from(instr.arg) as usize; - let Some(constant) = self.metadata.consts.get_index(const_idx) else { - continue; - }; - - // Check if it's a small integer - let ConstantData::Integer { value } = constant else { - continue; - }; - - // LOAD_SMALL_INT oparg is unsigned, so only 0..=255 can be encoded - if let Some(small) = value.to_i32().filter(|v| (0..=255).contains(v)) { - // Convert LOAD_CONST to LOAD_SMALL_INT - instr.instr = Opcode::LoadSmallInt.into(); - // The arg is the i32 value stored as u32 (two's complement) - instr.arg = OpArg::new(small as u32); - } - } - } - } - - /// Remove constants that are no longer referenced by LOAD_CONST instructions. - /// remove_unused_consts - fn remove_unused_consts(&mut self) { - let nconsts = self.metadata.consts.len(); - if nconsts == 0 { - return; - } - - // Mark used constants - // The first constant (index 0) is always kept (may be docstring) - let mut used = vec![false; nconsts]; - used[0] = true; - - for block in &self.blocks { - for instr in &block.instructions { - if let Some(Instruction::LoadConst { .. }) = instr.instr.real() { - let idx = u32::from(instr.arg) as usize; - if idx < nconsts { - used[idx] = true; - } - } - } - } - - // Check if any constants can be removed - let n_used: usize = used.iter().filter(|&&u| u).count(); - if n_used == nconsts { - return; // Nothing to remove - } - - // Build old_to_new index mapping - let mut old_to_new = vec![0usize; nconsts]; - let mut new_idx = 0usize; - for (old_idx, &is_used) in used.iter().enumerate() { - if is_used { - old_to_new[old_idx] = new_idx; - new_idx += 1; - } - } - - // Build new consts list - let old_consts: Vec<_> = self.metadata.consts.iter().cloned().collect(); - self.metadata.consts.clear(); - for (old_idx, constant) in old_consts.into_iter().enumerate() { - if used[old_idx] { - self.metadata.consts.insert(constant); - } - } - - // Update LOAD_CONST instruction arguments - for block in &mut self.blocks { - for instr in &mut block.instructions { - if let Some(Instruction::LoadConst { .. }) = instr.instr.real() { - let old_idx = u32::from(instr.arg) as usize; - if old_idx < nconsts { - instr.arg = OpArg::new(old_to_new[old_idx] as u32); - } - } - } - } - } - - /// Remove NOP instructions from all blocks, but keep NOPs that introduce - /// a new source line (they serve as line markers for monitoring LINE events). - fn remove_nops(&mut self) { - let layout_predecessors = compute_layout_predecessors(&self.blocks); - let keep_target_start_nops: Vec<_> = (0..self.blocks.len()) - .map(|idx| { - keep_target_start_no_location_nop( - &self.blocks, - BlockIdx(idx as u32), - &layout_predecessors, - ) - }) - .collect(); - let mut conditional_targets = vec![false; self.blocks.len()]; - for block in &self.blocks { - for instr in &block.instructions { - if instr.target != BlockIdx::NULL && is_conditional_jump(&instr.instr) { - let target = next_nonempty_block(&self.blocks, instr.target); - if target != BlockIdx::NULL { - conditional_targets[target.idx()] = true; - } - } - } - } - let preserve_loop_exit_pop_block_nops: Vec<_> = (0..self.blocks.len()) - .map(|idx| { - let block_idx = BlockIdx(idx as u32); - let block = &self.blocks[idx]; - let layout_pred = layout_predecessors[idx]; - block.instructions.first().is_some_and(|instr| { - matches!(instr.instr.real(), Some(Instruction::Nop)) - && instr.remove_no_location_nop - && instruction_lineno(instr) < 0 - && conditional_targets[idx] - && layout_pred != BlockIdx::NULL - && self.blocks[layout_pred.idx()] - .instructions - .last() - .is_some_and(|last| { - matches!( - last.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) && next_nonempty_block(&self.blocks, last.target) != block_idx - }) - }) - }) - .collect(); - - for (block_idx, block) in self.blocks.iter_mut().enumerate() { - let mut prev_line = None; - let mut src = 0usize; - block.instructions.retain(|ins| { - let keep = 'keep: { - if matches!(ins.instr.real(), Some(Instruction::Nop)) { - let keep_loop_exit_pop_block = src == 0 - && preserve_loop_exit_pop_block_nops - .get(block_idx) - .copied() - .unwrap_or(false); - let keep_target_start = src == 0 - && keep_target_start_nops - .get(block_idx) - .copied() - .unwrap_or(false); - if ins.remove_no_location_nop - && instruction_lineno(ins) < 0 - && !keep_loop_exit_pop_block - && (!keep_target_start || ins.folded_operand_nop) - { - break 'keep false; - } - let line = ins.location.line.get() as i32; - if prev_line == Some(line) { - break 'keep false; - } - } - prev_line = Some(instruction_lineno(ins)); - true - }; - src += 1; - keep - }); - } - } - - /// insert_superinstructions (flowgraph.c): combine adjacent same-line - /// LOAD_FAST / STORE_FAST pairs before later flowgraph passes change - /// block layout. - fn insert_superinstructions(&mut self) { - for block in &mut self.blocks { - let mut i = 0; - while i + 1 < block.instructions.len() { - let curr = &block.instructions[i]; - let next = &block.instructions[i + 1]; - if instruction_lineno(curr) != instruction_lineno(next) { - i += 1; - continue; - } - - match (curr.instr.real(), next.instr.real()) { - (Some(Instruction::LoadFast { .. }), Some(Instruction::LoadFast { .. })) => { - let idx1 = u32::from(curr.arg); - let idx2 = u32::from(next.arg); - if idx1 >= 16 || idx2 >= 16 { - i += 1; - continue; - } - let packed = (idx1 << 4) | idx2; - block.instructions[i].instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(packed); - block.instructions.remove(i + 1); - } - (Some(Instruction::StoreFast { .. }), Some(Instruction::LoadFast { .. })) => { - let store_idx = u32::from(curr.arg); - let load_idx = u32::from(next.arg); - if store_idx >= 16 || load_idx >= 16 { - i += 1; - continue; - } - let packed = (store_idx << 4) | load_idx; - block.instructions[i].instr = Instruction::StoreFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(packed); - block.instructions.remove(i + 1); - } - (Some(Instruction::StoreFast { .. }), Some(Instruction::StoreFast { .. })) => { - let idx1 = u32::from(curr.arg); - let idx2 = u32::from(next.arg); - if idx1 >= 16 || idx2 >= 16 { - i += 1; - continue; - } - let packed = (idx1 << 4) | idx2; - block.instructions[i].instr = Instruction::StoreFastStoreFast { - var_nums: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(packed); - block.instructions.remove(i + 1); - } - _ => i += 1, - } - } - } - } - - fn optimize_load_fast_borrow(&mut self) { - // NOT_LOCAL marker: instruction didn't come from a LOAD_FAST - const NOT_LOCAL: usize = usize::MAX; - const DUMMY_INSTR: isize = -1; - const SUPPORT_KILLED: u8 = 1; - const STORED_AS_LOCAL: u8 = 2; - const REF_UNCONSUMED: u8 = 4; - - #[derive(Clone, Copy)] - struct AbstractRef { - instr: isize, - local: usize, - } - - fn push_ref(refs: &mut Vec, instr: isize, local: usize) { - refs.push(AbstractRef { instr, local }); - } - - fn pop_ref(refs: &mut Vec) -> Option { - refs.pop() - } - - fn at_ref(refs: &[AbstractRef], idx: usize) -> Option { - refs.get(idx).copied() - } - - fn swap_top(refs: &mut [AbstractRef], depth: usize) { - let top = refs.len() - 1; - let other = refs.len() - depth; - refs.swap(top, other); - } - - fn kill_local(instr_flags: &mut [u8], refs: &[AbstractRef], local: usize) { - for r in refs.iter().copied().filter(|r| r.local == local) { - debug_assert!(r.instr >= 0); - instr_flags[r.instr as usize] |= SUPPORT_KILLED; - } - } - - fn store_local(instr_flags: &mut [u8], refs: &[AbstractRef], local: usize, r: AbstractRef) { - kill_local(instr_flags, refs, local); - if r.instr != DUMMY_INSTR { - instr_flags[r.instr as usize] |= STORED_AS_LOCAL; - } - } - - fn decode_packed_fast_locals(arg: OpArg) -> (usize, usize) { - let packed = u32::from(arg); - (((packed >> 4) & 0xF) as usize, (packed & 0xF) as usize) - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some(Instruction::JumpBackwardNoInterrupt { .. }) - ) - }); - has_pop_except && jumps_to_target - } - - fn block_falls_through_after_store_fast_store_fast( - block: &Block, - target: BlockIdx, - ) -> bool { - block.next == target - && matches!( - block.instructions.last().and_then(|info| info.instr.real()), - Some(Instruction::StoreFastStoreFast { .. }) - ) - } - - fn block_ends_with_backward_jump(block: &Block) -> bool { - matches!( - block.instructions.last().and_then(|info| info.instr.real()), - Some( - Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - } - - fn block_is_backward_jump_only(block: &Block) -> bool { - let mut real = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)); - matches!( - real.next(), - Some( - Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) && real.next().is_none() - } - - fn block_is_resume_loop_latch(blocks: &[Block], target: BlockIdx) -> bool { - let block = &blocks[target.idx()]; - if block_ends_with_backward_jump(block) { - return true; - } - block.next != BlockIdx::NULL && block_is_backward_jump_only(&blocks[block.next.idx()]) - } - - fn push_block( - worklist: &mut Vec, - visited: &mut [bool], - blocks: &[Block], - source: BlockIdx, - target: BlockIdx, - start_depth: usize, - ) { - let expected = blocks[target.idx()].start_depth.map(|depth| depth as usize); - if expected != Some(start_depth) { - debug_assert!( - expected == Some(start_depth), - "optimize_load_fast_borrow start_depth mismatch: source={source:?} target={target:?} expected={expected:?} actual={:?} source_last={:?} target_instrs={:?}", - Some(start_depth), - blocks[source.idx()] - .instructions - .last() - .and_then(|info| info.instr.real()), - blocks[target.idx()] - .instructions - .iter() - .map(|info| info.instr) - .collect::>(), - ); - return; - } - if !visited[target.idx()] { - visited[target.idx()] = true; - worklist.push(target); - } - } - - let mut handler_resume_loop_latch = vec![false; self.blocks.len()]; - for block in &self.blocks { - let Some(target) = block.instructions.last().map(|info| info.target) else { - continue; - }; - if target != BlockIdx::NULL - && is_handler_resume_predecessor(block, target) - && block_is_resume_loop_latch(&self.blocks, target) - && self - .blocks - .iter() - .any(|pred| block_falls_through_after_store_fast_store_fast(pred, target)) - { - handler_resume_loop_latch[target.idx()] = true; - } - } - - let mut visited = vec![false; self.blocks.len()]; - let mut worklist = vec![BlockIdx(0)]; - visited[0] = true; - - while let Some(block_idx) = worklist.pop() { - let block = &self.blocks[block_idx]; - - let mut instr_flags = vec![0u8; block.instructions.len()]; - let start_depth = block.start_depth.unwrap_or(0) as usize; - let mut refs = Vec::with_capacity(block.instructions.len() + start_depth + 2); - for _ in 0..start_depth { - push_ref(&mut refs, DUMMY_INSTR, NOT_LOCAL); - } - - for (i, info) in block.instructions.iter().enumerate() { - let instr = info.instr; - let arg_u32 = u32::from(info.arg); - - match instr { - AnyInstruction::Real(Instruction::DeleteFast { var_num }) => { - kill_local(&mut instr_flags, &refs, usize::from(var_num.get(info.arg))); - } - AnyInstruction::Real(Instruction::LoadFast { var_num }) => { - push_ref(&mut refs, i as isize, usize::from(var_num.get(info.arg))); - } - AnyInstruction::Real(Instruction::LoadFastAndClear { var_num }) => { - let local = usize::from(var_num.get(info.arg)); - kill_local(&mut instr_flags, &refs, local); - push_ref(&mut refs, i as isize, local); - } - AnyInstruction::Real(Instruction::LoadFastLoadFast { .. }) => { - let (local1, local2) = decode_packed_fast_locals(info.arg); - push_ref(&mut refs, i as isize, local1); - push_ref(&mut refs, i as isize, local2); - } - AnyInstruction::Real(Instruction::StoreFast { var_num }) => { - let Some(r) = pop_ref(&mut refs) else { - continue; - }; - store_local( - &mut instr_flags, - &refs, - usize::from(var_num.get(info.arg)), - r, - ); - } - AnyInstruction::Pseudo(PseudoInstruction::StoreFastMaybeNull { var_num }) => { - let Some(r) = pop_ref(&mut refs) else { - continue; - }; - store_local(&mut instr_flags, &refs, var_num.get(info.arg) as usize, r); - } - AnyInstruction::Real(Instruction::StoreFastLoadFast { .. }) => { - let (store_local_idx, load_local_idx) = decode_packed_fast_locals(info.arg); - let Some(r) = pop_ref(&mut refs) else { - continue; - }; - store_local(&mut instr_flags, &refs, store_local_idx, r); - push_ref(&mut refs, i as isize, load_local_idx); - } - AnyInstruction::Real(Instruction::StoreFastStoreFast { .. }) => { - let (local1, local2) = decode_packed_fast_locals(info.arg); - let Some(r1) = pop_ref(&mut refs) else { - continue; - }; - store_local(&mut instr_flags, &refs, local1, r1); - let Some(r2) = pop_ref(&mut refs) else { - continue; - }; - store_local(&mut instr_flags, &refs, local2, r2); - } - AnyInstruction::Real(Instruction::Copy { i: _ }) => { - let depth = arg_u32 as usize; - if depth == 0 || refs.len() < depth { - continue; - } - let r = at_ref(&refs, refs.len() - depth).expect("copy index in bounds"); - push_ref(&mut refs, r.instr, r.local); - } - AnyInstruction::Real(Instruction::Swap { i: _ }) => { - let depth = arg_u32 as usize; - if depth < 2 || refs.len() < depth { - continue; - } - swap_top(&mut refs, depth); - } - AnyInstruction::Real( - Instruction::FormatSimple - | Instruction::GetAnext - | Instruction::GetLen - | Instruction::GetYieldFromIter - | Instruction::ImportFrom { .. } - | Instruction::MatchKeys - | Instruction::MatchMapping - | Instruction::MatchSequence - | Instruction::WithExceptStart, - ) => { - let effect = instr.stack_effect_info(arg_u32); - let net_pushed = effect.pushed() as isize - effect.popped() as isize; - debug_assert!(net_pushed >= 0); - for _ in 0..net_pushed { - push_ref(&mut refs, i as isize, NOT_LOCAL); - } - } - AnyInstruction::Real( - Instruction::DictMerge { .. } - | Instruction::DictUpdate { .. } - | Instruction::ListAppend { .. } - | Instruction::ListExtend { .. } - | Instruction::MapAdd { .. } - | Instruction::Reraise { .. } - | Instruction::SetAdd { .. } - | Instruction::SetUpdate { .. }, - ) => { - let effect = instr.stack_effect_info(arg_u32); - let net_popped = effect.popped() as isize - effect.pushed() as isize; - debug_assert!(net_popped > 0); - for _ in 0..net_popped { - let _ = pop_ref(&mut refs); - } - } - AnyInstruction::Real( - Instruction::EndSend | Instruction::SetFunctionAttribute { .. }, - ) => { - let Some(tos) = pop_ref(&mut refs) else { - continue; - }; - let _ = pop_ref(&mut refs); - push_ref(&mut refs, tos.instr, tos.local); - } - AnyInstruction::Real(Instruction::CheckExcMatch) => { - let _ = pop_ref(&mut refs); - push_ref(&mut refs, i as isize, NOT_LOCAL); - } - AnyInstruction::Real(Instruction::ForIter { .. }) => { - let target = info.target; - if target != BlockIdx::NULL { - push_block( - &mut worklist, - &mut visited, - &self.blocks, - block_idx, - target, - refs.len() + 1, - ); - } - push_ref(&mut refs, i as isize, NOT_LOCAL); - } - AnyInstruction::Real(Instruction::LoadAttr { .. }) => { - let Some(self_ref) = pop_ref(&mut refs) else { - continue; - }; - push_ref(&mut refs, i as isize, NOT_LOCAL); - if arg_u32 & 1 != 0 { - push_ref(&mut refs, self_ref.instr, self_ref.local); - } - } - AnyInstruction::Real(Instruction::LoadSuperAttr { .. }) => { - let _ = pop_ref(&mut refs); - let _ = pop_ref(&mut refs); - let Some(self_ref) = pop_ref(&mut refs) else { - continue; - }; - push_ref(&mut refs, i as isize, NOT_LOCAL); - if arg_u32 & 1 != 0 { - push_ref(&mut refs, self_ref.instr, self_ref.local); - } - } - AnyInstruction::Real( - Instruction::LoadSpecial { .. } | Instruction::PushExcInfo, - ) => { - let Some(tos) = pop_ref(&mut refs) else { - continue; - }; - push_ref(&mut refs, i as isize, NOT_LOCAL); - push_ref(&mut refs, tos.instr, tos.local); - } - AnyInstruction::Real(Instruction::Send { .. }) => { - let target = info.target; - if target != BlockIdx::NULL { - push_block( - &mut worklist, - &mut visited, - &self.blocks, - block_idx, - target, - refs.len(), - ); - } - let _ = pop_ref(&mut refs); - push_ref(&mut refs, i as isize, NOT_LOCAL); - } - _ => { - let effect = instr.stack_effect_info(arg_u32); - let num_popped = effect.popped() as usize; - let num_pushed = effect.pushed() as usize; - let target = info.target; - if target != BlockIdx::NULL { - let target_depth = refs - .len() - .saturating_sub(num_popped) - .saturating_add(num_pushed); - push_block( - &mut worklist, - &mut visited, - &self.blocks, - block_idx, - target, - target_depth, - ); - } - if !instr.is_block_push() { - for _ in 0..num_popped { - let _ = pop_ref(&mut refs); - } - for _ in 0..num_pushed { - push_ref(&mut refs, i as isize, NOT_LOCAL); - } - } - } - } - } - - let next = block.next; - if next != BlockIdx::NULL - && block.instructions.last().is_none_or(|term| { - !term.instr.is_unconditional_jump() && !term.instr.is_scope_exit() - }) - { - push_block( - &mut worklist, - &mut visited, - &self.blocks, - block_idx, - next, - refs.len(), - ); - } - - for r in refs { - if r.instr != DUMMY_INSTR { - instr_flags[r.instr as usize] |= REF_UNCONSUMED; - } - } - - let block = &mut self.blocks[block_idx]; - if block.disable_load_fast_borrow || handler_resume_loop_latch[block_idx.idx()] { - continue; - } - for (i, info) in block.instructions.iter_mut().enumerate() { - if instr_flags[i] != 0 { - continue; - } - match info.instr.real() { - Some(Instruction::LoadFast { .. }) => { - info.instr = Instruction::LoadFastBorrow { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastLoadFast { .. }) => { - info.instr = Instruction::LoadFastBorrowLoadFastBorrow { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - } - - fn compute_load_fast_start_depths(&mut self) { - fn stackdepth_push( - stack: &mut Vec, - start_depths: &mut [u32], - target: BlockIdx, - depth: u32, - ) { - let idx = target.idx(); - let block_depth = &mut start_depths[idx]; - debug_assert!( - *block_depth == u32::MAX || *block_depth == depth, - "Invalid CFG, inconsistent optimize_load_fast stackdepth for block {:?}: existing={}, new={}", - target, - *block_depth, - depth, - ); - if *block_depth == u32::MAX { - *block_depth = depth; - stack.push(target); - } - } - - let mut stack = Vec::with_capacity(self.blocks.len()); - let mut start_depths = vec![u32::MAX; self.blocks.len()]; - stackdepth_push(&mut stack, &mut start_depths, BlockIdx(0), 0); - - 'process_blocks: while let Some(block_idx) = stack.pop() { - let mut depth = start_depths[block_idx.idx()]; - let block = &self.blocks[block_idx]; - for ins in &block.instructions { - let instr = &ins.instr; - let effect = instr.stack_effect(ins.arg.into()); - let new_depth = depth.saturating_add_signed(effect); - if ins.target != BlockIdx::NULL { - let jump_effect = instr.stack_effect_jump(ins.arg.into()); - let target_depth = depth.saturating_add_signed(jump_effect); - stackdepth_push(&mut stack, &mut start_depths, ins.target, target_depth); - } - depth = new_depth; - if instr.is_scope_exit() || instr.is_unconditional_jump() { - continue 'process_blocks; - } - } - if block.next != BlockIdx::NULL { - stackdepth_push(&mut stack, &mut start_depths, block.next, depth); - } - } - - for (block, &start_depth) in self.blocks.iter_mut().zip(&start_depths) { - block.start_depth = (start_depth != u32::MAX).then_some(start_depth); - } - } - - fn deoptimize_borrow_in_targeted_assert_message_blocks(&mut self) { - fn is_assertion_error_load(info: &InstructionInfo) -> bool { - matches!( - info.instr.real(), - Some(Instruction::LoadCommonConstant { idx }) - if idx.get(info.arg) == oparg::CommonConstant::AssertionError - ) - } - - fn is_direct_call_zero(info: &InstructionInfo) -> bool { - matches!( - info.instr.real(), - Some(Instruction::Call { argc }) if argc.get(info.arg) == 0 - ) - } - - fn has_prior_real_work(block: &Block, start: usize) -> bool { - block.instructions[..start].iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - }) - } - - fn deoptimize_borrow(info: &mut InstructionInfo) { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - - fn has_same_line_target_predecessor( - blocks: &[Block], - incoming_origins: &[Vec], - block_idx: BlockIdx, - line: i32, - ) -> bool { - incoming_origins[block_idx.idx()].iter().any(|&pred| { - blocks[pred.idx()].instructions.iter().any(|info| { - info.target != BlockIdx::NULL - && next_nonempty_block(blocks, info.target) == block_idx - && instruction_lineno(info) == line - }) - }) - } - - let target_flags = compute_target_predecessor_flags(&self.blocks); - let reachable = compute_reachable_blocks(&self.blocks); - let incoming_origins = compute_incoming_origins(&self.blocks, &reachable); - for block_idx in 0..self.blocks.len() { - if block_idx == 0 || !target_flags.targeted[block_idx] { - continue; - } - - let block = &self.blocks[block_idx]; - let mut assert_start = None; - let mut ranges = Vec::new(); - for i in 0..block.instructions.len() { - if is_assertion_error_load(&block.instructions[i]) { - assert_start = Some(i); - continue; - } - - let Some(start) = assert_start else { - continue; - }; - if !is_direct_call_zero(&block.instructions[i]) { - continue; - } - if !block.instructions[i + 1..] - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::RaiseVarargs { .. }))) - { - assert_start = None; - continue; - } - if has_prior_real_work(block, start) { - assert_start = None; - continue; - } - - let assert_line = instruction_lineno(&block.instructions[start]); - if has_same_line_target_predecessor( - &self.blocks, - &incoming_origins, - BlockIdx::new(block_idx as u32), - assert_line, - ) { - assert_start = None; - continue; - } - - ranges.push((start + 1, i)); - assert_start = None; - } - - let block = &mut self.blocks[block_idx]; - for (start, end) in ranges { - for info in &mut block.instructions[start..end] { - deoptimize_borrow(info); - } - } - } - } - - fn mark_unprotected_debug_four_tails_borrow_disabled(&mut self) { - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn debug_four_guard_name_load( - block: &Block, - names: &IndexSet, - varnames: &IndexSet, - ) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .take(6) - .collect(); - if reals.len() < 5 { - return false; - } - let loads_imap_fast = match reals[0].instr.real() { - Some( - Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num }, - ) => varnames - .get_index(usize::from(var_num.get(reals[0].arg))) - .is_some_and(|name| name.as_str() == "imap"), - _ => false, - }; - let loads_debug_attr = matches!( - reals[1].instr.real(), - Some(Instruction::LoadAttr { namei }) - if names[usize::try_from(namei.get(reals[1].arg).name_idx()).unwrap()].as_str() - == "debug" - ); - let compares_with_four = reals.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadSmallInt { i }) if i.get(info.arg) == 4 - ) - }) && reals - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::CompareOp { .. }))); - let has_conditional = reals.iter().any(|info| is_conditional_jump(&info.instr)); - loads_imap_fast && loads_debug_attr && compares_with_four && has_conditional - } - - fn block_has_jump_back_predecessor_to( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - ) -> bool { - predecessors[target.idx()].iter().any(|pred| { - blocks[pred.idx()].instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - }) - } - - fn block_has_mesg_call(block: &Block, names: &IndexSet) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadAttr { namei }) - if names[usize::try_from(namei.get(info.arg).name_idx()).unwrap()].as_str() - == "_mesg" - ) - }) && block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::Call { .. }))) - } - - fn normal_successors( - blocks: &[Block], - predecessors: &[Vec], - block_idx: BlockIdx, - ) -> Vec { - let block = &blocks[block_idx.idx()]; - let mut successors = Vec::new(); - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - successors.push(block.next); - } - let is_loop_header = - block_has_jump_back_predecessor_to(blocks, predecessors, block_idx); - for info in &block.instructions { - if info.target != BlockIdx::NULL && !is_loop_header { - successors.push(info.target); - } - } - successors - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let mut to_disable = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - let has_protected_predecessor = predecessors[idx] - .iter() - .any(|pred| block_has_protected_instructions(&self.blocks[pred.idx()])); - if block.cold - || block_is_exceptional(block) - || block_has_protected_instructions(block) - || has_protected_predecessor - || !debug_four_guard_name_load(block, &self.metadata.names, &self.metadata.varnames) - { - continue; - } - let block_idx = BlockIdx::new(idx as u32); - to_disable.push(block_idx); - let mut seen = vec![false; self.blocks.len()]; - let mut stack = normal_successors(&self.blocks, &predecessors, block_idx); - while let Some(successor) = stack.pop() { - if successor == BlockIdx::NULL || seen[successor.idx()] { - continue; - } - seen[successor.idx()] = true; - let successor_block = &self.blocks[successor.idx()]; - if successor_block.cold || block_is_exceptional(successor_block) { - continue; - } - to_disable.push(successor); - if block_has_protected_instructions(successor_block) - || successor_block - .instructions - .last() - .is_some_and(|info| info.instr.is_scope_exit()) - || successor_block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - { - continue; - } - if block_has_mesg_call(successor_block, &self.metadata.names) { - let has_loop_successor = - normal_successors(&self.blocks, &predecessors, successor) - .into_iter() - .any(|next| { - next != BlockIdx::NULL - && block_has_jump_back_predecessor_to( - &self.blocks, - &predecessors, - next, - ) - }); - if !has_loop_successor { - continue; - } - } - for next in normal_successors(&self.blocks, &predecessors, successor) { - stack.push(next); - } - } - } - to_disable.sort_by_key(|idx| idx.idx()); - to_disable.dedup(); - for block_idx in to_disable { - self.blocks[block_idx.idx()].disable_load_fast_borrow = true; - } - } - - fn mark_exception_handler_transition_targets_borrow_disabled(&mut self) { - fn first_real_handler(block: &Block) -> Option { - block - .instructions - .iter() - .find(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .and_then(|info| info.except_handler) - } - - fn last_real_handler(block: &Block) -> Option { - block - .instructions - .iter() - .rev() - .find(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .and_then(|info| info.except_handler) - } - - fn has_direct_tuple_return_tail(block: &Block) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .collect(); - reals - .iter() - .any(|instr| matches!(instr, Instruction::BuildTuple { .. })) - && reals - .last() - .is_some_and(|instr| matches!(instr, Instruction::ReturnValue)) - && !reals.iter().any(|instr| { - matches!( - instr, - Instruction::Swap { .. } - | Instruction::PopExcept - | Instruction::Reraise { .. } - | Instruction::WithExceptStart - ) - }) - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (idx, block) in self.blocks.iter().enumerate() { - let block_idx = BlockIdx::new(idx as u32); - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(block_idx); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(block_idx); - } - } - } - - let mut to_disable = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - if !has_direct_tuple_return_tail(block) { - continue; - } - let Some(handler) = first_real_handler(block) else { - continue; - }; - if predecessors[idx].iter().any(|pred| { - last_real_handler(&self.blocks[pred.idx()]) - .is_some_and(|pred_handler| pred_handler != handler) - }) { - to_disable.push(idx); - } - } - - for idx in to_disable { - self.blocks[idx].disable_load_fast_borrow = true; - } - } - - fn mark_targeted_nop_for_tails_borrow_disabled(&mut self) { - fn is_nop_only_block(block: &Block) -> bool { - !block.instructions.is_empty() - && block.instructions.iter().all(|info| { - matches!( - info.instr.real(), - Some(Instruction::Nop | Instruction::NotTaken) - ) - }) - } - - fn starts_for_iter_tail(block: &Block) -> bool { - let mut saw_iterable = false; - for info in block.instructions.iter().filter(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - }) { - match info.instr.real() { - Some( - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadName { .. } - | Instruction::LoadGlobal { .. }, - ) if !saw_iterable => saw_iterable = true, - Some(Instruction::GetIter) if saw_iterable => return true, - Some(Instruction::BuildList { .. } | Instruction::StoreFast { .. }) - if !saw_iterable => {} - _ => return false, - } - } - false - } - - let mut fallthrough_predecessors = vec![Vec::new(); self.blocks.len()]; - let mut jump_predecessors = vec![Vec::new(); self.blocks.len()]; - for (idx, block) in self.blocks.iter().enumerate() { - let block_idx = BlockIdx::new(idx as u32); - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - fallthrough_predecessors[block.next.idx()].push(block_idx); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - jump_predecessors[info.target.idx()].push(block_idx); - } - } - } - - let mut seeds = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - if !starts_for_iter_tail(block) { - continue; - } - let has_targeted_nop_predecessor = fallthrough_predecessors[idx].iter().any(|pred| { - is_nop_only_block(&self.blocks[pred.idx()]) - && !jump_predecessors[pred.idx()].is_empty() - }); - if has_targeted_nop_predecessor { - seeds.push(BlockIdx::new(idx as u32)); - } - } - - let mut seen = vec![false; self.blocks.len()]; - for seed in seeds { - let mut stack = vec![seed]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || seen[block_idx.idx()] { - continue; - } - seen[block_idx.idx()] = true; - self.blocks[block_idx.idx()].disable_load_fast_borrow = true; - - let block = &self.blocks[block_idx.idx()]; - if block - .instructions - .last() - .is_some_and(|info| info.instr.is_scope_exit()) - { - continue; - } - if block.next != BlockIdx::NULL && block.next.idx() >= seed.idx() { - stack.push(block.next); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL && info.target.idx() >= seed.idx() { - stack.push(info.target); - } - } - } - } - } - - fn restore_conditional_exception_for_iter_join_borrows(&mut self) { - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn is_conditional_predecessor_to(block: &Block, target: BlockIdx) -> bool { - block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::PopJumpIfFalse { .. } - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfNone { .. } - | Instruction::PopJumpIfNotNone { .. } - ) - ) - }) - } - - fn starts_with_for_iter(block: &Block) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .take(3) - .collect(); - matches!( - reals.as_slice(), - [ - Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }, - Instruction::GetIter, - .. - ] | [ - Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }, - Instruction::LoadAttr { .. }, - Instruction::GetIter, - .. - ] - ) - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (idx, block) in self.blocks.iter().enumerate() { - let block_idx = BlockIdx::new(idx as u32); - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(block_idx); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(block_idx); - } - } - } - - let mut to_restore = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - if !block.disable_load_fast_borrow - || block.cold - || block_is_exceptional(block) - || !starts_with_for_iter(block) - { - continue; - } - let target = BlockIdx::new(idx as u32); - let has_protected_predecessor = predecessors[idx] - .iter() - .any(|pred| block_has_protected_instructions(&self.blocks[pred.idx()])); - let has_conditional_normal_predecessor = predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !pred_block.disable_load_fast_borrow - && !pred_block.cold - && !block_is_exceptional(pred_block) - && !block_has_protected_instructions(pred_block) - && is_conditional_predecessor_to(pred_block, target) - }); - if has_protected_predecessor && has_conditional_normal_predecessor { - to_restore.push(idx); - } - } - - for idx in to_restore { - self.blocks[idx].disable_load_fast_borrow = false; - } - } - - fn deoptimize_borrow_for_handler_return_paths(&mut self) { - for block in &mut self.blocks { - let len = block.instructions.len(); - for i in 0..len { - let Some(Instruction::LoadFastBorrow { .. }) = block.instructions[i].instr.real() - else { - continue; - }; - let tail = &block.instructions[i + 1..]; - if tail.len() < 3 { - continue; - } - if !matches!(tail[0].instr.real(), Some(Instruction::Swap { .. })) { - continue; - } - if !matches!(tail[1].instr.real(), Some(Instruction::PopExcept)) { - continue; - } - if !matches!(tail[2].instr.real(), Some(Instruction::ReturnValue)) { - continue; - } - block.instructions[i].instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - } - } - - fn deoptimize_borrow_after_generator_exception_return(&mut self) { - if !self.flags.contains(CodeFlags::GENERATOR) { - return; - } - - fn deoptimize_block_borrows(block: &mut Block) { - let mut after_end_send = false; - for info in &mut block.instructions { - if matches!(info.instr.real(), Some(Instruction::EndSend)) { - after_end_send = true; - continue; - } - if after_end_send { - continue; - } - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn handler_checks_exception(blocks: &[Block], handler: BlockIdx) -> bool { - let mut stack = vec![handler]; - let mut visited = vec![false; blocks.len()]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL { - continue; - } - let idx = block_idx.idx(); - if visited[idx] { - continue; - } - visited[idx] = true; - - let block = &blocks[idx]; - let mut can_fallthrough = true; - for info in &block.instructions { - if matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) { - return true; - } - if info.target != BlockIdx::NULL { - stack.push(info.target); - } - if info.instr.is_scope_exit() || info.instr.is_unconditional_jump() { - can_fallthrough = false; - break; - } - } - if can_fallthrough && block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - false - } - - fn handler_returns_without_yield(blocks: &[Block], handler: BlockIdx) -> bool { - let mut stack = vec![(handler, false)]; - let mut visited = vec![[false; 2]; blocks.len()]; - while let Some((block_idx, mut saw_yield)) = stack.pop() { - if block_idx == BlockIdx::NULL { - continue; - } - let idx = block_idx.idx(); - let yield_idx = usize::from(saw_yield); - if visited[idx][yield_idx] { - continue; - } - visited[idx][yield_idx] = true; - - let block = &blocks[idx]; - let handler_resume_jump = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - && block.instructions.last().is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }); - let mut can_fallthrough = true; - for info in &block.instructions { - if matches!(info.instr.real(), Some(Instruction::YieldValue { .. })) { - saw_yield = true; - } - if matches!(info.instr.real(), Some(Instruction::ReturnValue)) { - if !saw_yield { - return true; - } - can_fallthrough = false; - break; - } - if info.target != BlockIdx::NULL - && !(handler_resume_jump && info.instr.is_unconditional_jump()) - { - stack.push((info.target, saw_yield)); - } - if info.instr.is_scope_exit() || info.instr.is_unconditional_jump() { - can_fallthrough = false; - break; - } - } - if can_fallthrough && block.next != BlockIdx::NULL { - stack.push((block.next, saw_yield)); - } - } - false - } - - fn normal_path_reaches_returning_handler( - blocks: &[Block], - start: BlockIdx, - returning_handler: &[bool], - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block.instructions.iter().any(|info| { - info.except_handler - .is_some_and(|handler| returning_handler[handler.handler_block.idx()]) - }) { - return true; - } - let Some(last) = block.instructions.last() else { - if block.next != BlockIdx::NULL { - stack.push(block.next); - } - continue; - }; - if last.instr.is_scope_exit() { - continue; - } - if last.instr.is_unconditional_jump() { - if last.target != BlockIdx::NULL { - stack.push(last.target); - } - continue; - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - stack.push(target); - } - } - if block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - false - } - - let mut returning_handler = vec![false; self.blocks.len()]; - for block in &self.blocks { - for handler in block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - { - if !returning_handler[handler.idx()] { - returning_handler[handler.idx()] = - handler_checks_exception(&self.blocks, handler) - && handler_returns_without_yield(&self.blocks, handler); - } - } - } - - let seeds: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) || block.cold { - return None; - } - let seed = BlockIdx::new(idx as u32); - let prev_protected_return = self.blocks.iter().any(|pred| { - pred.next == seed - && pred.instructions.iter().any(|info| { - info.except_handler.is_some_and(|handler| { - returning_handler[handler.handler_block.idx()] - }) - }) - }); - (prev_protected_return - && !normal_path_reaches_returning_handler( - &self.blocks, - seed, - &returning_handler, - )) - .then_some(seed) - }) - .collect(); - - let mut visited = vec![false; self.blocks.len()]; - for seed in seeds { - let mut cursor = seed; - while cursor != BlockIdx::NULL { - let idx = cursor.idx(); - if visited[idx] || block_is_exceptional(&self.blocks[idx]) || self.blocks[idx].cold - { - break; - } - visited[idx] = true; - deoptimize_block_borrows(&mut self.blocks[idx]); - cursor = self.blocks[idx].next; - } - } - } - - fn deoptimize_borrow_after_async_for_cleanup_resume(&mut self) { - fn deoptimize_block_borrows_from(block: &mut Block, start: usize) { - for info in block.instructions.iter_mut().skip(start) { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - let mut same_block_starts = Vec::new(); - let mut seeds = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - for (instr_idx, info) in block.instructions.iter().enumerate() { - if !matches!(info.instr.real(), Some(Instruction::EndAsyncFor)) { - continue; - } - if block.instructions[instr_idx + 1..] - .iter() - .any(|info| info.instr.real().is_some()) - { - same_block_starts.push((BlockIdx::new(idx as u32), instr_idx + 1)); - } else if block.next != BlockIdx::NULL { - let next = &self.blocks[block.next.idx()]; - let seed = next - .instructions - .last() - .filter(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) - .map_or(block.next, |info| info.target); - seeds.push(seed); - } - } - } - - for (block_idx, start) in same_block_starts { - if !block_is_exceptional(&self.blocks[block_idx.idx()]) { - deoptimize_block_borrows_from(&mut self.blocks[block_idx.idx()], start); - } - } - for seed in seeds { - if seed != BlockIdx::NULL && !block_is_exceptional(&self.blocks[seed.idx()]) { - deoptimize_block_borrows_from(&mut self.blocks[seed.idx()], 0); - } - } - } - - fn deoptimize_borrow_after_deoptimized_async_with_enter(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_deoptimized_async_with_enter(block: &Block) -> bool { - let has_async_enter = block - .instructions - .iter() - .any(|info| match info.instr.real() { - Some(Instruction::LoadSpecial { method }) => { - method.get(info.arg) == oparg::SpecialMethod::AEnter - } - _ => false, - }); - let has_strong_fast = block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastLoadFast { .. }) - ) - }); - has_async_enter && has_strong_fast - } - - fn send_targets(block: &Block) -> impl Iterator + '_ { - block.instructions.iter().filter_map(|info| { - matches!(info.instr.real(), Some(Instruction::Send { .. })) - .then_some(info.target) - .filter(|target| *target != BlockIdx::NULL) - }) - } - - fn block_calls_before_raise(block: &Block) -> bool { - let mut saw_call = false; - for info in &block.instructions { - match info.instr.real() { - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx, - ) => saw_call = true, - Some(Instruction::RaiseVarargs { .. }) => return saw_call, - _ => {} - } - } - false - } - - let mut seeds = Vec::new(); - for block in &self.blocks { - if !block_has_deoptimized_async_with_enter(block) { - continue; - } - seeds.extend(send_targets(block)); - if block.next != BlockIdx::NULL { - seeds.extend(send_targets(&self.blocks[block.next.idx()])); - } - } - - for seed in seeds { - if !block_is_exceptional(&self.blocks[seed.idx()]) - && block_calls_before_raise(&self.blocks[seed.idx()]) - { - deoptimize_block_borrows(&mut self.blocks[seed.idx()]); - } - } - } - - fn deoptimize_borrow_after_multi_handler_resume_join(&mut self) { - fn first_real_instr(block: &Block) -> Option { - block.instructions.iter().find_map(|info| info.instr.real()) - } - - fn is_handler_resume_jump_block(block: &Block) -> bool { - let Some(last_info) = block.instructions.last() else { - return false; - }; - if last_info.target == BlockIdx::NULL || !last_info.instr.is_unconditional_jump() { - return false; - } - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - } - - fn is_with_suppress_resume_jump_block(block: &Block) -> bool { - if !is_handler_resume_jump_block(block) { - return false; - } - - let mut saw_pop_except = false; - let mut pop_top_after_pop_except = 0usize; - for info in &block.instructions { - match info.instr.real() { - Some(Instruction::PopExcept) => saw_pop_except = true, - Some(Instruction::PopTop) if saw_pop_except => { - pop_top_after_pop_except += 1; - } - _ => {} - } - } - saw_pop_except && pop_top_after_pop_except >= 3 - } - - fn block_has_check_exc_match(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - } - - fn block_has_push_exc_info(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PushExcInfo))) - } - - fn mark_handler_entries( - blocks: &[Block], - predecessors: &[Vec], - start: BlockIdx, - stop: BlockIdx, - entries: &mut [bool], - ) { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || cursor == stop || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - if block_has_push_exc_info(&blocks[cursor.idx()]) { - entries[cursor.idx()] = true; - continue; - } - for pred in &predecessors[cursor.idx()] { - stack.push(*pred); - } - } - } - - fn predecessor_chain_has_check_exc_match( - blocks: &[Block], - predecessors: &[Vec], - start: BlockIdx, - stop: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || cursor == stop || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - if block_has_check_exc_match(&blocks[cursor.idx()]) { - return true; - } - for pred in &predecessors[cursor.idx()] { - stack.push(*pred); - } - } - false - } - - fn has_exception_match_resume_predecessor( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - ) -> bool { - predecessors[target.idx()].iter().any(|pred| { - let block = &blocks[pred.idx()]; - is_handler_resume_jump_block(block) - && !is_with_suppress_resume_jump_block(block) - && predecessor_chain_has_check_exc_match(blocks, predecessors, *pred, target) - }) - } - - fn starts_with_fast_attr_call(block: &Block) -> bool { - let infos: Vec<_> = block - .instructions - .iter() - .filter(|info| info.instr.real().is_some()) - .take(2) - .collect(); - matches!( - infos.as_slice(), - [ - first, - second, - .. - ] if matches!( - first.instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }) - ) && matches!(second.instr.real(), Some(Instruction::LoadAttr { .. })) - ) - } - - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn starts_with_conditional_guard(block: &Block) -> bool { - let infos: Vec<_> = block - .instructions - .iter() - .filter(|info| info.instr.real().is_some()) - .take(3) - .collect(); - if infos.len() < 2 { - return false; - } - let starts_with_load_fast = matches!( - infos[0].instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }) - ); - if !starts_with_load_fast { - return false; - } - matches!( - infos.get(1).and_then(|info| info.instr.real()), - Some( - Instruction::PopJumpIfFalse { .. } - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfNone { .. } - | Instruction::PopJumpIfNotNone { .. } - ) - ) || (matches!(infos[1].instr.real(), Some(Instruction::ToBool)) - && matches!( - infos.get(2).and_then(|info| info.instr.real()), - Some( - Instruction::PopJumpIfFalse { .. } - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfNone { .. } - | Instruction::PopJumpIfNotNone { .. } - ) - )) - } - - let mut handler_resume_predecessors = vec![Vec::new(); self.blocks.len()]; - let mut is_handler_resume_block = vec![false; self.blocks.len()]; - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - for (block_idx, block) in self.blocks.iter().enumerate() { - if !is_handler_resume_jump_block(block) { - continue; - } - is_handler_resume_block[block_idx] = true; - let target = block - .instructions - .last() - .expect("resume jump block has a last instruction") - .target; - handler_resume_predecessors[target.idx()].push(BlockIdx::new(block_idx as u32)); - } - let function_has_with_cleanup = self.blocks.iter().any(|block| { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::WithExceptStart))) - }); - - let mut visited = vec![false; self.blocks.len()]; - for (idx, resume_preds) in handler_resume_predecessors.iter().enumerate() { - if resume_preds.len() < 2 { - continue; - } - let seed = BlockIdx::new(idx as u32); - let mut handler_entries = vec![false; self.blocks.len()]; - for pred in resume_preds { - mark_handler_entries( - &self.blocks, - &predecessors, - *pred, - seed, - &mut handler_entries, - ); - } - if handler_entries.iter().filter(|&&is_entry| is_entry).count() < 2 { - continue; - } - if matches!( - first_real_instr(&self.blocks[seed.idx()]), - Some(Instruction::ForIter { .. }) - ) { - continue; - } - let mut segment = Vec::new(); - let mut cursor = seed; - while cursor != BlockIdx::NULL { - if block_is_exceptional(&self.blocks[cursor.idx()]) { - break; - } - segment.push(cursor); - cursor = self.blocks[cursor.idx()].next; - } - let has_complex_tail = segment.iter().any(|block_idx| { - self.blocks[block_idx.idx()] - .instructions - .iter() - .any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::ForIter { .. } - | Instruction::EndFor - | Instruction::PopIter - | Instruction::LoadFastAndClear { .. } - | Instruction::LoadFastCheck { .. } - | Instruction::ListAppend { .. } - | Instruction::MapAdd { .. } - | Instruction::SetAdd { .. } - ) - ) - }) - }); - let tail_enters_stacked_context = segment.iter().any(|block_idx| { - self.blocks[block_idx.idx()] - .start_depth - .is_some_and(|depth| depth > 1) - }); - let has_simple_with_except_resume_tail = - starts_with_fast_attr_call(&self.blocks[seed.idx()]) - && has_exception_match_resume_predecessor(&self.blocks, &predecessors, seed) - && predecessors[seed.idx()] - .iter() - .any(|pred| is_with_suppress_resume_jump_block(&self.blocks[pred.idx()])); - if !(starts_with_conditional_guard(&self.blocks[seed.idx()]) - && has_complex_tail - && !(function_has_with_cleanup && tail_enters_stacked_context) - || has_simple_with_except_resume_tail) - { - continue; - } - - let mut in_segment = vec![false; self.blocks.len()]; - for block_idx in &segment { - in_segment[block_idx.idx()] = true; - } - for block_idx in segment { - if visited[block_idx.idx()] { - continue; - } - if block_idx != seed - && predecessors[block_idx.idx()].iter().any(|pred| { - !in_segment[pred.idx()] - && !is_handler_resume_block[pred.idx()] - && self.blocks[pred.idx()] - .instructions - .iter() - .any(|info| info.instr.real().is_some()) - }) - { - continue; - } - visited[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - } - - fn deoptimize_borrow_after_named_except_cleanup_join(&mut self) { - fn first_real_instr(block: &Block) -> Option { - block.instructions.iter().find_map(|info| info.instr.real()) - } - - fn leading_bool_guard_local(block: &Block) -> Option { - let infos: Vec<_> = block - .instructions - .iter() - .filter(|info| info.instr.real().is_some()) - .take(3) - .collect(); - if infos.len() < 3 { - return None; - } - let load_local = match infos[0].instr.real() { - Some(Instruction::LoadFast { var_num }) => usize::from(var_num.get(infos[0].arg)), - Some(Instruction::LoadFastBorrow { var_num }) => { - usize::from(var_num.get(infos[0].arg)) - } - _ => return None, - }; - if !matches!(infos[1].instr.real(), Some(Instruction::ToBool)) { - return None; - } - if !matches!( - infos[2].instr.real(), - Some( - Instruction::PopJumpIfFalse { .. } - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfNone { .. } - | Instruction::PopJumpIfNotNone { .. } - ) - ) { - return None; - } - Some(load_local) - } - - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_simple_scope_exit(block: &Block) -> bool { - for info in &block.instructions { - match info.instr.real() { - Some(instr) if instr.is_scope_exit() => return true, - Some( - Instruction::Nop - | Instruction::NotTaken - | Instruction::LoadConst { .. } - | Instruction::LoadSmallInt { .. } - | Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - | Instruction::StoreName { .. } - | Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastCheck { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - | Instruction::BuildTuple { .. }, - ) => {} - Some(_) => return false, - None => {} - } - } - false - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last_info) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL && !successors.contains(&block.next) { - successors.push(block.next); - } - return successors; - } - if last_info.instr.is_scope_exit() { - return Vec::new(); - } - if last_info.instr.is_unconditional_jump() { - return (last_info.target != BlockIdx::NULL) - .then_some(last_info.target) - .into_iter() - .collect(); - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - fn is_return_value_through_with_exit(block: &Block) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .collect(); - if !matches!( - reals.first(), - Some( - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) { - return false; - } - if !matches!(reals.last(), Some(Instruction::ReturnValue)) { - return false; - } - reals.iter().skip(1).all(|instr| { - matches!( - instr, - Instruction::Swap { .. } - | Instruction::LoadConst { .. } - | Instruction::Call { .. } - | Instruction::PopTop - | Instruction::ReturnValue - ) - }) - } - - fn is_simple_store_attr_tail(block: &Block) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .collect(); - matches!( - reals.as_slice(), - [ - Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }, - Instruction::StoreAttr { .. }, - .. - ] | [ - Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. }, - Instruction::StoreAttr { .. }, - .. - ] - ) - } - - fn leading_bool_guard_has_scope_exit_successor(blocks: &[Block], block: &Block) -> bool { - let Some(cond_idx) = trailing_conditional_jump_index(block) else { - return false; - }; - [block.instructions[cond_idx].target, block.next] - .into_iter() - .any(|successor| { - successor != BlockIdx::NULL - && block_has_simple_scope_exit(&blocks[successor.idx()]) - }) - } - - fn path_reaches_named_cleanup( - blocks: &[Block], - start: BlockIdx, - cleanup: BlockIdx, - resume_target: BlockIdx, - ) -> bool { - if start == BlockIdx::NULL || start == resume_target { - return false; - } - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL - || block_idx == resume_target - || visited[block_idx.idx()] - { - continue; - } - if block_idx == cleanup { - return true; - } - visited[block_idx.idx()] = true; - for successor in normal_successors(&blocks[block_idx.idx()]) { - stack.push(successor); - } - } - false - } - - fn path_reaches_explicit_raise( - blocks: &[Block], - start: BlockIdx, - cleanup: BlockIdx, - resume_target: BlockIdx, - ) -> bool { - if start == BlockIdx::NULL || start == cleanup || start == resume_target { - return false; - } - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL - || block_idx == cleanup - || block_idx == resume_target - || visited[block_idx.idx()] - { - continue; - } - let block = &blocks[block_idx.idx()]; - if block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::RaiseVarargs { .. }))) - { - return true; - } - visited[block_idx.idx()] = true; - for successor in normal_successors(block) { - stack.push(successor); - } - } - false - } - - fn named_cleanup_has_conditional_raise_sibling( - blocks: &[Block], - cleanup: BlockIdx, - resume_target: BlockIdx, - ) -> bool { - for block in blocks { - let Some(cond_idx) = trailing_conditional_jump_index(block) else { - continue; - }; - let jump_target = block.instructions[cond_idx].target; - let fallthrough = block.next; - if jump_target == BlockIdx::NULL || fallthrough == BlockIdx::NULL { - continue; - } - - let jump_reaches_cleanup = - path_reaches_named_cleanup(blocks, jump_target, cleanup, resume_target); - let fallthrough_reaches_cleanup = - path_reaches_named_cleanup(blocks, fallthrough, cleanup, resume_target); - if jump_reaches_cleanup == fallthrough_reaches_cleanup { - continue; - } - - let sibling = if jump_reaches_cleanup { - fallthrough - } else { - jump_target - }; - if path_reaches_explicit_raise(blocks, sibling, cleanup, resume_target) { - return true; - } - } - false - } - - fn linear_tail_has_with_setup(blocks: &[Block], start: BlockIdx) -> bool { - let mut cursor = start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if block_is_exceptional(block) { - return false; - } - if block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::LoadSpecial { .. }))) - { - return true; - } - let Some(last_info) = block.instructions.last() else { - cursor = block.next; - continue; - }; - if last_info.instr.is_scope_exit() || last_info.instr.is_unconditional_jump() { - return false; - } - cursor = block.next; - } - false - } - - fn is_with_suppress_resume_block(block: &Block) -> bool { - let Some(last_info) = block.instructions.last() else { - return false; - }; - if last_info.target == BlockIdx::NULL || !last_info.instr.is_unconditional_jump() { - return false; - } - - let mut saw_pop_except = false; - let mut pop_top_after_pop_except = 0usize; - for info in &block.instructions { - match info.instr.real() { - Some(Instruction::PopExcept) => saw_pop_except = true, - Some(Instruction::PopTop) if saw_pop_except => { - pop_top_after_pop_except += 1; - } - _ => {} - } - } - saw_pop_except && pop_top_after_pop_except >= 3 - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - let mut named_cleanup_predecessors = vec![0usize; self.blocks.len()]; - let mut named_cleanup_requires_deopt = vec![false; self.blocks.len()]; - let mut has_with_suppress_resume_predecessor = vec![false; self.blocks.len()]; - let mut is_allowed_cleanup_resume_block = vec![false; self.blocks.len()]; - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - - for (block_idx, block) in self.blocks.iter().enumerate() { - let Some(last_info) = block.instructions.last() else { - continue; - }; - if last_info.target == BlockIdx::NULL || !last_info.instr.is_unconditional_jump() { - continue; - } - if is_with_suppress_resume_block(block) { - has_with_suppress_resume_predecessor[last_info.target.idx()] = true; - } - if block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - { - is_allowed_cleanup_resume_block[block_idx] = true; - } - if !is_named_except_cleanup_normal_exit_block(block) { - continue; - } - if matches!( - first_real_instr(&self.blocks[last_info.target.idx()]), - Some(Instruction::ForIter { .. }) - ) { - continue; - } - if matches!( - last_info.instr.real(), - Some(Instruction::JumpBackward { .. }) - ) && block_has_protected_instructions(&self.blocks[last_info.target.idx()]) - { - continue; - } - named_cleanup_predecessors[last_info.target.idx()] += 1; - if linear_tail_has_with_setup(&self.blocks, last_info.target) - && named_cleanup_has_conditional_raise_sibling( - &self.blocks, - BlockIdx::new(block_idx as u32), - last_info.target, - ) - { - named_cleanup_requires_deopt[last_info.target.idx()] = true; - } - } - for (idx, has_with_suppress_resume) in has_with_suppress_resume_predecessor - .iter() - .copied() - .enumerate() - { - if has_with_suppress_resume && named_cleanup_predecessors[idx] > 0 { - named_cleanup_requires_deopt[idx] = true; - } - } - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let mut visited = vec![false; self.blocks.len()]; - for (idx, &count) in named_cleanup_predecessors.iter().enumerate() { - if count == 0 { - continue; - } - let seed = BlockIdx::new(idx as u32); - let mut segment = Vec::new(); - let mut cursor = seed; - let seed_guard_local = leading_bool_guard_local(&self.blocks[seed.idx()]); - let mut fallback_guard_local = None; - while cursor != BlockIdx::NULL { - let block = &self.blocks[cursor.idx()]; - if block_is_exceptional(block) { - break; - } - if cursor != seed - && let Some(local) = leading_bool_guard_local(block) - { - if !leading_bool_guard_has_scope_exit_successor(&self.blocks, block) { - break; - } - if seed_guard_local.is_some_and(|seed_local| seed_local != local) { - break; - } - match fallback_guard_local { - None => fallback_guard_local = Some(local), - Some(expected) if expected != local => break, - Some(_) => {} - } - } - segment.push(cursor); - cursor = block.next; - } - let requires_deopt = named_cleanup_requires_deopt[idx]; - if fallback_guard_local.is_none() && !requires_deopt { - continue; - } - - let mut in_segment = vec![false; self.blocks.len()]; - for block_idx in &segment { - in_segment[block_idx.idx()] = true; - } - for block_idx in segment { - if visited[block_idx.idx()] { - continue; - } - let is_same_guard_fallback = fallback_guard_local.is_some_and(|local| { - leading_bool_guard_local(&self.blocks[block_idx.idx()]) == Some(local) - }); - if !requires_deopt && !is_same_guard_fallback { - continue; - } - if requires_deopt - && is_return_value_through_with_exit(&self.blocks[block_idx.idx()]) - { - continue; - } - if requires_deopt && is_simple_store_attr_tail(&self.blocks[block_idx.idx()]) { - continue; - } - if block_idx != seed - && !is_same_guard_fallback - && predecessors[block_idx.idx()].iter().any(|pred| { - !in_segment[pred.idx()] && !is_allowed_cleanup_resume_block[pred.idx()] - }) - { - continue; - } - visited[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - } - - fn deoptimize_borrow_after_reraising_except_handler(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_fast_load(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) - }) - } - - fn block_requires_post_reraise_strong_loads(block: &Block) -> bool { - block_has_protected_instructions(block) - || block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::DeleteFast { .. } - | Instruction::StoreAttr { .. } - | Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - | Instruction::StoreSubscr - ) - ) - }) - } - - fn block_ends_with_explicit_raise(block: &Block) -> bool { - block - .instructions - .iter() - .rev() - .filter_map(|info| info.instr.real().map(|instr| (instr, info.arg))) - .find(|(instr, _)| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .is_some_and(|(instr, arg)| { - matches!( - instr, - Instruction::RaiseVarargs { argc } - if argc.get(arg) != oparg::RaiseKind::BareRaise - ) - }) - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_non_nop_real_instructions(block: &Block) -> bool { - block.instructions.iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop | Instruction::Cache)) - }) - } - - fn block_jumps_backward_to(block: &Block, target: BlockIdx) -> bool { - block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - } - - fn block_jumps_unconditionally_to(block: &Block, target: BlockIdx) -> bool { - block - .instructions - .iter() - .any(|info| info.target == target && info.instr.is_unconditional_jump()) - } - - fn handler_chain_has_explicit_reraise(blocks: &[Block], handler: BlockIdx) -> bool { - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::RaiseVarargs { argc }) - if argc.get(info.arg) == oparg::RaiseKind::BareRaise - ) - }) { - return true; - } - if block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::PopExcept | Instruction::Reraise { .. }) - ) - }) { - return false; - } - cursor = block.next; - } - false - } - - fn handler_chain_resumes_normally(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL - && info.instr.is_unconditional_jump() - && has_pop_except - }) { - return true; - } - if has_pop_except - && !matches!( - last_real, - Some(Instruction::RaiseVarargs { .. } | Instruction::Reraise { .. }) - ) - { - return true; - } - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if !has_pop_except && target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!( - last_real, - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue - ) - ) && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - false - } - - fn handler_chain_has_named_cleanup_or_explicit_reraise( - blocks: &[Block], - handler: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block - .instructions - .iter() - .any(|info| match info.instr.real() { - Some( - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. }, - ) => true, - Some(Instruction::RaiseVarargs { argc }) => { - argc.get(info.arg) == oparg::RaiseKind::BareRaise - } - _ => false, - }) - { - return true; - } - if block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::PopExcept | Instruction::Reraise { .. }) - ) - }) { - continue; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - false - } - - fn protected_block_handler_has_named_cleanup_or_explicit_reraise( - blocks: &[Block], - block: &Block, - ) -> bool { - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| handler_chain_has_named_cleanup_or_explicit_reraise(blocks, handler)) - } - - fn nonresuming_reraise_handlers(blocks: &[Block], block: &Block) -> Vec { - let mut handlers = Vec::new(); - for handler in block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - { - if handlers.contains(&handler) { - continue; - } - if handler_chain_has_explicit_reraise(blocks, handler) - && !handler_chain_resumes_normally(blocks, handler) - { - handlers.push(handler); - } - } - handlers - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if matches!( - last.instr.real(), - Some( - Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - fn normal_path_reaches_handler( - blocks: &[Block], - start: BlockIdx, - handler: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block.instructions.iter().any(|info| { - info.except_handler - .is_some_and(|except_handler| except_handler.handler_block == handler) - }) { - return true; - } - stack.extend(normal_successors(block)); - } - false - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let has_reraising_except_handler = self.blocks.iter().any(|block| { - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| handler_chain_has_explicit_reraise(&self.blocks, handler)) - }); - if !has_reraising_except_handler { - return; - } - - let mut follows_protected_body = vec![false; self.blocks.len()]; - for (idx, block) in self.blocks.iter().enumerate() { - if block_is_exceptional(block) - || block.cold - || !block_has_fast_load(block) - || !block_requires_post_reraise_strong_loads(block) - { - continue; - } - if predecessors[idx] - .iter() - .any(|pred| is_named_except_cleanup_normal_exit_block(&self.blocks[pred.idx()])) - { - continue; - } - let mut seen = vec![false; self.blocks.len()]; - let mut stack = predecessors[idx].clone(); - while let Some(pred) = stack.pop() { - if pred == BlockIdx::NULL || seen[pred.idx()] { - continue; - } - seen[pred.idx()] = true; - let pred_block = &self.blocks[pred.idx()]; - if block_has_protected_instructions(pred_block) { - if block_jumps_backward_to(pred_block, BlockIdx::new(idx as u32)) { - continue; - } - if block_jumps_unconditionally_to(pred_block, BlockIdx::new(idx as u32)) { - continue; - } - let handlers = nonresuming_reraise_handlers(&self.blocks, pred_block); - follows_protected_body[idx] = !handlers.is_empty() - && !handlers.iter().any(|handler| { - normal_path_reaches_handler( - &self.blocks, - BlockIdx::new(idx as u32), - *handler, - ) - }); - break; - } - if !block_is_exceptional(pred_block) - && !pred_block.cold - && !block_has_non_nop_real_instructions(pred_block) - { - stack.extend(predecessors[pred.idx()].iter().copied()); - } - } - } - - for (idx, follows_protected_body) in follows_protected_body.iter().enumerate() { - if !*follows_protected_body { - continue; - } - let mut visited = vec![false; self.blocks.len()]; - let mut stack = vec![BlockIdx::new(idx as u32)]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - let block = &self.blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if protected_block_handler_has_named_cleanup_or_explicit_reraise( - &self.blocks, - block, - ) { - visited[block_idx.idx()] = true; - stack.extend(normal_successors(block)); - continue; - } - if predecessors[block_idx.idx()] - .iter() - .any(|pred| is_named_except_cleanup_normal_exit_block(&self.blocks[pred.idx()])) - { - continue; - } - if block_ends_with_explicit_raise(block) { - continue; - } - visited[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - if self.blocks[block_idx.idx()] - .instructions - .last() - .is_some_and(|info| info.instr.is_scope_exit()) - { - continue; - } - stack.extend(normal_successors(&self.blocks[block_idx.idx()])); - } - } - } - - fn deoptimize_borrow_in_protected_conditional_tail(&mut self) { - fn second_last_real_instr(block: &Block) -> Option { - let mut reals = block - .instructions - .iter() - .rev() - .filter_map(|info| info.instr.real()); - let _last = reals.next()?; - reals.next() - } - - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_non_nop_real_instructions(block: &Block) -> bool { - block.instructions.iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop)) - }) - } - - fn starts_with_assertion_error(block: &Block) -> bool { - block - .instructions - .iter() - .find(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .is_some_and(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadCommonConstant { idx }) - if idx.get(info.arg) == oparg::CommonConstant::AssertionError - ) - }) - } - - fn success_path_stays_protected(blocks: &[Block], start: BlockIdx) -> bool { - let mut cursor = start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - let has_non_marker_real = block.instructions.iter().any(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }); - if has_non_marker_real { - return block_has_protected_instructions(block); - } - cursor = block.next; - } - false - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let pop_top_count = block - .instructions - .iter() - .filter(|info| matches!(info.instr.real(), Some(Instruction::PopTop))) - .count(); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && pop_top_count == 0 && jumps_to_target - } - - fn has_direct_handler_resume_predecessor(blocks: &[Block], target: BlockIdx) -> bool { - blocks.iter().any(|pred_block| { - let has_pop_except = pred_block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = pred_block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && jumps_to_target - }) - } - - fn handler_chain_resumes_normally(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL - && info.instr.is_unconditional_jump() - && has_pop_except - }) { - return true; - } - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if !has_pop_except && target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!( - last_real, - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue - ) - ) && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - false - } - - fn handler_chain_has_explicit_raise(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - if block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::RaiseVarargs { .. }))) - { - return true; - } - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!(last_real, Some(Instruction::ReturnValue)) - && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - false - } - - fn handler_chain_has_multiple_handled_returns(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - let mut handled_returns = 0usize; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let has_return = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ReturnValue))); - let has_raise = block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::RaiseVarargs { .. } | Instruction::Reraise { .. }) - ) - }); - if has_pop_except && has_return && !has_raise { - handled_returns += 1; - if handled_returns >= 2 { - return true; - } - } - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!(last_real, Some(Instruction::ReturnValue)) - && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - false - } - - fn block_has_nonresuming_exception_match_handler(blocks: &[Block], block: &Block) -> bool { - let mut seen_handlers = Vec::new(); - for handler in block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - { - if seen_handlers.contains(&handler) { - continue; - } - seen_handlers.push(handler); - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - let mut has_exception_match = false; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - if blocks[cursor.idx()].instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) { - has_exception_match = true; - break; - } - cursor = blocks[cursor.idx()].next; - } - if has_exception_match - && handler_chain_has_explicit_raise(blocks, handler) - && !handler_chain_has_multiple_handled_returns(blocks, handler) - && !handler_chain_resumes_normally(blocks, handler) - { - return true; - } - } - false - } - - fn nonresuming_handlers(blocks: &[Block], block: &Block) -> Vec { - let mut handlers = Vec::new(); - for handler in block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - { - if handlers.contains(&handler) { - continue; - } - if handler_chain_has_explicit_raise(blocks, handler) - && !handler_chain_has_multiple_handled_returns(blocks, handler) - && !handler_chain_resumes_normally(blocks, handler) - { - handlers.push(handler); - } - } - handlers - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if matches!( - last.instr.real(), - Some( - Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - fn normal_path_reaches_handler( - blocks: &[Block], - start: BlockIdx, - handler: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block.instructions.iter().any(|info| { - info.except_handler - .is_some_and(|except_handler| except_handler.handler_block == handler) - }) { - return true; - } - stack.extend(normal_successors(block)); - } - false - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - let mut is_handler_resume_block = vec![false; self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if matches!(second_last_real_instr(block), Some(Instruction::PopExcept)) - && block.instructions.last().is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) - { - is_handler_resume_block[pred_idx] = true; - } - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let seeds: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - let cond_idx = trailing_conditional_jump_index(block)?; - if block.try_else_orelse_entry { - return None; - } - let prev_protected = predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - block_has_protected_instructions(pred_block) - && block_has_exception_match_handler(&self.blocks, pred_block) - && block_has_nonresuming_exception_match_handler(&self.blocks, pred_block) - }); - let prev_nonresuming_handlers: Vec<_> = predecessors[idx] - .iter() - .flat_map(|pred| { - let pred_block = &self.blocks[pred.idx()]; - if block_has_protected_instructions(pred_block) { - nonresuming_handlers(&self.blocks, pred_block) - } else { - Vec::new() - } - }) - .collect(); - let has_unprotected_normal_predecessor = predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !block_is_exceptional(pred_block) - && !pred_block.cold - && !is_handler_resume_block[pred.idx()] - && !is_handler_resume_predecessor(pred_block, BlockIdx::new(idx as u32)) - && !block_has_protected_instructions(pred_block) - && block_has_non_nop_real_instructions(pred_block) - }); - let assertion_message_fallthrough = block.next != BlockIdx::NULL - && starts_with_assertion_error(&self.blocks[block.next.idx()]); - let protected_assert = assertion_message_fallthrough - && block_has_protected_instructions(block) - && block_has_exception_match_handler(&self.blocks, block) - && block_has_nonresuming_exception_match_handler(&self.blocks, block); - let seed = if assertion_message_fallthrough { - block.instructions[cond_idx].target - } else { - BlockIdx::new(idx as u32) - }; - let force_deopt = assertion_message_fallthrough - && !success_path_stays_protected(&self.blocks, seed); - let seed_enabled = if assertion_message_fallthrough { - force_deopt && (prev_protected || protected_assert) - } else { - prev_protected - }; - let same_handler_continuation = prev_nonresuming_handlers - .iter() - .any(|handler| normal_path_reaches_handler(&self.blocks, seed, *handler)); - (!block_is_exceptional(block) - && seed != BlockIdx::NULL - && seed_enabled - && !same_handler_continuation - && !has_unprotected_normal_predecessor) - .then_some((seed, force_deopt)) - }) - .collect(); - - let mut visited = vec![false; self.blocks.len()]; - for (seed, force_deopt) in seeds { - let mut segment = Vec::new(); - let mut cursor = seed; - while cursor != BlockIdx::NULL { - if block_is_exceptional(&self.blocks[cursor.idx()]) { - break; - } - if cursor != seed - && predecessors[cursor.idx()].iter().any(|pred| { - is_handler_resume_block[pred.idx()] - || is_handler_resume_predecessor(&self.blocks[pred.idx()], cursor) - || is_named_except_cleanup_normal_exit_block(&self.blocks[pred.idx()]) - }) - { - break; - } - segment.push(cursor); - cursor = self.blocks[cursor.idx()].next; - } - if segment.iter().any(|block_idx| { - predecessors[block_idx.idx()] - .iter() - .any(|pred| is_named_except_cleanup_normal_exit_block(&self.blocks[pred.idx()])) - }) { - continue; - } - - let segment_ops: Vec<_> = segment - .iter() - .flat_map(|block_idx| { - self.blocks[block_idx.idx()] - .instructions - .iter() - .filter_map(|info| info.instr.real()) - }) - .collect(); - let call_count = segment_ops - .iter() - .filter(|instr| matches!(instr, Instruction::Call { .. })) - .count(); - let raise_count = segment_ops - .iter() - .filter(|instr| matches!(instr, Instruction::RaiseVarargs { .. })) - .count(); - let return_count = segment_ops - .iter() - .filter(|instr| matches!(instr, Instruction::ReturnValue)) - .count(); - let conditional_count = segment_ops - .iter() - .filter(|instr| { - matches!( - instr, - Instruction::PopJumpIfFalse { .. } - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfNone { .. } - | Instruction::PopJumpIfNotNone { .. } - ) - }) - .count(); - let has_handler_resume_predecessor = - predecessors[seed.idx()].iter().any(|pred| { - is_handler_resume_block[pred.idx()] - || is_handler_resume_predecessor(&self.blocks[pred.idx()], seed) - }) || has_direct_handler_resume_predecessor(&self.blocks, seed); - if has_handler_resume_predecessor { - continue; - } - let has_loop_cleanup_predecessor = predecessors[seed.idx()].iter().any(|pred| { - self.blocks[pred.idx()].instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::EndFor | Instruction::EndAsyncFor | Instruction::PopIter) - ) - }) - }); - let has_named_except_cleanup_predecessor = predecessors[seed.idx()] - .iter() - .any(|pred| is_named_except_cleanup_normal_exit_block(&self.blocks[pred.idx()])); - let has_complex_tail = segment_ops.iter().any(|instr| { - matches!( - instr, - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - | Instruction::ForIter { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - | Instruction::EndFor - | Instruction::PopIter - | Instruction::LoadFastAndClear { .. } - | Instruction::LoadFastCheck { .. } - | Instruction::ListAppend { .. } - | Instruction::MapAdd { .. } - | Instruction::SetAdd { .. } - ) - }); - let has_loop_or_comprehension_tail = segment_ops.iter().any(|instr| { - matches!( - instr, - Instruction::ForIter { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - | Instruction::EndFor - | Instruction::PopIter - | Instruction::LoadFastAndClear { .. } - | Instruction::LoadFastCheck { .. } - | Instruction::ListAppend { .. } - | Instruction::MapAdd { .. } - | Instruction::SetAdd { .. } - ) - }); - let has_store_fast_tail = segment_ops.iter().any(|instr| { - matches!( - instr, - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - ) - }); - let has_nonresuming_protected_conditional_tail = !has_handler_resume_predecessor - && !has_loop_cleanup_predecessor - && !has_loop_or_comprehension_tail - && has_store_fast_tail - && call_count >= 1 - && return_count >= 1 - && conditional_count == 1; - let has_existing_protected_conditional_tail = !has_loop_cleanup_predecessor - && !has_handler_resume_predecessor - && !has_named_except_cleanup_predecessor - && !has_complex_tail - && call_count == 2 - && raise_count == 1 - && return_count == 1 - && conditional_count == 1; - if !(force_deopt - || has_nonresuming_protected_conditional_tail - || has_existing_protected_conditional_tail) - { - continue; - } - - let mut in_segment = vec![false; self.blocks.len()]; - for block_idx in &segment { - in_segment[block_idx.idx()] = true; - } - - for block_idx in segment { - if visited[block_idx.idx()] { - continue; - } - if predecessors[block_idx.idx()].iter().any(|pred| { - is_handler_resume_block[pred.idx()] - || is_handler_resume_predecessor(&self.blocks[pred.idx()], block_idx) - || is_named_except_cleanup_normal_exit_block(&self.blocks[pred.idx()]) - }) || has_direct_handler_resume_predecessor(&self.blocks, block_idx) - { - continue; - } - if block_has_protected_instructions(&self.blocks[block_idx.idx()]) { - continue; - } - if !force_deopt - && block_idx != seed - && predecessors[block_idx.idx()] - .iter() - .any(|pred| !in_segment[pred.idx()] && !is_handler_resume_block[pred.idx()]) - { - continue; - } - visited[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - } - - fn deoptimize_borrow_after_terminal_except_tail(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_real_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.instr.real().is_some()) - } - - fn block_has_non_nop_real_instructions(block: &Block) -> bool { - block.instructions.iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop)) - }) - } - - fn handler_chain_resumes_normally(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut has_terminal_exit = false; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL - && info.instr.is_unconditional_jump() - && has_pop_except - }) { - return true; - } - if block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue - ) - ) - }) { - has_terminal_exit = true; - } - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if !has_pop_except && target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!( - last_real, - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue - ) - ) && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - !has_terminal_exit - } - - fn handler_reaches_match_before_terminal(blocks: &[Block], handler: BlockIdx) -> bool { - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if cursor != handler && block.except_handler { - return false; - } - for info in &block.instructions { - match info.instr.real() { - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) => { - return true; - } - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue, - ) => return false, - _ => {} - } - } - cursor = block.next; - } - false - } - - fn handler_chain_has_multiple_handled_returns(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - let mut handled_returns = 0usize; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let has_return = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ReturnValue))); - let has_raise = block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::RaiseVarargs { .. } | Instruction::Reraise { .. }) - ) - }); - if has_pop_except && has_return && !has_raise { - handled_returns += 1; - if handled_returns >= 2 { - return true; - } - } - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!(last_real, Some(Instruction::ReturnValue)) - && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - false - } - - fn handler_is_terminal_exception_handler(blocks: &[Block], handler: BlockIdx) -> bool { - handler_reaches_match_before_terminal(blocks, handler) - && !handler_chain_has_multiple_handled_returns(blocks, handler) - && !handler_chain_resumes_normally(blocks, handler) - } - - fn handler_chain_has_handled_return_or_bare_reraise( - blocks: &[Block], - handler: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let has_return = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ReturnValue))); - if has_pop_except && has_return { - return true; - } - if block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::RaiseVarargs { argc }) - if argc.get(info.arg) == oparg::RaiseKind::BareRaise - ) - }) { - return true; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - false - } - - fn handler_body_has_conditional_after_match(blocks: &[Block], handler: BlockIdx) -> bool { - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - let mut in_body = false; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - for info in &block.instructions { - match info.instr.real() { - Some(Instruction::PopTop) if !in_body => { - in_body = true; - } - Some(instr) - if in_body && is_conditional_jump(&AnyInstruction::Real(instr)) => - { - return true; - } - Some( - Instruction::PopExcept - | Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue, - ) => return false, - _ => {} - } - } - cursor = block.next; - } - false - } - - fn handler_is_terminal_for_conditional_tail(blocks: &[Block], handler: BlockIdx) -> bool { - handler_reaches_match_before_terminal(blocks, handler) - && handler_chain_has_handled_return_or_bare_reraise(blocks, handler) - && !handler_body_has_conditional_after_match(blocks, handler) - && !handler_chain_resumes_normally(blocks, handler) - } - - fn trailing_protected_tail_terminal_exception_handler( - blocks: &[Block], - block: &Block, - ) -> Option { - for info in block.instructions.iter().rev() { - match info.instr.real() { - Some(Instruction::Nop | Instruction::NotTaken | Instruction::PopTop) => {} - Some(_) => { - let handler = info.except_handler.map(|handler| handler.handler_block)?; - return handler_is_terminal_exception_handler(blocks, handler) - .then_some(handler); - } - None => {} - } - } - None - } - - fn trailing_protected_tail_conditional_exception_handler( - blocks: &[Block], - block: &Block, - ) -> Option { - for info in block.instructions.iter().rev() { - match info.instr.real() { - Some(Instruction::Nop | Instruction::NotTaken | Instruction::PopTop) => {} - Some(_) => { - let handler = info.except_handler.map(|handler| handler.handler_block)?; - return handler_is_terminal_for_conditional_tail(blocks, handler) - .then_some(handler); - } - None => {} - } - } - None - } - - fn protected_tail_ends_with_conditional(block: &Block) -> bool { - block - .instructions - .iter() - .rev() - .filter_map(|info| info.instr.real()) - .find(|instr| { - !matches!( - instr, - Instruction::Nop | Instruction::NotTaken | Instruction::PopTop - ) - }) - .is_some_and(|instr| is_conditional_jump(&AnyInstruction::Real(instr))) - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if matches!( - last.instr.real(), - Some( - Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - fn normal_path_reaches_handler( - blocks: &[Block], - start: BlockIdx, - handler: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block.instructions.iter().any(|info| { - info.except_handler - .is_some_and(|except_handler| except_handler.handler_block == handler) - }) { - return true; - } - stack.extend(normal_successors(block)); - } - false - } - - fn has_call_store_before_trailing_conditional(block: &Block) -> bool { - let Some(cond_idx) = trailing_conditional_jump_index(block) else { - return false; - }; - block.instructions[..cond_idx].iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::Call { .. } | Instruction::CallKw { .. }) - ) - }) && block.instructions[..cond_idx].iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - ) - ) - }) - } - - fn has_call_and_store(block: &Block) -> bool { - let mut has_call = false; - let mut has_store_fast = false; - for info in &block.instructions { - match info.instr.real() { - Some(Instruction::Call { .. } | Instruction::CallKw { .. }) => has_call = true, - Some( - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. }, - ) => has_store_fast = true, - _ => {} - } - } - has_call && has_store_fast - } - - fn normal_tail_reaches_conditional(blocks: &[Block], start: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if trailing_conditional_jump_index(block).is_some() - || has_call_store_before_trailing_conditional(block) - { - return true; - } - if block - .instructions - .last() - .is_some_and(|info| info.instr.is_scope_exit()) - { - continue; - } - stack.extend(normal_successors(block)); - } - false - } - - fn has_load_fast_pair(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) - }) - } - - fn has_call(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::Call { .. } | Instruction::CallKw { .. }) - ) - }) - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && jumps_to_target - } - - fn is_unprotected_call_store_bridge_to(block: &Block, successor: BlockIdx) -> bool { - !block_is_exceptional(block) - && !block.cold - && !block_has_protected_instructions(block) - && has_call_and_store(block) - && trailing_conditional_jump_index(block).is_none() - && normal_successors(block).contains(&successor) - } - - fn is_simple_fast_return_block(block: &Block) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .collect(); - matches!( - reals.as_slice(), - [ - Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }, - Instruction::ReturnValue, - ] - ) - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let mut seeds = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - let has_protected_call_predecessor = predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - block_has_protected_instructions(pred_block) && has_call(pred_block) - }); - let has_call_store_tail = has_call_and_store(block) && has_load_fast_pair(block); - let has_conditional_tail = trailing_conditional_jump_index(block).is_some() - || has_call_store_before_trailing_conditional(block) - || predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !block_is_exceptional(pred_block) - && !pred_block.cold - && has_call_and_store(pred_block) - && has_load_fast_pair(pred_block) - }); - let has_structured_terminal_tail_shape = has_conditional_tail; - if block_is_exceptional(block) - || block.cold - || block_has_protected_instructions(block) - || !block_has_real_instructions(block) - || (has_conditional_tail - && block.start_depth.is_some_and(|depth| depth > 0) - && !has_protected_call_predecessor - && !has_call_store_tail) - || block.try_else_orelse_entry - || predecessors[idx].iter().any(|pred| { - is_handler_resume_predecessor( - &self.blocks[pred.idx()], - BlockIdx::new(idx as u32), - ) - }) - || predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !block_is_exceptional(pred_block) - && !pred_block.cold - && !block_has_protected_instructions(pred_block) - && block_has_non_nop_real_instructions(pred_block) - && !is_unprotected_call_store_bridge_to( - pred_block, - BlockIdx::new(idx as u32), - ) - }) - || !(has_structured_terminal_tail_shape - || has_protected_call_predecessor - || has_call_store_tail) - { - continue; - } - - let mut seen = vec![false; self.blocks.len()]; - let mut stack = predecessors[idx].clone(); - while let Some(pred) = stack.pop() { - if pred == BlockIdx::NULL || seen[pred.idx()] { - continue; - } - seen[pred.idx()] = true; - let pred_block = &self.blocks[pred.idx()]; - if block_has_protected_instructions(pred_block) { - if protected_tail_ends_with_conditional(pred_block) { - break; - } - if let Some(handler) = - trailing_protected_tail_terminal_exception_handler(&self.blocks, pred_block) - { - let seed = BlockIdx::new(idx as u32); - if normal_path_reaches_handler(&self.blocks, seed, handler) { - break; - } - seeds.push(( - seed, - (has_protected_call_predecessor || has_call_store_tail) - && !has_structured_terminal_tail_shape, - false, - )); - } - break; - } - if is_unprotected_call_store_bridge_to(pred_block, BlockIdx::new(idx as u32)) { - stack.extend(predecessors[pred.idx()].iter().copied()); - continue; - } - if !block_is_exceptional(pred_block) - && !pred_block.cold - && !block_has_non_nop_real_instructions(pred_block) - { - stack.extend(predecessors[pred.idx()].iter().copied()); - } - } - } - - for (pred_idx, pred_block) in self.blocks.iter().enumerate() { - if block_is_exceptional(pred_block) || pred_block.cold { - continue; - } - let Some(handler) = - trailing_protected_tail_conditional_exception_handler(&self.blocks, pred_block) - else { - continue; - }; - for successor in normal_successors(pred_block) { - if successor == BlockIdx::NULL { - continue; - } - let successor_block = &self.blocks[successor.idx()]; - if block_is_exceptional(successor_block) - || successor_block.cold - || successor_block.try_else_orelse_entry - || !normal_tail_reaches_conditional(&self.blocks, successor) - || normal_path_reaches_handler(&self.blocks, successor, handler) - || predecessors[successor.idx()].iter().any(|pred| { - *pred != BlockIdx::new(pred_idx as u32) - && is_handler_resume_predecessor(&self.blocks[pred.idx()], successor) - }) - { - continue; - } - seeds.push((successor, false, true)); - } - } - - let mut visited = vec![false; self.blocks.len()]; - for (seed, direct_only, skip_simple_fast_returns) in seeds { - let mut reachable_from_seed = vec![false; self.blocks.len()]; - let mut reachability_stack = vec![seed]; - while let Some(block_idx) = reachability_stack.pop() { - if block_idx == BlockIdx::NULL || reachable_from_seed[block_idx.idx()] { - continue; - } - let block = &self.blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - reachable_from_seed[block_idx.idx()] = true; - reachability_stack.extend(normal_successors(block)); - } - - let mut stack = vec![seed]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - let block = &self.blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block_idx != seed - && predecessors[block_idx.idx()].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !block_is_exceptional(pred_block) - && !pred_block.cold - && !reachable_from_seed[pred.idx()] - && normal_successors(pred_block).contains(&block_idx) - }) - { - continue; - } - visited[block_idx.idx()] = true; - let successors = normal_successors(&self.blocks[block_idx.idx()]); - if !(self.blocks[block_idx.idx()].try_else_orelse_entry - || skip_simple_fast_returns - && is_simple_fast_return_block(&self.blocks[block_idx.idx()])) - { - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - if direct_only { - continue; - } - for successor in successors { - stack.push(successor); - } - } - } - } - - fn deoptimize_borrow_in_protected_method_call_after_terminal_except_tail(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_non_nop_real_instructions(block: &Block) -> bool { - block.instructions.iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop)) - }) - } - - fn block_ends_with_return_value(block: &Block) -> bool { - block - .instructions - .iter() - .rev() - .find_map(|info| info.instr.real()) - .is_some_and(|instr| matches!(instr, Instruction::ReturnValue)) - } - - fn handler_chain_has_exception_match(blocks: &[Block], handler: BlockIdx) -> bool { - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - if blocks[cursor.idx()].instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) { - return true; - } - cursor = blocks[cursor.idx()].next; - } - false - } - - fn block_has_protected_method_call(blocks: &[Block], block: &Block) -> bool { - if !block_has_protected_instructions(block) { - return false; - } - block.instructions.iter().any(|info| { - let is_method_load = matches!( - info.instr.real(), - Some(Instruction::LoadAttr { namei }) if namei.get(info.arg).is_method() - ); - is_method_load - && info.except_handler.is_some_and(|handler| { - handler_chain_has_exception_match(blocks, handler.handler_block) - }) - }) - } - - fn protected_method_call_handlers(blocks: &[Block], block: &Block) -> Vec { - let mut handlers = Vec::new(); - for info in &block.instructions { - if !matches!( - info.instr.real(), - Some(Instruction::LoadAttr { namei }) if namei.get(info.arg).is_method() - ) { - continue; - } - let Some(handler) = info.except_handler.map(|handler| handler.handler_block) else { - continue; - }; - if !handler_chain_has_exception_match(blocks, handler) { - continue; - } - if !handlers.contains(&handler) { - handlers.push(handler); - } - } - handlers - } - - fn block_shares_handler(block: &Block, handlers: &[BlockIdx]) -> bool { - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| handlers.contains(&handler)) - } - - fn starts_with_inlined_comprehension_restore(block: &Block) -> bool { - if block.start_depth.is_none_or(|depth| depth == 0) { - return false; - } - let mut saw_store = false; - for info in &block.instructions { - match info.instr.real() { - Some( - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. }, - ) => saw_store = true, - Some(Instruction::Nop) => {} - Some(_) => return saw_store, - None => {} - } - } - false - } - - fn handler_chain_resumes_normally(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut has_terminal_exit = false; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let last_real_info = block - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()); - let last_real = last_real_info.and_then(|info| info.instr.real()); - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL - && info.instr.is_unconditional_jump() - && has_pop_except - }) { - return true; - } - if block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue - ) - ) - }) { - has_terminal_exit = true; - } - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if last_real_info.is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) { - let target = last_real_info.unwrap().target; - if !has_pop_except && target != BlockIdx::NULL { - stack.push(target); - } - } else if !matches!( - last_real, - Some( - Instruction::RaiseVarargs { .. } - | Instruction::Reraise { .. } - | Instruction::ReturnValue - ) - ) && block.next != BlockIdx::NULL - { - stack.push(block.next); - } - } - !has_terminal_exit - } - - fn protected_block_has_terminal_exception_handler(blocks: &[Block], block: &Block) -> bool { - let mut seen_handlers = Vec::new(); - for handler in block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - { - if seen_handlers.contains(&handler) { - continue; - } - seen_handlers.push(handler); - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - let mut has_exception_match = false; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - if blocks[cursor.idx()].instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) { - has_exception_match = true; - break; - } - cursor = blocks[cursor.idx()].next; - } - if has_exception_match && !handler_chain_resumes_normally(blocks, handler) { - return true; - } - } - false - } - - fn protected_block_has_raising_exception_handler(blocks: &[Block], block: &Block) -> bool { - let mut seen_handlers = Vec::new(); - for handler in block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - { - if seen_handlers.contains(&handler) { - continue; - } - seen_handlers.push(handler); - let mut stack = Vec::new(); - let mut visited = vec![false; blocks.len()]; - let mut cursor = handler; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let handler_block = &blocks[cursor.idx()]; - if handler_block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) { - if handler_block.next != BlockIdx::NULL { - stack.push(handler_block.next); - } - break; - } - cursor = handler_block.next; - } - visited.fill(false); - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - let mut stop_path = false; - for info in &block.instructions { - match info.instr.real() { - Some( - Instruction::RaiseVarargs { .. } | Instruction::Reraise { .. }, - ) => { - return true; - } - Some(Instruction::ReturnValue | Instruction::PopExcept) => { - stop_path = true; - break; - } - _ => {} - } - } - if stop_path { - continue; - } - for info in &block.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - if info.instr.is_unconditional_jump() && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - } - false - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let mut to_deopt = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - if block_is_exceptional(block) - || block.cold - || starts_with_inlined_comprehension_restore(block) - || !block_has_protected_method_call(&self.blocks, block) - || block_ends_with_return_value(block) - || predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !block_is_exceptional(pred_block) - && !pred_block.cold - && !block_has_protected_instructions(pred_block) - && block_has_non_nop_real_instructions(pred_block) - }) - { - continue; - } - - let method_handlers = protected_method_call_handlers(&self.blocks, block); - let mut seen = vec![false; self.blocks.len()]; - let mut stack = predecessors[idx].clone(); - while let Some(pred) = stack.pop() { - if pred == BlockIdx::NULL || seen[pred.idx()] { - continue; - } - seen[pred.idx()] = true; - let pred_block = &self.blocks[pred.idx()]; - if block_has_protected_instructions(pred_block) { - if protected_block_has_terminal_exception_handler(&self.blocks, pred_block) - && protected_block_has_raising_exception_handler(&self.blocks, pred_block) - && !block_shares_handler(pred_block, &method_handlers) - { - to_deopt.push(BlockIdx::new(idx as u32)); - } - break; - } - if !block_is_exceptional(pred_block) - && !pred_block.cold - && !block_has_non_nop_real_instructions(pred_block) - { - stack.extend(predecessors[pred.idx()].iter().copied()); - } - } - } - - for block_idx in to_deopt { - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - - fn deoptimize_borrow_after_except_star_try_tail(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_fast_load(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) - }) - } - - fn handler_chain_has_exception_group_match(blocks: &[Block], handler: BlockIdx) -> bool { - let mut cursor = handler; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - if blocks[cursor.idx()] - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::CheckEgMatch))) - { - return true; - } - cursor = blocks[cursor.idx()].next; - } - false - } - - fn block_has_exception_group_handler(blocks: &[Block], block: &Block) -> bool { - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| handler_chain_has_exception_group_match(blocks, handler)) - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (idx, block) in self.blocks.iter().enumerate() { - for successor in normal_successors(block) { - predecessors[successor.idx()].push(BlockIdx::new(idx as u32)); - } - } - - let mut to_deopt = Vec::new(); - for (idx, block) in self.blocks.iter().enumerate() { - if block_is_exceptional(block) - || block.cold - || block_has_exception_group_handler(&self.blocks, block) - || !block_has_fast_load(block) - { - continue; - } - - let mut visited = vec![false; self.blocks.len()]; - let mut stack = predecessors[idx].clone(); - while let Some(pred) = stack.pop() { - if pred == BlockIdx::NULL || visited[pred.idx()] { - continue; - } - visited[pred.idx()] = true; - let pred_block = &self.blocks[pred.idx()]; - if block_has_protected_instructions(pred_block) - && block_has_exception_group_handler(&self.blocks, pred_block) - { - to_deopt.push(BlockIdx::new(idx as u32)); - break; - } - if !block_is_exceptional(pred_block) && !pred_block.cold { - stack.extend(predecessors[pred.idx()].iter().copied()); - } - } - } - - to_deopt.sort_by_key(|idx| idx.idx()); - to_deopt.dedup(); - for block_idx in to_deopt { - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - - fn deoptimize_borrow_for_folded_nonliteral_exprs(&mut self) { - let mut starts_after_folded_nonliteral_expr = vec![false; self.blocks.len()]; - for block in &self.blocks { - let Some(last) = block - .instructions - .iter() - .rev() - .find(|info| !matches!(info.instr.real(), Some(Instruction::Nop))) - else { - continue; - }; - if !last.folded_from_nonliteral_expr { - continue; - } - if block.next != BlockIdx::NULL { - starts_after_folded_nonliteral_expr[block.next.idx()] = true; - } - if last.target != BlockIdx::NULL { - starts_after_folded_nonliteral_expr[last.target.idx()] = true; - } - } - - for (block_idx, block) in self.blocks.iter_mut().enumerate() { - let mut deopt_tail = false; - let mut prev_non_nop_folded = starts_after_folded_nonliteral_expr[block_idx]; - let mut prev_prev_non_nop_folded = false; - let mut prev_non_nop_was_unpack = false; - for info in &mut block.instructions { - let folded_from_nonliteral_expr = info.folded_from_nonliteral_expr; - let real = info.instr.real(); - let store_from_folded_nonliteral_expr = match real { - Some(Instruction::StoreFast { .. } | Instruction::StoreFastLoadFast { .. }) => { - folded_from_nonliteral_expr || prev_non_nop_folded - } - Some(Instruction::StoreFastStoreFast { .. }) => { - folded_from_nonliteral_expr - || prev_non_nop_folded - || (prev_non_nop_was_unpack && prev_prev_non_nop_folded) - } - _ => false, - }; - if store_from_folded_nonliteral_expr { - deopt_tail = true; - } - if !folded_from_nonliteral_expr && !deopt_tail { - if let Some(real) = real - && !matches!(real, Instruction::Nop | Instruction::Cache) - { - prev_prev_non_nop_folded = prev_non_nop_folded; - prev_non_nop_folded = folded_from_nonliteral_expr; - prev_non_nop_was_unpack = - matches!(real, Instruction::UnpackSequence { .. }); - } - continue; - } - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - if let Some(real) = real - && !matches!(real, Instruction::Nop | Instruction::Cache) - { - prev_prev_non_nop_folded = prev_non_nop_folded; - prev_non_nop_folded = folded_from_nonliteral_expr; - prev_non_nop_was_unpack = matches!(real, Instruction::UnpackSequence { .. }); - } - } - } - } - - fn deoptimize_borrow_after_terminal_except_before_with(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && jumps_to_target - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_starts_with_with_setup(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::LoadSpecial { .. }))) - } - - fn block_starts_with_with_cleanup_handler(block: &Block) -> bool { - let mut reals = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop)); - matches!( - (reals.next(), reals.next()), - ( - Some(Instruction::PushExcInfo), - Some(Instruction::WithExceptStart) - ) - ) - } - - fn block_has_exception_match(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - } - - fn next_chain_reaches_exception_match_before_with_cleanup( - blocks: &[Block], - start: BlockIdx, - ) -> bool { - let mut cursor = start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if block_starts_with_with_cleanup_handler(block) { - return false; - } - if block_has_exception_match(block) { - return true; - } - cursor = block.next; - } - false - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - for successor in normal_successors(block) { - predecessors[successor.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - - let seeds: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) - || !block_has_protected_instructions(block) - || !block_starts_with_with_setup(block) - || !next_chain_reaches_exception_match_before_with_cleanup( - &self.blocks, - block.next, - ) - { - return None; - } - let has_terminal_except_predecessor = predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - pred_block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - && block_has_exception_match_handler(&self.blocks, pred_block) - }); - let has_handler_resume_predecessor = predecessors[idx].iter().any(|pred| { - is_handler_resume_predecessor( - &self.blocks[pred.idx()], - BlockIdx::new(idx as u32), - ) - }); - (has_terminal_except_predecessor && !has_handler_resume_predecessor) - .then_some(BlockIdx::new(idx as u32)) - }) - .collect(); - - let mut to_deopt = Vec::new(); - let mut visited = vec![false; self.blocks.len()]; - for seed in seeds { - let mut stack = vec![seed]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - let block = &self.blocks[block_idx.idx()]; - if block_is_exceptional(block) || !block_has_protected_instructions(block) { - continue; - } - visited[block_idx.idx()] = true; - to_deopt.push(block_idx); - for successor in normal_successors(block) { - stack.push(successor); - } - } - } - - to_deopt.sort_by_key(|idx| idx.idx()); - to_deopt.dedup(); - for block_idx in to_deopt { - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - - fn deoptimize_borrow_after_handler_resume_loop_tail(&mut self) { - fn deoptimize_block_borrows(block: &mut Block) { - for info in &mut block.instructions { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && jumps_to_target - } - - fn handler_chain_resumes_to_loop_header( - blocks: &[Block], - handler: BlockIdx, - loop_header: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - if has_pop_except - && block.instructions.iter().any(|info| { - info.target == loop_header - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - { - return true; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - false - } - - fn handler_chain_stores_and_resumes_to_loop_header( - blocks: &[Block], - handler: BlockIdx, - loop_header: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![(handler, false)]; - while let Some((block_idx, stores_local)) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - let stores_local = stores_local || block_has_store_fast(block); - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - if has_pop_except - && stores_local - && block.instructions.iter().any(|info| { - info.target == loop_header - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - { - return true; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - stack.push((info.target, stores_local)); - } - } - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - stack.push((block.next, stores_local)); - } - } - false - } - - fn protected_block_handler_resumes_to_self(blocks: &[Block], block_idx: BlockIdx) -> bool { - let block = &blocks[block_idx.idx()]; - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| handler_chain_resumes_to_loop_header(blocks, handler, block_idx)) - } - - fn is_suppressing_with_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - if !block - .instructions - .last() - .is_some_and(|info| info.target == target && info.instr.is_unconditional_jump()) - { - return false; - } - let mut reals = block - .instructions - .iter() - .rev() - .filter_map(|info| info.instr.real()); - matches!( - (reals.next(), reals.next(), reals.next(), reals.next(), reals.next()), - ( - Some(instr), - Some(Instruction::PopTop), - Some(Instruction::PopTop), - Some(Instruction::PopTop), - Some(Instruction::PopExcept), - ) if instr.is_unconditional_jump() - ) - } - - fn block_has_check_exc_match(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_non_nop_real_instructions(block: &Block) -> bool { - block.instructions.iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - }) - } - - fn predecessor_chain_has_check_exc_match( - blocks: &[Block], - predecessors: &[Vec], - start: BlockIdx, - stop: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || cursor == stop || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - if block_has_check_exc_match(&blocks[cursor.idx()]) { - return true; - } - for pred in &predecessors[cursor.idx()] { - stack.push(*pred); - } - } - false - } - - fn has_exception_match_resume_predecessor( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - ) -> bool { - predecessors[target.idx()].iter().any(|pred| { - let block = &blocks[pred.idx()]; - is_handler_resume_predecessor(block, target) - && !is_suppressing_with_resume_predecessor(block, target) - && predecessor_chain_has_check_exc_match(blocks, predecessors, *pred, target) - }) - } - - fn is_plain_protected_resume_successor( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - ) -> bool { - let mut has_handler_resume_predecessor = false; - let mut has_normal_fallthrough_predecessor = false; - for pred in &predecessors[target.idx()] { - let pred_block = &blocks[pred.idx()]; - if is_handler_resume_predecessor(pred_block, target) { - has_handler_resume_predecessor = true; - continue; - } - if block_is_exceptional(pred_block) || pred_block.cold { - continue; - } - if next_nonempty_block(blocks, pred_block.next) == target { - has_normal_fallthrough_predecessor = true; - continue; - } - return false; - } - has_handler_resume_predecessor && has_normal_fallthrough_predecessor - } - - fn predecessor_chain_has_protected_instructions( - blocks: &[Block], - predecessors: &[Vec], - start: BlockIdx, - stop: BlockIdx, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || cursor == stop || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - if block_has_protected_instructions(&blocks[cursor.idx()]) { - return true; - } - for pred in &predecessors[cursor.idx()] { - stack.push(*pred); - } - } - false - } - - fn block_stores_local(block: &Block, local: usize) -> bool { - block - .instructions - .iter() - .any(|info| match info.instr.real() { - Some(Instruction::StoreFast { var_num }) => { - usize::from(var_num.get(info.arg)) == local - } - Some(Instruction::StoreFastLoadFast { var_nums }) => { - let (store_idx, _) = var_nums.get(info.arg).indexes(); - usize::from(store_idx) == local - } - Some(Instruction::StoreFastStoreFast { var_nums }) => { - let (left, right) = var_nums.get(info.arg).indexes(); - usize::from(left) == local || usize::from(right) == local - } - _ => false, - }) - } - - fn predecessor_chain_stores_local( - blocks: &[Block], - predecessors: &[Vec], - start: BlockIdx, - stop: BlockIdx, - local: usize, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || cursor == stop || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - if block_stores_local(&blocks[cursor.idx()], local) { - return true; - } - for pred in &predecessors[cursor.idx()] { - stack.push(*pred); - } - } - false - } - - fn starts_with_bool_guard(block: &Block) -> bool { - let infos: Vec<_> = block - .instructions - .iter() - .filter(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .take(3) - .collect(); - matches!( - infos.as_slice(), - [ - first, - second, - third, - .. - ] if matches!( - first.instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }) - ) && matches!(second.instr.real(), Some(Instruction::ToBool)) - && matches!( - third.instr.real(), - Some( - Instruction::PopJumpIfFalse { .. } - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfNone { .. } - | Instruction::PopJumpIfNotNone { .. } - ) - ) - ) - } - - fn block_has_loop_back(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - } - - fn block_has_loop_back_to_or_before( - blocks: &[Block], - block: &Block, - target_block: BlockIdx, - ) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) && (info.target == target_block - || comes_before(blocks, info.target, target_block)) - }) - } - - fn block_has_for_iter(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))) - } - - fn block_has_get_iter(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::GetIter))) - } - - fn block_has_fast_load(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) - }) - } - - fn block_has_fast_load_pair(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) - }) - } - - fn block_has_method_load(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadAttr { namei }) if namei.get(info.arg).is_method() - ) - }) - } - - fn block_has_store_fast(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - ) - ) - }) - } - - fn block_has_call(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx - ) - ) - }) - } - - fn block_has_conditional_jump_to(block: &Block, target: BlockIdx) -> bool { - block - .instructions - .iter() - .any(|info| info.target == target && is_conditional_jump(&info.instr)) - } - - fn loop_back_target(block: &Block) -> Option { - block.instructions.iter().find_map(|info| { - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - .then_some(info.target) - }) - } - - fn conditional_fallthrough_loop_header( - blocks: &[Block], - block: &Block, - target: BlockIdx, - ) -> Option { - if !block_has_conditional_jump_to(block, target) { - return None; - } - loop_back_target(block).or_else(|| { - (block.next != BlockIdx::NULL) - .then(|| loop_back_target(&blocks[block.next.idx()]))? - }) - } - - fn any_protected_handler_resumes_to_loop_header( - blocks: &[Block], - loop_header: BlockIdx, - ) -> bool { - blocks.iter().any(|block| { - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| { - handler_chain_resumes_to_loop_header(blocks, handler, loop_header) - }) - }) - } - - fn any_storing_protected_handler_resumes_to_loop_header( - blocks: &[Block], - loop_header: BlockIdx, - ) -> bool { - blocks.iter().any(|block| { - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| { - handler_chain_stores_and_resumes_to_loop_header( - blocks, - handler, - loop_header, - ) - }) - }) - } - - fn any_exception_cleanup_jumps_to(blocks: &[Block], target: BlockIdx) -> bool { - blocks.iter().any(|block| { - block.cold - && block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - && block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - }) - } - - fn fallthrough_chain_reaches_target( - blocks: &[Block], - start: BlockIdx, - target: BlockIdx, - ) -> bool { - let mut cursor = start; - let mut seen = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !seen[cursor.idx()] { - if cursor == target { - return true; - } - seen[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if block_has_non_nop_real_instructions(block) || !block_has_fallthrough(block) { - return false; - } - cursor = block.next; - } - false - } - - fn block_is_normal_finally_cleanup_call(block: &Block) -> bool { - block_has_call(block) - && !block_has_store_fast(block) - && block_has_fallthrough(block) - && block - .instructions - .last() - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::PopTop))) - } - - fn first_fast_load_local(block: &Block) -> Option { - block - .instructions - .iter() - .find_map(|info| match info.instr.real() { - Some( - Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num }, - ) => Some(usize::from(var_num.get(info.arg))), - _ => None, - }) - } - - fn trailing_conditional_guard_local(block: &Block, target: BlockIdx) -> Option { - let infos: Vec<_> = block - .instructions - .iter() - .filter(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .collect(); - let jump = infos.last()?; - if jump.target != target || !is_conditional_jump(&jump.instr) { - return None; - } - let load = if infos - .get(infos.len().wrapping_sub(2)) - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::ToBool))) - { - infos.get(infos.len().wrapping_sub(3))? - } else { - infos.get(infos.len().wrapping_sub(2))? - }; - match load.instr.real() { - Some( - Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num }, - ) => Some(usize::from(var_num.get(load.arg))), - _ => None, - } - } - - fn block_is_calling_finally_cleanup(block: &Block) -> bool { - let has_call = block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx - ) - ) - }); - has_call - && block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::Reraise { .. }))) - && !block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - } - - fn handler_chain_calls_finally_cleanup(blocks: &[Block], handler: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![handler]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_calling_finally_cleanup(block) { - return true; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - stack.push(block.next); - } - } - false - } - - fn protected_block_handler_calls_finally_cleanup( - blocks: &[Block], - block_idx: BlockIdx, - ) -> bool { - let block = &blocks[block_idx.idx()]; - block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .any(|handler| handler_chain_calls_finally_cleanup(blocks, handler)) - } - - fn has_jump_back_predecessor_to( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - ) -> bool { - predecessors[target.idx()].iter().any(|pred| { - let pred_block = &blocks[pred.idx()]; - if pred_block.cold || block_is_exceptional(pred_block) { - return false; - } - blocks[pred.idx()].instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - }) - } - - fn tail_successors( - blocks: &[Block], - predecessors: &[Vec], - block_idx: BlockIdx, - ) -> Vec { - let block = &blocks[block_idx.idx()]; - if block_has_loop_back(block) { - return Vec::new(); - } - if let Some(for_iter) = block - .instructions - .iter() - .find(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))) - { - let mut successors = Vec::with_capacity(3); - if for_iter.target != BlockIdx::NULL { - successors.push(for_iter.target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - let is_loop_header = has_jump_back_predecessor_to(blocks, predecessors, block_idx); - if target != BlockIdx::NULL && !is_loop_header { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - fn tail_has_for_loop_back( - blocks: &[Block], - predecessors: &[Vec], - seed: BlockIdx, - ) -> bool { - let mut seen = vec![false; blocks.len()]; - let mut stack = vec![seed]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || seen[cursor.idx()] { - continue; - } - seen[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - for info in &block.instructions { - if matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) && info.target != BlockIdx::NULL - && block_has_for_iter(&blocks[info.target.idx()]) - { - return true; - } - } - for successor in tail_successors(blocks, predecessors, cursor) { - stack.push(successor); - } - } - false - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - let mut is_handler_resume_block = vec![false; self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - let block_idx = BlockIdx::new(pred_idx as u32); - if block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - && block.instructions.last().is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) - { - is_handler_resume_block[pred_idx] = true; - } - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(block_idx); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(block_idx); - } - } - } - let has_exception_match_handler = self.blocks.iter().any(|block| { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - }); - let has_exception_group_match_handler = self.blocks.iter().any(|block| { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::CheckEgMatch))) - }); - let suppressing_exception_match_method_tails: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) - || block_has_for_iter(block) - || !block_has_fast_load(block) - || !block_has_method_load(block) - { - return None; - } - let target = BlockIdx::new(idx as u32); - let has_suppressing_with_resume_predecessor = - predecessors[idx].iter().any(|pred| { - is_suppressing_with_resume_predecessor(&self.blocks[pred.idx()], target) - }); - (has_suppressing_with_resume_predecessor - && (has_exception_group_match_handler - || has_exception_match_resume_predecessor( - &self.blocks, - &predecessors, - target, - ))) - .then_some(target) - }) - .collect(); - for block_idx in suppressing_exception_match_method_tails { - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - - let handler_resume_loop_successor_tails: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) - || block.cold - || !block_has_fast_load_pair(block) - || !block_has_call(block) - { - return None; - } - predecessors[idx] - .iter() - .any(|pred| { - !block_is_exceptional(&self.blocks[pred.idx()]) - && !self.blocks[pred.idx()].cold - && !block_has_store_fast(&self.blocks[pred.idx()]) - && protected_block_handler_resumes_to_self(&self.blocks, *pred) - }) - .then_some(BlockIdx::new(idx as u32)) - }) - .collect(); - for block_idx in handler_resume_loop_successor_tails { - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - - let finally_cleanup_successor_tails: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) - || block.cold - || block_has_protected_instructions(block) - || block_has_call(block) - || !starts_with_bool_guard(block) - || !block_has_fast_load(block) - { - return None; - } - let target = BlockIdx::new(idx as u32); - predecessors[idx] - .iter() - .any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - block_has_conditional_jump_to(pred_block, target) - && pred_block.next != BlockIdx::NULL - && { - let cleanup_block = &self.blocks[pred_block.next.idx()]; - block_is_normal_finally_cleanup_call(cleanup_block) - && trailing_conditional_guard_local(pred_block, target) - == first_fast_load_local(cleanup_block) - && fallthrough_chain_reaches_target( - &self.blocks, - cleanup_block.next, - target, - ) - } - }) - .then_some(BlockIdx::new(idx as u32)) - }) - .collect(); - let mut visited_finally_tail = vec![false; self.blocks.len()]; - for seed in finally_cleanup_successor_tails { - let mut segment = Vec::new(); - let mut seen = vec![false; self.blocks.len()]; - let mut stack = vec![seed]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || seen[cursor.idx()] { - continue; - } - seen[cursor.idx()] = true; - let block = &self.blocks[cursor.idx()]; - if block_is_exceptional(block) || block.cold || block_has_loop_back(block) { - continue; - } - segment.push(cursor); - for successor in tail_successors(&self.blocks, &predecessors, cursor) { - stack.push(successor); - } - } - let segment_ops: Vec<_> = segment - .iter() - .flat_map(|block_idx| { - self.blocks[block_idx.idx()] - .instructions - .iter() - .filter_map(|info| info.instr.real()) - }) - .collect(); - let has_call = segment_ops.iter().any(|instr| { - matches!(instr, Instruction::Call { .. } | Instruction::CallKw { .. }) - }); - let has_store_fast = segment_ops.iter().any(|instr| { - matches!( - instr, - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - ) - }); - if !has_call || !has_store_fast { - continue; - } - for block_idx in segment { - if visited_finally_tail[block_idx.idx()] { - continue; - } - visited_finally_tail[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - - let handler_break_join_tails: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) - || block.cold - || !block_has_fast_load(block) - || !block_has_method_load(block) - || !block_has_call(block) - || !has_exception_match_resume_predecessor( - &self.blocks, - &predecessors, - BlockIdx::new(idx as u32), - ) - { - return None; - } - let join = BlockIdx::new(idx as u32); - let body = predecessors[idx].iter().find_map(|jump_pred| { - let jump_block = &self.blocks[jump_pred.idx()]; - if block_is_exceptional(jump_block) - || jump_block.cold - || !jump_block.instructions.last().is_some_and(|info| { - info.target == join && info.instr.is_unconditional_jump() - }) - { - return None; - } - predecessors[jump_pred.idx()].iter().find_map(|cond_pred| { - let cond_block = &self.blocks[cond_pred.idx()]; - if block_is_exceptional(cond_block) || cond_block.cold { - return None; - } - let cond_idx = trailing_conditional_jump_index(cond_block)?; - let cond = cond_block.instructions[cond_idx]; - if cond.target == *jump_pred - || cond.target == BlockIdx::NULL - || cond_block.next != *jump_pred - { - return None; - } - let body = next_nonempty_block(&self.blocks, cond.target); - if body == BlockIdx::NULL { - return None; - } - let body_block = &self.blocks[body.idx()]; - (!block_is_exceptional(body_block) - && !body_block.cold - && block_has_loop_back(body_block) - && block_has_fast_load(body_block) - && block_has_method_load(body_block) - && block_has_call(body_block) - && block_has_store_fast(body_block)) - .then_some(body) - }) - })?; - Some((join, body)) - }) - .collect(); - for (join, body) in handler_break_join_tails { - deoptimize_block_borrows(&mut self.blocks[join.idx()]); - deoptimize_block_borrows(&mut self.blocks[body.idx()]); - } - - let conditional_loop_bypass_tails: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) - || block.cold - || !block_has_fast_load(block) - || !has_exception_match_handler - { - return None; - } - let target = BlockIdx::new(idx as u32); - if is_plain_protected_resume_successor(&self.blocks, &predecessors, target) { - return None; - } - predecessors[idx] - .iter() - .any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - conditional_fallthrough_loop_header(&self.blocks, pred_block, target) - .is_some_and(|loop_header| { - let storing_handler_resumes = - any_storing_protected_handler_resumes_to_loop_header( - &self.blocks, - loop_header, - ) || any_storing_protected_handler_resumes_to_loop_header( - &self.blocks, - *pred, - ); - block_has_for_iter(&self.blocks[loop_header.idx()]) - && predecessor_chain_has_protected_instructions( - &self.blocks, - &predecessors, - *pred, - loop_header, - ) - && storing_handler_resumes - && (any_protected_handler_resumes_to_loop_header( - &self.blocks, - loop_header, - ) || any_protected_handler_resumes_to_loop_header( - &self.blocks, - *pred, - ) || any_exception_cleanup_jumps_to( - &self.blocks, - loop_header, - ) || any_exception_cleanup_jumps_to(&self.blocks, *pred)) - }) - }) - .then_some(target) - }) - .collect(); - let mut visited_conditional_tail = vec![false; self.blocks.len()]; - for seed in conditional_loop_bypass_tails { - let mut segment = Vec::new(); - let mut seen = vec![false; self.blocks.len()]; - let mut stack = vec![seed]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || seen[cursor.idx()] { - continue; - } - seen[cursor.idx()] = true; - let block = &self.blocks[cursor.idx()]; - if block_is_exceptional(block) || block.cold || block_has_loop_back(block) { - continue; - } - segment.push(cursor); - for successor in tail_successors(&self.blocks, &predecessors, cursor) { - stack.push(successor); - } - } - let segment_ops: Vec<_> = segment - .iter() - .flat_map(|block_idx| { - self.blocks[block_idx.idx()] - .instructions - .iter() - .filter_map(|info| info.instr.real()) - }) - .collect(); - let has_store_fast = segment_ops.iter().any(|instr| { - matches!( - instr, - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - ) - }); - if !has_store_fast { - continue; - } - for block_idx in segment { - if visited_conditional_tail[block_idx.idx()] { - continue; - } - visited_conditional_tail[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - - let seeds: Vec<_> = self - .blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - let has_bool_guard_tail = starts_with_bool_guard(block); - let has_loop_tail = block_has_for_iter(block) || block_has_get_iter(block); - let has_protected_predecessor = predecessors[idx].iter().any(|pred| { - self.blocks[pred.idx()] - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - }); - let has_protected_finally_cleanup_predecessor = predecessors[idx] - .iter() - .any(|pred| protected_block_handler_calls_finally_cleanup(&self.blocks, *pred)); - let has_finally_except_loop_tail = has_exception_match_handler - && has_loop_tail - && has_protected_finally_cleanup_predecessor; - let has_handler_resume_predecessor = predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !is_named_except_cleanup_normal_exit_block(pred_block) - && (is_handler_resume_block[pred.idx()] - || is_handler_resume_predecessor(pred_block, BlockIdx::new(idx as u32))) - }); - let is_plain_protected_resume_successor = is_plain_protected_resume_successor( - &self.blocks, - &predecessors, - BlockIdx::new(idx as u32), - ); - let has_suppressing_with_resume_predecessor = - predecessors[idx].iter().any(|pred| { - is_suppressing_with_resume_predecessor( - &self.blocks[pred.idx()], - BlockIdx::new(idx as u32), - ) - }); - let has_exception_match_resume_predecessor = has_exception_match_resume_predecessor( - &self.blocks, - &predecessors, - BlockIdx::new(idx as u32), - ); - let bool_guard_local = has_bool_guard_tail - .then(|| first_fast_load_local(block)) - .flatten(); - let handler_resume_predecessor_stores_guard = - bool_guard_local.is_some_and(|local| { - predecessors[idx].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - is_handler_resume_predecessor(pred_block, BlockIdx::new(idx as u32)) - && predecessor_chain_stores_local( - &self.blocks, - &predecessors, - *pred, - BlockIdx::new(idx as u32), - local, - ) - }) - }); - let is_loop_header = has_jump_back_predecessor_to( - &self.blocks, - &predecessors, - BlockIdx::new(idx as u32), - ); - let has_handler_resume_loop_tail = block_has_get_iter(block) - && has_suppressing_with_resume_predecessor - && has_exception_match_resume_predecessor; - let has_supported_tail = has_bool_guard_tail - || has_finally_except_loop_tail - || has_handler_resume_loop_tail; - if block_is_exceptional(block) || !has_supported_tail { - return None; - } - let should_seed = (has_protected_predecessor - && has_finally_except_loop_tail - && !has_suppressing_with_resume_predecessor) - || (has_bool_guard_tail - && has_handler_resume_predecessor - && !handler_resume_predecessor_stores_guard - && !is_plain_protected_resume_successor - && !is_loop_header - && tail_has_for_loop_back( - &self.blocks, - &predecessors, - BlockIdx::new(idx as u32), - )) - || has_handler_resume_loop_tail; - let allow_any_loop_back = - has_finally_except_loop_tail || has_handler_resume_loop_tail; - should_seed.then_some(( - BlockIdx::new(idx as u32), - has_handler_resume_loop_tail, - allow_any_loop_back, - )) - }) - .collect(); - - let mut visited = vec![false; self.blocks.len()]; - for (seed, include_join_tail, allow_any_loop_back) in seeds { - let mut segment = Vec::new(); - let mut found_loop_back = false; - let mut seen = vec![false; self.blocks.len()]; - let mut stack = vec![seed]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || seen[cursor.idx()] { - continue; - } - seen[cursor.idx()] = true; - let block = &self.blocks[cursor.idx()]; - if block_is_exceptional(block) { - continue; - } - segment.push(cursor); - if block_has_loop_back(block) { - found_loop_back |= allow_any_loop_back - || block_has_loop_back_to_or_before(&self.blocks, block, seed); - continue; - } - for successor in tail_successors(&self.blocks, &predecessors, cursor) { - stack.push(successor); - } - } - if !found_loop_back { - continue; - } - - let segment_ops: Vec<_> = segment - .iter() - .flat_map(|block_idx| { - self.blocks[block_idx.idx()] - .instructions - .iter() - .filter_map(|info| info.instr.real()) - }) - .collect(); - let has_call = segment_ops.iter().any(|instr| { - matches!(instr, Instruction::Call { .. } | Instruction::CallKw { .. }) - }); - let has_store_fast = segment_ops.iter().any(|instr| { - matches!( - instr, - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - ) - }); - if !has_call || !has_store_fast { - continue; - } - - let mut in_segment = vec![false; self.blocks.len()]; - for block_idx in &segment { - in_segment[block_idx.idx()] = true; - } - - for block_idx in segment { - if visited[block_idx.idx()] { - continue; - } - if !include_join_tail - && block_idx != seed - && predecessors[block_idx.idx()] - .iter() - .any(|pred| !in_segment[pred.idx()] && !is_handler_resume_block[pred.idx()]) - { - continue; - } - if has_exception_group_match_handler - && block_has_for_iter(&self.blocks[block_idx.idx()]) - && block_has_protected_instructions(&self.blocks[block_idx.idx()]) - { - continue; - } - visited[block_idx.idx()] = true; - deoptimize_block_borrows(&mut self.blocks[block_idx.idx()]); - } - } - } - - fn deoptimize_borrow_after_protected_import(&mut self) { - fn deoptimize_borrow(info: &mut InstructionInfo) { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - - fn deoptimize_block_borrows_from(block: &mut Block, start: usize) { - for info in block.instructions.iter_mut().skip(start) { - deoptimize_borrow(info); - } - } - - fn deoptimize_protected_block_borrows_from( - block: &mut Block, - start: usize, - protected_store_locals: Option<&[bool]>, - ) { - for info in block.instructions.iter_mut().skip(start) { - if info.except_handler.is_none() { - break; - } - if let Some(protected_store_locals) = protected_store_locals { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => { - let local = usize::from(var_num.get(info.arg)); - if protected_store_locals.get(local).copied().unwrap_or(false) { - continue; - } - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { var_nums }) => { - let (left, right) = var_nums.get(info.arg).indexes(); - let skip_left = protected_store_locals - .get(usize::from(left)) - .copied() - .unwrap_or(false); - let skip_right = protected_store_locals - .get(usize::from(right)) - .copied() - .unwrap_or(false); - if skip_left && skip_right { - continue; - } - } - _ => {} - } - } - deoptimize_borrow(info); - } - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && jumps_to_target - } - - fn handler_chain_returns(blocks: &[Block], handler_block: BlockIdx) -> bool { - let mut cursor = handler_block; - let mut visited = vec![false; blocks.len()]; - let mut after_pop_except = false; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - for info in &blocks[cursor.idx()].instructions { - match info.instr.real() { - Some(Instruction::ReturnValue) if !info.no_location_exit => return true, - Some(Instruction::PopExcept) => after_pop_except = true, - Some(_) if after_pop_except && is_conditional_jump(&info.instr) => { - return false; - } - Some(instr) - if after_pop_except - && (instr.is_unconditional_jump() || instr.is_scope_exit()) => - { - return false; - } - _ => {} - } - } - cursor = blocks[cursor.idx()].next; - } - false - } - - fn handler_chain_continues_to_or_before( - blocks: &[Block], - handler_block: BlockIdx, - seed: BlockIdx, - block_order: &[u32], - ) -> bool { - let mut cursor = handler_block; - let mut visited = vec![false; blocks.len()]; - let mut after_pop_except = false; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - for info in &blocks[cursor.idx()].instructions { - match info.instr.real() { - Some(Instruction::PopExcept) => after_pop_except = true, - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. }, - ) if after_pop_except - && info.target != BlockIdx::NULL - && block_order[info.target.idx()] <= block_order[seed.idx()] => - { - return true; - } - Some(_) if after_pop_except && is_conditional_jump(&info.instr) => { - return false; - } - Some(instr) - if after_pop_except - && (instr.is_unconditional_jump() || instr.is_scope_exit()) => - { - return false; - } - _ => {} - } - } - cursor = blocks[cursor.idx()].next; - } - false - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn push_normal_successors(stack: &mut Vec, block: &Block) { - if block.next != BlockIdx::NULL { - stack.push(block.next); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - } - - fn has_nested_protected_import_tail( - blocks: &[Block], - seed: BlockIdx, - import_idx: usize, - ) -> bool { - let Some(import_handler) = blocks[seed.idx()].instructions[import_idx].except_handler - else { - return false; - }; - let mut cursor = seed; - let mut start = import_idx + 1; - while cursor != BlockIdx::NULL && !block_is_exceptional(&blocks[cursor.idx()]) { - let block = &blocks[cursor.idx()]; - let mut saw_protected = false; - for info in block.instructions.iter().skip(start) { - let Some(handler) = info.except_handler else { - return false; - }; - saw_protected = true; - if handler.handler_block != import_handler.handler_block { - return true; - } - } - if !saw_protected { - return false; - } - cursor = block.next; - start = 0; - } - false - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let mut block_order = vec![u32::MAX; self.blocks.len()]; - let mut cursor = BlockIdx(0); - let mut pos = 0u32; - while cursor != BlockIdx::NULL { - block_order[cursor.idx()] = pos; - pos += 1; - cursor = self.blocks[cursor.idx()].next; - } - - let seeds: Vec<_> = - self.blocks - .iter() - .enumerate() - .filter_map(|(idx, block)| { - if block_is_exceptional(block) { - return None; - } - let import_idx = block.instructions.iter().position(|info| { - info.except_handler.is_some() - && matches!(info.instr.real(), Some(Instruction::ImportName { .. })) - })?; - let handler_returns = block.instructions[import_idx] - .except_handler - .is_some_and(|handler| { - handler_chain_returns(&self.blocks, handler.handler_block) - }); - let handler_continues = block.instructions[import_idx] - .except_handler - .is_some_and(|handler| { - handler_chain_continues_to_or_before( - &self.blocks, - handler.handler_block, - BlockIdx::new(idx as u32), - &block_order, - ) - }); - let nested_protected_tail = has_nested_protected_import_tail( - &self.blocks, - BlockIdx::new(idx as u32), - import_idx, - ); - if !handler_returns && !handler_continues && !nested_protected_tail { - return None; - } - Some(( - BlockIdx::new(idx as u32), - import_idx, - handler_returns, - handler_continues, - nested_protected_tail, - )) - }) - .collect(); - - let mut visited = vec![false; self.blocks.len()]; - for (seed, import_idx, handler_returns, handler_continues, nested_protected_tail) in seeds { - let mut protected_store_locals = vec![false; self.metadata.varnames.len()]; - for info in self.blocks[seed.idx()] - .instructions - .iter() - .skip(import_idx + 1) - { - if info.except_handler.is_none() { - break; - } - if let Some(Instruction::StoreFast { var_num }) = info.instr.real() { - let local = usize::from(var_num.get(info.arg)); - if let Some(slot) = protected_store_locals.get_mut(local) { - *slot = true; - } - } - } - - let mut in_segment = vec![false; self.blocks.len()]; - in_segment[seed.idx()] = true; - let mut segment = vec![(seed, import_idx + 1)]; - let mut cursor = self.blocks[seed.idx()].next; - while cursor != BlockIdx::NULL && !block_is_exceptional(&self.blocks[cursor.idx()]) { - if !handler_continues - && !handler_returns - && !nested_protected_tail - && block_has_protected_instructions(&self.blocks[cursor.idx()]) - { - break; - } - if predecessors[cursor.idx()].iter().any(|pred| { - !in_segment[pred.idx()] - && block_order[pred.idx()] < block_order[cursor.idx()] - && !is_handler_resume_predecessor(&self.blocks[pred.idx()], cursor) - }) { - break; - } - in_segment[cursor.idx()] = true; - segment.push((cursor, 0)); - if self.blocks[cursor.idx()] - .instructions - .iter() - .any(|info| info.instr.real().is_some_and(|instr| instr.is_scope_exit())) - && !handler_returns - && !handler_continues - { - break; - } - cursor = self.blocks[cursor.idx()].next; - } - - if nested_protected_tail { - let mut stack = Vec::new(); - for (block_idx, _) in &segment { - push_normal_successors(&mut stack, &self.blocks[block_idx.idx()]); - } - while let Some(candidate) = stack.pop() { - if candidate == BlockIdx::NULL - || in_segment[candidate.idx()] - || block_is_exceptional(&self.blocks[candidate.idx()]) - || !block_has_protected_instructions(&self.blocks[candidate.idx()]) - { - continue; - } - if predecessors[candidate.idx()].iter().any(|pred| { - !in_segment[pred.idx()] - && block_order[pred.idx()] < block_order[candidate.idx()] - && !is_handler_resume_predecessor(&self.blocks[pred.idx()], candidate) - }) { - continue; - } - in_segment[candidate.idx()] = true; - segment.push((candidate, 0)); - push_normal_successors(&mut stack, &self.blocks[candidate.idx()]); - } - } - - for (block_idx, start) in segment { - if visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - if block_idx == seed { - let protected_store_locals = - (!nested_protected_tail).then_some(protected_store_locals.as_slice()); - deoptimize_protected_block_borrows_from( - &mut self.blocks[block_idx.idx()], - start, - protected_store_locals, - ); - } else { - deoptimize_block_borrows_from(&mut self.blocks[block_idx.idx()], start); - } - } - } - } - - fn deoptimize_borrow_after_protected_store_tail(&mut self) { - fn deoptimize_block_borrows_from(block: &mut Block, start: usize) { - for info in block.instructions.iter_mut().skip(start) { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn is_handler_resume_predecessor(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - has_pop_except && jumps_to_target - } - - fn handler_chain_can_resume_to_segment( - blocks: &[Block], - block: &Block, - in_segment: &[bool], - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let handler_blocks: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - for handler_block in handler_blocks { - let mut cursor = handler_block; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let handler = &blocks[cursor.idx()]; - let mut after_pop_except = false; - for info in &handler.instructions { - if matches!(info.instr.real(), Some(Instruction::PopExcept)) { - after_pop_except = true; - continue; - } - if after_pop_except - && info.target != BlockIdx::NULL - && in_segment[info.target.idx()] - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - { - return true; - } - } - cursor = handler.next; - } - } - false - } - - fn handler_chain_has_explicit_raise(blocks: &[Block], block: &Block) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - let handler = &blocks[cursor.idx()]; - if handler - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::RaiseVarargs { .. }))) - { - return true; - } - for info in &handler.instructions { - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - if info.instr.is_unconditional_jump() && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if handler.next != BlockIdx::NULL { - stack.push(handler.next); - } - } - false - } - - fn handler_chain_exits_loop_after_pop_except(blocks: &[Block], block: &Block) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - let handler = &blocks[cursor.idx()]; - let mut after_pop_except = false; - for info in &handler.instructions { - if matches!(info.instr.real(), Some(Instruction::PopExcept)) { - after_pop_except = true; - continue; - } - if after_pop_except - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - { - return true; - } - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - if info.instr.is_unconditional_jump() && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if handler.next != BlockIdx::NULL { - stack.push(handler.next); - } - } - false - } - - fn handler_chain_has_nested_exception_match(blocks: &[Block], block: &Block) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - let mut matches_seen = 0; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || visited[cursor.idx()] { - continue; - } - visited[cursor.idx()] = true; - let handler = &blocks[cursor.idx()]; - for info in &handler.instructions { - if matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) { - matches_seen += 1; - if matches_seen > 1 { - return true; - } - } - if is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL { - stack.push(info.target); - } - if info.instr.is_unconditional_jump() && info.target != BlockIdx::NULL { - stack.push(info.target); - } - } - if handler.next != BlockIdx::NULL { - stack.push(handler.next); - } - } - false - } - - fn block_has_tail_deopt_trigger_from(block: &Block, start: usize) -> bool { - block.instructions.iter().skip(start).any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx - | Instruction::StoreAttr { .. } - | Instruction::StoreSubscr - ) - ) - }) - } - - fn block_has_generator_delegation(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::GetYieldFromIter - | Instruction::Send { .. } - | Instruction::YieldValue { .. } - ) - ) - }) - } - - fn block_has_external_backward_jump(block: &Block, in_segment: &[bool]) -> bool { - block.instructions.iter().any(|info| { - let target_is_external = - info.target != BlockIdx::NULL && !in_segment[info.target.idx()]; - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) && target_is_external - }) - } - - fn block_has_for_iter(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))) - } - - fn block_suffix_has_for_loop_back(block: &Block, blocks: &[Block], start: usize) -> bool { - block.instructions.iter().skip(start).any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) && info.target != BlockIdx::NULL - && block_has_for_iter(&blocks[info.target.idx()]) - }) - } - - fn block_has_for_loop_back(block: &Block, blocks: &[Block]) -> bool { - block_suffix_has_for_loop_back(block, blocks, 0) - } - - fn block_has_backward_jump(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }) - } - - fn normal_successors(block: &Block) -> Vec { - let Some(last) = block.instructions.last() else { - return (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect(); - }; - if last.instr.is_scope_exit() { - return Vec::new(); - } - if last.instr.is_unconditional_jump() { - return (last.target != BlockIdx::NULL) - .then_some(last.target) - .into_iter() - .collect(); - } - if let Some(cond_idx) = trailing_conditional_jump_index(block) { - let mut successors = Vec::with_capacity(2); - let target = block.instructions[cond_idx].target; - if target != BlockIdx::NULL { - successors.push(target); - } - if block.next != BlockIdx::NULL { - successors.push(block.next); - } - return successors; - } - (block.next != BlockIdx::NULL) - .then_some(block.next) - .into_iter() - .collect() - } - - fn segment_reaches_external_backward_jump( - blocks: &[Block], - segment: &[(BlockIdx, usize)], - in_segment: &[bool], - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = segment - .iter() - .map(|(block_idx, _)| *block_idx) - .collect::>(); - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block_has_external_backward_jump(block, in_segment) { - return true; - } - stack.extend(normal_successors(block)); - } - false - } - - fn normal_path_reaches_for_loop_back(blocks: &[Block], start: BlockIdx) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = vec![start]; - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block_is_exceptional(block) || block.cold { - continue; - } - if block_has_for_loop_back(block, blocks) { - return true; - } - stack.extend(normal_successors(block)); - } - false - } - - fn segment_has_for_loop_back(blocks: &[Block], segment: &[(BlockIdx, usize)]) -> bool { - segment - .iter() - .any(|(block_idx, _)| block_has_for_loop_back(&blocks[block_idx.idx()], blocks)) - } - - fn segment_has_backward_jump(blocks: &[Block], segment: &[(BlockIdx, usize)]) -> bool { - segment - .iter() - .any(|(block_idx, _)| block_has_backward_jump(&blocks[block_idx.idx()])) - } - - fn segment_has_yield_value(blocks: &[Block], segment: &[(BlockIdx, usize)]) -> bool { - segment.iter().any(|(block_idx, start)| { - blocks[block_idx.idx()] - .instructions - .iter() - .skip(*start) - .any(|info| matches!(info.instr.real(), Some(Instruction::YieldValue { .. }))) - }) - } - - fn block_suffix_starts_with_builtin_any_all_fast_path( - block: &Block, - names: &IndexSet, - start: usize, - ) -> bool { - block - .instructions - .iter() - .skip(start) - .find_map(|info| { - let instr = info.instr.real()?; - if matches!(instr, Instruction::Nop | Instruction::NotTaken) { - return None; - } - Some((instr, info.arg)) - }) - .is_some_and(|(instr, arg)| { - matches!( - instr, - Instruction::LoadGlobal { namei } - if names[usize::try_from(namei.get(arg) >> 1).unwrap()].as_str() - == "any" - || names[usize::try_from(namei.get(arg) >> 1).unwrap()].as_str() - == "all" - ) - }) - } - - fn segment_starts_with_builtin_any_all_fast_path( - blocks: &[Block], - names: &IndexSet, - segment: &[(BlockIdx, usize)], - ) -> bool { - let Some((block_idx, start)) = segment.first() else { - return false; - }; - block_suffix_starts_with_builtin_any_all_fast_path( - &blocks[block_idx.idx()], - names, - *start, - ) - } - - fn segment_has_named_except_cleanup_predecessor( - blocks: &[Block], - predecessors: &[Vec], - segment: &[(BlockIdx, usize)], - ) -> bool { - segment.iter().any(|(block_idx, _)| { - predecessors[block_idx.idx()] - .iter() - .any(|pred| is_named_except_cleanup_normal_exit_block(&blocks[pred.idx()])) - }) - } - - fn protected_store_subscr_operand_start(block: &Block) -> Option { - let store_idx = block.instructions.iter().position(|info| { - matches!(info.instr.real(), Some(Instruction::StoreSubscr)) - && info.except_handler.is_some() - })?; - - let mut stack_items = 0; - for start in (0..store_idx).rev() { - let produced = match block.instructions[start].instr.real() { - Some( - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadGlobal { .. } - | Instruction::LoadName { .. } - | Instruction::LoadDeref { .. }, - ) => 1, - Some( - Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. }, - ) => 2, - _ => return None, - }; - stack_items += produced; - if stack_items >= 3 { - return Some(start); - } - } - None - } - - fn block_has_attr_named(block: &Block, names: &IndexSet, attr: &str) -> bool { - block.instructions.iter().any(|info| { - let raw = u32::from(info.arg) as usize; - matches!( - info.instr.real(), - Some(Instruction::LoadAttr { namei }) - if names[usize::try_from(namei.get(info.arg).name_idx()).unwrap()].as_str() - == attr - || names - .get_index(raw) - .is_some_and(|name| name.as_str() == attr) - || names - .get_index(raw >> 1) - .is_some_and(|name| name.as_str() == attr) - ) - }) - } - - fn block_has_protected_instructions(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - } - - fn block_has_non_nop_real_instructions(block: &Block) -> bool { - block.instructions.iter().any(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop)) - }) - } - - fn first_unprotected_suffix(block: &Block) -> Option { - let mut saw_protected = false; - for (idx, info) in block.instructions.iter().enumerate() { - if info.except_handler.is_some() { - saw_protected = true; - } else if saw_protected { - return Some(idx); - } - } - None - } - - fn collect_stored_fast_locals_until(block: &Block, end: usize) -> Vec { - let mut locals = Vec::new(); - for info in block.instructions.iter().take(end) { - collect_stored_fast_local(info, &mut locals); - } - locals - } - - fn collect_protected_stored_fast_locals_until(block: &Block, end: usize) -> Vec { - let mut locals = Vec::new(); - for info in block - .instructions - .iter() - .take(end) - .filter(|info| info.except_handler.is_some()) - { - collect_stored_fast_local(info, &mut locals); - } - locals - } - - fn collect_stored_fast_local(info: &InstructionInfo, locals: &mut Vec) { - match info.instr.real() { - Some(Instruction::StoreFast { var_num }) => { - locals.push(usize::from(var_num.get(info.arg))); - } - Some(Instruction::StoreFastLoadFast { var_nums }) => { - let (store_idx, _) = var_nums.get(info.arg).indexes(); - locals.push(usize::from(store_idx)); - } - Some(Instruction::StoreFastStoreFast { var_nums }) => { - let (idx1, idx2) = var_nums.get(info.arg).indexes(); - locals.push(usize::from(idx1)); - locals.push(usize::from(idx2)); - } - _ => {} - } - } - - fn collect_borrowed_stored_locals_in_segment( - blocks: &[Block], - segment: &[(BlockIdx, usize)], - stored_locals: &[usize], - ) -> Vec { - let mut borrowed = Vec::new(); - for (block_idx, start) in segment { - for info in blocks[block_idx.idx()].instructions.iter().skip(*start) { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => { - let local = usize::from(var_num.get(info.arg)); - if stored_locals.contains(&local) { - borrowed.push(local); - } - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { var_nums }) => { - let (left, right) = var_nums.get(info.arg).indexes(); - for local in [usize::from(left), usize::from(right)] { - if stored_locals.contains(&local) { - borrowed.push(local); - } - } - } - _ => {} - } - } - } - borrowed.sort_unstable(); - borrowed.dedup(); - borrowed - } - - fn handler_chain_resumes_after_assigning_locals( - blocks: &[Block], - block: &Block, - in_segment: &[bool], - locals: &[usize], - ) -> bool { - if locals.is_empty() { - return false; - } - - let mut visited = vec![false; blocks.len()]; - let handler_blocks: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - for handler_block in handler_blocks { - let mut cursor = handler_block; - let mut assigned = Vec::new(); - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let handler = &blocks[cursor.idx()]; - let mut after_pop_except = false; - for info in &handler.instructions { - if !after_pop_except { - collect_stored_fast_local(info, &mut assigned); - } - if matches!(info.instr.real(), Some(Instruction::PopExcept)) { - after_pop_except = true; - continue; - } - if after_pop_except - && info.target != BlockIdx::NULL - && in_segment[info.target.idx()] - && info.instr.is_unconditional_jump() - { - assigned.sort_unstable(); - assigned.dedup(); - if locals.iter().all(|local| assigned.contains(local)) { - return true; - } - } - } - cursor = handler.next; - } - } - false - } - - fn block_is_normal_cleanup_call(block: &Block, metadata: &CodeUnitMetadata) -> bool { - if !block_has_fallthrough(block) { - return false; - } - let reals: Vec<_> = block - .instructions - .iter() - .filter(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .collect(); - let [.., none1, none2, none3, call, pop_top] = reals.as_slice() else { - return false; - }; - is_load_const_none(none1, metadata) - && is_load_const_none(none2, metadata) - && is_load_const_none(none3, metadata) - && matches!(call.instr.real(), Some(Instruction::Call { .. })) - && matches!(pop_top.instr.real(), Some(Instruction::PopTop)) - && collect_stored_fast_locals_until(block, block.instructions.len()).is_empty() - } - - fn collect_protected_predecessor_stored_fast_locals( - blocks: &[Block], - predecessors: &[Vec], - start: BlockIdx, - ) -> Vec { - let mut locals = Vec::new(); - let mut visited = vec![false; blocks.len()]; - let mut stack = predecessors[start.idx()].clone(); - while let Some(block_idx) = stack.pop() { - if block_idx == BlockIdx::NULL || visited[block_idx.idx()] { - continue; - } - visited[block_idx.idx()] = true; - let block = &blocks[block_idx.idx()]; - if block.cold - || block_is_exceptional(block) - || !block_has_protected_instructions(block) - { - continue; - } - locals.extend(collect_protected_stored_fast_locals_until( - block, - block.instructions.len(), - )); - stack.extend(predecessors[block_idx.idx()].iter().copied()); - } - locals.sort_unstable(); - locals.dedup(); - locals - } - - fn borrows_any_local_from(block: &Block, locals: &[usize], start: usize) -> bool { - block - .instructions - .iter() - .skip(start) - .any(|info| match info.instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => { - locals.contains(&usize::from(var_num.get(info.arg))) - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { var_nums }) => { - let (idx1, idx2) = var_nums.get(info.arg).indexes(); - locals.contains(&usize::from(idx1)) || locals.contains(&usize::from(idx2)) - } - _ => false, - }) - } - - fn borrowed_inplace_local_update_start(block: &Block) -> Option { - for i in 0..block.instructions.len().saturating_sub(3) { - let local = match block.instructions[i].instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => { - usize::from(var_num.get(block.instructions[i].arg)) - } - _ => continue, - }; - let Some(Instruction::BinaryOp { op }) = block.instructions[i + 2].instr.real() - else { - continue; - }; - if !matches!( - op.get(block.instructions[i + 2].arg), - oparg::BinaryOperator::InplaceAdd - | oparg::BinaryOperator::InplaceSubtract - | oparg::BinaryOperator::InplaceMultiply - | oparg::BinaryOperator::InplaceMatrixMultiply - | oparg::BinaryOperator::InplaceTrueDivide - | oparg::BinaryOperator::InplaceFloorDivide - | oparg::BinaryOperator::InplaceRemainder - | oparg::BinaryOperator::InplacePower - | oparg::BinaryOperator::InplaceLshift - | oparg::BinaryOperator::InplaceRshift - | oparg::BinaryOperator::InplaceAnd - | oparg::BinaryOperator::InplaceXor - | oparg::BinaryOperator::InplaceOr - ) { - continue; - } - if matches!( - block.instructions[i + 3].instr.real(), - Some(Instruction::StoreFast { var_num }) - if usize::from(var_num.get(block.instructions[i + 3].arg)) == local - ) { - return Some(i); - } - } - None - } - - fn starts_with_borrowed_local_bool_guard_from( - block: &Block, - locals: &[usize], - start: usize, - ) -> bool { - let mut reals = block - .instructions - .iter() - .skip(start) - .filter(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .take(3); - let Some(first) = reals.next() else { - return false; - }; - let Some(second) = reals.next() else { - return false; - }; - let Some(third) = reals.next() else { - return false; - }; - let borrows_stored_local = match first.instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => { - locals.contains(&usize::from(var_num.get(first.arg))) - } - _ => false, - }; - borrows_stored_local - && matches!(second.instr.real(), Some(Instruction::ToBool)) - && matches!( - third.instr.real(), - Some(Instruction::PopJumpIfFalse { .. } | Instruction::PopJumpIfTrue { .. }) - ) - } - - fn starts_with_borrowed_local_bool_guard(block: &Block, locals: &[usize]) -> bool { - starts_with_borrowed_local_bool_guard_from(block, locals, 0) - } - - fn protected_store_bool_guard_start(block: &Block) -> Option<(usize, Vec)> { - let mut saw_call = false; - for store_idx in 0..block.instructions.len() { - saw_call |= matches!( - block.instructions[store_idx].instr.real(), - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx - ) - ); - if !saw_call { - continue; - } - let local = match block.instructions[store_idx].instr.real() { - Some(Instruction::StoreFast { var_num }) => { - usize::from(var_num.get(block.instructions[store_idx].arg)) - } - _ => continue, - }; - let mut reals = block - .instructions - .iter() - .enumerate() - .skip(store_idx + 1) - .filter(|(_, info)| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }); - let Some((load_idx, load)) = reals.next() else { - continue; - }; - let borrows_stored_local = matches!( - load.instr.real(), - Some(Instruction::LoadFastBorrow { var_num }) - if usize::from(var_num.get(load.arg)) == local - ); - if !borrows_stored_local { - continue; - } - let Some((_, second)) = reals.next() else { - continue; - }; - let Some((_, third)) = reals.next() else { - continue; - }; - if matches!(second.instr.real(), Some(Instruction::ToBool)) - && matches!( - third.instr.real(), - Some( - Instruction::PopJumpIfFalse { .. } | Instruction::PopJumpIfTrue { .. } - ) - ) - { - return Some((load_idx, vec![local])); - } - } - None - } - - fn conditional_target(block: &Block) -> Option { - block - .instructions - .iter() - .find(|info| is_conditional_jump(&info.instr) && info.target != BlockIdx::NULL) - .map(|info| info.target) - } - - fn block_is_simple_exit_branch(block: &Block) -> bool { - let last_real = block - .instructions - .iter() - .rev() - .find_map(|info| info.instr.real()); - last_real.is_some_and(|instr| { - instr.is_scope_exit() || AnyInstruction::Real(instr).is_unconditional_jump() - }) - } - - fn contains_debug_four_guard(block: &Block, names: &IndexSet) -> bool { - let reals: Vec<_> = block - .instructions - .iter() - .filter(|info| { - info.instr.real().is_some_and(|instr| { - !matches!(instr, Instruction::Nop | Instruction::NotTaken) - }) - }) - .collect(); - if reals.len() < 5 { - return false; - } - reals.windows(5).any(|window| { - let loads_debug_attr = window.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadAttr { namei }) - if names[usize::try_from(namei.get(info.arg).name_idx()).unwrap()].as_str() - == "debug" - ) - }); - let compares_with_four = window.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::LoadSmallInt { i }) if i.get(info.arg) == 4 - ) - }) && window - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::CompareOp { .. }))); - let has_conditional = window.iter().any(|info| is_conditional_jump(&info.instr)); - loads_debug_attr && compares_with_four && has_conditional - }) - } - - fn marker_only_block(block: &Block) -> bool { - block.instructions.iter().all(|info| { - info.instr - .real() - .is_none_or(|instr| matches!(instr, Instruction::Nop | Instruction::NotTaken)) - }) - } - - fn predecessor_chain_contains_debug_four_guard( - blocks: &[Block], - predecessors: &[Vec], - block_idx: BlockIdx, - names: &IndexSet, - ) -> bool { - predecessors[block_idx.idx()].iter().any(|pred| { - contains_debug_four_guard(&blocks[pred.idx()], names) - || (marker_only_block(&blocks[pred.idx()]) - && predecessors[pred.idx()].iter().any(|pred_pred| { - contains_debug_four_guard(&blocks[pred_pred.idx()], names) - })) - }) - } - - fn collect_unprotected_tail_segment( - blocks: &[Block], - tail: BlockIdx, - ) -> (Vec<(BlockIdx, usize)>, Vec) { - let mut in_segment = vec![false; blocks.len()]; - let mut segment = Vec::new(); - let mut cursor = tail; - if cursor == BlockIdx::NULL - || block_is_exceptional(&blocks[cursor.idx()]) - || blocks[cursor.idx()].try_else_orelse_entry - || block_has_protected_instructions(&blocks[cursor.idx()]) - { - return (segment, in_segment); - } - while cursor != BlockIdx::NULL { - let segment_block = &blocks[cursor.idx()]; - if block_is_exceptional(segment_block) - || segment_block.try_else_orelse_entry - || block_has_protected_instructions(segment_block) - { - break; - } - segment.push((cursor, 0)); - in_segment[cursor.idx()] = true; - let last_real = segment_block - .instructions - .iter() - .rev() - .find_map(|info| info.instr.real()); - if last_real.is_some_and(|instr| { - instr.is_scope_exit() || AnyInstruction::Real(instr).is_unconditional_jump() - }) { - break; - } - cursor = next_nonempty_block(blocks, segment_block.next); - } - (segment, in_segment) - } - - fn collect_unprotected_tail_region( - blocks: &[Block], - tail: BlockIdx, - ) -> (Vec<(BlockIdx, usize)>, Vec) { - let mut in_segment = vec![false; blocks.len()]; - let mut segment = Vec::new(); - let mut stack = vec![tail]; - while let Some(cursor) = stack.pop() { - if cursor == BlockIdx::NULL || in_segment[cursor.idx()] { - continue; - } - let segment_block = &blocks[cursor.idx()]; - if block_is_exceptional(segment_block) || segment_block.try_else_orelse_entry { - continue; - } - segment.push((cursor, 0)); - in_segment[cursor.idx()] = true; - let last_real = segment_block - .instructions - .iter() - .rev() - .find_map(|info| info.instr.real()); - if last_real.is_some_and(|instr| { - instr.is_scope_exit() || AnyInstruction::Real(instr).is_unconditional_jump() - }) { - continue; - } - stack.extend(normal_successors(segment_block)); - } - (segment, in_segment) - } - - fn is_plain_protected_resume_successor( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - ) -> bool { - if target == BlockIdx::NULL { - return false; - } - let mut has_handler_resume_predecessor = false; - let mut has_normal_fallthrough_predecessor = false; - for pred in &predecessors[target.idx()] { - let pred_block = &blocks[pred.idx()]; - if is_handler_resume_predecessor(pred_block, target) { - has_handler_resume_predecessor = true; - continue; - } - if block_is_exceptional(pred_block) || pred_block.cold { - continue; - } - if next_nonempty_block(blocks, pred_block.next) == target { - has_normal_fallthrough_predecessor = true; - continue; - } - return false; - } - has_handler_resume_predecessor && has_normal_fallthrough_predecessor - } - - fn protected_call_store_return_load_start(block: &Block) -> Option { - let mut saw_call = false; - let end = block - .instructions - .iter() - .position(|info| matches!(info.instr.real(), Some(Instruction::PushExcInfo))) - .unwrap_or(block.instructions.len()); - for idx in 0..end.saturating_sub(2) { - match block.instructions[idx].instr.real() { - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx, - ) => { - saw_call = true; - continue; - } - Some(Instruction::StoreFast { var_num }) if saw_call => { - let stored = usize::from(var_num.get(block.instructions[idx].arg)); - let loaded = match block.instructions[idx + 1].instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => { - usize::from(var_num.get(block.instructions[idx + 1].arg)) - } - _ => continue, - }; - if stored == loaded - && matches!( - block.instructions[idx + 2].instr.real(), - Some(Instruction::ReturnValue) - ) - { - return Some(idx + 1); - } - } - _ => {} - } - } - None - } - - fn block_has_exception_match_trailer(block: &Block) -> bool { - let mut saw_push_exc_info = false; - for info in &block.instructions { - match info.instr.real() { - Some(Instruction::PushExcInfo) => { - saw_push_exc_info = true; - } - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - if saw_push_exc_info => - { - return true; - } - _ => {} - } - } - false - } - - fn block_starts_with_borrowed_local_return(block: &Block) -> Option<(usize, usize)> { - let mut reals = block.instructions.iter().enumerate().filter(|(_, info)| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - }); - let (load_idx, load) = reals.next()?; - let local = match load.instr.real() { - Some(Instruction::LoadFastBorrow { var_num }) => usize::from(var_num.get(load.arg)), - _ => return None, - }; - let (_, ret) = reals.next()?; - if matches!(ret.instr.real(), Some(Instruction::ReturnValue)) { - Some((load_idx, local)) - } else { - None - } - } - - fn block_is_exception_match_entry(block: &Block) -> bool { - block.cold - && block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - } - - fn cold_layout_tail_reaches_exception_match_entry( - blocks: &[Block], - start: BlockIdx, - ) -> bool { - let mut cursor = start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if !block.cold { - return false; - } - if block_is_exception_match_entry(block) { - return true; - } - cursor = block.next; - } - false - } - - fn protected_call_store_local_predecessor(block: &Block, local: usize) -> bool { - if !block.disable_load_fast_borrow { - return false; - } - let mut saw_call = false; - for info in &block.instructions { - match info.instr.real() { - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx, - ) => { - saw_call = true; - } - Some(Instruction::StoreFast { var_num }) - if saw_call && usize::from(var_num.get(info.arg)) == local => - { - return true; - } - _ => {} - } - } - false - } - - fn predecessor_chain_has_protected_call_store_local( - blocks: &[Block], - predecessors: &[Vec], - target: BlockIdx, - local: usize, - ) -> bool { - let mut visited = vec![false; blocks.len()]; - let mut stack = predecessors[target.idx()].clone(); - while let Some(pred) = stack.pop() { - if pred == BlockIdx::NULL || visited[pred.idx()] { - continue; - } - visited[pred.idx()] = true; - let block = &blocks[pred.idx()]; - if protected_call_store_local_predecessor(block, local) { - return true; - } - if marker_only_block(block) { - stack.extend(predecessors[pred.idx()].iter().copied()); - } - } - false - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let mut to_deopt = Vec::new(); - let has_exception_match_handler = self.blocks.iter().any(|block| { - block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - }); - if has_exception_match_handler { - for (block_idx, block) in self.blocks.iter().enumerate() { - if block_has_protected_instructions(block) - && let Some(start) = protected_store_subscr_operand_start(block) - { - to_deopt.push((BlockIdx::new(block_idx as u32), start)); - } - } - } - for (block_idx, block) in self.blocks.iter().enumerate() { - if block_has_exception_match_trailer(block) - && let Some(start) = protected_call_store_return_load_start(block) - { - to_deopt.push((BlockIdx::new(block_idx as u32), start)); - } - } - for (block_idx, block) in self.blocks.iter().enumerate() { - if !cold_layout_tail_reaches_exception_match_entry(&self.blocks, block.next) { - continue; - } - let Some((start, local)) = block_starts_with_borrowed_local_return(block) else { - continue; - }; - if predecessor_chain_has_protected_call_store_local( - &self.blocks, - &predecessors, - BlockIdx::new(block_idx as u32), - local, - ) { - to_deopt.push((BlockIdx::new(block_idx as u32), start)); - } - } - for (block_idx, block) in self.blocks.iter().enumerate() { - if block_is_exceptional(block) - || !block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - || !block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Call { .. } - | Instruction::CallKw { .. } - | Instruction::CallFunctionEx - ) - ) - }) - || block_has_generator_delegation(block) - || !block_has_exception_match_handler(&self.blocks, block) - { - continue; - } - if let Some(start) = protected_store_subscr_operand_start(block) { - to_deopt.push((BlockIdx::new(block_idx as u32), start)); - continue; - } - if let Some((start, stored_locals)) = protected_store_bool_guard_start(block) { - if block_suffix_has_for_loop_back(block, &self.blocks, start) - && !block_suffix_starts_with_builtin_any_all_fast_path( - block, - &self.metadata.names, - start, - ) - { - continue; - } - if !handler_chain_exits_loop_after_pop_except(&self.blocks, block) { - continue; - } - let handler_has_explicit_raise = - handler_chain_has_explicit_raise(&self.blocks, block); - let jump_target = conditional_target(block); - let fallthrough = next_nonempty_block(&self.blocks, block.next); - if !handler_has_explicit_raise && let Some(jump_target) = jump_target { - let branches = [(jump_target, fallthrough), (fallthrough, jump_target)]; - for (work, exit) in branches { - if work == BlockIdx::NULL || exit == BlockIdx::NULL { - continue; - } - let work_block = &self.blocks[work.idx()]; - let exit_block = &self.blocks[exit.idx()]; - if !block_is_exceptional(work_block) - && !block_has_protected_instructions(work_block) - && block_has_tail_deopt_trigger_from(work_block, 0) - && block_is_simple_exit_branch(exit_block) - { - if normal_path_reaches_for_loop_back(&self.blocks, work) { - continue; - } - to_deopt.push((BlockIdx::new(block_idx as u32), start)); - if borrows_any_local_from(work_block, &stored_locals, 0) { - to_deopt.push((work, 0)); - } - } - } - } - } - let same_block_tail_start = first_unprotected_suffix(block); - if let Some(start) = same_block_tail_start { - if block.try_else_orelse_entry { - continue; - } - if block_suffix_has_for_loop_back(block, &self.blocks, start) - && !block_suffix_starts_with_builtin_any_all_fast_path( - block, - &self.metadata.names, - start, - ) - { - continue; - } - let stored_locals = collect_protected_stored_fast_locals_until(block, start); - let handler_has_explicit_raise = - handler_chain_has_explicit_raise(&self.blocks, block); - if stored_locals.is_empty() - || handler_has_explicit_raise - || !starts_with_borrowed_local_bool_guard_from(block, &stored_locals, start) - { - continue; - } - let jump_target = conditional_target(block); - let fallthrough = next_nonempty_block(&self.blocks, block.next); - if let Some(jump_target) = jump_target { - let branches = [(jump_target, fallthrough), (fallthrough, jump_target)]; - for (work, exit) in branches { - if work == BlockIdx::NULL || exit == BlockIdx::NULL { - continue; - } - let work_block = &self.blocks[work.idx()]; - let exit_block = &self.blocks[exit.idx()]; - if !block_is_exceptional(work_block) - && !block_has_protected_instructions(work_block) - && block_has_tail_deopt_trigger_from(work_block, 0) - && block_is_simple_exit_branch(exit_block) - { - if normal_path_reaches_for_loop_back(&self.blocks, work) { - continue; - } - to_deopt.push((BlockIdx::new(block_idx as u32), start)); - if borrows_any_local_from(work_block, &stored_locals, 0) { - to_deopt.push((work, 0)); - } - } - } - } - continue; - } - let tail = next_nonempty_block(&self.blocks, block.next); - let (segment, in_segment) = collect_unprotected_tail_segment(&self.blocks, tail); - let handler_has_nested_exception_match = - handler_chain_has_nested_exception_match(&self.blocks, block); - let linear_segment_reaches_external_backward_jump = - segment_reaches_external_backward_jump(&self.blocks, &segment, &in_segment); - let linear_segment_has_for_loop_back = - segment_has_for_loop_back(&self.blocks, &segment); - let linear_segment_has_backward_jump = - segment_has_backward_jump(&self.blocks, &segment); - let linear_segment_starts_with_builtin_any_all_fast_path = - segment_starts_with_builtin_any_all_fast_path( - &self.blocks, - &self.metadata.names, - &segment, - ); - if handler_has_nested_exception_match - && handler_chain_can_resume_to_segment(&self.blocks, block, &in_segment) - { - for (block_idx, start) in &segment { - if let Some(update_start) = - borrowed_inplace_local_update_start(&self.blocks[block_idx.idx()]) - { - to_deopt.push((*block_idx, (*start).max(update_start))); - } - } - } - let segment_has_yield = segment_has_yield_value(&self.blocks, &segment); - let mut stored_locals = - collect_protected_stored_fast_locals_until(block, block.instructions.len()); - if stored_locals.is_empty() && segment_has_yield { - stored_locals = collect_protected_predecessor_stored_fast_locals( - &self.blocks, - &predecessors, - BlockIdx::new(block_idx as u32), - ); - } - if stored_locals.is_empty() { - continue; - } - let borrowed_stored_locals = - collect_borrowed_stored_locals_in_segment(&self.blocks, &segment, &stored_locals); - if !handler_has_nested_exception_match - && handler_chain_resumes_after_assigning_locals( - &self.blocks, - block, - &in_segment, - &borrowed_stored_locals, - ) - { - continue; - } - let handler_has_explicit_raise = handler_chain_has_explicit_raise(&self.blocks, block); - if !handler_has_explicit_raise - && segment_has_yield - && segment.iter().any(|(block_idx, start)| { - block_has_tail_deopt_trigger_from(&self.blocks[block_idx.idx()], *start) - }) - && segment.iter().any(|(block_idx, start)| { - borrows_any_local_from(&self.blocks[block_idx.idx()], &stored_locals, *start) - }) - { - let (yield_tail_region, _) = collect_unprotected_tail_region(&self.blocks, tail); - for (block_idx, start) in yield_tail_region { - to_deopt.push((block_idx, start)); - } - continue; - } - if tail != BlockIdx::NULL - && !block_is_exceptional(&self.blocks[tail.idx()]) - && !self.blocks[tail.idx()].try_else_orelse_entry - && !block_has_protected_instructions(&self.blocks[tail.idx()]) - && !is_plain_protected_resume_successor(&self.blocks, &predecessors, tail) - && starts_with_borrowed_local_bool_guard(&self.blocks[tail.idx()], &stored_locals) - && !handler_has_explicit_raise - { - let jump_target = conditional_target(&self.blocks[tail.idx()]); - let fallthrough = next_nonempty_block(&self.blocks, self.blocks[tail.idx()].next); - if let Some(jump_target) = jump_target { - let branches = [(jump_target, fallthrough), (fallthrough, jump_target)]; - for (work, exit) in branches { - if work == BlockIdx::NULL || exit == BlockIdx::NULL { - continue; - } - let work_block = &self.blocks[work.idx()]; - let exit_block = &self.blocks[exit.idx()]; - if !block_is_exceptional(work_block) - && !block_has_protected_instructions(work_block) - && block_has_tail_deopt_trigger_from(work_block, 0) - && block_is_simple_exit_branch(exit_block) - { - if normal_path_reaches_for_loop_back(&self.blocks, work) { - continue; - } - to_deopt.push((tail, 0)); - if borrows_any_local_from(work_block, &stored_locals, 0) { - to_deopt.push((work, 0)); - } - } - } - } - } - if segment.is_empty() - || segment.iter().any(|(block_idx, _)| { - contains_debug_four_guard(&self.blocks[block_idx.idx()], &self.metadata.names) - }) - || predecessor_chain_contains_debug_four_guard( - &self.blocks, - &predecessors, - segment[0].0, - &self.metadata.names, - ) - || !segment.iter().any(|(block_idx, start)| { - block_has_tail_deopt_trigger_from(&self.blocks[block_idx.idx()], *start) - }) - || segment_has_named_except_cleanup_predecessor( - &self.blocks, - &predecessors, - &segment, - ) - || (linear_segment_has_for_loop_back - && !linear_segment_starts_with_builtin_any_all_fast_path) - || (handler_chain_can_resume_to_segment(&self.blocks, block, &in_segment) - && (linear_segment_reaches_external_backward_jump - || (!handler_has_nested_exception_match - && linear_segment_has_backward_jump))) - || segment.iter().any(|(block_idx, _)| { - predecessors[block_idx.idx()].iter().any(|pred| { - let pred_block = &self.blocks[pred.idx()]; - !in_segment[pred.idx()] - && !block_is_exceptional(pred_block) - && !pred_block.cold - && !block_has_protected_instructions(pred_block) - && block_has_non_nop_real_instructions(pred_block) - }) - }) - || !segment.iter().any(|(block_idx, start)| { - borrows_any_local_from(&self.blocks[block_idx.idx()], &stored_locals, *start) - }) - { - continue; - } - let (deopt_segment, deopt_in_segment) = if handler_has_nested_exception_match - && !linear_segment_reaches_external_backward_jump - { - collect_unprotected_tail_region(&self.blocks, tail) - } else { - (segment, in_segment) - }; - if segment_has_for_loop_back(&self.blocks, &deopt_segment) - && !segment_starts_with_builtin_any_all_fast_path( - &self.blocks, - &self.metadata.names, - &deopt_segment, - ) - { - continue; - } - let deopt_segment_reaches_external_backward_jump = - segment_reaches_external_backward_jump( - &self.blocks, - &deopt_segment, - &deopt_in_segment, - ); - for (block_idx, start) in deopt_segment { - if deopt_segment_reaches_external_backward_jump - && predecessors[block_idx.idx()].iter().any(|pred| { - is_handler_resume_predecessor(&self.blocks[pred.idx()], block_idx) - }) - { - continue; - } - to_deopt.push((block_idx, start)); - } - } - - let mut continue_targets = Vec::new(); - for (handler_idx, block) in self.blocks.iter().enumerate() { - if !block.cold - || !block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) - { - continue; - } - let mut visited = vec![false; self.blocks.len()]; - let mut cursor = BlockIdx::new(handler_idx as u32); - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let handler = &self.blocks[cursor.idx()]; - let has_pop_except = handler - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - if has_pop_except { - for info in &handler.instructions { - if info.target != BlockIdx::NULL - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - { - continue_targets.push(info.target); - } - } - } - cursor = handler.next; - } - } - - continue_targets.sort_by_key(|idx| idx.idx()); - continue_targets.dedup(); - for target in continue_targets { - let block = &self.blocks[target.idx()]; - if block.cold - || block_is_exceptional(block) - || !block_has_tail_deopt_trigger_from(block, 0) - { - continue; - } - let stored_locals = collect_stored_fast_locals_until(block, block.instructions.len()); - if stored_locals.is_empty() { - continue; - } - let tail = next_nonempty_block(&self.blocks, block.next); - if tail == BlockIdx::NULL - || block_is_exceptional(&self.blocks[tail.idx()]) - || !block_has_tail_deopt_trigger_from(&self.blocks[tail.idx()], 0) - || !borrows_any_local_from(&self.blocks[tail.idx()], &stored_locals, 0) - { - continue; - } - let tail_jumps_back_to_target = - self.blocks[tail.idx()].instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - if tail_jumps_back_to_target { - if normal_path_reaches_for_loop_back(&self.blocks, tail) - && !block_suffix_starts_with_builtin_any_all_fast_path( - &self.blocks[tail.idx()], - &self.metadata.names, - 0, - ) - { - continue; - } - to_deopt.push((tail, 0)); - } - } - - for block in &self.blocks { - if block.cold - || block_is_exceptional(block) - || !block_is_normal_cleanup_call(block, &self.metadata) - { - continue; - } - let tail = next_nonempty_block(&self.blocks, block.next); - let (region, _) = collect_unprotected_tail_region(&self.blocks, tail); - if region.is_empty() - || !segment_has_yield_value(&self.blocks, ®ion) - || !region.iter().any(|(block_idx, start)| { - block_has_tail_deopt_trigger_from(&self.blocks[block_idx.idx()], *start) - }) - { - continue; - } - for (block_idx, start) in region { - to_deopt.push((block_idx, start)); - } - } - - to_deopt.sort_by_key(|(idx, start)| (idx.idx(), *start)); - let mut merged: Vec<(BlockIdx, usize)> = Vec::new(); - for (idx, start) in to_deopt { - match merged.last_mut() { - Some((last_idx, last_start)) if *last_idx == idx => { - *last_start = (*last_start).min(start); - } - _ => merged.push((idx, start)), - } - } - for (block_idx, start) in merged { - if block_has_attr_named(&self.blocks[block_idx.idx()], &self.metadata.names, "_mesg") { - continue; - } - if contains_debug_four_guard(&self.blocks[block_idx.idx()], &self.metadata.names) - || predecessor_chain_contains_debug_four_guard( - &self.blocks, - &predecessors, - block_idx, - &self.metadata.names, - ) - { - continue; - } - deoptimize_block_borrows_from(&mut self.blocks[block_idx.idx()], start); - } - } - - fn reborrow_after_suppressing_handler_resume_cleanup(&mut self) { - fn is_suppressing_handler_resume_to(block: &Block, target: BlockIdx) -> bool { - let has_pop_except = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))); - let clears_exception_name = block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::StoreFast { .. }))) - && block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::DeleteFast { .. }))); - let jumps_to_target = block.instructions.iter().any(|info| { - info.target == target - && matches!( - info.instr.real(), - Some( - Instruction::JumpForward { .. } - | Instruction::JumpBackward { .. } - | Instruction::JumpBackwardNoInterrupt { .. } - ) - ) - }); - let reraises = block.instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::RaiseVarargs { .. } | Instruction::Reraise { .. }) - ) - }); - has_pop_except && clears_exception_name && jumps_to_target && !reraises - } - - fn block_enters_with_context(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::LoadSpecial { .. }))) - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - for (block_idx, predecessors) in predecessors.iter().enumerate() { - let target = BlockIdx::new(block_idx as u32); - if self.blocks[block_idx].cold || block_is_exceptional(&self.blocks[block_idx]) { - continue; - } - if !predecessors - .iter() - .any(|pred| is_suppressing_handler_resume_to(&self.blocks[pred.idx()], target)) - { - continue; - } - if block_enters_with_context(&self.blocks[block_idx]) { - continue; - } - let starts_with_cleanup_call = self.blocks[block_idx] - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .take(5) - .any(|instr| matches!(instr, Instruction::Call { .. })); - if !starts_with_cleanup_call { - continue; - } - let starts_with_named_except_value_load = self.blocks[block_idx] - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .take(5) - .any(|instr| matches!(instr, Instruction::LoadFastCheck { .. })); - if starts_with_named_except_value_load { - continue; - } - let first_real = self.blocks[block_idx].instructions.iter().position(|info| { - info.instr - .real() - .is_some_and(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - }); - if let Some(first_real) = first_real { - let starts_with_receiver_load = matches!( - ( - self.blocks[block_idx].instructions[first_real].instr.real(), - self.blocks[block_idx] - .instructions - .get(first_real + 1) - .and_then(|info| info.instr.real()), - ), - ( - Some(Instruction::LoadFast { .. }), - Some(Instruction::LoadAttr { .. } | Instruction::LoadSuperAttr { .. }) - ) - ); - if starts_with_receiver_load { - self.blocks[block_idx].instructions[first_real].instr = - Instruction::LoadFastBorrow { - var_num: Arg::marker(), - } - .into(); - } - } - } - } - - fn deoptimize_borrow_before_import_after_join_store(&mut self) { - let mut predecessor_count = vec![0usize; self.blocks.len()]; - for block in &self.blocks { - if block.next != BlockIdx::NULL { - predecessor_count[block.next.idx()] += 1; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessor_count[info.target.idx()] += 1; - } - } - } - - for (block_idx, block) in self.blocks.iter_mut().enumerate() { - if predecessor_count[block_idx] < 2 { - continue; - } - - let len = block.instructions.len(); - let first_import_from = block - .instructions - .iter() - .position(|info| matches!(info.instr.real(), Some(Instruction::ImportFrom { .. }))); - if let Some(first_import_from) = first_import_from { - for idx in 0..first_import_from { - if matches!( - block.instructions[idx].instr.real(), - Some(Instruction::LoadFastBorrow { .. }) - ) { - block.instructions[idx].instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - } - } - - if !block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ImportName { .. }))) - { - continue; - } - - for idx in 0..len.saturating_sub(1) { - if !matches!( - block.instructions[idx].instr.real(), - Some(Instruction::LoadFastBorrow { .. }) - ) { - continue; - } - if !matches!( - block.instructions[idx + 1].instr.real(), - Some(Instruction::StoreGlobal { .. }) - ) { - continue; - } - block.instructions[idx].instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - } - } - - fn deoptimize_borrow_for_match_keys_attr(&mut self) { - let Some(key_name_idx) = self.metadata.names.get_index_of("KEY") else { - return; - }; - - let mut to_deopt = Vec::new(); - for block_idx in 0..self.blocks.len() { - let block = &self.blocks[block_idx]; - let len = block.instructions.len(); - for i in 0..len { - let Some(Instruction::LoadFastBorrow { .. }) = block.instructions[i].instr.real() - else { - continue; - }; - let Some(Instruction::LoadAttr { namei }) = block - .instructions - .get(i + 1) - .and_then(|info| info.instr.real()) - else { - continue; - }; - let load_attr = namei.get(block.instructions[i + 1].arg); - if load_attr.is_method() || load_attr.name_idx() as usize != key_name_idx { - continue; - } - - let mut saw_build_tuple = false; - let mut saw_match_keys = false; - let mut scan_block_idx = block_idx; - let mut scan_start = i + 2; - loop { - let scan_block = &self.blocks[scan_block_idx]; - for info in scan_block.instructions.iter().skip(scan_start) { - match info.instr.real() { - Some( - Instruction::LoadConst { .. } - | Instruction::LoadSmallInt { .. } - | Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadAttr { .. } - | Instruction::Nop, - ) => {} - Some(Instruction::BuildTuple { .. }) => saw_build_tuple = true, - Some(Instruction::MatchKeys) => { - saw_match_keys = true; - break; - } - _ => { - saw_build_tuple = false; - break; - } - } - } - if saw_match_keys { - break; - } - let Some(last) = scan_block.instructions.last() else { - break; - }; - if scan_block.next == BlockIdx::NULL - || last.instr.is_scope_exit() - || last.instr.is_unconditional_jump() - || last.target != BlockIdx::NULL - { - break; - } - scan_block_idx = scan_block.next.idx(); - scan_start = 0; - } - - if saw_build_tuple && saw_match_keys { - to_deopt.push((block_idx, i)); - } - } - } - - for (block_idx, instr_idx) in to_deopt { - self.blocks[block_idx].instructions[instr_idx].instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - } - - fn deoptimize_borrow_in_protected_attr_chain_tail(&mut self) { - fn second_last_real_instr(block: &Block) -> Option { - let mut reals = block - .instructions - .iter() - .rev() - .filter_map(|info| info.instr.real()); - let _last = reals.next()?; - reals.next() - } - - fn block_ends_with_suppressing_with_resume_jump(block: &Block) -> bool { - let mut reals = block - .instructions - .iter() - .rev() - .filter_map(|info| info.instr.real()); - let Some(last) = reals.next() else { - return false; - }; - if !last.is_unconditional_jump() { - return false; - } - matches!( - (reals.next(), reals.next(), reals.next(), reals.next()), - ( - Some(Instruction::PopTop), - Some(Instruction::PopTop), - Some(Instruction::PopTop), - Some(Instruction::PopExcept) - ) - ) - } - - fn block_ends_with_handler_resume_jump(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - && block.instructions.last().is_some_and(|info| { - info.target != BlockIdx::NULL && info.instr.is_unconditional_jump() - }) - } - - fn handler_chain_returns_before_resume(blocks: &[Block], handler_block: BlockIdx) -> bool { - let mut cursor = handler_block; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - if block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ReturnValue))) - { - return true; - } - if block_ends_with_handler_resume_jump(block) { - return false; - } - cursor = block.next; - } - false - } - - fn block_has_returning_exception_match_handler(blocks: &[Block], block: &Block) -> bool { - let mut visited = vec![false; blocks.len()]; - let handler_blocks: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - for handler_block in handler_blocks { - let mut cursor = handler_block; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - if blocks[cursor.idx()].instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) { - return handler_chain_returns_before_resume(blocks, handler_block); - } - cursor = blocks[cursor.idx()].next; - } - } - false - } - - fn deoptimize_borrow(info: &mut InstructionInfo) { - match info.instr.real() { - Some(Instruction::LoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - info.instr = Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - - fn is_attr_load(instr: Instruction) -> bool { - matches!( - instr, - Instruction::LoadAttr { .. } | Instruction::LoadSuperAttr { .. } - ) - } - - fn attr_load_is_method(info: InstructionInfo) -> bool { - match info.instr.real() { - Some(Instruction::LoadAttr { namei }) => namei.get(info.arg).is_method(), - Some(Instruction::LoadSuperAttr { namei }) => namei.get(info.arg).is_load_method(), - _ => false, - } - } - - fn is_subscript_index_setup(instr: Instruction) -> bool { - matches!( - instr, - Instruction::LoadConst { .. } - | Instruction::LoadSmallInt { .. } - | Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastCheck { .. } - | Instruction::Nop - ) - } - - enum DeoptKind { - ReturnIter { - tail_start_idx: usize, - }, - Subscript { - binary_op_idx: usize, - direct_root: bool, - }, - } - - fn should_deopt_borrowed_attr_chain( - real_instrs: &[(usize, InstructionInfo)], - load_idx: usize, - ) -> Option { - let mut cursor = load_idx + 1; - let mut last_attr_is_method = false; - while let Some((_, info)) = real_instrs.get(cursor) { - if !info.instr.real().is_some_and(is_attr_load) { - break; - } - last_attr_is_method = attr_load_is_method(*info); - cursor += 1; - } - let direct_root = cursor == load_idx + 1; - if direct_root - && !real_instrs.get(cursor).is_some_and(|(_, info)| { - info.instr.real().is_some_and(is_subscript_index_setup) - }) - { - return None; - } - - let (_, next_info) = real_instrs.get(cursor)?; - - match next_info.instr.real() { - Some(Instruction::GetIter) => Some(DeoptKind::ReturnIter { - tail_start_idx: cursor + 1, - }), - Some(Instruction::Call { .. } | Instruction::CallKw { .. }) => real_instrs - .get(cursor + 1) - .and_then(|(_, info)| info.instr.real()) - .and_then(|instr| { - matches!(instr, Instruction::GetIter).then_some(DeoptKind::ReturnIter { - tail_start_idx: cursor + 2, - }) - }), - _ => { - if last_attr_is_method { - return None; - } - while real_instrs.get(cursor).is_some_and(|(_, info)| { - info.instr.real().is_some_and(is_subscript_index_setup) - }) { - cursor += 1; - } - real_instrs.get(cursor).and_then(|(_, info)| { - matches!( - info.instr.real(), - Some(Instruction::BinaryOp { op }) - if op.get(info.arg) == oparg::BinaryOperator::Subscr - ) - .then_some(DeoptKind::Subscript { - binary_op_idx: cursor, - direct_root, - }) - }) - } - } - } - - fn tail_returns_without_store( - blocks: &[Block], - is_pre_handler: &[bool], - start_block_idx: BlockIdx, - start_instr_idx: usize, - ) -> bool { - let mut block_idx = start_block_idx; - let mut current_start = start_instr_idx; - for _ in 0..blocks.len() { - if block_idx == BlockIdx::NULL || !is_pre_handler[block_idx.idx()] { - break; - } - let block = &blocks[block_idx.idx()]; - for info in block.instructions.iter().skip(current_start) { - match info.instr.real() { - Some(Instruction::ReturnValue) => return true, - Some( - Instruction::StoreFast { .. } - | Instruction::StoreFastLoadFast { .. } - | Instruction::StoreFastStoreFast { .. } - | Instruction::DeleteFast { .. } - | Instruction::LoadFastAndClear { .. }, - ) => return false, - _ => {} - } - } - block_idx = block.next; - current_start = 0; - } - false - } - - let mut order = Vec::new(); - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - order.push(current); - current = self.blocks[current.idx()].next; - } - - let mut has_handler_resume_predecessor = vec![false; self.blocks.len()]; - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - let Some(last_info) = block.instructions.last() else { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - continue; - }; - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx::new(pred_idx as u32)); - } - if last_info.target == BlockIdx::NULL || !last_info.instr.is_unconditional_jump() { - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - continue; - } - let is_handler_resume_jump = - matches!(second_last_real_instr(block), Some(Instruction::PopExcept)) - || block_ends_with_suppressing_with_resume_jump(block); - if !is_handler_resume_jump { - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - continue; - } - has_handler_resume_predecessor[last_info.target.idx()] = true; - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx::new(pred_idx as u32)); - } - } - } - - let Some(first_handler_pos) = order.iter().position(|block_idx| { - self.blocks[block_idx.idx()] - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PushExcInfo))) - }) else { - return; - }; - let mut is_pre_handler = vec![false; self.blocks.len()]; - for &block_idx in &order[..first_handler_pos] { - is_pre_handler[block_idx.idx()] = true; - } - let mut is_protected_source = vec![false; self.blocks.len()]; - let mut reachable_from_protected_predecessor = vec![false; self.blocks.len()]; - let mut reachable_from_protected = vec![false; self.blocks.len()]; - let mut direct_subscript_from_returning_protected = vec![false; self.blocks.len()]; - for &block_idx in &order[..first_handler_pos] { - let idx = block_idx.idx(); - let block = &self.blocks[idx]; - is_protected_source[idx] = block_has_exception_match_handler(&self.blocks, block); - let has_direct_normal_protected_predecessor = predecessors[idx].iter().any(|pred| { - !block_is_exceptional(&self.blocks[pred.idx()]) - && self.blocks[pred.idx()] - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - }); - let has_direct_returning_protected_predecessor = predecessors[idx].iter().any(|pred| { - !block_is_exceptional(&self.blocks[pred.idx()]) - && block_has_returning_exception_match_handler( - &self.blocks, - &self.blocks[pred.idx()], - ) - }); - let has_unprotected_normal_predecessor = predecessors[idx].iter().any(|pred| { - !block_is_exceptional(&self.blocks[pred.idx()]) - && !self.blocks[pred.idx()] - .instructions - .iter() - .any(|info| info.except_handler.is_some()) - }); - reachable_from_protected_predecessor[idx] = - has_direct_normal_protected_predecessor && !has_unprotected_normal_predecessor; - reachable_from_protected[idx] = (is_protected_source[idx] - || has_direct_normal_protected_predecessor) - && !has_unprotected_normal_predecessor; - direct_subscript_from_returning_protected[idx] = - has_direct_returning_protected_predecessor && !has_unprotected_normal_predecessor; - } - - let mut cross_block_deopts = Vec::new(); - for &block_idx in &order[..first_handler_pos] { - if has_handler_resume_predecessor[block_idx.idx()] { - continue; - } - let block_instr_len = self.blocks[block_idx.idx()].instructions.len(); - let real_instrs: Vec<_> = self.blocks[block_idx.idx()] - .instructions - .iter() - .copied() - .enumerate() - .filter(|(_, info)| info.instr.real().is_some()) - .collect(); - let mut to_deopt = Vec::new(); - for (real_idx, (instr_idx, info)) in real_instrs.iter().enumerate() { - let has_prior_protected_instr = real_instrs[..real_idx] - .iter() - .any(|(_, info)| info.except_handler.is_some()); - let is_attr_chain_root = matches!( - info.instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }) - ); - if info.except_handler.is_some() || !is_attr_chain_root { - continue; - } - let Some(deopt_kind) = should_deopt_borrowed_attr_chain(&real_instrs, real_idx) - else { - continue; - }; - if let DeoptKind::ReturnIter { tail_start_idx } = deopt_kind { - if !(reachable_from_protected_predecessor[block_idx.idx()] - || is_protected_source[block_idx.idx()]) - { - continue; - } - let tail_instr_idx = real_instrs - .get(tail_start_idx) - .map_or(block_instr_len, |(instr_idx, _)| *instr_idx); - if !tail_returns_without_store( - &self.blocks, - &is_pre_handler, - block_idx, - tail_instr_idx, - ) { - continue; - } - } - if let DeoptKind::Subscript { direct_root, .. } = deopt_kind { - let should_deopt_subscript = if direct_root { - direct_subscript_from_returning_protected[block_idx.idx()] - } else { - reachable_from_protected_predecessor[block_idx.idx()] - || (is_protected_source[block_idx.idx()] && has_prior_protected_instr) - }; - if !should_deopt_subscript { - continue; - } - } - if matches!(info.instr.real(), Some(Instruction::LoadFastBorrow { .. })) { - to_deopt.push(*instr_idx); - } - if let DeoptKind::Subscript { binary_op_idx, .. } = deopt_kind { - for (extra_instr_idx, extra_info) in real_instrs - .iter() - .skip(real_idx + 1) - .take(binary_op_idx.saturating_sub(real_idx + 1)) - .map(|(idx, info)| (*idx, *info)) - { - if matches!( - extra_info.instr.real(), - Some(Instruction::LoadFastBorrow { .. }) - ) { - to_deopt.push(extra_instr_idx); - } - } - if matches!( - real_instrs - .get(binary_op_idx + 1) - .and_then(|(_, info)| info.instr.real()), - Some(Instruction::StoreFast { .. }) - ) { - for (extra_instr_idx, extra_info) in - real_instrs.iter().skip(binary_op_idx + 2) - { - if matches!( - extra_info.instr.real(), - Some( - Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) { - to_deopt.push(*extra_instr_idx); - } - } - let mut linear_tail = vec![block_idx]; - let mut cursor = self.blocks[block_idx.idx()].next; - while cursor != BlockIdx::NULL - && is_pre_handler[cursor.idx()] - && !block_is_exceptional(&self.blocks[cursor.idx()]) - { - if predecessors[cursor.idx()].iter().any(|pred| { - !linear_tail.contains(pred) - && !has_handler_resume_predecessor[pred.idx()] - }) { - break; - } - linear_tail.push(cursor); - cursor = self.blocks[cursor.idx()].next; - } - for tail_block_idx in linear_tail.into_iter().skip(1) { - for (tail_instr_idx, tail_info) in self.blocks[tail_block_idx.idx()] - .instructions - .iter() - .enumerate() - { - if matches!( - tail_info.instr.real(), - Some( - Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - ) - ) { - cross_block_deopts.push((tail_block_idx, tail_instr_idx)); - } - } - } - } - } - } - let block = &mut self.blocks[block_idx.idx()]; - for instr_idx in to_deopt { - deoptimize_borrow(&mut block.instructions[instr_idx]); - } - } - for (block_idx, instr_idx) in cross_block_deopts { - match self.blocks[block_idx.idx()].instructions[instr_idx] - .instr - .real() - { - Some(Instruction::LoadFastBorrow { .. }) => { - self.blocks[block_idx.idx()].instructions[instr_idx].instr = - Instruction::LoadFast { - var_num: Arg::marker(), - } - .into(); - } - Some(Instruction::LoadFastBorrowLoadFastBorrow { .. }) => { - self.blocks[block_idx.idx()].instructions[instr_idx].instr = - Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - } - .into(); - } - _ => {} - } - } - } - - fn deoptimize_store_fast_store_fast_after_cleanup(&mut self) { - fn last_real_instr(block: &Block) -> Option { - block - .instructions - .iter() - .rev() - .find_map(|info| info.instr.real()) - } - - fn is_cleanup_restore_prefix(instructions: &[InstructionInfo]) -> bool { - let mut saw_pop_iter = false; - for info in instructions { - match info.instr.real() { - Some(Instruction::EndFor) if !saw_pop_iter => {} - Some(Instruction::PopIter) if !saw_pop_iter => saw_pop_iter = true, - Some(Instruction::Swap { .. } | Instruction::PopTop) if saw_pop_iter => {} - _ => return false, - } - } - saw_pop_iter - } - - let mut predecessors = vec![Vec::new(); self.blocks.len()]; - for (pred_idx, block) in self.blocks.iter().enumerate() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()].push(BlockIdx(pred_idx as u32)); - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()].push(BlockIdx(pred_idx as u32)); - } - } - } - - let starts_after_cleanup: Vec = predecessors - .iter() - .map(|predecessor_blocks| { - !predecessor_blocks.is_empty() - && predecessor_blocks.iter().copied().all(|pred_idx| { - matches!( - last_real_instr(&self.blocks[pred_idx]), - Some(Instruction::PopIter | Instruction::Swap { .. }) - ) - }) - }) - .collect(); - - for (block_idx, block) in self.blocks.iter_mut().enumerate() { - let mut new_instructions = Vec::with_capacity(block.instructions.len()); - let mut in_restore_prefix = starts_after_cleanup[block_idx]; - for (i, info) in block.instructions.iter().copied().enumerate() { - if !in_restore_prefix - && matches!( - info.instr.real(), - Some( - Instruction::StoreFast { .. } | Instruction::StoreFastStoreFast { .. } - ) - ) - && !new_instructions.is_empty() - && (new_instructions.iter().all(|prev: &InstructionInfo| { - matches!( - prev.instr.real(), - Some(Instruction::Swap { .. } | Instruction::PopTop) - ) - }) || is_cleanup_restore_prefix(&new_instructions)) - { - in_restore_prefix = true; - } - let expand = matches!( - info.instr.real(), - Some(Instruction::StoreFastStoreFast { .. }) - ) && (is_cleanup_restore_prefix(&new_instructions) - || (i == 0 && starts_after_cleanup[block_idx]) - || in_restore_prefix); - - if expand { - let Some(Instruction::StoreFastStoreFast { var_nums }) = info.instr.real() - else { - unreachable!(); - }; - let packed = var_nums.get(info.arg); - let (idx1, idx2) = packed.indexes(); - - let mut first = info; - first.instr = Instruction::StoreFast { - var_num: Arg::marker(), - } - .into(); - first.arg = OpArg::new(u32::from(idx1)); - new_instructions.push(first); - - let mut second = info; - second.instr = Instruction::StoreFast { - var_num: Arg::marker(), - } - .into(); - second.arg = OpArg::new(u32::from(idx2)); - new_instructions.push(second); - continue; - } - - in_restore_prefix &= - matches!(info.instr.real(), Some(Instruction::StoreFast { .. })); - new_instructions.push(info); - } - block.instructions = new_instructions; - } - } - - fn fast_scan_many_locals( - &mut self, - nlocals: usize, - nparams: usize, - merged_cell_local: &impl Fn(usize) -> Option, - ) { - const PARAM_INITIALIZED: usize = usize::MAX; - - debug_assert!(nlocals > 64); - let mut states = vec![0usize; nlocals - 64]; - let high_params = nparams.saturating_sub(64).min(states.len()); - for state in states.iter_mut().take(high_params) { - *state = PARAM_INITIALIZED; - } - - let is_known = |idx: usize, state: usize, blocknum: usize| { - state == blocknum || (idx < nparams && state == PARAM_INITIALIZED) - }; - - let mut blocknum = 0usize; - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - blocknum += 1; - let old_instructions = self.blocks[current.idx()].instructions.clone(); - let mut new_instructions = Vec::with_capacity(old_instructions.len()); - let mut changed = false; - - for mut info in old_instructions { - match info.instr.real() { - Some( - Instruction::DeleteFast { var_num } - | Instruction::LoadFastAndClear { var_num }, - ) => { - let idx = usize::from(var_num.get(info.arg)); - if idx >= 64 && idx < nlocals { - states[idx - 64] = blocknum - 1; - } - new_instructions.push(info); - } - None if matches!( - info.instr.pseudo(), - Some(PseudoInstruction::StoreFastMaybeNull { .. }) - ) => - { - let Some(PseudoInstruction::StoreFastMaybeNull { var_num }) = - info.instr.pseudo() - else { - unreachable!(); - }; - let idx = var_num.get(info.arg) as usize; - if idx >= 64 && idx < nlocals { - states[idx - 64] = blocknum - 1; - } - new_instructions.push(info); - } - Some(Instruction::DeleteDeref { i }) => { - let cell_relative = usize::from(i.get(info.arg)); - if let Some(idx) = merged_cell_local(cell_relative) - && idx >= 64 - && idx < nlocals - { - states[idx - 64] = blocknum - 1; - } - new_instructions.push(info); - } - Some(Instruction::StoreFast { var_num }) => { - let idx = usize::from(var_num.get(info.arg)); - if idx >= 64 && idx < nlocals { - states[idx - 64] = blocknum; - } - new_instructions.push(info); - } - Some(Instruction::StoreDeref { i }) => { - let cell_relative = usize::from(i.get(info.arg)); - if let Some(idx) = merged_cell_local(cell_relative) - && idx >= 64 - && idx < nlocals - { - states[idx - 64] = blocknum; - } - new_instructions.push(info); - } - Some(Instruction::StoreFastStoreFast { var_nums }) => { - let packed = var_nums.get(info.arg); - let (idx1, idx2) = packed.indexes(); - let idx1 = usize::from(idx1); - let idx2 = usize::from(idx2); - if idx1 >= 64 && idx1 < nlocals { - states[idx1 - 64] = blocknum; - } - if idx2 >= 64 && idx2 < nlocals { - states[idx2 - 64] = blocknum; - } - new_instructions.push(info); - } - Some(Instruction::StoreFastLoadFast { var_nums }) => { - let packed = var_nums.get(info.arg); - let (store_idx, load_idx) = packed.indexes(); - let store_idx = usize::from(store_idx); - let load_idx = usize::from(load_idx); - if store_idx >= 64 && store_idx < nlocals { - states[store_idx - 64] = blocknum; - } - if load_idx >= 64 && load_idx < nlocals { - if !is_known(load_idx, states[load_idx - 64], blocknum) { - let mut first = info; - first.instr = Instruction::StoreFast { - var_num: Arg::marker(), - } - .into(); - first.arg = OpArg::new(store_idx as u32); - - let mut second = info; - second.instr = Opcode::LoadFastCheck.into(); - second.arg = OpArg::new(load_idx as u32); - - new_instructions.push(first); - new_instructions.push(second); - changed = true; - } else { - new_instructions.push(info); - } - } else { - new_instructions.push(info); - } - } - Some(Instruction::LoadFast { var_num }) => { - let idx = usize::from(var_num.get(info.arg)); - if idx >= 64 && idx < nlocals && !is_known(idx, states[idx - 64], blocknum) - { - info.instr = Opcode::LoadFastCheck.into(); - states[idx - 64] = blocknum; - changed = true; - } - new_instructions.push(info); - } - Some(Instruction::LoadFastLoadFast { var_nums }) => { - let packed = var_nums.get(info.arg); - let (idx1, idx2) = packed.indexes(); - let idx1 = usize::from(idx1); - let idx2 = usize::from(idx2); - let needs_check_1 = idx1 >= 64 - && idx1 < nlocals - && !is_known(idx1, states[idx1 - 64], blocknum); - if needs_check_1 { - states[idx1 - 64] = blocknum; - } - let needs_check_2 = idx2 >= 64 - && idx2 < nlocals - && !is_known(idx2, states[idx2 - 64], blocknum); - - if needs_check_1 || needs_check_2 { - let mut first = info; - first.instr = if needs_check_1 { - Opcode::LoadFastCheck - } else { - Opcode::LoadFast - } - .into(); - first.arg = OpArg::new(idx1 as u32); - - let mut second = info; - second.instr = if needs_check_2 { - Opcode::LoadFastCheck.into() - } else { - Opcode::LoadFast.into() - }; - second.arg = OpArg::new(idx2 as u32); - - new_instructions.push(first); - new_instructions.push(second); - changed = true; - if needs_check_2 { - states[idx2 - 64] = blocknum; - } - } else { - new_instructions.push(info); - } - } - Some(Instruction::LoadFastCheck { var_num }) => { - let idx = usize::from(var_num.get(info.arg)); - if idx >= 64 && idx < nlocals { - states[idx - 64] = blocknum; - } - new_instructions.push(info); - } - _ => new_instructions.push(info), - } - } - - if changed { - self.blocks[current.idx()].instructions = new_instructions; - } - current = self.blocks[current.idx()].next; - } - } - - fn add_checks_for_loads_of_uninitialized_variables(&mut self) { - let mut nlocals = self.metadata.varnames.len(); - if nlocals == 0 { - return; - } - - let cell_to_local: Vec<_> = self - .metadata - .cellvars - .iter() - .map(|name| self.metadata.varnames.get_index_of(name.as_str())) - .collect(); - let merged_cell_local = - |cell_relative: usize| cell_to_local.get(cell_relative).copied().flatten(); - - let mut nparams = self.metadata.argcount as usize + self.metadata.kwonlyargcount as usize; - if self.flags.contains(CodeFlags::VARARGS) { - nparams += 1; - } - if self.flags.contains(CodeFlags::VARKEYWORDS) { - nparams += 1; - } - nparams = nparams.min(nlocals); - - if nlocals > 64 { - self.fast_scan_many_locals(nlocals, nparams, &merged_cell_local); - nlocals = 64; - } - - let mut in_masks: Vec>> = vec![None; self.blocks.len()]; - let mut start_mask = vec![false; nlocals]; - for slot in start_mask.iter_mut().skip(nparams) { - *slot = true; - } - in_masks[0] = Some(start_mask); - - let mut worklist = vec![BlockIdx(0)]; - while let Some(block_idx) = worklist.pop() { - let idx = block_idx.idx(); - let Some(mut unsafe_mask) = in_masks[idx].clone() else { - continue; - }; - - let old_instructions = self.blocks[idx].instructions.clone(); - let mut new_instructions = Vec::with_capacity(old_instructions.len()); - let mut changed = false; - - for info in old_instructions { - let mut info = info; - if let Some(eh) = info.except_handler { - let target = next_nonempty_block(&self.blocks, eh.handler_block); - if target != BlockIdx::NULL - && merge_unsafe_mask(&mut in_masks[target.idx()], &unsafe_mask) - { - worklist.push(target); - } - } - if matches!(info.instr.real(), Some(Instruction::ForIter { .. })) - && info.target != BlockIdx::NULL - && merge_unsafe_mask(&mut in_masks[info.target.idx()], &unsafe_mask) - { - worklist.push(info.target); - } - match info.instr.real() { - None if matches!( - info.instr.pseudo(), - Some(PseudoInstruction::StoreFastMaybeNull { .. }) - ) => - { - let Some(PseudoInstruction::StoreFastMaybeNull { var_num }) = - info.instr.pseudo() - else { - unreachable!(); - }; - let var_idx = var_num.get(info.arg) as usize; - if var_idx < nlocals { - unsafe_mask[var_idx] = true; - } - new_instructions.push(info); - } - Some(Instruction::DeleteFast { var_num }) => { - let var_idx = usize::from(var_num.get(info.arg)); - if var_idx < nlocals { - unsafe_mask[var_idx] = true; - } - new_instructions.push(info); - } - Some(Instruction::LoadFastAndClear { var_num }) => { - let var_idx = usize::from(var_num.get(info.arg)); - if var_idx < nlocals { - unsafe_mask[var_idx] = true; - } - new_instructions.push(info); - } - Some(Instruction::StoreFast { var_num }) => { - let var_idx = usize::from(var_num.get(info.arg)); - if var_idx < nlocals { - unsafe_mask[var_idx] = false; - } - new_instructions.push(info); - } - Some(Instruction::StoreDeref { i }) => { - let cell_relative = usize::from(i.get(info.arg)); - if let Some(var_idx) = merged_cell_local(cell_relative) - && var_idx < nlocals - { - unsafe_mask[var_idx] = false; - } - new_instructions.push(info); - } - Some(Instruction::StoreFastStoreFast { var_nums }) => { - let packed = var_nums.get(info.arg); - let (idx1, idx2) = packed.indexes(); - let idx1 = usize::from(idx1); - let idx2 = usize::from(idx2); - if idx1 < nlocals { - unsafe_mask[idx1] = false; - } - if idx2 < nlocals { - unsafe_mask[idx2] = false; - } - new_instructions.push(info); - } - Some(Instruction::LoadFastCheck { var_num }) => { - let var_idx = usize::from(var_num.get(info.arg)); - if var_idx < nlocals { - unsafe_mask[var_idx] = false; - } - new_instructions.push(info); - } - Some(Instruction::DeleteDeref { i }) => { - let cell_relative = usize::from(i.get(info.arg)); - if let Some(var_idx) = merged_cell_local(cell_relative) - && var_idx < nlocals - { - unsafe_mask[var_idx] = true; - } - new_instructions.push(info); - } - Some( - Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num }, - ) => { - let var_idx = usize::from(var_num.get(info.arg)); - if var_idx < nlocals && unsafe_mask[var_idx] { - info.instr = Opcode::LoadFastCheck.into(); - changed = true; - } - if var_idx < nlocals { - unsafe_mask[var_idx] = false; - } - new_instructions.push(info); - } - Some( - Instruction::LoadFastLoadFast { var_nums } - | Instruction::LoadFastBorrowLoadFastBorrow { var_nums }, - ) => { - let packed = var_nums.get(info.arg); - let (idx1, idx2) = packed.indexes(); - let idx1 = usize::from(idx1); - let idx2 = usize::from(idx2); - let needs_check_1 = idx1 < nlocals && unsafe_mask[idx1]; - let needs_check_2 = idx2 < nlocals && unsafe_mask[idx2]; - if needs_check_1 || needs_check_2 { - let mut first = info; - first.instr = if needs_check_1 { - Opcode::LoadFastCheck - } else { - Opcode::LoadFast - } - .into(); - first.arg = OpArg::new(idx1 as u32); - - let mut second = info; - second.instr = if needs_check_2 { - Opcode::LoadFastCheck.into() - } else { - Opcode::LoadFast.into() - }; - second.arg = OpArg::new(idx2 as u32); - - new_instructions.push(first); - new_instructions.push(second); - changed = true; - } else { - new_instructions.push(info); - } - if idx1 < nlocals { - unsafe_mask[idx1] = false; - } - if idx2 < nlocals { - unsafe_mask[idx2] = false; - } - } - _ => new_instructions.push(info), - } - } - - if changed { - self.blocks[idx].instructions = new_instructions; - } - - let block = &self.blocks[idx]; - if block_has_fallthrough(block) { - let next = next_nonempty_block(&self.blocks, block.next); - if next != BlockIdx::NULL - && merge_unsafe_mask(&mut in_masks[next.idx()], &unsafe_mask) - { - worklist.push(next); - } - } - - if let Some(last) = block.instructions.last() - && is_jump_instruction(last) - { - let target = next_nonempty_block(&self.blocks, last.target); - if target != BlockIdx::NULL - && merge_unsafe_mask(&mut in_masks[target.idx()], &unsafe_mask) - { - worklist.push(target); - } - } - } - } - - fn max_stackdepth(&mut self) -> crate::InternalResult { - let mut maxdepth = 0u32; - let mut stack = Vec::with_capacity(self.blocks.len()); - let mut start_depths = vec![u32::MAX; self.blocks.len()]; - stackdepth_push(&mut stack, &mut start_depths, BlockIdx(0), 0); - const DEBUG: bool = false; - 'process_blocks: while let Some(block_idx) = stack.pop() { - let idx = block_idx.idx(); - let mut depth = start_depths[idx]; - if DEBUG { - eprintln!("===BLOCK {}===", block_idx.0); - } - let block = &self.blocks[block_idx]; - for ins in &block.instructions { - let instr = &ins.instr; - let effect = instr.stack_effect(ins.arg.into()); - if DEBUG { - let display_arg = if ins.target == BlockIdx::NULL { - ins.arg - } else { - OpArg::new(ins.target.0) - }; - eprint!("{display_arg:?}: {depth} {effect:+} => "); - } - let new_depth = depth.checked_add_signed(effect).ok_or({ - if effect < 0 { - InternalError::StackUnderflow - } else { - InternalError::StackOverflow - } - })?; - if DEBUG { - eprintln!("{new_depth}"); - } - if new_depth > maxdepth { - maxdepth = new_depth - } - // Process target blocks for branching instructions - if ins.target != BlockIdx::NULL { - let jump_effect = instr.stack_effect_jump(ins.arg.into()); - let target_depth = depth.checked_add_signed(jump_effect).ok_or({ - if jump_effect < 0 { - InternalError::StackUnderflow - } else { - InternalError::StackOverflow - } - })?; - if target_depth > maxdepth { - maxdepth = target_depth; - } - let target = next_nonempty_block(&self.blocks, ins.target); - if target != BlockIdx::NULL { - stackdepth_push(&mut stack, &mut start_depths, target, target_depth); - } - } - depth = new_depth; - if instr.is_scope_exit() || instr.is_unconditional_jump() { - continue 'process_blocks; - } - } - // Only push next block if it's not NULL - let next = next_nonempty_block(&self.blocks, block.next); - if next != BlockIdx::NULL { - stackdepth_push(&mut stack, &mut start_depths, next, depth); - } - } - if DEBUG { - eprintln!("DONE: {maxdepth}"); - } - - for (block, &start_depth) in self.blocks.iter_mut().zip(&start_depths) { - block.start_depth = (start_depth != u32::MAX).then_some(start_depth); - } - - // Fix up handler stack_depth in ExceptHandlerInfo using start_depths - // computed above: depth = start_depth - 1 - preserve_lasti - for block in &mut self.blocks { - for ins in &mut block.instructions { - if let Some(ref mut handler) = ins.except_handler { - let h_start = start_depths[handler.handler_block.idx()]; - if h_start != u32::MAX { - let adjustment = 1 + handler.preserve_lasti as u32; - debug_assert!( - h_start >= adjustment, - "handler start depth {h_start} too shallow for adjustment {adjustment}" - ); - handler.stack_depth = h_start.saturating_sub(adjustment); - } - } - } - } - - Ok(maxdepth) - } -} - -#[cfg(test)] -impl CodeInfo { - fn debug_block_dump(&self) -> String { - let mut out = String::new(); - for (block_idx, block) in iter_blocks(&self.blocks) { - use core::fmt::Write; - let _ = writeln!( - out, - "block {} next={} cold={} except={} preserve_lasti={} disable_borrow={} start_depth={}", - u32::from(block_idx), - if block.next == BlockIdx::NULL { - String::from("NULL") - } else { - u32::from(block.next).to_string() - }, - block.cold, - block.except_handler, - block.preserve_lasti, - block.disable_load_fast_borrow, - block - .start_depth - .map_or_else(|| String::from("None"), |depth| depth.to_string()), - ); - for info in &block.instructions { - let lineno = instruction_lineno(info); - let _ = writeln!( - out, - " [disp={} raw={} override={:?}] {:?} arg={} target={}", - lineno, - info.location.line.get(), - info.lineno_override, - info.instr, - u32::from(info.arg), - if info.target == BlockIdx::NULL { - String::from("NULL") - } else { - u32::from(info.target).to_string() - } - ); - } - } - out - } - - pub(crate) fn debug_late_cfg_trace(mut self) -> crate::InternalResult> { - let mut trace = Vec::new(); - trace.push(("initial".to_owned(), self.debug_block_dump())); - - self.splice_annotations_blocks(); - self.fold_binop_constants(); - self.fold_unary_constants(); - self.fold_binop_constants(); - self.fold_unary_constants(); - self.fold_tuple_constants(); - self.fold_binop_constants(); - self.fold_list_constants(); - self.fold_set_constants(); - self.optimize_lists_and_sets(); - self.convert_to_load_small_int(); - self.remove_unused_consts(); - self.dce(); - self.optimize_build_tuple_unpack(); - self.eliminate_dead_stores(); - self.apply_static_swaps(); - self.peephole_optimize(); - trace.push(( - "after_peephole_optimize".to_owned(), - self.debug_block_dump(), - )); - self.fold_tuple_constants(); - self.fold_binop_constants(); - self.fold_list_constants(); - self.fold_set_constants(); - self.optimize_lists_and_sets(); - self.convert_to_load_small_int(); - self.remove_unused_consts(); - self.dce(); - split_blocks_at_jumps(&mut self.blocks); - trace.push(( - "after_split_blocks_at_jumps".to_owned(), - self.debug_block_dump(), - )); - mark_except_handlers(&mut self.blocks); - label_exception_targets(&mut self.blocks); - redirect_empty_unconditional_jump_targets(&mut self.blocks); - inline_small_or_no_lineno_blocks(&mut self.blocks); - trace.push(( - "after_inline_small_or_no_lineno_blocks".to_owned(), - self.debug_block_dump(), - )); - jump_threading(&mut self.blocks); - trace.push(("after_jump_threading".to_owned(), self.debug_block_dump())); - self.eliminate_unreachable_blocks(); - self.remove_nops(); - trace.push(( - "after_early_remove_nops".to_owned(), - self.debug_block_dump(), - )); - self.add_checks_for_loads_of_uninitialized_variables(); - self.insert_superinstructions(); - resolve_line_numbers(&mut self.blocks); - inline_single_predecessor_artificial_expr_exit_blocks(&mut self.blocks); - trace.push(( - "after_first_resolve_line_numbers".to_owned(), - self.debug_block_dump(), - )); - push_cold_blocks_to_end(&mut self.blocks); - trace.push(( - "after_push_cold_before_chain_reorder".to_owned(), - self.debug_block_dump(), - )); - reorder_conditional_chain_and_jump_back_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, true, true); - - trace.push(( - "after_push_cold_blocks_to_end".to_owned(), - self.debug_block_dump(), - )); - - normalize_jumps(&mut self.blocks); - trace.push(("after_normalize_jumps".to_owned(), self.debug_block_dump())); - reorder_conditional_exit_and_jump_blocks(&mut self.blocks); - reorder_conditional_jump_and_exit_blocks(&mut self.blocks); - reorder_conditional_break_continue_blocks(&mut self.blocks); - reorder_conditional_explicit_continue_scope_exit_blocks(&mut self.blocks); - reorder_conditional_implicit_continue_scope_exit_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, true, true); - reorder_exception_handler_conditional_continue_scope_exit_blocks(&mut self.blocks); - deduplicate_adjacent_jump_back_blocks(&mut self.blocks); - reorder_conditional_body_and_implicit_continue_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, true, true); - reorder_jump_over_exception_cleanup_blocks(&mut self.blocks); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, false, true); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, false, true); - reorder_conditional_scope_exit_and_jump_back_blocks(&mut self.blocks, false, false); - trace.push(("after_reorder".to_owned(), self.debug_block_dump())); - - self.dce(); - self.eliminate_unreachable_blocks(); - trace.push(("after_dce_unreachable".to_owned(), self.debug_block_dump())); - - resolve_line_numbers(&mut self.blocks); - trace.push(( - "after_resolve_line_numbers".to_owned(), - self.debug_block_dump(), - )); - - materialize_empty_conditional_exit_targets(&mut self.blocks); - trace.push(( - "after_materialize_empty_conditional_exit_targets".to_owned(), - self.debug_block_dump(), - )); - redirect_empty_block_targets(&mut self.blocks); - trace.push(( - "after_redirect_empty_block_targets".to_owned(), - self.debug_block_dump(), - )); - - inline_small_fast_return_blocks(&mut self.blocks); - inline_unprotected_tuple_genexpr_assignment_return_blocks(&mut self.blocks); - trace.push(( - "after_inline_small_fast_return_blocks".to_owned(), - self.debug_block_dump(), - )); - - duplicate_end_returns(&mut self.blocks, &self.metadata); - duplicate_fallthrough_jump_back_targets(&mut self.blocks); - duplicate_shared_jump_back_targets(&mut self.blocks); - trace.push(( - "after_duplicate_jump_back_targets".to_owned(), - self.debug_block_dump(), - )); - - self.dce(); - self.eliminate_unreachable_blocks(); - trace.push(( - "after_second_dce_unreachable".to_owned(), - self.debug_block_dump(), - )); - - resolve_line_numbers(&mut self.blocks); - trace.push(( - "after_final_resolve_line_numbers".to_owned(), - self.debug_block_dump(), - )); - - self.remove_redundant_const_pop_top_pairs(); - remove_redundant_nops_and_jumps(&mut self.blocks); - trace.push(( - "after_remove_redundant_nops_and_jumps".to_owned(), - self.debug_block_dump(), - )); - - jump_threading_unconditional(&mut self.blocks); - reorder_jump_over_exception_cleanup_blocks(&mut self.blocks); - self.eliminate_unreachable_blocks(); - remove_redundant_nops_and_jumps(&mut self.blocks); - inline_with_suppress_return_blocks(&mut self.blocks); - inline_pop_except_return_blocks(&mut self.blocks); - inline_named_except_cleanup_normal_exit_jumps(&mut self.blocks); - duplicate_named_except_cleanup_returns(&mut self.blocks, &self.metadata); - self.eliminate_unreachable_blocks(); - trace.push(( - "after_final_cfg_cleanup".to_owned(), - self.debug_block_dump(), - )); - - resolve_line_numbers(&mut self.blocks); - trace.push(( - "after_post_cleanup_resolve_line_numbers".to_owned(), - self.debug_block_dump(), - )); - - let cellfixedoffsets = build_cellfixedoffsets( - &self.metadata.varnames, - &self.metadata.cellvars, - &self.metadata.freevars, - ); - mark_except_handlers(&mut self.blocks); - redirect_empty_block_targets(&mut self.blocks); - let _ = self.max_stackdepth()?; - convert_pseudo_ops(&mut self.blocks, &cellfixedoffsets); - remove_redundant_nops_and_jumps(&mut self.blocks); - self.mark_unprotected_debug_four_tails_borrow_disabled(); - self.mark_exception_handler_transition_targets_borrow_disabled(); - self.mark_targeted_nop_for_tails_borrow_disabled(); - trace.push(( - "after_convert_pseudo_ops".to_owned(), - self.debug_block_dump(), - )); - self.compute_load_fast_start_depths(); - trace.push(( - "after_compute_load_fast_start_depths".to_owned(), - self.debug_block_dump(), - )); - self.optimize_load_fast_borrow(); - trace.push(( - "after_raw_optimize_load_fast_borrow".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_in_targeted_assert_message_blocks(); - trace.push(( - "after_deoptimize_borrow_in_targeted_assert_message_blocks".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_for_folded_nonliteral_exprs(); - trace.push(( - "after_deoptimize_borrow_for_folded_nonliteral_exprs".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_generator_exception_return(); - self.deoptimize_borrow_after_async_for_cleanup_resume(); - trace.push(( - "after_deoptimize_borrow_after_generator_exception_return".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_multi_handler_resume_join(); - trace.push(( - "after_deoptimize_borrow_after_multi_handler_resume_join".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_named_except_cleanup_join(); - trace.push(( - "after_deoptimize_borrow_after_named_except_cleanup_join".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_reraising_except_handler(); - trace.push(( - "after_deoptimize_borrow_after_reraising_except_handler".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_in_protected_conditional_tail(); - trace.push(( - "after_deoptimize_borrow_in_protected_conditional_tail".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_terminal_except_tail(); - trace.push(( - "after_deoptimize_borrow_after_terminal_except_tail".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_except_star_try_tail(); - trace.push(( - "after_deoptimize_borrow_after_except_star_try_tail".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_in_protected_method_call_after_terminal_except_tail(); - trace.push(( - "after_deoptimize_borrow_in_protected_method_call_after_terminal_except_tail" - .to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_terminal_except_before_with(); - trace.push(( - "after_deoptimize_borrow_after_terminal_except_before_with".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_handler_resume_loop_tail(); - trace.push(( - "after_deoptimize_borrow_after_handler_resume_loop_tail".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_protected_import(); - trace.push(( - "after_deoptimize_borrow_after_protected_import".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_before_import_after_join_store(); - trace.push(( - "after_deoptimize_borrow_before_import_after_join_store".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_protected_store_tail(); - trace.push(( - "after_deoptimize_borrow_after_protected_store_tail".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_after_deoptimized_async_with_enter(); - trace.push(( - "after_optimize_load_fast_borrow".to_owned(), - self.debug_block_dump(), - )); - self.deoptimize_borrow_for_handler_return_paths(); - self.deoptimize_borrow_for_match_keys_attr(); - self.deoptimize_borrow_in_protected_attr_chain_tail(); - self.reborrow_after_suppressing_handler_resume_cleanup(); - trace.push(("after_borrow_deopts".to_owned(), self.debug_block_dump())); - self.deoptimize_store_fast_store_fast_after_cleanup(); - self.apply_static_swaps(); - self.deoptimize_store_fast_store_fast_after_cleanup(); - self.optimize_load_global_push_null(); - self.reorder_entry_prefix_cell_setup(); - self.remove_unused_consts(); - - Ok(trace) + fn capacity(&self) -> usize { + self.stack.len() } } -impl CodeInfo { - fn remap_block_idx(idx: BlockIdx, base: u32) -> BlockIdx { - if idx == BlockIdx::NULL { - idx - } else { - BlockIdx::new(u32::from(idx) + base) - } - } - - fn splice_annotations_blocks(&mut self) { - let mut placeholder = None; - for (block_idx, block) in self.blocks.iter().enumerate() { - if let Some(instr_idx) = block.instructions.iter().position(|info| { - matches!( - info.instr.pseudo(), - Some(PseudoInstruction::AnnotationsPlaceholder) - ) - }) { - placeholder = Some((block_idx, instr_idx)); - break; - } - } - - let Some((block_idx, instr_idx)) = placeholder else { - return; - }; - - let Some(mut annotations_blocks) = self.annotations_blocks.take() else { - self.blocks[block_idx].instructions.remove(instr_idx); - return; - }; - if annotations_blocks.is_empty() { - self.blocks[block_idx].instructions.remove(instr_idx); - return; - } - - let base = self.blocks.len() as u32; - for block in &mut annotations_blocks { - block.next = Self::remap_block_idx(block.next, base); - for info in &mut block.instructions { - info.target = Self::remap_block_idx(info.target, base); - if let Some(handler) = &mut info.except_handler { - handler.handler_block = Self::remap_block_idx(handler.handler_block, base); - } - } - } - - let ann_entry = BlockIdx::new(base); - let ann_tail = { - let mut cursor = ann_entry; - while annotations_blocks[(u32::from(cursor) - base) as usize].next != BlockIdx::NULL { - cursor = annotations_blocks[(u32::from(cursor) - base) as usize].next; - } - cursor - }; - - let old_next = self.blocks[block_idx].next; - let suffix = self.blocks[block_idx].instructions.split_off(instr_idx + 1); - self.blocks[block_idx].instructions.pop(); - - let suffix_block = if suffix.is_empty() { - old_next - } else { - let suffix_idx = BlockIdx::new(base + annotations_blocks.len() as u32); - let disable_load_fast_borrow = self.blocks[block_idx].disable_load_fast_borrow; - let block = Block { - instructions: suffix, - next: old_next, - disable_load_fast_borrow, - ..Default::default() - }; - annotations_blocks.push(block); - suffix_idx - }; - - self.blocks[block_idx].next = ann_entry; - let ann_tail_local = (u32::from(ann_tail) - base) as usize; - annotations_blocks[ann_tail_local].next = suffix_block; - self.blocks.extend(annotations_blocks); +#[derive(Clone, Debug)] +pub(crate) struct InstructionSequenceLabelMap { + block_labels: Vec, + /// Codegen-side shadow of CPython's instruction-sequence label map. + /// + /// `_PyInstructionSequence_UseLabel()` can map multiple labels to the same + /// instruction offset before `_PyCfg_FromInstructionSequence()` materializes + /// CFG blocks. The codegen CFG path keeps the same aliasing by resolving + /// those labels to the block that owns the shared offset. + cpython_block_by_label: Vec, +} + +fn instruction_sequence_label_map_register_label( + map: &mut InstructionSequenceLabelMap, + label: InstructionSequenceLabel, +) -> crate::InternalResult<()> { + debug_assert!(is_label(label)); + let old_size = map.cpython_block_by_label.len(); + let new_allocation = c_array_ensure_capacity::( + old_size, + label.idx(), + INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE, + )?; + if new_allocation > old_size { + if new_allocation > map.cpython_block_by_label.capacity() { + map.cpython_block_by_label + .try_reserve_exact(new_allocation - map.cpython_block_by_label.capacity()) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + } + map.cpython_block_by_label + .resize(new_allocation, BlockIdx::NULL); + for i in old_size..map.cpython_block_by_label.len() { + map.cpython_block_by_label[i] = BlockIdx::NULL; + } + } + debug_assert!(map.cpython_block_by_label.len() > label.idx()); + Ok(()) +} + +fn instruction_sequence_label_map_ensure_label_for_block( + map: &mut InstructionSequenceLabelMap, + seq: &mut InstructionSequence, + block: BlockIdx, +) -> crate::InternalResult { + debug_assert_ne!(block, BlockIdx::NULL); + let block_label = map.block_labels[block.idx()]; + if is_label(block_label) { + return Ok(block_label); + } + let label = instruction_sequence_new_label(seq); + debug_assert_eq!(label.0, seq.next_free_label); + instruction_sequence_label_map_register_label(map, label)?; + map.cpython_block_by_label[label.idx()] = block; + map.block_labels[block.idx()] = label; + Ok(label) +} + +fn instruction_sequence_label_map_label_for_block( + map: &InstructionSequenceLabelMap, + block: BlockIdx, +) -> InstructionSequenceLabel { + debug_assert_ne!(block, BlockIdx::NULL); + map.block_labels + .get(block.idx()) + .copied() + .unwrap_or(InstructionSequenceLabel::NO_LABEL) +} + +fn instruction_sequence_label_map_block_for_label( + map: &InstructionSequenceLabelMap, + label: InstructionSequenceLabel, +) -> Option { + if !is_label(label) { + return None; } + map.cpython_block_by_label + .get(label.idx()) + .copied() + .filter(|&block| block != BlockIdx::NULL) } -impl InstrDisplayContext for CodeInfo { - type Constant = ConstantData; - - fn get_constant(&self, consti: oparg::ConstIdx) -> &ConstantData { - &self.metadata.consts[consti.as_usize()] - } - - fn get_name(&self, i: usize) -> &str { - self.metadata.names[i].as_ref() - } - - fn get_varname(&self, var_num: oparg::VarNum) -> &str { - self.metadata.varnames[var_num.as_usize()].as_ref() +fn instruction_sequence_label_map_resolve_label( + map: &InstructionSequenceLabelMap, + block: BlockIdx, +) -> BlockIdx { + if block == BlockIdx::NULL { + return BlockIdx::NULL; } - - fn get_localsplus_name(&self, var_num: oparg::VarNum) -> &str { - let idx = var_num.as_usize(); - let nlocals = self.metadata.varnames.len(); - if idx < nlocals { - self.metadata.varnames[idx].as_ref() - } else { - let cell_idx = idx - nlocals; - self.metadata - .cellvars - .get_index(cell_idx) - .unwrap_or_else(|| &self.metadata.freevars[cell_idx - self.metadata.cellvars.len()]) - .as_ref() - } + let label = instruction_sequence_label_map_label_for_block(map, block); + if !is_label(label) { + return block; } + instruction_sequence_label_map_block_for_label(map, label).unwrap_or_else(|| { + debug_assert!( + false, + "CPython instruction-sequence label must map to a codegen CFG block" + ); + BlockIdx::NULL + }) } -fn stackdepth_push( - stack: &mut Vec, - start_depths: &mut [u32], - target: BlockIdx, - depth: u32, -) { - let idx = target.idx(); - let block_depth = &mut start_depths[idx]; - if depth > *block_depth || *block_depth == u32::MAX { - *block_depth = depth; - stack.push(target); +fn instruction_sequence_label_map_resolve_label_to_block( + map: &InstructionSequenceLabelMap, + label: InstructionSequenceLabel, +) -> BlockIdx { + if !is_label(label) { + return BlockIdx::NULL; } + instruction_sequence_label_map_block_for_label(map, label).unwrap_or_else(|| { + debug_assert!( + false, + "CPython instruction-sequence label must map to a codegen CFG block" + ); + BlockIdx::NULL + }) } -fn iter_blocks(blocks: &[Block]) -> impl Iterator + '_ { - let mut next = BlockIdx(0); - core::iter::from_fn(move || { - if next == BlockIdx::NULL { - return None; - } - let (idx, b) = (next, &blocks[next]); - next = b.next; - Some((idx, b)) - }) +fn instruction_sequence_label_oparg(label: InstructionSequenceLabel) -> OpArg { + debug_assert!(is_label(label)); + OpArg::new(label.idx() as u32) } -/// Generate Python 3.11+ format linetable from source locations -fn generate_linetable( - locations: &[LineTableLocation], - first_line: i32, - debug_ranges: bool, -) -> Box<[u8]> { - if locations.is_empty() { - return Box::new([]); +fn instruction_sequence_label_map_use_label_at_block( + map: &mut InstructionSequenceLabelMap, + seq: &mut InstructionSequence, + from: BlockIdx, + to: BlockIdx, +) -> crate::InternalResult<()> { + if from == BlockIdx::NULL || from == to { + return Ok(()); } - - let mut linetable = Vec::new(); - // Initialize prev_line to first_line - // The first entry's delta is relative to co_firstlineno - let mut prev_line = first_line; - let mut i = 0; - - while i < locations.len() { - let loc = &locations[i]; - - // Count consecutive instructions with the same location - let mut length = 1; - while i + length < locations.len() && locations[i + length] == locations[i] { - length += 1; - } - - // Process in chunks of up to 8 instructions - while length > 0 { - let entry_length = length.min(8); - - // Get line information - let line = loc.line; - - // NO_LOCATION: emit PyCodeLocationInfoKind::None entries (CACHE, etc.) - if line == -1 { - linetable.push( - 0x80 | ((PyCodeLocationInfoKind::None as u8) << 3) | ((entry_length - 1) as u8), - ); - // Do NOT update prev_line - length -= entry_length; - i += entry_length; - continue; - } - - let end_line = loc.end_line; - let line_delta = line - prev_line; - let end_line_delta = end_line - line; - - // When debug_ranges is disabled, only emit line info (NoColumns format) - if !debug_ranges { - // NoColumns format (code 13): line info only, no column data - linetable.push( - 0x80 | ((PyCodeLocationInfoKind::NoColumns as u8) << 3) - | ((entry_length - 1) as u8), - ); - write_signed_varint(&mut linetable, line_delta); - - prev_line = line; - length -= entry_length; - i += entry_length; - continue; - } - - // Get column information (only when debug_ranges is enabled) - let col = loc.col; - let end_col = loc.end_col; - - // Choose the appropriate encoding based on line delta and column info - if line_delta == 0 && end_line_delta == 0 { - if col < 80 && end_col - col < 16 && end_col >= col { - // Short form (codes 0-9) for common cases - let code = (col / 8).min(9) as u8; // Short0 to Short9 - linetable.push(0x80 | (code << 3) | ((entry_length - 1) as u8)); - let col_byte = (((col % 8) as u8) << 4) | ((end_col - col) as u8 & 0xf); - linetable.push(col_byte); - } else if col < 128 && end_col < 128 { - // One-line form (code 10) for same line - linetable.push( - 0x80 | ((PyCodeLocationInfoKind::OneLine0 as u8) << 3) - | ((entry_length - 1) as u8), - ); - linetable.push(col as u8); - linetable.push(end_col as u8); - } else { - // Long form for columns >= 128 - linetable.push( - 0x80 | ((PyCodeLocationInfoKind::Long as u8) << 3) - | ((entry_length - 1) as u8), - ); - write_signed_varint(&mut linetable, 0); // line_delta = 0 - write_varint(&mut linetable, 0); // end_line delta = 0 - write_varint(&mut linetable, (col as u32) + 1); - write_varint(&mut linetable, (end_col as u32) + 1); - } - } else if line_delta > 0 && line_delta < 3 && end_line_delta == 0 { - // One-line form (codes 11-12) for line deltas 1-2 - if col < 128 && end_col < 128 { - let code = (PyCodeLocationInfoKind::OneLine0 as u8) + (line_delta as u8); - linetable.push(0x80 | (code << 3) | ((entry_length - 1) as u8)); - linetable.push(col as u8); - linetable.push(end_col as u8); - } else { - // Long form for columns >= 128 - linetable.push( - 0x80 | ((PyCodeLocationInfoKind::Long as u8) << 3) - | ((entry_length - 1) as u8), - ); - write_signed_varint(&mut linetable, line_delta); - write_varint(&mut linetable, 0); // end_line delta = 0 - write_varint(&mut linetable, (col as u32) + 1); - write_varint(&mut linetable, (end_col as u32) + 1); - } - } else { - // Long form (code 14) for all other cases - // Handles: line_delta < 0, line_delta >= 3, multi-line spans, or columns >= 128 - linetable.push( - 0x80 | ((PyCodeLocationInfoKind::Long as u8) << 3) | ((entry_length - 1) as u8), - ); - write_signed_varint(&mut linetable, line_delta); - write_varint(&mut linetable, end_line_delta as u32); - write_varint(&mut linetable, (col as u32) + 1); - write_varint(&mut linetable, (end_col as u32) + 1); - } - - prev_line = line; - length -= entry_length; - i += entry_length; + let from_label = instruction_sequence_label_map_ensure_label_for_block(map, seq, from)?; + debug_assert!(map.cpython_block_by_label.len() > from_label.idx()); + let to_block = instruction_sequence_label_map_resolve_label(map, to); + if to_block == BlockIdx::NULL { + debug_assert!( + false, + "CPython label target must map to a codegen CFG block" + ); + return Ok(()); + } + map.cpython_block_by_label[from_label.idx()] = to_block; + Ok(()) +} + +fn instruction_sequence_label_map_push_unlabeled_block( + map: &mut InstructionSequenceLabelMap, +) -> crate::InternalResult<()> { + map.block_labels + .try_reserve(1) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + map.block_labels.push(InstructionSequenceLabel::NO_LABEL); + Ok(()) +} + +fn instruction_sequence_label_map_push_unmapped_label( + map: &mut InstructionSequenceLabelMap, + seq: &mut InstructionSequence, +) -> crate::InternalResult<()> { + let label = instruction_sequence_new_label(seq); + debug_assert_eq!(label.0, seq.next_free_label); + instruction_sequence_label_map_register_label(map, label)?; + let block = BlockIdx( + map.block_labels + .len() + .to_u32() + .ok_or(InternalError::MalformedControlFlowGraph)?, + ); + map.cpython_block_by_label[label.idx()] = block; + map.block_labels + .try_reserve(1) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + map.block_labels.push(label); + Ok(()) +} + +impl InstructionSequenceLabelMap { + pub(crate) fn new() -> Self { + Self { + block_labels: vec![InstructionSequenceLabel::NO_LABEL], + cpython_block_by_label: Vec::new(), } } - - linetable.into_boxed_slice() } -/// Generate Python 3.11+ exception table from instruction handler info -fn generate_exception_table(blocks: &[Block], block_to_index: &[u32]) -> Box<[u8]> { - let mut entries: Vec = Vec::new(); - let mut current_entry: Option<(ExceptHandlerInfo, u32)> = None; // (handler_info, start_index) - let mut instr_index = 0u32; - - // Iterate through all instructions in block order - // instr_index is the index into the final instructions array (including EXTENDED_ARG) - // This matches how frame.rs uses lasti - for (_, block) in iter_blocks(blocks) { - for instr in &block.instructions { - // instr_size includes EXTENDED_ARG and CACHE entries - let instr_size = instr.arg.instr_size() as u32 + instr.cache_entries; - - match (¤t_entry, instr.except_handler) { - // No current entry, no handler - nothing to do - (None, None) => {} - - // No current entry, handler starts - begin new entry - (None, Some(handler)) => { - current_entry = Some((handler, instr_index)); - } - - // Current entry exists, same handler - continue - (Some((curr_handler, _)), Some(handler)) - if curr_handler.handler_block == handler.handler_block - && curr_handler.stack_depth == handler.stack_depth - && curr_handler.preserve_lasti == handler.preserve_lasti => {} - - // Current entry exists, different handler - finish current, start new - (Some((curr_handler, start)), Some(handler)) => { - let target_index = block_to_index[curr_handler.handler_block.idx()]; - entries.push(ExceptionTableEntry::new( - *start, - instr_index, - target_index, - curr_handler.stack_depth as u16, - curr_handler.preserve_lasti, - )); - current_entry = Some((handler, instr_index)); - } - - // Current entry exists, no handler - finish current entry - (Some((curr_handler, start)), None) => { - let target_index = block_to_index[curr_handler.handler_block.idx()]; - entries.push(ExceptionTableEntry::new( - *start, - instr_index, - target_index, - curr_handler.stack_depth as u16, - curr_handler.preserve_lasti, - )); - current_entry = None; - } - } +pub struct CodeInfo { + pub flags: CodeFlags, + pub source_path: String, + pub private: Option, // For private name mangling, mostly for class - instr_index += instr_size; // Account for EXTENDED_ARG instructions - } - } + pub blocks: Vec, + pub current_block: BlockIdx, + pub(crate) instr_sequence: InstructionSequence, + pub(crate) instr_sequence_label_map: InstructionSequenceLabelMap, + pub(crate) annotations_instr_sequence: Option, - // Finish any remaining entry - if let Some((curr_handler, start)) = current_entry { - let target_index = block_to_index[curr_handler.handler_block.idx()]; - entries.push(ExceptionTableEntry::new( - start, - instr_index, - target_index, - curr_handler.stack_depth as u16, - curr_handler.preserve_lasti, - )); - } + pub metadata: CodeUnitMetadata, - encode_exception_table(&entries) -} + // For class scopes: attributes accessed via self.X + pub static_attributes: Option>, -/// Mark exception handler target blocks. -/// flowgraph.c mark_except_handlers -pub(crate) fn mark_except_handlers(blocks: &mut [Block]) { - // Reset handler flags - for block in blocks.iter_mut() { - block.except_handler = false; - block.preserve_lasti = false; - } - // Mark target blocks of SETUP_* as except handlers - let targets: Vec = blocks - .iter() - .flat_map(|b| b.instructions.iter()) - .filter(|i| i.instr.is_block_push() && i.target != BlockIdx::NULL) - .map(|i| i.target.idx()) - .collect(); - for idx in targets { - blocks[idx].except_handler = true; - } -} + // True if compiling an inlined comprehension + pub in_inlined_comp: bool, -/// flowgraph.c mark_cold -fn mark_cold(blocks: &mut [Block]) { - let n = blocks.len(); - let mut warm = vec![false; n]; - let mut queue = VecDeque::new(); + // Block stack for tracking nested control structures + pub fblock: Vec, - warm[0] = true; - queue.push_back(BlockIdx(0)); + // Reference to the symbol table for this scope + pub symbol_table_index: usize, + // CPython compile.c uses PyList_GET_SIZE(u->u_ste->ste_varnames) + // when calling flowgraph.c _PyCfg_OptimizeCodeUnit(). + pub nparams: usize, - while let Some(block_idx) = queue.pop_front() { - let block = &blocks[block_idx.idx()]; + // PEP 649: Track nesting depth inside conditional blocks (if/for/while/etc.) + // u_in_conditional_block + pub in_conditional_block: u32, - let has_fallthrough = block - .instructions - .last() - .is_none_or(|ins| !ins.instr.is_scope_exit() && !ins.instr.is_unconditional_jump()); - if has_fallthrough && block.next != BlockIdx::NULL { - let next_idx = block.next.idx(); - if !blocks[next_idx].except_handler && !warm[next_idx] { - warm[next_idx] = true; - queue.push_back(block.next); - } - } + // PEP 649: Next index for conditional annotation tracking + // u_next_conditional_annotation_index + pub next_conditional_annotation_index: u32, +} - for instr in &block.instructions { - if instr.target != BlockIdx::NULL { - let target_idx = instr.target.idx(); - if !blocks[target_idx].except_handler && !warm[target_idx] { - warm[target_idx] = true; - queue.push_back(instr.target); - } - } +impl CodeInfo { + pub(crate) fn addop_to_instr_sequence( + &mut self, + mut info: InstructionInfo, + ) -> crate::InternalResult<()> { + if info.instr.has_target() && info.target != BlockIdx::NULL { + let label = instruction_sequence_label_map_ensure_label_for_block( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, + info.target, + )?; + info.arg = instruction_sequence_label_oparg(label); + info.target = BlockIdx::NULL; + } + instruction_sequence_addop(&mut self.instr_sequence, info)?; + Ok(()) + } + + pub(crate) fn addop_to_instr_sequence_with_target_label( + &mut self, + mut info: InstructionInfo, + target_label: InstructionSequenceLabel, + ) -> crate::InternalResult<()> { + if !info.instr.has_target() { + return Err(InternalError::MalformedControlFlowGraph); } + info.arg = instruction_sequence_label_oparg(target_label); + info.target = BlockIdx::NULL; + instruction_sequence_addop(&mut self.instr_sequence, info)?; + Ok(()) } - for (i, block) in blocks.iter_mut().enumerate() { - block.cold = !warm[i]; + pub(crate) fn addop_to_current_block( + &mut self, + info: InstructionInfo, + ) -> crate::InternalResult<()> { + basicblock_addop(&mut self.blocks[self.current_block.idx()], info) } -} -/// flowgraph.c push_cold_blocks_to_end -fn push_cold_blocks_to_end(blocks: &mut Vec) { - if blocks.len() <= 1 { - return; + pub(crate) fn last_current_block_instr_mut(&mut self) -> Option<&mut InstructionInfo> { + basicblock_last_instr_mut(&mut self.blocks[self.current_block.idx()]) } - mark_cold(blocks); + pub(crate) fn set_last_instr_sequence_lineno_override(&mut self, lineno_override: i32) { + if let Some(last) = instruction_sequence_last_info_mut(&mut self.instr_sequence) { + last.lineno_override = Some(lineno_override); + } + } - // If a cold block falls through to a warm block, add an explicit jump - let fixups: Vec<(BlockIdx, BlockIdx)> = iter_blocks(blocks) - .filter(|(_, block)| { - block.cold - && block.next != BlockIdx::NULL - && !blocks[block.next.idx()].cold - && block.instructions.last().is_none_or(|ins| { - !ins.instr.is_scope_exit() && !ins.instr.is_unconditional_jump() - }) - }) - .map(|(idx, block)| (idx, block.next)) - .collect(); - - for (cold_idx, warm_next) in fixups { - let jump_block_idx = BlockIdx(blocks.len() as u32); - let mut jump_block = Block { - cold: true, - ..Block::default() - }; - jump_block.instructions.push(InstructionInfo { - instr: PseudoOpcode::JumpNoInterrupt.into(), - arg: OpArg::new(0), - target: warm_next, - location: SourceLocation::default(), - end_location: SourceLocation::default(), - except_handler: None, - folded_from_nonliteral_expr: false, - lineno_override: Some(-1), - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - }); - jump_block.next = blocks[cold_idx.idx()].next; - blocks[cold_idx.idx()].next = jump_block_idx; - blocks.push(jump_block); + pub(crate) fn use_instr_sequence_label( + &mut self, + block: BlockIdx, + ) -> crate::InternalResult<()> { + let label = instruction_sequence_label_map_ensure_label_for_block( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, + block, + )?; + instruction_sequence_use_label(&mut self.instr_sequence, label) } - // Extract cold block streaks and append at the end - let mut cold_head: BlockIdx = BlockIdx::NULL; - let mut cold_tail: BlockIdx = BlockIdx::NULL; - let mut current = BlockIdx(0); - assert!(!blocks[0].cold); + pub(crate) fn new_instr_sequence_label(&mut self) -> InstructionSequenceLabel { + instruction_sequence_new_label(&mut self.instr_sequence) + } - while current != BlockIdx::NULL { - let next = blocks[current.idx()].next; - if next == BlockIdx::NULL { - break; - } + pub(crate) fn use_raw_instr_sequence_label( + &mut self, + label: InstructionSequenceLabel, + ) -> crate::InternalResult<()> { + instruction_sequence_use_label(&mut self.instr_sequence, label) + } - if blocks[next.idx()].cold { - let cold_start = next; - let mut cold_end = next; - while blocks[cold_end.idx()].next != BlockIdx::NULL - && blocks[blocks[cold_end.idx()].next.idx()].cold - { - cold_end = blocks[cold_end.idx()].next; - } + pub(crate) fn mark_cpython_cfg_label(&mut self, block: BlockIdx) -> crate::InternalResult<()> { + let label = instruction_sequence_label_map_ensure_label_for_block( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, + block, + )?; + self.blocks[block.idx()].cpython_label = label; + Ok(()) + } - let after_cold = blocks[cold_end.idx()].next; - blocks[current.idx()].next = after_cold; - blocks[cold_end.idx()].next = BlockIdx::NULL; + pub(crate) fn resolve_instr_sequence_label(&self, block: BlockIdx) -> BlockIdx { + instruction_sequence_label_map_resolve_label(&self.instr_sequence_label_map, block) + } - if cold_head == BlockIdx::NULL { - cold_head = cold_start; - } else { - blocks[cold_tail.idx()].next = cold_start; - } - cold_tail = cold_end; - } else { - current = next; - } + pub(crate) fn block_for_instr_sequence_label( + &self, + label: InstructionSequenceLabel, + ) -> BlockIdx { + instruction_sequence_label_map_resolve_label_to_block(&self.instr_sequence_label_map, label) } - if cold_head != BlockIdx::NULL { - let mut last = current; - while blocks[last.idx()].next != BlockIdx::NULL { - last = blocks[last.idx()].next; - } - blocks[last.idx()].next = cold_head; - remove_redundant_nops_and_jumps(blocks); + pub(crate) fn use_instr_sequence_label_at_block( + &mut self, + from: BlockIdx, + to: BlockIdx, + ) -> crate::InternalResult<()> { + instruction_sequence_label_map_use_label_at_block( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, + from, + to, + ) } -} -/// Split blocks at branch points so each block has at most one branch -/// (conditional/unconditional jump) as its last instruction. -/// This matches CPython's CFG structure where each basic block has one exit. -fn split_blocks_at_jumps(blocks: &mut Vec) { - let mut bi = 0; - while bi < blocks.len() { - // Find the first jump/branch instruction in the block - let split_at = { - let block = &blocks[bi]; - let mut found = None; - for (i, ins) in block.instructions.iter().enumerate() { - if is_conditional_jump(&ins.instr) - || ins.instr.is_unconditional_jump() - || ins.instr.is_scope_exit() - { - if i + 1 < block.instructions.len() { - found = Some(i + 1); - } - break; - } - } - found - }; - if let Some(pos) = split_at { - let new_block_idx = BlockIdx(blocks.len() as u32); - let tail: Vec = blocks[bi].instructions.drain(pos..).collect(); - let old_next = blocks[bi].next; - let cold = blocks[bi].cold; - let disable_load_fast_borrow = blocks[bi].disable_load_fast_borrow; - blocks[bi].next = new_block_idx; - blocks.push(Block { - instructions: tail, - next: old_next, - cold, - disable_load_fast_borrow, - ..Block::default() - }); - // Don't increment bi - re-check current block (it might still have issues) + pub(crate) fn instr_sequence_label_for_block( + &mut self, + block: BlockIdx, + ) -> crate::InternalResult { + if block == BlockIdx::NULL { + Ok(InstructionSequenceLabel::NO_LABEL) } else { - bi += 1; + instruction_sequence_label_map_ensure_label_for_block( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, + block, + ) } } -} - -/// Jump threading: when a block's last jump targets a block whose first -/// instruction is an unconditional jump, redirect to the final target. -/// flowgraph.c optimize_basic_block + jump_thread -fn jump_threading(blocks: &mut [Block]) { - jump_threading_impl(blocks, true); -} - -fn jump_threading_unconditional(blocks: &mut [Block]) { - jump_threading_impl(blocks, false); -} -fn short_circuit_stub_conditional(block: &Block) -> Option { - let cond_idx = trailing_conditional_jump_index(block)?; - if cond_idx < 2 { - return None; - } - let [first, second, ..] = block.instructions.as_slice() else { - return None; - }; - if !matches!(first.instr.real(), Some(Instruction::Copy { i }) if i.get(first.arg) == 1) - || !matches!(second.instr.real(), Some(Instruction::ToBool)) - { - return None; + pub(crate) fn insert_start_setup_cleanup( + &mut self, + handler_block: BlockIdx, + ) -> crate::InternalResult<()> { + let handler_label = instruction_sequence_label_map_ensure_label_for_block( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, + handler_block, + )?; + instruction_sequence_insert_instruction( + &mut self.instr_sequence, + 0, + InstructionInfo { + instr: PseudoInstruction::SetupCleanup { + delta: Arg::marker(), + } + .into(), + arg: instruction_sequence_label_oparg(handler_label), + target: BlockIdx::NULL, + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: Some(NO_LOCATION_OVERRIDE), + }, + ) } - let only_markers_between = block.instructions[2..cond_idx].iter().all(|info| { - matches!( - info.instr.real(), - None | Some(Instruction::Nop | Instruction::NotTaken) + pub(crate) fn push_unmapped_instr_sequence_label(&mut self) -> crate::InternalResult<()> { + instruction_sequence_label_map_push_unmapped_label( + &mut self.instr_sequence_label_map, + &mut self.instr_sequence, ) - }); - if !only_markers_between { - return None; } - block.instructions[cond_idx].instr.real() -} - -fn opposite_short_circuit_target(block: &Block, source: AnyInstruction) -> bool { - let Some(conditional) = short_circuit_stub_conditional(block) else { - return false; - }; - matches!( - (source.real(), Some(conditional)), - ( - Some(Instruction::PopJumpIfFalse { .. }), - Some(Instruction::PopJumpIfTrue { .. }) - ) | ( - Some(Instruction::PopJumpIfTrue { .. }), - Some(Instruction::PopJumpIfFalse { .. }) - ) - ) -} + pub(crate) fn push_unlabeled_instr_sequence_block(&mut self) -> crate::InternalResult<()> { + instruction_sequence_label_map_push_unlabeled_block(&mut self.instr_sequence_label_map) + } -fn same_short_circuit_target(block: &Block, source: AnyInstruction) -> Option { - let conditional = short_circuit_stub_conditional(block)?; - matches!( - (source.real(), Some(conditional)), - ( - Some(Instruction::PopJumpIfFalse { .. }), - Some(Instruction::PopJumpIfFalse { .. }) - ) | ( - Some(Instruction::PopJumpIfTrue { .. }), - Some(Instruction::PopJumpIfTrue { .. }) - ) - ) - .then_some(block.instructions[trailing_conditional_jump_index(block)?].target) + fn take_recorded_instr_sequence(&mut self) -> crate::InternalResult { + let mut instr_sequence = + core::mem::replace(&mut self.instr_sequence, instruction_sequence_new()); + if let Some(mut annotations_instr_sequence) = self.annotations_instr_sequence.take() { + instruction_sequence_apply_label_map(&mut annotations_instr_sequence)?; + instruction_sequence_set_annotations_code( + &mut instr_sequence, + Some(Box::new(annotations_instr_sequence)), + ); + } + Ok(instr_sequence) + } + + fn prepare_cfg_from_codegen(&mut self) -> crate::InternalResult { + // CPython compile.c optimize_and_assemble_code_unit passes + // u_instr_sequence directly into flowgraph.c _PyCfg_FromInstructionSequence(). + self.take_recorded_instr_sequence() + } +} + +fn optimize_code_unit( + metadata: &mut CodeUnitMetadata, + blocks: &mut Vec, + instr_sequence: InstructionSequence, + nlocals: usize, + nparams: usize, +) -> crate::InternalResult<()> { + // Phase 1: _PyCfg_OptimizeCodeUnit (flowgraph.c) + *blocks = cfg_from_instruction_sequence(instr_sequence)?; + translate_jump_labels_to_targets(blocks)?; + mark_except_handlers(blocks)?; + label_exception_targets(blocks)?; + optimize_cfg(metadata, blocks, metadata.firstlineno)?; + remove_unused_consts(blocks, &mut metadata.consts)?; + add_checks_for_loads_of_uninitialized_variables(blocks, nlocals, nparams)?; + // CPython inserts superinstructions in _PyCfg_OptimizeCodeUnit, before + // later jump normalization / block reordering can create adjacencies + // that never exist at this stage in flowgraph.c. + insert_superinstructions(blocks)?; + push_cold_blocks_to_end(blocks)?; + // CPython resolves line numbers again after cold-block extraction. + resolve_line_numbers(blocks, metadata.firstlineno)?; + Ok(()) +} + +fn optimize_cfg( + metadata: &mut CodeUnitMetadata, + blocks: &mut Vec, + firstlineno: OneIndexed, +) -> crate::InternalResult<()> { + // flowgraph.c optimize_cfg + // CPython optimize_cfg() starts with check_cfg() and raises + // SystemError if a jump or scope exit is not the last instruction in + // its block. + check_cfg(blocks)?; + inline_small_or_no_lineno_blocks(blocks)?; + // CPython does not re-run instruction-sequence label-map/CFG conversion + // after this point. Unreferenced label blocks left by jump inlining + // remain block boundaries and can preserve line-marker NOPs. + remove_unreachable(blocks)?; + // CPython optimize_cfg resolves line numbers before local checks and + // superinstruction insertion, so fusion decisions see propagated + // source locations. + resolve_line_numbers(blocks, firstlineno)?; + // CPython optimize_cfg() runs optimize_load_const() and then + // optimize_basic_block() after line numbers are resolved. + optimize_load_const(metadata, blocks)?; + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next_block = blocks[block_idx.idx()].next; + optimize_basic_block(blocks, metadata, block_idx)?; + block_idx = next_block; + } + remove_redundant_nops_and_pairs(blocks)?; + // CPython optimize_cfg() removes newly-unreachable blocks and + // redundant NOP/jump chains before _PyCfg_OptimizeCodeUnit() prunes + // unused constants. + remove_unreachable(blocks)?; + remove_redundant_nops_and_jumps(blocks)?; + #[cfg(debug_assertions)] + assert!(no_redundant_jumps(blocks)); + Ok(()) +} + +fn optimized_cfg_to_instruction_sequence( + metadata: &CodeUnitMetadata, + flags: CodeFlags, + blocks: &mut Vec, +) -> crate::InternalResult<(u32, usize, InstructionSequence)> { + // Phase 2: _PyCfg_OptimizedCfgToInstructionSequence (flowgraph.c) + convert_pseudo_conditional_jumps(blocks)?; + let max_stackdepth = calculate_stackdepth(blocks)?; + debug_assert!(!is_generator(flags) || max_stackdepth != 0); + let nlocalsplus = prepare_localsplus(metadata, blocks, flags)?; + // Match CPython order: pseudo ops are lowered after stackdepth and + // localsplus preparation, before normalize_jumps. + convert_pseudo_ops(blocks)?; + normalize_jumps(blocks)?; + #[cfg(debug_assertions)] + assert!(no_redundant_jumps(blocks)); + // optimize_load_fast: after normalize_jumps + optimize_load_fast(blocks)?; + + let mut instr_sequence = instruction_sequence_new(); + cfg_to_instruction_sequence(blocks, &mut instr_sequence)?; + Ok((max_stackdepth, nlocalsplus, instr_sequence)) } -#[derive(Clone, Copy, PartialEq, Eq)] -enum JumpThreadKind { - Plain, - NoInterrupt, -} +impl CodeInfo { + #[allow(clippy::needless_range_loop)] + pub fn finalize_code( + mut self, + opts: &crate::compile::CompileOpts, + ) -> crate::InternalResult { + let instr_sequence = self.prepare_cfg_from_codegen()?; + let nlocals = self.metadata.varnames.len(); + let nparams = self.nparams; + optimize_code_unit( + &mut self.metadata, + &mut self.blocks, + instr_sequence, + nlocals, + nparams, + )?; + let (max_stackdepth, nlocalsplus, mut instr_sequence) = + optimized_cfg_to_instruction_sequence(&self.metadata, self.flags, &mut self.blocks)?; + let localsplusinfo = compute_localsplus_info(&self.metadata, nlocalsplus, self.flags)?; -fn jump_thread_kind(instr: AnyInstruction) -> Option { - Some(match instr.into() { - AnyOpcode::Pseudo(PseudoOpcode::Jump) - | AnyOpcode::Real(Opcode::JumpForward | Opcode::JumpBackward) => JumpThreadKind::Plain, - AnyOpcode::Pseudo(PseudoOpcode::JumpNoInterrupt) - | AnyOpcode::Real(Opcode::JumpBackwardNoInterrupt) => JumpThreadKind::NoInterrupt, - _ => return None, - }) -} + let Self { + flags, + source_path, + private: _, // private is only used during compilation -fn threaded_jump_instr( - source: AnyInstruction, - target: AnyInstruction, - conditional: bool, -) -> Option { - let target_kind = jump_thread_kind(target)?; - if conditional { - return (target_kind == JumpThreadKind::Plain).then_some(source); - } + blocks: _, + current_block: _, + instr_sequence: _, + instr_sequence_label_map: _, + annotations_instr_sequence: _, + metadata, + static_attributes: _, + in_inlined_comp: _, + fblock: _, + symbol_table_index: _, + nparams: _, + in_conditional_block: _, + next_conditional_annotation_index: _, + } = self; - let source_kind = jump_thread_kind(source)?; - let result_kind = if source_kind == JumpThreadKind::NoInterrupt - && target_kind == JumpThreadKind::NoInterrupt - { - JumpThreadKind::NoInterrupt - } else { - JumpThreadKind::Plain - }; + let CodeUnitMetadata { + name: obj_name, + qualname, + consts: constants, + names: name_cache, + varnames: varname_cache, + cellvars: _, + freevars: freevar_cache, + fast_hidden: _, + fast_hidden_final: _, + argcount: arg_count, + posonlyargcount: posonlyarg_count, + kwonlyargcount: kwonlyarg_count, + firstlineno: first_line_number, + } = metadata; - Some(match (source.into(), result_kind) { - (AnyOpcode::Pseudo(_), JumpThreadKind::Plain) => PseudoOpcode::Jump.into(), - (AnyOpcode::Pseudo(_), JumpThreadKind::NoInterrupt) => PseudoOpcode::JumpNoInterrupt.into(), - (AnyOpcode::Real(Opcode::JumpBackwardNoInterrupt), JumpThreadKind::Plain) => { - Opcode::JumpBackward.into() - } - (AnyOpcode::Real(Opcode::JumpBackwardNoInterrupt), JumpThreadKind::NoInterrupt) => source, - (AnyOpcode::Real(Opcode::JumpForward | Opcode::JumpBackward), JumpThreadKind::Plain) => { - source - } - ( - AnyOpcode::Real(Opcode::JumpForward | Opcode::JumpBackward), - JumpThreadKind::NoInterrupt, - ) => PseudoOpcode::JumpNoInterrupt.into(), - _ => return None, - }) -} + resolve_unconditional_jumps(&mut instr_sequence)?; + resolve_jump_offsets(&mut instr_sequence)?; + let assembled = assemble_emit( + &mut instr_sequence, + first_line_number.get() as i32, + opts.debug_ranges, + )?; + let locations = rustpython_compiler_core::marshal::linetable_to_locations( + &assembled.linetable, + first_line_number.get() as i32, + assembled.instructions.len(), + ); -fn can_thread_conditional_through_forward_nointerrupt( - source: AnyInstruction, - target_pos: u32, - final_target_pos: u32, -) -> bool { - matches!( - source.real(), - Some(Instruction::PopJumpIfNone { .. } | Instruction::PopJumpIfNotNone { .. }) - ) && final_target_pos > target_pos -} + Ok(CodeObject { + flags, + posonlyarg_count, + arg_count, + kwonlyarg_count, + source_path, + first_line_number: Some(first_line_number), + obj_name: obj_name.clone(), + qualname: qualname.unwrap_or(obj_name), -fn jump_threading_impl(blocks: &mut [Block], include_conditional: bool) { - let mut changed = true; - while changed { - changed = false; - let mut block_order = vec![u32::MAX; blocks.len()]; - let mut cursor = BlockIdx(0); - let mut pos = 0u32; - while cursor != BlockIdx::NULL { - block_order[cursor.idx()] = pos; - pos += 1; - cursor = blocks[cursor.idx()].next; - } - for bi in 0..blocks.len() { - let last_idx = match blocks[bi].instructions.len().checked_sub(1) { - Some(i) => i, - None => continue, - }; - let ins = blocks[bi].instructions[last_idx]; - let mut target = ins.target; - if target == BlockIdx::NULL { - continue; - } - if !(ins.instr.is_unconditional_jump() - || include_conditional && is_conditional_jump(&ins.instr)) - { - continue; - } - target = next_nonempty_block(blocks, target); - if target == BlockIdx::NULL { - continue; - } - if include_conditional && is_conditional_jump(&ins.instr) { - let next = next_nonempty_block(blocks, blocks[bi].next); - let next_is_scope_exit = next != BlockIdx::NULL - && blocks[next.idx()] - .instructions - .last() - .is_some_and(|instr| instr.instr.is_scope_exit()); - if next_is_scope_exit { - let target_pos = block_order.get(target.idx()).copied().unwrap_or(u32::MAX); - let target_first_jump = blocks[target.idx()].instructions.first().copied(); - let threads_match_success_jump_to_forward_nointerrupt = - matches!(ins.instr.real(), Some(Instruction::PopJumpIfNone { .. })) - && target_first_jump - .filter(|target_ins| target_ins.instr.is_unconditional_jump()) - .filter(|target_ins| target_ins.target != BlockIdx::NULL) - .is_some_and(|target_ins| { - let final_target_pos = block_order - .get(target_ins.target.idx()) - .copied() - .unwrap_or(u32::MAX); - jump_thread_kind(target_ins.instr) - == Some(JumpThreadKind::NoInterrupt) - && target_ins.match_success_jump - && final_target_pos > target_pos - }); - let next_raises = blocks[next.idx()].instructions.iter().any(|instr| { - matches!(instr.instr.real(), Some(Instruction::RaiseVarargs { .. })) - }); - let target_is_loop_backedge = blocks[target.idx()] - .instructions - .first() - .filter(|target_ins| target_ins.instr.is_unconditional_jump()) - .map(|target_ins| next_nonempty_block(blocks, target_ins.target)) - .is_some_and(|final_target| { - final_target == BlockIdx(bi as u32) - || comes_before(blocks, final_target, BlockIdx(bi as u32)) - }); - if !(threads_match_success_jump_to_forward_nointerrupt - || block_is_protected(&blocks[bi]) - && next_raises - && target_is_loop_backedge) - { - continue; - } - } - } - if include_conditional - && is_conditional_jump(&ins.instr) - && opposite_short_circuit_target(&blocks[target.idx()], ins.instr) - { - let final_target = next_nonempty_block(blocks, blocks[target.idx()].next); - if final_target != BlockIdx::NULL && ins.target != final_target { - blocks[bi].instructions[last_idx].target = final_target; - changed = true; - continue; - } - } - if include_conditional - && is_conditional_jump(&ins.instr) - && let Some(final_target) = - same_short_circuit_target(&blocks[target.idx()], ins.instr) - && final_target != BlockIdx::NULL - && ins.target != final_target - { - blocks[bi].instructions[last_idx].target = final_target; - changed = true; - continue; - } - if include_conditional && is_conditional_jump(&ins.instr) { - let source_pos = block_order[bi]; - let target_pos = block_order.get(target.idx()).copied().unwrap_or(u32::MAX); - if target_pos <= source_pos { - continue; - } - } - // Match CPython's early flowgraph jump threading: inspect the - // target block's first instruction only. A later unconditional-only - // cleanup pass may thread through line-anchor NOPs introduced after - // jump normalization. - let target_jump = if include_conditional { - blocks[target.idx()].instructions.first().copied() - } else { - blocks[target.idx()] - .instructions - .iter() - .find(|info| !matches!(info.instr.real(), Some(Instruction::Nop))) - .copied() - }; - if let Some(target_ins) = target_jump - && target_ins.instr.is_unconditional_jump() - && target_ins.target != BlockIdx::NULL - && target_ins.target != target - { - if !include_conditional - && blocks[target.idx()] - .instructions - .iter() - .take_while(|info| matches!(info.instr.real(), Some(Instruction::Nop))) - .any(instruction_has_lineno) - { - continue; - } - if !include_conditional && instruction_has_lineno(&target_ins) { - continue; - } - let source_pos = block_order[bi]; - let target_pos = block_order.get(target.idx()).copied().unwrap_or(u32::MAX); - let final_target = target_ins.target; - let final_target_pos = block_order - .get(final_target.idx()) - .copied() - .unwrap_or(u32::MAX); - let conditional = is_conditional_jump(&ins.instr); - if !include_conditional && source_pos < target_pos && final_target_pos < target_pos - { - // Keep the forward hop when threading would turn it into a - // backward edge. CPython preserves this shape for chained - // compare loop exits to avoid wraparound-style jumps. - continue; - } - if !include_conditional - && matches!( - jump_thread_kind(ins.instr), - Some(JumpThreadKind::NoInterrupt) - ) - && matches!( - jump_thread_kind(target_ins.instr), - Some(JumpThreadKind::Plain) - ) - { - // CPython does not late-thread WITH suppress exits through - // the line-anchored continue/break jump that follows. - continue; - } - let Some(threaded_instr) = (if conditional { - match jump_thread_kind(target_ins.instr) { - Some(JumpThreadKind::Plain) => Some(ins.instr), - Some(JumpThreadKind::NoInterrupt) - if target_ins.match_success_jump - && can_thread_conditional_through_forward_nointerrupt( - ins.instr, - target_pos, - final_target_pos, - ) => - { - // A forward JUMP_NO_INTERRUPT assembles to the same - // JUMP_FORWARD opcode as CPython's plain match - // success jump. Limit this to None-check match - // success tests; boolean finally-cleanup jumps need - // the stronger no-interrupt shape for later CFG - // cleanup decisions. - Some(ins.instr) - } - _ => None, - } - } else { - threaded_jump_instr(ins.instr, target_ins.instr, false) - }) else { - continue; - }; - if ins.target == final_target { - continue; - } - set_to_nop(&mut blocks[bi].instructions[last_idx]); - let mut threaded = ins; - threaded.instr = threaded_instr; - threaded.arg = OpArg::new(0); - threaded.target = final_target; - threaded.location = target_ins.location; - threaded.end_location = target_ins.end_location; - threaded.cache_entries = 0; - blocks[bi].instructions.push(threaded); - changed = true; - } - } - if include_conditional { - break; - } + max_stackdepth, + instructions: CodeUnits::from(assembled.instructions), + locations, + constants: constants.into_iter().collect(), + names: name_cache.into_iter().collect(), + varnames: varname_cache.into_iter().collect(), + cellvars: localsplusinfo.cellvars, + freevars: freevar_cache.into_iter().collect(), + localspluskinds: localsplusinfo.kinds, + linetable: assembled.linetable, + exceptiontable: assembled.exceptiontable, + }) } } -fn is_conditional_jump(instr: &AnyInstruction) -> bool { - matches!( - instr.real().map(Into::into), - Some( - Opcode::PopJumpIfFalse - | Opcode::PopJumpIfTrue - | Opcode::PopJumpIfNone - | Opcode::PopJumpIfNotNone - ) - ) -} - -fn is_false_path_conditional_jump(instr: &AnyInstruction) -> bool { - matches!( - instr.real().map(Into::into), - Some(Opcode::PopJumpIfFalse | Opcode::PopJumpIfNone | Opcode::PopJumpIfNotNone) - ) -} - -/// Invert a conditional jump opcode. -fn reversed_conditional(instr: &AnyInstruction) -> Option { - Some(match AnyOpcode::from(*instr).real()? { - Opcode::PopJumpIfFalse => Opcode::PopJumpIfTrue.into(), - Opcode::PopJumpIfTrue => Opcode::PopJumpIfFalse.into(), - Opcode::PopJumpIfNone => Opcode::PopJumpIfNotNone.into(), - Opcode::PopJumpIfNotNone => Opcode::PopJumpIfNone.into(), - _ => return None, - }) +/// flowgraph.c IS_GENERATOR +fn is_generator(flags: CodeFlags) -> bool { + flags.intersects(CodeFlags::GENERATOR | CodeFlags::COROUTINE | CodeFlags::ASYNC_GENERATOR) } -/// flowgraph.c normalize_jumps -fn normalize_jumps(blocks: &mut Vec) { - let mut visit_order = Vec::new(); - let mut visited = vec![false; blocks.len()]; - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - visit_order.push(current); - visited[current.idx()] = true; - current = blocks[current.idx()].next; - } - - visited.fill(false); - - for &block_idx in &visit_order { - let idx = block_idx.idx(); - visited[idx] = true; - - // Normalize conditional jumps: forward gets NOT_TAKEN, backward gets inverted - let last = blocks[idx].instructions.last(); - if let Some(last_ins) = last - && is_conditional_jump(&last_ins.instr) - && last_ins.target != BlockIdx::NULL - { - let target = last_ins.target; - let is_forward = !visited[target.idx()]; - - if is_forward { - // Insert NOT_TAKEN after forward conditional jump - let not_taken = InstructionInfo { - instr: Opcode::NotTaken.into(), - arg: OpArg::new(0), +/// flowgraph.c insert_prefix_instructions +fn insert_prefix_instructions( + metadata: &CodeUnitMetadata, + blocks: &mut [Block], + cellfixedoffsets: &[i32], + nfreevars: usize, + flags: CodeFlags, +) -> crate::InternalResult<()> { + debug_assert!(!blocks.is_empty()); + let entry = &mut blocks[0]; + let ncellvars = metadata.cellvars.len(); + let firstlineno = metadata.firstlineno; + debug_assert!(firstlineno.get() > 0); + + if is_generator(flags) { + let location = SourceLocation { + line: firstlineno, + character_offset: OneIndexed::MIN, + }; + basicblock_insert_instruction( + entry, + 0, + InstructionInfo { + instr: Instruction::ReturnGenerator.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location, + end_location: location, + except_handler: None, + lineno_override: Some(LINE_ONLY_LOCATION_OVERRIDE), + }, + )?; + basicblock_insert_instruction( + entry, + 1, + InstructionInfo { + instr: Instruction::PopTop.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location, + end_location: location, + except_handler: None, + lineno_override: Some(LINE_ONLY_LOCATION_OVERRIDE), + }, + )?; + } + + if ncellvars > 0 { + let nvars = metadata.varnames.len() + ncellvars; + let mut sorted = Vec::new(); + vec_try_reserve_exact(&mut sorted, nvars)?; + sorted.resize(nvars, 0i32); + for i in 0..ncellvars { + sorted[cellfixedoffsets[i] as usize] = i as i32 + 1; + } + let mut ncellsused = 0; + let mut i = 0; + while ncellsused < ncellvars { + let oldindex = sorted[i] - 1; + i += 1; + if oldindex == -1 { + continue; + } + basicblock_insert_instruction( + entry, + ncellsused, + InstructionInfo { + instr: Instruction::MakeCell { i: Arg::marker() }.into(), + arg: OpArg::new(oldindex as u32), target: BlockIdx::NULL, - location: last_ins.location, - end_location: last_ins.end_location, - except_handler: last_ins.except_handler, - folded_from_nonliteral_expr: false, - lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - }; - blocks[idx].instructions.push(not_taken); - } else { - // Backward conditional jump: invert and create new block - // Transform: `cond_jump T` (backward) - // Into: `reversed_cond_jump b_next` + new block [NOT_TAKEN, JUMP T] - let loc = last_ins.location; - let end_loc = last_ins.end_location; - let exc_handler = last_ins.except_handler; - - if let Some(reversed) = reversed_conditional(&last_ins.instr) { - let old_next = blocks[idx].next; - let is_cold = blocks[idx].cold; - let disable_load_fast_borrow = blocks[idx].disable_load_fast_borrow; - - // Create new block with NOT_TAKEN + JUMP to original backward target - let new_block_idx = BlockIdx(blocks.len() as u32); - let mut new_block = Block { - cold: is_cold, - disable_load_fast_borrow, - ..Block::default() - }; - new_block.instructions.push(InstructionInfo { - instr: Opcode::NotTaken.into(), - arg: OpArg::new(0), - target: BlockIdx::NULL, - location: loc, - end_location: end_loc, - except_handler: exc_handler, - folded_from_nonliteral_expr: false, - lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - }); - new_block.instructions.push(InstructionInfo { - instr: PseudoOpcode::Jump.into(), - arg: OpArg::new(0), - target, - location: loc, - end_location: end_loc, - except_handler: exc_handler, - folded_from_nonliteral_expr: false, - lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - }); - new_block.next = old_next; + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: Some(NO_LOCATION_OVERRIDE), + }, + )?; + ncellsused += 1; + } + } - // Update the conditional jump: invert opcode, target = old next block - let last_mut = blocks[idx].instructions.last_mut().unwrap(); - last_mut.instr = reversed; - last_mut.target = old_next; + if nfreevars > 0 { + basicblock_insert_instruction( + entry, + 0, + InstructionInfo { + instr: Instruction::CopyFreeVars { n: Arg::marker() }.into(), + arg: OpArg::new(nfreevars as u32), + target: BlockIdx::NULL, + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: Some(NO_LOCATION_OVERRIDE), + }, + )?; + } + Ok(()) +} - // Splice new block between current and old next - blocks[idx].next = new_block_idx; - blocks.push(new_block); +/// flowgraph.c prepare_localsplus +fn prepare_localsplus( + metadata: &CodeUnitMetadata, + blocks: &mut [Block], + flags: CodeFlags, +) -> crate::InternalResult { + let nlocals = metadata.varnames.len(); + let ncellvars = metadata.cellvars.len(); + let nfreevars = metadata.freevars.len(); + let int_max = i32::MAX as usize; + debug_assert!(nlocals < int_max); + debug_assert!(ncellvars < int_max); + debug_assert!(nfreevars < int_max); + debug_assert!(int_max - nlocals - ncellvars > 0); + debug_assert!(int_max - nlocals - ncellvars - nfreevars > 0); + let mut nlocalsplus = nlocals + ncellvars + nfreevars; + let mut cellfixedoffsets = build_cellfixedoffsets(metadata)?; + + // This must be called before fix_cell_offsets(). + insert_prefix_instructions(metadata, blocks, &cellfixedoffsets, nfreevars, flags)?; + + let numdropped = fix_cell_offsets(metadata, blocks, &mut cellfixedoffsets); + nlocalsplus -= numdropped; + Ok(nlocalsplus) +} + +/// flowgraph.c remove_unreachable +fn remove_unreachable(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + blocks[block_idx.idx()].predecessors = 0; + block_idx = blocks[block_idx.idx()].next; + } + + let mut stack = make_cfg_traversal_stack(blocks)?; + blocks[0].predecessors = 1; + stack.push(BlockIdx(0)); + blocks[0].visited = true; + while let Some(current) = stack.pop() { + let idx = current.idx(); + let next = blocks[idx].next; + if next != BlockIdx::NULL && bb_has_fallthrough(&blocks[idx]) { + if !blocks[next.idx()].visited { + debug_assert_eq!(blocks[next.idx()].predecessors, 0); + stack.push(next); + blocks[next.idx()].visited = true; + } + blocks[next.idx()].predecessors += 1; + } - // Extend visited array and update visit order - visited.push(true); + let instr_count = blocks[idx].instruction_used; + for i in 0..instr_count { + let instr = blocks[idx].instructions[i]; + if is_jump(&instr) || is_block_push(&instr) { + let target = instr.target; + debug_assert!(target != BlockIdx::NULL); + let target_idx = target.idx(); + if !blocks[target_idx].visited { + stack.push(target); + blocks[target_idx].visited = true; } + blocks[target_idx].predecessors += 1; } } } - // Rebuild visit_order since backward normalization may have added new blocks - let mut visit_order = Vec::new(); - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - visit_order.push(current); - current = blocks[current.idx()].next; + block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let i = block_idx.idx(); + let next = blocks[i].next; + if blocks[i].predecessors == 0 { + let block = &mut blocks[i]; + basicblock_clear(block); + block.except_handler = false; + } + block_idx = next; } + Ok(()) +} - // Resolve JUMP/JUMP_NO_INTERRUPT pseudo instructions before offset fixpoint. - let mut block_order = vec![0u32; blocks.len()]; - for (pos, &block_idx) in visit_order.iter().enumerate() { - block_order[block_idx.idx()] = pos as u32; +/// flowgraph.c eval_const_unaryop +fn eval_const_unaryop( + operand: &ConstantData, + op: Instruction, + intrinsic: Option, +) -> Option { + match (operand, op, intrinsic) { + (ConstantData::Integer { value }, Instruction::UnaryNegative, None) => { + Some(ConstantData::Integer { value: -value }) + } + (ConstantData::Float { value }, Instruction::UnaryNegative, None) => { + Some(ConstantData::Float { value: -value }) + } + (ConstantData::Complex { value }, Instruction::UnaryNegative, None) => { + Some(ConstantData::Complex { value: -value }) + } + (ConstantData::Boolean { value }, Instruction::UnaryNegative, None) => { + Some(ConstantData::Integer { + value: BigInt::from(-i32::from(*value)), + }) + } + (ConstantData::Integer { value }, Instruction::UnaryInvert, None) => { + Some(ConstantData::Integer { value: !value }) + } + (ConstantData::Boolean { .. }, Instruction::UnaryInvert, None) => None, + (_, Instruction::UnaryNot, None) => Some(ConstantData::Boolean { + value: !constant_truthiness(operand), + }), + ( + ConstantData::Integer { value }, + Instruction::CallIntrinsic1 { .. }, + Some(oparg::IntrinsicFunction1::UnaryPositive), + ) => Some(ConstantData::Integer { + value: value.clone(), + }), + ( + ConstantData::Float { value }, + Instruction::CallIntrinsic1 { .. }, + Some(oparg::IntrinsicFunction1::UnaryPositive), + ) => Some(ConstantData::Float { value: *value }), + ( + ConstantData::Boolean { value }, + Instruction::CallIntrinsic1 { .. }, + Some(oparg::IntrinsicFunction1::UnaryPositive), + ) => Some(ConstantData::Integer { + value: BigInt::from(i32::from(*value)), + }), + ( + ConstantData::Complex { value }, + Instruction::CallIntrinsic1 { .. }, + Some(oparg::IntrinsicFunction1::UnaryPositive), + ) => Some(ConstantData::Complex { value: *value }), + _ => None, } +} - for &block_idx in &visit_order { - let source_pos = block_order[block_idx.idx()]; - for info in &mut blocks[block_idx.idx()].instructions { - let target = info.target; - if target == BlockIdx::NULL { - continue; - } - let target_pos = block_order[target.idx()]; - info.instr = match info.instr.into() { - AnyOpcode::Pseudo(PseudoOpcode::Jump) => { - if target_pos > source_pos { - Opcode::JumpForward.into() - } else { - Opcode::JumpBackward.into() - } - } - AnyOpcode::Pseudo(PseudoOpcode::JumpNoInterrupt) => { - if target_pos > source_pos { - Opcode::JumpForward.into() - } else { - Opcode::JumpBackwardNoInterrupt.into() - } - } - _ => info.instr, - }; +fn constant_truthiness(constant: &ConstantData) -> bool { + match constant { + ConstantData::Tuple { elements } | ConstantData::Frozenset { elements } => { + !elements.is_empty() } + ConstantData::Integer { value } => !value.is_zero(), + ConstantData::Float { value } => *value != 0.0, + ConstantData::Complex { value } => value.re != 0.0 || value.im != 0.0, + ConstantData::Boolean { value } => *value, + ConstantData::Str { value } => !value.is_empty(), + ConstantData::Bytes { value } => !value.is_empty(), + ConstantData::Code { .. } | ConstantData::Slice { .. } | ConstantData::Ellipsis => true, + ConstantData::None => false, } } -/// flowgraph.c inline_small_or_no_lineno_blocks -fn inline_small_or_no_lineno_blocks(blocks: &mut [Block]) { - const MAX_COPY_SIZE: usize = 4; +fn load_const_truthiness( + instr: Instruction, + arg: OpArg, + metadata: &CodeUnitMetadata, +) -> Option { + match instr { + Instruction::LoadConst { consti } => { + let constant = &metadata.consts[consti.get(arg).as_usize()]; + Some(constant_truthiness(constant)) + } + Instruction::LoadSmallInt { i } => Some(i.get(arg) != 0), + _ => None, + } +} - let block_exits_scope = |block: &Block| { - block - .instructions - .last() - .is_some_and(|ins| ins.instr.is_scope_exit()) - }; - let block_has_no_lineno = |block: &Block| { - block - .instructions - .iter() - .all(|ins| !instruction_has_lineno(ins)) - }; - let target_pushes_handler = |block: &Block| { - block - .instructions - .iter() - .any(|ins| ins.instr.is_block_push()) - }; - let block_ends_with_list_to_tuple_jump = |block: &Block| { - let ops: Vec<_> = block - .instructions - .iter() - .filter(|info| !matches!(info.instr.real(), Some(Instruction::Nop))) - .collect(); - let Some((last, prefix)) = ops.split_last() else { - return false; - }; - if !last.instr.is_unconditional_jump() { - return false; +/// flowgraph.c add_const +fn add_const( + metadata: &mut CodeUnitMetadata, + constant: ConstantData, +) -> crate::InternalResult { + Ok(metadata.consts.try_insert_full(constant)?.0) +} + +fn instr_make_load_const( + metadata: &mut CodeUnitMetadata, + instr: &mut InstructionInfo, + constant: ConstantData, +) -> crate::InternalResult<()> { + if maybe_instr_make_load_smallint(instr, &constant) { + return Ok(()); + } + + let const_idx = add_const(metadata, constant)?; + instr_set_op1( + instr, + Instruction::LoadConst { + consti: Arg::marker(), } - let Some(prev) = prefix.last() else { - return false; - }; - match prev.instr.real() { - Some(Instruction::CallIntrinsic1 { func }) => { - func.get(prev.arg) == IntrinsicFunction1::ListToTuple - } - _ => false, + .into(), + OpArg::new(const_idx as u32), + ); + Ok(()) +} + +/// flowgraph.c fold_const_unaryop +fn fold_const_unaryop( + metadata: &mut CodeUnitMetadata, + block: &mut Block, + i: usize, +) -> crate::InternalResult { + let instr = &block.instructions[i]; + let (op, intrinsic) = match instr.instr.real() { + Some(Instruction::UnaryNegative) => (Instruction::UnaryNegative, None), + Some(Instruction::UnaryInvert) => (Instruction::UnaryInvert, None), + Some(Instruction::UnaryNot) => (Instruction::UnaryNot, None), + Some(Instruction::CallIntrinsic1 { func }) + if matches!( + func.get(instr.arg), + oparg::IntrinsicFunction1::UnaryPositive + ) => + { + ( + Instruction::CallIntrinsic1 { + func: Arg::marker(), + }, + Some(func.get(instr.arg)), + ) } + _ => return Ok(false), }; - let block_starts_with_store_and_exits = |block: &Block| { - block.instructions.first().is_some_and(|info| { - matches!( - info.instr.real(), - Some( - Instruction::StoreFast { .. } - | Instruction::StoreGlobal { .. } - | Instruction::StoreName { .. } - | Instruction::StoreDeref { .. } - ) - ) - }) && block_exits_scope(block) + let Some(operand_index) = (if let Some(start) = i.checked_sub(1) { + get_const_loading_instrs(block, start, 1)? + } else { + None + }) + .and_then(|indices| indices.into_iter().next()) else { + return Ok(false); }; - let block_is_simple_fast_return = |block: &Block| { - matches!( - block.instructions.as_slice(), - [load, ret] - if matches!( - load.instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }) - ) && matches!(ret.instr.real(), Some(Instruction::ReturnValue)) - ) + let operand = get_const_value(metadata, &block.instructions[operand_index]); + let Some(operand) = operand else { + return Ok(false); }; - let normal_layout_fallthrough_into = |blocks: &[Block], target: BlockIdx| { - let mut current = BlockIdx(0); - let mut previous_nonempty = BlockIdx::NULL; - while current != BlockIdx::NULL && current != target { - if !blocks[current.idx()].instructions.is_empty() { - previous_nonempty = current; - } - current = blocks[current.idx()].next; - } - previous_nonempty != BlockIdx::NULL - && block_has_fallthrough(&blocks[previous_nonempty.idx()]) - && next_nonempty_block(blocks, blocks[previous_nonempty.idx()].next) == target + let Some(folded_const) = eval_const_unaryop(&operand, op, intrinsic) else { + return Ok(false); }; + nop_out(block, &[operand_index]); + instr_make_load_const(metadata, &mut block.instructions[i], folded_const)?; + Ok(true) +} + +/// flowgraph.c get_const_loading_instrs +fn get_const_loading_instrs( + block: &Block, + mut start: usize, + size: usize, +) -> crate::InternalResult>> { + let mut indices = Vec::new(); + indices + .try_reserve_exact(size) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; loop { - let mut changes = false; - let mut predecessors = vec![0usize; blocks.len()]; - for block in blocks.iter() { - if block.next != BlockIdx::NULL { - predecessors[block.next.idx()] += 1; - } - for info in &block.instructions { - if info.target != BlockIdx::NULL { - predecessors[info.target.idx()] += 1; - } - } + if start >= block.instruction_used { + return Ok(None); } - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let next = blocks[current.idx()].next; - let Some(last) = blocks[current.idx()].instructions.last().copied() else { - current = next; - continue; - }; - if !last.instr.is_unconditional_jump() || last.target == BlockIdx::NULL { - current = next; - continue; + let instr = &block.instructions[start]; + if !matches!(instr.instr.real(), Some(Instruction::Nop)) { + if !loads_const(instr) { + return Ok(None); } - - let target = last.target; - if is_named_except_cleanup_normal_exit_block(&blocks[current.idx()]) - && target_pushes_handler(&blocks[target.idx()]) - && !named_except_cleanup_body_is_fast_local_only(&blocks[current.idx()]) - { - current = next; - continue; - } - let small_exit_block = block_exits_scope(&blocks[target.idx()]) - && blocks[target.idx()].instructions.len() <= MAX_COPY_SIZE; - let no_lineno_no_fallthrough = block_has_no_lineno(&blocks[target.idx()]) - && !block_has_fallthrough(&blocks[target.idx()]); - let shared_artificial_expr_exit = small_exit_block - && predecessors[target.idx()] > 1 - && is_artificial_expr_stmt_exit_block(&blocks[target.idx()]) - && !instruction_has_lineno(&blocks[target.idx()].instructions[0]) - && !instruction_has_lineno(&blocks[target.idx()].instructions[1]) - && !instruction_has_lineno(&blocks[target.idx()].instructions[2]); - let shared_tuple_genexpr_assignment_tail = small_exit_block - && predecessors[target.idx()] > 1 - && block_ends_with_list_to_tuple_jump(&blocks[current.idx()]) - && block_starts_with_store_and_exits(&blocks[target.idx()]); - if !shared_artificial_expr_exit - && !shared_tuple_genexpr_assignment_tail - && (small_exit_block || no_lineno_no_fallthrough) - { - let removed_jump_kind = jump_thread_kind(last.instr); - let preserve_removed_jump_nop = last.preserve_redundant_jump_as_nop; - let keep_removed_jump_nop = removed_jump_kind == Some(JumpThreadKind::NoInterrupt) - || blocks[current.idx()] - .instructions - .last() - .is_some_and(instruction_has_lineno); - if keep_removed_jump_nop { - let preserve_empty_end_label_nop = removed_jump_kind - == Some(JumpThreadKind::NoInterrupt) - && small_exit_block - && blocks[current.idx()].instructions.len() == 1 - && block_is_simple_fast_return(&blocks[target.idx()]) - && !normal_layout_fallthrough_into(blocks, current); - if let Some(last_instr) = blocks[current.idx()].instructions.last_mut() { - let lineno_override = last_instr.lineno_override; - set_to_nop(last_instr); - last_instr.lineno_override = lineno_override; - last_instr.preserve_block_start_no_location_nop |= - preserve_removed_jump_nop; - if preserve_empty_end_label_nop { - last_instr.lineno_override = None; - last_instr.preserve_block_start_no_location_nop = true; - } - } - } else { - let _ = blocks[current.idx()].instructions.pop(); - } - blocks[current.idx()] - .instructions - .extend(blocks[target.idx()].instructions.clone()); - if no_lineno_no_fallthrough - && removed_jump_kind == Some(JumpThreadKind::Plain) - && let Some(last) = blocks[current.idx()].instructions.last_mut() - && jump_thread_kind(last.instr) == Some(JumpThreadKind::NoInterrupt) - { - last.instr = match last.instr.into() { - AnyOpcode::Pseudo(PseudoOpcode::JumpNoInterrupt) => { - PseudoOpcode::Jump.into() - } - AnyOpcode::Real(Opcode::JumpBackwardNoInterrupt) => { - Opcode::JumpBackward.into() - } - _ => last.instr, - }; - } - changes = true; + indices.push(start); + if indices.len() == size { + break; } - - current = next; } + let Some(prev) = start.checked_sub(1) else { + return Ok(None); + }; + start = prev; + } + indices.reverse(); + Ok(Some(indices)) +} - if !changes { - break; - } +/// flowgraph.c nop_out +fn nop_out(block: &mut Block, instrs: &[usize]) { + for &i in instrs { + nop_out_no_location(&mut block.instructions[i]); } } -fn is_artificial_expr_stmt_exit_block(block: &Block) -> bool { - matches!( - block.instructions.as_slice(), - [ - InstructionInfo { - instr: AnyInstruction::Real(Instruction::PopTop), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::ReturnValue), - .. - } - ] - ) +/// flowgraph.c fold_const_binop +fn fold_const_binop( + metadata: &mut CodeUnitMetadata, + block: &mut Block, + i: usize, +) -> crate::InternalResult { + use oparg::BinaryOperator as BinOp; + + let Some(Instruction::BinaryOp { .. }) = block.instructions[i].instr.real() else { + return Ok(false); + }; + let Some(operand_indices) = (if let Some(start) = i.checked_sub(1) { + get_const_loading_instrs(block, start, 2)? + } else { + None + }) else { + return Ok(false); + }; + let op_raw = u32::from(block.instructions[i].arg); + let Ok(op) = BinOp::try_from(op_raw) else { + return Ok(false); + }; + let left = get_const_value(metadata, &block.instructions[operand_indices[0]]); + let right = get_const_value(metadata, &block.instructions[operand_indices[1]]); + let (Some(left_val), Some(right_val)) = (left, right) else { + return Ok(false); + }; + let Some(result_const) = eval_const_binop(&left_val, &right_val, op) else { + return Ok(false); + }; + nop_out(block, &operand_indices); + instr_make_load_const(metadata, &mut block.instructions[i], result_const)?; + Ok(true) } -fn inline_single_predecessor_artificial_expr_exit_blocks(blocks: &mut [Block]) { - let predecessors = compute_predecessors(blocks); +/// flowgraph.c loads_const +fn loads_const(info: &InstructionInfo) -> bool { + info.instr.has_const() || matches!(info.instr.real(), Some(Instruction::LoadSmallInt { .. })) +} - for idx in 0..blocks.len() { - let Some(last) = blocks[idx].instructions.last().copied() else { - continue; - }; - if !last.instr.is_unconditional_jump() || last.target == BlockIdx::NULL { - continue; +/// flowgraph.c get_const_value +fn get_const_value(metadata: &CodeUnitMetadata, info: &InstructionInfo) -> Option { + match info.instr.real() { + Some(Instruction::LoadSmallInt { .. }) => { + let v = u32::from(info.arg) as i32; + Some(ConstantData::Integer { + value: BigInt::from(v), + }) } - - let target = next_nonempty_block(blocks, last.target); - if target == BlockIdx::NULL - || predecessors[target.idx()] != 1 - || !is_artificial_expr_stmt_exit_block(&blocks[target.idx()]) - { - continue; + _ if info.instr.has_const() => { + let idx = u32::from(info.arg) as usize; + metadata.consts.get_index(idx).cloned() } + _ => None, + } +} - let is_jump_wrapper = blocks[idx] - .instructions - .split_last() - .is_some_and(|(_, prefix)| { - prefix - .iter() - .all(|ins| matches!(ins.instr.real(), Some(Instruction::Nop))) - }); - if is_jump_wrapper { - continue; +/// flowgraph.c const_folding_check_complexity +fn const_folding_check_complexity(obj: &ConstantData, mut limit: isize) -> Option { + if let ConstantData::Tuple { elements } = obj { + limit -= isize::try_from(elements.len()).ok()?; + if limit < 0 { + return None; } - - if blocks[idx] - .instructions - .last() - .is_some_and(instruction_has_lineno) - { - if let Some(last_instr) = blocks[idx].instructions.last_mut() { - set_to_nop(last_instr); - } - } else { - let _ = blocks[idx].instructions.pop(); + for element in elements { + limit = const_folding_check_complexity(element, limit)?; } - blocks[idx] - .instructions - .extend(blocks[target.idx()].instructions.clone()); } + Some(limit) } -struct TargetPredecessorFlags { - targeted: Vec, - plain_jump: Vec, +fn repeat_wtf8(value: &Wtf8Buf, n: usize) -> Option { + let mut result = Wtf8Buf::new(); + result.try_reserve_exact(value.len().checked_mul(n)?).ok()?; + for _ in 0..n { + result.push_wtf8(value); + } + Some(result) } -fn compute_target_predecessor_flags(blocks: &[Block]) -> TargetPredecessorFlags { - let mut targeted = vec![false; blocks.len()]; - let mut plain_jump = vec![false; blocks.len()]; - for block in blocks { - for instr in &block.instructions { - if instr.target == BlockIdx::NULL { - continue; - } - let target = next_nonempty_block(blocks, instr.target); - if target == BlockIdx::NULL { - continue; - } - let idx = target.idx(); - targeted[idx] = true; - if matches!(jump_thread_kind(instr.instr), Some(JumpThreadKind::Plain)) { - plain_jump[idx] = true; - } - } - } - TargetPredecessorFlags { - targeted, - plain_jump, +fn checked_repeat_count(n: &BigInt, item_size: usize) -> Option { + let n = n.to_isize()?; + if item_size != 0 && (n < 0 || n as usize > MAX_STR_SIZE / item_size) { + return None; } + Some(n.max(0) as usize) } -fn remove_redundant_nops_in_blocks(blocks: &mut [Block]) -> usize { - let mut changes = 0; - let plain_jump_targets = compute_target_predecessor_flags(blocks).plain_jump; - let layout_predecessors = compute_layout_predecessors(blocks); - let mut block_order = Vec::new(); - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - block_order.push(current); - current = blocks[current.idx()].next; - } - for block_idx in block_order { - let bi = block_idx.idx(); - let keep_target_start_nop = - keep_target_start_no_location_nop(blocks, block_idx, &layout_predecessors); - let follows_same_line_pop_iter = - layout_predecessor_ends_with_pop_iter_on_line(blocks, block_idx, &layout_predecessors); - let mut src_instructions = core::mem::take(&mut blocks[bi].instructions); - let mut kept = Vec::with_capacity(src_instructions.len()); - let mut prev_lineno = -1i32; - - for src in 0..src_instructions.len() { - let instr = src_instructions[src]; - let lineno = instruction_lineno(&instr); - let mut remove = false; - - if matches!(instr.instr.real(), Some(Instruction::Nop)) { - if instr.no_location_exit && instr.preserve_redundant_jump_as_nop { - remove = false; - } else if src == 0 - && lineno > 0 - && ((!keep_target_start_nop && follows_same_line_pop_iter == Some(lineno)) - || (instr.preserve_block_start_no_location_nop - && block_tail_starts_with_async_with_normal_exit( - &src_instructions[src + 1..], - ))) - { - remove = true; - } else if instr.preserve_redundant_jump_as_nop - || instr.preserve_block_start_no_location_nop - { - remove = false; - } else if lineno < 0 { - remove = true; - } else if instr.remove_no_location_nop - && src == 0 - && plain_jump_targets[block_idx.idx()] - && instr.lineno_override.is_some() - && !keep_target_start_nop - { - let next_lineno = src_instructions[src + 1..].iter().find_map(|next_instr| { - let line = instruction_lineno(next_instr); - if matches!(next_instr.instr.real(), Some(Instruction::Nop)) && line < 0 { - None - } else { - Some(line) - } - }); - if next_lineno.is_some_and(|next_lineno| lineno < next_lineno) { - remove = true; - } - } else if prev_lineno == lineno { - remove = true; - } else if src < src_instructions.len() - 1 { - if src_instructions[src + 1].instr.is_block_push() { - remove = false; - } else if src_instructions[src + 1].instr.is_unconditional_jump() - && src_instructions[src + 1].target != block_idx - { - let next_lineno = instruction_lineno(&src_instructions[src + 1]); - if next_lineno == lineno || next_lineno < 0 { - src_instructions[src + 1].lineno_override = Some(lineno); - remove = true; - } - } else if src_instructions[src + 1].folded_from_nonliteral_expr { - remove = true; - } else { - let next_lineno = instruction_lineno(&src_instructions[src + 1]); - if next_lineno == lineno { - remove = true; - } else if next_lineno < 0 { - src_instructions[src + 1].lineno_override = Some(lineno); - remove = true; - } - } - } else { - let next = next_nonempty_block(blocks, blocks[bi].next); - if next != BlockIdx::NULL { - let mut next_info = None; - for (next_idx, next_instr) in - blocks[next.idx()].instructions.iter().enumerate() - { - let line = instruction_lineno(next_instr); - if matches!(next_instr.instr.real(), Some(Instruction::Nop)) && line < 0 - { - continue; - } - next_info = Some((next_idx, line)); - break; - } - if let Some((next_idx, next_lineno)) = next_info { - if next_lineno == lineno { - remove = true; - } else if next_lineno < 0 { - blocks[next.idx()].instructions[next_idx].lineno_override = - Some(lineno); - remove = true; - } - } - } +/// flowgraph.c const_folding_safe_multiply +fn const_folding_safe_multiply(left: &ConstantData, right: &ConstantData) -> Option { + match (left, right) { + (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) => { + if !l.is_zero() && !r.is_zero() && l.bits() + r.bits() > MAX_INT_SIZE { + return None; + } + Some(ConstantData::Integer { value: l * r }) + } + (ConstantData::Float { value: l }, ConstantData::Float { value: r }) => { + Some(ConstantData::Float { value: l * r }) + } + (ConstantData::Str { value: s }, ConstantData::Integer { value: n }) => { + let n = checked_repeat_count(n, s.code_points().count())?; + Some(ConstantData::Str { + value: repeat_wtf8(s, n)?, + }) + } + (ConstantData::Integer { .. }, ConstantData::Str { .. }) => { + const_folding_safe_multiply(right, left) + } + (ConstantData::Bytes { value: b }, ConstantData::Integer { value: n }) => { + let n = checked_repeat_count(n, b.len())?; + let mut value = Vec::new(); + value.try_reserve_exact(b.len().checked_mul(n)?).ok()?; + for _ in 0..n { + value.extend_from_slice(b); + } + Some(ConstantData::Bytes { value }) + } + (ConstantData::Integer { .. }, ConstantData::Bytes { .. }) => { + const_folding_safe_multiply(right, left) + } + (ConstantData::Tuple { elements }, ConstantData::Integer { value: n }) => { + let n = n.to_usize()?; + if n != 0 && !elements.is_empty() { + if n > MAX_COLLECTION_SIZE / elements.len() { + return None; } + const_folding_check_complexity( + &ConstantData::Tuple { + elements: elements.clone(), + }, + MAX_TOTAL_ITEMS / isize::try_from(n).ok()?, + )?; } - - if remove { - changes += 1; - } else { - kept.push(instr); - prev_lineno = lineno; + let mut result = Vec::new(); + result + .try_reserve_exact(elements.len().checked_mul(n)?) + .ok()?; + for _ in 0..n { + result.extend(elements.iter().cloned()); } + Some(ConstantData::Tuple { elements: result }) } - - blocks[bi].instructions = kept; + (ConstantData::Integer { .. }, ConstantData::Tuple { .. }) => { + const_folding_safe_multiply(right, left) + } + _ => None, } - - changes } -fn remove_redundant_jumps_in_blocks(blocks: &mut [Block]) -> usize { - let mut changes = 0; - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = next_nonempty_block(blocks, blocks[idx].next); - if next != BlockIdx::NULL { - let Some(last_instr) = blocks[idx].instructions.last().copied() else { - current = blocks[idx].next; - continue; - }; - if last_instr.instr.is_unconditional_jump() - && last_instr.target != BlockIdx::NULL - && next_nonempty_block(blocks, last_instr.target) == next - { - let preserve_redundant_jump_nop = if last_instr.preserve_redundant_jump_as_nop { - let line = instruction_lineno(&last_instr); - let next_line = blocks[next.idx()].instructions.iter().find_map(|instr| { - let line = instruction_lineno(instr); - (!matches!(instr.instr.real(), Some(Instruction::Nop)) || line >= 0) - .then_some(line) - }); - line < 0 - || line > 0 - && !block_jump_follows_async_send_pop(&blocks[idx]) - && !(block_jump_follows_with_normal_exit(&blocks[idx]) - && block_tail_starts_with_async_with_normal_exit( - &blocks[next.idx()].instructions, - )) - && next_line.is_some_and(|next_line| next_line < line) +/// flowgraph.c const_folding_safe_power +fn const_folding_safe_power(left: &ConstantData, right: &ConstantData) -> Option { + match (left, right) { + (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) => { + if r < &BigInt::from(0) { + if l.is_zero() { + return None; + } + let base = l.to_f64()?; + if !base.is_finite() { + return None; + } + let result = if let Some(exp) = r.to_i32() { + base.powi(exp) } else { - false + base.powf(r.to_f64()?) }; - let last_instr = blocks[idx].instructions.last_mut().unwrap(); - let remove_no_location_nop = last_instr.remove_no_location_nop; - let folded_operand_nop = last_instr.folded_operand_nop; - let preserve_block_start_no_location_nop = - last_instr.preserve_block_start_no_location_nop; - set_to_nop(last_instr); - last_instr.preserve_redundant_jump_as_nop = preserve_redundant_jump_nop; - last_instr.remove_no_location_nop = remove_no_location_nop; - last_instr.folded_operand_nop = folded_operand_nop; - last_instr.preserve_block_start_no_location_nop = - preserve_block_start_no_location_nop || preserve_redundant_jump_nop; - changes += 1; - current = blocks[idx].next; - continue; + if !result.is_finite() { + return None; + } + return Some(ConstantData::Float { value: result }); + } + let exp: u64 = r.try_into().ok()?; + let exp_usize = usize::try_from(exp).ok()?; + if !l.is_zero() && exp > 0 && l.bits() > MAX_INT_SIZE / exp { + return None; } + Some(ConstantData::Integer { + value: num_traits::pow::pow(l.clone(), exp_usize), + }) } - current = blocks[idx].next; + (ConstantData::Float { value: l }, ConstantData::Float { value: r }) => { + let result = l.powf(*r); + result + .is_finite() + .then_some(ConstantData::Float { value: result }) + } + _ => None, } - changes } - -fn remove_redundant_nops_and_jumps(blocks: &mut [Block]) { - loop { - let removed_nops = remove_redundant_nops_in_blocks(blocks); - let removed_jumps = remove_redundant_jumps_in_blocks(blocks); - if removed_nops + removed_jumps == 0 { - break; - } + +/// flowgraph.c const_folding_safe_lshift +fn const_folding_safe_lshift(left: &ConstantData, right: &ConstantData) -> Option { + let (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) = (left, right) + else { + return None; + }; + let shift: u64 = r.try_into().ok()?; + let shift_usize = usize::try_from(shift).ok()?; + if shift > MAX_INT_SIZE || (!l.is_zero() && l.bits() > MAX_INT_SIZE - shift) { + return None; } + Some(ConstantData::Integer { + value: l << shift_usize, + }) } -fn redirect_empty_block_targets(blocks: &mut [Block]) { - let redirected_targets: Vec> = blocks - .iter() - .map(|block| { - block - .instructions - .iter() - .map(|instr| { - if instr.target == BlockIdx::NULL { - BlockIdx::NULL - } else { - next_nonempty_block(blocks, instr.target) - } - }) - .collect() - }) - .collect(); +/// flowgraph.c const_folding_safe_mod +fn const_folding_safe_mod(left: &ConstantData, right: &ConstantData) -> Option { + if matches!(left, ConstantData::Str { .. } | ConstantData::Bytes { .. }) { + return None; + } - for (block, block_targets) in blocks.iter_mut().zip(redirected_targets) { - for (instr, target) in block.instructions.iter_mut().zip(block_targets) { - if target != BlockIdx::NULL { - instr.target = target; + match (left, right) { + (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) => { + if r.is_zero() { + return None; } + let rem = l.clone() % r.clone(); + let value = if !rem.is_zero() && (rem < BigInt::from(0)) != (*r < BigInt::from(0)) { + rem + r + } else { + rem + }; + Some(ConstantData::Integer { value }) } + (ConstantData::Float { value: l }, ConstantData::Float { value: r }) => { + let (_, modulo) = float_div_mod(*l, *r)?; + Some(ConstantData::Float { value: modulo }) + } + _ => None, } } -fn redirect_empty_unconditional_jump_targets(blocks: &mut [Block]) { - const MAX_COPY_SIZE: usize = 4; +fn float_div_mod(left: f64, right: f64) -> Option<(f64, f64)> { + if right == 0.0 { + return None; + } - let block_exits_to_large_reraise = |block_idx: BlockIdx| { - let block = &blocks[block_idx.idx()]; - let Some(last) = block.instructions.last() else { - return false; - }; - let reraise_block = if matches!(last.instr.real(), Some(Instruction::Reraise { .. })) { - block_idx - } else if last.instr.is_unconditional_jump() && last.target != BlockIdx::NULL { - next_nonempty_block(blocks, last.target) + let mut modulo = left % right; + let div = (left - modulo) / right; + let floordiv = if modulo != 0.0 { + let div = if (right < 0.0) != (modulo < 0.0) { + modulo += right; + div - 1.0 } else { - BlockIdx::NULL + div }; - reraise_block != BlockIdx::NULL - && blocks[reraise_block.idx()].instructions.len() > MAX_COPY_SIZE - && blocks[reraise_block.idx()] - .instructions - .last() - .is_some_and(|instr| { - matches!(instr.instr.real(), Some(Instruction::Reraise { .. })) - }) + let mut floordiv = div.floor(); + if div - floordiv > 0.5 { + floordiv += 1.0; + } + floordiv + } else { + modulo = 0.0f64.copysign(right); + 0.0f64.copysign(left / right) }; - let mut raw_predecessors = vec![0u32; blocks.len()]; - for block in blocks.iter() { - if block_has_fallthrough(block) && block.next != BlockIdx::NULL { - raw_predecessors[block.next.idx()] += 1; + Some((floordiv, modulo)) +} + +/// flowgraph.c eval_const_binop complex result construction +fn eval_const_complex_const(value: Complex) -> Option { + (value.re.is_finite() && value.im.is_finite()).then_some(ConstantData::Complex { value }) +} + +/// flowgraph.c eval_const_binop complex operations +fn eval_const_complex_binop( + left: Complex, + right: Complex, + op: oparg::BinaryOperator, +) -> Option { + use oparg::BinaryOperator as BinOp; + + let value = match op { + BinOp::Add => left + right, + BinOp::Subtract => { + let re = left.re - right.re; + // Preserve CPython's signed-zero behavior for real-zero + // minus zero-complex expressions such as `0 - 0j`. + let im = if left.re == 0.0 + && left.im == 0.0 + && right.re == 0.0 + && right.im == 0.0 + && !right.im.is_sign_negative() + { + -0.0 + } else { + left.im - right.im + }; + Complex::new(re, im) } - for instr in &block.instructions { - if instr.target != BlockIdx::NULL { - raw_predecessors[instr.target.idx()] += 1; + BinOp::Multiply => left * right, + BinOp::TrueDivide => { + if right == Complex::new(0.0, 0.0) { + return None; } + left / right } - } + BinOp::Power => { + if left == Complex::new(0.0, 0.0) { + if right.im != 0.0 || right.re < 0.0 { + return None; + } - let redirected_targets: Vec> = blocks - .iter() - .map(|block| { - block - .instructions - .iter() - .map(|instr| { - if instr.target == BlockIdx::NULL || !instr.instr.is_unconditional_jump() { - instr.target - } else { - if blocks[instr.target.idx()].instructions.is_empty() - && raw_predecessors[instr.target.idx()] > 1 - && { - let target = next_nonempty_block(blocks, instr.target); - target != BlockIdx::NULL && block_exits_to_large_reraise(target) - } - { - return instr.target; - } - let target = next_nonempty_block(blocks, instr.target); - if matches!( - jump_thread_kind(instr.instr), - Some(JumpThreadKind::NoInterrupt) - ) && (target == BlockIdx::NULL - || is_jump_back_only_block(blocks, target)) - { - instr.target - } else { - target - } - } - }) - .collect() - }) - .collect(); + return eval_const_complex_const(if right.re == 0.0 { + Complex::new(1.0, 0.0) + } else { + Complex::new(0.0, 0.0) + }); + } - for (block, block_targets) in blocks.iter_mut().zip(redirected_targets) { - for (instr, target) in block.instructions.iter_mut().zip(block_targets) { - if target != BlockIdx::NULL { - instr.target = target; + if right.im == 0.0 + && right.re.fract() == 0.0 + && right.re >= f64::from(i32::MIN) + && right.re <= f64::from(i32::MAX) + { + left.powi(right.re as i32) + } else { + left.powc(right) } } - } + _ => return None, + }; + eval_const_complex_const(value) } -fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { - fn block_starts_with_with_normal_exit(block: &Block) -> bool { - matches!( - block.instructions.as_slice(), - [ - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::Call { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::PopTop), - .. - }, - .. - ] - ) +/// flowgraph.c eval_const_binop subscript index conversion +fn constant_as_index(value: &ConstantData) -> Option { + match value { + ConstantData::Integer { value } => value.to_i64().or_else(|| { + if value < &BigInt::from(0) { + Some(i64::MIN) + } else { + Some(i64::MAX) + } + }), + ConstantData::Boolean { value } => Some(i64::from(*value)), + _ => None, } +} - fn with_normal_exit_is_followed_by_try(blocks: &[Block], block_idx: BlockIdx) -> bool { - if block_idx == BlockIdx::NULL - || !block_starts_with_with_normal_exit(&blocks[block_idx.idx()]) - { - return false; - } - let next = next_nonempty_block(blocks, blocks[block_idx.idx()].next); - next != BlockIdx::NULL - && blocks[next.idx()].instructions.first().is_some_and(|info| { - matches!( - info.instr, - AnyInstruction::Pseudo(PseudoInstruction::SetupFinally { .. }) - | AnyInstruction::Real(Instruction::Nop) - ) - }) +/// flowgraph.c eval_const_binop subscript slice bound conversion +fn slice_bound(value: &ConstantData) -> Option> { + match value { + ConstantData::None => Some(None), + _ => constant_as_index(value).map(Some), } +} - fn has_loop_backedge_to(blocks: &[Block], target: BlockIdx) -> bool { - blocks.iter().enumerate().any(|(source_idx, block)| { - let source = BlockIdx(source_idx as u32); - comes_before(blocks, target, source) - && block.instructions.iter().any(|info| { - info.instr.is_unconditional_jump() - && info.target != BlockIdx::NULL - && next_nonempty_block(blocks, info.target) == target - }) - }) +/// flowgraph.c eval_const_binop subscript slice index adjustment +fn adjusted_slice_indices(len: usize, slice: &[ConstantData; 3]) -> Option> { + let len = i64::try_from(len).ok()?; + let start = slice_bound(&slice[0])?; + let stop = slice_bound(&slice[1])?; + let step = slice_bound(&slice[2])?.unwrap_or(1); + if step == 0 || step == i64::MIN { + return None; } - let mut jump_back_inserts = Vec::new(); - let mut inserts = Vec::new(); - let mut target_start_inserts = Vec::new(); - let mut jump_back_target_locations = Vec::new(); - for (block_idx, block) in blocks.iter().enumerate() { - let source = BlockIdx(block_idx as u32); - let (last, allow_scope_exit_target) = if let Some(last) = block - .instructions - .last() - .filter(|info| is_conditional_jump(&info.instr)) - { - (last, true) - } else if let Some(cond_idx) = trailing_conditional_jump_index(block) { - (&block.instructions[cond_idx], false) - } else { - continue; - }; - if last.target == BlockIdx::NULL { - continue; - } - let target = last.target; - if !blocks[target.idx()].instructions.is_empty() { - if is_jump_back_only_block(blocks, target) && block_has_no_lineno(&blocks[target.idx()]) - { - jump_back_target_locations.push((*last, target)); - } - if with_normal_exit_is_followed_by_try(blocks, target) - && has_loop_backedge_to(blocks, source) - && !matches!( - blocks[target.idx()] - .instructions - .first() - .and_then(|info| info.instr.real()), - Some(Instruction::Nop) - ) - { - target_start_inserts.push((*last, target)); + let step_is_negative = step < 0; + let lower = if step_is_negative { -1 } else { 0 }; + let upper = if step_is_negative { len - 1 } else { len }; + let adjust = |value: Option, default: i64| { + let mut value = value.unwrap_or(default); + if value < 0 { + value = value.saturating_add(len); + if value < 0 { + value = lower; } - continue; - } - let next = next_nonempty_block(blocks, blocks[target.idx()].next); - if next != BlockIdx::NULL - && is_jump_only_block(&blocks[next.idx()]) - && block_has_no_lineno(&blocks[next.idx()]) - && comes_before( - blocks, - next_nonempty_block(blocks, blocks[next.idx()].instructions[0].target), - next, - ) - { - jump_back_inserts.push((BlockIdx(block_idx as u32), target, next)); - continue; + } else if value >= len { + value = upper; } - if next == BlockIdx::NULL - || !((allow_scope_exit_target && is_scope_exit_block(&blocks[next.idx()])) - || (with_normal_exit_is_followed_by_try(blocks, next) - && has_loop_backedge_to(blocks, source))) - { - continue; + value + }; + let start = adjust(start, if step_is_negative { upper } else { lower }); + let stop = adjust(stop, if step_is_negative { lower } else { upper }); + + let mut index = i128::from(start); + let stop = i128::from(stop); + let step = i128::from(step); + let slice_len = if step > 0 { + if index < stop { + usize::try_from((stop - index - 1) / step + 1).ok()? + } else { + 0 } - inserts.push((*last, target)); - } - - for (source, target) in jump_back_target_locations { - if !is_jump_back_only_block(blocks, target) || !block_has_no_lineno(&blocks[target.idx()]) { - continue; + } else if index > stop { + usize::try_from((index - stop - 1) / -step + 1).ok()? + } else { + 0 + }; + let mut indices = Vec::new(); + indices.try_reserve_exact(slice_len).ok()?; + if step > 0 { + while index < stop { + indices.push(usize::try_from(index).ok()?); + index += step; } - if let Some(first) = blocks[target.idx()].instructions.first_mut() { - overwrite_location(first, source.location, source.end_location); + } else { + while index > stop { + indices.push(usize::try_from(index).ok()?); + index += step; } } + Some(indices) +} - for (source, target, next) in jump_back_inserts { - if !blocks[target.idx()].instructions.is_empty() { - continue; - } - let Some(last) = blocks[source.idx()].instructions.last().copied() else { - continue; - }; - let mut cloned = blocks[next.idx()].instructions[0]; - overwrite_location(&mut cloned, last.location, last.end_location); - blocks[target.idx()].instructions.push(cloned); +/// flowgraph.c eval_const_binop subscript index adjustment +fn adjusted_const_index(len: usize, index: &ConstantData) -> Option { + let len = i64::try_from(len).ok()?; + let index = constant_as_index(index)?; + let index = if index < 0 { + index.saturating_add(len) + } else { + index + }; + if index < 0 || index >= len { + return None; } + usize::try_from(index).ok() +} - for (source, target) in inserts { - if !blocks[target.idx()].instructions.is_empty() { - continue; +/// flowgraph.c eval_const_binop NB_SUBSCR +fn eval_const_subscript(container: &ConstantData, index: &ConstantData) -> Option { + match (container, index) { + ( + ConstantData::Str { value }, + ConstantData::Integer { .. } | ConstantData::Boolean { .. }, + ) => { + let string = value.to_string(); + if string.contains(char::REPLACEMENT_CHARACTER) { + return None; + } + let mut chars = Vec::new(); + chars.try_reserve_exact(string.chars().count()).ok()?; + chars.extend(string.chars()); + let index = adjusted_const_index(chars.len(), index)?; + Some(ConstantData::Str { + value: chars[index].to_string().into(), + }) } - blocks[target.idx()].instructions.push(InstructionInfo { - instr: Instruction::Nop.into(), - arg: OpArg::NULL, - target: BlockIdx::NULL, - location: source.location, - end_location: source.end_location, - except_handler: None, - folded_from_nonliteral_expr: false, - lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - }); - } - - for (source, target) in target_start_inserts.into_iter().rev() { - if !with_normal_exit_is_followed_by_try(blocks, target) - || matches!( - blocks[target.idx()] - .instructions - .first() - .and_then(|info| info.instr.real()), - Some(Instruction::Nop) - ) - { - continue; + (ConstantData::Str { value }, ConstantData::Slice { elements }) => { + let string = value.to_string(); + if string.contains(char::REPLACEMENT_CHARACTER) { + return None; + } + let mut chars = Vec::new(); + chars.try_reserve_exact(string.chars().count()).ok()?; + chars.extend(string.chars()); + let indices = adjusted_slice_indices(chars.len(), elements)?; + let capacity = indices.iter().try_fold(0usize, |capacity, &index| { + capacity.checked_add(chars[index].len_utf8()) + })?; + let mut result = String::new(); + result.try_reserve_exact(capacity).ok()?; + for index in indices { + result.push(chars[index]); + } + Some(ConstantData::Str { + value: result.into(), + }) } - blocks[target.idx()].instructions.insert( - 0, - InstructionInfo { - instr: Instruction::Nop.into(), - arg: OpArg::NULL, - target: BlockIdx::NULL, - location: source.location, - end_location: source.end_location, - except_handler: None, - folded_from_nonliteral_expr: false, - lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - }, - ); - } -} - -fn merge_unsafe_mask(slot: &mut Option>, incoming: &[bool]) -> bool { - match slot { - Some(existing) => { - let mut changed = false; - for (dst, src) in existing.iter_mut().zip(incoming.iter().copied()) { - if src && !*dst { - *dst = true; - changed = true; - } + ( + ConstantData::Bytes { value }, + ConstantData::Integer { .. } | ConstantData::Boolean { .. }, + ) => { + let index = adjusted_const_index(value.len(), index)?; + Some(ConstantData::Integer { + value: BigInt::from(value[index]), + }) + } + (ConstantData::Bytes { value }, ConstantData::Slice { elements }) => { + let indices = adjusted_slice_indices(value.len(), elements)?; + let mut result = Vec::new(); + result.try_reserve_exact(indices.len()).ok()?; + for index in indices { + result.push(value[index]); } - changed + Some(ConstantData::Bytes { value: result }) } - None => { - *slot = Some(incoming.to_vec()); - true + ( + ConstantData::Tuple { elements }, + ConstantData::Integer { .. } | ConstantData::Boolean { .. }, + ) => { + let index = adjusted_const_index(elements.len(), index)?; + Some(elements[index].clone()) + } + (ConstantData::Tuple { elements }, ConstantData::Slice { elements: slice }) => { + let indices = adjusted_slice_indices(elements.len(), slice)?; + let mut result = Vec::new(); + result.try_reserve_exact(indices.len()).ok()?; + for index in indices { + result.push(elements[index].clone()); + } + Some(ConstantData::Tuple { elements: result }) } + _ => None, } } -/// Follow chain of empty blocks to find first non-empty block. -fn next_nonempty_block(blocks: &[Block], mut idx: BlockIdx) -> BlockIdx { - while idx != BlockIdx::NULL - && blocks[idx.idx()].instructions.is_empty() - && blocks[idx.idx()].next != BlockIdx::NULL - { - idx = blocks[idx.idx()].next; +/// flowgraph.c eval_const_binop bool/int coercion +fn constant_as_int(value: &ConstantData) -> Option<(BigInt, bool)> { + match value { + ConstantData::Boolean { value } => Some((BigInt::from(u8::from(*value)), true)), + ConstantData::Integer { value } => Some((value.clone(), false)), + _ => None, } - idx -} - -fn is_load_const_none(instr: &InstructionInfo, metadata: &CodeUnitMetadata) -> bool { - matches!(instr.instr.real(), Some(Instruction::LoadConst { .. })) - && matches!( - metadata.consts.get_index(u32::from(instr.arg) as usize), - Some(ConstantData::None) - ) -} - -fn block_tail_starts_with_async_with_normal_exit(instructions: &[InstructionInfo]) -> bool { - matches!( - instructions, - [ - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::LoadConst { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::Call { .. }), - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::GetAwaitable { .. }), - .. - }, - .. - ] - ) -} - -fn instruction_lineno(instr: &InstructionInfo) -> i32 { - instr - .lineno_override - .unwrap_or_else(|| instr.location.line.get() as i32) -} - -fn instruction_has_lineno(instr: &InstructionInfo) -> bool { - instruction_lineno(instr) > 0 -} - -fn propagation_location(instr: &InstructionInfo) -> Option<(SourceLocation, SourceLocation)> { - instruction_has_lineno(instr).then_some((instr.location, instr.end_location)) -} - -fn block_has_fallthrough(block: &Block) -> bool { - block - .instructions - .last() - .is_none_or(|ins| !ins.instr.is_scope_exit() && !ins.instr.is_unconditional_jump()) } -fn is_jump_instruction(instr: &InstructionInfo) -> bool { - instr.instr.is_unconditional_jump() || is_conditional_jump(&instr.instr) -} +/// flowgraph.c eval_const_binop +fn eval_const_binop( + left: &ConstantData, + right: &ConstantData, + op: oparg::BinaryOperator, +) -> Option { + use oparg::BinaryOperator as BinOp; -fn is_exit_without_lineno(blocks: &[Block], block_idx: BlockIdx) -> bool { - let block = &blocks[block_idx.idx()]; - let Some(first) = block.instructions.first() else { - return false; - }; - if instruction_has_lineno(first) || !block_has_no_lineno(block) { - return false; + if matches!(op, BinOp::Subscr) { + return eval_const_subscript(left, right); } - if block - .instructions - .last() - .is_some_and(|last| last.instr.is_scope_exit()) + if let (Some((left_int, left_is_bool)), Some((right_int, right_is_bool))) = + (constant_as_int(left), constant_as_int(right)) + && (left_is_bool || right_is_bool) { - return true; + if left_is_bool && right_is_bool { + match op { + BinOp::And => { + return Some(ConstantData::Boolean { + value: !left_int.is_zero() & !right_int.is_zero(), + }); + } + BinOp::Or => { + return Some(ConstantData::Boolean { + value: !left_int.is_zero() | !right_int.is_zero(), + }); + } + BinOp::Xor => { + return Some(ConstantData::Boolean { + value: !left_int.is_zero() ^ !right_int.is_zero(), + }); + } + _ => {} + } + } + + return eval_const_binop( + &ConstantData::Integer { value: left_int }, + &ConstantData::Integer { value: right_int }, + op, + ); } - // CPython duplicates no-lineno exit blocks before propagating locations. - // RustPython's late CFG can inline the following synthetic jump-back block - // into that exit block first, collapsing `POP_EXCEPT; JUMP_BACKWARD` into a - // single block. Treat that merged tail as exit-like only for real loop - // backedges so resolve_line_numbers() can recover CPython's loop cleanup - // structure without duplicating cold exception-handler jumps to normal - // forward continuation code. - let Some((last, prefix)) = block.instructions.split_last() else { - return false; - }; - last.instr.is_unconditional_jump() - && prefix.iter().all(|info| { - matches!( - info.instr.real(), - Some(Instruction::PopExcept | Instruction::Nop) + match (left, right) { + (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) => { + let result = match op { + BinOp::Add => l + r, + BinOp::Subtract => l - r, + BinOp::Multiply => { + return const_folding_safe_multiply(left, right); + } + BinOp::TrueDivide => { + if r.is_zero() { + return None; + } + let l_f = l.to_f64()?; + let r_f = r.to_f64()?; + let result = l_f / r_f; + if !result.is_finite() { + return None; + } + return Some(ConstantData::Float { value: result }); + } + BinOp::FloorDivide => { + if r.is_zero() { + return None; + } + // Python floor division: round towards negative infinity + let (q, rem) = (l.clone() / r.clone(), l.clone() % r.clone()); + if !rem.is_zero() && (rem < BigInt::from(0)) != (*r < BigInt::from(0)) { + q - 1 + } else { + q + } + } + BinOp::Remainder => return const_folding_safe_mod(left, right), + BinOp::Power => return const_folding_safe_power(left, right), + BinOp::Lshift => return const_folding_safe_lshift(left, right), + BinOp::Rshift => { + let shift: u32 = r.try_into().ok()?; + l >> (shift as usize) + } + BinOp::And => l & r, + BinOp::Or => l | r, + BinOp::Xor => l ^ r, + _ => return None, + }; + Some(ConstantData::Integer { value: result }) + } + (ConstantData::Float { value: l }, ConstantData::Float { value: r }) => { + let result = match op { + BinOp::Add => l + r, + BinOp::Subtract => l - r, + BinOp::Multiply => return const_folding_safe_multiply(left, right), + BinOp::TrueDivide => { + if *r == 0.0 { + return None; + } + l / r + } + BinOp::FloorDivide => { + let (floordiv, _) = float_div_mod(*l, *r)?; + floordiv + } + BinOp::Remainder => return const_folding_safe_mod(left, right), + BinOp::Power => return const_folding_safe_power(left, right), + _ => return None, + }; + if matches!(op, BinOp::Power) && !result.is_finite() { + return None; + } + Some(ConstantData::Float { value: result }) + } + // Int op Float or Float op Int → Float + (ConstantData::Integer { value: l }, ConstantData::Float { value: r }) => { + let l_f = l.to_f64()?; + eval_const_binop( + &ConstantData::Float { value: l_f }, + &ConstantData::Float { value: *r }, + op, ) - }) - && prefix - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - && has_non_exception_loop_backedge_to(blocks, block_idx, last.target) -} - -fn is_eval_break_without_lineno(blocks: &[Block], block_idx: BlockIdx) -> bool { - let block = &blocks[block_idx.idx()]; - let Some(first) = block.instructions.first() else { - return false; - }; - !instruction_has_lineno(first) && block_has_no_lineno(block) && block_has_eval_break(block) -} - -fn block_has_eval_break(block: &Block) -> bool { - block.instructions.iter().any(|info| { - matches!( - info.instr, - AnyInstruction::Pseudo(PseudoInstruction::Jump { .. }) - | AnyInstruction::Real( - Instruction::Call { .. } - | Instruction::CallFunctionEx - | Instruction::CallKw { .. } - | Instruction::JumpBackward { .. } - | Instruction::Resume { .. } - ) - ) - }) + } + (ConstantData::Float { value: l }, ConstantData::Integer { value: r }) => { + let r_f = r.to_f64()?; + eval_const_binop( + &ConstantData::Float { value: *l }, + &ConstantData::Float { value: r_f }, + op, + ) + } + (ConstantData::Integer { value: l }, ConstantData::Complex { value: r }) => { + eval_const_complex_binop(Complex::new(l.to_f64()?, 0.0), *r, op) + } + (ConstantData::Complex { value: l }, ConstantData::Integer { value: r }) => { + eval_const_complex_binop(*l, Complex::new(r.to_f64()?, 0.0), op) + } + (ConstantData::Float { value: l }, ConstantData::Complex { value: r }) => { + eval_const_complex_binop(Complex::new(*l, 0.0), *r, op) + } + (ConstantData::Complex { value: l }, ConstantData::Float { value: r }) => { + eval_const_complex_binop(*l, Complex::new(*r, 0.0), op) + } + (ConstantData::Complex { value: l }, ConstantData::Complex { value: r }) => { + eval_const_complex_binop(*l, *r, op) + } + // String concatenation and repetition + (ConstantData::Str { value: l }, ConstantData::Str { value: r }) + if matches!(op, BinOp::Add) => + { + let mut result = Wtf8Buf::new(); + result + .try_reserve_exact(l.len().checked_add(r.len())?) + .ok()?; + result.push_wtf8(l); + result.push_wtf8(r); + Some(ConstantData::Str { value: result }) + } + (ConstantData::Str { .. }, ConstantData::Integer { .. }) + if matches!(op, BinOp::Multiply) => + { + const_folding_safe_multiply(left, right) + } + (ConstantData::Tuple { elements: l }, ConstantData::Tuple { elements: r }) + if matches!(op, BinOp::Add) => + { + let mut result = Vec::new(); + result + .try_reserve_exact(l.len().checked_add(r.len())?) + .ok()?; + result.extend(l.iter().cloned()); + result.extend(r.iter().cloned()); + Some(ConstantData::Tuple { elements: result }) + } + (ConstantData::Tuple { .. }, ConstantData::Integer { .. }) + if matches!(op, BinOp::Multiply) => + { + const_folding_safe_multiply(left, right) + } + (ConstantData::Integer { .. }, ConstantData::Tuple { .. }) + if matches!(op, BinOp::Multiply) => + { + const_folding_safe_multiply(left, right) + } + (ConstantData::Integer { .. }, ConstantData::Str { .. }) + if matches!(op, BinOp::Multiply) => + { + const_folding_safe_multiply(left, right) + } + (ConstantData::Bytes { value: l }, ConstantData::Bytes { value: r }) + if matches!(op, BinOp::Add) => + { + let mut result = Vec::new(); + result + .try_reserve_exact(l.len().checked_add(r.len())?) + .ok()?; + result.extend_from_slice(l); + result.extend_from_slice(r); + Some(ConstantData::Bytes { value: result }) + } + (ConstantData::Bytes { .. }, ConstantData::Integer { .. }) + if matches!(op, BinOp::Multiply) => + { + const_folding_safe_multiply(left, right) + } + (ConstantData::Integer { .. }, ConstantData::Bytes { .. }) + if matches!(op, BinOp::Multiply) => + { + const_folding_safe_multiply(left, right) + } + _ => None, + } } -fn block_has_no_lineno(block: &Block) -> bool { - block - .instructions - .iter() - .all(|ins| !instruction_has_lineno(ins)) -} +/// flowgraph.c fold_tuple_of_constants +fn fold_tuple_of_constants( + metadata: &mut CodeUnitMetadata, + block: &mut Block, + i: usize, +) -> crate::InternalResult { + let Some(Instruction::BuildTuple { .. }) = block.instructions[i].instr.real() else { + return Ok(false); + }; -fn shared_jump_back_target(block: &Block) -> Option { - if !block_has_no_lineno(block) { - return None; + let tuple_size = u32::from(block.instructions[i].arg) as usize; + if tuple_size > STACK_USE_GUIDELINE { + return Ok(false); } - let (last, prefix) = block.instructions.split_last()?; - if !last.instr.is_unconditional_jump() || last.target == BlockIdx::NULL { - return None; - } + let Some(operand_indices) = (if tuple_size == 0 { + Some(Vec::new()) + } else if let Some(start) = i.checked_sub(1) { + get_const_loading_instrs(block, start, tuple_size)? + } else { + None + }) else { + return Ok(false); + }; - if !prefix.iter().all(|info| { - matches!( - info.instr.real(), - Some(Instruction::PopExcept | Instruction::Nop) - ) - }) { - return None; + let mut elements = Vec::new(); + elements + .try_reserve_exact(tuple_size) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + for &j in &operand_indices { + let Some(element) = get_const_value(metadata, &block.instructions[j]) else { + return Ok(false); + }; + elements.push(element); } - Some(last.target) -} - -fn block_has_non_exception_loop_backedge_to( - blocks: &[Block], - source: BlockIdx, - target: BlockIdx, -) -> bool { - let target = next_nonempty_block(blocks, target); - source != BlockIdx::NULL - && target != BlockIdx::NULL - && !block_is_exceptional(&blocks[source.idx()]) - && comes_before(blocks, target, source) - && blocks[source.idx()].instructions.iter().any(|info| { - info.instr.is_unconditional_jump() - && info.target != BlockIdx::NULL - && next_nonempty_block(blocks, info.target) == target - }) -} - -fn has_non_exception_loop_backedge_to( - blocks: &[Block], - cleanup_block: BlockIdx, - target: BlockIdx, -) -> bool { - blocks.iter().enumerate().any(|(source_idx, block)| { - let source = BlockIdx(source_idx as u32); - source != cleanup_block - && !block_is_exceptional(block) - && block_has_non_exception_loop_backedge_to(blocks, source, target) - }) + nop_out(block, &operand_indices); + instr_make_load_const( + metadata, + &mut block.instructions[i], + ConstantData::Tuple { elements }, + )?; + Ok(true) } -fn is_jump_only_block(block: &Block) -> bool { - let [instr] = block.instructions.as_slice() else { - return false; +fn fold_constant_intrinsic_list_to_tuple( + metadata: &mut CodeUnitMetadata, + block: &mut Block, + i: usize, +) -> crate::InternalResult { + let Some(Instruction::CallIntrinsic1 { func }) = block.instructions[i].instr.real() else { + return Ok(false); }; - instr.instr.is_unconditional_jump() && instr.target != BlockIdx::NULL -} - -fn is_jump_back_only_block(blocks: &[Block], block_idx: BlockIdx) -> bool { - if block_idx == BlockIdx::NULL || !is_jump_only_block(&blocks[block_idx.idx()]) { - return false; + if func.get(block.instructions[i].arg) != IntrinsicFunction1::ListToTuple { + return Ok(false); } - comes_before( - blocks, - next_nonempty_block(blocks, blocks[block_idx.idx()].instructions[0].target), - block_idx, - ) -} - -fn is_pop_top_jump_block(block: &Block) -> bool { - let mut real_instrs = block - .instructions - .iter() - .filter(|info| !matches!(info.instr.real(), Some(Instruction::Nop))); - let Some(first) = real_instrs.next() else { - return false; - }; - let Some(second) = real_instrs.next() else { - return false; - }; - real_instrs.next().is_none() - && matches!(first.instr.real(), Some(Instruction::PopTop)) - && second.instr.is_unconditional_jump() - && second.target != BlockIdx::NULL -} -fn is_scope_exit_block(block: &Block) -> bool { - block - .instructions - .last() - .is_some_and(|instr| instr.instr.is_scope_exit()) -} + let mut consts_found = 0usize; + let mut expect_append = true; + let mut pos = i; + while let Some(prev) = pos.checked_sub(1) { + pos = prev; + let instr = &block.instructions[pos]; + if matches!(instr.instr.real(), Some(Instruction::Nop)) { + continue; + } -fn is_pop_top_scope_exit_block(block: &Block) -> bool { - is_scope_exit_block(block) - && matches!( - block - .instructions - .first() - .and_then(|info| info.instr.real()), - Some(Instruction::PopTop) - ) -} + if matches!(instr.instr.real(), Some(Instruction::BuildList { .. })) + && u32::from(instr.arg) == 0 + { + if !expect_append { + return Ok(false); + } -fn is_pop_top_exit_like_block(block: &Block) -> bool { - is_pop_top_scope_exit_block(block) || is_pop_top_jump_block(block) -} + let mut elements = Vec::new(); + elements + .try_reserve_exact(consts_found) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + for idx in (pos..i).rev() { + if matches!(block.instructions[idx].instr.real(), Some(Instruction::Nop)) { + continue; + } + if loads_const(&block.instructions[idx]) { + let Some(value) = get_const_value(metadata, &block.instructions[idx]) else { + return Ok(false); + }; + elements.push(value); + } + nop_out_no_location(&mut block.instructions[idx]); + } + debug_assert_eq!(elements.len(), consts_found); + elements.reverse(); + instr_make_load_const( + metadata, + &mut block.instructions[i], + ConstantData::Tuple { elements }, + )?; + return Ok(true); + } -fn is_loop_cleanup_block(block: &Block) -> bool { - block - .instructions - .iter() - .find_map(|info| info.instr.real()) - .is_some_and(|instr| { - matches!( - instr, - Instruction::EndFor | Instruction::EndAsyncFor | Instruction::PopIter - ) - }) -} + if expect_append { + if !matches!(instr.instr.real(), Some(Instruction::ListAppend { .. })) + || u32::from(instr.arg) != 1 + { + return Ok(false); + } + } else { + if !loads_const(instr) { + return Ok(false); + } + consts_found += 1; + } + expect_append = !expect_append; + } -fn is_async_loop_cleanup_block(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::EndAsyncFor))) + Ok(false) } -fn is_exception_cleanup_block(block: &Block) -> bool { - block - .instructions - .iter() - .any(|instr| matches!(instr.instr.real(), Some(Instruction::PopExcept))) - && block - .instructions - .last() - .is_some_and(|instr| matches!(instr.instr.real(), Some(Instruction::Reraise { .. }))) -} +/// Port of CPython's flowgraph.c optimize_lists_and_sets(). +fn optimize_lists_and_sets( + metadata: &mut CodeUnitMetadata, + block: &mut Block, + i: usize, + nextop: Option, +) -> crate::InternalResult { + let Some(instr) = block.instructions[i].instr.real() else { + return Ok(false); + }; + let is_list = matches!(instr, Instruction::BuildList { .. }); + let is_set = matches!(instr, Instruction::BuildSet { .. }); + if !is_list && !is_set { + return Ok(false); + } -fn is_reraise_scope_exit_block(block: &Block) -> bool { - block - .instructions - .last() - .is_some_and(|instr| matches!(instr.instr.real(), Some(Instruction::Reraise { .. }))) -} + let contains_or_iter = matches!( + nextop, + Some(Instruction::GetIter | Instruction::ContainsOp { .. }) + ); + let seq_size = u32::from(block.instructions[i].arg) as usize; + if seq_size > STACK_USE_GUIDELINE || (seq_size < MIN_CONST_SEQUENCE_SIZE && !contains_or_iter) { + return Ok(false); + } -fn block_starts_with_with_exit_none_call(block: &Block) -> bool { - let real_instrs: Vec<_> = block - .instructions - .iter() - .filter_map(|info| { - let instr = info.instr.real()?; - (!matches!(instr, Instruction::Nop)).then_some(instr) - }) - .take(4) - .collect(); - matches!( - real_instrs.as_slice(), - [ - Instruction::LoadConst { .. }, - Instruction::LoadConst { .. }, - Instruction::LoadConst { .. }, - Instruction::Call { .. }, - ] - ) -} + let Some(operand_indices) = (if seq_size == 0 { + Some(Vec::new()) + } else if let Some(start) = i.checked_sub(1) { + get_const_loading_instrs(block, start, seq_size)? + } else { + None + }) else { + if contains_or_iter && is_list { + let arg = block.instructions[i].arg; + instr_set_op1(&mut block.instructions[i], Opcode::BuildTuple.into(), arg); + return Ok(true); + } + return Ok(false); + }; -fn keep_target_start_no_location_nop( - blocks: &[Block], - target: BlockIdx, - layout_predecessors: &[BlockIdx], -) -> bool { - if target == BlockIdx::NULL { - return false; + let mut elements = Vec::new(); + elements + .try_reserve_exact(seq_size) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + for &j in &operand_indices { + let Some(element) = get_const_value(metadata, &block.instructions[j]) else { + return Ok(false); + }; + elements.push(element); } - let Some(first) = blocks[target.idx()].instructions.first() else { - return false; + + let const_data = if is_list { + ConstantData::Tuple { elements } + } else { + ConstantData::Frozenset { elements } }; - if !matches!(first.instr.real(), Some(Instruction::Nop)) { - return false; - } - let layout_pred = layout_predecessors[target.idx()]; - if layout_pred == BlockIdx::NULL { - return false; - } - if is_async_loop_cleanup_block(&blocks[layout_pred.idx()]) { - return true; - } - is_exception_cleanup_block(&blocks[layout_pred.idx()]) - && !block_starts_with_with_exit_none_call(&blocks[target.idx()]) -} + let const_idx = add_const(metadata, const_data)?; -fn layout_predecessor_ends_with_pop_iter_on_line( - blocks: &[Block], - target: BlockIdx, - layout_predecessors: &[BlockIdx], -) -> Option { - let layout_pred = layout_predecessors[target.idx()]; - if layout_pred == BlockIdx::NULL - || !block_has_fallthrough(&blocks[layout_pred.idx()]) - || next_nonempty_block(blocks, blocks[layout_pred.idx()].next) != target - { - return None; - } - let last = blocks[layout_pred.idx()].instructions.last()?; - matches!(last.instr.real(), Some(Instruction::PopIter)).then_some(instruction_lineno(last)) -} + if !contains_or_iter { + debug_assert!(i >= 2); + let folded_loc = block.instructions[i].location; + let end_loc = block.instructions[i].end_location; -fn is_with_suppress_exit_block(block: &Block) -> bool { - let real_instrs: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .collect(); - matches!( - real_instrs.as_slice(), - [ - Instruction::PopTop, - Instruction::PopExcept, - Instruction::PopTop, - Instruction::PopTop, - Instruction::PopTop, - last, - ] if last.is_unconditional_jump() - ) -} + nop_out(block, &operand_indices); -fn block_is_protected(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| info.except_handler.is_some()) -} + let build_instr = if is_list { + Instruction::BuildList { + count: Arg::marker(), + } + .into() + } else { + Instruction::BuildSet { + count: Arg::marker(), + } + .into() + }; + instr_set_op1(&mut block.instructions[i - 2], build_instr, OpArg::new(0)); + block.instructions[i - 2].location = folded_loc; + block.instructions[i - 2].end_location = end_loc; + block.instructions[i - 2].lineno_override = None; + + instr_set_op1( + &mut block.instructions[i - 1], + Instruction::LoadConst { + consti: Arg::marker(), + } + .into(), + OpArg::new(const_idx as u32), + ); -fn block_contains_suspension_point(block: &Block) -> bool { - block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .any(|instr| { - matches!( - instr, - Instruction::YieldValue { .. } - | Instruction::GetAwaitable { .. } - | Instruction::GetAnext - | Instruction::EndAsyncFor - ) - }) -} + let extend_instr = if is_list { + Opcode::ListExtend + } else { + Opcode::SetUpdate + }; + instr_set_op1( + &mut block.instructions[i], + extend_instr.into(), + OpArg::new(1), + ); + return Ok(true); + } -fn block_jump_follows_async_send_pop(block: &Block) -> bool { - let mut before_jump = - block - .instructions - .iter() - .rev() - .skip(1) - .filter_map(|info| match info.instr.real() { - Some(Instruction::Nop) => None, - instr => instr, - }); - matches!( - (before_jump.next(), before_jump.next()), - (Some(Instruction::PopTop), Some(Instruction::EndSend)) - ) + nop_out(block, &operand_indices); + + instr_set_op1( + &mut block.instructions[i], + Instruction::LoadConst { + consti: Arg::marker(), + } + .into(), + OpArg::new(const_idx as u32), + ); + Ok(true) } -fn block_jump_follows_with_normal_exit(block: &Block) -> bool { - let mut before_jump = - block - .instructions - .iter() - .rev() - .skip(1) - .filter_map(|info| match info.instr.real() { - Some(Instruction::Nop) => None, - instr => instr, - }); +/// flowgraph.c VISITED +const VISITED: i32 = -1; + +/// flowgraph.c SWAPPABLE +fn is_swappable(instr: &AnyInstruction) -> bool { matches!( - ( - before_jump.next(), - before_jump.next(), - before_jump.next(), - before_jump.next(), - before_jump.next(), - ), - ( - Some(Instruction::PopTop), - Some(Instruction::Call { .. }), - Some(Instruction::LoadConst { .. }), - Some(Instruction::LoadConst { .. }), - Some(Instruction::LoadConst { .. }), - ) + (*instr).into(), + AnyOpcode::Real(Opcode::StoreFast | Opcode::PopTop) + | AnyOpcode::Pseudo(PseudoOpcode::StoreFastMaybeNull) ) } -fn is_stop_iteration_error_handler_block(block: &Block) -> bool { - matches!( - block.instructions.as_slice(), - [ - InstructionInfo { - instr: AnyInstruction::Real(Instruction::CallIntrinsic1 { func }), - arg, - .. - }, - InstructionInfo { - instr: AnyInstruction::Real(Instruction::Reraise { .. }), - .. - } - ] if matches!(func.get(*arg), oparg::IntrinsicFunction1::StopIterationError) - ) +/// flowgraph.c STORES_TO +fn stores_to(info: &InstructionInfo) -> i32 { + match info.instr.into() { + AnyOpcode::Real(Opcode::StoreFast) + | AnyOpcode::Pseudo(PseudoOpcode::StoreFastMaybeNull) => u32::from(info.arg) as i32, + _ => -1, + } } -fn block_has_only_stop_iteration_error_handlers(block: &Block, blocks: &[Block]) -> bool { - let mut saw_handler = false; - for info in &block.instructions { - let Some(handler) = info.except_handler else { +/// flowgraph.c next_swappable_instruction +fn next_swappable_instruction(block: &Block, mut i: usize, lineno: i32) -> Option { + loop { + i += 1; + if i >= block.instruction_used { + return None; + } + let info = &block.instructions[i]; + let info_lineno = instruction_lineno(info); + if lineno >= 0 && info_lineno != lineno { + return None; + } + if matches!(info.instr, AnyInstruction::Real(Instruction::Nop)) { continue; - }; - saw_handler = true; - let target = next_nonempty_block(blocks, handler.handler_block); - if target == BlockIdx::NULL || !is_stop_iteration_error_handler_block(&blocks[target.idx()]) - { - return false; } + if is_swappable(&info.instr) { + return Some(i); + } + return None; } - saw_handler } -fn block_has_exception_match_handler(blocks: &[Block], block: &Block) -> bool { - let mut visited = vec![false; blocks.len()]; - let handler_blocks: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.except_handler.map(|handler| handler.handler_block)) - .collect(); - for handler_block in handler_blocks { - let mut cursor = handler_block; - while cursor != BlockIdx::NULL && !visited[cursor.idx()] { - visited[cursor.idx()] = true; - if blocks[cursor.idx()].instructions.iter().any(|info| { - matches!( - info.instr.real(), - Some(Instruction::CheckExcMatch | Instruction::CheckEgMatch) - ) - }) { - return true; +/// flowgraph.c swaptimize +fn swaptimize(block: &mut Block, ix: &mut usize) -> crate::InternalResult<()> { + debug_assert!(matches!( + block.instructions[*ix].instr.real(), + Some(Instruction::Swap { .. }) + )); + let mut depth = u32::from(block.instructions[*ix].arg) as usize; + let mut len = 1usize; + let mut more = false; + let limit = block.instruction_used - *ix; + while len < limit { + match block.instructions[*ix + len].instr.real() { + Some(Instruction::Swap { .. }) => { + depth = depth.max(u32::from(block.instructions[*ix + len].arg) as usize); + more = true; + len += 1; } - if blocks[cursor.idx()] - .instructions - .iter() - .any(|info| info.instr.is_scope_exit()) - { - break; + Some(Instruction::Nop) => { + len += 1; } - cursor = blocks[cursor.idx()].next; + _ => break, } } - false -} - -fn block_is_exceptional(block: &Block) -> bool { - block.except_handler || block.preserve_lasti || is_exception_cleanup_block(block) -} - -fn has_exceptional_duplicate_lineno(blocks: &[Block], source: BlockIdx, lineno: i32) -> bool { - blocks.iter().enumerate().any(|(idx, block)| { - BlockIdx(idx as u32) != source - && (block.cold || block_is_exceptional(block) || block_is_protected(block)) - && block - .instructions - .iter() - .any(|info| instruction_lineno(info) == lineno) - }) -} -fn trailing_conditional_jump_index(block: &Block) -> Option { - let last_idx = block.instructions.len().checked_sub(1)?; - if is_conditional_jump(&block.instructions[last_idx].instr) - && block.instructions[last_idx].target != BlockIdx::NULL - { - return Some(last_idx); + if !more { + return Ok(()); } - let cond_idx = last_idx.checked_sub(1)?; - if matches!( - block.instructions[last_idx].instr.real(), - Some(Instruction::NotTaken) - ) && is_conditional_jump(&block.instructions[cond_idx].instr) - && block.instructions[cond_idx].target != BlockIdx::NULL - { - Some(cond_idx) - } else { - None + let mut stack = Vec::new(); + stack + .try_reserve_exact(depth) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + stack.resize(depth, 0); + let mut i = 0; + while i < depth { + stack[i] = i as i32; + i += 1; } -} - -fn block_is_pure_conditional_test(block: &Block) -> bool { - let Some(cond_idx) = trailing_conditional_jump_index(block) else { - return false; - }; - block.instructions[..cond_idx].iter().all(|info| { - matches!( - info.instr.real(), - Some( - Instruction::Nop - | Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - | Instruction::LoadDeref { .. } - | Instruction::LoadGlobal { .. } - | Instruction::LoadConst { .. } - | Instruction::LoadSmallInt { .. } - | Instruction::LoadAttr { .. } - | Instruction::BinaryOp { .. } - | Instruction::ContainsOp { .. } - | Instruction::IsOp { .. } - | Instruction::CompareOp { .. } - | Instruction::ToBool - ) - ) - }) -} - -fn reorder_conditional_exit_and_jump_blocks(blocks: &mut [Block]) { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let last = blocks[idx].instructions[cond_idx]; - let Some(reversed) = reversed_conditional(&last.instr) else { - current = next; - continue; - }; + i = 0; + while i < len { + let info = &block.instructions[*ix + i]; + if matches!(info.instr.real(), Some(Instruction::Swap { .. })) { + let oparg = u32::from(info.arg) as usize; + stack.swap(0, oparg - 1); + } + i += 1; + } - let exit_start = next; - let jump_start = last.target; - if exit_start == BlockIdx::NULL || jump_start == BlockIdx::NULL || exit_start == jump_start - { - current = next; + let mut current = len as isize - 1; + for i in 0..depth { + if stack[i] == VISITED || stack[i] == i as i32 { continue; } - - let mut exit_end = BlockIdx::NULL; - let mut exit_block = BlockIdx::NULL; - let mut cursor = exit_start; - let mut exit_segment_valid = true; - while cursor != BlockIdx::NULL && cursor != jump_start { - if block_is_exceptional(&blocks[cursor.idx()]) { - exit_segment_valid = false; + let mut j = i; + loop { + if j != 0 { + debug_assert!(current >= 0); + let out = &mut block.instructions[*ix + current as usize]; + out.instr = Opcode::Swap.into(); + out.arg = OpArg::new((j + 1) as u32); + current -= 1; + } + if stack[j] == VISITED { + debug_assert_eq!(j, i); break; } - if !blocks[cursor.idx()].instructions.is_empty() { - if exit_block != BlockIdx::NULL { - exit_segment_valid = false; - break; - } - exit_block = cursor; - } - exit_end = cursor; - cursor = blocks[cursor.idx()].next; - } - if !exit_segment_valid - || cursor != jump_start - || exit_end == BlockIdx::NULL - || exit_block == BlockIdx::NULL - || !is_scope_exit_block(&blocks[exit_block.idx()]) - { - current = next; - continue; + let next_j = stack[j] as usize; + stack[j] = VISITED; + j = next_j; } + } - let mut jump_end = BlockIdx::NULL; - let mut jump_block = BlockIdx::NULL; - cursor = jump_start; - while cursor != BlockIdx::NULL { - if block_is_exceptional(&blocks[cursor.idx()]) - || block_is_protected(&blocks[cursor.idx()]) - { - jump_block = BlockIdx::NULL; - break; - } - jump_end = cursor; - if blocks[cursor.idx()].instructions.is_empty() { - cursor = blocks[cursor.idx()].next; + while current >= 0 { + set_to_nop(&mut block.instructions[*ix + current as usize]); + current -= 1; + } + *ix += len - 1; + Ok(()) +} + +/// flowgraph.c apply_static_swaps +fn apply_static_swaps(block: &mut Block, mut i: isize) { + while i >= 0 { + let idx = i as usize; + debug_assert!(idx < block.instruction_used); + let swap_arg = match block.instructions[idx].instr.real() { + Some(Instruction::Swap { .. }) => u32::from(block.instructions[idx].arg), + Some(Instruction::Nop | Instruction::PopTop | Instruction::StoreFast { .. }) => { + i -= 1; continue; } - if is_jump_only_block(&blocks[cursor.idx()]) { - jump_block = cursor; - } - break; - } - if jump_block == BlockIdx::NULL { - current = next; - continue; - } - let jump_instr = blocks[jump_block.idx()].instructions[0]; - let jump_is_forward = matches!( - jump_instr.instr.real(), - Some(Instruction::JumpForward { .. }) - ); - let jump_is_backward = matches!( - jump_instr.instr.real(), - Some(Instruction::JumpBackward { .. }) - ); - let exit_is_reorderable = - blocks[exit_block.idx()] - .instructions - .last() - .is_some_and(|info| { - matches!( - info.instr.real(), - Some(Instruction::ReturnValue | Instruction::RaiseVarargs { .. }) - ) - }); - if !(jump_is_forward || jump_is_backward && exit_is_reorderable) { - current = next; - continue; - } - if jump_is_backward { - if block_is_protected(&blocks[idx]) - || block_is_protected(&blocks[exit_block.idx()]) - || block_is_protected(&blocks[jump_block.idx()]) + _ if matches!( + block.instructions[idx].instr.pseudo(), + Some(PseudoInstruction::StoreFastMaybeNull { .. }) + ) => { - current = next; + i -= 1; continue; } - let after_jump_start = blocks[jump_end.idx()].next; - let after_jump = next_nonempty_block(blocks, after_jump_start); - let after_jump_is_end_async_for = after_jump != BlockIdx::NULL - && blocks[after_jump.idx()] - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::EndAsyncFor))); - let after_jump_is_adjacent_scope_exit = after_jump != BlockIdx::NULL - && is_pop_top_exit_like_block(&blocks[after_jump.idx()]); - if after_jump_is_end_async_for || after_jump_is_adjacent_scope_exit { - current = next; - continue; + _ => return, + }; + + let Some(j) = next_swappable_instruction(block, idx, -1) else { + return; + }; + let lineno = instruction_lineno(&block.instructions[j]); + let mut k = j; + for _ in 1..swap_arg { + let Some(next) = next_swappable_instruction(block, k, lineno) else { + return; + }; + k = next; + } + + let store_j = stores_to(&block.instructions[j]); + let store_k = stores_to(&block.instructions[k]); + if store_j >= 0 || store_k >= 0 { + if store_j == store_k { + return; } - let jump_target = jump_instr.target; - let jumps_to_for_iter = blocks[jump_target.idx()] - .instructions - .first() - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))); - let jump_exits_to_loop_exit = trailing_conditional_jump_index( - &blocks[jump_target.idx()], - ) - .is_some_and(|loop_cond_idx| { - let loop_exit_start = blocks[jump_target.idx()].instructions[loop_cond_idx].target; - loop_exit_start == after_jump_start - || next_nonempty_block(blocks, loop_exit_start) == after_jump - }); - if after_jump != BlockIdx::NULL - && !blocks[after_jump.idx()].cold - && !jumps_to_for_iter - && !jump_exits_to_loop_exit - { - current = next; - continue; + let mut idx = j + 1; + while idx < k { + let store_idx = stores_to(&block.instructions[idx]); + if store_idx >= 0 && (store_idx == store_j || store_idx == store_k) { + return; + } + idx += 1; } } - let after_jump = blocks[jump_end.idx()].next; - blocks[idx].next = jump_start; - blocks[jump_end.idx()].next = exit_start; - blocks[exit_end.idx()].next = after_jump; + set_to_nop(&mut block.instructions[idx]); + block.instructions.swap(j, k); + i -= 1; + } +} - let cond_mut = &mut blocks[idx].instructions[cond_idx]; - cond_mut.instr = reversed; - cond_mut.target = exit_start; +/// flowgraph.c optimize_basic_block swap pass +fn apply_static_swaps_block(block: &mut Block) -> crate::InternalResult<()> { + let mut i = 0; + while i < block.instruction_used { + if matches!( + block.instructions[i].instr.real(), + Some(Instruction::Swap { .. }) + ) { + swaptimize(block, &mut i)?; + apply_static_swaps(block, i as isize); + } + i += 1; + } + Ok(()) +} - current = after_jump; +/// flowgraph.c maybe_instr_make_load_smallint +fn maybe_instr_make_load_smallint(instr: &mut InstructionInfo, constant: &ConstantData) -> bool { + if let ConstantData::Integer { value } = constant + && let Some(small) = value.to_i32().filter(|v| (0..=255).contains(v)) + { + instr_set_op1(instr, Opcode::LoadSmallInt.into(), OpArg::new(small as u32)); + return true; } + false } -fn reorder_conditional_jump_and_exit_blocks(blocks: &mut [Block]) { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; +/// flowgraph.c basicblock_optimize_load_const +fn basicblock_optimize_load_const( + metadata: &mut CodeUnitMetadata, + block: &mut Block, +) -> crate::InternalResult<()> { + let mut i = 0; + let mut effective_opcode = None; + let mut effective_oparg = OpArg::new(0); + while i < block.instruction_used { + if matches!( + block.instructions[i].instr.real(), + Some(Instruction::LoadConst { .. }) + ) && let Some(constant) = get_const_value(metadata, &block.instructions[i]) + { + maybe_instr_make_load_smallint(&mut block.instructions[i], &constant); + } + + let curr = block.instructions[i]; + let curr_arg = curr.arg; + + // Only combine if the source is a real instruction. + let Some(curr_instr) = curr.instr.real() else { + i += 1; continue; }; - let last = blocks[idx].instructions[cond_idx]; - let Some(reversed) = reversed_conditional(&last.instr) else { - current = next; + let is_copy_of_load_const = matches!( + (effective_opcode, curr_instr), + (Some(Instruction::LoadConst { .. }), Instruction::Copy { i }) if i.get(curr_arg) == 1 + ); + if !is_copy_of_load_const { + effective_opcode = Some(curr_instr); + effective_oparg = curr_arg; + } + let Some(const_instr) = effective_opcode else { + i += 1; continue; }; + let const_arg = effective_oparg; - let jump_start = next; - let exit_start = last.target; - if jump_start == BlockIdx::NULL || exit_start == BlockIdx::NULL || jump_start == exit_start - { - current = next; + if i + 1 >= block.instruction_used { + i += 1; continue; } - let mut jump_end = BlockIdx::NULL; - let mut jump_block = BlockIdx::NULL; - let mut cursor = jump_start; - let mut jump_segment_valid = true; - while cursor != BlockIdx::NULL && cursor != exit_start { - if block_is_exceptional(&blocks[cursor.idx()]) { - jump_segment_valid = false; - break; - } - if !blocks[cursor.idx()].instructions.is_empty() { - if jump_block != BlockIdx::NULL || !is_jump_only_block(&blocks[cursor.idx()]) { - jump_segment_valid = false; - break; + let next = block.instructions[i + 1]; + let next_arg = next.arg; + + if let Some(is_true) = load_const_truthiness(const_instr, const_arg, metadata) { + let const_jump = match (next.instr.real(), next.instr.pseudo()) { + (_, Some(PseudoInstruction::JumpIfTrue { .. })) => Some((true, false)), + (_, Some(PseudoInstruction::JumpIfFalse { .. })) => Some((false, false)), + (Some(Instruction::PopJumpIfTrue { .. }), _) => Some((true, true)), + (Some(Instruction::PopJumpIfFalse { .. }), _) => Some((false, true)), + _ => None, + }; + if let Some((jump_if_true, pops_condition)) = const_jump { + if pops_condition { + set_to_nop(&mut block.instructions[i]); + } + if is_true == jump_if_true { + block.instructions[i + 1].instr = PseudoInstruction::Jump { + delta: Arg::marker(), + } + .into(); + } else { + set_to_nop(&mut block.instructions[i + 1]); } - jump_block = cursor; + i += 1; + continue; } - jump_end = cursor; - cursor = blocks[cursor.idx()].next; - } - if !jump_segment_valid || cursor != exit_start || jump_block == BlockIdx::NULL { - current = next; - continue; - } - let jump_instr = blocks[jump_block.idx()].instructions[0]; - if !matches!( - jump_instr.instr.real(), - Some(Instruction::JumpForward { .. }) - ) { - current = next; - continue; } - if jump_instr.lineno_override.is_some_and(|line| line >= 0) - && instruction_lineno(&jump_instr) != instruction_lineno(&last) - { - current = next; + + // The remaining combinations require both instructions to be real. + let Some(next_instr) = next.instr.real() else { + i += 1; continue; - } + }; - let mut exit_end = BlockIdx::NULL; - let mut exit_block = BlockIdx::NULL; - let after_exit = loop { - if cursor == BlockIdx::NULL { - break BlockIdx::NULL; - } - if block_is_exceptional(&blocks[cursor.idx()]) { - if exit_block != BlockIdx::NULL { - break cursor; + if let Instruction::LoadConst { consti } = const_instr { + let constant = &metadata.consts[consti.get(const_arg).as_usize()]; + if matches!(constant, ConstantData::None) + && let Instruction::IsOp { invert } = next_instr + { + let mut jump_idx = i + 2; + if jump_idx >= block.instruction_used { + i += 1; + continue; } - exit_block = BlockIdx::NULL; - break BlockIdx::NULL; - } - if !blocks[cursor.idx()].instructions.is_empty() { - if exit_block != BlockIdx::NULL { - break cursor; + + if matches!( + block.instructions[jump_idx].instr.real(), + Some(Instruction::ToBool) + ) { + set_to_nop(&mut block.instructions[jump_idx]); + jump_idx += 1; + if jump_idx >= block.instruction_used { + i += 1; + continue; + } } - if !is_scope_exit_block(&blocks[cursor.idx()]) { - exit_block = BlockIdx::NULL; - break BlockIdx::NULL; + + let Some(jump_instr) = block.instructions[jump_idx].instr.real() else { + i += 1; + continue; + }; + + let mut invert = matches!( + invert.get(next_arg), + rustpython_compiler_core::bytecode::Invert::Yes + ); + match jump_instr { + Instruction::PopJumpIfFalse { .. } => { + invert = !invert; + } + Instruction::PopJumpIfTrue { .. } => {} + _ => { + i += 1; + continue; + } + }; + + set_to_nop(&mut block.instructions[i]); + set_to_nop(&mut block.instructions[i + 1]); + block.instructions[jump_idx].instr = if invert { + Instruction::PopJumpIfNotNone { + delta: Arg::marker(), + } + } else { + Instruction::PopJumpIfNone { + delta: Arg::marker(), + } } - exit_block = cursor; + .into(); + i = jump_idx; + continue; } - exit_end = cursor; - cursor = blocks[cursor.idx()].next; - }; - if exit_block == BlockIdx::NULL || exit_end == BlockIdx::NULL { - current = next; - continue; } - blocks[idx].next = exit_start; - blocks[exit_end.idx()].next = jump_start; - blocks[jump_end.idx()].next = after_exit; - - let cond_mut = &mut blocks[idx].instructions[cond_idx]; - cond_mut.instr = reversed; - cond_mut.target = jump_start; + if matches!( + const_instr, + Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. } + ) && matches!(next_instr, Instruction::ToBool) + && let Some(value) = load_const_truthiness(const_instr, const_arg, metadata) + { + let const_idx = add_const(metadata, ConstantData::Boolean { value })?; + set_to_nop(&mut block.instructions[i]); + instr_set_op1( + &mut block.instructions[i + 1], + Instruction::LoadConst { + consti: Arg::marker(), + } + .into(), + OpArg::new(const_idx as u32), + ); + i += 1; + continue; + } - current = after_exit; + i += 1; } + Ok(()) } -fn reorder_conditional_chain_and_jump_back_blocks(blocks: &mut Vec) { - let target_comes_before = |target: BlockIdx, block: BlockIdx, blocks: &[Block]| -> bool { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - if current == target { - return true; - } - if current == block { - return false; - } - current = blocks[current.idx()].next; - } - false - }; - - fn is_single_delete_subscr_body(block: &Block) -> bool { - let real: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .collect(); - real.iter() - .filter(|instr| matches!(instr, Instruction::DeleteSubscr)) - .count() - == 1 - && matches!(real.last(), Some(Instruction::DeleteSubscr)) - && real.iter().all(|instr| { - matches!( - instr, - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - | Instruction::LoadSmallInt { .. } - | Instruction::LoadConst { .. } - | Instruction::Copy { .. } - | Instruction::Swap { .. } - | Instruction::BinaryOp { .. } - | Instruction::BinarySlice - | Instruction::BuildSlice { .. } - | Instruction::StoreSubscr - | Instruction::DeleteSubscr - ) - }) +/// flowgraph.c optimize_load_const +fn optimize_load_const( + metadata: &mut CodeUnitMetadata, + blocks: &mut [Block], +) -> crate::InternalResult<()> { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next_block = blocks[block_idx.idx()].next; + let block = &mut blocks[block_idx]; + basicblock_optimize_load_const(metadata, block)?; + block_idx = next_block; } + Ok(()) +} - fn chain_is_conditional_single_delete_body( - blocks: &[Block], - chain_start: BlockIdx, - jump_start: BlockIdx, - ) -> bool { - if chain_start == BlockIdx::NULL || jump_start == BlockIdx::NULL { - return false; - } - let Some(chain_cond_idx) = trailing_conditional_jump_index(&blocks[chain_start.idx()]) - else { - return false; +/// flowgraph.c optimize_basic_block +fn optimize_basic_block( + blocks: &mut [Block], + metadata: &mut CodeUnitMetadata, + block_idx: BlockIdx, +) -> crate::InternalResult<()> { + let bi = block_idx.idx(); + let mut nop = InstructionInfo { + instr: Instruction::Nop.into(), + arg: OpArg::NULL, + target: BlockIdx::NULL, + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: None, + }; + instr_set_op0(&mut nop, Instruction::Nop.into()); + let mut i = 0; + while i < blocks[bi].instruction_used { + let inst = blocks[bi].instructions[i]; + debug_assert!(!inst.instr.is_assembler()); + let target = if inst.instr.has_target() { + let target = inst.target; + debug_assert!(target != BlockIdx::NULL); + debug_assert!(blocks[target.idx()].instruction_used != 0); + debug_assert!(!blocks[target.idx()].instructions[0].instr.is_assembler()); + blocks[target.idx()].instructions[0] + } else { + nop }; - let chain_cond = blocks[chain_start.idx()].instructions[chain_cond_idx]; - if !is_false_path_conditional_jump(&chain_cond.instr) || chain_cond.target != jump_start { - return false; - } - let body = next_nonempty_block(blocks, blocks[chain_start.idx()].next); - body != BlockIdx::NULL - && !block_is_exceptional(&blocks[body.idx()]) - && !block_is_protected(&blocks[body.idx()]) - && next_nonempty_block(blocks, blocks[body.idx()].next) == jump_start - && is_single_delete_subscr_body(&blocks[body.idx()]) - } - fn jump_targets_for_iter(blocks: &[Block], jump_block: BlockIdx) -> bool { - if jump_block == BlockIdx::NULL { - return false; - } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { - return false; - }; - let target = next_nonempty_block(blocks, info.target); - target != BlockIdx::NULL - && blocks[target.idx()] - .instructions - .first() - .is_some_and(|target_info| { - matches!(target_info.instr.real(), Some(Instruction::ForIter { .. })) - }) - } - - fn true_body_false_backedge_is_already_normalized( - blocks: &[Block], - conditional: InstructionInfo, - false_backedge: BlockIdx, - true_body: BlockIdx, - ) -> bool { - fn comes_before(blocks: &[Block], target: BlockIdx, block: BlockIdx) -> bool { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - if current == target { - return true; + let nextop = blocks[bi] + .instructions + .get(i + 1) + .and_then(|next| next.instr.real()); + + match inst.instr { + AnyInstruction::Real(Instruction::BuildTuple { .. }) => { + let oparg = u32::from(inst.arg); + if matches!(nextop, Some(Instruction::UnpackSequence { .. })) + && u32::from(blocks[bi].instructions[i + 1].arg) == oparg + { + match oparg { + 1 => { + set_to_nop(&mut blocks[bi].instructions[i]); + set_to_nop(&mut blocks[bi].instructions[i + 1]); + i += 1; + continue; + } + 2 | 3 => { + set_to_nop(&mut blocks[bi].instructions[i]); + blocks[bi].instructions[i + 1].instr = + Instruction::Swap { i: Arg::marker() }.into(); + i += 1; + continue; + } + _ => {} + } } - if current == block { - return false; + fold_tuple_of_constants(metadata, &mut blocks[bi], i)?; + } + AnyInstruction::Real(Instruction::BuildList { .. } | Instruction::BuildSet { .. }) => { + optimize_lists_and_sets(metadata, &mut blocks[bi], i, nextop)?; + } + AnyInstruction::Real( + Instruction::PopJumpIfNotNone { .. } | Instruction::PopJumpIfNone { .. }, + ) if matches!(target.instr.into(), AnyOpcode::Pseudo(PseudoOpcode::Jump)) + && jump_thread(blocks, block_idx, i, &target, inst.instr)? => + { + continue; + } + AnyInstruction::Real(Instruction::PopJumpIfFalse { .. }) + if matches!(target.instr.into(), AnyOpcode::Pseudo(PseudoOpcode::Jump)) + && jump_thread(blocks, block_idx, i, &target, inst.instr)? => + { + continue; + } + AnyInstruction::Real(Instruction::PopJumpIfTrue { .. }) + if matches!(target.instr.into(), AnyOpcode::Pseudo(PseudoOpcode::Jump)) + && jump_thread(blocks, block_idx, i, &target, inst.instr)? => + { + continue; + } + AnyInstruction::Pseudo( + pseudo @ (PseudoInstruction::JumpIfFalse { .. } + | PseudoInstruction::JumpIfTrue { .. }), + ) => { + let opcode = pseudo.into(); + match target.instr.pseudo().map(Into::into) { + Some(PseudoOpcode::Jump) + if jump_thread(blocks, block_idx, i, &target, opcode)? => + { + continue; + } + Some(PseudoOpcode::JumpIfFalse) + if matches!( + opcode, + AnyInstruction::Pseudo(PseudoInstruction::JumpIfFalse { .. }) + ) && jump_thread(blocks, block_idx, i, &target, opcode)? => + { + continue; + } + Some(PseudoOpcode::JumpIfTrue) + if matches!( + opcode, + AnyInstruction::Pseudo(PseudoInstruction::JumpIfTrue { .. }) + ) && jump_thread(blocks, block_idx, i, &target, opcode)? => + { + continue; + } + Some(PseudoOpcode::JumpIfFalse | PseudoOpcode::JumpIfTrue) => { + let next = blocks[inst.target.idx()].next; + debug_assert!(next != BlockIdx::NULL); + debug_assert!(next != inst.target); + blocks[bi].instructions[i].target = next; + continue; + } + _ => {} } - current = blocks[current.idx()].next; } - false - } - - if !matches!( - conditional.instr.real(), - Some(Instruction::PopJumpIfTrue { .. }) - ) || false_backedge == BlockIdx::NULL - || true_body == BlockIdx::NULL - || !is_jump_only_block(&blocks[false_backedge.idx()]) - { - return false; - } - let false_target = blocks[false_backedge.idx()].instructions[0].target; - if false_target == BlockIdx::NULL || !comes_before(blocks, false_target, false_backedge) { - return false; - } - let true_tail = next_nonempty_block(blocks, blocks[true_body.idx()].next); - true_tail != BlockIdx::NULL - && is_jump_only_block(&blocks[true_tail.idx()]) - && blocks[true_tail.idx()].instructions[0].target == false_target - } - - fn has_other_conditional_predecessor_to( - blocks: &[Block], - conditional_target_counts: &[usize], - target: BlockIdx, - current: BlockIdx, - ) -> bool { - let current_targets = blocks[current.idx()] - .instructions - .iter() - .filter(|info| info.target == target && is_conditional_jump(&info.instr)) - .count(); - conditional_target_counts[target.idx()] > current_targets - } - - let has_exceptional_lineno_sources = blocks - .iter() - .any(|block| block.cold || block_is_exceptional(block) || block_is_protected(block)); - let mut conditional_target_counts = vec![0usize; blocks.len()]; - for block in blocks.iter() { - for info in &block.instructions { - if info.target != BlockIdx::NULL && is_conditional_jump(&info.instr) { - conditional_target_counts[info.target.idx()] += 1; + AnyInstruction::Pseudo( + PseudoInstruction::Jump { .. } | PseudoInstruction::JumpNoInterrupt { .. }, + ) => match target.instr.into() { + AnyOpcode::Pseudo(PseudoOpcode::Jump) + if jump_thread(blocks, block_idx, i, &target, PseudoOpcode::Jump.into())? => + { + continue; + } + AnyOpcode::Pseudo(PseudoOpcode::JumpNoInterrupt) + if jump_thread(blocks, block_idx, i, &target, inst.instr)? => + { + continue; + } + _ => {} + }, + // CPython leaves FOR_ITER jump threading disabled. + AnyInstruction::Real(Instruction::ForIter { .. }) => {} + AnyInstruction::Real(Instruction::StoreFast { .. }) + if matches!(nextop, Some(Instruction::StoreFast { .. })) + && u32::from(inst.arg) == u32::from(blocks[bi].instructions[i + 1].arg) + && instruction_lineno(&blocks[bi].instructions[i]) + == instruction_lineno(&blocks[bi].instructions[i + 1]) => + { + blocks[bi].instructions[i].instr = Instruction::PopTop.into(); + blocks[bi].instructions[i].arg = OpArg::NULL; + } + AnyInstruction::Real(Instruction::Swap { .. }) if u32::from(inst.arg) == 1 => { + set_to_nop(&mut blocks[bi].instructions[i]); + } + AnyInstruction::Real(Instruction::LoadGlobal { .. }) + if matches!(nextop, Some(Instruction::PushNull)) + && (u32::from(inst.arg) & 1) == 0 => + { + instr_set_op1( + &mut blocks[bi].instructions[i], + inst.instr, + OpArg::new(u32::from(inst.arg) | 1), + ); + set_to_nop(&mut blocks[bi].instructions[i + 1]); + } + AnyInstruction::Real(Instruction::CompareOp { .. }) + if matches!(nextop, Some(Instruction::ToBool)) => + { + set_to_nop(&mut blocks[bi].instructions[i]); + instr_set_op1( + &mut blocks[bi].instructions[i + 1], + inst.instr, + OpArg::new(u32::from(inst.arg) | oparg::COMPARE_OP_BOOL_MASK), + ); + i += 1; + continue; + } + AnyInstruction::Real(Instruction::ContainsOp { .. } | Instruction::IsOp { .. }) + if matches!(nextop, Some(Instruction::ToBool)) => + { + set_to_nop(&mut blocks[bi].instructions[i]); + instr_set_op1(&mut blocks[bi].instructions[i + 1], inst.instr, inst.arg); + i += 1; + continue; + } + AnyInstruction::Real(Instruction::ContainsOp { .. } | Instruction::IsOp { .. }) + if matches!(nextop, Some(Instruction::UnaryNot)) => + { + set_to_nop(&mut blocks[bi].instructions[i]); + let inverted = u32::from(inst.arg) ^ 1; + debug_assert!(inverted == 0 || inverted == 1); + instr_set_op1( + &mut blocks[bi].instructions[i + 1], + inst.instr, + OpArg::new(inverted), + ); + i += 1; + continue; + } + AnyInstruction::Real(Instruction::ToBool) + if matches!(nextop, Some(Instruction::ToBool)) => + { + set_to_nop(&mut blocks[bi].instructions[i]); + i += 1; + continue; + } + AnyInstruction::Real(Instruction::UnaryNot) => { + if matches!(nextop, Some(Instruction::ToBool)) { + set_to_nop(&mut blocks[bi].instructions[i]); + instr_set_op0(&mut blocks[bi].instructions[i + 1], inst.instr); + i += 1; + continue; + } + if matches!(nextop, Some(Instruction::UnaryNot)) { + set_to_nop(&mut blocks[bi].instructions[i]); + set_to_nop(&mut blocks[bi].instructions[i + 1]); + i += 1; + continue; + } + fold_const_unaryop(metadata, &mut blocks[bi], i)?; + } + AnyInstruction::Real(Instruction::UnaryInvert | Instruction::UnaryNegative) => { + fold_const_unaryop(metadata, &mut blocks[bi], i)?; + } + AnyInstruction::Real(Instruction::CallIntrinsic1 { func }) => { + match func.get(inst.arg) { + IntrinsicFunction1::ListToTuple => { + if matches!(nextop, Some(Instruction::GetIter)) { + set_to_nop(&mut blocks[bi].instructions[i]); + } else { + fold_constant_intrinsic_list_to_tuple(metadata, &mut blocks[bi], i)?; + } + } + IntrinsicFunction1::UnaryPositive => { + fold_const_unaryop(metadata, &mut blocks[bi], i)?; + } + _ => {} + } } + AnyInstruction::Real(Instruction::BinaryOp { .. }) => { + fold_const_binop(metadata, &mut blocks[bi], i)?; + } + _ => {} } + + i += 1; } + apply_static_swaps_block(&mut blocks[block_idx])?; + Ok(()) +} - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let last = blocks[idx].instructions[cond_idx]; +/// flowgraph.c remove_redundant_nops_and_pairs +#[allow(clippy::if_same_then_else, clippy::useless_let_if_seq)] +#[allow(clippy::unnecessary_wraps)] +fn remove_redundant_nops_and_pairs(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut done = false; - let Some(reversed) = reversed_conditional(&last.instr) else { - current = next; - continue; - }; + while !done { + done = true; + let mut instr: Option<(BlockIdx, usize)> = None; + let mut block_idx = BlockIdx::new(0); - let chain_start = next; - let jump_start = last.target; - if chain_start == BlockIdx::NULL - || jump_start == BlockIdx::NULL - || chain_start == jump_start - { - current = next; - continue; - } - if true_body_false_backedge_is_already_normalized(blocks, last, chain_start, jump_start) { - current = next; - continue; - } - let mut chain_has_suspension_point = false; - let mut scan = chain_start; - while scan != BlockIdx::NULL && scan != jump_start { - if block_contains_suspension_point(&blocks[scan.idx()]) { - chain_has_suspension_point = true; - break; + while block_idx != BlockIdx::NULL { + basicblock_remove_redundant_nops(blocks, block_idx)?; + if is_label(blocks[block_idx.idx()].cpython_label) { + instr = None; } - scan = blocks[scan.idx()].next; - } - let chain_starts_with_false_path_jump = trailing_conditional_jump_index( - &blocks[chain_start.idx()], - ) - .is_some_and(|chain_cond_idx| { - is_false_path_conditional_jump( - &blocks[chain_start.idx()].instructions[chain_cond_idx].instr, - ) - }); - let chain_is_single_exit_block = is_scope_exit_block(&blocks[chain_start.idx()]) - && next_nonempty_block(blocks, blocks[chain_start.idx()].next) == jump_start; - let chain_is_jump_only_exit_block = is_jump_only_block(&blocks[chain_start.idx()]) - && !target_comes_before( - blocks[chain_start.idx()].instructions[0].target, - chain_start, - blocks, - ); - let chain_is_jump_back_exit_block = is_jump_only_block(&blocks[chain_start.idx()]) - && target_comes_before( - blocks[chain_start.idx()].instructions[0].target, - chain_start, - blocks, - ); - let allow_true_path_jump_back_reorder = - matches!(last.instr.real(), Some(Instruction::PopJumpIfTrue { .. })) - && (chain_has_suspension_point - || chain_starts_with_false_path_jump - || chain_is_single_exit_block - || chain_is_jump_only_exit_block - || chain_is_jump_back_exit_block); - let is_generic_false_path_reorder = !allow_true_path_jump_back_reorder; - if !is_false_path_conditional_jump(&last.instr) && !allow_true_path_jump_back_reorder { - current = next; - continue; - } - if is_generic_false_path_reorder && is_scope_exit_block(&blocks[chain_start.idx()]) { - current = next; - continue; - } - if has_exceptional_lineno_sources - && is_generic_false_path_reorder - && has_exceptional_duplicate_lineno(blocks, current, instruction_lineno(&last)) - { - current = next; - continue; - } - if is_generic_false_path_reorder - && has_other_conditional_predecessor_to( - blocks, - &conditional_target_counts, - chain_start, - current, - ) - { - current = next; - continue; - } - if block_is_protected(&blocks[idx]) && block_contains_suspension_point(&blocks[idx]) { - current = next; - continue; - } - if let Some(chain_cond_idx) = trailing_conditional_jump_index(&blocks[chain_start.idx()]) { - let chain_cond = blocks[chain_start.idx()].instructions[chain_cond_idx]; - if matches!( - chain_cond.instr.real().map(Into::into), - Some(Opcode::PopJumpIfTrue) - ) { - let chain_true_target = next_nonempty_block(blocks, chain_cond.target); - if chain_true_target != BlockIdx::NULL - && !is_scope_exit_block(&blocks[chain_true_target.idx()]) - && !is_jump_only_block(&blocks[chain_true_target.idx()]) - && !is_pop_top_jump_block(&blocks[chain_true_target.idx()]) - { - current = next; - continue; + + let len = blocks[block_idx.idx()].instruction_used; + for instr_idx in 0..len { + let prev_instr = instr; + instr = Some((block_idx, instr_idx)); + let instr_info = blocks[block_idx.idx()].instructions[instr_idx]; + let mut prev_opcode = None; + let mut prev_oparg = 0; + if let Some((prev_block, prev_instr_idx)) = prev_instr { + let prev_info = blocks[prev_block.idx()].instructions[prev_instr_idx]; + prev_opcode = prev_info.instr.real(); + prev_oparg = match prev_info.instr.real() { + Some(Instruction::Copy { i }) => i.get(prev_info.arg), + _ => u32::from(prev_info.arg), + }; + } + let opcode = instr_info.instr.real(); + let mut is_redundant_pair = false; + if matches!(opcode, Some(Instruction::PopTop)) { + if matches!( + prev_opcode, + Some(Instruction::LoadConst { .. } | Instruction::LoadSmallInt { .. }) + ) { + is_redundant_pair = true; + } else if matches!(prev_opcode, Some(Instruction::Copy { .. })) + && prev_oparg == 1 + { + is_redundant_pair = true; + } + } + + if is_redundant_pair { + let (prev_block, prev_instr_idx) = + prev_instr.expect("redundant pair has previous"); + set_to_nop(&mut blocks[prev_block.idx()].instructions[prev_instr_idx]); + set_to_nop(&mut blocks[block_idx.idx()].instructions[instr_idx]); + done = false; } } - } - let mut chain_end = BlockIdx::NULL; - let mut saw_nonempty = false; - let mut nonempty_blocks = 0usize; - let mut real_instr_count = 0usize; - let mut cursor = chain_start; - let mut chain_valid = true; - while cursor != BlockIdx::NULL && cursor != jump_start { - if block_is_exceptional(&blocks[cursor.idx()]) - || (block_is_protected(&blocks[cursor.idx()]) - && block_contains_suspension_point(&blocks[cursor.idx()]) - && !block_has_only_stop_iteration_error_handlers(&blocks[cursor.idx()], blocks)) - { - chain_valid = false; - break; + let mut instr_is_jump = false; + if let Some((instr_block, instr_idx)) = instr { + instr_is_jump = is_jump(&blocks[instr_block.idx()].instructions[instr_idx]); } - if !blocks[cursor.idx()].instructions.is_empty() { - saw_nonempty = true; - nonempty_blocks += 1; - real_instr_count += blocks[cursor.idx()] - .instructions - .iter() - .filter(|info| info.instr.real().is_some()) - .count(); + let block = &blocks[block_idx.idx()]; + if instr_is_jump || !bb_has_fallthrough(block) { + instr = None; } - chain_end = cursor; - cursor = blocks[cursor.idx()].next; - } - if !chain_valid || !saw_nonempty || chain_end == BlockIdx::NULL || cursor != jump_start { - current = next; - continue; - } - let chain_is_conditional_single_delete_body = - chain_is_conditional_single_delete_body(blocks, chain_start, jump_start); - if is_generic_false_path_reorder - && nonempty_blocks > 1 - && !chain_is_conditional_single_delete_body - { - current = next; - continue; - } - if !is_generic_false_path_reorder && (nonempty_blocks > 8 || real_instr_count > 80) { - current = next; - continue; + block_idx = block.next; } + } + Ok(()) +} - let mut jump_end = BlockIdx::NULL; - let mut jump_block = BlockIdx::NULL; - cursor = jump_start; - while cursor != BlockIdx::NULL { - if block_is_exceptional(&blocks[cursor.idx()]) { - jump_block = BlockIdx::NULL; - break; - } - jump_end = cursor; - if blocks[cursor.idx()].instructions.is_empty() { - cursor = blocks[cursor.idx()].next; - continue; - } - if !is_jump_only_block(&blocks[cursor.idx()]) - || !target_comes_before(blocks[cursor.idx()].instructions[0].target, cursor, blocks) - { - jump_block = BlockIdx::NULL; - } else { - jump_block = cursor; +/// flowgraph.c remove_unused_consts +#[allow(clippy::needless_range_loop)] +fn remove_unused_consts( + blocks: &mut [Block], + consts: &mut ConstantPool, +) -> crate::InternalResult<()> { + let nconsts = consts.len(); + if nconsts == 0 { + return Ok(()); + } + + let mut index_map = Vec::new(); + index_map + .try_reserve_exact(nconsts) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + index_map.resize(nconsts, 0isize); + for i in 1..nconsts { + index_map[i] = -1; + } + // The first constant may be docstring; keep it always. + index_map[0] = 0; + + // Mark used consts. + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let block = &blocks[block_idx]; + for i in 0..block.instruction_used { + let instr = &block.instructions[i]; + if instr.instr.has_const() { + let index = u32::from(instr.arg) as usize; + debug_assert!(index < nconsts); + index_map[index] = index as isize; } - break; - } - if jump_block == BlockIdx::NULL || jump_end == BlockIdx::NULL { - current = next; - continue; - } - let after_jump = next_nonempty_block(blocks, blocks[jump_block.idx()].next); - let jump_is_artificial = blocks[jump_block.idx()] - .instructions - .first() - .is_some_and(|info| matches!(info.lineno_override, Some(line) if line < 0)); - let after_jump_is_adjacent_scope_exit = - after_jump != BlockIdx::NULL && is_pop_top_exit_like_block(&blocks[after_jump.idx()]); - if !is_generic_false_path_reorder - && chain_is_single_exit_block - && after_jump_is_adjacent_scope_exit - { - current = next; - continue; - } - if is_generic_false_path_reorder - && jump_is_artificial - && after_jump != BlockIdx::NULL - && is_loop_cleanup_block(&blocks[after_jump.idx()]) - && !chain_is_conditional_single_delete_body - { - current = next; - continue; - } - if !is_generic_false_path_reorder - && jump_is_artificial - && jump_targets_for_iter(blocks, jump_block) - { - current = next; - continue; - } - if nonempty_blocks == 1 - && !is_jump_only_block(&blocks[chain_start.idx()]) - && after_jump != BlockIdx::NULL - && !blocks[after_jump.idx()].cold - && !block_is_exceptional(&blocks[after_jump.idx()]) - && !is_scope_exit_block(&blocks[after_jump.idx()]) - && !is_loop_cleanup_block(&blocks[after_jump.idx()]) - { - current = next; - continue; } + block_idx = block.next; + } - let mut cloned_jump = blocks[jump_block.idx()].clone(); - cloned_jump.next = chain_start; - cloned_jump.start_depth = None; - let cloned_idx = BlockIdx::new(blocks.len() as u32); - blocks.push(cloned_jump); - conditional_target_counts.push(0); - blocks[idx].next = cloned_idx; - let cond_mut = &mut blocks[idx].instructions[cond_idx]; - if cond_mut.target != BlockIdx::NULL { - conditional_target_counts[cond_mut.target.idx()] -= 1; + // Now index_map[i] == i if consts[i] is used, -1 otherwise. + // Condense consts. + let mut n_used_consts = 0; + for i in 0..nconsts { + if index_map[i] != -1 { + debug_assert_eq!(index_map[i], i as isize); + index_map[n_used_consts] = index_map[i]; + n_used_consts += 1; } - conditional_target_counts[chain_start.idx()] += 1; - cond_mut.instr = reversed; - cond_mut.target = chain_start; + } - current = next; + if n_used_consts == nconsts { + return Ok(()); } -} -fn reorder_conditional_scope_exit_and_jump_back_blocks( - blocks: &mut [Block], - allow_for_iter_jump_targets: bool, - allow_true_scope_exit_reorder: bool, -) { - fn retarget_empty_chain_targets( - blocks: &mut [Block], - chain_start: BlockIdx, - chain_end: BlockIdx, - replacement: BlockIdx, - ) { - if chain_start == BlockIdx::NULL - || chain_end == BlockIdx::NULL - || replacement == BlockIdx::NULL - { - return; + // Move all used consts to the beginning of the consts list. + debug_assert!(n_used_consts < nconsts); + for i in 0..n_used_consts { + let old_index = index_map[i] as usize; + debug_assert!(i <= old_index && old_index < nconsts); + if i != old_index { + let value = consts.constants[old_index].clone(); + consts.constants[i] = value; } + } - let mut in_chain = vec![false; blocks.len()]; - let mut cursor = chain_start; - while cursor != BlockIdx::NULL && cursor != chain_end { - let block = &blocks[cursor.idx()]; - if !block.instructions.is_empty() { - return; + // Truncate the consts list at its new size. + consts.constants.truncate(n_used_consts); + + // Adjust const indices in the bytecode. + let mut reverse_index_map = Vec::new(); + reverse_index_map + .try_reserve_exact(nconsts) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + reverse_index_map.resize(nconsts, 0isize); + for i in 0..nconsts { + reverse_index_map[i] = -1; + } + for i in 0..n_used_consts { + let old_index = index_map[i]; + debug_assert!(old_index != -1); + let old_index = old_index as usize; + debug_assert_eq!(reverse_index_map[old_index], -1); + reverse_index_map[old_index] = i as isize; + } + + block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next_block = blocks[block_idx.idx()].next; + let block = &mut blocks[block_idx]; + for i in 0..block.instruction_used { + let instr = &mut block.instructions[i]; + if instr.instr.has_const() { + let index = u32::from(instr.arg) as usize; + debug_assert!(reverse_index_map[index] >= 0); + debug_assert!(reverse_index_map[index] < n_used_consts as isize); + instr.arg = OpArg::new(reverse_index_map[index] as u32); } - in_chain[cursor.idx()] = true; - cursor = block.next; } - if cursor != chain_end { - return; + block_idx = next_block; + } + Ok(()) +} + +fn optimize_load_fast(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut max_instrs = 0; + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + max_instrs = max_instrs.max(blocks[current.idx()].instruction_used); + current = blocks[current.idx()].next; + } + let mut instr_flags = Vec::new(); + instr_flags + .try_reserve_exact(max_instrs) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + instr_flags.resize(max_instrs, 0u8); + let mut refs = RefStack { + refs: Vec::new(), + size: 0, + capacity: 0, + }; + let mut worklist = make_cfg_traversal_stack(blocks)?; + worklist.push(BlockIdx(0)); + blocks[0].start_depth = 0; + blocks[0].visited = true; + while let Some(block_idx) = worklist.pop() { + let block_i = block_idx.idx(); + + let instr_count = blocks[block_i].instruction_used; + instr_flags[..instr_count].fill(0); + debug_assert!(blocks[block_i].start_depth >= 0); + let start_depth = blocks[block_i].start_depth as usize; + ref_stack_clear(&mut refs); + for _ in 0..start_depth { + push_ref(&mut refs, DUMMY_INSTR, NOT_LOCAL)?; } - for block in blocks { - for instr in &mut block.instructions { - if instr.target != BlockIdx::NULL && in_chain[instr.target.idx()] { - instr.target = replacement; + for i in 0..instr_count { + let info = blocks[block_i].instructions[i]; + let instr = info.instr; + let arg_u32 = u32::from(info.arg); + debug_assert!(!matches!(instr.real(), Some(Instruction::ExtendedArg))); + + match instr { + AnyInstruction::Real(Instruction::DeleteFast { var_num }) => { + kill_local( + &mut instr_flags, + &refs, + local_as_ref_local(usize::from(var_num.get(info.arg))), + ); + } + AnyInstruction::Real(Instruction::LoadFast { var_num }) => { + push_ref( + &mut refs, + i as isize, + local_as_ref_local(usize::from(var_num.get(info.arg))), + )?; + } + AnyInstruction::Real(Instruction::LoadFastAndClear { var_num }) => { + let local = local_as_ref_local(usize::from(var_num.get(info.arg))); + kill_local(&mut instr_flags, &refs, local); + push_ref(&mut refs, i as isize, local)?; + } + AnyInstruction::Real(Instruction::LoadFastLoadFast { .. }) => { + let local1 = (arg_u32 >> 4) as isize; + let local2 = (arg_u32 & 15) as isize; + push_ref(&mut refs, i as isize, local1)?; + push_ref(&mut refs, i as isize, local2)?; + } + AnyInstruction::Real(Instruction::StoreFast { var_num }) => { + let r = ref_stack_pop(&mut refs); + store_local( + &mut instr_flags, + &refs, + local_as_ref_local(usize::from(var_num.get(info.arg))), + r, + ); + } + AnyInstruction::Real(Instruction::StoreFastLoadFast { .. }) => { + let r = ref_stack_pop(&mut refs); + store_local(&mut instr_flags, &refs, (arg_u32 >> 4) as isize, r); + push_ref(&mut refs, i as isize, (arg_u32 & 15) as isize)?; + } + AnyInstruction::Real(Instruction::StoreFastStoreFast { .. }) => { + let r1 = ref_stack_pop(&mut refs); + store_local(&mut instr_flags, &refs, (arg_u32 >> 4) as isize, r1); + let r2 = ref_stack_pop(&mut refs); + store_local(&mut instr_flags, &refs, (arg_u32 & 15) as isize, r2); + } + AnyInstruction::Real(Instruction::Copy { i: _ }) => { + let depth = arg_u32 as usize; + assert!(depth > 0); + assert!(refs.size >= depth); + let r = ref_stack_at(&refs, refs.size - depth); + push_ref(&mut refs, r.instr, r.local)?; + } + AnyInstruction::Real(Instruction::Swap { i: _ }) => { + let depth = arg_u32 as usize; + assert!(depth >= 2); + assert!(refs.size >= depth); + ref_stack_swap_top(&mut refs, depth); + } + AnyInstruction::Real( + Instruction::FormatSimple + | Instruction::GetAnext + | Instruction::GetLen + | Instruction::GetYieldFromIter + | Instruction::ImportFrom { .. } + | Instruction::MatchKeys + | Instruction::MatchMapping + | Instruction::MatchSequence + | Instruction::WithExceptStart, + ) => { + let effect = instr.stack_effect_info(arg_u32); + let net_pushed = effect.pushed() as isize - effect.popped() as isize; + debug_assert!(net_pushed >= 0); + // CPython optimize_load_fast() shadows the outer + // instruction index in this produced-value loop. + for produced in 0..net_pushed { + push_ref(&mut refs, produced, NOT_LOCAL)?; + } + } + AnyInstruction::Real( + Instruction::DictMerge { .. } + | Instruction::DictUpdate { .. } + | Instruction::ListAppend { .. } + | Instruction::ListExtend { .. } + | Instruction::MapAdd { .. } + | Instruction::Reraise { .. } + | Instruction::SetAdd { .. } + | Instruction::SetUpdate { .. }, + ) => { + let effect = instr.stack_effect_info(arg_u32); + let net_popped = effect.popped() as isize - effect.pushed() as isize; + debug_assert!(net_popped > 0); + for _ in 0..net_popped { + let _ = ref_stack_pop(&mut refs); + } + } + AnyInstruction::Real( + Instruction::EndSend | Instruction::SetFunctionAttribute { .. }, + ) => { + let effect = instr.stack_effect_info(arg_u32); + debug_assert_eq!(effect.popped(), 2); + debug_assert_eq!(effect.pushed(), 1); + let tos = ref_stack_pop(&mut refs); + let _ = ref_stack_pop(&mut refs); + push_ref(&mut refs, tos.instr, tos.local)?; + } + AnyInstruction::Real(Instruction::CheckExcMatch) => { + let _ = ref_stack_pop(&mut refs); + push_ref(&mut refs, i as isize, NOT_LOCAL)?; + } + AnyInstruction::Real(Instruction::ForIter { .. }) => { + let target = info.target; + debug_assert!(target != BlockIdx::NULL); + load_fast_push_block(&mut worklist, blocks, target, refs.size + 1); + push_ref(&mut refs, i as isize, NOT_LOCAL)?; + } + AnyInstruction::Real( + Instruction::LoadAttr { .. } | Instruction::LoadSuperAttr { .. }, + ) => { + let self_ref = ref_stack_pop(&mut refs); + if matches!(instr.real(), Some(Instruction::LoadSuperAttr { .. })) { + let _ = ref_stack_pop(&mut refs); + let _ = ref_stack_pop(&mut refs); + } + push_ref(&mut refs, i as isize, NOT_LOCAL)?; + if arg_u32 & 1 != 0 { + push_ref(&mut refs, self_ref.instr, self_ref.local)?; + } + } + AnyInstruction::Real( + Instruction::LoadSpecial { .. } | Instruction::PushExcInfo, + ) => { + let tos = ref_stack_pop(&mut refs); + push_ref(&mut refs, i as isize, NOT_LOCAL)?; + push_ref(&mut refs, tos.instr, tos.local)?; + } + AnyInstruction::Real(Instruction::Send { .. }) => { + let target = info.target; + debug_assert!(target != BlockIdx::NULL); + load_fast_push_block(&mut worklist, blocks, target, refs.size); + let _ = ref_stack_pop(&mut refs); + push_ref(&mut refs, i as isize, NOT_LOCAL)?; + } + _ => { + let effect = instr.stack_effect_info(arg_u32); + let num_popped = effect.popped() as usize; + let num_pushed = effect.pushed() as usize; + let target = info.target; + if instr.has_target() { + debug_assert!(target != BlockIdx::NULL); + debug_assert!(refs.size >= num_popped); + let target_depth = refs.size - num_popped + num_pushed; + load_fast_push_block(&mut worklist, blocks, target, target_depth); + } + if !is_block_push(&info) { + for _ in 0..num_popped { + let _ = ref_stack_pop(&mut refs); + } + for _ in 0..num_pushed { + push_ref(&mut refs, i as isize, NOT_LOCAL)?; + } + } } } } - } - fn jump_targets_for_iter(blocks: &[Block], jump_block: BlockIdx) -> bool { - if jump_block == BlockIdx::NULL { - return false; + let fallthrough = blocks[block_i].next; + let term = basicblock_last_instr(&blocks[block_i]).copied(); + if let Some(term) = term + && fallthrough != BlockIdx::NULL + && !term.instr.is_unconditional_jump() + && !term.instr.is_scope_exit() + { + debug_assert!(bb_has_fallthrough(&blocks[block_i])); + load_fast_push_block(&mut worklist, blocks, fallthrough, refs.size); } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { - return false; - }; - let target = next_nonempty_block(blocks, info.target); - target != BlockIdx::NULL - && blocks[target.idx()] - .instructions - .first() - .is_some_and(|target_info| { - matches!(target_info.instr.real(), Some(Instruction::ForIter { .. })) - }) - } - fn is_explicit_continue_to_for_iter(blocks: &[Block], jump_block: BlockIdx) -> bool { - if jump_block == BlockIdx::NULL { - return false; - } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { - return false; - }; - matches!( - info.instr.real(), - Some(Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. }) - ) && !matches!(info.lineno_override, Some(line) if line < 0) - && jump_targets_for_iter(blocks, jump_block) - } - - fn is_explicit_continue_after_conditional( - blocks: &[Block], - jump_block: BlockIdx, - cond: InstructionInfo, - ) -> bool { - if !is_explicit_continue_to_for_iter(blocks, jump_block) { - return false; + for i in 0..refs.size { + let r = ref_stack_at(&refs, i); + if r.instr != DUMMY_INSTR { + instr_flags[r.instr as usize] |= LoadFastInstrFlag::RefUnconsumed as u8; + } } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { - return false; - }; - instruction_lineno(info) > instruction_lineno(&cond) - } - fn is_explicit_non_for_jump_back(blocks: &[Block], jump_block: BlockIdx) -> bool { - if jump_block == BlockIdx::NULL || jump_targets_for_iter(blocks, jump_block) { - return false; + let block = &mut blocks[block_idx]; + let iused = block.instruction_used; + let mut i = 0; + while i < iused { + let info = &mut block.instructions[i]; + if instr_flags[i] != 0 { + i += 1; + continue; + } + match info.instr.real() { + Some(Instruction::LoadFast { .. }) => { + info.instr = Instruction::LoadFastBorrow { + var_num: Arg::marker(), + } + .into(); + } + Some(Instruction::LoadFastLoadFast { .. }) => { + info.instr = Instruction::LoadFastBorrowLoadFastBorrow { + var_nums: Arg::marker(), + } + .into(); + } + _ => {} + } + i += 1; } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { - return false; - }; - matches!(info.instr.real(), Some(Instruction::JumpBackward { .. })) - && info.lineno_override.is_some_and(|line| line >= 0) } + Ok(()) +} +/// flowgraph.c calculate_stackdepth +fn calculate_stackdepth(blocks: &mut [Block]) -> crate::InternalResult { let mut current = BlockIdx(0); while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let cond = blocks[idx].instructions[cond_idx]; - if matches!(cond.instr.real(), Some(Instruction::PopJumpIfFalse { .. })) { - let exit_start = next; - let exit_block = next_nonempty_block(blocks, exit_start); - let jump_start = cond.target; - let jump_block = next_nonempty_block(blocks, jump_start); - let after_jump = if jump_block != BlockIdx::NULL { - next_nonempty_block(blocks, blocks[jump_block.idx()].next) - } else { - BlockIdx::NULL - }; - let after_jump_continues_conditional_chain = after_jump != BlockIdx::NULL - && block_is_pure_conditional_test(&blocks[after_jump.idx()]); - if exit_start == BlockIdx::NULL - || exit_block == BlockIdx::NULL - || jump_start == BlockIdx::NULL - || jump_block == BlockIdx::NULL - || jump_block == exit_block - || after_jump_continues_conditional_chain - || block_is_exceptional(&blocks[idx]) - || block_is_exceptional(&blocks[jump_block.idx()]) - || block_is_exceptional(&blocks[exit_block.idx()]) - || block_is_protected(&blocks[idx]) - || block_is_protected(&blocks[jump_block.idx()]) - || block_is_protected(&blocks[exit_block.idx()]) - || !is_scope_exit_block(&blocks[exit_block.idx()]) - || !is_jump_back_only_block(blocks, jump_block) - || (!allow_for_iter_jump_targets - && is_explicit_continue_to_for_iter(blocks, jump_block)) - && blocks[exit_block.idx()].instructions.iter().any(|info| { - matches!(info.instr.real(), Some(Instruction::RaiseVarargs { .. })) - }) - || (!allow_for_iter_jump_targets - && is_explicit_non_for_jump_back(blocks, jump_block)) - || (after_jump != BlockIdx::NULL - && is_pop_top_exit_like_block(&blocks[after_jump.idx()])) - || (after_jump != BlockIdx::NULL - && !blocks[after_jump.idx()].cold - && !is_scope_exit_block(&blocks[after_jump.idx()]) - && !is_loop_cleanup_block(&blocks[after_jump.idx()]) - && jump_targets_for_iter(blocks, jump_block)) - || next_nonempty_block(blocks, blocks[exit_block.idx()].next) != jump_block - { - current = next; - continue; + blocks[current.idx()].start_depth = START_DEPTH_UNSET; + current = blocks[current.idx()].next; + } + let mut stack = make_cfg_traversal_stack(blocks)?; + let mut maxdepth = 0i32; + stackdepth_push(&mut stack, blocks, BlockIdx(0), 0)?; + while let Some(block_idx) = stack.pop() { + let idx = block_idx.idx(); + let mut depth = blocks[idx].start_depth; + debug_assert!(depth >= 0); + let mut next = blocks[idx].next; + let instr_count = blocks[idx].instruction_used; + for i in 0..instr_count { + let ins = blocks[idx].instructions[i]; + let instr = &ins.instr; + let effects = get_stack_effects(*instr, ins.arg, 0)?; + let new_depth = depth + effects.net; + if new_depth < 0 { + return Err(InternalError::StackUnderflow); + } + maxdepth = maxdepth.max(depth); + if instr.has_target() && !matches!(instr.real(), Some(Instruction::EndAsyncFor)) { + debug_assert!(ins.target != BlockIdx::NULL); + let effects = get_stack_effects(*instr, ins.arg, 1)?; + let target_depth = depth + effects.net; + debug_assert!(target_depth >= 0); + maxdepth = maxdepth.max(depth); + stackdepth_push(&mut stack, blocks, ins.target, target_depth)?; + } + depth = new_depth; + debug_assert!(!instr.is_assembler()); + if instr.is_unconditional_jump() || instr.is_scope_exit() { + next = BlockIdx::NULL; + break; } - - let after_jump = blocks[jump_block.idx()].next; - blocks[idx].instructions[cond_idx].instr = reversed_conditional(&cond.instr) - .expect("PopJumpIfFalse has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = exit_start; - blocks[idx].next = jump_start; - blocks[jump_block.idx()].next = exit_start; - blocks[exit_block.idx()].next = after_jump; - current = next; - continue; - } - if !matches!(cond.instr.real(), Some(Instruction::PopJumpIfTrue { .. })) { - current = next; - continue; } - if !allow_true_scope_exit_reorder { - current = next; - continue; - } - let exit_block = next_nonempty_block(blocks, cond.target); - let jump_block = next_nonempty_block(blocks, next); - if exit_block == BlockIdx::NULL - || jump_block == BlockIdx::NULL - || !is_scope_exit_block(&blocks[exit_block.idx()]) - || !is_jump_only_block(&blocks[jump_block.idx()]) - || (jump_targets_for_iter(blocks, jump_block) - && !is_explicit_continue_after_conditional(blocks, jump_block, cond)) - || next_nonempty_block(blocks, blocks[jump_block.idx()].next) != exit_block - || !comes_before( - blocks, - next_nonempty_block(blocks, blocks[jump_block.idx()].instructions[0].target), - jump_block, - ) - { - current = next; - continue; + if next != BlockIdx::NULL { + debug_assert!(bb_has_fallthrough(&blocks[idx])); + stackdepth_push(&mut stack, blocks, next, depth)?; } - let after_exit = blocks[exit_block.idx()].next; - retarget_empty_chain_targets(blocks, next, jump_block, jump_block); - blocks[idx].instructions[cond_idx].instr = - reversed_conditional(&cond.instr).expect("PopJumpIfTrue has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = jump_block; - blocks[idx].next = exit_block; - blocks[exit_block.idx()].next = jump_block; - blocks[jump_block.idx()].next = after_exit; - current = next; } + + let stackdepth = maxdepth; + Ok(stackdepth as u32) } -fn reorder_conditional_break_continue_blocks(blocks: &mut [Block]) { - fn is_break_exit_like(blocks: &[Block], block_idx: BlockIdx) -> bool { - let block = &blocks[block_idx.idx()]; - if !is_jump_only_block(block) { - return false; +#[cfg(test)] +impl CodeInfo { + fn debug_block_dump(&self) -> String { + let mut out = String::new(); + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + use core::fmt::Write; + let block = &self.blocks[block_idx.idx()]; + let block_return = if basicblock_returns(block) { + " return" + } else { + "" + }; + let _ = writeln!( + out, + "block {} next={} cold={} except={} preserve_lasti={} start_depth={}{}", + u32::from(block_idx), + if block.next == BlockIdx::NULL { + String::from("NULL") + } else { + u32::from(block.next).to_string() + }, + block.cold, + block.except_handler, + block.preserve_lasti, + if block.start_depth < 0 { + String::from("None") + } else { + block.start_depth.to_string() + }, + block_return, + ); + for info in &block.instructions[..block.instruction_used] { + let lineno = instruction_lineno(info); + let _ = writeln!( + out, + " [disp={}:{} raw={}:{}-{}:{} override={:?}] {:?} arg={} target={}", + lineno, + info.location.character_offset.get(), + info.location.line.get(), + info.location.character_offset.get(), + info.end_location.line.get(), + info.end_location.character_offset.get(), + info.lineno_override, + info.instr, + u32::from(info.arg), + if info.target == BlockIdx::NULL { + String::from("NULL") + } else { + u32::from(info.target).to_string() + } + ); + } + block_idx = block.next; } - matches!( - block.instructions[0].instr.real(), - Some(Instruction::JumpForward { .. }) - ) && !comes_before( - blocks, - next_nonempty_block(blocks, block.instructions[0].target), - block_idx, - ) + out } - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let cond = blocks[idx].instructions[cond_idx]; - if !matches!(cond.instr.real(), Some(Instruction::PopJumpIfTrue { .. })) { - current = next; - continue; + pub(crate) fn debug_late_cfg_trace(mut self) -> crate::InternalResult> { + let mut trace = Vec::new(); + trace.push(("initial".to_owned(), self.debug_block_dump())); + + let instr_sequence = self.prepare_cfg_from_codegen()?; + self.blocks = cfg_from_instruction_sequence(instr_sequence)?; + trace.push(( + "after_cfg_from_instruction_sequence".to_owned(), + self.debug_block_dump(), + )); + translate_jump_labels_to_targets(&mut self.blocks)?; + mark_except_handlers(&mut self.blocks)?; + label_exception_targets(&mut self.blocks)?; + check_cfg(&self.blocks)?; + inline_small_or_no_lineno_blocks(&mut self.blocks)?; + trace.push(( + "after_inline_small_or_no_lineno_blocks".to_owned(), + self.debug_block_dump(), + )); + remove_unreachable(&mut self.blocks)?; + resolve_line_numbers(&mut self.blocks, self.metadata.firstlineno)?; + optimize_load_const(&mut self.metadata, &mut self.blocks)?; + trace.push(( + "after_optimize_load_const".to_owned(), + self.debug_block_dump(), + )); + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next_block = self.blocks[block_idx.idx()].next; + optimize_basic_block(&mut self.blocks, &mut self.metadata, block_idx)?; + block_idx = next_block; } + trace.push(( + "after_optimize_basic_block".to_owned(), + self.debug_block_dump(), + )); + remove_redundant_nops_and_pairs(&mut self.blocks)?; + remove_unreachable(&mut self.blocks)?; + remove_redundant_nops_and_jumps(&mut self.blocks)?; + #[cfg(debug_assertions)] + assert!(no_redundant_jumps(&self.blocks)); + remove_unused_consts(&mut self.blocks, &mut self.metadata.consts)?; + trace.push(( + "after_optimize_cfg_cleanup".to_owned(), + self.debug_block_dump(), + )); + let nlocals = self.metadata.varnames.len(); + let nparams = self.nparams; + add_checks_for_loads_of_uninitialized_variables(&mut self.blocks, nlocals, nparams)?; + insert_superinstructions(&mut self.blocks)?; + push_cold_blocks_to_end(&mut self.blocks)?; + trace.push(( + "after_push_cold_before_chain_reorder".to_owned(), + self.debug_block_dump(), + )); + resolve_line_numbers(&mut self.blocks, self.metadata.firstlineno)?; + trace.push(( + "after_push_cold_resolve_line_numbers".to_owned(), + self.debug_block_dump(), + )); + + trace.push(( + "after_push_cold_blocks_to_end".to_owned(), + self.debug_block_dump(), + )); + + convert_pseudo_conditional_jumps(&mut self.blocks)?; + trace.push(( + "after_convert_pseudo_conditional_jumps".to_owned(), + self.debug_block_dump(), + )); - let jump_start = next; - let jump_block = next_nonempty_block(blocks, jump_start); - let exit_start = cond.target; - let exit_block = next_nonempty_block(blocks, exit_start); - if jump_start == BlockIdx::NULL - || jump_block == BlockIdx::NULL - || exit_start == BlockIdx::NULL - || exit_block == BlockIdx::NULL - || jump_block == exit_block - || block_is_exceptional(&blocks[idx]) - || block_is_exceptional(&blocks[jump_block.idx()]) - || block_is_exceptional(&blocks[exit_block.idx()]) - || block_is_protected(&blocks[idx]) - || block_is_protected(&blocks[jump_block.idx()]) - || block_is_protected(&blocks[exit_block.idx()]) - || !is_jump_back_only_block(blocks, jump_block) - || !is_break_exit_like(blocks, exit_block) - || next_nonempty_block(blocks, blocks[jump_block.idx()].next) != exit_block - { - current = next; - continue; - } + let _max_stackdepth = calculate_stackdepth(&mut self.blocks)?; + let _nlocalsplus = prepare_localsplus(&self.metadata, &mut self.blocks, self.flags)?; + convert_pseudo_ops(&mut self.blocks)?; + trace.push(( + "after_convert_pseudo_ops".to_owned(), + self.debug_block_dump(), + )); - let after_exit = blocks[exit_block.idx()].next; - blocks[idx].instructions[cond_idx].instr = - reversed_conditional(&cond.instr).expect("PopJumpIfTrue has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = jump_start; - blocks[idx].next = exit_start; - blocks[exit_block.idx()].next = jump_start; - blocks[jump_block.idx()].next = after_exit; - current = next; + normalize_jumps(&mut self.blocks)?; + #[cfg(debug_assertions)] + assert!(no_redundant_jumps(&self.blocks)); + trace.push(("after_normalize_jumps".to_owned(), self.debug_block_dump())); + optimize_load_fast(&mut self.blocks)?; + trace.push(( + "after_optimize_load_fast".to_owned(), + self.debug_block_dump(), + )); + + Ok(trace) } } -fn reorder_conditional_explicit_continue_scope_exit_blocks(blocks: &mut [Block]) { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let cond = blocks[idx].instructions[cond_idx]; - if !matches!(cond.instr.real(), Some(Instruction::PopJumpIfTrue { .. })) { - current = next; - continue; - } +impl InstrDisplayContext for CodeInfo { + type Constant = ConstantData; - let jump_start = next; - let jump_block = next_nonempty_block(blocks, jump_start); - let exit_start = cond.target; - let exit_block = next_nonempty_block(blocks, exit_start); - if jump_start == BlockIdx::NULL - || jump_block == BlockIdx::NULL - || exit_start == BlockIdx::NULL - || exit_block == BlockIdx::NULL - || jump_block == exit_block - || block_is_exceptional(&blocks[idx]) - || block_is_exceptional(&blocks[jump_block.idx()]) - || block_is_exceptional(&blocks[exit_block.idx()]) - || block_is_protected(&blocks[idx]) - || block_is_protected(&blocks[jump_block.idx()]) - || block_is_protected(&blocks[exit_block.idx()]) - || !is_scope_exit_block(&blocks[exit_block.idx()]) - || !is_jump_back_only_block(blocks, jump_block) - || next_nonempty_block(blocks, blocks[jump_block.idx()].next) != exit_block - || blocks[jump_block.idx()] - .instructions - .first() - .is_none_or(|info| info.lineno_override.is_none_or(|lineno| lineno < 0)) - { - current = next; - continue; - } + fn get_constant(&self, consti: oparg::ConstIdx) -> &ConstantData { + &self.metadata.consts[consti.as_usize()] + } - let after_exit = blocks[exit_block.idx()].next; - blocks[idx].instructions[cond_idx].instr = - reversed_conditional(&cond.instr).expect("PopJumpIfTrue has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = jump_start; - blocks[idx].next = exit_start; - blocks[exit_block.idx()].next = jump_start; - blocks[jump_block.idx()].next = after_exit; - current = next; + fn get_name(&self, i: usize) -> &str { + self.metadata.names[i].as_ref() } -} -fn reorder_conditional_implicit_continue_scope_exit_blocks(blocks: &mut [Block]) { - fn jump_back_target(block: &Block) -> Option { - let [info] = block.instructions.as_slice() else { - return None; - }; - if matches!(info.instr.real(), Some(Instruction::JumpBackward { .. })) - && info.target != BlockIdx::NULL - { - Some(info.target) + fn get_varname(&self, var_num: oparg::VarNum) -> &str { + self.metadata.varnames[var_num.as_usize()].as_ref() + } + + fn get_localsplus_name(&self, var_num: oparg::VarNum) -> &str { + let idx = var_num.as_usize(); + let nlocals = self.metadata.varnames.len(); + if idx < nlocals { + self.metadata.varnames[idx].as_ref() } else { - None + let cell_idx = idx - nlocals; + self.metadata + .cellvars + .get_index(cell_idx) + .unwrap_or_else(|| &self.metadata.freevars[cell_idx - self.metadata.cellvars.len()]) + .as_ref() } } +} - fn scope_exit_segment_tail_before_jump( - blocks: &[Block], - start: BlockIdx, - jump_start: BlockIdx, - ) -> Option { - let jump_block = next_nonempty_block(blocks, jump_start); - if jump_block == BlockIdx::NULL { - return None; - } +const NOT_LOCAL: isize = -1; +const DUMMY_INSTR: isize = -1; - let mut segment = Vec::new(); - let mut cursor = next_nonempty_block(blocks, start); - while cursor != BlockIdx::NULL && cursor != jump_block { - if segment.len() >= blocks.len() { - return None; - } - let block = &blocks[cursor.idx()]; - if block_is_exceptional(block) || block_is_protected(block) { - return None; +/// flowgraph.c make_super_instruction +fn make_super_instruction( + inst1: &mut InstructionInfo, + inst2: &mut InstructionInfo, + super_op: AnyInstruction, +) { + let line1 = instruction_lineno(inst1); + let line2 = instruction_lineno(inst2); + if line1 >= 0 && line2 >= 0 && line1 != line2 { + return; + } + let arg1 = u32::from(inst1.arg); + let arg2 = u32::from(inst2.arg); + if arg1 >= 16 || arg2 >= 16 { + return; + } + instr_set_op1(inst1, super_op, OpArg::new((arg1 << 4) | arg2)); + set_to_nop(inst2); +} + +/// flowgraph.c insert_superinstructions +fn insert_superinstructions(blocks: &mut [Block]) -> crate::InternalResult { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next_block = blocks[block_idx.idx()].next; + let block = &mut blocks[block_idx]; + for i in 0..block.instruction_used { + let nextop = (i + 1 < block.instruction_used) + .then(|| block.instructions[i + 1].instr.real()) + .flatten(); + match block.instructions[i].instr.real() { + Some(Instruction::LoadFast { .. }) => { + if matches!(nextop, Some(Instruction::LoadFast { .. })) { + let (inst1, rest) = block.instructions[i..].split_at_mut(1); + make_super_instruction( + &mut inst1[0], + &mut rest[0], + Instruction::LoadFastLoadFast { + var_nums: Arg::marker(), + } + .into(), + ); + } + } + Some(Instruction::StoreFast { .. }) => match nextop { + Some(Instruction::LoadFast { .. }) => { + let (inst1, rest) = block.instructions[i..].split_at_mut(1); + make_super_instruction( + &mut inst1[0], + &mut rest[0], + Instruction::StoreFastLoadFast { + var_nums: Arg::marker(), + } + .into(), + ); + } + Some(Instruction::StoreFast { .. }) => { + let (inst1, rest) = block.instructions[i..].split_at_mut(1); + make_super_instruction( + &mut inst1[0], + &mut rest[0], + Instruction::StoreFastStoreFast { + var_nums: Arg::marker(), + } + .into(), + ); + } + _ => {} + }, + _ => {} } - segment.push(cursor); - cursor = next_nonempty_block(blocks, block.next); - } - if cursor != jump_block || segment.is_empty() { - return None; - } - - let mut in_segment = vec![false; blocks.len()]; - for block_idx in &segment { - in_segment[block_idx.idx()] = true; } + block_idx = next_block; + } + let res = remove_redundant_nops(blocks)?; + #[cfg(debug_assertions)] + assert!(no_redundant_nops(blocks)); + Ok(res) +} - let mut has_scope_exit = false; - for block_idx in &segment { - let block = &blocks[block_idx.idx()]; - if is_scope_exit_block(block) { - has_scope_exit = true; - continue; - } +/// flowgraph.c LoadFastInstrFlag +#[repr(u8)] +enum LoadFastInstrFlag { + SupportKilled = 1, + StoredAsLocal = 2, + RefUnconsumed = 4, +} - if block_has_fallthrough(block) { - let next = next_nonempty_block(blocks, block.next); - if next == BlockIdx::NULL || !in_segment[next.idx()] { - return None; - } - } +/// flowgraph.c ref +#[derive(Clone, Copy)] +struct Ref { + instr: isize, + local: isize, +} - for info in &block.instructions { - if info.target == BlockIdx::NULL { - continue; - } - let target = next_nonempty_block(blocks, info.target); - if target == BlockIdx::NULL || !in_segment[target.idx()] { - return None; - } - } - } +/// flowgraph.c ref_stack +struct RefStack { + refs: Vec, + size: usize, + capacity: usize, +} - has_scope_exit.then_some(*segment.last().expect("non-empty segment")) +/// flowgraph.c ref_stack_push +fn ref_stack_push(stack: &mut RefStack, r: Ref) -> crate::InternalResult<()> { + debug_assert_eq!(stack.refs.len(), stack.capacity); + if stack.size == stack.capacity { + let doubled = stack.capacity * 2; + let new_cap = 32.max(doubled); + stack + .refs + .try_reserve_exact(new_cap - stack.capacity) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + stack.refs.resize(new_cap, Ref { instr: 0, local: 0 }); + stack.capacity = new_cap; } + stack.refs[stack.size] = r; + stack.size += 1; + Ok(()) +} - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let cond = blocks[idx].instructions[cond_idx]; - if !matches!(cond.instr.real(), Some(Instruction::PopJumpIfFalse { .. })) { - current = next; - continue; - } +/// flowgraph.c ref_stack_pop +fn ref_stack_pop(stack: &mut RefStack) -> Ref { + assert!(stack.size > 0); + stack.size -= 1; + stack.refs[stack.size] +} - let exit_start = next; - let exit_block = next_nonempty_block(blocks, exit_start); - let jump_start = cond.target; - let jump_block = next_nonempty_block(blocks, jump_start); - let jump_target = if jump_block != BlockIdx::NULL { - jump_back_target(&blocks[jump_block.idx()]) - } else { - None - }; - let jumps_to_for_iter = jump_target.is_some_and(|target| { - blocks[target.idx()] - .instructions - .first() - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))) - }); - let after_jump_start = if jump_block != BlockIdx::NULL { - blocks[jump_block.idx()].next - } else { - BlockIdx::NULL - }; - let after_jump = next_nonempty_block(blocks, after_jump_start); - let jump_exits_to_loop_exit = jump_target.is_some_and(|target| { - trailing_conditional_jump_index(&blocks[target.idx()]).is_some_and(|cond_idx| { - let loop_exit_start = blocks[target.idx()].instructions[cond_idx].target; - loop_exit_start == after_jump_start - || next_nonempty_block(blocks, loop_exit_start) == after_jump - }) - }); - let jump_has_lineno = blocks[jump_block.idx()] - .instructions - .first() - .is_some_and(|info| instruction_lineno(info) >= 0); - let exit_segment_tail = scope_exit_segment_tail_before_jump(blocks, exit_start, jump_start); - if exit_start == BlockIdx::NULL - || exit_block == BlockIdx::NULL - || jump_start == BlockIdx::NULL - || jump_block == BlockIdx::NULL - || exit_block == jump_block - || block_is_exceptional(&blocks[idx]) - || block_is_exceptional(&blocks[exit_block.idx()]) - || block_is_exceptional(&blocks[jump_block.idx()]) - || block_is_protected(&blocks[idx]) - || block_is_protected(&blocks[exit_block.idx()]) - || block_is_protected(&blocks[jump_block.idx()]) - || exit_segment_tail.is_none() - || !is_jump_back_only_block(blocks, jump_block) - || jumps_to_for_iter - || (after_jump != BlockIdx::NULL - && !blocks[after_jump.idx()].cold - && !jump_exits_to_loop_exit - && !jump_has_lineno) - { - current = next; - continue; - } +/// flowgraph.c ref_stack_swap_top +fn ref_stack_swap_top(stack: &mut RefStack, off: usize) { + assert!(off >= 2 && stack.size >= off); + let top = stack.size - 1; + let other = stack.size - off; + stack.refs.swap(top, other); +} - let exit_segment_tail = exit_segment_tail.expect("checked above"); - let after_jump = blocks[jump_block.idx()].next; - blocks[idx].instructions[cond_idx].instr = - reversed_conditional(&cond.instr).expect("PopJumpIfFalse has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = exit_start; - blocks[idx].next = jump_start; - blocks[jump_block.idx()].next = exit_start; - blocks[exit_segment_tail.idx()].next = after_jump; - current = next; - } +/// flowgraph.c ref_stack_at +fn ref_stack_at(stack: &RefStack, idx: usize) -> Ref { + assert!(idx < stack.size); + stack.refs[idx] } -fn reorder_exception_handler_conditional_continue_scope_exit_blocks(blocks: &mut [Block]) { - fn handler_scope_exit_returns(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - && block - .instructions - .last() - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::ReturnValue))) - } +/// flowgraph.c ref_stack_clear +fn ref_stack_clear(stack: &mut RefStack) { + stack.size = 0; +} - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let cond = blocks[idx].instructions[cond_idx]; - if !matches!( - cond.instr.real(), - Some(Instruction::PopJumpIfNotNone { .. } | Instruction::PopJumpIfFalse { .. }) - ) || !(block_is_exceptional(&blocks[idx]) || blocks[idx].cold) - { - current = next; - continue; - } +/// flowgraph.c optimize_load_fast PUSH_REF +fn push_ref(stack: &mut RefStack, instr: isize, local: isize) -> crate::InternalResult<()> { + ref_stack_push(stack, Ref { instr, local }) +} - let exit_start = next; - let exit_block = next_nonempty_block(blocks, exit_start); - let jump_start = cond.target; - let jump_block = next_nonempty_block(blocks, jump_start); - let jump_target = if jump_block != BlockIdx::NULL { - next_nonempty_block(blocks, blocks[jump_block.idx()].instructions[0].target) - } else { - BlockIdx::NULL - }; - let exit_raises = exit_block != BlockIdx::NULL - && blocks[exit_block.idx()] - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::RaiseVarargs { .. }))); - let exit_returns = - exit_block != BlockIdx::NULL && handler_scope_exit_returns(&blocks[exit_block.idx()]); - if exit_start == BlockIdx::NULL - || exit_block == BlockIdx::NULL - || jump_start == BlockIdx::NULL - || jump_block == BlockIdx::NULL - || jump_target == BlockIdx::NULL - || exit_block == jump_block - || !is_scope_exit_block(&blocks[exit_block.idx()]) - || !is_jump_back_only_block(blocks, jump_block) - || !(exit_raises || exit_returns) - || !blocks[jump_target.idx()] - .instructions - .first() - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))) - || next_nonempty_block(blocks, blocks[exit_block.idx()].next) != jump_block - { - current = next; +/// flowgraph.c kill_local +fn kill_local(instr_flags: &mut [u8], refs: &RefStack, local: isize) { + for i in 0..refs.size { + let r = ref_stack_at(refs, i); + if r.local != local { continue; } - - let after_jump = blocks[jump_block.idx()].next; - blocks[idx].instructions[cond_idx].instr = reversed_conditional(&cond.instr) - .expect("handler conditional has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = exit_start; - blocks[idx].next = jump_start; - blocks[jump_block.idx()].next = exit_start; - blocks[exit_block.idx()].next = after_jump; - current = next; + debug_assert!(r.instr >= 0); + instr_flags[r.instr as usize] |= LoadFastInstrFlag::SupportKilled as u8; } } -fn deduplicate_adjacent_jump_back_blocks(blocks: &mut [Block]) { - fn jump_back_target(block: &Block) -> Option { - let [info] = block.instructions.as_slice() else { - return None; - }; - if matches!(info.instr.real(), Some(Instruction::JumpBackward { .. })) - && info.target != BlockIdx::NULL - { - Some(info.target) - } else { - None - } +/// flowgraph.c store_local +fn store_local(instr_flags: &mut [u8], refs: &RefStack, local: isize, r: Ref) { + kill_local(instr_flags, refs, local); + if r.instr != DUMMY_INSTR { + instr_flags[r.instr as usize] |= LoadFastInstrFlag::StoredAsLocal as u8; } +} - fn jump_back_lineno(block: &Block) -> Option { - let [info] = block.instructions.as_slice() else { - return None; - }; - matches!(info.instr.real(), Some(Instruction::JumpBackward { .. })) - .then(|| instruction_lineno(info)) - } +fn local_as_ref_local(local: usize) -> isize { + local as isize +} - fn has_protected_conditional_predecessor( - blocks: &[Block], - incoming_origins: &[Vec], - target: BlockIdx, - ) -> bool { - incoming_origins[target.idx()] - .iter() - .copied() - .any(|origin| { - blocks[origin.idx()].instructions.iter().any(|info| { - info.except_handler.is_some() - && is_conditional_jump(&info.instr) - && next_nonempty_block(blocks, info.target) == target - }) - }) +/// flowgraph.c load_fast_push_block +fn load_fast_push_block( + worklist: &mut CfgTraversalStack, + blocks: &mut [Block], + target: BlockIdx, + start_depth: usize, +) { + debug_assert!(target != BlockIdx::NULL); + debug_assert!(blocks[target.idx()].start_depth >= 0); + debug_assert_eq!(blocks[target.idx()].start_depth as usize, start_depth,); + if !blocks[target.idx()].visited { + blocks[target.idx()].visited = true; + worklist.push(target); } +} - let reachable = compute_reachable_blocks(blocks); - let incoming_origins = compute_incoming_origins(blocks, &reachable); - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let Some(target) = jump_back_target(&blocks[current.idx()]) else { - current = blocks[current.idx()].next; - continue; - }; - if block_is_exceptional(&blocks[current.idx()]) - || block_is_protected(&blocks[current.idx()]) - { - current = blocks[current.idx()].next; - continue; - } +fn stackdepth_push( + stack: &mut CfgTraversalStack, + blocks: &mut [Block], + target: BlockIdx, + depth: i32, +) -> crate::InternalResult<()> { + let idx = target.idx(); + let block_depth = &mut blocks[idx].start_depth; + if !(*block_depth < 0 || *block_depth == depth) { + return Err(InternalError::InconsistentStackDepth); + } + if *block_depth < depth && *block_depth < 100 { + debug_assert!(*block_depth < 0); + *block_depth = depth; + stack.push(target); + } + Ok(()) +} - let duplicate = next_nonempty_block(blocks, blocks[current.idx()].next); - if duplicate == BlockIdx::NULL - || duplicate == current - || block_is_exceptional(&blocks[duplicate.idx()]) - || block_is_protected(&blocks[duplicate.idx()]) - || jump_back_target(&blocks[duplicate.idx()]) != Some(target) - || jump_back_lineno(&blocks[duplicate.idx()]) - != jump_back_lineno(&blocks[current.idx()]) - || has_protected_conditional_predecessor(blocks, &incoming_origins, current) - || has_protected_conditional_predecessor(blocks, &incoming_origins, duplicate) - { - current = blocks[current.idx()].next; - continue; - } +/// flowgraph.c stack_effects +struct StackEffects { + net: i32, +} - for block in blocks.iter_mut() { - for info in &mut block.instructions { - if info.target == duplicate { - info.target = current; - } - } - } - current = blocks[current.idx()].next; +/// flowgraph.c get_stack_effects +#[allow(clippy::unnecessary_wraps)] +fn get_stack_effects( + instr: AnyInstruction, + oparg: OpArg, + jump: i32, +) -> crate::InternalResult { + if instr + .real() + .is_some_and(|op| op.as_opcode().deopt().is_some()) + { + return Err(InternalError::InvalidStackEffect); } + let oparg = u32::from(oparg); + let net = if instr.is_block_push() && jump == 0 { + 0 + } else if jump != 0 { + instr.stack_effect_jump(oparg) + } else { + instr.stack_effect(oparg) + }; + Ok(StackEffects { net }) } -fn reorder_conditional_body_and_implicit_continue_blocks(blocks: &mut Vec) { - fn jump_back_target(blocks: &[Block], block_idx: BlockIdx) -> Option { - if block_idx == BlockIdx::NULL { - return None; - } - let instructions = blocks[block_idx.idx()].instructions.as_slice(); - let jump = match instructions { - [jump] if jump.instr.is_unconditional_jump() => jump, - [not_taken, jump] - if matches!(not_taken.instr.real(), Some(Instruction::NotTaken)) - && jump.instr.is_unconditional_jump() => - { - jump - } - _ => return None, - }; - if jump.target == BlockIdx::NULL - || !comes_before(blocks, next_nonempty_block(blocks, jump.target), block_idx) - { - return None; - } - Some(jump.target) - } - - fn find_body_tail( - blocks: &[Block], - body_start: BlockIdx, - false_jump: BlockIdx, - target: BlockIdx, - ) -> Option { - let mut cursor = body_start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && cursor != false_jump { - if visited[cursor.idx()] { - return None; - } - visited[cursor.idx()] = true; - if block_is_exceptional(&blocks[cursor.idx()]) { - return None; - } - if jump_back_target(blocks, cursor) == Some(target) { - return Some(cursor); - } - cursor = blocks[cursor.idx()].next; - } - None +fn vec_try_reserve_exact(vec: &mut Vec, additional: usize) -> crate::InternalResult<()> { + vec.try_reserve_exact(additional) + .map_err(|_| InternalError::MalformedControlFlowGraph) +} + +fn vec_try_resize_to_double_capacity(vec: &mut Vec) -> crate::InternalResult<()> { + let capacity = vec.capacity(); + debug_assert!(capacity > 0); + let len = capacity + .checked_mul(core::mem::size_of::()) + .ok_or(InternalError::MalformedControlFlowGraph)?; + if capacity == 0 || len > usize::MAX / 2 { + return Err(InternalError::MalformedControlFlowGraph); } + let new_capacity = capacity * 2; + let additional = new_capacity + .checked_sub(vec.len()) + .ok_or(InternalError::MalformedControlFlowGraph)?; + vec_try_reserve_exact(vec, additional) +} - fn find_body_tail_before_jump( - blocks: &[Block], - body_start: BlockIdx, - jump_start: BlockIdx, - ) -> Option { - let mut cursor = body_start; - let mut tail = BlockIdx::NULL; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && cursor != jump_start { - if visited[cursor.idx()] { - return None; - } - visited[cursor.idx()] = true; - if block_is_exceptional(&blocks[cursor.idx()]) { - return None; - } - tail = cursor; - cursor = blocks[cursor.idx()].next; - } - (tail != BlockIdx::NULL && cursor == jump_start).then_some(tail) - } - - fn body_segment_contains_for_iter( - blocks: &[Block], - body_start: BlockIdx, - body_tail: BlockIdx, - ) -> bool { - let mut cursor = body_start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL { - if visited[cursor.idx()] { - return false; - } - visited[cursor.idx()] = true; - if blocks[cursor.idx()] - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::ForIter { .. }))) - { - return true; - } - if cursor == body_tail { - return false; - } - cursor = blocks[cursor.idx()].next; - } - false +/// assemble.c write_location_first_byte +fn write_location_first_byte(linetable: &mut Vec, code: u8, length: usize) { + linetable.extend(write_location_entry_start(code, length)); +} + +/// pycore_code.h write_location_entry_start +fn write_location_entry_start(code: u8, length: usize) -> [u8; 1] { + debug_assert!(length > 0 && length <= 8); + debug_assert_eq!(code & 15, code); + [0x80 | (code << 3) | ((length - 1) as u8)] +} + +/// assemble.c write_location_byte +fn write_location_byte(linetable: &mut Vec, value: u8) { + linetable.push(value); +} + +/// assemble.c write_location_varint +fn write_location_varint(linetable: &mut Vec, value: u32) { + write_varint(linetable, value); +} + +/// assemble.c write_location_signed_varint +fn write_location_signed_varint(linetable: &mut Vec, value: i32) { + write_signed_varint(linetable, value); +} + +/// assemble.c write_location_info_short_form +fn write_location_info_short_form( + linetable: &mut Vec, + length: usize, + column: i32, + end_column: i32, +) { + debug_assert!(length > 0 && length <= 8); + debug_assert!(column < 80); + debug_assert!(end_column >= column); + debug_assert!(end_column - column < 16); + let column_low_bits = column & 7; + let column_group = column >> 3; + let code = PyCodeLocationInfoKind::Short0 as u8 + column_group as u8; + write_location_first_byte(linetable, code, length); + write_location_byte( + linetable, + ((column_low_bits as u8) << 4) | ((end_column - column) as u8), + ); +} + +/// assemble.c write_location_info_oneline_form +fn write_location_info_oneline_form( + linetable: &mut Vec, + length: usize, + line_delta: i32, + column: i32, + end_column: i32, +) { + debug_assert!(length > 0 && length <= 8); + debug_assert!((0..3).contains(&line_delta)); + debug_assert!(column < 128); + debug_assert!(end_column < 128); + let code = PyCodeLocationInfoKind::OneLine0 as u8 + line_delta as u8; + write_location_first_byte(linetable, code, length); + write_location_byte(linetable, column as u8); + write_location_byte(linetable, end_column as u8); +} + +/// assemble.c write_location_info_long_form +fn write_location_info_long_form( + linetable: &mut Vec, + loc: LineTableLocation, + length: usize, + line_delta: i32, +) { + debug_assert!(length > 0 && length <= 8); + write_location_first_byte(linetable, PyCodeLocationInfoKind::Long as u8, length); + write_location_signed_varint(linetable, line_delta); + debug_assert!(loc.end_line >= loc.line); + write_location_varint(linetable, (loc.end_line - loc.line) as u32); + write_location_varint( + linetable, + if loc.col < 0 { 0 } else { (loc.col as u32) + 1 }, + ); + write_location_varint( + linetable, + if loc.end_col < 0 { + 0 + } else { + (loc.end_col as u32) + 1 + }, + ); +} + +/// assemble.c write_location_info_none +fn write_location_info_none(linetable: &mut Vec, length: usize) { + write_location_first_byte(linetable, PyCodeLocationInfoKind::None as u8, length); +} + +/// assemble.c write_location_info_no_column +fn write_location_info_no_column(linetable: &mut Vec, length: usize, line_delta: i32) { + write_location_first_byte(linetable, PyCodeLocationInfoKind::NoColumns as u8, length); + write_location_signed_varint(linetable, line_delta); +} + +/// assemble.c write_location_info_entry +fn write_location_info_entry( + linetable: &mut Vec, + loc: LineTableLocation, + length: usize, + prev_line: &mut i32, + debug_ranges: bool, +) -> crate::InternalResult<()> { + const THEORETICAL_MAX_ENTRY_SIZE: usize = 25; + if linetable + .len() + .checked_add(THEORETICAL_MAX_ENTRY_SIZE) + .ok_or(InternalError::MalformedControlFlowGraph)? + >= linetable.capacity() + { + debug_assert!(linetable.capacity() > THEORETICAL_MAX_ENTRY_SIZE); + vec_try_resize_to_double_capacity(linetable)?; + } + if loc.line == NO_LOCATION_OVERRIDE { + write_location_info_none(linetable, length); + return Ok(()); + } + + let line_delta = loc.line - *prev_line; + let column = loc.col; + let end_column = loc.end_col; + if !debug_ranges + || ((column < 0 || end_column < 0) && (loc.end_line == loc.line || loc.end_line < 0)) + { + write_location_info_no_column(linetable, length, line_delta); + *prev_line = loc.line; + return Ok(()); } - fn body_segment_contains_protected_block( - blocks: &[Block], - body_start: BlockIdx, - body_tail: BlockIdx, - ) -> bool { - let mut cursor = body_start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL { - if visited[cursor.idx()] { - return false; - } - visited[cursor.idx()] = true; - if block_is_protected(&blocks[cursor.idx()]) { - return true; - } - if cursor == body_tail { - return false; - } - cursor = blocks[cursor.idx()].next; + if loc.end_line == loc.line { + if line_delta == 0 && column < 80 && end_column - column < 16 && end_column >= column { + write_location_info_short_form(linetable, length, column, end_column); + return Ok(()); + } + if (0..3).contains(&line_delta) && column < 128 && end_column < 128 { + write_location_info_oneline_form(linetable, length, line_delta, column, end_column); + *prev_line = loc.line; + return Ok(()); } - false } - fn body_segment_contains_scope_exit( - blocks: &[Block], - body_start: BlockIdx, - body_tail: BlockIdx, - ) -> bool { - let mut cursor = body_start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL { - if visited[cursor.idx()] { - return false; - } - visited[cursor.idx()] = true; - if blocks[cursor.idx()] - .instructions - .iter() - .any(|info| info.instr.is_scope_exit()) - { - return true; - } - if cursor == body_tail { - return false; + write_location_info_long_form(linetable, loc, length, line_delta); + *prev_line = loc.line; + Ok(()) +} + +/// assemble.c assemble_emit_location +fn assemble_emit_location( + linetable: &mut Vec, + loc: LineTableLocation, + mut size: usize, + prev_line: &mut i32, + debug_ranges: bool, +) -> crate::InternalResult<()> { + if size == 0 { + return Ok(()); + } + while size > 8 { + write_location_info_entry(linetable, loc, 8, prev_line, debug_ranges)?; + size -= 8; + } + write_location_info_entry(linetable, loc, size, prev_line, debug_ranges) +} + +fn no_linetable_location() -> LineTableLocation { + LineTableLocation { + line: NO_LOCATION_OVERRIDE, + end_line: NO_LOCATION_OVERRIDE, + col: NO_LOCATION_OVERRIDE, + end_col: NO_LOCATION_OVERRIDE, + } +} + +fn next_linetable_location() -> LineTableLocation { + LineTableLocation { + line: NEXT_LOCATION_OVERRIDE, + end_line: NEXT_LOCATION_OVERRIDE, + col: NEXT_LOCATION_OVERRIDE, + end_col: NEXT_LOCATION_OVERRIDE, + } +} + +/// assemble.c assemble_emit_exception_table_item +fn assemble_emit_exception_table_item(table: &mut Vec, value: i32, mut msb: u8) { + debug_assert!((msb | 128) == 128); + debug_assert!((0..(1 << 30)).contains(&value)); + let value = value as u32; + const CONTINUATION_BIT: u8 = 64; + if value >= 1 << 24 { + table.push(((value >> 24) as u8) | CONTINUATION_BIT | msb); + msb = 0; + } + if value >= 1 << 18 { + table.push((((value >> 18) & 0x3f) as u8) | CONTINUATION_BIT | msb); + msb = 0; + } + if value >= 1 << 12 { + table.push((((value >> 12) & 0x3f) as u8) | CONTINUATION_BIT | msb); + msb = 0; + } + if value >= 1 << 6 { + table.push((((value >> 6) & 0x3f) as u8) | CONTINUATION_BIT | msb); + msb = 0; + } + table.push(((value & 0x3f) as u8) | msb); +} + +/// assemble.c assemble_emit_exception_table_entry +fn assemble_emit_exception_table_entry( + table: &mut Vec, + start: i32, + end: i32, + handler_offset: i32, + handler: InstructionSequenceExceptHandlerInfo, +) -> crate::InternalResult<()> { + const MAX_SIZE_OF_ENTRY: usize = 20; + if table + .len() + .checked_add(MAX_SIZE_OF_ENTRY) + .ok_or(InternalError::MalformedControlFlowGraph)? + >= table.capacity() + { + vec_try_resize_to_double_capacity(table)?; + } + let size = end - start; + debug_assert!(end > start); + let target = handler_offset; + let mut depth = handler.start_depth - 1; + if handler.preserve_lasti > 0 { + depth -= 1; + } + debug_assert!(depth >= 0); + let depth_lasti = (depth << 1) | handler.preserve_lasti; + assemble_emit_exception_table_item(table, start, 1 << 7); + assemble_emit_exception_table_item(table, size, 0); + assemble_emit_exception_table_item(table, target, 0); + assemble_emit_exception_table_item(table, depth_lasti, 0); + Ok(()) +} + +/// assemble.c assemble_exception_table +fn assemble_exception_table( + instrs: &[InstructionSequenceEntry], +) -> crate::InternalResult> { + let mut table = Vec::new(); + vec_try_reserve_exact(&mut table, DEFAULT_LNOTAB_SIZE)?; + let mut handler = InstructionSequenceExceptHandlerInfo { + h_label: NO_EXCEPTION_HANDLER_LABEL, + start_depth: -1, + preserve_lasti: -1, + }; + let mut start = -1; + let mut ioffset = 0i32; + + for i in 0..instrs.len() { + let instr = &instrs[i]; + if instr.except_handler.h_label != handler.h_label { + if handler.h_label >= 0 { + let handler_offset = instrs[handler.h_label as usize].i_offset; + assemble_emit_exception_table_entry( + &mut table, + start, + ioffset, + handler_offset, + handler, + )?; } - cursor = blocks[cursor.idx()].next; + start = ioffset; + handler = instr.except_handler; } - false + ioffset += instr_size(&instr.info) as i32; } - fn body_segment_contains_jump_back_to( - blocks: &[Block], - body_start: BlockIdx, - body_tail: BlockIdx, - target: BlockIdx, - ) -> bool { - let mut cursor = body_start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL { - if visited[cursor.idx()] { - return false; - } - visited[cursor.idx()] = true; - if jump_back_target(blocks, cursor) == Some(target) { - return true; - } - if cursor == body_tail { - return false; - } - cursor = blocks[cursor.idx()].next; - } - false + if handler.h_label >= 0 { + let handler_offset = instrs[handler.h_label as usize].i_offset; + assemble_emit_exception_table_entry(&mut table, start, ioffset, handler_offset, handler)?; } - fn body_segment_contains_any_jump_back( - blocks: &[Block], - body_start: BlockIdx, - body_tail: BlockIdx, - ) -> bool { - fn jump_back_or_self_target(blocks: &[Block], block_idx: BlockIdx) -> Option { - if block_idx == BlockIdx::NULL { - return None; - } - let jump = blocks[block_idx.idx()].instructions.last()?; - if !jump.instr.is_unconditional_jump() { - return None; - } - if jump.target == BlockIdx::NULL { - return None; - } - let target = next_nonempty_block(blocks, jump.target); - if target == block_idx || comes_before(blocks, target, block_idx) { - Some(jump.target) - } else { - None - } - } + Ok(table.into_boxed_slice()) +} - let mut cursor = body_start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL { - if visited[cursor.idx()] { - return false; - } - visited[cursor.idx()] = true; - if jump_back_or_self_target(blocks, cursor).is_some() { - return true; - } - if cursor == body_tail { - return false; - } - cursor = blocks[cursor.idx()].next; +/// Mark exception handler target blocks. +/// flowgraph.c mark_except_handlers +#[allow(clippy::unnecessary_wraps)] +pub(crate) fn mark_except_handlers(blocks: &mut [Block]) -> crate::InternalResult<()> { + #[cfg(debug_assertions)] + { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + assert!(!blocks[block_idx.idx()].except_handler); + block_idx = blocks[block_idx.idx()].next; } - false } - fn block_ends_with_jump_back_to( - blocks: &[Block], - block_idx: BlockIdx, - target: BlockIdx, - ) -> bool { - let Some(last) = blocks[block_idx.idx()].instructions.last() else { - return false; - }; - last.instr.is_unconditional_jump() - && last.target != BlockIdx::NULL - && next_nonempty_block(blocks, last.target) == next_nonempty_block(blocks, target) - && comes_before(blocks, next_nonempty_block(blocks, target), block_idx) - } + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next = blocks[block_idx.idx()].next; + let instr_count = blocks[block_idx.idx()].instruction_used; + for i in 0..instr_count { + let instr = blocks[block_idx.idx()].instructions[i]; + if is_block_push(&instr) { + debug_assert!(instr.target != BlockIdx::NULL); + blocks[instr.target.idx()].except_handler = true; + } + } + block_idx = next; + } + Ok(()) +} + +/// flowgraph.c mark_cold (two-pass to match CPython). +/// +/// Phase 1 (mark_warm): propagate "warm" from entry via fall-through and +/// jump targets. CPython asserts while visiting warm blocks that they are not +/// exception handlers. +/// +/// Phase 2 (mark_cold): propagate "cold" from except_handler blocks via +/// forward edges. Blocks reached only via runtime exception dispatch are +/// marked cold and pushed to the end by push_cold_blocks_to_end. +/// +/// Blocks reached by neither phase remain `cold=false`. They are typically +/// empty unreachable placeholders left by remove_unreachable; they stay in +/// their original chain position (e.g. between entry and the post-try +/// continuation for a nested try/except whose inner_end was emptied by +/// optimize_cfg). This matches CPython's behavior and is necessary for +/// optimize_load_fast to terminate fall-through at those placeholders. +/// flowgraph.c mark_warm +fn mark_warm(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut stack = make_cfg_traversal_stack(blocks)?; + stack.push(BlockIdx(0)); + blocks[0].visited = true; + while let Some(block_idx) = stack.pop() { + let idx = block_idx.idx(); + debug_assert!(!blocks[idx].except_handler); + blocks[idx].warm = true; - fn conditional_jump_target_count(blocks: &[Block], target: BlockIdx) -> usize { - let target = next_nonempty_block(blocks, target); - blocks - .iter() - .flat_map(|block| &block.instructions) - .filter(|info| { - is_conditional_jump(&info.instr) - && info.target != BlockIdx::NULL - && next_nonempty_block(blocks, info.target) == target - }) - .count() - } + let next = blocks[idx].next; + if next != BlockIdx::NULL && bb_has_fallthrough(&blocks[idx]) && !blocks[next.idx()].visited + { + stack.push(next); + blocks[next.idx()].visited = true; + } - fn empty_chain_reaches(blocks: &[Block], start: BlockIdx, target: BlockIdx) -> bool { - let mut cursor = start; - let mut visited = vec![false; blocks.len()]; - while cursor != BlockIdx::NULL && cursor != target { - if visited[cursor.idx()] - || block_is_exceptional(&blocks[cursor.idx()]) - || block_is_protected(&blocks[cursor.idx()]) - || !blocks[cursor.idx()].instructions.is_empty() - { - return false; + let instr_count = blocks[idx].instruction_used; + for i in 0..instr_count { + let instr = blocks[idx].instructions[i]; + if is_jump(&instr) { + let target = instr.target; + debug_assert!(target != BlockIdx::NULL); + if !blocks[target.idx()].visited { + stack.push(target); + blocks[target.idx()].visited = true; + } } - visited[cursor.idx()] = true; - cursor = blocks[cursor.idx()].next; } - cursor == target } + Ok(()) +} - fn block_starts_loop_cleanup(blocks: &[Block], block_idx: BlockIdx) -> bool { - block_idx != BlockIdx::NULL - && matches!( - blocks[block_idx.idx()] - .instructions - .first() - .and_then(|info| info.instr.real()), - Some(Instruction::EndFor | Instruction::PopIter) - ) +fn mark_cold(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let block = &mut blocks[block_idx.idx()]; + debug_assert!(!block.cold); + debug_assert!(!block.warm); + block_idx = block.next; } - fn is_single_delete_subscr_body(block: &Block) -> bool { - let real: Vec<_> = block - .instructions - .iter() - .filter_map(|info| info.instr.real()) - .filter(|instr| !matches!(instr, Instruction::Nop | Instruction::NotTaken)) - .collect(); - real.iter() - .filter(|instr| matches!(instr, Instruction::DeleteSubscr)) - .count() - == 1 - && matches!(real.last(), Some(Instruction::DeleteSubscr)) - && real.iter().all(|instr| { - matches!( - instr, - Instruction::LoadFast { .. } - | Instruction::LoadFastBorrow { .. } - | Instruction::LoadFastLoadFast { .. } - | Instruction::LoadFastBorrowLoadFastBorrow { .. } - | Instruction::LoadSmallInt { .. } - | Instruction::LoadConst { .. } - | Instruction::Copy { .. } - | Instruction::Swap { .. } - | Instruction::BinaryOp { .. } - | Instruction::BinarySlice - | Instruction::BuildSlice { .. } - | Instruction::StoreSubscr - | Instruction::DeleteSubscr - ) - }) - } + mark_warm(blocks)?; - fn block_has_call(block: &Block) -> bool { - block - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::Call { .. }))) - } - - fn protected_region_has_prior_scope_exit( - blocks: &[Block], - loop_target: BlockIdx, - current: BlockIdx, - ) -> bool { - let mut cursor = next_nonempty_block(blocks, loop_target); - let mut visited = vec![false; blocks.len()]; - let mut saw_protected = false; - while cursor != BlockIdx::NULL && cursor != current { - if visited[cursor.idx()] { - return false; - } - visited[cursor.idx()] = true; - let block = &blocks[cursor.idx()]; - saw_protected |= block_is_protected(block); - if saw_protected - && block - .instructions - .iter() - .any(|info| info.instr.is_scope_exit()) - { - return true; - } - cursor = block.next; + let mut cold_stack = make_cfg_traversal_stack(blocks)?; + block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let i = block_idx.idx(); + let next = blocks[i].next; + let block = &blocks[i]; + if block.except_handler { + debug_assert!(!block.warm); + cold_stack.push(block_idx); + blocks[i].visited = true; } - false + block_idx = next; } - - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); + while let Some(block_idx) = cold_stack.pop() { + let idx = block_idx.idx(); + blocks[idx].cold = true; let next = blocks[idx].next; - let Some(cond_idx) = trailing_conditional_jump_index(&blocks[idx]) else { - current = next; - continue; - }; - let cond = blocks[idx].instructions[cond_idx]; - let Some(reversed_cond) = reversed_conditional(&cond.instr) else { - current = next; - continue; - }; - - let true_jump_start = cond.target; - let true_jump = next_nonempty_block(blocks, true_jump_start); - let body_start = next; - let body = next_nonempty_block(blocks, body_start); - let true_jump_loop_target = jump_back_target(blocks, true_jump); - let normalized_forward_conditional = blocks[idx] - .instructions - .get(cond_idx + 1) - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::NotTaken))); - if true_jump_loop_target.is_some() - && (true_jump_start == true_jump - || empty_chain_reaches(blocks, true_jump_start, true_jump)) - && body_start != BlockIdx::NULL - && body != BlockIdx::NULL - && true_jump != BlockIdx::NULL - && !block_is_exceptional(&blocks[idx]) - && !block_is_exceptional(&blocks[true_jump.idx()]) - && !block_is_protected(&blocks[idx]) - && !block_is_protected(&blocks[true_jump.idx()]) - && next_nonempty_block(blocks, blocks[true_jump.idx()].next) != body - && let Some(body_tail) = find_body_tail_before_jump(blocks, body, true_jump_start) - { - let loop_target = true_jump_loop_target.expect("checked above"); - let after_jump = blocks[true_jump.idx()].next; - let body_tail_is_conditional = - trailing_conditional_jump_index(&blocks[body_tail.idx()]).is_some(); - let body_is_single_block = body == body_tail; - let body_has_scope_exit = body_segment_contains_scope_exit(blocks, body, body_tail); - let body_tail_has_scope_exit = is_scope_exit_block(&blocks[body_tail.idx()]); - let body_starts_conditional_chain = block_is_pure_conditional_test(&blocks[body.idx()]); - let body_tail_ends_with_loop_backedge = - block_ends_with_jump_back_to(blocks, body_tail, loop_target); - let body_has_loop_backedge = body_tail_ends_with_loop_backedge - || body_segment_contains_jump_back_to(blocks, body, body_tail, loop_target); - if body_has_scope_exit && !body_tail_has_scope_exit { - current = next; - continue; - } - let body_has_inner_for_iter = body_segment_contains_for_iter(blocks, body, body_tail); - let body_has_any_loop_backedge = body_has_inner_for_iter - && body_segment_contains_any_jump_back(blocks, body, body_tail); - let normalized_single_block_can_reorder = - !normalized_forward_conditional || !block_has_call(&blocks[body.idx()]); - let has_exceptional_duplicate_condition_line = - has_exceptional_duplicate_lineno(blocks, current, instruction_lineno(&cond)); - let has_prior_protected_scope_exit = - protected_region_has_prior_scope_exit(blocks, loop_target, current); - let after_jump_target = next_nonempty_block(blocks, after_jump); - let after_jump_starts_loop_cleanup = - block_starts_loop_cleanup(blocks, after_jump_target); - let after_jump_continues_conditional_chain = after_jump_target != BlockIdx::NULL - && block_is_pure_conditional_test(&blocks[after_jump_target.idx()]); - let body_is_pop_top_exit_like = - body_is_single_block && is_pop_top_exit_like_block(&blocks[body.idx()]); - if (body_has_scope_exit || body_is_pop_top_exit_like) - && after_jump_target != BlockIdx::NULL - && is_pop_top_exit_like_block(&blocks[after_jump_target.idx()]) - { - current = next; - continue; - } - if after_jump_continues_conditional_chain { - current = next; - continue; + if next != BlockIdx::NULL && bb_has_fallthrough(&blocks[idx]) { + let next_idx = next.idx(); + if !blocks[next_idx].warm && !blocks[next_idx].visited { + cold_stack.push(next); + blocks[next_idx].visited = true; } - let jump_target_has_multiple_conditional_predecessors = - conditional_jump_target_count(blocks, true_jump_start) > 1; - let simple_single_block_can_reorder = body_is_single_block - && !body_tail_is_conditional - && !has_exceptional_duplicate_condition_line - && after_jump_starts_loop_cleanup - && !body_has_scope_exit - && (!blocks[body.idx()] - .instructions - .iter() - .any(|info| info.instr.is_unconditional_jump()) - || body_tail_ends_with_loop_backedge) - && !after_jump_continues_conditional_chain - && matches!(cond.instr.real(), Some(Instruction::PopJumpIfFalse { .. })); - let trailing_implicit_continue_can_reorder = after_jump != BlockIdx::NULL - && next_nonempty_block(blocks, after_jump) != body - && (!after_jump_starts_loop_cleanup - || body_has_loop_backedge - || body_has_any_loop_backedge) - && !(body_has_scope_exit - && after_jump_target != BlockIdx::NULL - && is_scope_exit_block(&blocks[after_jump_target.idx()]) - && !body_starts_conditional_chain) - && !is_scope_exit_block(&blocks[body.idx()]); - let can_reorder = !has_exceptional_duplicate_condition_line - && !has_prior_protected_scope_exit - && (!jump_target_has_multiple_conditional_predecessors || body_tail_has_scope_exit) - && ((!body_tail_is_conditional - && ((normalized_single_block_can_reorder - && body_is_single_block - && matches!(cond.instr.real(), Some(Instruction::PopJumpIfTrue { .. }))) - || (body == body_tail - && is_single_delete_subscr_body(&blocks[body.idx()])) - || simple_single_block_can_reorder)) - || (trailing_implicit_continue_can_reorder - && ((body_has_scope_exit && body_tail_has_scope_exit) - || body_has_loop_backedge - || (after_jump_starts_loop_cleanup && body_has_any_loop_backedge)) - && !after_jump_continues_conditional_chain)); - if can_reorder && after_jump != BlockIdx::NULL && after_jump != body_start { - blocks[idx].instructions[cond_idx].instr = reversed_cond; - blocks[idx].instructions[cond_idx].target = body_start; - if body_tail_ends_with_loop_backedge && true_jump_start == true_jump { - blocks[idx].next = true_jump_start; - blocks[true_jump.idx()].next = body_start; - blocks[body_tail.idx()].next = after_jump; - } else { - let cloned_jump_idx = BlockIdx(blocks.len() as u32); - let mut cloned_jump = blocks[true_jump.idx()].clone(); - cloned_jump.next = body_start; - blocks.push(cloned_jump); + } - blocks[idx].next = cloned_jump_idx; - blocks[body_tail.idx()].next = true_jump_start; + let instr_count = blocks[idx].instruction_used; + for i in 0..instr_count { + let instr = blocks[idx].instructions[i]; + if is_jump(&instr) { + debug_assert_eq!(i, instr_count - 1); + let target = instr.target; + debug_assert!(target != BlockIdx::NULL); + if !blocks[target.idx()].warm && !blocks[target.idx()].visited { + cold_stack.push(target); + blocks[target.idx()].visited = true; } - current = blocks[idx].next; - continue; } } + } + Ok(()) +} - if !matches!(cond.instr.real(), Some(Instruction::PopJumpIfTrue { .. })) { - current = next; - continue; - } +/// flowgraph.c push_cold_blocks_to_end +fn push_cold_blocks_to_end(blocks: &mut Vec) -> crate::InternalResult<()> { + if blocks[0].next == BlockIdx::NULL { + return Ok(()); + } - let false_jump_start = next; - let false_jump = next_nonempty_block(blocks, false_jump_start); - let body_start = cond.target; - let body = next_nonempty_block(blocks, body_start); - let Some(loop_target) = jump_back_target(blocks, false_jump) else { - current = next; - continue; - }; - if false_jump_start == BlockIdx::NULL - || false_jump == BlockIdx::NULL - || body_start == BlockIdx::NULL - || body == BlockIdx::NULL - || false_jump == body - || block_is_exceptional(&blocks[idx]) - || block_is_exceptional(&blocks[false_jump.idx()]) - || block_is_exceptional(&blocks[body.idx()]) - || block_is_protected(&blocks[idx]) - || block_is_protected(&blocks[false_jump.idx()]) - || block_is_protected(&blocks[body.idx()]) - || next_nonempty_block(blocks, blocks[false_jump.idx()].next) != body + mark_cold(blocks)?; + let mut next_label = get_max_label(blocks) + 1; + + // If a cold block falls through to a warm block, add an explicit jump + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next = blocks[block_idx.idx()].next; + if blocks[block_idx.idx()].cold + && bb_has_fallthrough(&blocks[block_idx.idx()]) + && next != BlockIdx::NULL + && blocks[next.idx()].warm { - current = next; - continue; + let explicit_jump = blocks_new_block(blocks)?; + if !is_label(blocks[next.idx()].cpython_label) { + blocks[next.idx()].cpython_label = InstructionSequenceLabel::from_index(next_label); + next_label += 1; + } + let jump_label = blocks[next.idx()].cpython_label; + debug_assert!(is_label(jump_label)); + basicblock_addop( + &mut blocks[explicit_jump.idx()], + InstructionInfo { + instr: PseudoOpcode::JumpNoInterrupt.into(), + arg: instruction_sequence_label_oparg(jump_label), + target: BlockIdx::NULL, + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: Some(NO_LOCATION_OVERRIDE), + }, + )?; + blocks[explicit_jump.idx()].cold = true; + blocks[explicit_jump.idx()].next = next; + blocks[explicit_jump.idx()].predecessors = 1; + blocks[block_idx.idx()].next = explicit_jump; + let target = blocks[explicit_jump.idx()].next; + let last = basicblock_last_instr_mut(&mut blocks[explicit_jump.idx()]) + .expect("missing explicit jump"); + last.target = target; } + block_idx = blocks[block_idx.idx()].next; + } - let Some(body_tail) = find_body_tail(blocks, body, false_jump, loop_target) else { - current = next; - continue; - }; - let after_body = blocks[body_tail.idx()].next; - if after_body == BlockIdx::NULL || after_body == false_jump_start { - current = next; - continue; + assert!(!blocks[0].cold); + let mut cold_blocks: BlockIdx = BlockIdx::NULL; + let mut cold_blocks_tail: BlockIdx = BlockIdx::NULL; + let mut block_idx = BlockIdx(0); + + while blocks[block_idx.idx()].next != BlockIdx::NULL { + debug_assert!(!blocks[block_idx.idx()].cold); + while blocks[block_idx.idx()].next != BlockIdx::NULL + && !blocks[blocks[block_idx.idx()].next.idx()].cold + { + block_idx = blocks[block_idx.idx()].next; } - let body_contains_for_iter = body_segment_contains_for_iter(blocks, body, body_tail); - let simple_body_can_reorder = !block_has_call(&blocks[body.idx()]) - && jump_back_target(blocks, body_tail) == Some(loop_target); - let false_jump_starts_with_not_taken = blocks[false_jump.idx()] - .instructions - .first() - .is_some_and(|info| matches!(info.instr.real(), Some(Instruction::NotTaken))); - if simple_body_can_reorder && false_jump_starts_with_not_taken && !body_contains_for_iter { - current = next; - continue; + if blocks[block_idx.idx()].next == BlockIdx::NULL { + break; } - if (!body_contains_for_iter && !simple_body_can_reorder) - || body_segment_contains_protected_block(blocks, body, body_tail) - || trailing_conditional_jump_index(&blocks[body.idx()]).is_some() - || block_starts_loop_cleanup(blocks, next_nonempty_block(blocks, after_body)) + + debug_assert!(!blocks[block_idx.idx()].cold); + debug_assert!(blocks[blocks[block_idx.idx()].next.idx()].cold); + + let mut block_end = blocks[block_idx.idx()].next; + while blocks[block_end.idx()].next != BlockIdx::NULL + && blocks[blocks[block_end.idx()].next.idx()].cold { - current = next; - continue; + block_end = blocks[block_end.idx()].next; } - if blocks[idx] - .instructions - .get(cond_idx + 1) - .is_none_or(|info| !matches!(info.instr.real(), Some(Instruction::NotTaken))) - && matches!( - blocks[false_jump.idx()] - .instructions - .first() - .and_then(|info| info.instr.real()), - Some(Instruction::NotTaken) - ) - { - let not_taken = blocks[false_jump.idx()].instructions.remove(0); - blocks[idx].instructions.insert(cond_idx + 1, not_taken); + debug_assert!(blocks[block_end.idx()].cold); + debug_assert!( + blocks[block_end.idx()].next == BlockIdx::NULL + || !blocks[blocks[block_end.idx()].next.idx()].cold + ); + + if cold_blocks == BlockIdx::NULL { + cold_blocks = blocks[block_idx.idx()].next; + } else { + blocks[cold_blocks_tail.idx()].next = blocks[block_idx.idx()].next; } - blocks[idx].instructions[cond_idx].instr = - reversed_conditional(&cond.instr).expect("PopJumpIfTrue has a reversed conditional"); - blocks[idx].instructions[cond_idx].target = false_jump_start; - blocks[idx].next = body_start; - blocks[body_tail.idx()].next = false_jump_start; - blocks[false_jump.idx()].next = after_body; - current = blocks[idx].next; + cold_blocks_tail = block_end; + blocks[block_idx.idx()].next = blocks[block_end.idx()].next; + blocks[block_end.idx()].next = BlockIdx::NULL; } -} -#[allow(dead_code)] -fn reorder_jump_over_exception_cleanup_blocks(blocks: &mut [Block]) { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let idx = current.idx(); - let next = blocks[idx].next; - if blocks[idx].cold && is_with_suppress_exit_block(&blocks[idx]) { - current = next; - continue; - } - let Some(last) = blocks[idx].instructions.last().copied() else { - current = next; - continue; - }; - if !matches!(last.instr.real(), Some(Instruction::JumpForward { .. })) - || last.target == BlockIdx::NULL - { - current = next; - continue; - } + debug_assert!(blocks[block_idx.idx()].next == BlockIdx::NULL); + blocks[block_idx.idx()].next = cold_blocks; - let cleanup_start = next; - let target_start = last.target; - let target = next_nonempty_block(blocks, target_start); - if cleanup_start == BlockIdx::NULL || target == BlockIdx::NULL || cleanup_start == target { - current = next; - continue; - } - // Keep the target anchored to the first target block. If we have to - // skip leading empty blocks here, reordering can leave the jump shape - // inconsistent in nested cleanup chains such as poplib.POP3.close(). - if target_start != target { - current = next; - continue; - } + if cold_blocks != BlockIdx::NULL { + remove_redundant_nops_and_jumps(blocks)?; + } + Ok(()) +} - let mut cleanup_end = BlockIdx::NULL; - let mut saw_exceptional = false; - let mut cursor = cleanup_start; - while cursor != BlockIdx::NULL && cursor != target { - if blocks[cursor.idx()].instructions.is_empty() { - cleanup_end = cursor; - cursor = blocks[cursor.idx()].next; - continue; - } - if !(block_is_exceptional(&blocks[cursor.idx()]) - || is_exception_cleanup_block(&blocks[cursor.idx()]) - || blocks[cursor.idx()].cold && is_reraise_scope_exit_block(&blocks[cursor.idx()])) - { - cleanup_end = BlockIdx::NULL; - break; +/// flowgraph.c check_cfg +fn check_cfg(blocks: &[Block]) -> crate::InternalResult<()> { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let block = &blocks[block_idx.idx()]; + for i in 0..block.instruction_used { + let opcode = block.instructions[i].instr; + debug_assert!(!opcode.is_assembler()); + if opcode.is_terminator() && i != block.instruction_used - 1 { + return Err(InternalError::MalformedControlFlowGraph); } - saw_exceptional = true; - cleanup_end = cursor; - cursor = blocks[cursor.idx()].next; - } - if !saw_exceptional || cleanup_end == BlockIdx::NULL || cursor != target { - current = next; - continue; } + block_idx = block.next; + } + Ok(()) +} - let mut target_end = BlockIdx::NULL; - let mut target_exit = BlockIdx::NULL; - let mut nonempty_target_blocks = 0usize; - cursor = target; - while cursor != BlockIdx::NULL { - if block_is_exceptional(&blocks[cursor.idx()]) { - break; - } - target_end = cursor; - if !blocks[cursor.idx()].instructions.is_empty() { - nonempty_target_blocks += 1; - target_exit = cursor; +/// flowgraph.c jump_thread +fn jump_thread( + blocks: &mut [Block], + block_idx: BlockIdx, + instr_idx: usize, + target: &InstructionInfo, + opcode: AnyInstruction, +) -> crate::InternalResult { + let bi = block_idx.idx(); + debug_assert!(is_jump(&blocks[bi].instructions[instr_idx])); + debug_assert!(is_jump(target)); + debug_assert_eq!(instr_idx + 1, blocks[bi].instruction_used); + debug_assert!(target.target != BlockIdx::NULL); + if blocks[bi].instructions[instr_idx].target != target.target { + set_to_nop(&mut blocks[bi].instructions[instr_idx]); + basicblock_add_jump(blocks, block_idx, opcode, target.target, target)?; + return Ok(true); + } + Ok(false) +} + +/// flowgraph.c basicblock_add_jump +fn basicblock_add_jump( + blocks: &mut [Block], + block_idx: BlockIdx, + instr: AnyInstruction, + target: BlockIdx, + loc_source: &InstructionInfo, +) -> crate::InternalResult<()> { + let bi = block_idx.idx(); + let last = basicblock_last_instr(&blocks[bi]); + if last.is_some_and(is_jump) { + return Err(InternalError::MalformedControlFlowGraph); + } + debug_assert!(target != BlockIdx::NULL); + let label = blocks[target.idx()].cpython_label; + debug_assert!(is_label(label)); + let arg = instruction_sequence_label_oparg(label); + let block = &mut blocks[bi]; + basicblock_addop( + block, + InstructionInfo { + instr, + arg, + target: BlockIdx::NULL, + location: loc_source.location, + end_location: loc_source.end_location, + except_handler: None, + lineno_override: loc_source.lineno_override, + }, + )?; + let last = basicblock_last_instr_mut(block).expect("missing jump"); + debug_assert!(match (last.instr, instr) { + (AnyInstruction::Real(last), AnyInstruction::Real(opcode)) => + last.as_opcode() == opcode.as_opcode(), + (AnyInstruction::Pseudo(last), AnyInstruction::Pseudo(opcode)) => + last.as_opcode() == opcode.as_opcode(), + _ => false, + }); + last.target = target; + Ok(()) +} + +/// pycore_opcode_utils.h IS_CONDITIONAL_JUMP_OPCODE +fn is_conditional_jump_opcode(instr: &AnyInstruction) -> bool { + matches!( + instr.real().map(Into::into), + Some( + Opcode::PopJumpIfFalse + | Opcode::PopJumpIfTrue + | Opcode::PopJumpIfNone + | Opcode::PopJumpIfNotNone + ) + ) +} + +/// flowgraph.c convert_pseudo_conditional_jumps +fn convert_pseudo_conditional_jumps(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next = blocks[block_idx.idx()].next; + let block = &mut blocks[block_idx.idx()]; + let mut i = 0; + while i < block.instruction_used { + let instr = block.instructions[i]; + let opcode = instr.instr; + if matches!( + opcode.pseudo(), + Some(PseudoInstruction::JumpIfFalse { .. } | PseudoInstruction::JumpIfTrue { .. }) + ) { + debug_assert_eq!(i, block.instruction_used - 1); + block.instructions[i].instr = + if matches!(opcode.pseudo(), Some(PseudoInstruction::JumpIfFalse { .. })) { + Instruction::PopJumpIfFalse { + delta: Arg::marker(), + } + .into() + } else { + Instruction::PopJumpIfTrue { + delta: Arg::marker(), + } + .into() + }; + + let location = instr.location; + let end_location = instr.end_location; + let except_handler = instr.except_handler; + let lineno_override = instr.lineno_override; + let copy = InstructionInfo { + instr: Instruction::Copy { i: Arg::marker() }.into(), + arg: OpArg::new(1), + target: BlockIdx::NULL, + location, + end_location, + except_handler, + lineno_override, + }; + basicblock_insert_instruction(block, i, copy)?; + i += 1; + + let to_bool = InstructionInfo { + instr: Instruction::ToBool.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location, + end_location, + except_handler, + lineno_override, + }; + basicblock_insert_instruction(block, i, to_bool)?; + i += 1; } - cursor = blocks[cursor.idx()].next; + i += 1; } + block_idx = next; + } + Ok(()) +} - const MAX_REORDERED_EXIT_BLOCK_SIZE: usize = 4; +/// flowgraph.c normalize_jumps_in_block +fn normalize_jumps_in_block( + blocks: &mut Vec, + block_idx: BlockIdx, +) -> crate::InternalResult<()> { + let idx = block_idx.idx(); + let Some(last_ins) = basicblock_last_instr(&blocks[idx]).copied() else { + return Ok(()); + }; + if !is_conditional_jump_opcode(&last_ins.instr) { + return Ok(()); + } + debug_assert!(!last_ins.instr.is_assembler()); - if target_end == BlockIdx::NULL - || target_exit == BlockIdx::NULL - || nonempty_target_blocks != 1 - || target_exit != target_end - || !is_scope_exit_block(&blocks[target_exit.idx()]) - || blocks[target_exit.idx()].instructions.len() > MAX_REORDERED_EXIT_BLOCK_SIZE - { - current = next; - continue; - } + debug_assert!(last_ins.target != BlockIdx::NULL); + let is_forward = !blocks[last_ins.target.idx()].visited; - let after_target = blocks[target_end.idx()].next; - blocks[idx].next = target_start; - blocks[target_end.idx()].next = cleanup_start; - blocks[cleanup_end.idx()].next = after_target; - current = after_target; + if is_forward { + // Insert NOT_TAKEN after forward conditional jump. + let not_taken = InstructionInfo { + instr: Opcode::NotTaken.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location: last_ins.location, + end_location: last_ins.end_location, + except_handler: None, + lineno_override: last_ins.lineno_override, + }; + basicblock_addop(&mut blocks[idx], not_taken)?; + return Ok(()); } + + let reversed_opcode = match AnyOpcode::from(last_ins.instr).real() { + Some(Opcode::PopJumpIfNotNone) => Opcode::PopJumpIfNone.into(), + Some(Opcode::PopJumpIfNone) => Opcode::PopJumpIfNotNone.into(), + Some(Opcode::PopJumpIfFalse) => Opcode::PopJumpIfTrue.into(), + Some(Opcode::PopJumpIfTrue) => Opcode::PopJumpIfFalse.into(), + _ => unreachable!("conditional jump has reverse opcode"), + }; + + // Transform 'conditional jump T' to 'reversed_jump b_next' followed by + // 'jump_backwards T'. + let loc = last_ins.location; + let end_loc = last_ins.end_location; + + let target = last_ins.target; + let backwards_jump_idx = blocks_new_block(blocks)?; + basicblock_addop( + &mut blocks[backwards_jump_idx.idx()], + InstructionInfo { + instr: Opcode::NotTaken.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location: loc, + end_location: end_loc, + except_handler: None, + lineno_override: last_ins.lineno_override, + }, + )?; + basicblock_add_jump( + blocks, + backwards_jump_idx, + PseudoOpcode::Jump.into(), + target, + &last_ins, + )?; + blocks[backwards_jump_idx.idx()].start_depth = blocks[target.idx()].start_depth; + + let old_next = blocks[idx].next; + debug_assert!(old_next != BlockIdx::NULL); + + let last_mut = basicblock_last_instr_mut(&mut blocks[idx]).unwrap(); + last_mut.instr = reversed_opcode; + last_mut.target = old_next; + + blocks[backwards_jump_idx.idx()].cold = blocks[idx].cold; + blocks[backwards_jump_idx.idx()].next = old_next; + blocks[idx].next = backwards_jump_idx; + Ok(()) } -fn maybe_propagate_location( - instr: &mut InstructionInfo, - location: SourceLocation, - end_location: SourceLocation, -) { - if !instruction_has_lineno(instr) { - instr.location = location; - instr.end_location = end_location; - instr.lineno_override = None; +/// flowgraph.c normalize_jumps +fn normalize_jumps(blocks: &mut Vec) -> crate::InternalResult<()> { + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + blocks[current.idx()].visited = false; + current = blocks[current.idx()].next; + } + + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + let idx = current.idx(); + blocks[idx].visited = true; + normalize_jumps_in_block(blocks, current)?; + current = blocks[idx].next; } + Ok(()) } -fn overwrite_location( - instr: &mut InstructionInfo, - location: SourceLocation, - end_location: SourceLocation, -) { - instr.location = location; - instr.end_location = end_location; - instr.lineno_override = None; +/// flowgraph.c basicblock_inline_small_or_no_lineno_blocks +fn basicblock_inline_small_or_no_lineno_blocks( + blocks: &mut [Block], + block_idx: BlockIdx, +) -> crate::InternalResult { + let Some(last) = basicblock_last_instr(&blocks[block_idx.idx()]).copied() else { + return Ok(false); + }; + if !last.instr.is_unconditional_jump() { + return Ok(false); + } + + let target = last.target; + debug_assert!(target != BlockIdx::NULL); + let small_exit_block = basicblock_exits_scope(&blocks[target.idx()]) + && blocks[target.idx()].instruction_used <= MAX_COPY_SIZE; + let no_lineno_no_fallthrough = basicblock_has_no_lineno(&blocks[target.idx()]) + && !bb_has_fallthrough(&blocks[target.idx()]); + if small_exit_block || no_lineno_no_fallthrough { + debug_assert!(is_jump(&last)); + let removed_jump_opcode = last.instr; + let last = basicblock_last_instr_mut(&mut blocks[block_idx.idx()]) + .expect("non-empty block has last instruction"); + set_to_nop(last); + basicblock_append_block_instructions(blocks, block_idx, target)?; + if no_lineno_no_fallthrough { + let last = basicblock_last_instr_mut(&mut blocks[block_idx.idx()]).unwrap(); + if last.instr.is_unconditional_jump() + && matches!( + removed_jump_opcode.into(), + AnyOpcode::Pseudo(PseudoOpcode::Jump) + ) + { + last.instr = PseudoOpcode::Jump.into(); + } + } + blocks[target.idx()].predecessors -= 1; + return Ok(true); + } + Ok(false) } -fn compute_reachable_blocks(blocks: &[Block]) -> Vec { - let mut reachable = vec![false; blocks.len()]; - if blocks.is_empty() { - return reachable; +/// flowgraph.c inline_small_or_no_lineno_blocks +fn inline_small_or_no_lineno_blocks(blocks: &mut [Block]) -> crate::InternalResult { + loop { + let mut changes = false; + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + let next = blocks[current.idx()].next; + let res = basicblock_inline_small_or_no_lineno_blocks(blocks, current)?; + if res { + changes = true; + } + + current = next; + } + if !changes { + return Ok(changes); + } } +} - reachable[0] = true; - let mut changed = true; - while changed { - changed = false; - for i in 0..blocks.len() { - if !reachable[i] { +/// flowgraph.c basicblock_remove_redundant_nops +#[allow(clippy::unnecessary_wraps)] +fn basicblock_remove_redundant_nops( + blocks: &mut [Block], + block_idx: BlockIdx, +) -> crate::InternalResult { + let bi = block_idx.idx(); + let mut dest = 0; + let mut prev_lineno = -1i32; + let instr_count = blocks[bi].instruction_used; + + for src in 0..instr_count { + let instr = blocks[bi].instructions[src]; + let lineno = instruction_lineno(&instr); + + if matches!(instr.instr.real(), Some(Instruction::Nop)) { + if lineno < 0 { continue; } - for ins in &blocks[i].instructions { - if ins.target != BlockIdx::NULL && !reachable[ins.target.idx()] { - reachable[ins.target.idx()] = true; - changed = true; + if prev_lineno == lineno { + continue; + } + if src < instr_count - 1 { + let next_lineno = instruction_lineno(&blocks[bi].instructions[src + 1]); + if next_lineno == lineno { + continue; } - if let Some(eh) = &ins.except_handler - && !reachable[eh.handler_block.idx()] - { - reachable[eh.handler_block.idx()] = true; - changed = true; + if next_lineno < 0 { + instr_set_loc( + &mut blocks[bi].instructions[src + 1], + instr.location, + instr.end_location, + instr.lineno_override, + ); + continue; + } + } else { + let next = next_nonempty_block(blocks, blocks[bi].next); + if next != BlockIdx::NULL { + let mut next_loc = no_linetable_location(); + let mut next_i = 0; + while next_i < blocks[next.idx()].instruction_used { + let instr = blocks[next.idx()].instructions[next_i]; + if matches!(instr.instr.real(), Some(Instruction::Nop)) + && instruction_lineno(&instr) < 0 + { + next_i += 1; + continue; + } + next_loc = instruction_linetable_location(&instr); + break; + } + if lineno == next_loc.line { + continue; + } } } - let next = blocks[i].next; - if next != BlockIdx::NULL - && !reachable[next.idx()] - && !blocks[i].instructions.last().is_some_and(|ins| { - ins.instr.is_scope_exit() || ins.instr.is_unconditional_jump() - }) - { - reachable[next.idx()] = true; - changed = true; - } } + + if dest != src { + blocks[bi].instructions[dest] = blocks[bi].instructions[src]; + } + dest += 1; + prev_lineno = lineno; } - reachable + debug_assert!(dest <= instr_count); + let num_removed = instr_count - dest; + blocks[bi].instruction_used = dest; + Ok(num_removed) +} + +/// flowgraph.c remove_redundant_nops +#[allow(clippy::unnecessary_wraps)] +fn remove_redundant_nops(blocks: &mut [Block]) -> crate::InternalResult { + let mut changes = 0; + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + let next = blocks[current.idx()].next; + let change = basicblock_remove_redundant_nops(blocks, current)?; + changes += change; + current = next; + } + Ok(changes) } -fn compute_predecessors(blocks: &[Block]) -> Vec { - let mut predecessors = vec![0u32; blocks.len()]; - if blocks.is_empty() { - return predecessors; +/// flowgraph.c no_redundant_nops +#[cfg(debug_assertions)] +fn no_redundant_nops(blocks: &mut [Block]) -> bool { + match remove_redundant_nops(blocks) { + Ok(0) => true, + Ok(_) | Err(_) => false, } +} - let reachable = compute_reachable_blocks(blocks); - predecessors[0] = 1; +/// flowgraph.c remove_redundant_jumps +fn remove_redundant_jumps(blocks: &mut [Block]) -> crate::InternalResult { + let mut changes = 0; let mut current = BlockIdx(0); while current != BlockIdx::NULL { - if !reachable[current.idx()] { - current = blocks[current.idx()].next; + let block_idx = current.idx(); + let Some(last) = basicblock_last_instr(&blocks[block_idx]).copied() else { + current = blocks[block_idx].next; continue; + }; + debug_assert!(!last.instr.is_assembler()); + if last.instr.is_unconditional_jump() { + let jump_target = next_nonempty_block(blocks, last.target); + if jump_target == BlockIdx::NULL { + return Err(InternalError::MalformedControlFlowGraph); + } + let next = next_nonempty_block(blocks, blocks[block_idx].next); + if jump_target == next { + changes += 1; + let last = basicblock_last_instr_mut(&mut blocks[block_idx]).unwrap(); + set_to_nop(last); + } } + current = blocks[block_idx].next; + } + Ok(changes) +} +/// flowgraph.c no_redundant_jumps +#[cfg(debug_assertions)] +fn no_redundant_jumps(blocks: &[Block]) -> bool { + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { let block = &blocks[current.idx()]; - if block_has_fallthrough(block) { + if let Some(last) = basicblock_last_instr(block) + && last.instr.is_unconditional_jump() + { let next = next_nonempty_block(blocks, block.next); - if next != BlockIdx::NULL && reachable[next.idx()] { - predecessors[next.idx()] += 1; - } - } - for ins in &block.instructions { - if ins.target != BlockIdx::NULL { - let target = next_nonempty_block(blocks, ins.target); - if target != BlockIdx::NULL && reachable[target.idx()] { - predecessors[target.idx()] += 1; + let jump_target = next_nonempty_block(blocks, last.target); + if jump_target == next { + assert!(next != BlockIdx::NULL); + if instruction_lineno(last) + == instruction_lineno(&blocks[next.idx()].instructions[0]) + { + assert_ne!( + instruction_lineno(last), + instruction_lineno(&blocks[next.idx()].instructions[0]), + "redundant jump has same line as fallthrough target" + ); + return false; } } } current = block.next; } - predecessors + true } -fn record_incoming_origin(origins: &mut [Vec], target: BlockIdx, source: BlockIdx) { - let incoming = &mut origins[target.idx()]; - if !incoming.contains(&source) { - incoming.push(source); +fn remove_redundant_nops_and_jumps(blocks: &mut [Block]) -> crate::InternalResult<()> { + loop { + // Convergence is guaranteed because the number of redundant jumps and + // nops only decreases. + let removed_nops = remove_redundant_nops(blocks)?; + let removed_jumps = remove_redundant_jumps(blocks)?; + if removed_nops + removed_jumps == 0 { + break; + } } + Ok(()) } -fn compute_incoming_origins(blocks: &[Block], reachable: &[bool]) -> Vec> { - let mut origins = vec![Vec::new(); blocks.len()]; +/// flowgraph.c make_cfg_traversal_stack +fn make_cfg_traversal_stack(blocks: &mut [Block]) -> crate::InternalResult { + debug_assert!(!blocks.is_empty()); + let mut nblocks = 0; let mut current = BlockIdx(0); while current != BlockIdx::NULL { - if !reachable[current.idx()] { - current = blocks[current.idx()].next; - continue; + blocks[current.idx()].visited = false; + nblocks += 1; + current = blocks[current.idx()].next; + } + debug_assert!(nblocks > 0); + let mut stack = Vec::new(); + stack + .try_reserve_exact(nblocks) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + stack.resize(nblocks, BlockIdx::NULL); + let stack = CfgTraversalStack { stack, sp: 0 }; + debug_assert_eq!(stack.capacity(), nblocks); + Ok(stack) +} + +fn blocks_new_block(blocks: &mut Vec) -> crate::InternalResult { + blocks + .try_reserve(1) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + let block_idx = BlockIdx( + blocks + .len() + .to_u32() + .ok_or(InternalError::MalformedControlFlowGraph)?, + ); + blocks.push(Block::default()); + Ok(block_idx) +} + +/// flowgraph.c struct _PyCfgBuilder +struct CfgBuilder { + blocks: Vec, + entry: BlockIdx, + block_list: BlockIdx, + current: BlockIdx, + current_label: InstructionSequenceLabel, +} + +/// flowgraph.c cfg_builder_new_block +fn cfg_builder_new_block(g: &mut CfgBuilder) -> crate::InternalResult { + let block = blocks_new_block(&mut g.blocks)?; + g.blocks[block.idx()].allocation_next = g.block_list; + g.blocks[block.idx()].cpython_label = InstructionSequenceLabel::NO_LABEL; + g.block_list = block; + Ok(block) +} + +/// flowgraph.c cfg_builder_use_next_block +fn cfg_builder_use_next_block(g: &mut CfgBuilder, block: BlockIdx) -> BlockIdx { + debug_assert!(block != BlockIdx::NULL); + g.blocks[g.current.idx()].next = block; + g.current = block; + block +} + +/// flowgraph.c init_cfg_builder +fn init_cfg_builder(g: &mut CfgBuilder) -> crate::InternalResult<()> { + g.block_list = BlockIdx::NULL; + let block = cfg_builder_new_block(g)?; + g.entry = block; + g.current = block; + g.current_label = InstructionSequenceLabel::NO_LABEL; + Ok(()) +} + +/// flowgraph.c _PyCfgBuilder_New +fn cfg_builder_new() -> crate::InternalResult { + let mut builder = CfgBuilder { + blocks: Vec::new(), + entry: BlockIdx::NULL, + block_list: BlockIdx::NULL, + current: BlockIdx::NULL, + current_label: InstructionSequenceLabel::NO_LABEL, + }; + init_cfg_builder(&mut builder)?; + Ok(builder) +} + +/// flowgraph.c cfg_builder_current_block_is_terminated +fn cfg_builder_current_block_is_terminated(g: &mut CfgBuilder) -> bool { + let block = &mut g.blocks[g.current.idx()]; + let last = basicblock_last_instr(block).copied(); + if last.is_some_and(|last| last.instr.is_terminator()) { + return true; + } + if is_label(g.current_label) { + if last.is_some() || is_label(block.cpython_label) { + return true; } + block.cpython_label = g.current_label; + g.current_label = InstructionSequenceLabel::NO_LABEL; + } + false +} - let block = &blocks[current.idx()]; - if block_has_fallthrough(block) { - let next = next_nonempty_block(blocks, block.next); - if next != BlockIdx::NULL && reachable[next.idx()] { - record_incoming_origin(&mut origins, next, current); - } +/// flowgraph.c cfg_builder_maybe_start_new_block +fn cfg_builder_maybe_start_new_block(g: &mut CfgBuilder) -> crate::InternalResult<()> { + if cfg_builder_current_block_is_terminated(g) { + let block = cfg_builder_new_block(g)?; + g.blocks[block.idx()].cpython_label = g.current_label; + g.current_label = InstructionSequenceLabel::NO_LABEL; + cfg_builder_use_next_block(g, block); + } + Ok(()) +} + +/// flowgraph.c _PyCfgBuilder_UseLabel +fn cfg_builder_use_label( + g: &mut CfgBuilder, + label_id: InstructionSequenceLabel, +) -> crate::InternalResult<()> { + g.current_label = label_id; + cfg_builder_maybe_start_new_block(g) +} + +/// flowgraph.c _PyCfgBuilder_Addop +fn cfg_builder_addop(g: &mut CfgBuilder, info: InstructionInfo) -> crate::InternalResult<()> { + cfg_builder_maybe_start_new_block(g)?; + basicblock_addop(&mut g.blocks[g.current.idx()], info) +} + +/// flowgraph.c cfg_builder_check +fn cfg_builder_check(g: &CfgBuilder) -> bool { + debug_assert!(g.entry != BlockIdx::NULL); + debug_assert!(g.blocks[g.entry.idx()].instruction_used != 0); + let mut block = g.block_list; + while block != BlockIdx::NULL { + debug_assert!(block.idx() < g.blocks.len()); + let block_ref = &g.blocks[block.idx()]; + let has_instr_array = block_ref.instruction_allocation > 0; + if has_instr_array { + debug_assert!(block_ref.instruction_allocation > 0); + debug_assert_eq!( + block_ref.instructions.len(), + block_ref.instruction_allocation + ); + debug_assert!(block_ref.instruction_allocation >= block_ref.instruction_used); + } else { + debug_assert_eq!(block_ref.instruction_used, 0); + debug_assert_eq!(block_ref.instruction_allocation, 0); } - for ins in &block.instructions { - if ins.target != BlockIdx::NULL { - let target = next_nonempty_block(blocks, ins.target); - if target != BlockIdx::NULL && reachable[target.idx()] { - record_incoming_origin(&mut origins, target, current); - } + block = block_ref.allocation_next; + } + true +} + +/// flowgraph.c _PyCfgBuilder_CheckSize +fn cfg_builder_check_size(g: &CfgBuilder) -> crate::InternalResult<()> { + debug_assert!(g.entry != BlockIdx::NULL); + debug_assert!(g.block_list != BlockIdx::NULL); + debug_assert!(g.current != BlockIdx::NULL); + let mut nblocks = 0usize; + let mut block = g.block_list; + while block != BlockIdx::NULL { + debug_assert!(block.idx() < g.blocks.len()); + nblocks += 1; + block = g.blocks[block.idx()].allocation_next; + } + debug_assert_eq!(nblocks, g.blocks.len()); + if nblocks > usize::MAX / core::mem::size_of::() { + return Err(InternalError::MalformedControlFlowGraph); + } + Ok(()) +} + +/// flowgraph.c translate_jump_labels_to_targets +fn translate_jump_labels_to_targets(blocks: &mut [Block]) -> crate::InternalResult<()> { + let max_label = get_max_label(blocks); + let label_count = (max_label + 1) as usize; + if label_count > usize::MAX / core::mem::size_of::() { + return Err(InternalError::MalformedControlFlowGraph); + } + let mut label_to_block = Vec::new(); + vec_try_reserve_exact(&mut label_to_block, label_count)?; + label_to_block.resize(label_count, BlockIdx::NULL); + + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let block = &blocks[block_idx.idx()]; + if is_label(block.cpython_label) { + let label_id = block.cpython_label; + debug_assert!(label_id.0 <= max_label); + label_to_block[label_id.idx()] = block_idx; + } + block_idx = block.next; + } + + block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next = blocks[block_idx.idx()].next; + for i in 0..blocks[block_idx.idx()].instruction_used { + let info = &mut blocks[block_idx.idx()].instructions[i]; + debug_assert_eq!(info.target, BlockIdx::NULL); + if info.instr.has_target() { + let lbl = u32::from(info.arg) as i32; + debug_assert!(lbl >= 0 && lbl <= max_label); + let target = label_to_block[lbl as usize]; + debug_assert!(target != BlockIdx::NULL); + debug_assert_eq!( + blocks[target.idx()].cpython_label, + InstructionSequenceLabel(lbl) + ); + info.target = target; } } - current = block.next; + block_idx = next; } - origins + Ok(()) } -fn duplicate_exits_without_lineno(blocks: &mut Vec, predecessors: &mut Vec) { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let block = &blocks[current.idx()]; - let last = match block.instructions.last() { - Some(ins) if ins.target != BlockIdx::NULL && is_jump_instruction(ins) => ins, - _ => { - current = blocks[current.idx()].next; - continue; - } - }; +/// flowgraph.c _PyCfg_FromInstructionSequence +fn cfg_from_instruction_sequence( + mut instr_sequence: InstructionSequence, +) -> crate::InternalResult> { + instruction_sequence_apply_label_map(&mut instr_sequence)?; + let mut builder = cfg_builder_new()?; - let target = next_nonempty_block(blocks, last.target); - if target == BlockIdx::NULL { - current = blocks[current.idx()].next; - continue; + for i in 0..instr_sequence.instr_used { + instr_sequence.instrs[i].i_target = 0; + } + for i in 0..instr_sequence.instr_used { + if instr_sequence.instrs[i].info.instr.has_target() { + let target_offset = u32::from(instr_sequence.instrs[i].info.arg) as usize; + debug_assert!(target_offset < instr_sequence.instr_used); + instr_sequence.instrs[target_offset].i_target = 1; } - let target_is_exit_without_lineno = is_exit_without_lineno(blocks, target); - let target_is_protected_eval_break = - is_eval_break_without_lineno(blocks, target) && last.except_handler.is_some(); - if !target_is_exit_without_lineno && !target_is_protected_eval_break { - current = blocks[current.idx()].next; + } + let InstructionSequence { + instrs, + instr_used, + label_map, + label_map_allocation, + annotations_code, + .. + } = instr_sequence; + debug_assert!(label_map.is_none()); + debug_assert_eq!(label_map_allocation, 0); + + let mut offset = 0i32; + + let mut i = 0; + while i < instr_used { + let mut entry = instrs[i]; + if matches!( + entry.info.instr.pseudo(), + Some(PseudoInstruction::AnnotationsPlaceholder) + ) { + if let Some(annotations_code) = &annotations_code { + debug_assert!(annotations_code.label_map.is_none()); + debug_assert_eq!(annotations_code.label_map_allocation, 0); + for j in 0..annotations_code.instr_used { + let ann_entry = annotations_code.instrs[j]; + debug_assert!(!ann_entry.info.instr.has_target()); + let mut info = ann_entry.info; + info.target = BlockIdx::NULL; + cfg_builder_addop(&mut builder, info)?; + } + offset += annotations_code.instr_used as i32 - 1; + } else { + offset -= 1; + } + i += 1; continue; } - if predecessors[target.idx()] <= 1 { - current = blocks[current.idx()].next; - continue; + + if entry.i_target != 0 { + let label_id = i as i32 + offset; + let label = InstructionSequenceLabel(label_id); + cfg_builder_use_label(&mut builder, label)?; } - // Copy the exit block and splice it into the linked list after the - // original target block, matching CPython's copy_basicblock() layout. - let new_idx = BlockIdx(blocks.len() as u32); - let mut new_block = blocks[target.idx()].clone(); - if let Some(first) = new_block.instructions.first_mut() - && let Some((location, end_location)) = propagation_location(last) - { - overwrite_location(first, location, end_location); + let opcode = entry.info.instr; + let mut oparg = entry.info.arg; + if opcode.has_target() { + let target_offset = u32::from(oparg) as i32 + offset; + debug_assert!(target_offset >= 0); + oparg = OpArg::new(target_offset as u32); } - let old_next = blocks[target.idx()].next; - new_block.next = old_next; - blocks.push(new_block); - blocks[target.idx()].next = new_idx; - - // Update the jump target - let last_mut = blocks[current.idx()].instructions.last_mut().unwrap(); - last_mut.target = new_idx; - predecessors[target.idx()] -= 1; - predecessors.push(1); - current = blocks[current.idx()].next; + entry.info.instr = opcode; + entry.info.arg = oparg; + entry.info.target = BlockIdx::NULL; + cfg_builder_addop(&mut builder, entry.info)?; + i += 1; } - let reachable = compute_reachable_blocks(blocks); - let incoming_origins = compute_incoming_origins(blocks, &reachable); - current = BlockIdx(0); - while current != BlockIdx::NULL { - let block = &blocks[current.idx()]; - if let Some(last) = block.instructions.last() - && block_has_fallthrough(block) - { - let target = next_nonempty_block(blocks, block.next); - if target != BlockIdx::NULL - && (predecessors[target.idx()] == 1 - || has_unique_fallthrough_origin( - blocks, - &reachable, - &incoming_origins, - current, - target, - )) - && (is_exit_without_lineno(blocks, target) - || is_eval_break_without_lineno(blocks, target)) - && let Some((location, end_location)) = propagation_location(last) - && let Some(first) = blocks[target.idx()].instructions.first_mut() - { - maybe_propagate_location(first, location, end_location); - } + cfg_builder_check_size(&builder)?; + debug_assert!(cfg_builder_check(&builder)); + Ok(builder.blocks) +} + +/// flowgraph.c maybe_push +fn maybe_push( + blocks: &mut [Block], + worklist: &mut CfgTraversalStack, + block: BlockIdx, + unsafe_mask: u64, +) { + debug_assert!(block != BlockIdx::NULL); + + let idx = block.idx(); + let both = blocks[idx].unsafe_locals_mask | unsafe_mask; + if blocks[idx].unsafe_locals_mask != both { + blocks[idx].unsafe_locals_mask = both; + if !blocks[idx].visited { + worklist.push(block); + blocks[idx].visited = true; } - current = blocks[current.idx()].next; } } -fn propagate_line_numbers(blocks: &mut [Block], predecessors: &[u32]) { - let reachable = compute_reachable_blocks(blocks); - let incoming_origins = compute_incoming_origins(blocks, &reachable); - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - if !blocks[current.idx()].instructions.is_empty() { - let (next_block, has_fallthrough) = { - let block = &blocks[current.idx()]; - (block.next, block_has_fallthrough(block)) - }; +/// flowgraph.c scan_block_for_locals +fn scan_block_for_locals( + blocks: &mut [Block], + block_idx: BlockIdx, + worklist: &mut CfgTraversalStack, +) { + let idx = block_idx.idx(); + let mut unsafe_mask = blocks[idx].unsafe_locals_mask; + let instr_count = blocks[idx].instruction_used; - let prev_location = { - let block = &mut blocks[current.idx()]; - let mut prev_location = None; - for instr in &mut block.instructions { - if let Some((location, end_location)) = prev_location { - maybe_propagate_location(instr, location, end_location); - } - prev_location = propagation_location(instr); - } - prev_location - }; - let last = blocks[current.idx()].instructions.last().copied().unwrap(); - - if has_fallthrough { - let target = next_nonempty_block(blocks, next_block); - if target != BlockIdx::NULL - && (predecessors[target.idx()] == 1 - || has_unique_fallthrough_origin( - blocks, - &reachable, - &incoming_origins, - current, - target, - )) - && let Some((location, end_location)) = prev_location - && let Some(first) = blocks[target.idx()].instructions.first_mut() - { - maybe_propagate_location(first, location, end_location); + for i in 0..instr_count { + let (instr, arg, except_handler) = { + let info = &blocks[idx].instructions[i]; + ( + info.instr, + info.arg, + info.except_handler.map(|eh| eh.handler_block), + ) + }; + debug_assert!(!matches!(instr.real(), Some(Instruction::ExtendedArg))); + + if let Some(handler_block) = except_handler { + maybe_push(blocks, worklist, handler_block, unsafe_mask); + } + + let oparg = u32::from(arg) as usize; + if oparg >= LOCAL_UNSAFE_MASK_BITS { + continue; + } + + let bit = 1u64 << oparg; + match instr { + AnyInstruction::Real( + Instruction::DeleteFast { .. } | Instruction::LoadFastAndClear { .. }, + ) + | AnyInstruction::Pseudo(PseudoInstruction::StoreFastMaybeNull { .. }) => { + unsafe_mask |= bit; + } + AnyInstruction::Real(Instruction::StoreFast { .. }) => { + unsafe_mask &= !bit; + } + AnyInstruction::Real(Instruction::LoadFastCheck { .. }) => { + // If this doesn't raise, then the local is defined. + unsafe_mask &= !bit; + } + AnyInstruction::Real(Instruction::LoadFast { .. }) => { + if unsafe_mask & bit != 0 { + blocks[idx].instructions[i].instr = Opcode::LoadFastCheck.into(); } + unsafe_mask &= !bit; } + _ => {} + } + } - if is_jump_instruction(&last) { - let mut target = next_nonempty_block(blocks, last.target); - while target != BlockIdx::NULL - && blocks[target.idx()].instructions.is_empty() - && predecessors[target.idx()] == 1 - { - target = blocks[target.idx()].next; + let next = blocks[idx].next; + if next != BlockIdx::NULL && bb_has_fallthrough(&blocks[idx]) { + maybe_push(blocks, worklist, next, unsafe_mask); + } + + let last = basicblock_last_instr(&blocks[idx]).copied(); + if let Some(last) = last + && is_jump(&last) + { + let target = last.target; + debug_assert!(target != BlockIdx::NULL); + maybe_push(blocks, worklist, target, unsafe_mask); + } +} + +/// flowgraph.c fast_scan_many_locals +fn fast_scan_many_locals(blocks: &mut [Block], nlocals: usize) -> crate::InternalResult<()> { + debug_assert!(nlocals > LOCAL_UNSAFE_MASK_BITS); + let mut states = Vec::new(); + states + .try_reserve_exact(nlocals - LOCAL_UNSAFE_MASK_BITS) + .map_err(|_| InternalError::MalformedControlFlowGraph)?; + states.resize(nlocals - LOCAL_UNSAFE_MASK_BITS, 0usize); + let mut blocknum = 0usize; + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + blocknum += 1; + for i in 0..blocks[current.idx()].instruction_used { + let info = &mut blocks[current.idx()].instructions[i]; + debug_assert!(!matches!(info.instr.real(), Some(Instruction::ExtendedArg))); + let arg = u32::from(info.arg) as usize; + if arg < LOCAL_UNSAFE_MASK_BITS { + continue; + } + debug_assert!(arg >= LOCAL_UNSAFE_MASK_BITS); + match info.instr { + AnyInstruction::Real( + Instruction::DeleteFast { .. } | Instruction::LoadFastAndClear { .. }, + ) + | AnyInstruction::Pseudo(PseudoInstruction::StoreFastMaybeNull { .. }) => { + debug_assert!(arg < nlocals); + states[arg - LOCAL_UNSAFE_MASK_BITS] = blocknum - 1; } - if target != BlockIdx::NULL - && (predecessors[target.idx()] == 1 - || has_unique_jump_origin( - blocks, - &reachable, - &incoming_origins, - current, - target, - )) - && let Some((location, end_location)) = prev_location - && let Some(first) = blocks[target.idx()].instructions.first_mut() - { - maybe_propagate_location(first, location, end_location); + AnyInstruction::Real(Instruction::StoreFast { .. }) => { + debug_assert!(arg < nlocals); + states[arg - LOCAL_UNSAFE_MASK_BITS] = blocknum; + } + AnyInstruction::Real(Instruction::LoadFast { .. }) => { + debug_assert!(arg < nlocals); + if states[arg - LOCAL_UNSAFE_MASK_BITS] != blocknum { + info.instr = Opcode::LoadFastCheck.into(); + } + states[arg - LOCAL_UNSAFE_MASK_BITS] = blocknum; } + _ => {} } } current = blocks[current.idx()].next; } + Ok(()) } -fn resolve_line_numbers(blocks: &mut Vec) { - let mut predecessors = compute_predecessors(blocks); - duplicate_exits_without_lineno(blocks, &mut predecessors); - propagate_line_numbers(blocks, &predecessors); -} +/// flowgraph.c add_checks_for_loads_of_uninitialized_variables +fn add_checks_for_loads_of_uninitialized_variables( + blocks: &mut [Block], + mut nlocals: usize, + nparams: usize, +) -> crate::InternalResult<()> { + if nlocals == 0 { + return Ok(()); + } -fn find_layout_predecessor(blocks: &[Block], target: BlockIdx) -> BlockIdx { - if target == BlockIdx::NULL { - return BlockIdx::NULL; + if nlocals > LOCAL_UNSAFE_MASK_BITS { + fast_scan_many_locals(blocks, nlocals)?; + nlocals = LOCAL_UNSAFE_MASK_BITS; + } + + let mut worklist = make_cfg_traversal_stack(blocks)?; + let mut start_mask = 0u64; + for i in nparams..nlocals { + start_mask |= 1u64 << i; } + maybe_push(blocks, &mut worklist, BlockIdx(0), start_mask); + let mut current = BlockIdx(0); while current != BlockIdx::NULL { - if blocks[current.idx()].next == target { - return current; - } + scan_block_for_locals(blocks, current, &mut worklist); current = blocks[current.idx()].next; } - BlockIdx::NULL -} -fn compute_layout_predecessors(blocks: &[Block]) -> Vec { - let mut predecessors = vec![BlockIdx::NULL; blocks.len()]; - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let next = blocks[current.idx()].next; - if next != BlockIdx::NULL { - predecessors[next.idx()] = current; - } - current = next; + while let Some(block_idx) = worklist.pop() { + blocks[block_idx.idx()].visited = false; + scan_block_for_locals(blocks, block_idx, &mut worklist); } - predecessors + Ok(()) } -fn has_unique_fallthrough_origin( - blocks: &[Block], - reachable: &[bool], - incoming_origins: &[Vec], - source: BlockIdx, - target: BlockIdx, -) -> bool { - if source == BlockIdx::NULL - || target == BlockIdx::NULL - || !reachable[source.idx()] - || !block_has_fallthrough(&blocks[source.idx()]) - || next_nonempty_block(blocks, blocks[source.idx()].next) != target - { - return false; +/// Follow chain of empty blocks to find first non-empty block. +fn next_nonempty_block(blocks: &[Block], mut idx: BlockIdx) -> BlockIdx { + while idx != BlockIdx::NULL && blocks[idx.idx()].instruction_used == 0 { + idx = blocks[idx.idx()].next; } + idx +} - let chain_start = blocks[source.idx()].next; - let mut current = chain_start; - while current != target { - if current == BlockIdx::NULL { - return false; - } - if !blocks[current.idx()].instructions.is_empty() { - return false; - } - current = blocks[current.idx()].next; +fn instruction_lineno(instr: &InstructionInfo) -> i32 { + match instr.lineno_override { + Some(LINE_ONLY_LOCATION_OVERRIDE) | None => instr.location.line.get() as i32, + Some(lineno) => lineno, } +} - fn empty_chain_contains( - blocks: &[Block], - mut current: BlockIdx, - target: BlockIdx, - needle: BlockIdx, - ) -> bool { - while current != target { - if current == needle { - return true; - } - current = blocks[current.idx()].next; - } +fn instruction_is_no_location(instr: &InstructionInfo) -> bool { + instruction_lineno(instr) == NO_LOCATION_OVERRIDE +} + +/// flowgraph.c basicblock_nofallthrough +fn basicblock_nofallthrough(block: &Block) -> bool { + let last = basicblock_last_instr(block); + last.is_some_and(|last| last.instr.is_scope_exit() || last.instr.is_unconditional_jump()) +} + +/// flowgraph.c BB_NO_FALLTHROUGH +fn bb_no_fallthrough(block: &Block) -> bool { + basicblock_nofallthrough(block) +} + +/// flowgraph.c BB_HAS_FALLTHROUGH +fn bb_has_fallthrough(block: &Block) -> bool { + !bb_no_fallthrough(block) +} + +/// flowgraph.c add_checks_for_loads_of_uninitialized_variables uses uint64_t masks. +const LOCAL_UNSAFE_MASK_BITS: usize = 64; + +/// flowgraph.c MAX_COPY_SIZE +const MAX_COPY_SIZE: usize = 4; + +/// flowgraph.c is_jump +fn is_jump(instr: &InstructionInfo) -> bool { + instr.instr.has_jump() +} + +/// flowgraph.c is_block_push +fn is_block_push(instr: &InstructionInfo) -> bool { + instr.instr.is_block_push() +} + +/// flowgraph.c basicblock_returns +#[cfg(test)] +fn basicblock_returns(block: &Block) -> bool { + let last = basicblock_last_instr(block); + if let Some(last) = last { + matches!(last.instr.real(), Some(Instruction::ReturnValue)) + } else { false } +} - incoming_origins[target.idx()].iter().all(|&origin| { - origin == source || empty_chain_contains(blocks, chain_start, target, origin) - }) +/// flowgraph.c basicblock_exits_scope +fn basicblock_exits_scope(block: &Block) -> bool { + let last = basicblock_last_instr(block); + last.is_some_and(|last| last.instr.is_scope_exit()) } -fn has_unique_jump_origin( - blocks: &[Block], - reachable: &[bool], - incoming_origins: &[Vec], - source: BlockIdx, - target: BlockIdx, -) -> bool { - if source == BlockIdx::NULL - || target == BlockIdx::NULL - || !reachable[source.idx()] - || !blocks[source.idx()] - .instructions - .last() - .is_some_and(|instr| is_jump_instruction(instr) && instr.target != BlockIdx::NULL) - { - return false; +/// flowgraph.c is_exit_or_eval_check_without_lineno +fn is_exit_or_eval_check_without_lineno(block: &Block) -> bool { + if basicblock_exits_scope(block) || basicblock_has_eval_break(block) { + basicblock_has_no_lineno(block) + } else { + false } - - incoming_origins[target.idx()].iter().all(|&origin| { - origin == source - || (blocks[origin.idx()].instructions.is_empty() - && next_nonempty_block(blocks, blocks[origin.idx()].next) == target) - }) } -fn comes_before(blocks: &[Block], first: BlockIdx, second: BlockIdx) -> bool { - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - if current == first { +/// flowgraph.c basicblock_has_eval_break +fn basicblock_has_eval_break(block: &Block) -> bool { + let mut i = 0; + while i < block.instruction_used { + if block.instructions[i].instr.has_eval_break() { return true; } - if current == second { + i += 1; + } + false +} + +/// flowgraph.c basicblock_has_no_lineno +fn basicblock_has_no_lineno(block: &Block) -> bool { + let mut i = 0; + while i < block.instruction_used { + if instruction_lineno(&block.instructions[i]) >= 0 { return false; } - current = blocks[current.idx()].next; + i += 1; } - false + true } -fn duplicate_shared_jump_back_targets(blocks: &mut Vec) { - let predecessors = compute_predecessors(blocks); - let mut clones = Vec::new(); - let mut lineful_clones_before_target = Vec::new(); - let mut block_order = vec![usize::MAX; blocks.len()]; +/// flowgraph.c copy_basicblock +fn copy_basicblock( + blocks: &mut Vec, + block_idx: BlockIdx, +) -> crate::InternalResult { + debug_assert!(bb_no_fallthrough(&blocks[block_idx.idx()])); + let result = blocks_new_block(blocks)?; + basicblock_append_block_instructions(blocks, result, block_idx)?; + Ok(result) +} + +/// flowgraph.c get_max_label +fn get_max_label(blocks: &[Block]) -> i32 { + let mut lbl = -1; let mut current = BlockIdx(0); - let mut pos = 0usize; while current != BlockIdx::NULL { - block_order[current.idx()] = pos; - pos += 1; + let cpython_label = blocks[current.idx()].cpython_label; + lbl = lbl.max(cpython_label.0); current = blocks[current.idx()].next; } + lbl +} - for target in 0..blocks.len() { - let target = BlockIdx(target as u32); - if is_jump_back_only_block(blocks, target) - && instruction_lineno(&blocks[target.idx()].instructions[0]) >= 0 - { - let jump_target = - next_nonempty_block(blocks, blocks[target.idx()].instructions[0].target); - let layout_pred = find_layout_predecessor(blocks, target); - if jump_target != BlockIdx::NULL - && comes_before(blocks, jump_target, target) - && layout_pred != BlockIdx::NULL - && !block_has_fallthrough(&blocks[layout_pred.idx()]) - && predecessors[target.idx()] >= 2 +fn duplicate_exits_without_lineno(blocks: &mut Vec) -> crate::InternalResult<()> { + let mut next_lbl = get_max_label(blocks) + 1; + + let entryblock = BlockIdx(0); + let mut b = entryblock; + while b != BlockIdx::NULL { + let Some(last) = basicblock_last_instr(&blocks[b.idx()]).copied() else { + b = blocks[b.idx()].next; + continue; + }; + if is_jump(&last) { + debug_assert!(last.target != BlockIdx::NULL); + let target = next_nonempty_block(blocks, last.target); + debug_assert!(target != BlockIdx::NULL); + if is_exit_or_eval_check_without_lineno(&blocks[target.idx()]) + && blocks[target.idx()].predecessors > 1 { - let target_location = blocks[target.idx()].instructions[0].location; - let target_end_location = blocks[target.idx()].instructions[0].end_location; - let target_follows_forward_jump = blocks[layout_pred.idx()] - .instructions - .last() - .is_some_and(|info| { - matches!(info.instr.real(), Some(Instruction::JumpForward { .. })) - }); - for block_idx in 0..blocks.len() { - let block_idx = BlockIdx(block_idx as u32); - for (instr_idx, info) in blocks[block_idx.idx()].instructions.iter().enumerate() - { - if !is_conditional_jump(&info.instr) - || info.target == BlockIdx::NULL - || next_nonempty_block(blocks, info.target) != target - || block_order[block_idx.idx()] >= block_order[target.idx()] - { - continue; - } - let jump_lineno = instruction_lineno(info); - if target_follows_forward_jump && target_location.line >= info.location.line - { - continue; - } - if jump_lineno >= 0 - && (info.location != target_location - || info.end_location != target_end_location) - { - lineful_clones_before_target.push((target, block_idx, instr_idx)); - } - } - } - } + let new_target = copy_basicblock(blocks, target)?; + instr_set_location( + &mut blocks[new_target.idx()].instructions[0], + instr_location(&last), + ); + let last_mut = basicblock_last_instr_mut(&mut blocks[b.idx()]).unwrap(); + last_mut.target = new_target; + blocks[target.idx()].predecessors -= 1; + blocks[new_target.idx()].predecessors = 1; + blocks[new_target.idx()].next = blocks[target.idx()].next; + blocks[new_target.idx()].cpython_label = InstructionSequenceLabel(next_lbl); + next_lbl += 1; + blocks[target.idx()].next = new_target; + } + } + b = blocks[b.idx()].next; + } + + b = entryblock; + while b != BlockIdx::NULL { + let next = blocks[b.idx()].next; + if bb_has_fallthrough(&blocks[b.idx()]) + && next != BlockIdx::NULL + && blocks[b.idx()].instruction_used != 0 + && is_exit_or_eval_check_without_lineno(&blocks[next.idx()]) + { + let last = *basicblock_last_instr(&blocks[b.idx()]).expect("block has instructions"); + instr_set_location( + &mut blocks[next.idx()].instructions[0], + instr_location(&last), + ); } + b = blocks[b.idx()].next; + } + Ok(()) +} - let Some(jump_target) = shared_jump_back_target(&blocks[target.idx()]) else { +fn propagate_line_numbers(blocks: &mut [Block]) { + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + let idx = current.idx(); + let Some(last) = basicblock_last_instr(&blocks[idx]).copied() else { + current = blocks[idx].next; continue; }; - if blocks[target.idx()] - .instructions - .iter() - .any(|info| matches!(info.instr.real(), Some(Instruction::PopExcept))) - { - continue; - } - let jump_target = next_nonempty_block(blocks, jump_target); - if jump_target == BlockIdx::NULL || !comes_before(blocks, jump_target, target) { - continue; - } - if !has_non_exception_loop_backedge_to(blocks, target, jump_target) { - continue; + let mut prev_location = no_instruction_location(); + for i in 0..blocks[idx].instruction_used { + if instruction_is_no_location(&blocks[idx].instructions[i]) { + instr_set_location(&mut blocks[idx].instructions[i], prev_location); + } else { + prev_location = instr_location(&blocks[idx].instructions[i]); + } } - let layout_pred = find_layout_predecessor(blocks, target); - if layout_pred != BlockIdx::NULL - && (!block_has_fallthrough(&blocks[layout_pred.idx()]) - || next_nonempty_block(blocks, blocks[layout_pred.idx()].next) != target) - && predecessors[target.idx()] >= 2 - { - let mut jump_predecessors = Vec::new(); - for block_idx in 0..blocks.len() { - let block_idx = BlockIdx(block_idx as u32); - for (instr_idx, info) in blocks[block_idx.idx()].instructions.iter().enumerate() { - if !is_jump_instruction(info) || info.target == BlockIdx::NULL { - continue; - } - if next_nonempty_block(blocks, info.target) == target { - jump_predecessors.push((block_idx, instr_idx)); - } - } - } - if jump_predecessors.len() >= 2 - && jump_predecessors.iter().all(|(block_idx, instr_idx)| { - is_conditional_jump(&blocks[block_idx.idx()].instructions[*instr_idx].instr) - && block_order[block_idx.idx()] < block_order[target.idx()] - }) - && let Some((keep_block, keep_instr)) = jump_predecessors - .iter() - .max_by_key(|(block_idx, _)| block_order[block_idx.idx()]) - .copied() + let next = blocks[idx].next; + if bb_has_fallthrough(&blocks[idx]) { + debug_assert!(next != BlockIdx::NULL); + if next != BlockIdx::NULL + && blocks[next.idx()].predecessors == 1 + && blocks[next.idx()].instruction_used != 0 + && instruction_is_no_location(&blocks[next.idx()].instructions[0]) { - for (block_idx, instr_idx) in jump_predecessors { - if block_idx == keep_block && instr_idx == keep_instr { - continue; - } - clones.push((target, block_idx, instr_idx)); - } - continue; + instr_set_location(&mut blocks[next.idx()].instructions[0], prev_location); } } - if layout_pred == BlockIdx::NULL - || !block_has_fallthrough(&blocks[layout_pred.idx()]) - || next_nonempty_block(blocks, blocks[layout_pred.idx()].next) != target - || predecessors[target.idx()] < 2 - { - continue; - } - - for block_idx in 0..blocks.len() { - let block_idx = BlockIdx(block_idx as u32); - if block_idx == target || block_idx == layout_pred { - continue; - } - - for (instr_idx, info) in blocks[block_idx.idx()].instructions.iter().enumerate() { - if !is_jump_instruction(info) || info.target == BlockIdx::NULL { - continue; - } - if next_nonempty_block(blocks, info.target) != target { - continue; + if is_jump(&last) { + let target = last.target; + debug_assert!(target != BlockIdx::NULL); + if blocks[target.idx()].predecessors == 1 { + let instr = basicblock_raw_first_instr_mut(&mut blocks[target.idx()]); + if instruction_is_no_location(instr) { + instr_set_location(instr, prev_location); } - clones.push((target, block_idx, instr_idx)); } } + current = blocks[current.idx()].next; } +} - lineful_clones_before_target.sort_by_key(|(target, block_idx, _)| { - ( - block_order[target.idx()], - usize::MAX - block_order[block_idx.idx()], - ) - }); - for (target, block_idx, instr_idx) in lineful_clones_before_target { - if next_nonempty_block( - blocks, - blocks[block_idx.idx()].instructions[instr_idx].target, - ) != target - { - continue; - } - let jump = blocks[block_idx.idx()].instructions[instr_idx]; - let layout_pred = find_layout_predecessor(blocks, target); - if layout_pred == BlockIdx::NULL { - continue; - } +fn resolve_line_numbers( + blocks: &mut Vec, + _firstlineno: OneIndexed, +) -> crate::InternalResult<()> { + duplicate_exits_without_lineno(blocks)?; + propagate_line_numbers(blocks); + Ok(()) +} - let mut cloned = blocks[target.idx()].clone(); - if let Some(first) = cloned.instructions.first_mut() { - overwrite_location(first, jump.location, jump.end_location); - } - let new_idx = BlockIdx(blocks.len() as u32); - cloned.next = target; - blocks.push(cloned); - blocks[layout_pred.idx()].next = new_idx; - blocks[block_idx.idx()].instructions[instr_idx].target = new_idx; - } - - for (target, block_idx, instr_idx) in clones.into_iter().rev() { - let jump = blocks[block_idx.idx()].instructions[instr_idx]; - let mut cloned = blocks[target.idx()].clone(); - if let Some(first) = cloned.instructions.first_mut() { - overwrite_location(first, jump.location, jump.end_location); - } +/// flowgraph.c make_except_stack +#[allow(clippy::unnecessary_wraps)] +fn make_except_stack() -> crate::InternalResult { + let handlers = [BlockIdx::NULL; CO_MAXBLOCKS + 2]; + debug_assert_eq!(handlers[0], BlockIdx::NULL); + Ok(CfgExceptStack { handlers, depth: 0 }) +} - let new_idx = BlockIdx(blocks.len() as u32); - let old_next = blocks[target.idx()].next; - cloned.next = old_next; - blocks.push(cloned); - blocks[target.idx()].next = new_idx; - blocks[block_idx.idx()].instructions[instr_idx].target = new_idx; - } +/// flowgraph.c copy_except_stack +#[allow(clippy::unnecessary_wraps)] +fn copy_except_stack(stack: &CfgExceptStack) -> crate::InternalResult { + debug_assert!(stack.depth <= CO_MAXBLOCKS + 1); + Ok(CfgExceptStack { + handlers: stack.handlers, + depth: stack.depth, + }) } -fn duplicate_fallthrough_jump_back_targets(blocks: &mut Vec) { - fn block_has_real_fallthrough_body(block: &Block) -> bool { - block.instructions.iter().any(|info| { - !matches!( - info.instr.real(), - Some(Instruction::Nop | Instruction::NotTaken | Instruction::PopTop) - ) - }) +/// flowgraph.c except_stack_top +fn except_stack_top(stack: &CfgExceptStack, blocks: &[Block]) -> Option { + debug_assert!(stack.depth <= CO_MAXBLOCKS + 1); + let handler_block = stack.handlers[stack.depth]; + if handler_block == BlockIdx::NULL { + return None; } + Some(ExceptHandlerInfo { + handler_block, + preserve_lasti: blocks[handler_block.idx()].preserve_lasti, + }) +} - let predecessors = compute_predecessors(blocks); - let mut clones = Vec::new(); +/// flowgraph.c push_except_block +fn push_except_block( + stack: &mut CfgExceptStack, + setup: InstructionInfo, + blocks: &mut [Block], +) -> Option { + debug_assert!(is_block_push(&setup)); + let instr = setup.instr; + let target = setup.target; + debug_assert!(target != BlockIdx::NULL); + if matches!( + instr.pseudo(), + Some(PseudoInstruction::SetupWith { .. } | PseudoInstruction::SetupCleanup { .. }) + ) { + blocks[target.idx()].preserve_lasti = true; + } + debug_assert!(stack.depth <= CO_MAXBLOCKS); + stack.depth += 1; + stack.handlers[stack.depth] = target; + debug_assert!(stack.depth <= CO_MAXBLOCKS + 1); + except_stack_top(stack, blocks) +} - let mut layout_pred = BlockIdx(0); - while layout_pred != BlockIdx::NULL { - if !block_has_fallthrough(&blocks[layout_pred.idx()]) - || !block_has_real_fallthrough_body(&blocks[layout_pred.idx()]) - { - layout_pred = blocks[layout_pred.idx()].next; - continue; - } +/// flowgraph.c pop_except_block +fn pop_except_block(stack: &mut CfgExceptStack, blocks: &[Block]) -> Option { + debug_assert!(stack.depth > 0); + stack.depth -= 1; + debug_assert!(stack.depth <= CO_MAXBLOCKS); + except_stack_top(stack, blocks) +} - let target = next_nonempty_block(blocks, blocks[layout_pred.idx()].next); - if target == BlockIdx::NULL - || predecessors[target.idx()] < 2 - || !is_jump_back_only_block(blocks, target) - { - layout_pred = blocks[layout_pred.idx()].next; - continue; - } - if blocks[target.idx()].instructions[0] - .lineno_override - .is_some_and(|lineno| lineno >= 0) - { - layout_pred = blocks[layout_pred.idx()].next; - continue; - } - if !block_has_no_lineno(&blocks[target.idx()]) - && trailing_conditional_jump_index(&blocks[layout_pred.idx()]).is_some() - { - layout_pred = blocks[layout_pred.idx()].next; - continue; - } - let jump_target = next_nonempty_block(blocks, blocks[target.idx()].instructions[0].target); - if jump_target == BlockIdx::NULL - || (!has_non_exception_loop_backedge_to(blocks, target, jump_target) - && !block_has_non_exception_loop_backedge_to(blocks, target, jump_target)) - { - layout_pred = blocks[layout_pred.idx()].next; - continue; - } +pub(crate) fn label_exception_targets(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut todo = make_cfg_traversal_stack(blocks)?; - let has_non_layout_jump_predecessor = blocks.iter().enumerate().any(|(idx, block)| { - let block_idx = BlockIdx(idx as u32); - block_idx != layout_pred - && block_idx != target - && block.instructions.iter().any(|info| { - is_jump_instruction(info) - && info.target != BlockIdx::NULL - && next_nonempty_block(blocks, info.target) == target - }) - }); - if has_non_layout_jump_predecessor { - clones.push((layout_pred, target)); - } + todo.push(BlockIdx(0)); + blocks[0].visited = true; + blocks[0].except_stack = Some(make_except_stack()?); - layout_pred = blocks[layout_pred.idx()].next; - } + while let Some(block_idx) = todo.pop() { + let bi = block_idx.idx(); + debug_assert!(blocks[bi].visited); + let mut stack = Some( + blocks[bi] + .except_stack + .take() + .expect("visited exception block has an except stack"), + ); + let mut handler = except_stack_top(stack.as_ref().expect("active exception stack"), blocks); + let mut last_yield_except_depth: i32 = -1; + let mut stack_transferred = false; - for (layout_pred, target) in clones.into_iter().rev() { - if next_nonempty_block(blocks, blocks[layout_pred.idx()].next) != target { - continue; + let instr_count = blocks[bi].instruction_used; + for i in 0..instr_count { + let info = blocks[bi].instructions[i]; + let instr = info.instr; + let target = info.target; + let arg = info.arg; + + if is_block_push(&info) { + debug_assert!(target != BlockIdx::NULL); + if !blocks[target.idx()].visited { + blocks[target.idx()].except_stack = Some(copy_except_stack( + stack.as_ref().expect("active exception stack"), + )?); + todo.push(target); + blocks[target.idx()].visited = true; + } + handler = push_except_block( + stack.as_mut().expect("active exception stack"), + info, + blocks, + ); + } else if instr.is_pop_block() { + handler = pop_except_block(stack.as_mut().expect("active exception stack"), blocks); + set_to_nop(&mut blocks[bi].instructions[i]); + } else if is_jump(&blocks[bi].instructions[i]) { + blocks[bi].instructions[i].except_handler = handler; + debug_assert_eq!(i, instr_count - 1); + + // CPython label_exception_targets(): copy the except stack + // when this block can also fall through, otherwise transfer it + // to the jump target. + debug_assert!(target != BlockIdx::NULL); + if !blocks[target.idx()].visited { + if bb_has_fallthrough(&blocks[bi]) { + blocks[target.idx()].except_stack = Some(copy_except_stack( + stack.as_ref().expect("active exception stack"), + )?); + } else { + blocks[target.idx()].except_stack = stack.take(); + stack_transferred = true; + todo.push(target); + blocks[target.idx()].visited = true; + break; + } + todo.push(target); + blocks[target.idx()].visited = true; + } + } else if matches!(instr.real(), Some(Instruction::YieldValue { .. })) { + blocks[bi].instructions[i].except_handler = handler; + last_yield_except_depth = + stack.as_ref().expect("active exception stack").depth as i32; + } else if let Some(Instruction::Resume { context: _ }) = instr.real() { + blocks[bi].instructions[i].except_handler = handler; + let resume_arg = u32::from(arg); + if resume_arg != u32::from(oparg::ResumeLocation::AtFuncStart) { + debug_assert!(last_yield_except_depth >= 0); + if last_yield_except_depth == 1 { + blocks[bi].instructions[i].arg = + OpArg::new(resume_arg | oparg::ResumeContext::DEPTH1_MASK); + } + last_yield_except_depth = -1; + } + } else { + blocks[bi].instructions[i].except_handler = handler; + } } - let Some(last) = blocks[layout_pred.idx()].instructions.last().copied() else { - continue; - }; - let new_idx = BlockIdx(blocks.len() as u32); - let mut cloned = blocks[target.idx()].clone(); - if let Some(first) = cloned.instructions.first_mut() { - overwrite_location(first, last.location, last.end_location); + let next = blocks[bi].next; + if !stack_transferred && bb_has_fallthrough(&blocks[bi]) { + debug_assert!(next != BlockIdx::NULL); + if next != BlockIdx::NULL && !blocks[next.idx()].visited { + blocks[next.idx()].except_stack = stack.take(); + todo.push(next); + blocks[next.idx()].visited = true; + } + } + } + #[cfg(debug_assertions)] + { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let block = &blocks[block_idx.idx()]; + debug_assert!(block.except_stack.is_none()); + block_idx = block.next; } - cloned.next = blocks[layout_pred.idx()].next; - blocks.push(cloned); - blocks[layout_pred.idx()].next = new_idx; } + Ok(()) } -/// Duplicate `LOAD_CONST None + RETURN_VALUE` for blocks that fall through -/// to the final return block. -fn duplicate_end_returns(blocks: &mut Vec, metadata: &CodeUnitMetadata) { - // Walk the block chain and keep the last non-cold non-empty block. - // After cold exception handlers are pushed to the end, the mainline - // return epilogue can sit before trailing cold blocks. - let mut last_block = BlockIdx::NULL; - let mut last_nonempty_block = BlockIdx::NULL; - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - if !blocks[current.idx()].instructions.is_empty() { - last_nonempty_block = current; - if !blocks[current.idx()].cold { - last_block = current; +/// Convert remaining pseudo ops to real instructions or NOP. +/// flowgraph.c convert_pseudo_ops +pub(crate) fn convert_pseudo_ops(blocks: &mut [Block]) -> crate::InternalResult<()> { + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next = blocks[block_idx.idx()].next; + let block = &mut blocks[block_idx.idx()]; + for i in 0..block.instruction_used { + let info = &mut block.instructions[i]; + if is_block_push(info) { + set_to_nop(info); + } else if matches!( + info.instr.pseudo(), + Some(PseudoInstruction::LoadClosure { .. }) + ) { + debug_assert!(is_pseudo_target( + PseudoOpcode::LoadClosure, + Opcode::LoadFast + )); + info.instr = Opcode::LoadFast.into(); + } else if matches!( + info.instr.pseudo(), + Some(PseudoInstruction::StoreFastMaybeNull { .. }) + ) { + debug_assert!(is_pseudo_target( + PseudoOpcode::StoreFastMaybeNull, + Opcode::StoreFast + )); + info.instr = Instruction::StoreFast { + var_num: Arg::marker(), + } + .into(); } } - current = blocks[current.idx()].next; - } - if last_block == BlockIdx::NULL { - last_block = last_nonempty_block; - } - if last_block == BlockIdx::NULL { - return; + block_idx = next; } + // CPython flowgraph.c::convert_pseudo_ops() finishes by calling + // remove_redundant_nops_and_jumps(). + remove_redundant_nops_and_jumps(blocks) +} - let last_insts = &blocks[last_block.idx()].instructions; - // Only apply when the last block is EXACTLY a return-None epilogue. - let is_return_block = last_insts.len() == 2 - && matches!( - last_insts[0].instr, - AnyInstruction::Real(Instruction::LoadConst { .. }) - ) - && is_load_const_none(&last_insts[0], metadata) - && last_insts[0].no_location_exit - && last_insts[1].no_location_exit - && matches!( - last_insts[1].instr, - AnyInstruction::Real(Instruction::ReturnValue) - ); - if !is_return_block { - return; +/// flowgraph.c build_cellfixedoffsets +#[allow(clippy::needless_range_loop)] +pub(crate) fn build_cellfixedoffsets( + metadata: &CodeUnitMetadata, +) -> crate::InternalResult> { + let nlocals = metadata.varnames.len(); + let ncellvars = metadata.cellvars.len(); + let nfreevars = metadata.freevars.len(); + let noffsets = ncellvars + nfreevars; + let mut fixed = Vec::new(); + vec_try_reserve_exact(&mut fixed, noffsets)?; + fixed.resize(noffsets, 0); + for i in 0..noffsets { + fixed[i] = (nlocals + i) as i32; + } + for oldindex in 0..ncellvars { + let varname = metadata + .cellvars + .get_index(oldindex) + .expect("cellvar index is in range"); + if let Some(varindex) = metadata.varnames.get_index_of(varname) { + let argoffset = varindex as i32; + fixed[oldindex] = argoffset; + } } + Ok(fixed) +} - // Get the return instructions to clone - let return_insts: Vec = last_insts[last_insts.len() - 2..].to_vec(); - let predecessors = compute_predecessors(blocks); +/// flowgraph.c fix_cell_offsets +#[allow(clippy::needless_range_loop)] +pub(crate) fn fix_cell_offsets( + metadata: &CodeUnitMetadata, + blocks: &mut [Block], + cellfixedoffsets: &mut [i32], +) -> usize { + let nlocals = metadata.varnames.len(); + let ncellvars = metadata.cellvars.len(); + let nfreevars = metadata.freevars.len(); + let noffsets = ncellvars + nfreevars; + debug_assert_eq!(cellfixedoffsets.len(), noffsets); - // Find non-cold blocks that reach the last return block either by - // fallthrough or as an unconditional jump target that should get its own - // cloned epilogue. - let mut fallthrough_blocks_to_fix = Vec::new(); - let mut jump_targets_to_fix = Vec::new(); - current = BlockIdx(0); - while current != BlockIdx::NULL { - let block = &blocks[current.idx()]; - let next = next_nonempty_block(blocks, block.next); - if current != last_block && !block.cold { - let last_ins = block.instructions.last(); - let has_fallthrough = last_ins - .is_none_or(|ins| !ins.instr.is_scope_exit() && !ins.instr.is_unconditional_jump()); - // Don't duplicate if block already ends with the same return pattern - let already_has_return = block.instructions.len() >= 2 && { - let n = block.instructions.len(); - matches!( - block.instructions[n - 2].instr, - AnyInstruction::Real(Instruction::LoadConst { .. }) - ) && matches!( - block.instructions[n - 1].instr, - AnyInstruction::Real(Instruction::ReturnValue) + let mut numdropped = 0usize; + for i in 0..noffsets { + if cellfixedoffsets[i] == (i + nlocals) as i32 { + cellfixedoffsets[i] -= numdropped as i32; + } else { + numdropped += 1; + } + } + + let mut block_idx = BlockIdx(0); + while block_idx != BlockIdx::NULL { + let next = blocks[block_idx.idx()].next; + let block = &mut blocks[block_idx.idx()]; + for i in 0..block.instruction_used { + let inst = &mut block.instructions[i]; + debug_assert!( + !matches!(inst.instr.real(), Some(Instruction::ExtendedArg)), + "fix_cell_offsets is called before extended args are generated" + ); + let oldoffset = u32::from(inst.arg) as i32; + match inst.instr { + AnyInstruction::Real( + Instruction::MakeCell { .. } + | Instruction::LoadDeref { .. } + | Instruction::StoreDeref { .. } + | Instruction::DeleteDeref { .. } + | Instruction::LoadFromDictOrDeref { .. }, ) - }; - if !block.except_handler - && next == last_block - && has_fallthrough - && trailing_conditional_jump_index(block).is_none() - && !already_has_return - { - fallthrough_blocks_to_fix.push(current); - } - let jump_idx = trailing_conditional_jump_index(block).or_else(|| { - block.instructions.last().and_then(|last| { - (last.instr.is_unconditional_jump() && last.target != BlockIdx::NULL) - .then_some(block.instructions.len() - 1) - }) - }); - if let Some(jump_idx) = jump_idx { - let jump = &block.instructions[jump_idx]; - if jump.target != BlockIdx::NULL - && next_nonempty_block(blocks, jump.target) == last_block - && (is_conditional_jump(&jump.instr) || predecessors[last_block.idx()] > 1) - { - jump_targets_to_fix.push((current, jump_idx)); + | AnyInstruction::Pseudo(PseudoInstruction::LoadClosure { .. }) => { + debug_assert!(oldoffset >= 0); + debug_assert!(oldoffset < noffsets as i32); + let fixed_offset = cellfixedoffsets[oldoffset as usize]; + debug_assert!(fixed_offset >= 0); + inst.arg = OpArg::new(fixed_offset as u32); } + _ => {} } } - current = blocks[current.idx()].next; + block_idx = next; } + numdropped +} - // Duplicate the return instructions at the end of fall-through blocks - for block_idx in fallthrough_blocks_to_fix { - let propagated_location = blocks[block_idx.idx()] - .instructions - .last() - .map(|instr| (instr.location, instr.end_location)); - let mut cloned_return = return_insts.clone(); - if !instruction_has_lineno(&cloned_return[0]) - && let Some((location, end_location)) = propagated_location - { - for instr in &mut cloned_return { - overwrite_location(instr, location, end_location); - } +#[cfg(test)] +mod tests { + use super::*; + + fn test_location(line: u32) -> SourceLocation { + SourceLocation { + line: OneIndexed::new(line as usize).expect("valid line number"), + character_offset: OneIndexed::MIN, } - blocks[block_idx.idx()].instructions.extend(cloned_return); } - // Clone the final return block for jump predecessors so their target layout - // matches CPython's duplicated exit blocks. - for (block_idx, instr_idx) in jump_targets_to_fix.into_iter().rev() { - let jump = blocks[block_idx.idx()].instructions[instr_idx]; - let mut cloned_return = return_insts.clone(); - if let Some(first) = cloned_return.first_mut() { - overwrite_location(first, jump.location, jump.end_location); + fn test_instr(instr: Instruction, line: u32) -> InstructionInfo { + InstructionInfo { + instr: instr.into(), + arg: OpArg::new(0), + target: BlockIdx::NULL, + location: test_location(line), + end_location: test_location(line), + except_handler: None, + lineno_override: None, } - let new_idx = BlockIdx(blocks.len() as u32); - let is_conditional = is_conditional_jump(&jump.instr); - let new_block = Block { - cold: blocks[last_block.idx()].cold, - except_handler: blocks[last_block.idx()].except_handler, - disable_load_fast_borrow: blocks[last_block.idx()].disable_load_fast_borrow, - instructions: cloned_return, - next: if is_conditional { - last_block - } else { - blocks[block_idx.idx()].next + } + + fn test_jump(target: BlockIdx, line: u32) -> InstructionInfo { + let mut instr = test_instr(Instruction::Nop, line); + instr.instr = PseudoOpcode::Jump.into(); + instr.target = target; + instr + } + + fn test_cond_jump(target: BlockIdx, line: u32) -> InstructionInfo { + let mut instr = test_instr(Instruction::Nop, line); + instr.instr = PseudoOpcode::JumpIfFalse.into(); + instr.target = target; + instr + } + + fn test_block_push(block: &mut Block, info: InstructionInfo) { + let off = basicblock_next_instr(block).expect("test block instruction slot"); + block.instructions[off] = info; + } + + fn test_code_info(block: Block) -> CodeInfo { + CodeInfo { + flags: CodeFlags::empty(), + source_path: "source_path".to_owned(), + private: None, + blocks: vec![block], + current_block: BlockIdx::new(0), + instr_sequence: instruction_sequence_new(), + instr_sequence_label_map: InstructionSequenceLabelMap::new(), + annotations_instr_sequence: None, + metadata: CodeUnitMetadata { + name: "".to_owned(), + qualname: Some("".to_owned()), + consts: Default::default(), + names: IndexSet::default(), + varnames: IndexSet::default(), + cellvars: IndexSet::default(), + freevars: IndexSet::default(), + fast_hidden: IndexMap::default(), + fast_hidden_final: IndexSet::default(), + argcount: 0, + posonlyargcount: 0, + kwonlyargcount: 0, + firstlineno: OneIndexed::MIN, }, - ..Block::default() - }; - blocks.push(new_block); - if is_conditional { - let layout_pred = find_layout_predecessor(blocks, last_block); - if layout_pred != BlockIdx::NULL { - blocks[layout_pred.idx()].next = new_idx; - } - } else { - blocks[block_idx.idx()].next = new_idx; + static_attributes: None, + in_inlined_comp: false, + fblock: Vec::new(), + symbol_table_index: 0, + nparams: 0, + in_conditional_block: 0, + next_conditional_annotation_index: 0, } - blocks[block_idx.idx()].instructions[instr_idx].target = new_idx; } -} -fn inline_small_fast_return_blocks(blocks: &mut [Block]) { - fn block_is_small_fast_return(block: &Block) -> bool { - if block.instructions.len() > 3 { - return false; + #[test] + fn get_stack_effects_rejects_cpython_deopt_opcodes() { + match get_stack_effects(Instruction::BinaryOpAddInt.into(), OpArg::new(0), 0) { + Err(InternalError::InvalidStackEffect) => {} + Err(err) => panic!("unexpected stack-effect error: {err}"), + Ok(_) => panic!("CPython get_stack_effects rejects specialized deopt opcodes"), } - let real: Vec<_> = block - .instructions - .iter() - .filter(|info| !matches!(info.instr.real(), Some(Instruction::Nop))) - .collect(); - matches!( - real.as_slice(), - [load, ret] - if matches!( - load.instr.real(), - Some(Instruction::LoadFast { .. } | Instruction::LoadFastBorrow { .. }) - ) && matches!(ret.instr.real(), Some(Instruction::ReturnValue)) - ) } - loop { - let mut changed = false; - let mut current = BlockIdx(0); - while current != BlockIdx::NULL { - let next = blocks[current.idx()].next; - let Some(last) = blocks[current.idx()].instructions.last().copied() else { - current = next; - continue; - }; - if !last.instr.is_unconditional_jump() || last.target == BlockIdx::NULL { - current = next; - continue; - } - let target = next_nonempty_block(blocks, last.target); - if target == BlockIdx::NULL || !block_is_small_fast_return(&blocks[target.idx()]) { - current = next; - continue; - } + #[test] + fn instruction_sequence_label_shadow_preserves_cpython_offset_aliases() { + let mut seq = instruction_sequence_new(); + let mut labels = InstructionSequenceLabelMap::new(); + instruction_sequence_label_map_push_unmapped_label(&mut labels, &mut seq).unwrap(); + instruction_sequence_label_map_push_unmapped_label(&mut labels, &mut seq).unwrap(); + assert_eq!( + labels.cpython_block_by_label.len(), + INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE + ); - if jump_thread_kind(last.instr) == Some(JumpThreadKind::NoInterrupt) - || instruction_has_lineno(&last) - { - let last_instr = blocks[current.idx()].instructions.last_mut().unwrap(); - let lineno_override = last_instr.lineno_override; - set_to_nop(last_instr); - last_instr.lineno_override = lineno_override; - } else { - blocks[current.idx()].instructions.pop(); - } - let cloned = blocks[target.idx()].instructions.clone(); - blocks[current.idx()].instructions.extend(cloned); - changed = true; - current = next; - } - if !changed { - break; - } + let first = BlockIdx::new(1); + let second = BlockIdx::new(2); + assert_ne!( + instruction_sequence_label_map_label_for_block(&labels, first), + instruction_sequence_label_map_label_for_block(&labels, second) + ); + + // CPython `_PyInstructionSequence_UseLabel()` can map consecutive + // labels to the same instruction offset. The codegen CFG shadow must + // resolve the later block label to the block owning that shared offset. + instruction_sequence_label_map_use_label_at_block(&mut labels, &mut seq, second, first) + .unwrap(); + assert_eq!( + instruction_sequence_label_map_resolve_label(&labels, first), + first + ); + assert_eq!( + instruction_sequence_label_map_resolve_label(&labels, second), + first + ); } -} -fn is_fast_store_load_return_block(block: &Block) -> bool { - let [store, load, ret] = block.instructions.as_slice() else { - return false; - }; - let stored = match store.instr.real() { - Some(Instruction::StoreFast { var_num }) => usize::from(var_num.get(store.arg)), - _ => return false, - }; - let loaded = match load.instr.real() { - Some(Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num }) => { - usize::from(var_num.get(load.arg)) - } - _ => return false, - }; - stored == loaded && matches!(ret.instr.real(), Some(Instruction::ReturnValue)) -} + #[test] + fn except_stack_tracks_cpython_depth_and_handler_slots() { + let mut stack = make_except_stack().unwrap(); + assert_eq!(stack.depth, 0); + assert_eq!(stack.handlers.len(), CO_MAXBLOCKS + 2); + assert_eq!(stack.handlers[0], BlockIdx::NULL); -fn inline_unprotected_tuple_genexpr_assignment_return_blocks(blocks: &mut [Block]) { - for block_idx in 0..blocks.len() { - if block_is_protected(&blocks[block_idx]) { - continue; - } + let mut blocks = vec![Block::default(), Block::default()]; + assert!(except_stack_top(&stack, &blocks).is_none()); - let Some(jump_idx) = blocks[block_idx].instructions.len().checked_sub(1) else { - continue; + let setup = InstructionInfo { + instr: PseudoInstruction::SetupWith { + delta: Arg::marker(), + } + .into(), + arg: OpArg::new(0), + target: BlockIdx::new(1), + location: SourceLocation::default(), + end_location: SourceLocation::default(), + except_handler: None, + lineno_override: None, }; - let jump = blocks[block_idx].instructions[jump_idx]; - if !jump.instr.is_unconditional_jump() || jump.target == BlockIdx::NULL { - continue; - } - - let previous_is_list_to_tuple = blocks[block_idx].instructions[..jump_idx] - .iter() - .rev() - .find_map(|info| match info.instr.real() { - Some(Instruction::CallIntrinsic1 { func }) => { - Some(func.get(info.arg) == IntrinsicFunction1::ListToTuple) - } - Some(Instruction::Nop | Instruction::NotTaken) => None, - Some(_) => Some(false), - None => None, - }) - .unwrap_or(false); - if !previous_is_list_to_tuple { - continue; - } + let handler = push_except_block(&mut stack, setup, &mut blocks).unwrap(); + assert_eq!(stack.depth, 1); + assert_eq!(stack.handlers[1], BlockIdx::new(1)); + assert_eq!(handler.handler_block, BlockIdx::new(1)); + assert!(handler.preserve_lasti); + assert!(blocks[1].preserve_lasti); - let target = next_nonempty_block(blocks, jump.target); - if target == BlockIdx::NULL || !is_fast_store_load_return_block(&blocks[target.idx()]) { - continue; - } + let copy = copy_except_stack(&stack).unwrap(); + assert_eq!(copy.depth, stack.depth); + assert_eq!(copy.handlers, stack.handlers); - let mut cloned = blocks[target.idx()].instructions.clone(); - if let Some(first) = cloned.first_mut() { - overwrite_location(first, jump.location, jump.end_location); - } - blocks[block_idx].instructions.pop(); - blocks[block_idx].instructions.extend(cloned); + assert!(pop_except_block(&mut stack, &blocks).is_none()); + assert_eq!(stack.depth, 0); } -} -fn inline_with_suppress_return_blocks(blocks: &mut [Block]) { - fn has_with_suppress_prefix(block: &Block, jump_idx: usize) -> bool { - let tail: Vec<_> = block.instructions[..jump_idx] - .iter() - .filter_map(|info| info.instr.real()) - .rev() - .take(5) - .collect(); - matches!( - tail.as_slice(), - [ - Instruction::PopTop, - Instruction::PopTop, - Instruction::PopTop, - Instruction::PopExcept, - Instruction::PopTop, - ] + #[test] + fn ref_stack_tracks_cpython_size_and_allocated_refs() { + let mut stack = RefStack { + refs: Vec::new(), + size: 0, + capacity: 0, + }; + ref_stack_push(&mut stack, Ref { instr: 7, local: 3 }).unwrap(); + assert_eq!(stack.size, 1); + assert_eq!(stack.capacity, 32); + assert_eq!(stack.refs.len(), 32); + assert_eq!(ref_stack_at(&stack, 0).instr, 7); + assert_eq!(ref_stack_at(&stack, 0).local, 3); + + ref_stack_clear(&mut stack); + assert_eq!(stack.size, 0); + assert_eq!(stack.capacity, 32); + assert_eq!(stack.refs.len(), 32); + + ref_stack_push( + &mut stack, + Ref { + instr: DUMMY_INSTR, + local: NOT_LOCAL, + }, ) + .unwrap(); + assert_eq!(stack.size, 1); + assert_eq!(ref_stack_pop(&mut stack).instr, DUMMY_INSTR); + assert_eq!(stack.size, 0); } - for block_idx in 0..blocks.len() { - let Some(jump_idx) = blocks[block_idx].instructions.len().checked_sub(1) else { - continue; + #[test] + fn cfg_traversal_stack_resets_visited_and_allocates_for_blocks() { + let mut blocks = vec![Block::default(), Block::default()]; + blocks[0].next = BlockIdx::new(1); + blocks[0].visited = true; + blocks[1].visited = true; + + let mut stack = make_cfg_traversal_stack(&mut blocks).unwrap(); + assert!(!blocks[0].visited); + assert!(!blocks[1].visited); + assert!(stack.capacity() >= 2); + assert_eq!(stack.pop(), None); + + stack.push(BlockIdx::new(1)); + stack.push(BlockIdx::new(0)); + assert_eq!(stack.pop(), Some(BlockIdx::new(0))); + assert_eq!(stack.pop(), Some(BlockIdx::new(1))); + assert_eq!(stack.pop(), None); + } + + #[test] + fn instruction_sequence_insert_preserves_cpython_slot_metadata() { + let handler = InstructionSequenceExceptHandlerInfo { + h_label: 7, + start_depth: 3, + preserve_lasti: 1, }; - let jump = blocks[block_idx].instructions[jump_idx]; - if !jump.instr.is_unconditional_jump() || jump.target == BlockIdx::NULL { - continue; - } - if !has_with_suppress_prefix(&blocks[block_idx], jump_idx) { - continue; - } + let mut seq = instruction_sequence_new(); + let entry = instruction_sequence_addop(&mut seq, test_instr(Instruction::Nop, 11)).unwrap(); + entry.except_handler = handler; + entry.i_target = 1; + entry.i_offset = 42; + + instruction_sequence_insert_instruction(&mut seq, 0, test_instr(Instruction::PopTop, 12)) + .unwrap(); + + // CPython `_PyInstructionSequence_InsertInstruction()` shifts the + // backing instruction slots, then overwrites only opcode/oparg/loc. + let inserted = &seq.instrs[0]; + assert!(matches!( + inserted.info.instr.real(), + Some(Instruction::PopTop) + )); + assert_eq!(inserted.except_handler.h_label, handler.h_label); + assert_eq!(inserted.except_handler.start_depth, handler.start_depth); + assert_eq!( + inserted.except_handler.preserve_lasti, + handler.preserve_lasti + ); + assert_eq!(inserted.i_target, 1); + assert_eq!(inserted.i_offset, 42); + } - let target = next_nonempty_block(blocks, jump.target); - if target == BlockIdx::NULL || !is_const_return_block(&blocks[target.idx()]) { - continue; + #[test] + fn instruction_sequence_tracks_cpython_c_array_allocation() { + let mut seq = instruction_sequence_new(); + for i in 0..99 { + instruction_sequence_addop(&mut seq, test_instr(Instruction::Nop, 10 + i)).unwrap(); } + assert_eq!(seq.instr_allocation, INITIAL_INSTR_SEQUENCE_SIZE); + assert_eq!(seq.instrs.len(), seq.instr_allocation); + assert_eq!(seq.instr_used, 99); - let mut cloned_return = blocks[target.idx()].instructions.clone(); - for instr in &mut cloned_return { - overwrite_location(instr, jump.location, jump.end_location); - } - blocks[block_idx].instructions.pop(); - blocks[block_idx].instructions.extend(cloned_return); + // CPython calls `_Py_CArray_EnsureCapacity(s_used + 1)`, so the 100th + // instruction expands a 100-slot array to 200 before returning offset 99. + instruction_sequence_addop(&mut seq, test_instr(Instruction::Nop, 109)).unwrap(); + assert_eq!(seq.instr_allocation, INITIAL_INSTR_SEQUENCE_SIZE * 2); + assert_eq!(seq.instrs.len(), seq.instr_allocation); + assert_eq!(seq.instr_used, 100); } -} - -fn is_named_except_cleanup_return_block(block: &Block, metadata: &CodeUnitMetadata) -> bool { - matches!( - block.instructions.as_slice(), - [pop_except, load_none1, store, delete, load_none2, ret] - if matches!(pop_except.instr.real(), Some(Instruction::PopExcept)) - && is_load_const_none(load_none1, metadata) - && matches!( - store.instr.real(), - Some(Instruction::StoreFast { .. } | Instruction::StoreName { .. }) - ) - && matches!( - delete.instr.real(), - Some(Instruction::DeleteFast { .. } | Instruction::DeleteName { .. }) - ) - && is_load_const_none(load_none2, metadata) - && matches!(ret.instr.real(), Some(Instruction::ReturnValue)) - ) -} -fn duplicate_named_except_cleanup_returns(blocks: &mut Vec, metadata: &CodeUnitMetadata) { - let predecessors = compute_predecessors(blocks); - let mut clones = Vec::new(); - - for target in 0..blocks.len() { - let target = BlockIdx(target as u32); - if !is_named_except_cleanup_return_block(&blocks[target.idx()], metadata) { - continue; - } + #[test] + fn instruction_sequence_label_map_tracks_cpython_c_array_allocation() { + let mut seq = instruction_sequence_new(); + instruction_sequence_use_label(&mut seq, InstructionSequenceLabel::from_index(1)).unwrap(); + assert_eq!( + seq.label_map_allocation, + INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE + ); + assert_eq!( + seq.label_map.as_ref().expect("label map allocated").len(), + INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE + ); - let layout_pred = find_layout_predecessor(blocks, target); - if layout_pred == BlockIdx::NULL - || next_nonempty_block(blocks, blocks[layout_pred.idx()].next) != target - { - continue; - } + // CPython passes the label id itself to `_Py_CArray_EnsureCapacity()`. + // Label 10 therefore expands the initial 10-slot map to 20. + instruction_sequence_use_label(&mut seq, InstructionSequenceLabel::from_index(10)).unwrap(); + assert_eq!( + seq.label_map_allocation, + INITIAL_INSTR_SEQUENCE_LABELS_MAP_SIZE * 2 + ); + } - let fallthroughs_into_target = blocks[layout_pred.idx()] - .instructions - .last() - .is_none_or(|ins| !ins.instr.is_scope_exit() && !ins.instr.is_unconditional_jump()); - if !fallthroughs_into_target || predecessors[target.idx()] < 2 { - continue; - } + #[test] + fn basicblock_addop_reuses_cpython_spare_except_handler_slot() { + let handler = ExceptHandlerInfo { + handler_block: BlockIdx::new(7), + preserve_lasti: true, + }; + let mut block = Block::default(); + let mut stale = test_instr(Instruction::Nop, 11); + stale.except_handler = Some(handler); + test_block_push(&mut block, stale); + basicblock_clear(&mut block); - let target_lineno = blocks[target.idx()] - .instructions - .first() - .map_or(-1, instruction_lineno); - let layout_pred_lineno = blocks[layout_pred.idx()] - .instructions - .iter() - .rev() - .find(|info| info.instr.real().is_some()) - .map_or(-1, instruction_lineno); - if target_lineno > 0 && layout_pred_lineno > 0 && target_lineno != layout_pred_lineno { - continue; - } + basicblock_addop(&mut block, test_instr(Instruction::PopTop, 12)) + .expect("basicblock_addop succeeds"); - for block_idx in 0..blocks.len() { - if block_idx == target.idx() { - continue; - } - let Some(instr_idx) = trailing_conditional_jump_index(&blocks[block_idx]) else { - continue; - }; - if next_nonempty_block(blocks, blocks[block_idx].instructions[instr_idx].target) - != target - { - continue; - } - clones.push((BlockIdx(block_idx as u32), instr_idx, target)); - } + // CPython `basicblock_addop()` writes opcode/oparg/target/location into + // the reused `b_instr[b_iused]` slot, but does not clear `i_except`. + assert_eq!(block.instruction_used, 1); + assert_eq!(block.instructions[0].except_handler, Some(handler)); + assert_eq!(block.instructions[0].target, BlockIdx::NULL); } - for (block_idx, instr_idx, target) in clones.into_iter().rev() { - let jump = blocks[block_idx.idx()].instructions[instr_idx]; - let mut cloned = blocks[target.idx()].instructions.clone(); - if let Some(first) = cloned.first_mut() { - overwrite_location(first, jump.location, jump.end_location); + #[test] + fn basicblock_next_instr_tracks_cpython_c_array_allocation() { + let mut block = Block::default(); + for i in 0..15 { + basicblock_addop(&mut block, test_instr(Instruction::PopTop, 10 + i)) + .expect("basicblock_addop succeeds"); } + assert_eq!(block.instruction_allocation, DEFAULT_BLOCK_SIZE); - let new_idx = BlockIdx(blocks.len() as u32); - let next = blocks[target.idx()].next; - blocks.push(Block { - cold: blocks[target.idx()].cold, - except_handler: blocks[target.idx()].except_handler, - disable_load_fast_borrow: blocks[target.idx()].disable_load_fast_borrow, - instructions: cloned, - next, - ..Block::default() - }); - blocks[target.idx()].next = new_idx; - blocks[block_idx.idx()].instructions[instr_idx].target = new_idx; + // CPython calls `_Py_CArray_EnsureCapacity(b_iused + 1)`, so the 16th + // instruction expands a 16-slot array to 32 before returning offset 15. + basicblock_addop(&mut block, test_instr(Instruction::PopTop, 25)) + .expect("basicblock_addop succeeds"); + assert_eq!(block.instruction_allocation, DEFAULT_BLOCK_SIZE * 2); } -} - -fn is_const_return_block(block: &Block) -> bool { - block.instructions.len() == 2 - && matches!( - block.instructions[0].instr.real(), - Some(Instruction::LoadConst { .. }) - ) - && matches!( - block.instructions[1].instr.real(), - Some(Instruction::ReturnValue) - ) -} -fn inline_pop_except_return_blocks(blocks: &mut [Block]) { - for block_idx in 0..blocks.len() { - let Some(jump_idx) = blocks[block_idx].instructions.len().checked_sub(1) else { - continue; + #[test] + fn basicblock_insert_instruction_consumes_spare_without_inheriting_except_handler() { + let handler = ExceptHandlerInfo { + handler_block: BlockIdx::new(9), + preserve_lasti: false, }; - let jump = blocks[block_idx].instructions[jump_idx]; - if !jump.instr.is_unconditional_jump() || jump.target == BlockIdx::NULL { - continue; - } + let mut block = Block::default(); + test_block_push(&mut block, test_instr(Instruction::Nop, 21)); + let mut stale = test_instr(Instruction::Nop, 22); + stale.except_handler = Some(handler); + test_block_push(&mut block, stale); + block.instruction_used = 1; - let Some(last_real_before_jump) = blocks[block_idx].instructions[..jump_idx] - .iter() - .rev() - .find_map(|info| info.instr.real()) - else { - continue; + basicblock_insert_instruction(&mut block, 0, test_instr(Instruction::PopTop, 23)) + .expect("basicblock_insert_instruction succeeds"); + + // CPython `basicblock_insert_instruction()` also obtains a slot with + // `basicblock_next_instr()`, then overwrites the inserted position with + // the provided instruction copy, including its `i_except` value. + assert_eq!(block.instruction_used, 2); + assert_eq!(block.instructions[0].except_handler, None); + } + + #[test] + fn basicblock_clear_preserves_cpython_spare_slots() { + let handler = ExceptHandlerInfo { + handler_block: BlockIdx::new(3), + preserve_lasti: true, }; - if !matches!(last_real_before_jump, Instruction::PopExcept) { - continue; - } + let mut block = Block::default(); + let mut stale = test_instr(Instruction::PopTop, 31); + stale.except_handler = Some(handler); + test_block_push(&mut block, stale); - let target = next_nonempty_block(blocks, jump.target); - if target == BlockIdx::NULL || !is_const_return_block(&blocks[target.idx()]) { - continue; - } + basicblock_clear(&mut block); + basicblock_addop(&mut block, test_instr(Instruction::Nop, 32)) + .expect("basicblock_addop succeeds"); - let mut cloned_return = blocks[target.idx()].instructions.clone(); - for instr in &mut cloned_return { - overwrite_location(instr, jump.location, jump.end_location); - } - blocks[block_idx].instructions.pop(); - blocks[block_idx].instructions.extend(cloned_return); + // CPython `remove_unreachable()` sets `b_iused = 0` without clearing the + // backing `b_instr` slot. A later `basicblock_addop()` reuses that slot + // and does not overwrite `i_except`. + assert_eq!(block.instruction_used, 1); + assert_eq!(block.instructions[0].except_handler, Some(handler)); } -} -fn inline_named_except_cleanup_normal_exit_jumps(blocks: &mut [Block]) { - for block_idx in 0..blocks.len() { - let Some(jump_idx) = blocks[block_idx].instructions.len().checked_sub(1) else { - continue; - }; - let jump = blocks[block_idx].instructions[jump_idx]; - if !jump.instr.is_unconditional_jump() || jump.target == BlockIdx::NULL { - continue; + #[test] + fn basicblock_clear_reuses_cpython_spare_slots_in_offset_order() { + let mut block = Block::default(); + for i in 0..3 { + let mut stale = test_instr(Instruction::Nop, 35 + i); + stale.except_handler = Some(ExceptHandlerInfo { + handler_block: BlockIdx::new(i + 1), + preserve_lasti: false, + }); + test_block_push(&mut block, stale); } - let target = next_nonempty_block(blocks, jump.target); - if target == BlockIdx::NULL - || target == BlockIdx(block_idx as u32) - || !is_standalone_named_except_cleanup_normal_exit_block(&blocks[target.idx()]) - { - continue; + basicblock_clear(&mut block); + for i in 0..3 { + basicblock_addop(&mut block, test_instr(Instruction::PopTop, 38 + i)) + .expect("basicblock_addop succeeds"); } - let cloned_cleanup = blocks[target.idx()].instructions.clone(); - blocks[block_idx].instructions.pop(); - blocks[block_idx].instructions.extend(cloned_cleanup); + let handlers = block + .used_instructions() + .iter() + .map(|instr| { + instr + .except_handler + .expect("reused CPython slot") + .handler_block + }) + .collect::>(); + assert_eq!( + handlers, + [BlockIdx::new(1), BlockIdx::new(2), BlockIdx::new(3)] + ); } -} -/// Label exception targets: walk CFG with except stack, set per-instruction -/// handler info and block preserve_lasti flag. Converts POP_BLOCK to NOP. -/// flowgraph.c label_exception_targets + push_except_block -pub(crate) fn label_exception_targets(blocks: &mut [Block]) { - #[derive(Clone)] - struct ExceptEntry { - handler_block: BlockIdx, - preserve_lasti: bool, - } + #[test] + fn basicblock_append_instructions_overwrites_cpython_spare_slot() { + let handler = ExceptHandlerInfo { + handler_block: BlockIdx::new(5), + preserve_lasti: false, + }; + let mut blocks = vec![Block::default(), Block::default()]; + let mut stale = test_instr(Instruction::Nop, 41); + stale.except_handler = Some(handler); + test_block_push(&mut blocks[0], stale); + basicblock_clear(&mut blocks[0]); - let num_blocks = blocks.len(); - if num_blocks == 0 { - return; + test_block_push(&mut blocks[1], test_instr(Instruction::PopTop, 42)); + basicblock_append_block_instructions(&mut blocks, BlockIdx::new(0), BlockIdx::new(1)) + .expect("basicblock_append_block_instructions succeeds"); + + // CPython `basicblock_append_instructions()` obtains a slot with + // `basicblock_next_instr()`, then overwrites it with the copied + // instruction, including `i_except`. + assert_eq!(blocks[0].instruction_used, 1); + assert_eq!(blocks[0].instructions[0].except_handler, None); } - let mut visited = vec![false; num_blocks]; - let mut block_stacks: Vec>> = vec![None; num_blocks]; + #[test] + fn instr_set_op0_nop_preserves_cpython_stale_target() { + let mut info = test_jump(BlockIdx::new(1), 50); + set_to_nop(&mut info); - // Entry block - visited[0] = true; - block_stacks[0] = Some(Vec::new()); + assert_eq!(info.target, BlockIdx::new(1)); - let mut todo = vec![BlockIdx(0)]; + let mut blocks = vec![Block::default(), Block::default()]; + test_block_push(&mut blocks[0], info); + blocks[0].next = BlockIdx::new(1); - while let Some(block_idx) = todo.pop() { - let bi = block_idx.idx(); - let mut stack = block_stacks[bi].take().unwrap_or_default(); - let mut last_yield_except_depth: i32 = -1; + let mut instr_sequence = instruction_sequence_new(); + cfg_to_instruction_sequence(&mut blocks, &mut instr_sequence) + .expect("non-target NOP should ignore stale CPython i_target"); + } - let instr_count = blocks[bi].instructions.len(); - for i in 0..instr_count { - // Read all needed fields (each temporary borrow ends immediately) - let target = blocks[bi].instructions[i].target; - let arg = blocks[bi].instructions[i].arg; - let is_push = blocks[bi].instructions[i].instr.is_block_push(); - let is_pop = blocks[bi].instructions[i].instr.is_pop_block(); - - if is_push { - // Determine preserve_lasti from instruction type (push_except_block) - let preserve_lasti = matches!( - blocks[bi].instructions[i].instr.pseudo(), - Some( - PseudoInstruction::SetupWith { .. } - | PseudoInstruction::SetupCleanup { .. } - ) - ); + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "target_block != BlockIdx::NULL")] + fn cfg_to_instruction_sequence_requires_target_for_target_opcodes() { + let mut block = Block::default(); + test_block_push(&mut block, test_jump(BlockIdx::NULL, 51)); + let mut blocks = vec![block]; - // Set preserve_lasti on handler block - if preserve_lasti && target != BlockIdx::NULL { - blocks[target.idx()].preserve_lasti = true; - } + let mut instr_sequence = instruction_sequence_new(); + let _ = cfg_to_instruction_sequence(&mut blocks, &mut instr_sequence); + } - // Propagate except stack to handler block if not visited - if target != BlockIdx::NULL && !visited[target.idx()] { - visited[target.idx()] = true; - block_stacks[target.idx()] = Some(stack.clone()); - todo.push(target); - } + #[test] + fn static_swaps_respect_cpython_no_location_line_boundary() { + let mut block = Block::default(); + let mut swap = test_instr(Instruction::Swap { i: Arg::marker() }, 60); + swap.arg = OpArg::new(2); + let mut store = test_instr( + Instruction::StoreFast { + var_num: Arg::marker(), + }, + 60, + ); + store.arg = OpArg::new(0); + let mut pop = test_instr(Instruction::PopTop, 60); + pop.lineno_override = Some(NO_LOCATION_OVERRIDE); + for info in [swap, store, pop] { + test_block_push(&mut block, info); + } - // Push handler onto except stack - stack.push(ExceptEntry { - handler_block: target, - preserve_lasti, - }); - } else if is_pop { - debug_assert!( - !stack.is_empty(), - "POP_BLOCK with empty except stack at block {bi} instruction {i}" - ); - stack.pop(); - // POP_BLOCK → NOP - let remove_no_location_nop = blocks[bi].instructions[i].remove_no_location_nop; - let folded_operand_nop = blocks[bi].instructions[i].folded_operand_nop; - let preserve_block_start_no_location_nop = - blocks[bi].instructions[i].preserve_block_start_no_location_nop; - set_to_nop(&mut blocks[bi].instructions[i]); - blocks[bi].instructions[i].remove_no_location_nop = remove_no_location_nop; - blocks[bi].instructions[i].folded_operand_nop = folded_operand_nop; - blocks[bi].instructions[i].preserve_block_start_no_location_nop = - preserve_block_start_no_location_nop; - } else { - // Set except_handler for this instruction from except stack top - // stack_depth placeholder: filled by fixup_handler_depths - let handler_info = stack.last().map(|e| ExceptHandlerInfo { - handler_block: e.handler_block, - stack_depth: 0, - preserve_lasti: e.preserve_lasti, - }); - blocks[bi].instructions[i].except_handler = handler_info; - - // Track YIELD_VALUE except stack depth - // Record the except stack depth at the point of yield. - // With the StopIteration wrapper, depth is naturally correct: - // - plain yield outside try: depth=1 → DEPTH1 set - // - yield inside try: depth=2+ → no DEPTH1 - // - yield-from/await: has internal SETUP_FINALLY → depth=2+ → no DEPTH1 - if let Some(Instruction::YieldValue { .. }) = - blocks[bi].instructions[i].instr.real() - { - last_yield_except_depth = stack.len() as i32; - } + apply_static_swaps_block(&mut block).expect("apply_static_swaps_block succeeds"); - // Set RESUME DEPTH1 flag based on last yield's except depth - if let Some(Instruction::Resume { context }) = - blocks[bi].instructions[i].instr.real() - { - let location = context.get(arg).location(); - match location { - oparg::ResumeLocation::AtFuncStart => {} - _ => { - if last_yield_except_depth == 1 { - blocks[bi].instructions[i].arg = - OpArg::new(oparg::ResumeContext::new(location, true).as_u32()); - } - last_yield_except_depth = -1; - } - } - } + // CPython `next_swappable_instruction()` compares `i_loc.lineno` + // directly, so a following NO_LOCATION swaperand does not match the + // first swaperand's positive line number. + assert!(matches!( + block.instructions[0].instr.real(), + Some(Instruction::Swap { .. }) + )); + assert!(matches!( + block.instructions[1].instr.real(), + Some(Instruction::StoreFast { .. }) + )); + assert!(matches!( + block.instructions[2].instr.real(), + Some(Instruction::PopTop) + )); - // For jump instructions, propagate except stack to target - if target != BlockIdx::NULL && !visited[target.idx()] { - visited[target.idx()] = true; - block_stacks[target.idx()] = Some(stack.clone()); - todo.push(target); - } - } + let mut block = Block::default(); + let mut swap = test_instr(Instruction::Swap { i: Arg::marker() }, 70); + swap.arg = OpArg::new(2); + let mut store = test_instr( + Instruction::StoreFast { + var_num: Arg::marker(), + }, + 70, + ); + store.arg = OpArg::new(0); + store.lineno_override = Some(NO_LOCATION_OVERRIDE); + let pop = test_instr(Instruction::PopTop, 71); + for info in [swap, store, pop] { + test_block_push(&mut block, info); } - // Propagate to fallthrough block (block.next) - let next = blocks[bi].next; - if next != BlockIdx::NULL && !visited[next.idx()] { - let has_fallthrough = blocks[bi] - .instructions - .last() - .is_none_or(|ins| !ins.instr.is_scope_exit() && !ins.instr.is_unconditional_jump()); // Empty block falls through - if has_fallthrough { - visited[next.idx()] = true; - block_stacks[next.idx()] = Some(stack); - todo.push(next); - } - } - } -} + apply_static_swaps_block(&mut block).expect("apply_static_swaps_block succeeds"); -/// Convert remaining pseudo ops to real instructions or NOP. -/// flowgraph.c convert_pseudo_ops -pub(crate) fn convert_pseudo_ops(blocks: &mut [Block], cellfixedoffsets: &[u32]) { - for block in blocks.iter_mut() { - for info in &mut block.instructions { - let Some(pseudo) = info.instr.pseudo() else { - continue; - }; - match pseudo { - // Block push pseudo ops → NOP - PseudoInstruction::SetupCleanup { .. } - | PseudoInstruction::SetupFinally { .. } - | PseudoInstruction::SetupWith { .. } => { - let preserve_block_start_no_location_nop = - info.preserve_block_start_no_location_nop; - set_to_nop(info); - info.preserve_block_start_no_location_nop = - preserve_block_start_no_location_nop; - } - // PopBlock in reachable blocks is converted to NOP by - // label_exception_targets. Dead blocks may still have them. - PseudoInstruction::PopBlock => { - let remove_no_location_nop = info.remove_no_location_nop; - let folded_operand_nop = info.folded_operand_nop; - let preserve_block_start_no_location_nop = - info.preserve_block_start_no_location_nop; - set_to_nop(info); - info.remove_no_location_nop = remove_no_location_nop; - info.folded_operand_nop = folded_operand_nop; - info.preserve_block_start_no_location_nop = - preserve_block_start_no_location_nop; - } - // LOAD_CLOSURE → LOAD_FAST (using cellfixedoffsets for merged layout) - PseudoInstruction::LoadClosure { i } => { - let cell_relative = i.get(info.arg) as usize; - let new_idx = cellfixedoffsets[cell_relative]; - info.arg = OpArg::new(new_idx); - info.instr = Opcode::LoadFast.into(); - } - // Jump pseudo ops are resolved during block linearization - PseudoInstruction::Jump { .. } | PseudoInstruction::JumpNoInterrupt { .. } => {} - PseudoInstruction::StoreFastMaybeNull { .. } => { - info.instr = Instruction::StoreFast { - var_num: Arg::marker(), - } - .into(); - } - // These should have been resolved earlier - PseudoInstruction::AnnotationsPlaceholder - | PseudoInstruction::JumpIfFalse { .. } - | PseudoInstruction::JumpIfTrue { .. } => { - unreachable!("Unexpected pseudo instruction in convert_pseudo_ops: {pseudo:?}") - } - } - } + // Conversely, when the first swaperand has NO_LOCATION, CPython passes + // `-1` as the line filter and does not enforce a boundary. + assert!(matches!( + block.instructions[0].instr.real(), + Some(Instruction::Nop) + )); + assert!(matches!( + block.instructions[1].instr.real(), + Some(Instruction::PopTop) + )); + assert!(matches!( + block.instructions[2].instr.real(), + Some(Instruction::StoreFast { .. }) + )); } -} -/// Build cellfixedoffsets mapping: cell/free index -> localsplus index. -/// Merged cells (cellvar also in varnames) get the local slot index. -/// Non-merged cells get slots after nlocals. Free vars follow. -pub(crate) fn build_cellfixedoffsets( - varnames: &IndexSet, - cellvars: &IndexSet, - freevars: &IndexSet, -) -> Vec { - let nlocals = varnames.len(); - let ncells = cellvars.len(); - let nfrees = freevars.len(); - let mut fixed = Vec::with_capacity(ncells + nfrees); - let mut numdropped = 0usize; - for (i, cellvar) in cellvars.iter().enumerate() { - if let Some(local_idx) = varnames.get_index_of(cellvar) { - fixed.push(local_idx as u32); - numdropped += 1; - } else { - fixed.push((nlocals + i - numdropped) as u32); - } + #[test] + fn optimize_load_const_tracks_cpython_copy_of_load_const() { + let mut block = Block::default(); + test_block_push( + &mut block, + test_instr( + Instruction::LoadConst { + consti: Arg::marker(), + }, + 80, + ), + ); + let mut copy = test_instr(Instruction::Copy { i: Arg::marker() }, 80); + copy.arg = OpArg::new(1); + test_block_push(&mut block, copy); + test_block_push(&mut block, test_instr(Instruction::ToBool, 80)); + + let mut code = test_code_info(block); + let (const_idx, _) = code.metadata.consts.insert_full(ConstantData::Tuple { + elements: vec![ConstantData::Integer { + value: BigInt::from(1), + }], + }); + code.blocks[0].instructions[0].arg = OpArg::new(const_idx as u32); + + optimize_load_const(&mut code.metadata, &mut code.blocks) + .expect("optimize_load_const succeeds"); + + // CPython `basicblock_optimize_load_const()` keeps the previous + // LOAD_CONST as the effective opcode for a following `COPY 1`, so the + // COPY is NOPed and TO_BOOL becomes LOAD_CONST True. + assert!(matches!( + code.blocks[0].instructions[0].instr.real(), + Some(Instruction::LoadConst { .. }) + )); + assert!(matches!( + code.blocks[0].instructions[1].instr.real(), + Some(Instruction::Nop) + )); + let load_bool = &code.blocks[0].instructions[2]; + assert!(matches!( + load_bool.instr.real(), + Some(Instruction::LoadConst { .. }) + )); + assert_eq!( + code.metadata.consts[u32::from(load_bool.arg) as usize], + ConstantData::Boolean { value: true } + ); } - for i in 0..nfrees { - fixed.push((nlocals + ncells - numdropped + i) as u32); + + #[test] + fn optimize_load_fast_records_no_input_opcode_ref_at_cpython_produced_index() { + let mut block = Block::default(); + test_block_push( + &mut block, + test_instr( + Instruction::LoadFast { + var_num: Arg::marker(), + }, + 10, + ), + ); + test_block_push(&mut block, test_instr(Instruction::GetLen, 10)); + let mut swap = test_instr(Instruction::Swap { i: Arg::marker() }, 10); + swap.arg = OpArg::new(2); + test_block_push(&mut block, swap); + test_block_push(&mut block, test_instr(Instruction::PopTop, 10)); + + let mut code = test_code_info(block); + optimize_load_fast(&mut code.blocks).expect("optimize_load_fast succeeds"); + + // CPython `optimize_load_fast()` shadows the outer instruction index in + // the produced-value loop for GET_LEN, so the produced ref is recorded + // with index 0 here. The original LOAD_FAST is therefore not considered + // the consumed producer. + assert!(matches!( + code.blocks[0].instructions[0].instr.real(), + Some(Instruction::LoadFast { .. }) + )); } - fixed -} -/// Convert DEREF instruction opargs from cell-relative indices to localsplus indices -/// using the cellfixedoffsets mapping. -pub(crate) fn fixup_deref_opargs(blocks: &mut [Block], cellfixedoffsets: &[u32]) { - for block in blocks.iter_mut() { - for info in &mut block.instructions { - let Some(instr) = info.instr.real() else { - continue; - }; - let needs_fixup = matches!( - instr.into(), - Opcode::LoadDeref - | Opcode::StoreDeref - | Opcode::DeleteDeref - | Opcode::LoadFromDictOrDeref - | Opcode::MakeCell - ); - if needs_fixup { - let cell_relative = u32::from(info.arg) as usize; - info.arg = OpArg::new(cellfixedoffsets[cell_relative]); - } + #[test] + fn constant_sequence_loads_use_cpython_opcode_has_const_metadata() { + let mut metadata = CodeUnitMetadata { + name: "".to_owned(), + qualname: Some("".to_owned()), + consts: Default::default(), + names: IndexSet::default(), + varnames: IndexSet::default(), + cellvars: IndexSet::default(), + freevars: IndexSet::default(), + fast_hidden: IndexMap::default(), + fast_hidden_final: IndexSet::default(), + argcount: 0, + posonlyargcount: 0, + kwonlyargcount: 0, + firstlineno: OneIndexed::MIN, + }; + let (left, _) = metadata + .consts + .insert_full(ConstantData::Str { value: "a".into() }); + let (right, _) = metadata + .consts + .insert_full(ConstantData::Str { value: "b".into() }); + + let mut immortal = test_instr(Instruction::Nop, 90); + immortal.instr = Opcode::LoadConstImmortal.into(); + immortal.arg = OpArg::new(left as u32); + let mut mortal = test_instr(Instruction::Nop, 90); + mortal.instr = Opcode::LoadConstMortal.into(); + mortal.arg = OpArg::new(right as u32); + let mut build = test_instr( + Instruction::BuildTuple { + count: Arg::marker(), + }, + 90, + ); + build.arg = OpArg::new(2); + let mut block = Block::default(); + for info in [immortal, mortal, build] { + test_block_push(&mut block, info); } - } -} -#[cfg(test)] -mod tests { - use super::*; + assert!( + fold_tuple_of_constants(&mut metadata, &mut block, 2) + .expect("fold_tuple_of_constants succeeds") + ); - fn instruction_info(instr: Instruction, arg: u32, target: BlockIdx) -> InstructionInfo { - InstructionInfo { - instr: instr.into(), - arg: OpArg::new(arg), - target, - location: SourceLocation::default(), - end_location: SourceLocation::default(), - except_handler: None, - folded_from_nonliteral_expr: false, - lineno_override: None, - cache_entries: 0, - preserve_redundant_jump_as_nop: false, - remove_no_location_nop: false, - folded_operand_nop: false, - no_location_exit: false, - preserve_block_start_no_location_nop: false, - match_success_jump: false, - } + // CPython `loads_const()` accepts every `OPCODE_HAS_CONST` opcode, not + // just canonical LOAD_CONST, so LOAD_CONST_IMMORTAL/MORTAL participate + // in constant-sequence folding. + assert!(matches!( + block.instructions[0].instr.real(), + Some(Instruction::Nop) + )); + assert!(matches!( + block.instructions[1].instr.real(), + Some(Instruction::Nop) + )); + let folded = &block.instructions[2]; + assert!(matches!( + folded.instr.real(), + Some(Instruction::LoadConst { .. }) + )); + assert!(matches!( + &metadata.consts[u32::from(folded.arg) as usize], + ConstantData::Tuple { elements } if elements.len() == 2 + )); } #[test] - fn short_circuit_stub_allows_only_marker_instructions_before_jump() { - let final_target = BlockIdx(7); - let block = Block { - instructions: vec![ - instruction_info(Instruction::Copy { i: Arg::marker() }, 1, BlockIdx::NULL), - instruction_info(Instruction::ToBool, 0, BlockIdx::NULL), - instruction_info(Instruction::Nop, 0, BlockIdx::NULL), - instruction_info(Instruction::NotTaken, 0, BlockIdx::NULL), - instruction_info( - Instruction::PopJumpIfFalse { - delta: Arg::marker(), - }, - 0, - final_target, - ), - ], - ..Block::default() - }; - + fn resolve_line_numbers_duplicates_exit_blocks_like_cpython() { + let exit = BlockIdx::new(2); + let mut blocks = vec![Block::default(), Block::default(), Block::default()]; + blocks[0].cpython_label = InstructionSequenceLabel::from_index(0); + blocks[1].cpython_label = InstructionSequenceLabel::from_index(1); + blocks[2].cpython_label = InstructionSequenceLabel::from_index(2); + blocks[0].next = BlockIdx::new(1); + test_block_push(&mut blocks[0], test_cond_jump(exit, 10)); + blocks[1].next = exit; + test_block_push(&mut blocks[1], test_jump(exit, 20)); + test_block_push(&mut blocks[2], test_instr(Instruction::ReturnValue, 30)); + blocks[2].instructions[0].lineno_override = Some(NO_LOCATION_OVERRIDE); + + remove_unreachable(&mut blocks).expect("remove_unreachable succeeds"); + resolve_line_numbers(&mut blocks, OneIndexed::MIN).expect("resolve_line_numbers succeeds"); + + // CPython `duplicate_exits_without_lineno()` copies a shared exit block + // reached by jumps so each copy can inherit its sole predecessor's line. + let duplicate = blocks[0].instructions[0].target; + assert_ne!(duplicate, exit); assert_eq!( - same_short_circuit_target( - &block, - Instruction::PopJumpIfFalse { - delta: Arg::marker(), - } - .into(), - ), - Some(final_target) + blocks[duplicate.idx()].cpython_label, + InstructionSequenceLabel::from_index(3) + ); + assert_eq!( + instruction_lineno(&blocks[duplicate.idx()].instructions[0]), + 10 ); + assert_eq!(blocks[1].instructions[0].target, exit); + assert_eq!(instruction_lineno(&blocks[exit.idx()].instructions[0]), 20); } #[test] - fn short_circuit_stub_rejects_real_instruction_before_jump() { - let block = Block { - instructions: vec![ - instruction_info(Instruction::Copy { i: Arg::marker() }, 1, BlockIdx::NULL), - instruction_info(Instruction::ToBool, 0, BlockIdx::NULL), - instruction_info(Instruction::PopTop, 0, BlockIdx::NULL), - instruction_info( - Instruction::PopJumpIfFalse { - delta: Arg::marker(), - }, - 0, - BlockIdx(7), - ), - ], - ..Block::default() - }; - + fn propagate_line_numbers_treats_next_location_like_cpython() { + let mut block = Block::default(); + test_block_push(&mut block, test_instr(Instruction::Nop, 10)); + test_block_push(&mut block, test_instr(Instruction::PopTop, 20)); + block.instructions[1].lineno_override = Some(NEXT_LOCATION_OVERRIDE); + test_block_push(&mut block, test_instr(Instruction::ReturnValue, 30)); + block.instructions[2].lineno_override = Some(NO_LOCATION_OVERRIDE); + let mut blocks = vec![block]; + + remove_unreachable(&mut blocks).expect("remove_unreachable succeeds"); + propagate_line_numbers(&mut blocks); + + // CPython `propagate_line_numbers()` only copies over NO_LOCATION + // (`lineno == NO_LOCATION`). `NEXT_LOCATION` (`lineno == -2`) becomes the + // current previous location and is copied to following NO_LOCATION + // instructions for assemble.c to resolve later. assert_eq!( - same_short_circuit_target( - &block, - Instruction::PopJumpIfFalse { - delta: Arg::marker(), - } - .into(), - ), - None + blocks[0].instructions[1].lineno_override, + Some(NEXT_LOCATION_OVERRIDE) ); - assert!(!opposite_short_circuit_target( - &block, - Instruction::PopJumpIfTrue { - delta: Arg::marker(), - } - .into() + assert_eq!( + blocks[0].instructions[2].lineno_override, + Some(NEXT_LOCATION_OVERRIDE) + ); + } + + #[test] + fn propagate_line_numbers_updates_empty_jump_target_raw_slot_like_cpython() { + let mut blocks = vec![Block::default(), Block::default(), Block::default()]; + blocks[0].next = BlockIdx::new(2); + test_block_push(&mut blocks[0], test_cond_jump(BlockIdx::new(1), 10)); + test_block_push(&mut blocks[1], test_instr(Instruction::Nop, 20)); + blocks[1].instructions[0].lineno_override = Some(NO_LOCATION_OVERRIDE); + basicblock_clear(&mut blocks[1]); + test_block_push(&mut blocks[2], test_instr(Instruction::ReturnValue, 30)); + + remove_unreachable(&mut blocks).expect("remove_unreachable succeeds"); + propagate_line_numbers(&mut blocks); + + // CPython `propagate_line_numbers()` directly reads `target->b_instr[0]` + // for jump targets without checking `b_iused`. If + // `remove_redundant_nops()` emptied the target, that writes the stale + // backing slot rather than an active instruction. + assert_eq!(instruction_lineno(&blocks[1].instructions[0]), 10); + } + + #[test] + fn basicblock_has_no_lineno_treats_next_location_like_cpython() { + let mut block = Block::default(); + test_block_push(&mut block, test_instr(Instruction::Nop, 10)); + block.instructions[0].lineno_override = Some(NEXT_LOCATION_OVERRIDE); + + // CPython `basicblock_has_no_lineno()` treats every negative lineno as + // no line number, including `NEXT_LOCATION` (`lineno == -2`). + assert!(basicblock_has_no_lineno(&block)); + + test_block_push(&mut block, test_instr(Instruction::PopTop, 11)); + assert!(!basicblock_has_no_lineno(&block)); + } + + #[test] + fn jump_threading_rechecks_new_jump_like_cpython() { + let mut blocks = vec![ + Block::default(), + Block::default(), + Block::default(), + Block::default(), + ]; + for (i, block) in blocks.iter_mut().enumerate() { + block.cpython_label = InstructionSequenceLabel::from_index(i as i32); + } + blocks[0].next = BlockIdx::new(1); + blocks[1].next = BlockIdx::new(2); + blocks[2].next = BlockIdx::new(3); + test_block_push(&mut blocks[0], test_jump(BlockIdx::new(1), 10)); + test_block_push(&mut blocks[1], test_jump(BlockIdx::new(2), 20)); + test_block_push(&mut blocks[2], test_jump(BlockIdx::new(3), 30)); + test_block_push(&mut blocks[3], test_instr(Instruction::ReturnValue, 40)); + + let mut metadata = test_code_info(Block::default()).metadata; + optimize_basic_block(&mut blocks, &mut metadata, BlockIdx::new(0)) + .expect("valid jump chain"); + + // CPython `optimize_basic_block()` continues after `jump_thread()`, so + // the appended jump is immediately checked against the next jump target. + let threaded = basicblock_last_instr(&blocks[0]).expect("threaded jump"); + assert!(matches!( + threaded.instr.pseudo(), + Some(PseudoInstruction::Jump { .. }) )); + assert_eq!(threaded.target, BlockIdx::new(3)); + assert_eq!(u32::from(threaded.arg), 3); } } diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 8d6ad984354..b598ab7e933 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -8,8 +8,8 @@ extern crate log; extern crate alloc; -type IndexMap = indexmap::IndexMap; -type IndexSet = indexmap::IndexSet; +type IndexMap = indexmap::IndexMap; +type IndexSet = indexmap::IndexSet; pub mod compile; pub mod error; diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 60c8d2734a4..10144597f46 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -32,6 +32,9 @@ pub struct SymbolTable { // Return True if the block is a nested class or function pub is_nested: bool, + /// Whether this function-like scope was created directly in a class block. + pub is_method: bool, + /// A set of symbols present on this scope level. pub symbols: IndexMap, @@ -57,6 +60,9 @@ pub struct SymbolTable { /// Whether this scope contains yield/yield from (is a generator function) pub is_generator: bool, + /// Whether this scope contains await or async comprehension machinery. + pub is_coroutine: bool, + /// Whether this comprehension scope should be inlined (PEP 709) /// True for list/set/dict comprehensions in non-generator expressions pub comp_inlined: bool, @@ -90,6 +96,7 @@ impl SymbolTable { typ, line_number, is_nested, + is_method: false, symbols: IndexMap::default(), sub_tables: vec![], next_sub_table: 0, @@ -98,6 +105,7 @@ impl SymbolTable { needs_classdict: false, can_see_class_scope: false, is_generator: false, + is_coroutine: false, comp_inlined: false, annotation_block: None, skip_enclosing_function_scope: false, @@ -1103,6 +1111,7 @@ impl SymbolTableBuilder { | CompilerScope::Lambda | CompilerScope::Comprehension | CompilerScope::Annotation + | CompilerScope::TypeParams ) } @@ -1118,11 +1127,17 @@ impl SymbolTableBuilder { } fn enter_scope(&mut self, name: &str, typ: CompilerScope, line_number: u32) { - let is_nested = self.tables.last().is_some_and(|table| { - table.is_nested - || matches!( - table.typ, - CompilerScope::Function | CompilerScope::AsyncFunction + let parent = self.tables.last(); + let is_nested = + parent.is_some_and(|table| table.is_nested || Self::is_function_like_scope(table.typ)); + let is_method = parent.is_some_and(|table| { + table.typ == CompilerScope::Class + && matches!( + typ, + CompilerScope::Function + | CompilerScope::AsyncFunction + | CompilerScope::Lambda + | CompilerScope::Comprehension ) }); // Inherit mangled_names from parent for non-class scopes @@ -1132,6 +1147,7 @@ impl SymbolTableBuilder { .and_then(|t| t.mangled_names.clone()) .filter(|_| typ != CompilerScope::Class); let mut table = SymbolTable::new(name.to_owned(), typ, line_number, is_nested); + table.is_method = is_method; table.future_annotations = self.future_annotations; table.mangled_names = inherited_mangled_names; self.tables.push(table); @@ -1145,6 +1161,8 @@ impl SymbolTableBuilder { name: &str, line_number: u32, for_class: bool, + has_defaults: bool, + has_kwdefaults: bool, ) -> SymbolTableResult { // Check if we're in a class scope let in_class = self @@ -1174,6 +1192,12 @@ impl SymbolTableBuilder { if for_class { self.register_name(".generic_base", SymbolUsage::Assigned, TextRange::default())?; } + if has_defaults { + self.register_name(".defaults", SymbolUsage::Parameter, TextRange::default())?; + } + if has_kwdefaults { + self.register_name(".kwdefaults", SymbolUsage::Parameter, TextRange::default())?; + } Ok(()) } @@ -1195,6 +1219,7 @@ impl SymbolTableBuilder { let can_see_class_scope = current.typ == CompilerScope::Class || current.can_see_class_scope; let has_conditional = current.has_conditional_annotations; + let is_nested = current.is_nested || Self::is_function_like_scope(current.typ); // Create annotation block if not exists if current.annotation_block.is_none() { @@ -1202,7 +1227,7 @@ impl SymbolTableBuilder { "__annotate__".to_owned(), CompilerScope::Annotation, line_number, - true, // is_nested + is_nested, ); // Annotation scope in class can see class scope annotation_table.can_see_class_scope = can_see_class_scope; @@ -1488,6 +1513,8 @@ impl SymbolTableBuilder { &format!("", name.as_str()), self.line_index_start(type_params.range), false, + true, + Self::has_kwonlydefaults(parameters), )?; self.scan_type_params(type_params)?; } @@ -1509,6 +1536,9 @@ impl SymbolTableBuilder { }, has_type_params, // skip_defaults: already scanned above )?; + if *is_async { + self.tables.last_mut().unwrap().is_coroutine = true; + } self.scan_statements(body)?; self.leave_scope(); if type_params.is_some() { @@ -1536,6 +1566,8 @@ impl SymbolTableBuilder { &format!("", name.as_str()), self.line_index_start(type_params.range), true, // for_class: enable selective mangling + false, + false, )?; // Set class_name for mangling in type param scope self.class_name = Some(name.to_string()); @@ -1847,6 +1879,8 @@ impl SymbolTableBuilder { &format!(""), self.line_index_start(type_params.range), false, + false, + false, )?; self.scan_type_params(type_params)?; } @@ -1999,6 +2033,7 @@ impl SymbolTableBuilder { range: _, }) => { self.scan_expression(value, context)?; + self.tables.last_mut().unwrap().is_coroutine = true; } Expr::Yield(ExprYield { value, @@ -2326,6 +2361,9 @@ impl SymbolTableBuilder { ); // Generator expressions need the is_generator flag self.tables.last_mut().unwrap().is_generator = is_generator; + if generators.iter().any(|generator| generator.is_async) { + self.tables.last_mut().unwrap().is_coroutine = true; + } // PEP 709: Mark non-generator comprehensions for inlining. // CPython's symtable marks all non-generator comprehensions for @@ -2362,7 +2400,14 @@ impl SymbolTableBuilder { } self.scan_expression(elt1, ExpressionContext::Load)?; + // CPython symtable_handle_comprehension(): non-generator async + // comprehensions propagate ste_coroutine to the enclosing scope after + // the comprehension block is exited. + let propagate_coroutine = self.tables.last().unwrap().is_coroutine && !is_generator; self.leave_scope(); + if propagate_coroutine { + self.tables.last_mut().unwrap().is_coroutine = true; + } Ok(()) } @@ -2583,6 +2628,13 @@ impl SymbolTableBuilder { Ok(()) } + fn has_kwonlydefaults(parameters: &ast::Parameters) -> bool { + parameters + .kwonlyargs + .iter() + .any(|arg| arg.default.is_some()) + } + fn enter_scope_with_parameters( &mut self, name: &str, @@ -2704,17 +2756,6 @@ impl SymbolTableBuilder { Ok(()) } - fn add_varname_to_scope(&mut self, table_idx: usize, name: &str) { - let varnames = if table_idx + 1 == self.tables.len() { - &mut self.current_varnames - } else { - &mut self.varnames_stack[table_idx + 1] - }; - if !varnames.iter().any(|existing| existing == name) { - varnames.push(name.to_owned()); - } - } - // Mirrors CPython symtable_extend_namedexpr_scope(): assignment expressions // inside comprehensions bind in the nearest function/module-like scope, not // in the synthetic comprehension scope itself. @@ -2752,9 +2793,6 @@ impl SymbolTableBuilder { match table_type { CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Lambda => { - let current_comp_inlined = self.tables.last().is_some_and(|table| { - table.typ == CompilerScope::Comprehension && table.comp_inlined - }); let parent_is_global = self.tables[table_idx] .symbols .get(mangled.as_str()) @@ -2777,9 +2815,6 @@ impl SymbolTableBuilder { .entry(mangled.clone()) .or_insert_with(|| Symbol::new(mangled.as_str())); symbol.flags.insert(SymbolFlags::ASSIGNED); - if !parent_is_global && current_comp_inlined { - self.add_varname_to_scope(table_idx, mangled.as_str()); - } return Ok(()); } CompilerScope::Module => { @@ -2941,7 +2976,7 @@ impl SymbolTableBuilder { match role { SymbolUsage::Nonlocal if scope_depth < 2 => { return Err(SymbolTableError { - error: format!("cannot define nonlocal '{name}' at top level."), + error: "nonlocal declaration not allowed at module level".into(), location, }); } @@ -3046,6 +3081,9 @@ pub(crate) fn mangle_name<'a>(class_name: Option<&str>, name: &'a str) -> Cow<'a } // Strip leading underscores from class name let class_name = class_name.trim_start_matches('_'); + if class_name.is_empty() { + return name.into(); + } let mut ret = String::with_capacity(1 + class_name.len() + name.len()); ret.push('_'); ret.push_str(class_name); @@ -3068,3 +3106,21 @@ pub(crate) fn maybe_mangle_name<'a>( } mangle_name(class_name, name) } + +#[cfg(test)] +mod tests { + use super::mangle_name; + + #[test] + fn mangle_name_leaves_private_name_in_underscore_only_class() { + assert_eq!(mangle_name(Some("_"), "__a"), "__a"); + assert_eq!(mangle_name(Some("__"), "__a"), "__a"); + assert_eq!(mangle_name(Some("___"), "__a"), "__a"); + } + + #[test] + fn mangle_name_strips_leading_class_underscores() { + assert_eq!(mangle_name(Some("_a"), "__a"), "_a__a"); + assert_eq!(mangle_name(Some("__a"), "__a"), "_a__a"); + } +} diff --git a/crates/common/src/cformat.rs b/crates/common/src/cformat.rs index 5d24b30ce06..7dbe1076975 100644 --- a/crates/common/src/cformat.rs +++ b/crates/common/src/cformat.rs @@ -866,7 +866,7 @@ mod tests { use super::*; #[test] - fn test_fill_and_align() { + fn fill_and_align() { assert_eq!( "%10s" .parse::() @@ -898,7 +898,7 @@ mod tests { } #[test] - fn test_parse_key() { + fn parse_key() { let expected = Ok(CFormatSpecKeyed { mapping_key: Some("amount".to_owned()), spec: CFormatSpec { @@ -926,7 +926,7 @@ mod tests { } #[test] - fn test_format_parse_key_fail() { + fn format_parse_key_fail() { assert_eq!( "%(aged".parse::(), Err(CFormatError { @@ -937,7 +937,7 @@ mod tests { } #[test] - fn test_format_parse_type_fail() { + fn format_parse_type_fail() { assert_eq!( "Hello %n".parse::(), Err(CFormatError { @@ -948,7 +948,7 @@ mod tests { } #[test] - fn test_incomplete_format_fail() { + fn incomplete_format_fail() { assert_eq!( "Hello %".parse::(), Err(CFormatError { @@ -959,7 +959,7 @@ mod tests { } #[test] - fn test_parse_flags() { + fn parse_flags() { let expected = Ok(CFormatSpec { format_type: CFormatType::Number(CNumberType::DecimalD), min_field_width: Some(CFormatQuantity::Amount(10)), @@ -975,7 +975,7 @@ mod tests { } #[test] - fn test_parse_and_format_string() { + fn parse_and_format_string() { assert_eq!( "%5.4s" .parse::() @@ -1007,7 +1007,7 @@ mod tests { } #[test] - fn test_parse_and_format_unicode_string() { + fn parse_and_format_unicode_string() { assert_eq!( "%.2s" .parse::() @@ -1018,7 +1018,7 @@ mod tests { } #[test] - fn test_parse_and_format_number() { + fn parse_and_format_number() { assert_eq!( "%5d" .parse::() @@ -1092,7 +1092,7 @@ mod tests { } #[test] - fn test_parse_and_format_float() { + fn parse_and_format_float() { assert_eq!( "%f".parse::().unwrap().format_float(1.2345), "1.234500" @@ -1130,7 +1130,7 @@ mod tests { } #[test] - fn test_format_parse() { + fn format_parse() { let fmt = "Hello, my name is %s and I'm %d years old"; let expected = Ok(CFormatString { parts: vec![ diff --git a/crates/common/src/format.rs b/crates/common/src/format.rs index 8992eb9ca36..af20b5746c8 100644 --- a/crates/common/src/format.rs +++ b/crates/common/src/format.rs @@ -1447,7 +1447,7 @@ mod tests { use super::*; #[test] - fn test_fill_and_align() { + fn fill_and_align() { let parse_fill_and_align = |text| { let (fill, align, rest) = parse_fill_and_align(str::as_ref(text)); ( @@ -1479,7 +1479,7 @@ mod tests { } #[test] - fn test_width_only() { + fn width_only() { let expected = Ok(FormatSpec { conversion: None, fill: None, @@ -1495,7 +1495,7 @@ mod tests { } #[test] - fn test_fill_and_width() { + fn fill_and_width() { let expected = Ok(FormatSpec { conversion: None, fill: Some('<'.into()), @@ -1511,7 +1511,7 @@ mod tests { } #[test] - fn test_all() { + fn all() { let expected = Ok(FormatSpec { conversion: None, fill: Some('<'.into()), @@ -1531,7 +1531,7 @@ mod tests { } #[test] - fn test_format_bool() { + fn format_bool_basic() { assert_eq!(format_bool("b", true), Ok("1".to_owned())); assert_eq!(format_bool("b", false), Ok("0".to_owned())); assert_eq!(format_bool("d", true), Ok("1".to_owned())); @@ -1563,7 +1563,7 @@ mod tests { } #[test] - fn test_format_int() { + fn format_int() { assert_eq!( FormatSpec::parse("d") .unwrap() @@ -1609,7 +1609,7 @@ mod tests { } #[test] - fn test_format_int_sep() { + fn format_int_sep() { let spec = FormatSpec::parse(",").expect(""); assert_eq!(spec.grouping_option, Some(FormatGrouping::Comma)); assert_eq!( @@ -1619,7 +1619,7 @@ mod tests { } #[test] - fn test_format_int_width_and_grouping() { + fn format_int_width_and_grouping() { // issue #5922: width + comma grouping should pad left, not inside the number let spec = FormatSpec::parse("10,").unwrap(); let result = spec.format_int(&BigInt::from(1234)).unwrap(); @@ -1627,7 +1627,7 @@ mod tests { } #[test] - fn test_format_int_padding_with_grouping() { + fn format_int_padding_with_grouping() { // CPython behavior: f'{1234:010,}' results in "00,001,234" let spec1 = FormatSpec::parse("010,").unwrap(); let result1 = spec1.format_int(&BigInt::from(1234)).unwrap(); @@ -1650,7 +1650,7 @@ mod tests { } #[test] - fn test_format_int_non_aftersign_zero_padding() { + fn format_int_non_aftersign_zero_padding() { // CPython behavior: f'{1234:0>10,}' results in "000001,234" let spec = FormatSpec::parse("0>10,").unwrap(); let result = spec.format_int(&BigInt::from(1234)).unwrap(); @@ -1658,7 +1658,7 @@ mod tests { } #[test] - fn test_format_parse() { + fn format_parse() { let expected = Ok(FormatString { format_parts: vec![ FormatPart::Literal("abcd".into()), @@ -1680,12 +1680,12 @@ mod tests { } #[test] - fn test_format_parse_multi_byte_char() { + fn format_parse_multi_byte_char() { assert!(FormatString::from_str("{a:%ЫйЯЧ}".as_ref()).is_ok()); } #[test] - fn test_format_parse_fail() { + fn format_parse_fail() { assert_eq!( FormatString::from_str("{s".as_ref()), Err(FormatParseError::UnmatchedBracket) @@ -1693,7 +1693,7 @@ mod tests { } #[test] - fn test_square_brackets_inside_format() { + fn square_brackets_inside_format() { assert_eq!( FormatString::from_str("{[:123]}".as_ref()), Ok(FormatString { @@ -1721,7 +1721,7 @@ mod tests { } #[test] - fn test_format_parse_escape() { + fn format_parse_escape() { let expected = Ok(FormatString { format_parts: vec![ FormatPart::Literal("{".into()), @@ -1738,7 +1738,7 @@ mod tests { } #[test] - fn test_format_invalid_specification() { + fn format_invalid_specification() { assert_eq!( FormatSpec::parse("%3"), Err(FormatSpecError::InvalidFormatSpecifier) @@ -1770,7 +1770,7 @@ mod tests { } #[test] - fn test_parse_field_name() { + fn parse_field_name() { let parse = |s: &str| FieldName::parse(s.as_ref()); assert_eq!( parse(""), diff --git a/crates/common/src/str.rs b/crates/common/src/str.rs index 1af21e385c7..3fddef04bb8 100644 --- a/crates/common/src/str.rs +++ b/crates/common/src/str.rs @@ -664,7 +664,7 @@ mod tests { use super::*; #[test] - fn test_get_chars() { + fn get_chars_basic() { let s = "0123456789"; assert_eq!(get_chars(s, 3..7), "3456"); assert_eq!(get_chars(s, 3..7), &s[3..7]); diff --git a/crates/compiler-core/generate.py b/crates/compiler-core/generate.py deleted file mode 100644 index ecb4652ea5d..00000000000 --- a/crates/compiler-core/generate.py +++ /dev/null @@ -1,721 +0,0 @@ -#!/usr/bin/env python -import collections -import dataclasses -import io -import os -import pathlib -import subprocess -import sys - -import tomllib - -CRATE_ROOT = pathlib.Path(__file__).parent -CONF_FILE = CRATE_ROOT / "opcode.toml" -OUT_FILE = CRATE_ROOT / "src" / "bytecode" / "instructions.rs" - -ROOT = CRATE_ROOT.parents[1] - -try: - CPYTHON_ROOT = pathlib.Path(os.environ["CPYTHON_ROOT"]).expanduser().resolve() -except KeyError: - raise ValueError("Missing environment variable 'CPYTHON_ROOT'") - -CPYTHON_TOOLS_LIB = CPYTHON_ROOT / "Tools" / "cases_generator" - -sys.path.append(CPYTHON_TOOLS_LIB.as_posix()) - -import analyzer -from generators_common import DEFAULT_INPUT -from stack import get_stack_effect - - -@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) -class OpcodeGen: - name: str - instruction_enum: str - instructions: list - numeric_repr: str - metadata: dict[str, str] - analysis: analyzer.Analysis - - def gen(self) -> str: - methods = "\n\n".join( - getattr(self, attr).strip() - for attr in sorted(dir(self)) - if attr.startswith("fn_") - ) - - impls = "\n\n".join( - getattr(self, attr).strip() - for attr in sorted(dir(self)) - if attr.startswith("impl_") - ) - - variants = ",\n".join(instr.name for instr in self) - - return f""" - #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub enum {self.name} {{ - {variants} - }} - - impl {self.name} {{ - {methods} - }} - - {impls} - """ - - @property - def fn_as_numeric(self) -> str: - arms = ",\n".join(f"Self::{instr.name} => {instr.opcode}" for instr in self) - return f""" - #[must_use] - pub const fn as_{self.numeric_repr}(self) -> {self.numeric_repr} {{ - match self {{ - {arms}, - }} - }} - """ - - @property - def fn_try_from_numeric(self) -> str: - arms = ",\n".join(f"{instr.opcode} => Self::{instr.name}" for instr in self) - return f""" - pub const fn try_from_{self.numeric_repr}( - value: {self.numeric_repr} - ) -> Result {{ - Ok(match value {{ - {arms}, - _ => return Err(MarshalError::InvalidBytecode), - }}) - }} - """ - - @property - def impl_try_from_numeric(self) -> str: - return f""" - impl TryFrom<{self.numeric_repr}> for {self.name} {{ - type Error = MarshalError; - - fn try_from(value: {self.numeric_repr}) -> Result {{ - Self::try_from_{self.numeric_repr}(value) - }} - }} - """ - - @property - def impl_into_numeric(self) -> str: - return f""" - impl From<{self.name}> for {self.numeric_repr} {{ - fn from(opcode: {self.name}) -> Self {{ - opcode.as_{self.numeric_repr}() - }} - }} - """ - - def build_has_attr_fn(self, fn_attr: str, prop_attr: str, doc_flag: str) -> str: - arms = "|".join( - f"Self::{instr.name}" - for instr in self - if getattr(instr.properties, prop_attr) - ) - - if arms: - inner = f"matches!(self, {arms})" - else: - inner = "false" - - return f""" - /// Does this opcode have '{doc_flag}' set. - #[must_use] - pub const fn has_{fn_attr}(self) -> bool {{ - {inner} - }} - """ - - fn_has_arg = property( - lambda self: self.build_has_attr_fn("arg", "oparg", "HAS_ARG_FLAG") - ) - - fn_has_const = property( - lambda self: self.build_has_attr_fn("const", "uses_co_consts", "HAS_CONST_FLAG") - ) - - fn_has_name = property( - lambda self: self.build_has_attr_fn("name", "uses_co_names", "HAS_NAME_FLAG") - ) - - fn_has_jump = property( - lambda self: self.build_has_attr_fn("jump", "jumps", "HAS_JUMP_FLAG") - ) - - fn_has_free = property( - lambda self: self.build_has_attr_fn("free", "has_free", "HAS_FREE_FLAG") - ) - - fn_has_local = property( - lambda self: self.build_has_attr_fn("local", "uses_locals", "HAS_LOCAL_FLAG") - ) - - @property - def instrumented_mapping(self) -> dict[str, str]: - inames = {instr.name for instr in self if instr.name.startswith("Instrumented")} - names = {instr.name for instr in self} - inames - - res = {} - for iname in sorted(inames): - name = iname.removeprefix("Instrumented") - if name not in names: - continue - - res[name] = iname - - return res - - @property - def fn_to_base(self) -> str: - arms = ",\n".join( - f"Self::{iname} => Self::{name}" - for name, iname in self.instrumented_mapping.items() - ) - - arms = arms.strip() - if not arms: - inner = "None" - else: - inner = f""" - Some(match self {{ - {arms}, - _ => return None, - - }}) - """ - - return f""" - #[must_use] - pub const fn to_base(self) -> Option {{ - {inner} - }} - """ - - @property - def fn_to_instrumented(self) -> str: - arms = ",\n".join( - f"Self::{name} => Self::{iname}" - for name, iname in self.instrumented_mapping.items() - ) - - arms = arms.strip() - if not arms: - inner = "None" - else: - inner = f""" - Some(match self {{ - {arms}, - _ => return None, - - }}) - """ - - return f""" - #[must_use] - pub const fn to_instrumented(self) -> Option {{ - {inner} - }} - """ - - @property - def fn_deopt(self) -> str: - names = {instr.name for instr in self} - - deopts = collections.defaultdict(list) - for family in self.analysis.families.values(): - family_name = to_pascal_case(family.name) - if family_name not in names: - continue - - for member in family.members: - if member.name == family_name: - continue - - deopts[family_name].append(member.name) - - arms = "" - for target, specialized in deopts.items(): - ops = "|".join(f"Self::{op}" for op in specialized) - arms += f"{ops} => Self::{target},\n" - - arms = arms.strip() - - if not arms: - inner = "None" - else: - inner = f""" - Some(match self {{ - {arms} - _ => return None, - - }}) - """ - - return f""" - #[must_use] - pub const fn deopt(self) -> Option {{ - {inner} - }} - """ - - @property - def fn_cache_entries(self) -> str: - arms = "" - for instr in self: - name = instr.name - if getattr(instr, "family", None) and (instr.family.name != name): - continue - - if name.startswith("Instrumented"): - continue - - try: - size = instr.size - except AttributeError: - continue - - if size > 1: - arms += f"Self::{name} => {size - 1},\n" - - arms = arms.strip() - if not arms: - inner = "0" - else: - inner = f""" - match self.deoptimize() {{ - {arms} - _ => 0, - }} - """ - - return f""" - #[must_use] - pub const fn cache_entries(self) -> usize {{ - {inner} - }} - """ - - @property - def fn_stack_effect_info(self) -> str: - oparg_used = False - arms = "" - for instr in self: - name = instr.name - stack = get_stack_effect(instr) - - popped = (-stack.base_offset).to_c() - pushed = (stack.logical_sp - stack.base_offset).to_c() - - pushed_comment = "" - popped_comment = "" - - if stack_effect := self.metadata.get(name, {}).get("stack_effect"): - if npushed := stack_effect.get("pushed"): - pushed_comment = f"// TODO: Differs from CPython `{pushed}`" - pushed = npushed - - if npopped := stack_effect.get("popped"): - popped_comment = f"// TODO: Differs from CPython `{popped}`" - popped = npopped - - oparg_used = oparg_used or any("oparg" in expr for expr in (pushed, popped)) - - arms += f""" - Self::{name} => ( - {pushed}, {pushed_comment} - {popped}, {popped_comment} - ), - """.strip() - - arms = arms.strip() - - oparg_arg = "_oparg" - oparg_cast = "" - if oparg_used: - oparg_arg = "oparg" - oparg_cast = f""" - // Reason for converting {oparg_arg} to i32 is because of expressions like `1 + (oparg -1)` - // that causes underflow errors. - let oparg = i32::try_from({oparg_arg}).expect("{oparg_arg} does not fit in an `i32`"); - """ - - return f""" - #[must_use] - pub fn stack_effect_info(&self, {oparg_arg}: u32) -> StackEffect {{ - {oparg_cast} - - let (pushed, popped) = match self {{ - {arms} - }}; - - debug_assert!(u32::try_from(pushed).is_ok()); - debug_assert!(u32::try_from(popped).is_ok()); - - StackEffect::new(pushed as u32, popped as u32) - }} - """ - - @property - def fn_as_instruction(self) -> str: - arms = "" - for instr in self: - name = instr.name - arms += f"Self::{name} => {self.instruction_enum}::{name}" - if oparg := self.metadata.get(name, {}).get("oparg"): - oname = oparg["name"] - arms += f" {{ {oname}: Arg::marker() }}" - - arms += ",\n" - - return f""" - /// Returns self as [`{self.instruction_enum}`]. - #[must_use] - pub const fn as_instruction(self) -> {self.instruction_enum} {{ - match self {{ - {arms} - }} - }} - """ - - @property - def impl_as_instruction(self) -> str: - return f""" - impl From<{self.name}> for {self.instruction_enum} {{ - fn from(opcode: {self.name}) -> Self {{ - opcode.as_instruction() - }} - }} - """ - - @property - def fn_stack_effect(self) -> str: - return """ - /// Stack effect of [`Self::stack_effect_info`]. - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 { - self.stack_effect_info(oparg).effect() - } - """ - - def __iter__(self): - yield from self.instructions - - -@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) -class InstructionGen: - name: str - opcode_enum: str - instructions: list - numeric_repr: str - metadata: dict[str, str] - - def gen(self) -> str: - methods = "\n\n".join( - getattr(self, attr).strip() - for attr in sorted(dir(self)) - if attr.startswith("fn_") - ) - - impls = "\n\n".join( - getattr(self, attr).strip() - for attr in sorted(dir(self)) - if attr.startswith("impl_") - ) - - variants = "" - for instr in self: - name = instr.name - variants += name - - if oparg := self.metadata.get(name, {}).get("oparg"): - oname, otype = oparg["name"], oparg["type"] - - variants += f"{{ {oname}: Arg<{otype}> }}" - - opcode = instr.opcode - variants += f" = {opcode},\n" - - return f""" - #[derive(Clone, Copy, Debug, Eq, PartialEq)] - #[repr({self.numeric_repr})] // TODO: Remove this `#[repr(...)]` - pub enum {self.name} {{ - {variants} - }} - - impl {self.name} {{ - {methods} - }} - - {impls} - """ - - @property - def fn_as_opcode(self) -> str: - arms = "" - for instr in self: - name = instr.name - arms += f"Self::{name}" - if oparg := self.metadata.get(name, {}).get("oparg"): - arms += " { .. }" - - arms += f"=> {self.opcode_enum}::{name},\n" - - return f""" - /// Returns self as a [`{self.opcode_enum}`]. - #[must_use] - pub const fn as_opcode(self) -> {self.opcode_enum} {{ - match self {{ - {arms} - }} - }} - """ - - @property - def impl_as_opcode(self) -> str: - return f""" - impl From<{self.name}> for {self.opcode_enum} {{ - fn from(instruction: {self.name}) -> Self {{ - instruction.as_opcode() - }} - }} - """ - - @property - def fn_as_numeric_repr(self) -> str: - return f""" - #[must_use] - pub const fn as_{self.numeric_repr}(self) -> {self.numeric_repr} {{ - self.as_opcode().as_{self.numeric_repr}() - }} - """ - - @property - def impl_as_numeric_repr(self) -> str: - return f""" - impl From<{self.name}> for {self.numeric_repr} {{ - fn from(instruction: {self.name}) -> Self {{ - instruction.as_{self.numeric_repr}() - }} - }} - """ - - @property - def fn_label_arg(self) -> str: - TARGET = "oparg::Label" - - arms = "" - for instr in self: - name = instr.name - if oparg := self.metadata.get(name, {}).get("oparg"): - oname, otype = oparg["name"], oparg["type"] - if otype != TARGET: - continue - - arms += f"Self::{name} {{ {oname} }} => *{oname},\n" - - arms = arms.strip() - - return f""" - #[must_use] - pub const fn label_arg(&self) -> Option> {{ - Some(match self {{ - {arms} - _ => return None, - }}) - }} - """ - - @property - def fn_to_base(self) -> str: - return f""" - #[must_use] - pub const fn to_base(self) -> Option {{ - if let Some(opcode) = self.as_opcode().to_base() {{ - Some(opcode.as_instruction()) - }} else {{ - None - }} - }} - """ - - @property - def fn_to_instrumented(self) -> str: - return f""" - #[must_use] - pub const fn to_instrumented(self) -> Option {{ - if let Some(opcode) = self.as_opcode().to_instrumented() {{ - Some(opcode.as_instruction()) - }} else {{ - None - }} - }} - """ - - @property - def fn_try_from_numeric(self) -> str: - return f""" - pub const fn try_from_{self.numeric_repr}( - value: {self.numeric_repr} - ) -> Result {{ - match {self.opcode_enum}::try_from_{self.numeric_repr}(value) {{ - Ok(opcode) => Ok(opcode.as_instruction()), - Err(e) => Err(e), - }} - }} - """ - - @property - def impl_try_from_numeric(self) -> str: - return f""" - impl TryFrom<{self.numeric_repr}> for {self.name} {{ - type Error = MarshalError; - - fn try_from(value: {self.numeric_repr}) -> Result {{ - Self::try_from_{self.numeric_repr}(value) - }} - }} - """ - - @property - def fn_stack_effect(self) -> str: - return """ - /// Stack effect of [`Self::stack_effect_info`]. - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 { - self.as_opcode().stack_effect(oparg) - } - """ - - @property - def fn_cache_entries(self) -> str: - return f""" - #[must_use] - pub const fn cache_entries(self) -> usize {{ - self.as_opcode().cache_entries() - }} - """ - - @property - def fn_deopt(self) -> str: - return f""" - #[must_use] - pub const fn deopt(self) -> Option {{ - if let Some(opcode) = self.as_opcode().deopt() {{ - Some(opcode.as_instruction()) - }} else {{ - None - }} - }} - """ - - @property - def fn_stack_effect_info(self) -> str: - return f""" - #[must_use] - pub fn stack_effect_info(&self, oparg: u32) -> StackEffect {{ - self.as_opcode().stack_effect_info(oparg) - }} - """ - - def __iter__(self): - yield from self.instructions - - -def to_pascal_case(s: str) -> str: - return s.title().replace("_", "") - - -def get_analysis() -> analyzer.Analysis: - analysis = analyzer.analyze_files([DEFAULT_INPUT]) - - # We don't differentiate between real and pseudos yet - analysis.instructions |= analysis.pseudos - return analysis - - -def rustfmt(code: str) -> str: - return subprocess.check_output(["rustfmt", "--emit=stdout"], input=code, text=True) - - -def main(): - CONF = tomllib.loads(CONF_FILE.read_text()) - - analysis = get_analysis() - - outfile = io.StringIO() - for opcode_enum, conf in CONF.items(): - metadata = conf["opcodes"] - numeric_repr = conf["numeric_repr"] - instruction_enum = conf["instruction_enum"] - - opcode_range = conf["range"] - lower, upper = map(int, (opcode_range["min"], opcode_range["max"])) - bounds = range(lower, upper + 1) - - instructions = sorted( - ( - instr - for instr in analysis.instructions.values() - if instr.opcode in bounds - ), - key=lambda x: x.opcode, - ) - - for instr in instructions: - instr.name = to_pascal_case(instr.name) - - opcode_code = OpcodeGen( - name=opcode_enum, - instruction_enum=instruction_enum, - instructions=instructions, - numeric_repr=numeric_repr, - metadata=metadata, - analysis=analysis, - ).gen() - - outfile.write(opcode_code) - - instruction_code = InstructionGen( - name=instruction_enum, - opcode_enum=opcode_enum, - instructions=instructions, - numeric_repr=numeric_repr, - metadata=metadata, - ).gen() - - outfile.write(instruction_code) - - generated = outfile.getvalue() - - script_path = pathlib.Path(__file__).resolve().relative_to(ROOT).as_posix() - - output = rustfmt( - f""" -// This file is generated by {script_path} -// Do not edit! - -use crate::{{ - bytecode::{{ - instruction::{{Arg, StackEffect}}, - oparg, - }}, - marshal::MarshalError, -}}; - -{generated} - """ - ) - - OUT_FILE.write_text(output) - - -if __name__ == "__main__": - main() diff --git a/crates/compiler-core/opcode.toml b/crates/compiler-core/opcode.toml deleted file mode 100644 index 6ee0cb86512..00000000000 --- a/crates/compiler-core/opcode.toml +++ /dev/null @@ -1,270 +0,0 @@ -[Opcode] -instruction_enum = "Instruction" -numeric_repr = "u8" -range = { min = 0, max = 255 } - -[Opcode.opcodes.BinaryOp] -oparg = { name = "op", type = "oparg::BinaryOperator" } - -[Opcode.opcodes.BuildInterpolation] -oparg = { name = "format", type = "u32" } - -[Opcode.opcodes.BuildList] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.BuildMap] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.BuildSet] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.BuildSlice] -oparg = { name = "argc", type = "oparg::BuildSliceArgCount" } - -[Opcode.opcodes.BuildString] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.BuildTuple] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.Call] -oparg = { name = "argc", type = "u32" } - -[Opcode.opcodes.CallIntrinsic1] -oparg = { name = "func", type = "oparg::IntrinsicFunction1" } - -[Opcode.opcodes.CallIntrinsic2] -oparg = { name = "func", type = "oparg::IntrinsicFunction2" } - -[Opcode.opcodes.CallKw] -oparg = { name = "argc", type = "u32" } - -[Opcode.opcodes.CompareOp] -oparg = { name = "opname", type = "oparg::ComparisonOperator" } - -[Opcode.opcodes.ContainsOp] -oparg = { name = "invert", type = "oparg::Invert" } - -[Opcode.opcodes.ConvertValue] -oparg = { name = "oparg", type = "oparg::ConvertValueOparg" } - -[Opcode.opcodes.Copy] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.CopyFreeVars] -oparg = { name = "n", type = "u32" } - -[Opcode.opcodes.DeleteAttr] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.DeleteDeref] -oparg = { name = "i", type = "oparg::VarNum" } - -[Opcode.opcodes.DeleteFast] -oparg = { name = "var_num", type = "oparg::VarNum" } - -[Opcode.opcodes.DeleteGlobal] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.DeleteName] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.DictMerge] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.DictUpdate] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.ForIter] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.GetAwaitable] -oparg = { name = "r#where", type = "u32" } - -[Opcode.opcodes.ImportFrom] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.ImportName] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.IsOp] -oparg = { name = "invert", type = "oparg::Invert" } - -[Opcode.opcodes.JumpBackward] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.JumpBackwardNoInterrupt] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.JumpForward] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.ListAppend] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.ListExtend] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.LoadAttr] -oparg = { name = "namei", type = "oparg::LoadAttr" } - -[Opcode.opcodes.LoadCommonConstant] -oparg = { name = "idx", type = "oparg::CommonConstant" } - -[Opcode.opcodes.LoadConst] -oparg = { name = "consti", type = "oparg::ConstIdx" } - -[Opcode.opcodes.LoadDeref] -oparg = { name = "i", type = "oparg::VarNum" } - -[Opcode.opcodes.LoadFast] -oparg = { name = "var_num", type = "oparg::VarNum" } - -[Opcode.opcodes.LoadFastAndClear] -oparg = { name = "var_num", type = "oparg::VarNum" } - -[Opcode.opcodes.LoadFastBorrow] -oparg = { name = "var_num", type = "oparg::VarNum" } - -[Opcode.opcodes.LoadFastBorrowLoadFastBorrow] -oparg = { name = "var_nums", type = "oparg::VarNums" } - -[Opcode.opcodes.LoadFastCheck] -oparg = { name = "var_num", type = "oparg::VarNum" } - -[Opcode.opcodes.LoadFastLoadFast] -oparg = { name = "var_nums", type = "oparg::VarNums" } - -[Opcode.opcodes.LoadFromDictOrDeref] -oparg = { name = "i", type = "oparg::VarNum" } - -[Opcode.opcodes.LoadFromDictOrGlobals] -oparg = { name = "i", type = "oparg::NameIdx" } - -[Opcode.opcodes.LoadGlobal] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.LoadName] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.LoadSmallInt] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.LoadSpecial] -oparg = { name = "method", type = "oparg::SpecialMethod" } - -[Opcode.opcodes.LoadSuperAttr] -oparg = { name = "namei", type = "oparg::LoadSuperAttr" } - -[Opcode.opcodes.MakeCell] -oparg = { name = "i", type = "oparg::VarNum" } - -[Opcode.opcodes.MapAdd] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.MatchClass] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.PopJumpIfFalse] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.PopJumpIfNone] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.PopJumpIfNotNone] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.PopJumpIfTrue] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.RaiseVarargs] -oparg = { name = "argc", type = "oparg::RaiseKind" } - -[Opcode.opcodes.Reraise] -oparg = { name = "depth", type = "u32" } - -[Opcode.opcodes.Send] -oparg = { name = "delta", type = "oparg::Label" } - -[Opcode.opcodes.SetAdd] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.SetFunctionAttribute] -oparg = { name = "flag", type = "oparg::MakeFunctionFlag" } - -[Opcode.opcodes.SetUpdate] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.StoreAttr] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.StoreDeref] -oparg = { name = "i", type = "oparg::VarNum" } - -[Opcode.opcodes.StoreFast] -oparg = { name = "var_num", type = "oparg::VarNum" } - -[Opcode.opcodes.StoreFastLoadFast] -oparg = { name = "var_nums", type = "oparg::VarNums" } - -[Opcode.opcodes.StoreFastStoreFast] -oparg = { name = "var_nums", type = "oparg::VarNums" } - -[Opcode.opcodes.StoreGlobal] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.StoreName] -oparg = { name = "namei", type = "oparg::NameIdx" } - -[Opcode.opcodes.Swap] -oparg = { name = "i", type = "u32" } - -[Opcode.opcodes.UnpackEx] -oparg = { name = "counts", type = "oparg::UnpackExArgs" } - -[Opcode.opcodes.UnpackSequence] -oparg = { name = "count", type = "u32" } - -[Opcode.opcodes.WithExceptStart] -stack_effect = { pushed = "7", popped = "6" } - -[Opcode.opcodes.YieldValue] -oparg = { name = "arg", type = "u32" } - -[Opcode.opcodes.Resume] -oparg = { name = "context", type = "oparg::ResumeContext" } - -[PseudoOpcode] -instruction_enum = "PseudoInstruction" -numeric_repr = "u16" -range = { min = 256, max = 65535 } - -[PseudoOpcode.opcodes.Jump] -oparg = { name = "delta", type = "oparg::Label" } - -[PseudoOpcode.opcodes.JumpIfFalse] -oparg = { name = "delta", type = "oparg::Label" } - -[PseudoOpcode.opcodes.JumpIfTrue] -oparg = { name = "delta", type = "oparg::Label" } - -[PseudoOpcode.opcodes.JumpNoInterrupt] -oparg = { name = "delta", type = "oparg::Label" } - -[PseudoOpcode.opcodes.LoadClosure] -oparg = { name = "i", type = "oparg::NameIdx" } - -[PseudoOpcode.opcodes.SetupCleanup] -oparg = { name = "delta", type = "oparg::Label" } -stack_effect = { pushed = "0" } - -[PseudoOpcode.opcodes.SetupFinally] -oparg = { name = "delta", type = "oparg::Label" } -stack_effect = { pushed = "0" } - -[PseudoOpcode.opcodes.SetupWith] -oparg = { name = "delta", type = "oparg::Label" } -stack_effect = { pushed = "0" } - -[PseudoOpcode.opcodes.StoreFastMaybeNull] -oparg = { name = "var_num", type = "oparg::NameIdx" } diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index 86723f40022..5872c0ffbd2 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -20,8 +20,10 @@ use num_complex::Complex64; use rustpython_wtf8::{Wtf8, Wtf8Buf}; pub use crate::bytecode::{ - instruction::{AnyInstruction, AnyOpcode, Arg, StackEffect}, - instructions::{Instruction, Opcode, PseudoInstruction, PseudoOpcode}, + instruction::{ + AnyInstruction, AnyOpcode, Arg, Instruction, Opcode, PseudoInstruction, PseudoOpcode, + StackEffect, + }, oparg::{ BinaryOperator, BuildSliceArgCount, CommonConstant, ComparisonOperator, ConvertValueOparg, IntrinsicFunction1, IntrinsicFunction2, Invert, Label, LoadAttr, LoadSuperAttr, @@ -31,7 +33,8 @@ pub use crate::bytecode::{ }; mod instruction; -mod instructions; +mod opcode_metadata; + pub mod oparg; /// Exception table entry for zero-cost exception handling @@ -412,6 +415,10 @@ impl IndexMut for [T] { } /// Per-slot kind flags for localsplus (co_localspluskinds). +pub const CO_FAST_ARG_POS: u8 = 0x02; +pub const CO_FAST_ARG_KW: u8 = 0x04; +pub const CO_FAST_ARG_VAR: u8 = 0x08; +pub const CO_FAST_ARG: u8 = CO_FAST_ARG_POS | CO_FAST_ARG_KW | CO_FAST_ARG_VAR; pub const CO_FAST_HIDDEN: u8 = 0x10; pub const CO_FAST_LOCAL: u8 = 0x20; pub const CO_FAST_CELL: u8 = 0x40; @@ -440,7 +447,8 @@ pub struct CodeObject { pub varnames: Box<[C::Name]>, pub cellvars: Box<[C::Name]>, pub freevars: Box<[C::Name]>, - /// Per-slot kind flags: CO_FAST_LOCAL, CO_FAST_CELL, CO_FAST_FREE, CO_FAST_HIDDEN. + /// Per-slot kind flags: CO_FAST_ARG_*, CO_FAST_LOCAL, CO_FAST_CELL, + /// CO_FAST_FREE, CO_FAST_HIDDEN. /// Length = nlocalsplus (nlocals + ncells + nfrees). pub localspluskinds: Box<[u8]>, /// Line number table (CPython 3.11+ format) @@ -460,9 +468,12 @@ bitflags! { const GENERATOR = 0x0020; const COROUTINE = 0x0080; const ITERABLE_COROUTINE = 0x0100; + const ASYNC_GENERATOR = 0x0200; + const FUTURE_ANNOTATIONS = 0x1000000; /// If a code object represents a function and has a docstring, /// this bit is set and the first item in co_consts is the docstring. const HAS_DOCSTRING = 0x4000000; + const METHOD = 0x8000000; } } @@ -906,8 +917,6 @@ impl PartialEq for ConstantData { match (self, other) { (Integer { value: a }, Integer { value: b }) => a == b, - // we want to compare floats *by actual value* - if we have the *exact same* float - // already in a constant cache, we want to use that (Float { value: a }, Float { value: b }) => a.to_bits() == b.to_bits(), (Complex { value: a }, Complex { value: b }) => { a.re.to_bits() == b.re.to_bits() && a.im.to_bits() == b.im.to_bits() @@ -1286,7 +1295,7 @@ mod tests { use alloc::{vec, vec::Vec}; #[test] - fn test_exception_table_encode_decode() { + fn exception_table_encode_decode() { let entries = vec![ ExceptionTableEntry::new(0, 10, 20, 2, false), ExceptionTableEntry::new(15, 25, 30, 1, true), @@ -1324,7 +1333,7 @@ mod tests { } #[test] - fn test_exception_table_empty() { + fn exception_table_empty() { let entries: Vec = vec![]; let encoded = encode_exception_table(&entries); assert!(encoded.is_empty()); @@ -1332,7 +1341,7 @@ mod tests { } #[test] - fn test_exception_table_single_entry() { + fn exception_table_single_entry() { let entries = vec![ExceptionTableEntry::new(5, 15, 100, 3, true)]; let encoded = encode_exception_table(&entries); diff --git a/crates/compiler-core/src/bytecode/instruction.rs b/crates/compiler-core/src/bytecode/instruction.rs index 62937069790..69714a0fe66 100644 --- a/crates/compiler-core/src/bytecode/instruction.rs +++ b/crates/compiler-core/src/bytecode/instruction.rs @@ -2,34 +2,715 @@ use core::{fmt, marker::PhantomData}; use crate::marshal::MarshalError; -use super::{Instruction, OpArg, OpArgByte, OpArgType, Opcode, PseudoInstruction, PseudoOpcode}; +use super::{OpArg, OpArgByte, OpArgType, oparg}; -impl Opcode { - /// Map a specialized or instrumented opcode back to its adaptive (base) variant. - #[must_use] - pub const fn deoptimize(self) -> Self { - match self.deopt() { - Some(v) => v, - None => { - // Instrumented opcodes map back to their base - match self.to_base() { +macro_rules! define_opcodes { + ( + #[repr($typ:ident)] + $opcode_vis:vis enum $opcode_name:ident; + + $(#[$instr_meta:meta])* + $instr_vis:vis enum $instr_name:ident { + $( + $(#[$op_meta:meta])* + $op_name:ident $({ $arg_name:ident: Arg<$arg_type:ty> $(,)? })? = $op_id:expr + ),* $(,)? + } + ) => { + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + $opcode_vis enum $opcode_name { + $($op_name),* + } + + impl $opcode_name { + #[doc = concat!("Converts this opcode to [`", stringify!($instr_name), "`].")] + #[must_use] + $opcode_vis const fn as_instruction(&self) -> $instr_name { + match self { + $( + Self::$op_name => $instr_name::$op_name $({ $arg_name: Arg::marker() })?, + )* + } + } + + /// Map a specialized or instrumented opcode back to its adaptive (base) variant. + #[must_use] + $opcode_vis const fn deoptimize(self) -> Self { + match self.deopt() { Some(v) => v, - None => self, + None => { + // Instrumented opcodes map back to their base + match self.to_base() { + Some(v) => v, + None => self, + } + } + } + } + + // NOTE: Keep private. Will be exposed under `try_from_u8/try_from_u16`. + pub(super) const fn try_from_numeric(value: $typ) -> Result { + match value { + $($op_id => Ok(Self::$op_name),)* + _ => Err($crate::marshal::MarshalError::InvalidBytecode), + } + } + + // NOTE: Keep private. Will be exposed under `as_u8/as_u16`. + #[must_use] + pub(super) const fn as_numeric(self) -> $typ { + match self { + $(Self::$op_name => $op_id,)* } } } + + impl From<$opcode_name> for $instr_name { + fn from(opcode: $opcode_name) -> Self { + opcode.as_instruction() + } + } + + + impl TryFrom<$typ> for $opcode_name { + type Error = $crate::marshal::MarshalError; + + fn try_from(value: $typ) -> Result { + Self::try_from_numeric(value) + } + } + + impl From<$opcode_name> for $typ { + fn from(opcode: $opcode_name) -> Self { + opcode.as_numeric() + } + } + + #[derive(Clone, Copy, Debug)] + #[repr($typ)] // TODO: Remove this repr + $instr_vis enum $instr_name { + $( + $(#[$op_meta])* + $op_name $({ $arg_name: Arg<$arg_type> })? = $op_id // TODO: Don't assign value + ),* + } + + impl $instr_name { + #[doc = concat!("Get the corresponding [`", stringify!($opcode_name), "`].")] + #[must_use] + $instr_vis const fn as_opcode(&self) -> $opcode_name { + match self { + $( + Self::$op_name $({ $arg_name: _ })? => $opcode_name::$op_name, + )* + } + } + + #[must_use] + $instr_vis const fn label_arg(&self) -> Option> { + //define_opcodes!(@label_arm Self::$op_name $({ $arg_name } : $arg_type)?) + define_opcodes!(@match self, Self, [$($op_name $({ $arg_name : $arg_type })?),*]) + } + + #[must_use] + pub const fn to_base(self) -> Option { + if let Some(op) = self.as_opcode().to_base() { + Some(op.as_instruction()) + } else { + None + } + } + + #[must_use] + pub const fn to_instrumented(self) -> Option { + if let Some(op) = self.as_opcode().to_instrumented() { + Some(op.as_instruction()) + } else { + None + } + } + + /// Returns `true` if this is any instrumented opcode. + #[must_use] + $instr_vis const fn is_instrumented(&self) -> bool { + self.as_opcode().is_instrumented() + } + + #[must_use] + $instr_vis const fn is_unconditional_jump(&self) -> bool { + self.as_opcode().is_unconditional_jump() + } + + #[must_use] + $instr_vis const fn is_block_push(&self) -> bool { + self.as_opcode().is_block_push() + } + + #[must_use] + $instr_vis const fn is_scope_exit(&self) -> bool { + self.as_opcode().is_scope_exit() + } + + #[must_use] + $instr_vis const fn is_terminator(&self) -> bool { + self.as_opcode().is_terminator() + } + + #[must_use] + $instr_vis const fn is_no_fallthrough(&self) -> bool { + self.as_opcode().is_no_fallthrough() + } + + #[must_use] + $instr_vis const fn has_target(&self) -> bool { + self.as_opcode().has_target() + } + + #[must_use] + $instr_vis const fn has_jump(&self) -> bool { + self.as_opcode().has_jump() + } + + #[must_use] + $instr_vis const fn has_arg(&self) -> bool { + self.as_opcode().has_arg() + } + + #[must_use] + $instr_vis const fn has_const(&self) -> bool { + self.as_opcode().has_const() + } + + #[must_use] + $instr_vis const fn has_eval_break(&self) -> bool { + self.as_opcode().has_eval_break() + } + + #[must_use] + $instr_vis const fn is_assembler(&self) -> bool { + self.as_opcode().is_assembler() + } + + #[must_use] + $instr_vis const fn cache_entries(&self) -> usize{ + self.as_opcode().cache_entries() + } + + /// Map a specialized or instrumented opcode back to its adaptive (base) variant. + #[must_use] + $instr_vis const fn deoptimize(&self) -> Self { + self.as_opcode().deoptimize().as_instruction() + } + + #[must_use] + $instr_vis fn stack_effect_jump(&self, oparg: u32) -> i32 { + self.as_opcode().stack_effect_jump(oparg) + } + + #[must_use] + $instr_vis fn stack_effect_info(&self, oparg: u32) -> StackEffect { + self.as_opcode().stack_effect_info(oparg) + } + + #[must_use] + $instr_vis fn stack_effect(&self, oparg: u32) -> i32 { + self.as_opcode().stack_effect(oparg) + } + } + + impl From<$instr_name> for $opcode_name { + fn from(instr: $instr_name) -> Self { + instr.as_opcode() + } + } + + impl TryFrom<$typ> for $instr_name { + type Error = $crate::marshal::MarshalError; + + fn try_from(value: $typ) -> Result { + $opcode_name::try_from_numeric(value).map(Into::into) + } + } + + impl From<$instr_name> for $typ { + fn from(instr: $instr_name) -> Self { + instr.as_opcode().into() + } + } + }; + + // Base case: empty list + (@match $self:expr, $name:ident, []) => { + None + }; + + // Label field variant (with trailing variants) + (@match $self:expr, $name:ident, [$variant:ident { $field:ident : Label } , $($rest:tt)*]) => { + match $self { + $name::$variant { $field } => Some(*$field), + other => define_opcodes!(@match other, $name, [$($rest)*]), + } + }; + + // Label field variant (last in list) + (@match $self:expr, $name:ident, [$variant:ident { $field:ident : Label }]) => { + match $self { + $name::$variant { $field } => Some(*$field), + other => define_opcodes!(@match other, $name, []), + } + }; + + // Non-Label field variant (with trailing variants) + (@match $self:expr, $name:ident, [$variant:ident { $field:ident : $type:ty } , $($rest:tt)*]) => { + match $self { + $name::$variant { .. } => None, + other => define_opcodes!(@match other, $name, [$($rest)*]), + } + }; + + // Non-Label field variant (last in list) + (@match $self:expr, $name:ident, [$variant:ident { $field:ident : $type:ty }]) => { + match $self { + $name::$variant { .. } => None, + _ => define_opcodes!(@match _, $name, []), + } + }; + + // Unit variant (with trailing variants) + (@match $self:expr, $name:ident, [$variant:ident , $($rest:tt)*]) => { + match $self { + $name::$variant => None, + other => define_opcodes!(@match other, $name, [$($rest)*]), + } + }; + + // Unit variant (last in list) + (@match $self:expr, $name:ident, [$variant:ident]) => { + match $self { + $name::$variant => None, + _ => define_opcodes!(@match _, $name, []), + } + }; +} + +define_opcodes!( + #[repr(u8)] + pub enum Opcode; + + pub enum Instruction { + Cache = 0, + BinarySlice = 1, + BuildTemplate = 2, + BinaryOpInplaceAddUnicode = 3, + CallFunctionEx = 4, + CheckEgMatch = 5, + CheckExcMatch = 6, + CleanupThrow = 7, + DeleteSubscr = 8, + EndFor = 9, + EndSend = 10, + ExitInitCheck = 11, + FormatSimple = 12, + FormatWithSpec = 13, + GetAiter = 14, + GetAnext = 15, + GetIter = 16, + Reserved = 17, + GetLen = 18, + GetYieldFromIter = 19, + InterpreterExit = 20, + LoadBuildClass = 21, + LoadLocals = 22, + MakeFunction = 23, + MatchKeys = 24, + MatchMapping = 25, + MatchSequence = 26, + Nop = 27, + NotTaken = 28, + PopExcept = 29, + PopIter = 30, + PopTop = 31, + PushExcInfo = 32, + PushNull = 33, + ReturnGenerator = 34, + ReturnValue = 35, + SetupAnnotations = 36, + StoreSlice = 37, + StoreSubscr = 38, + ToBool = 39, + UnaryInvert = 40, + UnaryNegative = 41, + UnaryNot = 42, + WithExceptStart = 43, + BinaryOp { + op: Arg, + } = 44, + BuildInterpolation { + format: Arg, + } = 45, + BuildList { + count: Arg, + } = 46, + BuildMap { + count: Arg, + } = 47, + BuildSet { + count: Arg, + } = 48, + BuildSlice { + argc: Arg, + } = 49, + BuildString { + count: Arg, + } = 50, + BuildTuple { + count: Arg, + } = 51, + Call { + argc: Arg, + } = 52, + CallIntrinsic1 { + func: Arg, + } = 53, + CallIntrinsic2 { + func: Arg, + } = 54, + CallKw { + argc: Arg, + } = 55, + CompareOp { + opname: Arg, + } = 56, + ContainsOp { + invert: Arg, + } = 57, + ConvertValue { + oparg: Arg, + } = 58, + Copy { + i: Arg, + } = 59, + CopyFreeVars { + n: Arg, + } = 60, + DeleteAttr { + namei: Arg, + } = 61, + DeleteDeref { + i: Arg, + } = 62, + DeleteFast { + var_num: Arg, + } = 63, + DeleteGlobal { + namei: Arg, + } = 64, + DeleteName { + namei: Arg, + } = 65, + DictMerge { + i: Arg, + } = 66, + DictUpdate { + i: Arg, + } = 67, + EndAsyncFor = 68, + ExtendedArg = 69, + ForIter { + delta: Arg, + } = 70, + GetAwaitable { + r#where: Arg, + } = 71, + ImportFrom { + namei: Arg, + } = 72, + ImportName { + namei: Arg, + } = 73, + IsOp { + invert: Arg, + } = 74, + JumpBackward { + delta: Arg, + } = 75, + JumpBackwardNoInterrupt { + delta: Arg, + } = 76, + JumpForward { + delta: Arg, + } = 77, + ListAppend { + i: Arg, + } = 78, + ListExtend { + i: Arg, + } = 79, + LoadAttr { + namei: Arg, + } = 80, + LoadCommonConstant { + idx: Arg, + } = 81, + LoadConst { + consti: Arg, + } = 82, + LoadDeref { + i: Arg, + } = 83, + LoadFast { + var_num: Arg, + } = 84, + LoadFastAndClear { + var_num: Arg, + } = 85, + LoadFastBorrow { + var_num: Arg, + } = 86, + LoadFastBorrowLoadFastBorrow { + var_nums: Arg, + } = 87, + LoadFastCheck { + var_num: Arg, + } = 88, + LoadFastLoadFast { + var_nums: Arg, + } = 89, + LoadFromDictOrDeref { + i: Arg, + } = 90, + LoadFromDictOrGlobals { + i: Arg, + } = 91, + LoadGlobal { + namei: Arg, + } = 92, + LoadName { + namei: Arg, + } = 93, + LoadSmallInt { + i: Arg, + } = 94, + LoadSpecial { + method: Arg, + } = 95, + LoadSuperAttr { + namei: Arg, + } = 96, + MakeCell { + i: Arg, + } = 97, + MapAdd { + i: Arg, + } = 98, + MatchClass { + count: Arg, + } = 99, + PopJumpIfFalse { + delta: Arg, + } = 100, + PopJumpIfNone { + delta: Arg, + } = 101, + PopJumpIfNotNone { + delta: Arg, + } = 102, + PopJumpIfTrue { + delta: Arg, + } = 103, + RaiseVarargs { + argc: Arg, + } = 104, + Reraise { + depth: Arg, + } = 105, + Send { + delta: Arg, + } = 106, + SetAdd { + i: Arg, + } = 107, + SetFunctionAttribute { + flag: Arg, + } = 108, + SetUpdate { + i: Arg, + } = 109, + StoreAttr { + namei: Arg, + } = 110, + StoreDeref { + i: Arg, + } = 111, + StoreFast { + var_num: Arg, + } = 112, + StoreFastLoadFast { + var_nums: Arg, + } = 113, + StoreFastStoreFast { + var_nums: Arg, + } = 114, + StoreGlobal { + namei: Arg, + } = 115, + StoreName { + namei: Arg, + } = 116, + Swap { + i: Arg, + } = 117, + UnpackEx { + counts: Arg, + } = 118, + UnpackSequence { + count: Arg, + } = 119, + YieldValue { + arg: Arg, + } = 120, + Resume { + context: Arg, + } = 128, + BinaryOpAddFloat = 129, + BinaryOpAddInt = 130, + BinaryOpAddUnicode = 131, + BinaryOpExtend = 132, + BinaryOpMultiplyFloat = 133, + BinaryOpMultiplyInt = 134, + BinaryOpSubscrDict = 135, + BinaryOpSubscrGetitem = 136, + BinaryOpSubscrListInt = 137, + BinaryOpSubscrListSlice = 138, + BinaryOpSubscrStrInt = 139, + BinaryOpSubscrTupleInt = 140, + BinaryOpSubtractFloat = 141, + BinaryOpSubtractInt = 142, + CallAllocAndEnterInit = 143, + CallBoundMethodExactArgs = 144, + CallBoundMethodGeneral = 145, + CallBuiltinClass = 146, + CallBuiltinFast = 147, + CallBuiltinFastWithKeywords = 148, + CallBuiltinO = 149, + CallIsinstance = 150, + CallKwBoundMethod = 151, + CallKwNonPy = 152, + CallKwPy = 153, + CallLen = 154, + CallListAppend = 155, + CallMethodDescriptorFast = 156, + CallMethodDescriptorFastWithKeywords = 157, + CallMethodDescriptorNoargs = 158, + CallMethodDescriptorO = 159, + CallNonPyGeneral = 160, + CallPyExactArgs = 161, + CallPyGeneral = 162, + CallStr1 = 163, + CallTuple1 = 164, + CallType1 = 165, + CompareOpFloat = 166, + CompareOpInt = 167, + CompareOpStr = 168, + ContainsOpDict = 169, + ContainsOpSet = 170, + ForIterGen = 171, + ForIterList = 172, + ForIterRange = 173, + ForIterTuple = 174, + JumpBackwardJit = 175, + JumpBackwardNoJit = 176, + LoadAttrClass = 177, + LoadAttrClassWithMetaclassCheck = 178, + LoadAttrGetattributeOverridden = 179, + LoadAttrInstanceValue = 180, + LoadAttrMethodLazyDict = 181, + LoadAttrMethodNoDict = 182, + LoadAttrMethodWithValues = 183, + LoadAttrModule = 184, + LoadAttrNondescriptorNoDict = 185, + LoadAttrNondescriptorWithValues = 186, + LoadAttrProperty = 187, + LoadAttrSlot = 188, + LoadAttrWithHint = 189, + LoadConstImmortal = 190, + LoadConstMortal = 191, + LoadGlobalBuiltin = 192, + LoadGlobalModule = 193, + LoadSuperAttrAttr = 194, + LoadSuperAttrMethod = 195, + ResumeCheck = 196, + SendGen = 197, + StoreAttrInstanceValue = 198, + StoreAttrSlot = 199, + StoreAttrWithHint = 200, + StoreSubscrDict = 201, + StoreSubscrListInt = 202, + ToBoolAlwaysTrue = 203, + ToBoolBool = 204, + ToBoolInt = 205, + ToBoolList = 206, + ToBoolNone = 207, + ToBoolStr = 208, + UnpackSequenceList = 209, + UnpackSequenceTuple = 210, + UnpackSequenceTwoTuple = 211, + InstrumentedEndFor = 234, + InstrumentedPopIter = 235, + InstrumentedEndSend = 236, + InstrumentedForIter = 237, + InstrumentedInstruction = 238, + InstrumentedJumpForward = 239, + InstrumentedNotTaken = 240, + InstrumentedPopJumpIfTrue = 241, + InstrumentedPopJumpIfFalse = 242, + InstrumentedPopJumpIfNone = 243, + InstrumentedPopJumpIfNotNone = 244, + InstrumentedResume = 245, + InstrumentedReturnValue = 246, + InstrumentedYieldValue = 247, + InstrumentedEndAsyncFor = 248, + InstrumentedLoadSuperAttr = 249, + InstrumentedCall = 250, + InstrumentedCallKw = 251, + InstrumentedCallFunctionEx = 252, + InstrumentedJumpBackward = 253, + InstrumentedLine = 254, + EnterExecutor = 255, + } +); + +define_opcodes!( + #[repr(u16)] + pub enum PseudoOpcode; + + pub enum PseudoInstruction { + AnnotationsPlaceholder = 256, + Jump { delta: Arg } = 257, + JumpIfFalse { delta: Arg } = 258, + JumpIfTrue { delta: Arg } = 259, + JumpNoInterrupt { delta: Arg } = 260, + LoadClosure { i: Arg } = 261, + PopBlock = 262, + SetupCleanup { delta: Arg } = 263, + SetupFinally { delta: Arg } = 264, + SetupWith { delta: Arg } = 265, + StoreFastMaybeNull { var_num: Arg } = 266, } +); - /// Returns `true` if this is any instrumented opcode - /// (regular INSTRUMENTED_*, INSTRUMENTED_LINE, or INSTRUMENTED_INSTRUCTION). +impl Opcode { #[must_use] - pub const fn is_instrumented(self) -> bool { - self.to_base().is_some() - || matches!(self, Self::InstrumentedLine | Self::InstrumentedInstruction) + pub const fn is_unconditional_jump(&self) -> bool { + matches!( + self, + Self::JumpForward | Self::JumpBackward | Self::JumpBackwardNoInterrupt + ) } + /// CPython's `IS_ASSEMBLER_OPCODE`. #[must_use] - pub const fn is_unconditional_jump(&self) -> bool { + pub const fn is_assembler(&self) -> bool { matches!( self, Self::JumpForward | Self::JumpBackward | Self::JumpBackwardNoInterrupt @@ -41,11 +722,35 @@ impl Opcode { matches!(self, Self::ReturnValue | Self::RaiseVarargs | Self::Reraise) } + /// CPython's `IS_TERMINATOR_OPCODE`. + #[must_use] + pub const fn is_terminator(&self) -> bool { + self.has_jump() || self.is_scope_exit() + } + + /// CPython's `IS_SCOPE_EXIT_OPCODE || IS_UNCONDITIONAL_JUMP_OPCODE`. + #[must_use] + pub const fn is_no_fallthrough(&self) -> bool { + self.is_scope_exit() || self.is_unconditional_jump() + } + + /// CPython's `HAS_TARGET`. + #[must_use] + pub const fn has_target(&self) -> bool { + self.has_jump() || self.is_block_push() + } + #[must_use] pub const fn is_block_push(&self) -> bool { false } + /// Stack effect of [`Self::stack_effect_info`]. + #[must_use] + pub fn stack_effect(&self, oparg: u32) -> i32 { + self.stack_effect_info(oparg).effect() + } + /// Stack effect when the instruction takes its branch (jump=true). /// /// CPython equivalent: `stack_effect(opcode, oparg, jump=True)`. @@ -59,11 +764,6 @@ impl Opcode { } impl PseudoOpcode { - #[must_use] - pub const fn is_instrumented(&self) -> bool { - false - } - #[must_use] pub const fn is_block_push(&self) -> bool { matches!( @@ -72,84 +772,63 @@ impl PseudoOpcode { ) } - /// Handler entry effect for SETUP_* pseudo ops. - /// - /// Fallthrough effect is 0 (NOPs), but when the branch is taken the - /// handler block starts with extra values on the stack: - /// SETUP_FINALLY: +1 (exc) - /// SETUP_CLEANUP: +2 (lasti + exc) - /// SETUP_WITH: +1 (pops __enter__ result, pushes lasti + exc) #[must_use] - pub fn stack_effect_jump(&self, oparg: u32) -> i32 { - match self { - Self::SetupFinally | Self::SetupWith => 1, - Self::SetupCleanup => 2, - _ => self.stack_effect(oparg), - } - } -} - -impl Instruction { - /// Returns `true` if this is any instrumented opcode - /// (regular INSTRUMENTED_*, INSTRUMENTED_LINE, or INSTRUMENTED_INSTRUCTION). - #[must_use] - pub const fn is_instrumented(self) -> bool { - self.as_opcode().is_instrumented() + pub const fn is_scope_exit(&self) -> bool { + false } #[must_use] pub const fn is_unconditional_jump(&self) -> bool { - self.as_opcode().is_unconditional_jump() + matches!(self, Self::Jump | Self::JumpNoInterrupt) } #[must_use] - pub const fn is_block_push(&self) -> bool { - self.as_opcode().is_block_push() - } - - #[must_use] - pub const fn is_scope_exit(&self) -> bool { - self.as_opcode().is_scope_exit() - } - - /// Map a specialized or instrumented opcode back to its adaptive (base) variant. - #[must_use] - pub const fn deoptimize(self) -> Self { - self.as_opcode().deoptimize().as_instruction() + pub const fn is_assembler(&self) -> bool { + false } + /// CPython's `IS_TERMINATOR_OPCODE`. #[must_use] - pub fn stack_effect_jump(&self, oparg: u32) -> i32 { - self.as_opcode().stack_effect(oparg) + pub const fn is_terminator(&self) -> bool { + self.has_jump() } -} -impl PseudoInstruction { - /// Returns true if self is one of: - /// - [`PseudoInstruction::SetupCleanup`] - /// - [`PseudoInstruction::SetupFinally`] - /// - [`PseudoInstruction::SetupWith`] + /// CPython's `IS_SCOPE_EXIT_OPCODE || IS_UNCONDITIONAL_JUMP_OPCODE`. #[must_use] - pub const fn is_block_push(&self) -> bool { - self.as_opcode().is_block_push() + pub const fn is_no_fallthrough(&self) -> bool { + self.is_unconditional_jump() } + /// CPython's `HAS_TARGET`. #[must_use] - pub const fn is_unconditional_jump(&self) -> bool { - matches!( - self.as_opcode(), - PseudoOpcode::Jump | PseudoOpcode::JumpNoInterrupt - ) + pub const fn has_target(&self) -> bool { + self.has_jump() || self.is_block_push() } + /// flowgraph.c get_stack_effects block-push non-jump case. #[must_use] - pub const fn is_scope_exit(&self) -> bool { - false + pub fn stack_effect(&self, oparg: u32) -> i32 { + if self.is_block_push() { + 0 + } else { + self.stack_effect_info(oparg).effect() + } } + /// Handler entry effect for SETUP_* pseudo ops. + /// + /// Fallthrough effect is 0 (NOPs), but when the branch is taken the + /// handler block starts with extra values on the stack: + /// SETUP_FINALLY: +1 (exc) + /// SETUP_CLEANUP: +2 (lasti + exc) + /// SETUP_WITH: +1 (pops __enter__ result, pushes lasti + exc) #[must_use] pub fn stack_effect_jump(&self, oparg: u32) -> i32 { - self.as_opcode().stack_effect_jump(oparg) + match self { + Self::SetupFinally | Self::SetupWith => 1, + Self::SetupCleanup => 2, + _ => self.stack_effect(oparg), + } } } @@ -191,28 +870,68 @@ pub enum AnyInstruction { impl AnyInstruction { either_real_pseudo!( - #[must_use] - pub const fn is_unconditional_jump(&self) -> bool + #[must_use] + pub const fn is_unconditional_jump(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn is_scope_exit(&self) -> bool + #[must_use] + pub const fn is_scope_exit(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 + #[must_use] + pub const fn is_terminator(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub fn stack_effect_jump(&self, oparg: u32) -> i32 + #[must_use] + pub const fn is_no_fallthrough(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub fn stack_effect_info(&self, oparg: u32) -> StackEffect + #[must_use] + pub const fn has_target(&self) -> bool + ); + + either_real_pseudo!( + #[must_use] + pub const fn has_jump(&self) -> bool + ); + + either_real_pseudo!( + #[must_use] + pub const fn has_arg(&self) -> bool + ); + + either_real_pseudo!( + #[must_use] + pub const fn has_const(&self) -> bool + ); + + either_real_pseudo!( + #[must_use] + pub const fn has_eval_break(&self) -> bool + ); + + either_real_pseudo!( + #[must_use] + pub const fn is_assembler(&self) -> bool + ); + + either_real_pseudo!( + #[must_use] + pub fn stack_effect(&self, oparg: u32) -> i32 + ); + + either_real_pseudo!( + #[must_use] + pub fn stack_effect_jump(&self, oparg: u32) -> i32 + ); + + either_real_pseudo!( + #[must_use] + pub fn stack_effect_info(&self, oparg: u32) -> StackEffect ); } @@ -340,7 +1059,7 @@ impl AnyInstruction { } } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum AnyOpcode { Real(Opcode), Pseudo(PseudoOpcode), @@ -428,53 +1147,53 @@ impl AnyOpcode { } either_real_pseudo!( - #[must_use] - pub const fn has_arg(&self) -> bool + #[must_use] + pub const fn has_arg(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn has_jump(&self) -> bool + #[must_use] + pub const fn has_jump(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn has_free(&self) -> bool + #[must_use] + pub const fn has_free(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn has_local(&self) -> bool + #[must_use] + pub const fn has_local(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn has_name(&self) -> bool + #[must_use] + pub const fn has_name(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn has_const(&self) -> bool + #[must_use] + pub const fn has_const(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn is_instrumented(&self) -> bool + #[must_use] + pub const fn is_instrumented(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub const fn is_block_push(&self) -> bool + #[must_use] + pub const fn is_block_push(&self) -> bool ); either_real_pseudo!( - #[must_use] - pub fn stack_effect_jump(&self, oparg: u32) -> i32 + #[must_use] + pub fn stack_effect_jump(&self, oparg: u32) -> i32 ); either_real_pseudo!( - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 + #[must_use] + pub fn stack_effect(&self, oparg: u32) -> i32 ); #[must_use] @@ -601,3 +1320,115 @@ impl fmt::Debug for Arg { // breaks the VM:/ const _: () = assert!(core::mem::size_of::() == 1); const _: () = assert!(core::mem::size_of::() == 2); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn eval_break_flags_match_cpython_jump_metadata() { + assert!(Opcode::JumpBackward.has_eval_break()); + assert!(!Opcode::JumpBackwardNoInterrupt.has_eval_break()); + assert!(!Opcode::JumpForward.has_eval_break()); + + assert!(PseudoOpcode::Jump.has_eval_break()); + assert!(!PseudoOpcode::JumpIfFalse.has_eval_break()); + assert!(!PseudoOpcode::JumpIfTrue.has_eval_break()); + assert!(!PseudoOpcode::JumpNoInterrupt.has_eval_break()); + + assert!(AnyInstruction::from(PseudoOpcode::Jump).has_eval_break()); + } + + #[test] + fn terminator_flags_match_cpython_opcode_utils() { + assert!(Opcode::JumpForward.is_terminator()); + assert!(Opcode::PopJumpIfFalse.is_terminator()); + assert!(Opcode::ForIter.is_terminator()); + assert!(Opcode::ReturnValue.is_terminator()); + assert!(!Opcode::Nop.is_terminator()); + + assert!(PseudoOpcode::JumpIfTrue.is_terminator()); + assert!(PseudoOpcode::JumpNoInterrupt.is_terminator()); + assert!(!PseudoOpcode::SetupFinally.is_terminator()); + assert!(!PseudoOpcode::SetupWith.is_terminator()); + assert!(!PseudoOpcode::SetupCleanup.is_terminator()); + assert!(!PseudoOpcode::PopBlock.is_terminator()); + + assert!(AnyInstruction::from(PseudoOpcode::JumpIfFalse).is_terminator()); + } + + #[test] + fn assembler_flags_match_cpython_opcode_utils() { + assert!(Opcode::JumpForward.is_assembler()); + assert!(Opcode::JumpBackward.is_assembler()); + assert!(Opcode::JumpBackwardNoInterrupt.is_assembler()); + assert!(!Opcode::PopJumpIfFalse.is_assembler()); + assert!(!Opcode::Nop.is_assembler()); + + assert!(!PseudoOpcode::Jump.is_assembler()); + assert!(!PseudoOpcode::JumpNoInterrupt.is_assembler()); + assert!(!AnyInstruction::from(PseudoOpcode::Jump).is_assembler()); + } + + #[test] + fn target_flags_match_cpython_opcode_utils() { + assert!(Opcode::JumpForward.has_target()); + assert!(Opcode::ForIter.has_target()); + assert!(!Opcode::ReturnValue.has_target()); + assert!(!Opcode::Nop.has_target()); + + assert!(PseudoOpcode::Jump.has_target()); + assert!(PseudoOpcode::SetupFinally.has_target()); + assert!(PseudoOpcode::SetupWith.has_target()); + assert!(PseudoOpcode::SetupCleanup.has_target()); + assert!(!PseudoOpcode::PopBlock.has_target()); + + assert!(AnyInstruction::from(PseudoOpcode::SetupFinally).has_target()); + } + + #[test] + fn arg_flags_match_cpython_opcode_metadata() { + assert!(Opcode::LoadConst.has_arg()); + assert!(Opcode::YieldValue.has_arg()); + assert!(!Opcode::Nop.has_arg()); + assert!(!Opcode::ReturnValue.has_arg()); + + assert!(PseudoOpcode::Jump.has_arg()); + assert!(PseudoOpcode::JumpIfFalse.has_arg()); + assert!(PseudoOpcode::JumpIfTrue.has_arg()); + assert!(PseudoOpcode::JumpNoInterrupt.has_arg()); + assert!(PseudoOpcode::LoadClosure.has_arg()); + assert!(PseudoOpcode::StoreFastMaybeNull.has_arg()); + assert!(!PseudoOpcode::AnnotationsPlaceholder.has_arg()); + assert!(!PseudoOpcode::PopBlock.has_arg()); + } + + #[test] + fn const_flags_match_cpython_opcode_metadata() { + assert!(Opcode::LoadConst.has_const()); + assert!(Opcode::LoadConstImmortal.has_const()); + assert!(Opcode::LoadConstMortal.has_const()); + assert!(!Opcode::LoadSmallInt.has_const()); + assert!(!Opcode::Nop.has_const()); + + assert!(!PseudoOpcode::LoadClosure.has_const()); + assert!(!AnyInstruction::from(PseudoOpcode::Jump).has_const()); + } + + #[test] + fn no_fallthrough_flags_match_cpython_basicblock_nofallthrough() { + assert!(Opcode::JumpForward.is_no_fallthrough()); + assert!(Opcode::ReturnValue.is_no_fallthrough()); + assert!(!Opcode::PopJumpIfFalse.is_no_fallthrough()); + assert!(!Opcode::ForIter.is_no_fallthrough()); + assert!(!Opcode::Nop.is_no_fallthrough()); + + assert!(PseudoOpcode::Jump.is_no_fallthrough()); + assert!(PseudoOpcode::JumpNoInterrupt.is_no_fallthrough()); + assert!(!PseudoOpcode::JumpIfFalse.is_no_fallthrough()); + assert!(!PseudoOpcode::SetupFinally.is_no_fallthrough()); + assert!(!PseudoOpcode::SetupWith.is_no_fallthrough()); + + assert!(AnyInstruction::from(PseudoOpcode::Jump).is_no_fallthrough()); + } +} diff --git a/crates/compiler-core/src/bytecode/instructions.rs b/crates/compiler-core/src/bytecode/instructions.rs deleted file mode 100644 index 7f16dc42d4a..00000000000 --- a/crates/compiler-core/src/bytecode/instructions.rs +++ /dev/null @@ -1,2780 +0,0 @@ -// This file is generated by crates/compiler-core/generate.py -// Do not edit! - -use crate::{ - bytecode::{ - instruction::{Arg, StackEffect}, - oparg, - }, - marshal::MarshalError, -}; - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum Opcode { - Cache, - BinarySlice, - BuildTemplate, - BinaryOpInplaceAddUnicode, - CallFunctionEx, - CheckEgMatch, - CheckExcMatch, - CleanupThrow, - DeleteSubscr, - EndFor, - EndSend, - ExitInitCheck, - FormatSimple, - FormatWithSpec, - GetAiter, - GetAnext, - GetIter, - Reserved, - GetLen, - GetYieldFromIter, - InterpreterExit, - LoadBuildClass, - LoadLocals, - MakeFunction, - MatchKeys, - MatchMapping, - MatchSequence, - Nop, - NotTaken, - PopExcept, - PopIter, - PopTop, - PushExcInfo, - PushNull, - ReturnGenerator, - ReturnValue, - SetupAnnotations, - StoreSlice, - StoreSubscr, - ToBool, - UnaryInvert, - UnaryNegative, - UnaryNot, - WithExceptStart, - BinaryOp, - BuildInterpolation, - BuildList, - BuildMap, - BuildSet, - BuildSlice, - BuildString, - BuildTuple, - Call, - CallIntrinsic1, - CallIntrinsic2, - CallKw, - CompareOp, - ContainsOp, - ConvertValue, - Copy, - CopyFreeVars, - DeleteAttr, - DeleteDeref, - DeleteFast, - DeleteGlobal, - DeleteName, - DictMerge, - DictUpdate, - EndAsyncFor, - ExtendedArg, - ForIter, - GetAwaitable, - ImportFrom, - ImportName, - IsOp, - JumpBackward, - JumpBackwardNoInterrupt, - JumpForward, - ListAppend, - ListExtend, - LoadAttr, - LoadCommonConstant, - LoadConst, - LoadDeref, - LoadFast, - LoadFastAndClear, - LoadFastBorrow, - LoadFastBorrowLoadFastBorrow, - LoadFastCheck, - LoadFastLoadFast, - LoadFromDictOrDeref, - LoadFromDictOrGlobals, - LoadGlobal, - LoadName, - LoadSmallInt, - LoadSpecial, - LoadSuperAttr, - MakeCell, - MapAdd, - MatchClass, - PopJumpIfFalse, - PopJumpIfNone, - PopJumpIfNotNone, - PopJumpIfTrue, - RaiseVarargs, - Reraise, - Send, - SetAdd, - SetFunctionAttribute, - SetUpdate, - StoreAttr, - StoreDeref, - StoreFast, - StoreFastLoadFast, - StoreFastStoreFast, - StoreGlobal, - StoreName, - Swap, - UnpackEx, - UnpackSequence, - YieldValue, - Resume, - BinaryOpAddFloat, - BinaryOpAddInt, - BinaryOpAddUnicode, - BinaryOpExtend, - BinaryOpMultiplyFloat, - BinaryOpMultiplyInt, - BinaryOpSubscrDict, - BinaryOpSubscrGetitem, - BinaryOpSubscrListInt, - BinaryOpSubscrListSlice, - BinaryOpSubscrStrInt, - BinaryOpSubscrTupleInt, - BinaryOpSubtractFloat, - BinaryOpSubtractInt, - CallAllocAndEnterInit, - CallBoundMethodExactArgs, - CallBoundMethodGeneral, - CallBuiltinClass, - CallBuiltinFast, - CallBuiltinFastWithKeywords, - CallBuiltinO, - CallIsinstance, - CallKwBoundMethod, - CallKwNonPy, - CallKwPy, - CallLen, - CallListAppend, - CallMethodDescriptorFast, - CallMethodDescriptorFastWithKeywords, - CallMethodDescriptorNoargs, - CallMethodDescriptorO, - CallNonPyGeneral, - CallPyExactArgs, - CallPyGeneral, - CallStr1, - CallTuple1, - CallType1, - CompareOpFloat, - CompareOpInt, - CompareOpStr, - ContainsOpDict, - ContainsOpSet, - ForIterGen, - ForIterList, - ForIterRange, - ForIterTuple, - JumpBackwardJit, - JumpBackwardNoJit, - LoadAttrClass, - LoadAttrClassWithMetaclassCheck, - LoadAttrGetattributeOverridden, - LoadAttrInstanceValue, - LoadAttrMethodLazyDict, - LoadAttrMethodNoDict, - LoadAttrMethodWithValues, - LoadAttrModule, - LoadAttrNondescriptorNoDict, - LoadAttrNondescriptorWithValues, - LoadAttrProperty, - LoadAttrSlot, - LoadAttrWithHint, - LoadConstImmortal, - LoadConstMortal, - LoadGlobalBuiltin, - LoadGlobalModule, - LoadSuperAttrAttr, - LoadSuperAttrMethod, - ResumeCheck, - SendGen, - StoreAttrInstanceValue, - StoreAttrSlot, - StoreAttrWithHint, - StoreSubscrDict, - StoreSubscrListInt, - ToBoolAlwaysTrue, - ToBoolBool, - ToBoolInt, - ToBoolList, - ToBoolNone, - ToBoolStr, - UnpackSequenceList, - UnpackSequenceTuple, - UnpackSequenceTwoTuple, - InstrumentedEndFor, - InstrumentedPopIter, - InstrumentedEndSend, - InstrumentedForIter, - InstrumentedInstruction, - InstrumentedJumpForward, - InstrumentedNotTaken, - InstrumentedPopJumpIfTrue, - InstrumentedPopJumpIfFalse, - InstrumentedPopJumpIfNone, - InstrumentedPopJumpIfNotNone, - InstrumentedResume, - InstrumentedReturnValue, - InstrumentedYieldValue, - InstrumentedEndAsyncFor, - InstrumentedLoadSuperAttr, - InstrumentedCall, - InstrumentedCallKw, - InstrumentedCallFunctionEx, - InstrumentedJumpBackward, - InstrumentedLine, - EnterExecutor, -} - -impl Opcode { - /// Returns self as [`Instruction`]. - #[must_use] - pub const fn as_instruction(self) -> Instruction { - match self { - Self::Cache => Instruction::Cache, - Self::BinarySlice => Instruction::BinarySlice, - Self::BuildTemplate => Instruction::BuildTemplate, - Self::BinaryOpInplaceAddUnicode => Instruction::BinaryOpInplaceAddUnicode, - Self::CallFunctionEx => Instruction::CallFunctionEx, - Self::CheckEgMatch => Instruction::CheckEgMatch, - Self::CheckExcMatch => Instruction::CheckExcMatch, - Self::CleanupThrow => Instruction::CleanupThrow, - Self::DeleteSubscr => Instruction::DeleteSubscr, - Self::EndFor => Instruction::EndFor, - Self::EndSend => Instruction::EndSend, - Self::ExitInitCheck => Instruction::ExitInitCheck, - Self::FormatSimple => Instruction::FormatSimple, - Self::FormatWithSpec => Instruction::FormatWithSpec, - Self::GetAiter => Instruction::GetAiter, - Self::GetAnext => Instruction::GetAnext, - Self::GetIter => Instruction::GetIter, - Self::Reserved => Instruction::Reserved, - Self::GetLen => Instruction::GetLen, - Self::GetYieldFromIter => Instruction::GetYieldFromIter, - Self::InterpreterExit => Instruction::InterpreterExit, - Self::LoadBuildClass => Instruction::LoadBuildClass, - Self::LoadLocals => Instruction::LoadLocals, - Self::MakeFunction => Instruction::MakeFunction, - Self::MatchKeys => Instruction::MatchKeys, - Self::MatchMapping => Instruction::MatchMapping, - Self::MatchSequence => Instruction::MatchSequence, - Self::Nop => Instruction::Nop, - Self::NotTaken => Instruction::NotTaken, - Self::PopExcept => Instruction::PopExcept, - Self::PopIter => Instruction::PopIter, - Self::PopTop => Instruction::PopTop, - Self::PushExcInfo => Instruction::PushExcInfo, - Self::PushNull => Instruction::PushNull, - Self::ReturnGenerator => Instruction::ReturnGenerator, - Self::ReturnValue => Instruction::ReturnValue, - Self::SetupAnnotations => Instruction::SetupAnnotations, - Self::StoreSlice => Instruction::StoreSlice, - Self::StoreSubscr => Instruction::StoreSubscr, - Self::ToBool => Instruction::ToBool, - Self::UnaryInvert => Instruction::UnaryInvert, - Self::UnaryNegative => Instruction::UnaryNegative, - Self::UnaryNot => Instruction::UnaryNot, - Self::WithExceptStart => Instruction::WithExceptStart, - Self::BinaryOp => Instruction::BinaryOp { op: Arg::marker() }, - Self::BuildInterpolation => Instruction::BuildInterpolation { - format: Arg::marker(), - }, - Self::BuildList => Instruction::BuildList { - count: Arg::marker(), - }, - Self::BuildMap => Instruction::BuildMap { - count: Arg::marker(), - }, - Self::BuildSet => Instruction::BuildSet { - count: Arg::marker(), - }, - Self::BuildSlice => Instruction::BuildSlice { - argc: Arg::marker(), - }, - Self::BuildString => Instruction::BuildString { - count: Arg::marker(), - }, - Self::BuildTuple => Instruction::BuildTuple { - count: Arg::marker(), - }, - Self::Call => Instruction::Call { - argc: Arg::marker(), - }, - Self::CallIntrinsic1 => Instruction::CallIntrinsic1 { - func: Arg::marker(), - }, - Self::CallIntrinsic2 => Instruction::CallIntrinsic2 { - func: Arg::marker(), - }, - Self::CallKw => Instruction::CallKw { - argc: Arg::marker(), - }, - Self::CompareOp => Instruction::CompareOp { - opname: Arg::marker(), - }, - Self::ContainsOp => Instruction::ContainsOp { - invert: Arg::marker(), - }, - Self::ConvertValue => Instruction::ConvertValue { - oparg: Arg::marker(), - }, - Self::Copy => Instruction::Copy { i: Arg::marker() }, - Self::CopyFreeVars => Instruction::CopyFreeVars { n: Arg::marker() }, - Self::DeleteAttr => Instruction::DeleteAttr { - namei: Arg::marker(), - }, - Self::DeleteDeref => Instruction::DeleteDeref { i: Arg::marker() }, - Self::DeleteFast => Instruction::DeleteFast { - var_num: Arg::marker(), - }, - Self::DeleteGlobal => Instruction::DeleteGlobal { - namei: Arg::marker(), - }, - Self::DeleteName => Instruction::DeleteName { - namei: Arg::marker(), - }, - Self::DictMerge => Instruction::DictMerge { i: Arg::marker() }, - Self::DictUpdate => Instruction::DictUpdate { i: Arg::marker() }, - Self::EndAsyncFor => Instruction::EndAsyncFor, - Self::ExtendedArg => Instruction::ExtendedArg, - Self::ForIter => Instruction::ForIter { - delta: Arg::marker(), - }, - Self::GetAwaitable => Instruction::GetAwaitable { - r#where: Arg::marker(), - }, - Self::ImportFrom => Instruction::ImportFrom { - namei: Arg::marker(), - }, - Self::ImportName => Instruction::ImportName { - namei: Arg::marker(), - }, - Self::IsOp => Instruction::IsOp { - invert: Arg::marker(), - }, - Self::JumpBackward => Instruction::JumpBackward { - delta: Arg::marker(), - }, - Self::JumpBackwardNoInterrupt => Instruction::JumpBackwardNoInterrupt { - delta: Arg::marker(), - }, - Self::JumpForward => Instruction::JumpForward { - delta: Arg::marker(), - }, - Self::ListAppend => Instruction::ListAppend { i: Arg::marker() }, - Self::ListExtend => Instruction::ListExtend { i: Arg::marker() }, - Self::LoadAttr => Instruction::LoadAttr { - namei: Arg::marker(), - }, - Self::LoadCommonConstant => Instruction::LoadCommonConstant { idx: Arg::marker() }, - Self::LoadConst => Instruction::LoadConst { - consti: Arg::marker(), - }, - Self::LoadDeref => Instruction::LoadDeref { i: Arg::marker() }, - Self::LoadFast => Instruction::LoadFast { - var_num: Arg::marker(), - }, - Self::LoadFastAndClear => Instruction::LoadFastAndClear { - var_num: Arg::marker(), - }, - Self::LoadFastBorrow => Instruction::LoadFastBorrow { - var_num: Arg::marker(), - }, - Self::LoadFastBorrowLoadFastBorrow => Instruction::LoadFastBorrowLoadFastBorrow { - var_nums: Arg::marker(), - }, - Self::LoadFastCheck => Instruction::LoadFastCheck { - var_num: Arg::marker(), - }, - Self::LoadFastLoadFast => Instruction::LoadFastLoadFast { - var_nums: Arg::marker(), - }, - Self::LoadFromDictOrDeref => Instruction::LoadFromDictOrDeref { i: Arg::marker() }, - Self::LoadFromDictOrGlobals => Instruction::LoadFromDictOrGlobals { i: Arg::marker() }, - Self::LoadGlobal => Instruction::LoadGlobal { - namei: Arg::marker(), - }, - Self::LoadName => Instruction::LoadName { - namei: Arg::marker(), - }, - Self::LoadSmallInt => Instruction::LoadSmallInt { i: Arg::marker() }, - Self::LoadSpecial => Instruction::LoadSpecial { - method: Arg::marker(), - }, - Self::LoadSuperAttr => Instruction::LoadSuperAttr { - namei: Arg::marker(), - }, - Self::MakeCell => Instruction::MakeCell { i: Arg::marker() }, - Self::MapAdd => Instruction::MapAdd { i: Arg::marker() }, - Self::MatchClass => Instruction::MatchClass { - count: Arg::marker(), - }, - Self::PopJumpIfFalse => Instruction::PopJumpIfFalse { - delta: Arg::marker(), - }, - Self::PopJumpIfNone => Instruction::PopJumpIfNone { - delta: Arg::marker(), - }, - Self::PopJumpIfNotNone => Instruction::PopJumpIfNotNone { - delta: Arg::marker(), - }, - Self::PopJumpIfTrue => Instruction::PopJumpIfTrue { - delta: Arg::marker(), - }, - Self::RaiseVarargs => Instruction::RaiseVarargs { - argc: Arg::marker(), - }, - Self::Reraise => Instruction::Reraise { - depth: Arg::marker(), - }, - Self::Send => Instruction::Send { - delta: Arg::marker(), - }, - Self::SetAdd => Instruction::SetAdd { i: Arg::marker() }, - Self::SetFunctionAttribute => Instruction::SetFunctionAttribute { - flag: Arg::marker(), - }, - Self::SetUpdate => Instruction::SetUpdate { i: Arg::marker() }, - Self::StoreAttr => Instruction::StoreAttr { - namei: Arg::marker(), - }, - Self::StoreDeref => Instruction::StoreDeref { i: Arg::marker() }, - Self::StoreFast => Instruction::StoreFast { - var_num: Arg::marker(), - }, - Self::StoreFastLoadFast => Instruction::StoreFastLoadFast { - var_nums: Arg::marker(), - }, - Self::StoreFastStoreFast => Instruction::StoreFastStoreFast { - var_nums: Arg::marker(), - }, - Self::StoreGlobal => Instruction::StoreGlobal { - namei: Arg::marker(), - }, - Self::StoreName => Instruction::StoreName { - namei: Arg::marker(), - }, - Self::Swap => Instruction::Swap { i: Arg::marker() }, - Self::UnpackEx => Instruction::UnpackEx { - counts: Arg::marker(), - }, - Self::UnpackSequence => Instruction::UnpackSequence { - count: Arg::marker(), - }, - Self::YieldValue => Instruction::YieldValue { arg: Arg::marker() }, - Self::Resume => Instruction::Resume { - context: Arg::marker(), - }, - Self::BinaryOpAddFloat => Instruction::BinaryOpAddFloat, - Self::BinaryOpAddInt => Instruction::BinaryOpAddInt, - Self::BinaryOpAddUnicode => Instruction::BinaryOpAddUnicode, - Self::BinaryOpExtend => Instruction::BinaryOpExtend, - Self::BinaryOpMultiplyFloat => Instruction::BinaryOpMultiplyFloat, - Self::BinaryOpMultiplyInt => Instruction::BinaryOpMultiplyInt, - Self::BinaryOpSubscrDict => Instruction::BinaryOpSubscrDict, - Self::BinaryOpSubscrGetitem => Instruction::BinaryOpSubscrGetitem, - Self::BinaryOpSubscrListInt => Instruction::BinaryOpSubscrListInt, - Self::BinaryOpSubscrListSlice => Instruction::BinaryOpSubscrListSlice, - Self::BinaryOpSubscrStrInt => Instruction::BinaryOpSubscrStrInt, - Self::BinaryOpSubscrTupleInt => Instruction::BinaryOpSubscrTupleInt, - Self::BinaryOpSubtractFloat => Instruction::BinaryOpSubtractFloat, - Self::BinaryOpSubtractInt => Instruction::BinaryOpSubtractInt, - Self::CallAllocAndEnterInit => Instruction::CallAllocAndEnterInit, - Self::CallBoundMethodExactArgs => Instruction::CallBoundMethodExactArgs, - Self::CallBoundMethodGeneral => Instruction::CallBoundMethodGeneral, - Self::CallBuiltinClass => Instruction::CallBuiltinClass, - Self::CallBuiltinFast => Instruction::CallBuiltinFast, - Self::CallBuiltinFastWithKeywords => Instruction::CallBuiltinFastWithKeywords, - Self::CallBuiltinO => Instruction::CallBuiltinO, - Self::CallIsinstance => Instruction::CallIsinstance, - Self::CallKwBoundMethod => Instruction::CallKwBoundMethod, - Self::CallKwNonPy => Instruction::CallKwNonPy, - Self::CallKwPy => Instruction::CallKwPy, - Self::CallLen => Instruction::CallLen, - Self::CallListAppend => Instruction::CallListAppend, - Self::CallMethodDescriptorFast => Instruction::CallMethodDescriptorFast, - Self::CallMethodDescriptorFastWithKeywords => { - Instruction::CallMethodDescriptorFastWithKeywords - } - Self::CallMethodDescriptorNoargs => Instruction::CallMethodDescriptorNoargs, - Self::CallMethodDescriptorO => Instruction::CallMethodDescriptorO, - Self::CallNonPyGeneral => Instruction::CallNonPyGeneral, - Self::CallPyExactArgs => Instruction::CallPyExactArgs, - Self::CallPyGeneral => Instruction::CallPyGeneral, - Self::CallStr1 => Instruction::CallStr1, - Self::CallTuple1 => Instruction::CallTuple1, - Self::CallType1 => Instruction::CallType1, - Self::CompareOpFloat => Instruction::CompareOpFloat, - Self::CompareOpInt => Instruction::CompareOpInt, - Self::CompareOpStr => Instruction::CompareOpStr, - Self::ContainsOpDict => Instruction::ContainsOpDict, - Self::ContainsOpSet => Instruction::ContainsOpSet, - Self::ForIterGen => Instruction::ForIterGen, - Self::ForIterList => Instruction::ForIterList, - Self::ForIterRange => Instruction::ForIterRange, - Self::ForIterTuple => Instruction::ForIterTuple, - Self::JumpBackwardJit => Instruction::JumpBackwardJit, - Self::JumpBackwardNoJit => Instruction::JumpBackwardNoJit, - Self::LoadAttrClass => Instruction::LoadAttrClass, - Self::LoadAttrClassWithMetaclassCheck => Instruction::LoadAttrClassWithMetaclassCheck, - Self::LoadAttrGetattributeOverridden => Instruction::LoadAttrGetattributeOverridden, - Self::LoadAttrInstanceValue => Instruction::LoadAttrInstanceValue, - Self::LoadAttrMethodLazyDict => Instruction::LoadAttrMethodLazyDict, - Self::LoadAttrMethodNoDict => Instruction::LoadAttrMethodNoDict, - Self::LoadAttrMethodWithValues => Instruction::LoadAttrMethodWithValues, - Self::LoadAttrModule => Instruction::LoadAttrModule, - Self::LoadAttrNondescriptorNoDict => Instruction::LoadAttrNondescriptorNoDict, - Self::LoadAttrNondescriptorWithValues => Instruction::LoadAttrNondescriptorWithValues, - Self::LoadAttrProperty => Instruction::LoadAttrProperty, - Self::LoadAttrSlot => Instruction::LoadAttrSlot, - Self::LoadAttrWithHint => Instruction::LoadAttrWithHint, - Self::LoadConstImmortal => Instruction::LoadConstImmortal, - Self::LoadConstMortal => Instruction::LoadConstMortal, - Self::LoadGlobalBuiltin => Instruction::LoadGlobalBuiltin, - Self::LoadGlobalModule => Instruction::LoadGlobalModule, - Self::LoadSuperAttrAttr => Instruction::LoadSuperAttrAttr, - Self::LoadSuperAttrMethod => Instruction::LoadSuperAttrMethod, - Self::ResumeCheck => Instruction::ResumeCheck, - Self::SendGen => Instruction::SendGen, - Self::StoreAttrInstanceValue => Instruction::StoreAttrInstanceValue, - Self::StoreAttrSlot => Instruction::StoreAttrSlot, - Self::StoreAttrWithHint => Instruction::StoreAttrWithHint, - Self::StoreSubscrDict => Instruction::StoreSubscrDict, - Self::StoreSubscrListInt => Instruction::StoreSubscrListInt, - Self::ToBoolAlwaysTrue => Instruction::ToBoolAlwaysTrue, - Self::ToBoolBool => Instruction::ToBoolBool, - Self::ToBoolInt => Instruction::ToBoolInt, - Self::ToBoolList => Instruction::ToBoolList, - Self::ToBoolNone => Instruction::ToBoolNone, - Self::ToBoolStr => Instruction::ToBoolStr, - Self::UnpackSequenceList => Instruction::UnpackSequenceList, - Self::UnpackSequenceTuple => Instruction::UnpackSequenceTuple, - Self::UnpackSequenceTwoTuple => Instruction::UnpackSequenceTwoTuple, - Self::InstrumentedEndFor => Instruction::InstrumentedEndFor, - Self::InstrumentedPopIter => Instruction::InstrumentedPopIter, - Self::InstrumentedEndSend => Instruction::InstrumentedEndSend, - Self::InstrumentedForIter => Instruction::InstrumentedForIter, - Self::InstrumentedInstruction => Instruction::InstrumentedInstruction, - Self::InstrumentedJumpForward => Instruction::InstrumentedJumpForward, - Self::InstrumentedNotTaken => Instruction::InstrumentedNotTaken, - Self::InstrumentedPopJumpIfTrue => Instruction::InstrumentedPopJumpIfTrue, - Self::InstrumentedPopJumpIfFalse => Instruction::InstrumentedPopJumpIfFalse, - Self::InstrumentedPopJumpIfNone => Instruction::InstrumentedPopJumpIfNone, - Self::InstrumentedPopJumpIfNotNone => Instruction::InstrumentedPopJumpIfNotNone, - Self::InstrumentedResume => Instruction::InstrumentedResume, - Self::InstrumentedReturnValue => Instruction::InstrumentedReturnValue, - Self::InstrumentedYieldValue => Instruction::InstrumentedYieldValue, - Self::InstrumentedEndAsyncFor => Instruction::InstrumentedEndAsyncFor, - Self::InstrumentedLoadSuperAttr => Instruction::InstrumentedLoadSuperAttr, - Self::InstrumentedCall => Instruction::InstrumentedCall, - Self::InstrumentedCallKw => Instruction::InstrumentedCallKw, - Self::InstrumentedCallFunctionEx => Instruction::InstrumentedCallFunctionEx, - Self::InstrumentedJumpBackward => Instruction::InstrumentedJumpBackward, - Self::InstrumentedLine => Instruction::InstrumentedLine, - Self::EnterExecutor => Instruction::EnterExecutor, - } - } - - #[must_use] - pub const fn as_u8(self) -> u8 { - match self { - Self::Cache => 0, - Self::BinarySlice => 1, - Self::BuildTemplate => 2, - Self::BinaryOpInplaceAddUnicode => 3, - Self::CallFunctionEx => 4, - Self::CheckEgMatch => 5, - Self::CheckExcMatch => 6, - Self::CleanupThrow => 7, - Self::DeleteSubscr => 8, - Self::EndFor => 9, - Self::EndSend => 10, - Self::ExitInitCheck => 11, - Self::FormatSimple => 12, - Self::FormatWithSpec => 13, - Self::GetAiter => 14, - Self::GetAnext => 15, - Self::GetIter => 16, - Self::Reserved => 17, - Self::GetLen => 18, - Self::GetYieldFromIter => 19, - Self::InterpreterExit => 20, - Self::LoadBuildClass => 21, - Self::LoadLocals => 22, - Self::MakeFunction => 23, - Self::MatchKeys => 24, - Self::MatchMapping => 25, - Self::MatchSequence => 26, - Self::Nop => 27, - Self::NotTaken => 28, - Self::PopExcept => 29, - Self::PopIter => 30, - Self::PopTop => 31, - Self::PushExcInfo => 32, - Self::PushNull => 33, - Self::ReturnGenerator => 34, - Self::ReturnValue => 35, - Self::SetupAnnotations => 36, - Self::StoreSlice => 37, - Self::StoreSubscr => 38, - Self::ToBool => 39, - Self::UnaryInvert => 40, - Self::UnaryNegative => 41, - Self::UnaryNot => 42, - Self::WithExceptStart => 43, - Self::BinaryOp => 44, - Self::BuildInterpolation => 45, - Self::BuildList => 46, - Self::BuildMap => 47, - Self::BuildSet => 48, - Self::BuildSlice => 49, - Self::BuildString => 50, - Self::BuildTuple => 51, - Self::Call => 52, - Self::CallIntrinsic1 => 53, - Self::CallIntrinsic2 => 54, - Self::CallKw => 55, - Self::CompareOp => 56, - Self::ContainsOp => 57, - Self::ConvertValue => 58, - Self::Copy => 59, - Self::CopyFreeVars => 60, - Self::DeleteAttr => 61, - Self::DeleteDeref => 62, - Self::DeleteFast => 63, - Self::DeleteGlobal => 64, - Self::DeleteName => 65, - Self::DictMerge => 66, - Self::DictUpdate => 67, - Self::EndAsyncFor => 68, - Self::ExtendedArg => 69, - Self::ForIter => 70, - Self::GetAwaitable => 71, - Self::ImportFrom => 72, - Self::ImportName => 73, - Self::IsOp => 74, - Self::JumpBackward => 75, - Self::JumpBackwardNoInterrupt => 76, - Self::JumpForward => 77, - Self::ListAppend => 78, - Self::ListExtend => 79, - Self::LoadAttr => 80, - Self::LoadCommonConstant => 81, - Self::LoadConst => 82, - Self::LoadDeref => 83, - Self::LoadFast => 84, - Self::LoadFastAndClear => 85, - Self::LoadFastBorrow => 86, - Self::LoadFastBorrowLoadFastBorrow => 87, - Self::LoadFastCheck => 88, - Self::LoadFastLoadFast => 89, - Self::LoadFromDictOrDeref => 90, - Self::LoadFromDictOrGlobals => 91, - Self::LoadGlobal => 92, - Self::LoadName => 93, - Self::LoadSmallInt => 94, - Self::LoadSpecial => 95, - Self::LoadSuperAttr => 96, - Self::MakeCell => 97, - Self::MapAdd => 98, - Self::MatchClass => 99, - Self::PopJumpIfFalse => 100, - Self::PopJumpIfNone => 101, - Self::PopJumpIfNotNone => 102, - Self::PopJumpIfTrue => 103, - Self::RaiseVarargs => 104, - Self::Reraise => 105, - Self::Send => 106, - Self::SetAdd => 107, - Self::SetFunctionAttribute => 108, - Self::SetUpdate => 109, - Self::StoreAttr => 110, - Self::StoreDeref => 111, - Self::StoreFast => 112, - Self::StoreFastLoadFast => 113, - Self::StoreFastStoreFast => 114, - Self::StoreGlobal => 115, - Self::StoreName => 116, - Self::Swap => 117, - Self::UnpackEx => 118, - Self::UnpackSequence => 119, - Self::YieldValue => 120, - Self::Resume => 128, - Self::BinaryOpAddFloat => 129, - Self::BinaryOpAddInt => 130, - Self::BinaryOpAddUnicode => 131, - Self::BinaryOpExtend => 132, - Self::BinaryOpMultiplyFloat => 133, - Self::BinaryOpMultiplyInt => 134, - Self::BinaryOpSubscrDict => 135, - Self::BinaryOpSubscrGetitem => 136, - Self::BinaryOpSubscrListInt => 137, - Self::BinaryOpSubscrListSlice => 138, - Self::BinaryOpSubscrStrInt => 139, - Self::BinaryOpSubscrTupleInt => 140, - Self::BinaryOpSubtractFloat => 141, - Self::BinaryOpSubtractInt => 142, - Self::CallAllocAndEnterInit => 143, - Self::CallBoundMethodExactArgs => 144, - Self::CallBoundMethodGeneral => 145, - Self::CallBuiltinClass => 146, - Self::CallBuiltinFast => 147, - Self::CallBuiltinFastWithKeywords => 148, - Self::CallBuiltinO => 149, - Self::CallIsinstance => 150, - Self::CallKwBoundMethod => 151, - Self::CallKwNonPy => 152, - Self::CallKwPy => 153, - Self::CallLen => 154, - Self::CallListAppend => 155, - Self::CallMethodDescriptorFast => 156, - Self::CallMethodDescriptorFastWithKeywords => 157, - Self::CallMethodDescriptorNoargs => 158, - Self::CallMethodDescriptorO => 159, - Self::CallNonPyGeneral => 160, - Self::CallPyExactArgs => 161, - Self::CallPyGeneral => 162, - Self::CallStr1 => 163, - Self::CallTuple1 => 164, - Self::CallType1 => 165, - Self::CompareOpFloat => 166, - Self::CompareOpInt => 167, - Self::CompareOpStr => 168, - Self::ContainsOpDict => 169, - Self::ContainsOpSet => 170, - Self::ForIterGen => 171, - Self::ForIterList => 172, - Self::ForIterRange => 173, - Self::ForIterTuple => 174, - Self::JumpBackwardJit => 175, - Self::JumpBackwardNoJit => 176, - Self::LoadAttrClass => 177, - Self::LoadAttrClassWithMetaclassCheck => 178, - Self::LoadAttrGetattributeOverridden => 179, - Self::LoadAttrInstanceValue => 180, - Self::LoadAttrMethodLazyDict => 181, - Self::LoadAttrMethodNoDict => 182, - Self::LoadAttrMethodWithValues => 183, - Self::LoadAttrModule => 184, - Self::LoadAttrNondescriptorNoDict => 185, - Self::LoadAttrNondescriptorWithValues => 186, - Self::LoadAttrProperty => 187, - Self::LoadAttrSlot => 188, - Self::LoadAttrWithHint => 189, - Self::LoadConstImmortal => 190, - Self::LoadConstMortal => 191, - Self::LoadGlobalBuiltin => 192, - Self::LoadGlobalModule => 193, - Self::LoadSuperAttrAttr => 194, - Self::LoadSuperAttrMethod => 195, - Self::ResumeCheck => 196, - Self::SendGen => 197, - Self::StoreAttrInstanceValue => 198, - Self::StoreAttrSlot => 199, - Self::StoreAttrWithHint => 200, - Self::StoreSubscrDict => 201, - Self::StoreSubscrListInt => 202, - Self::ToBoolAlwaysTrue => 203, - Self::ToBoolBool => 204, - Self::ToBoolInt => 205, - Self::ToBoolList => 206, - Self::ToBoolNone => 207, - Self::ToBoolStr => 208, - Self::UnpackSequenceList => 209, - Self::UnpackSequenceTuple => 210, - Self::UnpackSequenceTwoTuple => 211, - Self::InstrumentedEndFor => 234, - Self::InstrumentedPopIter => 235, - Self::InstrumentedEndSend => 236, - Self::InstrumentedForIter => 237, - Self::InstrumentedInstruction => 238, - Self::InstrumentedJumpForward => 239, - Self::InstrumentedNotTaken => 240, - Self::InstrumentedPopJumpIfTrue => 241, - Self::InstrumentedPopJumpIfFalse => 242, - Self::InstrumentedPopJumpIfNone => 243, - Self::InstrumentedPopJumpIfNotNone => 244, - Self::InstrumentedResume => 245, - Self::InstrumentedReturnValue => 246, - Self::InstrumentedYieldValue => 247, - Self::InstrumentedEndAsyncFor => 248, - Self::InstrumentedLoadSuperAttr => 249, - Self::InstrumentedCall => 250, - Self::InstrumentedCallKw => 251, - Self::InstrumentedCallFunctionEx => 252, - Self::InstrumentedJumpBackward => 253, - Self::InstrumentedLine => 254, - Self::EnterExecutor => 255, - } - } - - #[must_use] - pub const fn cache_entries(self) -> usize { - match self.deoptimize() { - Self::StoreSubscr => 1, - Self::ToBool => 3, - Self::BinaryOp => 5, - Self::Call => 3, - Self::CallKw => 3, - Self::CompareOp => 1, - Self::ContainsOp => 1, - Self::ForIter => 1, - Self::JumpBackward => 1, - Self::LoadAttr => 9, - Self::LoadGlobal => 4, - Self::LoadSuperAttr => 1, - Self::PopJumpIfFalse => 1, - Self::PopJumpIfNone => 1, - Self::PopJumpIfNotNone => 1, - Self::PopJumpIfTrue => 1, - Self::Send => 1, - Self::StoreAttr => 4, - Self::UnpackSequence => 1, - _ => 0, - } - } - - #[must_use] - pub const fn deopt(self) -> Option { - Some(match self { - Self::ResumeCheck => Self::Resume, - Self::LoadConstMortal | Self::LoadConstImmortal => Self::LoadConst, - Self::ToBoolAlwaysTrue - | Self::ToBoolBool - | Self::ToBoolInt - | Self::ToBoolList - | Self::ToBoolNone - | Self::ToBoolStr => Self::ToBool, - Self::BinaryOpMultiplyInt - | Self::BinaryOpAddInt - | Self::BinaryOpSubtractInt - | Self::BinaryOpMultiplyFloat - | Self::BinaryOpAddFloat - | Self::BinaryOpSubtractFloat - | Self::BinaryOpAddUnicode - | Self::BinaryOpSubscrListInt - | Self::BinaryOpSubscrListSlice - | Self::BinaryOpSubscrTupleInt - | Self::BinaryOpSubscrStrInt - | Self::BinaryOpSubscrDict - | Self::BinaryOpSubscrGetitem - | Self::BinaryOpExtend - | Self::BinaryOpInplaceAddUnicode => Self::BinaryOp, - Self::StoreSubscrDict | Self::StoreSubscrListInt => Self::StoreSubscr, - Self::SendGen => Self::Send, - Self::UnpackSequenceTwoTuple | Self::UnpackSequenceTuple | Self::UnpackSequenceList => { - Self::UnpackSequence - } - Self::StoreAttrInstanceValue | Self::StoreAttrSlot | Self::StoreAttrWithHint => { - Self::StoreAttr - } - Self::LoadGlobalModule | Self::LoadGlobalBuiltin => Self::LoadGlobal, - Self::LoadSuperAttrAttr | Self::LoadSuperAttrMethod => Self::LoadSuperAttr, - Self::LoadAttrInstanceValue - | Self::LoadAttrModule - | Self::LoadAttrWithHint - | Self::LoadAttrSlot - | Self::LoadAttrClass - | Self::LoadAttrClassWithMetaclassCheck - | Self::LoadAttrProperty - | Self::LoadAttrGetattributeOverridden - | Self::LoadAttrMethodWithValues - | Self::LoadAttrMethodNoDict - | Self::LoadAttrMethodLazyDict - | Self::LoadAttrNondescriptorWithValues - | Self::LoadAttrNondescriptorNoDict => Self::LoadAttr, - Self::CompareOpFloat | Self::CompareOpInt | Self::CompareOpStr => Self::CompareOp, - Self::ContainsOpSet | Self::ContainsOpDict => Self::ContainsOp, - Self::JumpBackwardNoJit | Self::JumpBackwardJit => Self::JumpBackward, - Self::ForIterList | Self::ForIterTuple | Self::ForIterRange | Self::ForIterGen => { - Self::ForIter - } - Self::CallBoundMethodExactArgs - | Self::CallPyExactArgs - | Self::CallType1 - | Self::CallStr1 - | Self::CallTuple1 - | Self::CallBuiltinClass - | Self::CallBuiltinO - | Self::CallBuiltinFast - | Self::CallBuiltinFastWithKeywords - | Self::CallLen - | Self::CallIsinstance - | Self::CallListAppend - | Self::CallMethodDescriptorO - | Self::CallMethodDescriptorFastWithKeywords - | Self::CallMethodDescriptorNoargs - | Self::CallMethodDescriptorFast - | Self::CallAllocAndEnterInit - | Self::CallPyGeneral - | Self::CallBoundMethodGeneral - | Self::CallNonPyGeneral => Self::Call, - Self::CallKwBoundMethod | Self::CallKwPy | Self::CallKwNonPy => Self::CallKw, - _ => return None, - }) - } - - /// Does this opcode have 'HAS_ARG_FLAG' set. - #[must_use] - pub const fn has_arg(self) -> bool { - matches!( - self, - Self::BinaryOp - | Self::BuildInterpolation - | Self::BuildList - | Self::BuildMap - | Self::BuildSet - | Self::BuildSlice - | Self::BuildString - | Self::BuildTuple - | Self::Call - | Self::CallIntrinsic1 - | Self::CallIntrinsic2 - | Self::CallKw - | Self::CompareOp - | Self::ContainsOp - | Self::ConvertValue - | Self::Copy - | Self::CopyFreeVars - | Self::DeleteAttr - | Self::DeleteDeref - | Self::DeleteFast - | Self::DeleteGlobal - | Self::DeleteName - | Self::DictMerge - | Self::DictUpdate - | Self::EndAsyncFor - | Self::ExtendedArg - | Self::ForIter - | Self::GetAwaitable - | Self::ImportFrom - | Self::ImportName - | Self::IsOp - | Self::JumpBackward - | Self::JumpBackwardNoInterrupt - | Self::JumpForward - | Self::ListAppend - | Self::ListExtend - | Self::LoadAttr - | Self::LoadCommonConstant - | Self::LoadConst - | Self::LoadDeref - | Self::LoadFast - | Self::LoadFastAndClear - | Self::LoadFastBorrow - | Self::LoadFastBorrowLoadFastBorrow - | Self::LoadFastCheck - | Self::LoadFastLoadFast - | Self::LoadFromDictOrDeref - | Self::LoadFromDictOrGlobals - | Self::LoadGlobal - | Self::LoadName - | Self::LoadSmallInt - | Self::LoadSpecial - | Self::LoadSuperAttr - | Self::MakeCell - | Self::MapAdd - | Self::MatchClass - | Self::PopJumpIfFalse - | Self::PopJumpIfNone - | Self::PopJumpIfNotNone - | Self::PopJumpIfTrue - | Self::RaiseVarargs - | Self::Reraise - | Self::Send - | Self::SetAdd - | Self::SetFunctionAttribute - | Self::SetUpdate - | Self::StoreAttr - | Self::StoreDeref - | Self::StoreFast - | Self::StoreFastLoadFast - | Self::StoreFastStoreFast - | Self::StoreGlobal - | Self::StoreName - | Self::Swap - | Self::UnpackEx - | Self::UnpackSequence - | Self::YieldValue - | Self::Resume - | Self::CallAllocAndEnterInit - | Self::CallBoundMethodExactArgs - | Self::CallBoundMethodGeneral - | Self::CallBuiltinClass - | Self::CallBuiltinFast - | Self::CallBuiltinFastWithKeywords - | Self::CallBuiltinO - | Self::CallIsinstance - | Self::CallKwBoundMethod - | Self::CallKwNonPy - | Self::CallKwPy - | Self::CallListAppend - | Self::CallMethodDescriptorFast - | Self::CallMethodDescriptorFastWithKeywords - | Self::CallMethodDescriptorNoargs - | Self::CallMethodDescriptorO - | Self::CallNonPyGeneral - | Self::CallPyExactArgs - | Self::CallPyGeneral - | Self::CallStr1 - | Self::CallTuple1 - | Self::CallType1 - | Self::CompareOpFloat - | Self::CompareOpInt - | Self::CompareOpStr - | Self::ContainsOpDict - | Self::ContainsOpSet - | Self::ForIterGen - | Self::ForIterList - | Self::ForIterRange - | Self::ForIterTuple - | Self::JumpBackwardJit - | Self::JumpBackwardNoJit - | Self::LoadAttrClass - | Self::LoadAttrClassWithMetaclassCheck - | Self::LoadAttrGetattributeOverridden - | Self::LoadAttrInstanceValue - | Self::LoadAttrMethodLazyDict - | Self::LoadAttrMethodNoDict - | Self::LoadAttrMethodWithValues - | Self::LoadAttrModule - | Self::LoadAttrNondescriptorNoDict - | Self::LoadAttrNondescriptorWithValues - | Self::LoadAttrProperty - | Self::LoadAttrSlot - | Self::LoadAttrWithHint - | Self::LoadConstImmortal - | Self::LoadConstMortal - | Self::LoadGlobalBuiltin - | Self::LoadGlobalModule - | Self::LoadSuperAttrAttr - | Self::LoadSuperAttrMethod - | Self::SendGen - | Self::StoreAttrWithHint - | Self::UnpackSequenceList - | Self::UnpackSequenceTuple - | Self::UnpackSequenceTwoTuple - | Self::InstrumentedForIter - | Self::InstrumentedJumpForward - | Self::InstrumentedPopJumpIfTrue - | Self::InstrumentedPopJumpIfFalse - | Self::InstrumentedPopJumpIfNone - | Self::InstrumentedPopJumpIfNotNone - | Self::InstrumentedResume - | Self::InstrumentedYieldValue - | Self::InstrumentedEndAsyncFor - | Self::InstrumentedLoadSuperAttr - | Self::InstrumentedCall - | Self::InstrumentedCallKw - | Self::InstrumentedJumpBackward - | Self::EnterExecutor - ) - } - - /// Does this opcode have 'HAS_CONST_FLAG' set. - #[must_use] - pub const fn has_const(self) -> bool { - matches!( - self, - Self::LoadConst | Self::LoadConstImmortal | Self::LoadConstMortal - ) - } - - /// Does this opcode have 'HAS_FREE_FLAG' set. - #[must_use] - pub const fn has_free(self) -> bool { - matches!( - self, - Self::DeleteDeref | Self::LoadFromDictOrDeref | Self::MakeCell | Self::StoreDeref - ) - } - - /// Does this opcode have 'HAS_JUMP_FLAG' set. - #[must_use] - pub const fn has_jump(self) -> bool { - matches!( - self, - Self::EndAsyncFor - | Self::ForIter - | Self::JumpBackward - | Self::JumpBackwardNoInterrupt - | Self::JumpForward - | Self::PopJumpIfFalse - | Self::PopJumpIfNone - | Self::PopJumpIfNotNone - | Self::PopJumpIfTrue - | Self::Send - | Self::ForIterList - | Self::ForIterRange - | Self::ForIterTuple - | Self::JumpBackwardJit - | Self::JumpBackwardNoJit - | Self::InstrumentedForIter - | Self::InstrumentedEndAsyncFor - ) - } - - /// Does this opcode have 'HAS_LOCAL_FLAG' set. - #[must_use] - pub const fn has_local(self) -> bool { - matches!( - self, - Self::BinaryOpInplaceAddUnicode - | Self::DeleteFast - | Self::LoadDeref - | Self::LoadFast - | Self::LoadFastAndClear - | Self::LoadFastBorrow - | Self::LoadFastBorrowLoadFastBorrow - | Self::LoadFastCheck - | Self::LoadFastLoadFast - | Self::StoreFast - | Self::StoreFastLoadFast - | Self::StoreFastStoreFast - ) - } - - /// Does this opcode have 'HAS_NAME_FLAG' set. - #[must_use] - pub const fn has_name(self) -> bool { - matches!( - self, - Self::DeleteAttr - | Self::DeleteGlobal - | Self::DeleteName - | Self::ImportFrom - | Self::ImportName - | Self::LoadAttr - | Self::LoadFromDictOrGlobals - | Self::LoadGlobal - | Self::LoadName - | Self::LoadSuperAttr - | Self::StoreAttr - | Self::StoreGlobal - | Self::StoreName - | Self::LoadAttrGetattributeOverridden - | Self::LoadAttrWithHint - | Self::LoadSuperAttrAttr - | Self::LoadSuperAttrMethod - | Self::StoreAttrWithHint - | Self::InstrumentedLoadSuperAttr - ) - } - - /// Stack effect of [`Self::stack_effect_info`]. - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 { - self.stack_effect_info(oparg).effect() - } - - #[must_use] - pub fn stack_effect_info(&self, oparg: u32) -> StackEffect { - // Reason for converting oparg to i32 is because of expressions like `1 + (oparg -1)` - // that causes underflow errors. - let oparg = i32::try_from(oparg).expect("oparg does not fit in an `i32`"); - - let (pushed, popped) = match self { - Self::Cache => (0, 0), - Self::BinarySlice => (1, 3), - Self::BuildTemplate => (1, 2), - Self::BinaryOpInplaceAddUnicode => (0, 2), - Self::CallFunctionEx => (1, 4), - Self::CheckEgMatch => (2, 2), - Self::CheckExcMatch => (2, 2), - Self::CleanupThrow => (2, 3), - Self::DeleteSubscr => (0, 2), - Self::EndFor => (0, 1), - Self::EndSend => (1, 2), - Self::ExitInitCheck => (0, 1), - Self::FormatSimple => (1, 1), - Self::FormatWithSpec => (1, 2), - Self::GetAiter => (1, 1), - Self::GetAnext => (2, 1), - Self::GetIter => (1, 1), - Self::Reserved => (0, 0), - Self::GetLen => (2, 1), - Self::GetYieldFromIter => (1, 1), - Self::InterpreterExit => (0, 1), - Self::LoadBuildClass => (1, 0), - Self::LoadLocals => (1, 0), - Self::MakeFunction => (1, 1), - Self::MatchKeys => (3, 2), - Self::MatchMapping => (2, 1), - Self::MatchSequence => (2, 1), - Self::Nop => (0, 0), - Self::NotTaken => (0, 0), - Self::PopExcept => (0, 1), - Self::PopIter => (0, 1), - Self::PopTop => (0, 1), - Self::PushExcInfo => (2, 1), - Self::PushNull => (1, 0), - Self::ReturnGenerator => (1, 0), - Self::ReturnValue => (1, 1), - Self::SetupAnnotations => (0, 0), - Self::StoreSlice => (0, 4), - Self::StoreSubscr => (0, 3), - Self::ToBool => (1, 1), - Self::UnaryInvert => (1, 1), - Self::UnaryNegative => (1, 1), - Self::UnaryNot => (1, 1), - Self::WithExceptStart => ( - 7, // TODO: Differs from CPython `6` - 6, // TODO: Differs from CPython `5` - ), - Self::BinaryOp => (1, 2), - Self::BuildInterpolation => (1, 2 + (oparg & 1)), - Self::BuildList => (1, oparg), - Self::BuildMap => (1, oparg * 2), - Self::BuildSet => (1, oparg), - Self::BuildSlice => (1, oparg), - Self::BuildString => (1, oparg), - Self::BuildTuple => (1, oparg), - Self::Call => (1, 2 + oparg), - Self::CallIntrinsic1 => (1, 1), - Self::CallIntrinsic2 => (1, 2), - Self::CallKw => (1, 3 + oparg), - Self::CompareOp => (1, 2), - Self::ContainsOp => (1, 2), - Self::ConvertValue => (1, 1), - Self::Copy => (2 + (oparg - 1), 1 + (oparg - 1)), - Self::CopyFreeVars => (0, 0), - Self::DeleteAttr => (0, 1), - Self::DeleteDeref => (0, 0), - Self::DeleteFast => (0, 0), - Self::DeleteGlobal => (0, 0), - Self::DeleteName => (0, 0), - Self::DictMerge => (4 + (oparg - 1), 5 + (oparg - 1)), - Self::DictUpdate => (1 + (oparg - 1), 2 + (oparg - 1)), - Self::EndAsyncFor => (0, 2), - Self::ExtendedArg => (0, 0), - Self::ForIter => (2, 1), - Self::GetAwaitable => (1, 1), - Self::ImportFrom => (2, 1), - Self::ImportName => (1, 2), - Self::IsOp => (1, 2), - Self::JumpBackward => (0, 0), - Self::JumpBackwardNoInterrupt => (0, 0), - Self::JumpForward => (0, 0), - Self::ListAppend => (1 + (oparg - 1), 2 + (oparg - 1)), - Self::ListExtend => (1 + (oparg - 1), 2 + (oparg - 1)), - Self::LoadAttr => (1 + (oparg & 1), 1), - Self::LoadCommonConstant => (1, 0), - Self::LoadConst => (1, 0), - Self::LoadDeref => (1, 0), - Self::LoadFast => (1, 0), - Self::LoadFastAndClear => (1, 0), - Self::LoadFastBorrow => (1, 0), - Self::LoadFastBorrowLoadFastBorrow => (2, 0), - Self::LoadFastCheck => (1, 0), - Self::LoadFastLoadFast => (2, 0), - Self::LoadFromDictOrDeref => (1, 1), - Self::LoadFromDictOrGlobals => (1, 1), - Self::LoadGlobal => (1 + (oparg & 1), 0), - Self::LoadName => (1, 0), - Self::LoadSmallInt => (1, 0), - Self::LoadSpecial => (2, 1), - Self::LoadSuperAttr => (1 + (oparg & 1), 3), - Self::MakeCell => (0, 0), - Self::MapAdd => (1 + (oparg - 1), 3 + (oparg - 1)), - Self::MatchClass => (1, 3), - Self::PopJumpIfFalse => (0, 1), - Self::PopJumpIfNone => (0, 1), - Self::PopJumpIfNotNone => (0, 1), - Self::PopJumpIfTrue => (0, 1), - Self::RaiseVarargs => (0, oparg), - Self::Reraise => (oparg, 1 + oparg), - Self::Send => (2, 2), - Self::SetAdd => (1 + (oparg - 1), 2 + (oparg - 1)), - Self::SetFunctionAttribute => (1, 2), - Self::SetUpdate => (1 + (oparg - 1), 2 + (oparg - 1)), - Self::StoreAttr => (0, 2), - Self::StoreDeref => (0, 1), - Self::StoreFast => (0, 1), - Self::StoreFastLoadFast => (1, 1), - Self::StoreFastStoreFast => (0, 2), - Self::StoreGlobal => (0, 1), - Self::StoreName => (0, 1), - Self::Swap => (2 + (oparg - 2), 2 + (oparg - 2)), - Self::UnpackEx => (1 + (oparg & 0xFF) + (oparg >> 8), 1), - Self::UnpackSequence => (oparg, 1), - Self::YieldValue => (1, 1), - Self::Resume => (0, 0), - Self::BinaryOpAddFloat => (1, 2), - Self::BinaryOpAddInt => (1, 2), - Self::BinaryOpAddUnicode => (1, 2), - Self::BinaryOpExtend => (1, 2), - Self::BinaryOpMultiplyFloat => (1, 2), - Self::BinaryOpMultiplyInt => (1, 2), - Self::BinaryOpSubscrDict => (1, 2), - Self::BinaryOpSubscrGetitem => (0, 2), - Self::BinaryOpSubscrListInt => (1, 2), - Self::BinaryOpSubscrListSlice => (1, 2), - Self::BinaryOpSubscrStrInt => (1, 2), - Self::BinaryOpSubscrTupleInt => (1, 2), - Self::BinaryOpSubtractFloat => (1, 2), - Self::BinaryOpSubtractInt => (1, 2), - Self::CallAllocAndEnterInit => (0, 2 + oparg), - Self::CallBoundMethodExactArgs => (0, 2 + oparg), - Self::CallBoundMethodGeneral => (0, 2 + oparg), - Self::CallBuiltinClass => (1, 2 + oparg), - Self::CallBuiltinFast => (1, 2 + oparg), - Self::CallBuiltinFastWithKeywords => (1, 2 + oparg), - Self::CallBuiltinO => (1, 2 + oparg), - Self::CallIsinstance => (1, 2 + oparg), - Self::CallKwBoundMethod => (0, 3 + oparg), - Self::CallKwNonPy => (1, 3 + oparg), - Self::CallKwPy => (0, 3 + oparg), - Self::CallLen => (1, 3), - Self::CallListAppend => (0, 3), - Self::CallMethodDescriptorFast => (1, 2 + oparg), - Self::CallMethodDescriptorFastWithKeywords => (1, 2 + oparg), - Self::CallMethodDescriptorNoargs => (1, 2 + oparg), - Self::CallMethodDescriptorO => (1, 2 + oparg), - Self::CallNonPyGeneral => (1, 2 + oparg), - Self::CallPyExactArgs => (0, 2 + oparg), - Self::CallPyGeneral => (0, 2 + oparg), - Self::CallStr1 => (1, 3), - Self::CallTuple1 => (1, 3), - Self::CallType1 => (1, 3), - Self::CompareOpFloat => (1, 2), - Self::CompareOpInt => (1, 2), - Self::CompareOpStr => (1, 2), - Self::ContainsOpDict => (1, 2), - Self::ContainsOpSet => (1, 2), - Self::ForIterGen => (1, 1), - Self::ForIterList => (2, 1), - Self::ForIterRange => (2, 1), - Self::ForIterTuple => (2, 1), - Self::JumpBackwardJit => (0, 0), - Self::JumpBackwardNoJit => (0, 0), - Self::LoadAttrClass => (1 + (oparg & 1), 1), - Self::LoadAttrClassWithMetaclassCheck => (1 + (oparg & 1), 1), - Self::LoadAttrGetattributeOverridden => (1, 1), - Self::LoadAttrInstanceValue => (1 + (oparg & 1), 1), - Self::LoadAttrMethodLazyDict => (2, 1), - Self::LoadAttrMethodNoDict => (2, 1), - Self::LoadAttrMethodWithValues => (2, 1), - Self::LoadAttrModule => (1 + (oparg & 1), 1), - Self::LoadAttrNondescriptorNoDict => (1, 1), - Self::LoadAttrNondescriptorWithValues => (1, 1), - Self::LoadAttrProperty => (0, 1), - Self::LoadAttrSlot => (1 + (oparg & 1), 1), - Self::LoadAttrWithHint => (1 + (oparg & 1), 1), - Self::LoadConstImmortal => (1, 0), - Self::LoadConstMortal => (1, 0), - Self::LoadGlobalBuiltin => (1 + (oparg & 1), 0), - Self::LoadGlobalModule => (1 + (oparg & 1), 0), - Self::LoadSuperAttrAttr => (1, 3), - Self::LoadSuperAttrMethod => (2, 3), - Self::ResumeCheck => (0, 0), - Self::SendGen => (1, 2), - Self::StoreAttrInstanceValue => (0, 2), - Self::StoreAttrSlot => (0, 2), - Self::StoreAttrWithHint => (0, 2), - Self::StoreSubscrDict => (0, 3), - Self::StoreSubscrListInt => (0, 3), - Self::ToBoolAlwaysTrue => (1, 1), - Self::ToBoolBool => (1, 1), - Self::ToBoolInt => (1, 1), - Self::ToBoolList => (1, 1), - Self::ToBoolNone => (1, 1), - Self::ToBoolStr => (1, 1), - Self::UnpackSequenceList => (oparg, 1), - Self::UnpackSequenceTuple => (oparg, 1), - Self::UnpackSequenceTwoTuple => (2, 1), - Self::InstrumentedEndFor => (1, 2), - Self::InstrumentedPopIter => (0, 1), - Self::InstrumentedEndSend => (1, 2), - Self::InstrumentedForIter => (2, 1), - Self::InstrumentedInstruction => (0, 0), - Self::InstrumentedJumpForward => (0, 0), - Self::InstrumentedNotTaken => (0, 0), - Self::InstrumentedPopJumpIfTrue => (0, 1), - Self::InstrumentedPopJumpIfFalse => (0, 1), - Self::InstrumentedPopJumpIfNone => (0, 1), - Self::InstrumentedPopJumpIfNotNone => (0, 1), - Self::InstrumentedResume => (0, 0), - Self::InstrumentedReturnValue => (1, 1), - Self::InstrumentedYieldValue => (1, 1), - Self::InstrumentedEndAsyncFor => (0, 2), - Self::InstrumentedLoadSuperAttr => (1 + (oparg & 1), 3), - Self::InstrumentedCall => (1, 2 + oparg), - Self::InstrumentedCallKw => (1, 3 + oparg), - Self::InstrumentedCallFunctionEx => (1, 4), - Self::InstrumentedJumpBackward => (0, 0), - Self::InstrumentedLine => (0, 0), - Self::EnterExecutor => (0, 0), - }; - - debug_assert!(u32::try_from(pushed).is_ok()); - debug_assert!(u32::try_from(popped).is_ok()); - - StackEffect::new(pushed as u32, popped as u32) - } - - #[must_use] - pub const fn to_base(self) -> Option { - Some(match self { - Self::InstrumentedCall => Self::Call, - Self::InstrumentedCallFunctionEx => Self::CallFunctionEx, - Self::InstrumentedCallKw => Self::CallKw, - Self::InstrumentedEndAsyncFor => Self::EndAsyncFor, - Self::InstrumentedEndFor => Self::EndFor, - Self::InstrumentedEndSend => Self::EndSend, - Self::InstrumentedForIter => Self::ForIter, - Self::InstrumentedJumpBackward => Self::JumpBackward, - Self::InstrumentedJumpForward => Self::JumpForward, - Self::InstrumentedLoadSuperAttr => Self::LoadSuperAttr, - Self::InstrumentedNotTaken => Self::NotTaken, - Self::InstrumentedPopIter => Self::PopIter, - Self::InstrumentedPopJumpIfFalse => Self::PopJumpIfFalse, - Self::InstrumentedPopJumpIfNone => Self::PopJumpIfNone, - Self::InstrumentedPopJumpIfNotNone => Self::PopJumpIfNotNone, - Self::InstrumentedPopJumpIfTrue => Self::PopJumpIfTrue, - Self::InstrumentedResume => Self::Resume, - Self::InstrumentedReturnValue => Self::ReturnValue, - Self::InstrumentedYieldValue => Self::YieldValue, - _ => return None, - }) - } - - #[must_use] - pub const fn to_instrumented(self) -> Option { - Some(match self { - Self::Call => Self::InstrumentedCall, - Self::CallFunctionEx => Self::InstrumentedCallFunctionEx, - Self::CallKw => Self::InstrumentedCallKw, - Self::EndAsyncFor => Self::InstrumentedEndAsyncFor, - Self::EndFor => Self::InstrumentedEndFor, - Self::EndSend => Self::InstrumentedEndSend, - Self::ForIter => Self::InstrumentedForIter, - Self::JumpBackward => Self::InstrumentedJumpBackward, - Self::JumpForward => Self::InstrumentedJumpForward, - Self::LoadSuperAttr => Self::InstrumentedLoadSuperAttr, - Self::NotTaken => Self::InstrumentedNotTaken, - Self::PopIter => Self::InstrumentedPopIter, - Self::PopJumpIfFalse => Self::InstrumentedPopJumpIfFalse, - Self::PopJumpIfNone => Self::InstrumentedPopJumpIfNone, - Self::PopJumpIfNotNone => Self::InstrumentedPopJumpIfNotNone, - Self::PopJumpIfTrue => Self::InstrumentedPopJumpIfTrue, - Self::Resume => Self::InstrumentedResume, - Self::ReturnValue => Self::InstrumentedReturnValue, - Self::YieldValue => Self::InstrumentedYieldValue, - _ => return None, - }) - } - - pub const fn try_from_u8(value: u8) -> Result { - Ok(match value { - 0 => Self::Cache, - 1 => Self::BinarySlice, - 2 => Self::BuildTemplate, - 3 => Self::BinaryOpInplaceAddUnicode, - 4 => Self::CallFunctionEx, - 5 => Self::CheckEgMatch, - 6 => Self::CheckExcMatch, - 7 => Self::CleanupThrow, - 8 => Self::DeleteSubscr, - 9 => Self::EndFor, - 10 => Self::EndSend, - 11 => Self::ExitInitCheck, - 12 => Self::FormatSimple, - 13 => Self::FormatWithSpec, - 14 => Self::GetAiter, - 15 => Self::GetAnext, - 16 => Self::GetIter, - 17 => Self::Reserved, - 18 => Self::GetLen, - 19 => Self::GetYieldFromIter, - 20 => Self::InterpreterExit, - 21 => Self::LoadBuildClass, - 22 => Self::LoadLocals, - 23 => Self::MakeFunction, - 24 => Self::MatchKeys, - 25 => Self::MatchMapping, - 26 => Self::MatchSequence, - 27 => Self::Nop, - 28 => Self::NotTaken, - 29 => Self::PopExcept, - 30 => Self::PopIter, - 31 => Self::PopTop, - 32 => Self::PushExcInfo, - 33 => Self::PushNull, - 34 => Self::ReturnGenerator, - 35 => Self::ReturnValue, - 36 => Self::SetupAnnotations, - 37 => Self::StoreSlice, - 38 => Self::StoreSubscr, - 39 => Self::ToBool, - 40 => Self::UnaryInvert, - 41 => Self::UnaryNegative, - 42 => Self::UnaryNot, - 43 => Self::WithExceptStart, - 44 => Self::BinaryOp, - 45 => Self::BuildInterpolation, - 46 => Self::BuildList, - 47 => Self::BuildMap, - 48 => Self::BuildSet, - 49 => Self::BuildSlice, - 50 => Self::BuildString, - 51 => Self::BuildTuple, - 52 => Self::Call, - 53 => Self::CallIntrinsic1, - 54 => Self::CallIntrinsic2, - 55 => Self::CallKw, - 56 => Self::CompareOp, - 57 => Self::ContainsOp, - 58 => Self::ConvertValue, - 59 => Self::Copy, - 60 => Self::CopyFreeVars, - 61 => Self::DeleteAttr, - 62 => Self::DeleteDeref, - 63 => Self::DeleteFast, - 64 => Self::DeleteGlobal, - 65 => Self::DeleteName, - 66 => Self::DictMerge, - 67 => Self::DictUpdate, - 68 => Self::EndAsyncFor, - 69 => Self::ExtendedArg, - 70 => Self::ForIter, - 71 => Self::GetAwaitable, - 72 => Self::ImportFrom, - 73 => Self::ImportName, - 74 => Self::IsOp, - 75 => Self::JumpBackward, - 76 => Self::JumpBackwardNoInterrupt, - 77 => Self::JumpForward, - 78 => Self::ListAppend, - 79 => Self::ListExtend, - 80 => Self::LoadAttr, - 81 => Self::LoadCommonConstant, - 82 => Self::LoadConst, - 83 => Self::LoadDeref, - 84 => Self::LoadFast, - 85 => Self::LoadFastAndClear, - 86 => Self::LoadFastBorrow, - 87 => Self::LoadFastBorrowLoadFastBorrow, - 88 => Self::LoadFastCheck, - 89 => Self::LoadFastLoadFast, - 90 => Self::LoadFromDictOrDeref, - 91 => Self::LoadFromDictOrGlobals, - 92 => Self::LoadGlobal, - 93 => Self::LoadName, - 94 => Self::LoadSmallInt, - 95 => Self::LoadSpecial, - 96 => Self::LoadSuperAttr, - 97 => Self::MakeCell, - 98 => Self::MapAdd, - 99 => Self::MatchClass, - 100 => Self::PopJumpIfFalse, - 101 => Self::PopJumpIfNone, - 102 => Self::PopJumpIfNotNone, - 103 => Self::PopJumpIfTrue, - 104 => Self::RaiseVarargs, - 105 => Self::Reraise, - 106 => Self::Send, - 107 => Self::SetAdd, - 108 => Self::SetFunctionAttribute, - 109 => Self::SetUpdate, - 110 => Self::StoreAttr, - 111 => Self::StoreDeref, - 112 => Self::StoreFast, - 113 => Self::StoreFastLoadFast, - 114 => Self::StoreFastStoreFast, - 115 => Self::StoreGlobal, - 116 => Self::StoreName, - 117 => Self::Swap, - 118 => Self::UnpackEx, - 119 => Self::UnpackSequence, - 120 => Self::YieldValue, - 128 => Self::Resume, - 129 => Self::BinaryOpAddFloat, - 130 => Self::BinaryOpAddInt, - 131 => Self::BinaryOpAddUnicode, - 132 => Self::BinaryOpExtend, - 133 => Self::BinaryOpMultiplyFloat, - 134 => Self::BinaryOpMultiplyInt, - 135 => Self::BinaryOpSubscrDict, - 136 => Self::BinaryOpSubscrGetitem, - 137 => Self::BinaryOpSubscrListInt, - 138 => Self::BinaryOpSubscrListSlice, - 139 => Self::BinaryOpSubscrStrInt, - 140 => Self::BinaryOpSubscrTupleInt, - 141 => Self::BinaryOpSubtractFloat, - 142 => Self::BinaryOpSubtractInt, - 143 => Self::CallAllocAndEnterInit, - 144 => Self::CallBoundMethodExactArgs, - 145 => Self::CallBoundMethodGeneral, - 146 => Self::CallBuiltinClass, - 147 => Self::CallBuiltinFast, - 148 => Self::CallBuiltinFastWithKeywords, - 149 => Self::CallBuiltinO, - 150 => Self::CallIsinstance, - 151 => Self::CallKwBoundMethod, - 152 => Self::CallKwNonPy, - 153 => Self::CallKwPy, - 154 => Self::CallLen, - 155 => Self::CallListAppend, - 156 => Self::CallMethodDescriptorFast, - 157 => Self::CallMethodDescriptorFastWithKeywords, - 158 => Self::CallMethodDescriptorNoargs, - 159 => Self::CallMethodDescriptorO, - 160 => Self::CallNonPyGeneral, - 161 => Self::CallPyExactArgs, - 162 => Self::CallPyGeneral, - 163 => Self::CallStr1, - 164 => Self::CallTuple1, - 165 => Self::CallType1, - 166 => Self::CompareOpFloat, - 167 => Self::CompareOpInt, - 168 => Self::CompareOpStr, - 169 => Self::ContainsOpDict, - 170 => Self::ContainsOpSet, - 171 => Self::ForIterGen, - 172 => Self::ForIterList, - 173 => Self::ForIterRange, - 174 => Self::ForIterTuple, - 175 => Self::JumpBackwardJit, - 176 => Self::JumpBackwardNoJit, - 177 => Self::LoadAttrClass, - 178 => Self::LoadAttrClassWithMetaclassCheck, - 179 => Self::LoadAttrGetattributeOverridden, - 180 => Self::LoadAttrInstanceValue, - 181 => Self::LoadAttrMethodLazyDict, - 182 => Self::LoadAttrMethodNoDict, - 183 => Self::LoadAttrMethodWithValues, - 184 => Self::LoadAttrModule, - 185 => Self::LoadAttrNondescriptorNoDict, - 186 => Self::LoadAttrNondescriptorWithValues, - 187 => Self::LoadAttrProperty, - 188 => Self::LoadAttrSlot, - 189 => Self::LoadAttrWithHint, - 190 => Self::LoadConstImmortal, - 191 => Self::LoadConstMortal, - 192 => Self::LoadGlobalBuiltin, - 193 => Self::LoadGlobalModule, - 194 => Self::LoadSuperAttrAttr, - 195 => Self::LoadSuperAttrMethod, - 196 => Self::ResumeCheck, - 197 => Self::SendGen, - 198 => Self::StoreAttrInstanceValue, - 199 => Self::StoreAttrSlot, - 200 => Self::StoreAttrWithHint, - 201 => Self::StoreSubscrDict, - 202 => Self::StoreSubscrListInt, - 203 => Self::ToBoolAlwaysTrue, - 204 => Self::ToBoolBool, - 205 => Self::ToBoolInt, - 206 => Self::ToBoolList, - 207 => Self::ToBoolNone, - 208 => Self::ToBoolStr, - 209 => Self::UnpackSequenceList, - 210 => Self::UnpackSequenceTuple, - 211 => Self::UnpackSequenceTwoTuple, - 234 => Self::InstrumentedEndFor, - 235 => Self::InstrumentedPopIter, - 236 => Self::InstrumentedEndSend, - 237 => Self::InstrumentedForIter, - 238 => Self::InstrumentedInstruction, - 239 => Self::InstrumentedJumpForward, - 240 => Self::InstrumentedNotTaken, - 241 => Self::InstrumentedPopJumpIfTrue, - 242 => Self::InstrumentedPopJumpIfFalse, - 243 => Self::InstrumentedPopJumpIfNone, - 244 => Self::InstrumentedPopJumpIfNotNone, - 245 => Self::InstrumentedResume, - 246 => Self::InstrumentedReturnValue, - 247 => Self::InstrumentedYieldValue, - 248 => Self::InstrumentedEndAsyncFor, - 249 => Self::InstrumentedLoadSuperAttr, - 250 => Self::InstrumentedCall, - 251 => Self::InstrumentedCallKw, - 252 => Self::InstrumentedCallFunctionEx, - 253 => Self::InstrumentedJumpBackward, - 254 => Self::InstrumentedLine, - 255 => Self::EnterExecutor, - _ => return Err(MarshalError::InvalidBytecode), - }) - } -} - -impl From for Instruction { - fn from(opcode: Opcode) -> Self { - opcode.as_instruction() - } -} - -impl From for u8 { - fn from(opcode: Opcode) -> Self { - opcode.as_u8() - } -} - -impl TryFrom for Opcode { - type Error = MarshalError; - - fn try_from(value: u8) -> Result { - Self::try_from_u8(value) - } -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -#[repr(u8)] // TODO: Remove this `#[repr(...)]` -pub enum Instruction { - Cache = 0, - BinarySlice = 1, - BuildTemplate = 2, - BinaryOpInplaceAddUnicode = 3, - CallFunctionEx = 4, - CheckEgMatch = 5, - CheckExcMatch = 6, - CleanupThrow = 7, - DeleteSubscr = 8, - EndFor = 9, - EndSend = 10, - ExitInitCheck = 11, - FormatSimple = 12, - FormatWithSpec = 13, - GetAiter = 14, - GetAnext = 15, - GetIter = 16, - Reserved = 17, - GetLen = 18, - GetYieldFromIter = 19, - InterpreterExit = 20, - LoadBuildClass = 21, - LoadLocals = 22, - MakeFunction = 23, - MatchKeys = 24, - MatchMapping = 25, - MatchSequence = 26, - Nop = 27, - NotTaken = 28, - PopExcept = 29, - PopIter = 30, - PopTop = 31, - PushExcInfo = 32, - PushNull = 33, - ReturnGenerator = 34, - ReturnValue = 35, - SetupAnnotations = 36, - StoreSlice = 37, - StoreSubscr = 38, - ToBool = 39, - UnaryInvert = 40, - UnaryNegative = 41, - UnaryNot = 42, - WithExceptStart = 43, - BinaryOp { - op: Arg, - } = 44, - BuildInterpolation { - format: Arg, - } = 45, - BuildList { - count: Arg, - } = 46, - BuildMap { - count: Arg, - } = 47, - BuildSet { - count: Arg, - } = 48, - BuildSlice { - argc: Arg, - } = 49, - BuildString { - count: Arg, - } = 50, - BuildTuple { - count: Arg, - } = 51, - Call { - argc: Arg, - } = 52, - CallIntrinsic1 { - func: Arg, - } = 53, - CallIntrinsic2 { - func: Arg, - } = 54, - CallKw { - argc: Arg, - } = 55, - CompareOp { - opname: Arg, - } = 56, - ContainsOp { - invert: Arg, - } = 57, - ConvertValue { - oparg: Arg, - } = 58, - Copy { - i: Arg, - } = 59, - CopyFreeVars { - n: Arg, - } = 60, - DeleteAttr { - namei: Arg, - } = 61, - DeleteDeref { - i: Arg, - } = 62, - DeleteFast { - var_num: Arg, - } = 63, - DeleteGlobal { - namei: Arg, - } = 64, - DeleteName { - namei: Arg, - } = 65, - DictMerge { - i: Arg, - } = 66, - DictUpdate { - i: Arg, - } = 67, - EndAsyncFor = 68, - ExtendedArg = 69, - ForIter { - delta: Arg, - } = 70, - GetAwaitable { - r#where: Arg, - } = 71, - ImportFrom { - namei: Arg, - } = 72, - ImportName { - namei: Arg, - } = 73, - IsOp { - invert: Arg, - } = 74, - JumpBackward { - delta: Arg, - } = 75, - JumpBackwardNoInterrupt { - delta: Arg, - } = 76, - JumpForward { - delta: Arg, - } = 77, - ListAppend { - i: Arg, - } = 78, - ListExtend { - i: Arg, - } = 79, - LoadAttr { - namei: Arg, - } = 80, - LoadCommonConstant { - idx: Arg, - } = 81, - LoadConst { - consti: Arg, - } = 82, - LoadDeref { - i: Arg, - } = 83, - LoadFast { - var_num: Arg, - } = 84, - LoadFastAndClear { - var_num: Arg, - } = 85, - LoadFastBorrow { - var_num: Arg, - } = 86, - LoadFastBorrowLoadFastBorrow { - var_nums: Arg, - } = 87, - LoadFastCheck { - var_num: Arg, - } = 88, - LoadFastLoadFast { - var_nums: Arg, - } = 89, - LoadFromDictOrDeref { - i: Arg, - } = 90, - LoadFromDictOrGlobals { - i: Arg, - } = 91, - LoadGlobal { - namei: Arg, - } = 92, - LoadName { - namei: Arg, - } = 93, - LoadSmallInt { - i: Arg, - } = 94, - LoadSpecial { - method: Arg, - } = 95, - LoadSuperAttr { - namei: Arg, - } = 96, - MakeCell { - i: Arg, - } = 97, - MapAdd { - i: Arg, - } = 98, - MatchClass { - count: Arg, - } = 99, - PopJumpIfFalse { - delta: Arg, - } = 100, - PopJumpIfNone { - delta: Arg, - } = 101, - PopJumpIfNotNone { - delta: Arg, - } = 102, - PopJumpIfTrue { - delta: Arg, - } = 103, - RaiseVarargs { - argc: Arg, - } = 104, - Reraise { - depth: Arg, - } = 105, - Send { - delta: Arg, - } = 106, - SetAdd { - i: Arg, - } = 107, - SetFunctionAttribute { - flag: Arg, - } = 108, - SetUpdate { - i: Arg, - } = 109, - StoreAttr { - namei: Arg, - } = 110, - StoreDeref { - i: Arg, - } = 111, - StoreFast { - var_num: Arg, - } = 112, - StoreFastLoadFast { - var_nums: Arg, - } = 113, - StoreFastStoreFast { - var_nums: Arg, - } = 114, - StoreGlobal { - namei: Arg, - } = 115, - StoreName { - namei: Arg, - } = 116, - Swap { - i: Arg, - } = 117, - UnpackEx { - counts: Arg, - } = 118, - UnpackSequence { - count: Arg, - } = 119, - YieldValue { - arg: Arg, - } = 120, - Resume { - context: Arg, - } = 128, - BinaryOpAddFloat = 129, - BinaryOpAddInt = 130, - BinaryOpAddUnicode = 131, - BinaryOpExtend = 132, - BinaryOpMultiplyFloat = 133, - BinaryOpMultiplyInt = 134, - BinaryOpSubscrDict = 135, - BinaryOpSubscrGetitem = 136, - BinaryOpSubscrListInt = 137, - BinaryOpSubscrListSlice = 138, - BinaryOpSubscrStrInt = 139, - BinaryOpSubscrTupleInt = 140, - BinaryOpSubtractFloat = 141, - BinaryOpSubtractInt = 142, - CallAllocAndEnterInit = 143, - CallBoundMethodExactArgs = 144, - CallBoundMethodGeneral = 145, - CallBuiltinClass = 146, - CallBuiltinFast = 147, - CallBuiltinFastWithKeywords = 148, - CallBuiltinO = 149, - CallIsinstance = 150, - CallKwBoundMethod = 151, - CallKwNonPy = 152, - CallKwPy = 153, - CallLen = 154, - CallListAppend = 155, - CallMethodDescriptorFast = 156, - CallMethodDescriptorFastWithKeywords = 157, - CallMethodDescriptorNoargs = 158, - CallMethodDescriptorO = 159, - CallNonPyGeneral = 160, - CallPyExactArgs = 161, - CallPyGeneral = 162, - CallStr1 = 163, - CallTuple1 = 164, - CallType1 = 165, - CompareOpFloat = 166, - CompareOpInt = 167, - CompareOpStr = 168, - ContainsOpDict = 169, - ContainsOpSet = 170, - ForIterGen = 171, - ForIterList = 172, - ForIterRange = 173, - ForIterTuple = 174, - JumpBackwardJit = 175, - JumpBackwardNoJit = 176, - LoadAttrClass = 177, - LoadAttrClassWithMetaclassCheck = 178, - LoadAttrGetattributeOverridden = 179, - LoadAttrInstanceValue = 180, - LoadAttrMethodLazyDict = 181, - LoadAttrMethodNoDict = 182, - LoadAttrMethodWithValues = 183, - LoadAttrModule = 184, - LoadAttrNondescriptorNoDict = 185, - LoadAttrNondescriptorWithValues = 186, - LoadAttrProperty = 187, - LoadAttrSlot = 188, - LoadAttrWithHint = 189, - LoadConstImmortal = 190, - LoadConstMortal = 191, - LoadGlobalBuiltin = 192, - LoadGlobalModule = 193, - LoadSuperAttrAttr = 194, - LoadSuperAttrMethod = 195, - ResumeCheck = 196, - SendGen = 197, - StoreAttrInstanceValue = 198, - StoreAttrSlot = 199, - StoreAttrWithHint = 200, - StoreSubscrDict = 201, - StoreSubscrListInt = 202, - ToBoolAlwaysTrue = 203, - ToBoolBool = 204, - ToBoolInt = 205, - ToBoolList = 206, - ToBoolNone = 207, - ToBoolStr = 208, - UnpackSequenceList = 209, - UnpackSequenceTuple = 210, - UnpackSequenceTwoTuple = 211, - InstrumentedEndFor = 234, - InstrumentedPopIter = 235, - InstrumentedEndSend = 236, - InstrumentedForIter = 237, - InstrumentedInstruction = 238, - InstrumentedJumpForward = 239, - InstrumentedNotTaken = 240, - InstrumentedPopJumpIfTrue = 241, - InstrumentedPopJumpIfFalse = 242, - InstrumentedPopJumpIfNone = 243, - InstrumentedPopJumpIfNotNone = 244, - InstrumentedResume = 245, - InstrumentedReturnValue = 246, - InstrumentedYieldValue = 247, - InstrumentedEndAsyncFor = 248, - InstrumentedLoadSuperAttr = 249, - InstrumentedCall = 250, - InstrumentedCallKw = 251, - InstrumentedCallFunctionEx = 252, - InstrumentedJumpBackward = 253, - InstrumentedLine = 254, - EnterExecutor = 255, -} - -impl Instruction { - #[must_use] - pub const fn as_u8(self) -> u8 { - self.as_opcode().as_u8() - } - - /// Returns self as a [`Opcode`]. - #[must_use] - pub const fn as_opcode(self) -> Opcode { - match self { - Self::Cache => Opcode::Cache, - Self::BinarySlice => Opcode::BinarySlice, - Self::BuildTemplate => Opcode::BuildTemplate, - Self::BinaryOpInplaceAddUnicode => Opcode::BinaryOpInplaceAddUnicode, - Self::CallFunctionEx => Opcode::CallFunctionEx, - Self::CheckEgMatch => Opcode::CheckEgMatch, - Self::CheckExcMatch => Opcode::CheckExcMatch, - Self::CleanupThrow => Opcode::CleanupThrow, - Self::DeleteSubscr => Opcode::DeleteSubscr, - Self::EndFor => Opcode::EndFor, - Self::EndSend => Opcode::EndSend, - Self::ExitInitCheck => Opcode::ExitInitCheck, - Self::FormatSimple => Opcode::FormatSimple, - Self::FormatWithSpec => Opcode::FormatWithSpec, - Self::GetAiter => Opcode::GetAiter, - Self::GetAnext => Opcode::GetAnext, - Self::GetIter => Opcode::GetIter, - Self::Reserved => Opcode::Reserved, - Self::GetLen => Opcode::GetLen, - Self::GetYieldFromIter => Opcode::GetYieldFromIter, - Self::InterpreterExit => Opcode::InterpreterExit, - Self::LoadBuildClass => Opcode::LoadBuildClass, - Self::LoadLocals => Opcode::LoadLocals, - Self::MakeFunction => Opcode::MakeFunction, - Self::MatchKeys => Opcode::MatchKeys, - Self::MatchMapping => Opcode::MatchMapping, - Self::MatchSequence => Opcode::MatchSequence, - Self::Nop => Opcode::Nop, - Self::NotTaken => Opcode::NotTaken, - Self::PopExcept => Opcode::PopExcept, - Self::PopIter => Opcode::PopIter, - Self::PopTop => Opcode::PopTop, - Self::PushExcInfo => Opcode::PushExcInfo, - Self::PushNull => Opcode::PushNull, - Self::ReturnGenerator => Opcode::ReturnGenerator, - Self::ReturnValue => Opcode::ReturnValue, - Self::SetupAnnotations => Opcode::SetupAnnotations, - Self::StoreSlice => Opcode::StoreSlice, - Self::StoreSubscr => Opcode::StoreSubscr, - Self::ToBool => Opcode::ToBool, - Self::UnaryInvert => Opcode::UnaryInvert, - Self::UnaryNegative => Opcode::UnaryNegative, - Self::UnaryNot => Opcode::UnaryNot, - Self::WithExceptStart => Opcode::WithExceptStart, - Self::BinaryOp { .. } => Opcode::BinaryOp, - Self::BuildInterpolation { .. } => Opcode::BuildInterpolation, - Self::BuildList { .. } => Opcode::BuildList, - Self::BuildMap { .. } => Opcode::BuildMap, - Self::BuildSet { .. } => Opcode::BuildSet, - Self::BuildSlice { .. } => Opcode::BuildSlice, - Self::BuildString { .. } => Opcode::BuildString, - Self::BuildTuple { .. } => Opcode::BuildTuple, - Self::Call { .. } => Opcode::Call, - Self::CallIntrinsic1 { .. } => Opcode::CallIntrinsic1, - Self::CallIntrinsic2 { .. } => Opcode::CallIntrinsic2, - Self::CallKw { .. } => Opcode::CallKw, - Self::CompareOp { .. } => Opcode::CompareOp, - Self::ContainsOp { .. } => Opcode::ContainsOp, - Self::ConvertValue { .. } => Opcode::ConvertValue, - Self::Copy { .. } => Opcode::Copy, - Self::CopyFreeVars { .. } => Opcode::CopyFreeVars, - Self::DeleteAttr { .. } => Opcode::DeleteAttr, - Self::DeleteDeref { .. } => Opcode::DeleteDeref, - Self::DeleteFast { .. } => Opcode::DeleteFast, - Self::DeleteGlobal { .. } => Opcode::DeleteGlobal, - Self::DeleteName { .. } => Opcode::DeleteName, - Self::DictMerge { .. } => Opcode::DictMerge, - Self::DictUpdate { .. } => Opcode::DictUpdate, - Self::EndAsyncFor => Opcode::EndAsyncFor, - Self::ExtendedArg => Opcode::ExtendedArg, - Self::ForIter { .. } => Opcode::ForIter, - Self::GetAwaitable { .. } => Opcode::GetAwaitable, - Self::ImportFrom { .. } => Opcode::ImportFrom, - Self::ImportName { .. } => Opcode::ImportName, - Self::IsOp { .. } => Opcode::IsOp, - Self::JumpBackward { .. } => Opcode::JumpBackward, - Self::JumpBackwardNoInterrupt { .. } => Opcode::JumpBackwardNoInterrupt, - Self::JumpForward { .. } => Opcode::JumpForward, - Self::ListAppend { .. } => Opcode::ListAppend, - Self::ListExtend { .. } => Opcode::ListExtend, - Self::LoadAttr { .. } => Opcode::LoadAttr, - Self::LoadCommonConstant { .. } => Opcode::LoadCommonConstant, - Self::LoadConst { .. } => Opcode::LoadConst, - Self::LoadDeref { .. } => Opcode::LoadDeref, - Self::LoadFast { .. } => Opcode::LoadFast, - Self::LoadFastAndClear { .. } => Opcode::LoadFastAndClear, - Self::LoadFastBorrow { .. } => Opcode::LoadFastBorrow, - Self::LoadFastBorrowLoadFastBorrow { .. } => Opcode::LoadFastBorrowLoadFastBorrow, - Self::LoadFastCheck { .. } => Opcode::LoadFastCheck, - Self::LoadFastLoadFast { .. } => Opcode::LoadFastLoadFast, - Self::LoadFromDictOrDeref { .. } => Opcode::LoadFromDictOrDeref, - Self::LoadFromDictOrGlobals { .. } => Opcode::LoadFromDictOrGlobals, - Self::LoadGlobal { .. } => Opcode::LoadGlobal, - Self::LoadName { .. } => Opcode::LoadName, - Self::LoadSmallInt { .. } => Opcode::LoadSmallInt, - Self::LoadSpecial { .. } => Opcode::LoadSpecial, - Self::LoadSuperAttr { .. } => Opcode::LoadSuperAttr, - Self::MakeCell { .. } => Opcode::MakeCell, - Self::MapAdd { .. } => Opcode::MapAdd, - Self::MatchClass { .. } => Opcode::MatchClass, - Self::PopJumpIfFalse { .. } => Opcode::PopJumpIfFalse, - Self::PopJumpIfNone { .. } => Opcode::PopJumpIfNone, - Self::PopJumpIfNotNone { .. } => Opcode::PopJumpIfNotNone, - Self::PopJumpIfTrue { .. } => Opcode::PopJumpIfTrue, - Self::RaiseVarargs { .. } => Opcode::RaiseVarargs, - Self::Reraise { .. } => Opcode::Reraise, - Self::Send { .. } => Opcode::Send, - Self::SetAdd { .. } => Opcode::SetAdd, - Self::SetFunctionAttribute { .. } => Opcode::SetFunctionAttribute, - Self::SetUpdate { .. } => Opcode::SetUpdate, - Self::StoreAttr { .. } => Opcode::StoreAttr, - Self::StoreDeref { .. } => Opcode::StoreDeref, - Self::StoreFast { .. } => Opcode::StoreFast, - Self::StoreFastLoadFast { .. } => Opcode::StoreFastLoadFast, - Self::StoreFastStoreFast { .. } => Opcode::StoreFastStoreFast, - Self::StoreGlobal { .. } => Opcode::StoreGlobal, - Self::StoreName { .. } => Opcode::StoreName, - Self::Swap { .. } => Opcode::Swap, - Self::UnpackEx { .. } => Opcode::UnpackEx, - Self::UnpackSequence { .. } => Opcode::UnpackSequence, - Self::YieldValue { .. } => Opcode::YieldValue, - Self::Resume { .. } => Opcode::Resume, - Self::BinaryOpAddFloat => Opcode::BinaryOpAddFloat, - Self::BinaryOpAddInt => Opcode::BinaryOpAddInt, - Self::BinaryOpAddUnicode => Opcode::BinaryOpAddUnicode, - Self::BinaryOpExtend => Opcode::BinaryOpExtend, - Self::BinaryOpMultiplyFloat => Opcode::BinaryOpMultiplyFloat, - Self::BinaryOpMultiplyInt => Opcode::BinaryOpMultiplyInt, - Self::BinaryOpSubscrDict => Opcode::BinaryOpSubscrDict, - Self::BinaryOpSubscrGetitem => Opcode::BinaryOpSubscrGetitem, - Self::BinaryOpSubscrListInt => Opcode::BinaryOpSubscrListInt, - Self::BinaryOpSubscrListSlice => Opcode::BinaryOpSubscrListSlice, - Self::BinaryOpSubscrStrInt => Opcode::BinaryOpSubscrStrInt, - Self::BinaryOpSubscrTupleInt => Opcode::BinaryOpSubscrTupleInt, - Self::BinaryOpSubtractFloat => Opcode::BinaryOpSubtractFloat, - Self::BinaryOpSubtractInt => Opcode::BinaryOpSubtractInt, - Self::CallAllocAndEnterInit => Opcode::CallAllocAndEnterInit, - Self::CallBoundMethodExactArgs => Opcode::CallBoundMethodExactArgs, - Self::CallBoundMethodGeneral => Opcode::CallBoundMethodGeneral, - Self::CallBuiltinClass => Opcode::CallBuiltinClass, - Self::CallBuiltinFast => Opcode::CallBuiltinFast, - Self::CallBuiltinFastWithKeywords => Opcode::CallBuiltinFastWithKeywords, - Self::CallBuiltinO => Opcode::CallBuiltinO, - Self::CallIsinstance => Opcode::CallIsinstance, - Self::CallKwBoundMethod => Opcode::CallKwBoundMethod, - Self::CallKwNonPy => Opcode::CallKwNonPy, - Self::CallKwPy => Opcode::CallKwPy, - Self::CallLen => Opcode::CallLen, - Self::CallListAppend => Opcode::CallListAppend, - Self::CallMethodDescriptorFast => Opcode::CallMethodDescriptorFast, - Self::CallMethodDescriptorFastWithKeywords => { - Opcode::CallMethodDescriptorFastWithKeywords - } - Self::CallMethodDescriptorNoargs => Opcode::CallMethodDescriptorNoargs, - Self::CallMethodDescriptorO => Opcode::CallMethodDescriptorO, - Self::CallNonPyGeneral => Opcode::CallNonPyGeneral, - Self::CallPyExactArgs => Opcode::CallPyExactArgs, - Self::CallPyGeneral => Opcode::CallPyGeneral, - Self::CallStr1 => Opcode::CallStr1, - Self::CallTuple1 => Opcode::CallTuple1, - Self::CallType1 => Opcode::CallType1, - Self::CompareOpFloat => Opcode::CompareOpFloat, - Self::CompareOpInt => Opcode::CompareOpInt, - Self::CompareOpStr => Opcode::CompareOpStr, - Self::ContainsOpDict => Opcode::ContainsOpDict, - Self::ContainsOpSet => Opcode::ContainsOpSet, - Self::ForIterGen => Opcode::ForIterGen, - Self::ForIterList => Opcode::ForIterList, - Self::ForIterRange => Opcode::ForIterRange, - Self::ForIterTuple => Opcode::ForIterTuple, - Self::JumpBackwardJit => Opcode::JumpBackwardJit, - Self::JumpBackwardNoJit => Opcode::JumpBackwardNoJit, - Self::LoadAttrClass => Opcode::LoadAttrClass, - Self::LoadAttrClassWithMetaclassCheck => Opcode::LoadAttrClassWithMetaclassCheck, - Self::LoadAttrGetattributeOverridden => Opcode::LoadAttrGetattributeOverridden, - Self::LoadAttrInstanceValue => Opcode::LoadAttrInstanceValue, - Self::LoadAttrMethodLazyDict => Opcode::LoadAttrMethodLazyDict, - Self::LoadAttrMethodNoDict => Opcode::LoadAttrMethodNoDict, - Self::LoadAttrMethodWithValues => Opcode::LoadAttrMethodWithValues, - Self::LoadAttrModule => Opcode::LoadAttrModule, - Self::LoadAttrNondescriptorNoDict => Opcode::LoadAttrNondescriptorNoDict, - Self::LoadAttrNondescriptorWithValues => Opcode::LoadAttrNondescriptorWithValues, - Self::LoadAttrProperty => Opcode::LoadAttrProperty, - Self::LoadAttrSlot => Opcode::LoadAttrSlot, - Self::LoadAttrWithHint => Opcode::LoadAttrWithHint, - Self::LoadConstImmortal => Opcode::LoadConstImmortal, - Self::LoadConstMortal => Opcode::LoadConstMortal, - Self::LoadGlobalBuiltin => Opcode::LoadGlobalBuiltin, - Self::LoadGlobalModule => Opcode::LoadGlobalModule, - Self::LoadSuperAttrAttr => Opcode::LoadSuperAttrAttr, - Self::LoadSuperAttrMethod => Opcode::LoadSuperAttrMethod, - Self::ResumeCheck => Opcode::ResumeCheck, - Self::SendGen => Opcode::SendGen, - Self::StoreAttrInstanceValue => Opcode::StoreAttrInstanceValue, - Self::StoreAttrSlot => Opcode::StoreAttrSlot, - Self::StoreAttrWithHint => Opcode::StoreAttrWithHint, - Self::StoreSubscrDict => Opcode::StoreSubscrDict, - Self::StoreSubscrListInt => Opcode::StoreSubscrListInt, - Self::ToBoolAlwaysTrue => Opcode::ToBoolAlwaysTrue, - Self::ToBoolBool => Opcode::ToBoolBool, - Self::ToBoolInt => Opcode::ToBoolInt, - Self::ToBoolList => Opcode::ToBoolList, - Self::ToBoolNone => Opcode::ToBoolNone, - Self::ToBoolStr => Opcode::ToBoolStr, - Self::UnpackSequenceList => Opcode::UnpackSequenceList, - Self::UnpackSequenceTuple => Opcode::UnpackSequenceTuple, - Self::UnpackSequenceTwoTuple => Opcode::UnpackSequenceTwoTuple, - Self::InstrumentedEndFor => Opcode::InstrumentedEndFor, - Self::InstrumentedPopIter => Opcode::InstrumentedPopIter, - Self::InstrumentedEndSend => Opcode::InstrumentedEndSend, - Self::InstrumentedForIter => Opcode::InstrumentedForIter, - Self::InstrumentedInstruction => Opcode::InstrumentedInstruction, - Self::InstrumentedJumpForward => Opcode::InstrumentedJumpForward, - Self::InstrumentedNotTaken => Opcode::InstrumentedNotTaken, - Self::InstrumentedPopJumpIfTrue => Opcode::InstrumentedPopJumpIfTrue, - Self::InstrumentedPopJumpIfFalse => Opcode::InstrumentedPopJumpIfFalse, - Self::InstrumentedPopJumpIfNone => Opcode::InstrumentedPopJumpIfNone, - Self::InstrumentedPopJumpIfNotNone => Opcode::InstrumentedPopJumpIfNotNone, - Self::InstrumentedResume => Opcode::InstrumentedResume, - Self::InstrumentedReturnValue => Opcode::InstrumentedReturnValue, - Self::InstrumentedYieldValue => Opcode::InstrumentedYieldValue, - Self::InstrumentedEndAsyncFor => Opcode::InstrumentedEndAsyncFor, - Self::InstrumentedLoadSuperAttr => Opcode::InstrumentedLoadSuperAttr, - Self::InstrumentedCall => Opcode::InstrumentedCall, - Self::InstrumentedCallKw => Opcode::InstrumentedCallKw, - Self::InstrumentedCallFunctionEx => Opcode::InstrumentedCallFunctionEx, - Self::InstrumentedJumpBackward => Opcode::InstrumentedJumpBackward, - Self::InstrumentedLine => Opcode::InstrumentedLine, - Self::EnterExecutor => Opcode::EnterExecutor, - } - } - - #[must_use] - pub const fn cache_entries(self) -> usize { - self.as_opcode().cache_entries() - } - - #[must_use] - pub const fn deopt(self) -> Option { - if let Some(opcode) = self.as_opcode().deopt() { - Some(opcode.as_instruction()) - } else { - None - } - } - - #[must_use] - pub const fn label_arg(&self) -> Option> { - Some(match self { - Self::ForIter { delta } => *delta, - Self::JumpBackward { delta } => *delta, - Self::JumpBackwardNoInterrupt { delta } => *delta, - Self::JumpForward { delta } => *delta, - Self::PopJumpIfFalse { delta } => *delta, - Self::PopJumpIfNone { delta } => *delta, - Self::PopJumpIfNotNone { delta } => *delta, - Self::PopJumpIfTrue { delta } => *delta, - Self::Send { delta } => *delta, - _ => return None, - }) - } - - /// Stack effect of [`Self::stack_effect_info`]. - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 { - self.as_opcode().stack_effect(oparg) - } - - #[must_use] - pub fn stack_effect_info(&self, oparg: u32) -> StackEffect { - self.as_opcode().stack_effect_info(oparg) - } - - #[must_use] - pub const fn to_base(self) -> Option { - if let Some(opcode) = self.as_opcode().to_base() { - Some(opcode.as_instruction()) - } else { - None - } - } - - #[must_use] - pub const fn to_instrumented(self) -> Option { - if let Some(opcode) = self.as_opcode().to_instrumented() { - Some(opcode.as_instruction()) - } else { - None - } - } - - pub const fn try_from_u8(value: u8) -> Result { - match Opcode::try_from_u8(value) { - Ok(opcode) => Ok(opcode.as_instruction()), - Err(e) => Err(e), - } - } -} - -impl From for u8 { - fn from(instruction: Instruction) -> Self { - instruction.as_u8() - } -} - -impl From for Opcode { - fn from(instruction: Instruction) -> Self { - instruction.as_opcode() - } -} - -impl TryFrom for Instruction { - type Error = MarshalError; - - fn try_from(value: u8) -> Result { - Self::try_from_u8(value) - } -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum PseudoOpcode { - AnnotationsPlaceholder, - Jump, - JumpIfFalse, - JumpIfTrue, - JumpNoInterrupt, - LoadClosure, - PopBlock, - SetupCleanup, - SetupFinally, - SetupWith, - StoreFastMaybeNull, -} - -impl PseudoOpcode { - /// Returns self as [`PseudoInstruction`]. - #[must_use] - pub const fn as_instruction(self) -> PseudoInstruction { - match self { - Self::AnnotationsPlaceholder => PseudoInstruction::AnnotationsPlaceholder, - Self::Jump => PseudoInstruction::Jump { - delta: Arg::marker(), - }, - Self::JumpIfFalse => PseudoInstruction::JumpIfFalse { - delta: Arg::marker(), - }, - Self::JumpIfTrue => PseudoInstruction::JumpIfTrue { - delta: Arg::marker(), - }, - Self::JumpNoInterrupt => PseudoInstruction::JumpNoInterrupt { - delta: Arg::marker(), - }, - Self::LoadClosure => PseudoInstruction::LoadClosure { i: Arg::marker() }, - Self::PopBlock => PseudoInstruction::PopBlock, - Self::SetupCleanup => PseudoInstruction::SetupCleanup { - delta: Arg::marker(), - }, - Self::SetupFinally => PseudoInstruction::SetupFinally { - delta: Arg::marker(), - }, - Self::SetupWith => PseudoInstruction::SetupWith { - delta: Arg::marker(), - }, - Self::StoreFastMaybeNull => PseudoInstruction::StoreFastMaybeNull { - var_num: Arg::marker(), - }, - } - } - - #[must_use] - pub const fn as_u16(self) -> u16 { - match self { - Self::AnnotationsPlaceholder => 256, - Self::Jump => 257, - Self::JumpIfFalse => 258, - Self::JumpIfTrue => 259, - Self::JumpNoInterrupt => 260, - Self::LoadClosure => 261, - Self::PopBlock => 262, - Self::SetupCleanup => 263, - Self::SetupFinally => 264, - Self::SetupWith => 265, - Self::StoreFastMaybeNull => 266, - } - } - - #[must_use] - pub const fn cache_entries(self) -> usize { - 0 - } - - #[must_use] - pub const fn deopt(self) -> Option { - None - } - - /// Does this opcode have 'HAS_ARG_FLAG' set. - #[must_use] - pub const fn has_arg(self) -> bool { - matches!( - self, - Self::Jump - | Self::JumpIfFalse - | Self::JumpIfTrue - | Self::JumpNoInterrupt - | Self::LoadClosure - | Self::StoreFastMaybeNull - ) - } - - /// Does this opcode have 'HAS_CONST_FLAG' set. - #[must_use] - pub const fn has_const(self) -> bool { - false - } - - /// Does this opcode have 'HAS_FREE_FLAG' set. - #[must_use] - pub const fn has_free(self) -> bool { - false - } - - /// Does this opcode have 'HAS_JUMP_FLAG' set. - #[must_use] - pub const fn has_jump(self) -> bool { - matches!( - self, - Self::Jump | Self::JumpIfFalse | Self::JumpIfTrue | Self::JumpNoInterrupt - ) - } - - /// Does this opcode have 'HAS_LOCAL_FLAG' set. - #[must_use] - pub const fn has_local(self) -> bool { - matches!(self, Self::LoadClosure | Self::StoreFastMaybeNull) - } - - /// Does this opcode have 'HAS_NAME_FLAG' set. - #[must_use] - pub const fn has_name(self) -> bool { - false - } - - /// Stack effect of [`Self::stack_effect_info`]. - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 { - self.stack_effect_info(oparg).effect() - } - - #[must_use] - pub fn stack_effect_info(&self, _oparg: u32) -> StackEffect { - let (pushed, popped) = match self { - Self::AnnotationsPlaceholder => (0, 0), - Self::Jump => (0, 0), - Self::JumpIfFalse => (1, 1), - Self::JumpIfTrue => (1, 1), - Self::JumpNoInterrupt => (0, 0), - Self::LoadClosure => (1, 0), - Self::PopBlock => (0, 0), - Self::SetupCleanup => ( - 0, // TODO: Differs from CPython `2` - 0, - ), - Self::SetupFinally => ( - 0, // TODO: Differs from CPython `1` - 0, - ), - Self::SetupWith => ( - 0, // TODO: Differs from CPython `1` - 0, - ), - Self::StoreFastMaybeNull => (0, 1), - }; - - debug_assert!(u32::try_from(pushed).is_ok()); - debug_assert!(u32::try_from(popped).is_ok()); - - StackEffect::new(pushed as u32, popped as u32) - } - - #[must_use] - pub const fn to_base(self) -> Option { - None - } - - #[must_use] - pub const fn to_instrumented(self) -> Option { - None - } - - pub const fn try_from_u16(value: u16) -> Result { - Ok(match value { - 256 => Self::AnnotationsPlaceholder, - 257 => Self::Jump, - 258 => Self::JumpIfFalse, - 259 => Self::JumpIfTrue, - 260 => Self::JumpNoInterrupt, - 261 => Self::LoadClosure, - 262 => Self::PopBlock, - 263 => Self::SetupCleanup, - 264 => Self::SetupFinally, - 265 => Self::SetupWith, - 266 => Self::StoreFastMaybeNull, - _ => return Err(MarshalError::InvalidBytecode), - }) - } -} - -impl From for PseudoInstruction { - fn from(opcode: PseudoOpcode) -> Self { - opcode.as_instruction() - } -} - -impl From for u16 { - fn from(opcode: PseudoOpcode) -> Self { - opcode.as_u16() - } -} - -impl TryFrom for PseudoOpcode { - type Error = MarshalError; - - fn try_from(value: u16) -> Result { - Self::try_from_u16(value) - } -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -#[repr(u16)] // TODO: Remove this `#[repr(...)]` -pub enum PseudoInstruction { - AnnotationsPlaceholder = 256, - Jump { delta: Arg } = 257, - JumpIfFalse { delta: Arg } = 258, - JumpIfTrue { delta: Arg } = 259, - JumpNoInterrupt { delta: Arg } = 260, - LoadClosure { i: Arg } = 261, - PopBlock = 262, - SetupCleanup { delta: Arg } = 263, - SetupFinally { delta: Arg } = 264, - SetupWith { delta: Arg } = 265, - StoreFastMaybeNull { var_num: Arg } = 266, -} - -impl PseudoInstruction { - #[must_use] - pub const fn as_u16(self) -> u16 { - self.as_opcode().as_u16() - } - - /// Returns self as a [`PseudoOpcode`]. - #[must_use] - pub const fn as_opcode(self) -> PseudoOpcode { - match self { - Self::AnnotationsPlaceholder => PseudoOpcode::AnnotationsPlaceholder, - Self::Jump { .. } => PseudoOpcode::Jump, - Self::JumpIfFalse { .. } => PseudoOpcode::JumpIfFalse, - Self::JumpIfTrue { .. } => PseudoOpcode::JumpIfTrue, - Self::JumpNoInterrupt { .. } => PseudoOpcode::JumpNoInterrupt, - Self::LoadClosure { .. } => PseudoOpcode::LoadClosure, - Self::PopBlock => PseudoOpcode::PopBlock, - Self::SetupCleanup { .. } => PseudoOpcode::SetupCleanup, - Self::SetupFinally { .. } => PseudoOpcode::SetupFinally, - Self::SetupWith { .. } => PseudoOpcode::SetupWith, - Self::StoreFastMaybeNull { .. } => PseudoOpcode::StoreFastMaybeNull, - } - } - - #[must_use] - pub const fn cache_entries(self) -> usize { - self.as_opcode().cache_entries() - } - - #[must_use] - pub const fn deopt(self) -> Option { - if let Some(opcode) = self.as_opcode().deopt() { - Some(opcode.as_instruction()) - } else { - None - } - } - - #[must_use] - pub const fn label_arg(&self) -> Option> { - Some(match self { - Self::Jump { delta } => *delta, - Self::JumpIfFalse { delta } => *delta, - Self::JumpIfTrue { delta } => *delta, - Self::JumpNoInterrupt { delta } => *delta, - Self::SetupCleanup { delta } => *delta, - Self::SetupFinally { delta } => *delta, - Self::SetupWith { delta } => *delta, - _ => return None, - }) - } - - /// Stack effect of [`Self::stack_effect_info`]. - #[must_use] - pub fn stack_effect(&self, oparg: u32) -> i32 { - self.as_opcode().stack_effect(oparg) - } - - #[must_use] - pub fn stack_effect_info(&self, oparg: u32) -> StackEffect { - self.as_opcode().stack_effect_info(oparg) - } - - #[must_use] - pub const fn to_base(self) -> Option { - if let Some(opcode) = self.as_opcode().to_base() { - Some(opcode.as_instruction()) - } else { - None - } - } - - #[must_use] - pub const fn to_instrumented(self) -> Option { - if let Some(opcode) = self.as_opcode().to_instrumented() { - Some(opcode.as_instruction()) - } else { - None - } - } - - pub const fn try_from_u16(value: u16) -> Result { - match PseudoOpcode::try_from_u16(value) { - Ok(opcode) => Ok(opcode.as_instruction()), - Err(e) => Err(e), - } - } -} - -impl From for u16 { - fn from(instruction: PseudoInstruction) -> Self { - instruction.as_u16() - } -} - -impl From for PseudoOpcode { - fn from(instruction: PseudoInstruction) -> Self { - instruction.as_opcode() - } -} - -impl TryFrom for PseudoInstruction { - type Error = MarshalError; - - fn try_from(value: u16) -> Result { - Self::try_from_u16(value) - } -} diff --git a/crates/compiler-core/src/bytecode/oparg.rs b/crates/compiler-core/src/bytecode/oparg.rs index 6cb58b1a68b..03628604a3f 100644 --- a/crates/compiler-core/src/bytecode/oparg.rs +++ b/crates/compiler-core/src/bytecode/oparg.rs @@ -1,7 +1,7 @@ use core::fmt; use crate::{ - bytecode::{CodeUnit, instructions::Instruction}, + bytecode::{CodeUnit, Instruction}, marshal::MarshalError, }; diff --git a/crates/compiler-core/src/bytecode/opcode_metadata.rs b/crates/compiler-core/src/bytecode/opcode_metadata.rs new file mode 100644 index 00000000000..64c8d3c5330 --- /dev/null +++ b/crates/compiler-core/src/bytecode/opcode_metadata.rs @@ -0,0 +1,833 @@ +// This file is generated by tools/opcode_metadata/generate_rs_opcode_metadata.py +// Do not edit! + +use crate::{bytecode::instruction::StackEffect, marshal::MarshalError}; + +impl super::Opcode { + /// Returns [`Self`] as [`u8`]. + #[must_use] + pub const fn as_u8(self) -> u8 { + self.as_numeric() + } + + #[must_use] + pub const fn cache_entries(self) -> usize { + match self.deoptimize() { + Self::StoreSubscr => 1, + Self::ToBool => 3, + Self::BinaryOp => 5, + Self::Call => 3, + Self::CallKw => 3, + Self::CompareOp => 1, + Self::ContainsOp => 1, + Self::ForIter => 1, + Self::JumpBackward => 1, + Self::LoadAttr => 9, + Self::LoadGlobal => 4, + Self::LoadSuperAttr => 1, + Self::PopJumpIfFalse => 1, + Self::PopJumpIfNone => 1, + Self::PopJumpIfNotNone => 1, + Self::PopJumpIfTrue => 1, + Self::Send => 1, + Self::StoreAttr => 4, + Self::UnpackSequence => 1, + _ => 0, + } + } + + #[must_use] + pub const fn deopt(self) -> Option { + Some(match self { + Self::ResumeCheck => Self::Resume, + Self::LoadConstMortal | Self::LoadConstImmortal => Self::LoadConst, + Self::ToBoolAlwaysTrue + | Self::ToBoolBool + | Self::ToBoolInt + | Self::ToBoolList + | Self::ToBoolNone + | Self::ToBoolStr => Self::ToBool, + Self::BinaryOpMultiplyInt + | Self::BinaryOpAddInt + | Self::BinaryOpSubtractInt + | Self::BinaryOpMultiplyFloat + | Self::BinaryOpAddFloat + | Self::BinaryOpSubtractFloat + | Self::BinaryOpAddUnicode + | Self::BinaryOpSubscrListInt + | Self::BinaryOpSubscrListSlice + | Self::BinaryOpSubscrTupleInt + | Self::BinaryOpSubscrStrInt + | Self::BinaryOpSubscrDict + | Self::BinaryOpSubscrGetitem + | Self::BinaryOpExtend + | Self::BinaryOpInplaceAddUnicode => Self::BinaryOp, + Self::StoreSubscrDict | Self::StoreSubscrListInt => Self::StoreSubscr, + Self::SendGen => Self::Send, + Self::UnpackSequenceTwoTuple | Self::UnpackSequenceTuple | Self::UnpackSequenceList => { + Self::UnpackSequence + } + Self::StoreAttrInstanceValue | Self::StoreAttrSlot | Self::StoreAttrWithHint => { + Self::StoreAttr + } + Self::LoadGlobalModule | Self::LoadGlobalBuiltin => Self::LoadGlobal, + Self::LoadSuperAttrAttr | Self::LoadSuperAttrMethod => Self::LoadSuperAttr, + Self::LoadAttrInstanceValue + | Self::LoadAttrModule + | Self::LoadAttrWithHint + | Self::LoadAttrSlot + | Self::LoadAttrClass + | Self::LoadAttrClassWithMetaclassCheck + | Self::LoadAttrProperty + | Self::LoadAttrGetattributeOverridden + | Self::LoadAttrMethodWithValues + | Self::LoadAttrMethodNoDict + | Self::LoadAttrMethodLazyDict + | Self::LoadAttrNondescriptorWithValues + | Self::LoadAttrNondescriptorNoDict => Self::LoadAttr, + Self::CompareOpFloat | Self::CompareOpInt | Self::CompareOpStr => Self::CompareOp, + Self::ContainsOpSet | Self::ContainsOpDict => Self::ContainsOp, + Self::JumpBackwardNoJit | Self::JumpBackwardJit => Self::JumpBackward, + Self::ForIterList | Self::ForIterTuple | Self::ForIterRange | Self::ForIterGen => { + Self::ForIter + } + Self::CallBoundMethodExactArgs + | Self::CallPyExactArgs + | Self::CallType1 + | Self::CallStr1 + | Self::CallTuple1 + | Self::CallBuiltinClass + | Self::CallBuiltinO + | Self::CallBuiltinFast + | Self::CallBuiltinFastWithKeywords + | Self::CallLen + | Self::CallIsinstance + | Self::CallListAppend + | Self::CallMethodDescriptorO + | Self::CallMethodDescriptorFastWithKeywords + | Self::CallMethodDescriptorNoargs + | Self::CallMethodDescriptorFast + | Self::CallAllocAndEnterInit + | Self::CallPyGeneral + | Self::CallBoundMethodGeneral + | Self::CallNonPyGeneral => Self::Call, + Self::CallKwBoundMethod | Self::CallKwPy | Self::CallKwNonPy => Self::CallKw, + _ => return None, + }) + } + + /// Does this opcode have 'HAS_ARG_FLAG' set. + #[must_use] + pub const fn has_arg(self) -> bool { + matches!( + self, + Self::BinaryOp + | Self::BuildInterpolation + | Self::BuildList + | Self::BuildMap + | Self::BuildSet + | Self::BuildSlice + | Self::BuildString + | Self::BuildTuple + | Self::Call + | Self::CallIntrinsic1 + | Self::CallIntrinsic2 + | Self::CallKw + | Self::CompareOp + | Self::ContainsOp + | Self::ConvertValue + | Self::Copy + | Self::CopyFreeVars + | Self::DeleteAttr + | Self::DeleteDeref + | Self::DeleteFast + | Self::DeleteGlobal + | Self::DeleteName + | Self::DictMerge + | Self::DictUpdate + | Self::EndAsyncFor + | Self::ExtendedArg + | Self::ForIter + | Self::GetAwaitable + | Self::ImportFrom + | Self::ImportName + | Self::IsOp + | Self::JumpBackward + | Self::JumpBackwardNoInterrupt + | Self::JumpForward + | Self::ListAppend + | Self::ListExtend + | Self::LoadAttr + | Self::LoadCommonConstant + | Self::LoadConst + | Self::LoadDeref + | Self::LoadFast + | Self::LoadFastAndClear + | Self::LoadFastBorrow + | Self::LoadFastBorrowLoadFastBorrow + | Self::LoadFastCheck + | Self::LoadFastLoadFast + | Self::LoadFromDictOrDeref + | Self::LoadFromDictOrGlobals + | Self::LoadGlobal + | Self::LoadName + | Self::LoadSmallInt + | Self::LoadSpecial + | Self::LoadSuperAttr + | Self::MakeCell + | Self::MapAdd + | Self::MatchClass + | Self::PopJumpIfFalse + | Self::PopJumpIfNone + | Self::PopJumpIfNotNone + | Self::PopJumpIfTrue + | Self::RaiseVarargs + | Self::Reraise + | Self::Send + | Self::SetAdd + | Self::SetFunctionAttribute + | Self::SetUpdate + | Self::StoreAttr + | Self::StoreDeref + | Self::StoreFast + | Self::StoreFastLoadFast + | Self::StoreFastStoreFast + | Self::StoreGlobal + | Self::StoreName + | Self::Swap + | Self::UnpackEx + | Self::UnpackSequence + | Self::YieldValue + | Self::Resume + | Self::CallAllocAndEnterInit + | Self::CallBoundMethodExactArgs + | Self::CallBoundMethodGeneral + | Self::CallBuiltinClass + | Self::CallBuiltinFast + | Self::CallBuiltinFastWithKeywords + | Self::CallBuiltinO + | Self::CallIsinstance + | Self::CallKwBoundMethod + | Self::CallKwNonPy + | Self::CallKwPy + | Self::CallListAppend + | Self::CallMethodDescriptorFast + | Self::CallMethodDescriptorFastWithKeywords + | Self::CallMethodDescriptorNoargs + | Self::CallMethodDescriptorO + | Self::CallNonPyGeneral + | Self::CallPyExactArgs + | Self::CallPyGeneral + | Self::CallStr1 + | Self::CallTuple1 + | Self::CallType1 + | Self::CompareOpFloat + | Self::CompareOpInt + | Self::CompareOpStr + | Self::ContainsOpDict + | Self::ContainsOpSet + | Self::ForIterGen + | Self::ForIterList + | Self::ForIterRange + | Self::ForIterTuple + | Self::JumpBackwardJit + | Self::JumpBackwardNoJit + | Self::LoadAttrClass + | Self::LoadAttrClassWithMetaclassCheck + | Self::LoadAttrGetattributeOverridden + | Self::LoadAttrInstanceValue + | Self::LoadAttrMethodLazyDict + | Self::LoadAttrMethodNoDict + | Self::LoadAttrMethodWithValues + | Self::LoadAttrModule + | Self::LoadAttrNondescriptorNoDict + | Self::LoadAttrNondescriptorWithValues + | Self::LoadAttrProperty + | Self::LoadAttrSlot + | Self::LoadAttrWithHint + | Self::LoadConstImmortal + | Self::LoadConstMortal + | Self::LoadGlobalBuiltin + | Self::LoadGlobalModule + | Self::LoadSuperAttrAttr + | Self::LoadSuperAttrMethod + | Self::SendGen + | Self::StoreAttrWithHint + | Self::UnpackSequenceList + | Self::UnpackSequenceTuple + | Self::UnpackSequenceTwoTuple + | Self::InstrumentedForIter + | Self::InstrumentedJumpForward + | Self::InstrumentedPopJumpIfTrue + | Self::InstrumentedPopJumpIfFalse + | Self::InstrumentedPopJumpIfNone + | Self::InstrumentedPopJumpIfNotNone + | Self::InstrumentedResume + | Self::InstrumentedYieldValue + | Self::InstrumentedEndAsyncFor + | Self::InstrumentedLoadSuperAttr + | Self::InstrumentedCall + | Self::InstrumentedCallKw + | Self::InstrumentedJumpBackward + | Self::EnterExecutor + ) + } + + /// Does this opcode have 'HAS_CONST_FLAG' set. + #[must_use] + pub const fn has_const(self) -> bool { + matches!( + self, + Self::LoadConst | Self::LoadConstImmortal | Self::LoadConstMortal + ) + } + + /// Does this opcode have 'HAS_EVAL_BREAK_FLAG' set. + #[must_use] + pub const fn has_eval_break(self) -> bool { + matches!( + self, + Self::CallFunctionEx + | Self::Call + | Self::JumpBackward + | Self::Resume + | Self::CallBuiltinClass + | Self::CallBuiltinFast + | Self::CallBuiltinFastWithKeywords + | Self::CallBuiltinO + | Self::CallKwNonPy + | Self::CallMethodDescriptorFast + | Self::CallMethodDescriptorFastWithKeywords + | Self::CallMethodDescriptorNoargs + | Self::CallMethodDescriptorO + | Self::CallNonPyGeneral + | Self::CallStr1 + | Self::CallTuple1 + | Self::JumpBackwardJit + | Self::JumpBackwardNoJit + | Self::InstrumentedResume + | Self::InstrumentedCall + | Self::InstrumentedCallFunctionEx + | Self::InstrumentedJumpBackward + ) + } + + /// Does this opcode have 'HAS_FREE_FLAG' set. + #[must_use] + pub const fn has_free(self) -> bool { + matches!( + self, + Self::DeleteDeref | Self::LoadFromDictOrDeref | Self::MakeCell | Self::StoreDeref + ) + } + + /// Does this opcode have 'HAS_JUMP_FLAG' set. + #[must_use] + pub const fn has_jump(self) -> bool { + matches!( + self, + Self::EndAsyncFor + | Self::ForIter + | Self::JumpBackward + | Self::JumpBackwardNoInterrupt + | Self::JumpForward + | Self::PopJumpIfFalse + | Self::PopJumpIfNone + | Self::PopJumpIfNotNone + | Self::PopJumpIfTrue + | Self::Send + | Self::ForIterList + | Self::ForIterRange + | Self::ForIterTuple + | Self::JumpBackwardJit + | Self::JumpBackwardNoJit + | Self::InstrumentedForIter + | Self::InstrumentedEndAsyncFor + ) + } + + /// Does this opcode have 'HAS_LOCAL_FLAG' set. + #[must_use] + pub const fn has_local(self) -> bool { + matches!( + self, + Self::BinaryOpInplaceAddUnicode + | Self::DeleteFast + | Self::LoadDeref + | Self::LoadFast + | Self::LoadFastAndClear + | Self::LoadFastBorrow + | Self::LoadFastBorrowLoadFastBorrow + | Self::LoadFastCheck + | Self::LoadFastLoadFast + | Self::StoreFast + | Self::StoreFastLoadFast + | Self::StoreFastStoreFast + ) + } + + /// Does this opcode have 'HAS_NAME_FLAG' set. + #[must_use] + pub const fn has_name(self) -> bool { + matches!( + self, + Self::DeleteAttr + | Self::DeleteGlobal + | Self::DeleteName + | Self::ImportFrom + | Self::ImportName + | Self::LoadAttr + | Self::LoadFromDictOrGlobals + | Self::LoadGlobal + | Self::LoadName + | Self::LoadSuperAttr + | Self::StoreAttr + | Self::StoreGlobal + | Self::StoreName + | Self::LoadAttrGetattributeOverridden + | Self::LoadAttrWithHint + | Self::LoadSuperAttrAttr + | Self::LoadSuperAttrMethod + | Self::StoreAttrWithHint + | Self::InstrumentedLoadSuperAttr + ) + } + + #[must_use] + pub const fn is_instrumented(self) -> bool { + matches!( + self, + Self::InstrumentedEndFor + | Self::InstrumentedPopIter + | Self::InstrumentedEndSend + | Self::InstrumentedForIter + | Self::InstrumentedInstruction + | Self::InstrumentedJumpForward + | Self::InstrumentedNotTaken + | Self::InstrumentedPopJumpIfTrue + | Self::InstrumentedPopJumpIfFalse + | Self::InstrumentedPopJumpIfNone + | Self::InstrumentedPopJumpIfNotNone + | Self::InstrumentedResume + | Self::InstrumentedReturnValue + | Self::InstrumentedYieldValue + | Self::InstrumentedEndAsyncFor + | Self::InstrumentedLoadSuperAttr + | Self::InstrumentedCall + | Self::InstrumentedCallKw + | Self::InstrumentedCallFunctionEx + | Self::InstrumentedJumpBackward + | Self::InstrumentedLine + ) + } + + #[must_use] + pub fn stack_effect_info(&self, oparg: u32) -> StackEffect { + // Reason for converting oparg to i32 is because of expressions like `1 + (oparg -1)` + // that causes underflow errors. + let oparg = i32::try_from(oparg).expect("oparg does not fit in an `i32`"); + + let (pushed, popped) = match self { + Self::Cache => (0, 0), + Self::BinarySlice => (1, 3), + Self::BuildTemplate => (1, 2), + Self::BinaryOpInplaceAddUnicode => (0, 2), + Self::CallFunctionEx => (1, 4), + Self::CheckEgMatch => (2, 2), + Self::CheckExcMatch => (2, 2), + Self::CleanupThrow => (2, 3), + Self::DeleteSubscr => (0, 2), + Self::EndFor => (0, 1), + Self::EndSend => (1, 2), + Self::ExitInitCheck => (0, 1), + Self::FormatSimple => (1, 1), + Self::FormatWithSpec => (1, 2), + Self::GetAiter => (1, 1), + Self::GetAnext => (2, 1), + Self::GetIter => (1, 1), + Self::Reserved => (0, 0), + Self::GetLen => (2, 1), + Self::GetYieldFromIter => (1, 1), + Self::InterpreterExit => (0, 1), + Self::LoadBuildClass => (1, 0), + Self::LoadLocals => (1, 0), + Self::MakeFunction => (1, 1), + Self::MatchKeys => (3, 2), + Self::MatchMapping => (2, 1), + Self::MatchSequence => (2, 1), + Self::Nop => (0, 0), + Self::NotTaken => (0, 0), + Self::PopExcept => (0, 1), + Self::PopIter => (0, 1), + Self::PopTop => (0, 1), + Self::PushExcInfo => (2, 1), + Self::PushNull => (1, 0), + Self::ReturnGenerator => (1, 0), + Self::ReturnValue => (1, 1), + Self::SetupAnnotations => (0, 0), + Self::StoreSlice => (0, 4), + Self::StoreSubscr => (0, 3), + Self::ToBool => (1, 1), + Self::UnaryInvert => (1, 1), + Self::UnaryNegative => (1, 1), + Self::UnaryNot => (1, 1), + Self::WithExceptStart => (6, 5), + Self::BinaryOp => (1, 2), + Self::BuildInterpolation => (1, 2 + (oparg & 1)), + Self::BuildList => (1, oparg), + Self::BuildMap => (1, oparg * 2), + Self::BuildSet => (1, oparg), + Self::BuildSlice => (1, oparg), + Self::BuildString => (1, oparg), + Self::BuildTuple => (1, oparg), + Self::Call => (1, 2 + oparg), + Self::CallIntrinsic1 => (1, 1), + Self::CallIntrinsic2 => (1, 2), + Self::CallKw => (1, 3 + oparg), + Self::CompareOp => (1, 2), + Self::ContainsOp => (1, 2), + Self::ConvertValue => (1, 1), + Self::Copy => (2 + (oparg - 1), 1 + (oparg - 1)), + Self::CopyFreeVars => (0, 0), + Self::DeleteAttr => (0, 1), + Self::DeleteDeref => (0, 0), + Self::DeleteFast => (0, 0), + Self::DeleteGlobal => (0, 0), + Self::DeleteName => (0, 0), + Self::DictMerge => (4 + (oparg - 1), 5 + (oparg - 1)), + Self::DictUpdate => (1 + (oparg - 1), 2 + (oparg - 1)), + Self::EndAsyncFor => (0, 2), + Self::ExtendedArg => (0, 0), + Self::ForIter => (2, 1), + Self::GetAwaitable => (1, 1), + Self::ImportFrom => (2, 1), + Self::ImportName => (1, 2), + Self::IsOp => (1, 2), + Self::JumpBackward => (0, 0), + Self::JumpBackwardNoInterrupt => (0, 0), + Self::JumpForward => (0, 0), + Self::ListAppend => (1 + (oparg - 1), 2 + (oparg - 1)), + Self::ListExtend => (1 + (oparg - 1), 2 + (oparg - 1)), + Self::LoadAttr => (1 + (oparg & 1), 1), + Self::LoadCommonConstant => (1, 0), + Self::LoadConst => (1, 0), + Self::LoadDeref => (1, 0), + Self::LoadFast => (1, 0), + Self::LoadFastAndClear => (1, 0), + Self::LoadFastBorrow => (1, 0), + Self::LoadFastBorrowLoadFastBorrow => (2, 0), + Self::LoadFastCheck => (1, 0), + Self::LoadFastLoadFast => (2, 0), + Self::LoadFromDictOrDeref => (1, 1), + Self::LoadFromDictOrGlobals => (1, 1), + Self::LoadGlobal => (1 + (oparg & 1), 0), + Self::LoadName => (1, 0), + Self::LoadSmallInt => (1, 0), + Self::LoadSpecial => (2, 1), + Self::LoadSuperAttr => (1 + (oparg & 1), 3), + Self::MakeCell => (0, 0), + Self::MapAdd => (1 + (oparg - 1), 3 + (oparg - 1)), + Self::MatchClass => (1, 3), + Self::PopJumpIfFalse => (0, 1), + Self::PopJumpIfNone => (0, 1), + Self::PopJumpIfNotNone => (0, 1), + Self::PopJumpIfTrue => (0, 1), + Self::RaiseVarargs => (0, oparg), + Self::Reraise => (oparg, 1 + oparg), + Self::Send => (2, 2), + Self::SetAdd => (1 + (oparg - 1), 2 + (oparg - 1)), + Self::SetFunctionAttribute => (1, 2), + Self::SetUpdate => (1 + (oparg - 1), 2 + (oparg - 1)), + Self::StoreAttr => (0, 2), + Self::StoreDeref => (0, 1), + Self::StoreFast => (0, 1), + Self::StoreFastLoadFast => (1, 1), + Self::StoreFastStoreFast => (0, 2), + Self::StoreGlobal => (0, 1), + Self::StoreName => (0, 1), + Self::Swap => (2 + (oparg - 2), 2 + (oparg - 2)), + Self::UnpackEx => (1 + (oparg & 0xFF) + (oparg >> 8), 1), + Self::UnpackSequence => (oparg, 1), + Self::YieldValue => (1, 1), + Self::Resume => (0, 0), + Self::BinaryOpAddFloat => (1, 2), + Self::BinaryOpAddInt => (1, 2), + Self::BinaryOpAddUnicode => (1, 2), + Self::BinaryOpExtend => (1, 2), + Self::BinaryOpMultiplyFloat => (1, 2), + Self::BinaryOpMultiplyInt => (1, 2), + Self::BinaryOpSubscrDict => (1, 2), + Self::BinaryOpSubscrGetitem => (0, 2), + Self::BinaryOpSubscrListInt => (1, 2), + Self::BinaryOpSubscrListSlice => (1, 2), + Self::BinaryOpSubscrStrInt => (1, 2), + Self::BinaryOpSubscrTupleInt => (1, 2), + Self::BinaryOpSubtractFloat => (1, 2), + Self::BinaryOpSubtractInt => (1, 2), + Self::CallAllocAndEnterInit => (0, 2 + oparg), + Self::CallBoundMethodExactArgs => (0, 2 + oparg), + Self::CallBoundMethodGeneral => (0, 2 + oparg), + Self::CallBuiltinClass => (1, 2 + oparg), + Self::CallBuiltinFast => (1, 2 + oparg), + Self::CallBuiltinFastWithKeywords => (1, 2 + oparg), + Self::CallBuiltinO => (1, 2 + oparg), + Self::CallIsinstance => (1, 2 + oparg), + Self::CallKwBoundMethod => (0, 3 + oparg), + Self::CallKwNonPy => (1, 3 + oparg), + Self::CallKwPy => (0, 3 + oparg), + Self::CallLen => (1, 3), + Self::CallListAppend => (0, 3), + Self::CallMethodDescriptorFast => (1, 2 + oparg), + Self::CallMethodDescriptorFastWithKeywords => (1, 2 + oparg), + Self::CallMethodDescriptorNoargs => (1, 2 + oparg), + Self::CallMethodDescriptorO => (1, 2 + oparg), + Self::CallNonPyGeneral => (1, 2 + oparg), + Self::CallPyExactArgs => (0, 2 + oparg), + Self::CallPyGeneral => (0, 2 + oparg), + Self::CallStr1 => (1, 3), + Self::CallTuple1 => (1, 3), + Self::CallType1 => (1, 3), + Self::CompareOpFloat => (1, 2), + Self::CompareOpInt => (1, 2), + Self::CompareOpStr => (1, 2), + Self::ContainsOpDict => (1, 2), + Self::ContainsOpSet => (1, 2), + Self::ForIterGen => (1, 1), + Self::ForIterList => (2, 1), + Self::ForIterRange => (2, 1), + Self::ForIterTuple => (2, 1), + Self::JumpBackwardJit => (0, 0), + Self::JumpBackwardNoJit => (0, 0), + Self::LoadAttrClass => (1 + (oparg & 1), 1), + Self::LoadAttrClassWithMetaclassCheck => (1 + (oparg & 1), 1), + Self::LoadAttrGetattributeOverridden => (1, 1), + Self::LoadAttrInstanceValue => (1 + (oparg & 1), 1), + Self::LoadAttrMethodLazyDict => (2, 1), + Self::LoadAttrMethodNoDict => (2, 1), + Self::LoadAttrMethodWithValues => (2, 1), + Self::LoadAttrModule => (1 + (oparg & 1), 1), + Self::LoadAttrNondescriptorNoDict => (1, 1), + Self::LoadAttrNondescriptorWithValues => (1, 1), + Self::LoadAttrProperty => (0, 1), + Self::LoadAttrSlot => (1 + (oparg & 1), 1), + Self::LoadAttrWithHint => (1 + (oparg & 1), 1), + Self::LoadConstImmortal => (1, 0), + Self::LoadConstMortal => (1, 0), + Self::LoadGlobalBuiltin => (1 + (oparg & 1), 0), + Self::LoadGlobalModule => (1 + (oparg & 1), 0), + Self::LoadSuperAttrAttr => (1, 3), + Self::LoadSuperAttrMethod => (2, 3), + Self::ResumeCheck => (0, 0), + Self::SendGen => (1, 2), + Self::StoreAttrInstanceValue => (0, 2), + Self::StoreAttrSlot => (0, 2), + Self::StoreAttrWithHint => (0, 2), + Self::StoreSubscrDict => (0, 3), + Self::StoreSubscrListInt => (0, 3), + Self::ToBoolAlwaysTrue => (1, 1), + Self::ToBoolBool => (1, 1), + Self::ToBoolInt => (1, 1), + Self::ToBoolList => (1, 1), + Self::ToBoolNone => (1, 1), + Self::ToBoolStr => (1, 1), + Self::UnpackSequenceList => (oparg, 1), + Self::UnpackSequenceTuple => (oparg, 1), + Self::UnpackSequenceTwoTuple => (2, 1), + Self::InstrumentedEndFor => (1, 2), + Self::InstrumentedPopIter => (0, 1), + Self::InstrumentedEndSend => (1, 2), + Self::InstrumentedForIter => (2, 1), + Self::InstrumentedInstruction => (0, 0), + Self::InstrumentedJumpForward => (0, 0), + Self::InstrumentedNotTaken => (0, 0), + Self::InstrumentedPopJumpIfTrue => (0, 1), + Self::InstrumentedPopJumpIfFalse => (0, 1), + Self::InstrumentedPopJumpIfNone => (0, 1), + Self::InstrumentedPopJumpIfNotNone => (0, 1), + Self::InstrumentedResume => (0, 0), + Self::InstrumentedReturnValue => (1, 1), + Self::InstrumentedYieldValue => (1, 1), + Self::InstrumentedEndAsyncFor => (0, 2), + Self::InstrumentedLoadSuperAttr => (1 + (oparg & 1), 3), + Self::InstrumentedCall => (1, 2 + oparg), + Self::InstrumentedCallKw => (1, 3 + oparg), + Self::InstrumentedCallFunctionEx => (1, 4), + Self::InstrumentedJumpBackward => (0, 0), + Self::InstrumentedLine => (0, 0), + Self::EnterExecutor => (0, 0), + }; + + debug_assert!(u32::try_from(pushed).is_ok()); + debug_assert!(u32::try_from(popped).is_ok()); + + StackEffect::new(pushed as u32, popped as u32) + } + + #[must_use] + pub const fn to_base(self) -> Option { + Some(match self { + Self::InstrumentedCall => Self::Call, + Self::InstrumentedCallFunctionEx => Self::CallFunctionEx, + Self::InstrumentedCallKw => Self::CallKw, + Self::InstrumentedEndAsyncFor => Self::EndAsyncFor, + Self::InstrumentedEndFor => Self::EndFor, + Self::InstrumentedEndSend => Self::EndSend, + Self::InstrumentedForIter => Self::ForIter, + Self::InstrumentedJumpBackward => Self::JumpBackward, + Self::InstrumentedJumpForward => Self::JumpForward, + Self::InstrumentedLoadSuperAttr => Self::LoadSuperAttr, + Self::InstrumentedNotTaken => Self::NotTaken, + Self::InstrumentedPopIter => Self::PopIter, + Self::InstrumentedPopJumpIfFalse => Self::PopJumpIfFalse, + Self::InstrumentedPopJumpIfNone => Self::PopJumpIfNone, + Self::InstrumentedPopJumpIfNotNone => Self::PopJumpIfNotNone, + Self::InstrumentedPopJumpIfTrue => Self::PopJumpIfTrue, + Self::InstrumentedResume => Self::Resume, + Self::InstrumentedReturnValue => Self::ReturnValue, + Self::InstrumentedYieldValue => Self::YieldValue, + _ => return None, + }) + } + + #[must_use] + pub const fn to_instrumented(self) -> Option { + Some(match self { + Self::Call => Self::InstrumentedCall, + Self::CallFunctionEx => Self::InstrumentedCallFunctionEx, + Self::CallKw => Self::InstrumentedCallKw, + Self::EndAsyncFor => Self::InstrumentedEndAsyncFor, + Self::EndFor => Self::InstrumentedEndFor, + Self::EndSend => Self::InstrumentedEndSend, + Self::ForIter => Self::InstrumentedForIter, + Self::JumpBackward => Self::InstrumentedJumpBackward, + Self::JumpForward => Self::InstrumentedJumpForward, + Self::LoadSuperAttr => Self::InstrumentedLoadSuperAttr, + Self::NotTaken => Self::InstrumentedNotTaken, + Self::PopIter => Self::InstrumentedPopIter, + Self::PopJumpIfFalse => Self::InstrumentedPopJumpIfFalse, + Self::PopJumpIfNone => Self::InstrumentedPopJumpIfNone, + Self::PopJumpIfNotNone => Self::InstrumentedPopJumpIfNotNone, + Self::PopJumpIfTrue => Self::InstrumentedPopJumpIfTrue, + Self::Resume => Self::InstrumentedResume, + Self::ReturnValue => Self::InstrumentedReturnValue, + Self::YieldValue => Self::InstrumentedYieldValue, + _ => return None, + }) + } + + pub const fn try_from_u8(value: u8) -> Result { + Self::try_from_numeric(value) + } +} + +impl super::PseudoOpcode { + /// Returns [`Self`] as [`u16`]. + #[must_use] + pub const fn as_u16(self) -> u16 { + self.as_numeric() + } + + #[must_use] + pub const fn cache_entries(self) -> usize { + 0 + } + + #[must_use] + pub const fn deopt(self) -> Option { + None + } + + /// Does this opcode have 'HAS_ARG_FLAG' set. + #[must_use] + pub const fn has_arg(self) -> bool { + matches!( + self, + Self::Jump + | Self::JumpIfFalse + | Self::JumpIfTrue + | Self::JumpNoInterrupt + | Self::LoadClosure + | Self::StoreFastMaybeNull + ) + } + + /// Does this opcode have 'HAS_CONST_FLAG' set. + #[must_use] + pub const fn has_const(self) -> bool { + false + } + + /// Does this opcode have 'HAS_EVAL_BREAK_FLAG' set. + #[must_use] + pub const fn has_eval_break(self) -> bool { + matches!(self, Self::Jump) + } + + /// Does this opcode have 'HAS_FREE_FLAG' set. + #[must_use] + pub const fn has_free(self) -> bool { + false + } + + /// Does this opcode have 'HAS_JUMP_FLAG' set. + #[must_use] + pub const fn has_jump(self) -> bool { + matches!( + self, + Self::Jump | Self::JumpIfFalse | Self::JumpIfTrue | Self::JumpNoInterrupt + ) + } + + /// Does this opcode have 'HAS_LOCAL_FLAG' set. + #[must_use] + pub const fn has_local(self) -> bool { + matches!(self, Self::LoadClosure | Self::StoreFastMaybeNull) + } + + /// Does this opcode have 'HAS_NAME_FLAG' set. + #[must_use] + pub const fn has_name(self) -> bool { + false + } + + #[must_use] + pub const fn is_instrumented(self) -> bool { + false + } + + #[must_use] + pub fn stack_effect_info(&self, _oparg: u32) -> StackEffect { + let (pushed, popped) = match self { + Self::AnnotationsPlaceholder => (0, 0), + Self::Jump => (0, 0), + Self::JumpIfFalse => (1, 1), + Self::JumpIfTrue => (1, 1), + Self::JumpNoInterrupt => (0, 0), + Self::LoadClosure => (1, 0), + Self::PopBlock => (0, 0), + Self::SetupCleanup => (2, 0), + Self::SetupFinally => (1, 0), + Self::SetupWith => (1, 0), + Self::StoreFastMaybeNull => (0, 1), + }; + + debug_assert!(u32::try_from(pushed).is_ok()); + debug_assert!(u32::try_from(popped).is_ok()); + + StackEffect::new(pushed as u32, popped as u32) + } + + #[must_use] + pub const fn to_base(self) -> Option { + None + } + + #[must_use] + pub const fn to_instrumented(self) -> Option { + None + } + + pub const fn try_from_u16(value: u16) -> Result { + Self::try_from_numeric(value) + } +} diff --git a/crates/compiler-core/src/lib.rs b/crates/compiler-core/src/lib.rs index 245713d1a14..8fe1146e514 100644 --- a/crates/compiler-core/src/lib.rs +++ b/crates/compiler-core/src/lib.rs @@ -1,4 +1,5 @@ #![no_std] +#![recursion_limit = "256"] // Needed for `define_opcodes!` macro #![doc(html_logo_url = "https://raw.githubusercontent.com/RustPython/RustPython/main/logo.png")] #![doc(html_root_url = "https://docs.rs/rustpython-compiler-core/")] diff --git a/crates/compiler-core/src/marshal.rs b/crates/compiler-core/src/marshal.rs index 829c1dc9519..0c28a83e72a 100644 --- a/crates/compiler-core/src/marshal.rs +++ b/crates/compiler-core/src/marshal.rs @@ -194,6 +194,22 @@ pub fn deserialize_code( rdr: &mut R, bag: Bag, ) -> Result> { + let mut refs: Vec> = Vec::new(); + deserialize_code_inner(rdr, bag, MAX_MARSHAL_STACK_DEPTH, &mut refs) +} + +/// Inner code-object deserializer that shares a ref table with caller. +/// Used when decoding a code object embedded in another marshal stream so +/// that TYPE_REF entries inside the code can resolve across nested values. +fn deserialize_code_inner( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, +) -> Result> { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } // 1–5: scalar fields let arg_count = rdr.read_u32()?; let posonlyarg_count = rdr.read_u32()?; @@ -202,24 +218,24 @@ pub fn deserialize_code( let flags = CodeFlags::from_bits_truncate(rdr.read_u32()?); // 6: co_code - let code_bytes = read_marshal_bytes(rdr)?; + let code_bytes = read_marshal_bytes(rdr, &bag, refs)?; // 7: co_consts - let constants = read_marshal_const_tuple(rdr, bag)?; + let constants = read_marshal_const_tuple(rdr, bag, depth, refs)?; // 8: co_names - let names = read_marshal_name_tuple(rdr, &bag)?; + let names = read_marshal_name_tuple(rdr, &bag, refs)?; // 9: co_localsplusnames - let localsplusnames = read_marshal_str_vec(rdr)?; + let localsplusnames = read_marshal_str_vec(rdr, &bag, refs)?; // 10: co_localspluskinds - let localspluskinds = read_marshal_bytes(rdr)?; + let localspluskinds = read_marshal_bytes(rdr, &bag, refs)?; // 11–13: filename, name, qualname - let source_path = bag.make_name(&read_marshal_str(rdr)?); - let obj_name = bag.make_name(&read_marshal_str(rdr)?); - let qualname = bag.make_name(&read_marshal_str(rdr)?); + let source_path = bag.make_name(&read_marshal_str(rdr, &bag, refs)?); + let obj_name = bag.make_name(&read_marshal_str(rdr, &bag, refs)?); + let qualname = bag.make_name(&read_marshal_str(rdr, &bag, refs)?); // 14: co_firstlineno let first_line_raw = rdr.read_u32()? as i32; @@ -230,8 +246,8 @@ pub fn deserialize_code( }; // 15–16: linetable, exceptiontable - let linetable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice(); - let exceptiontable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice(); + let linetable = read_marshal_bytes(rdr, &bag, refs)?.into_boxed_slice(); + let exceptiontable = read_marshal_bytes(rdr, &bag, refs)?.into_boxed_slice(); // Split localsplusnames/kinds → varnames/cellvars/freevars let lp = split_localplus( @@ -275,72 +291,238 @@ pub fn deserialize_code( }) } -/// Read a marshal bytes object (TYPE_STRING = b's'). -fn read_marshal_bytes(rdr: &mut R) -> Result> { - let type_byte = rdr.read_u8()? & !FLAG_REF; +/// Reserve a ref slot if `FLAG_REF` was present, returning its index. +fn reserve_ref_slot(has_flag: bool, refs: &mut Vec>) -> Option { + if has_flag { + let idx = refs.len(); + refs.push(None); + Some(idx) + } else { + None + } +} + +/// Resolve a TYPE_REF index, returning the previously stored value. +fn resolve_ref(idx: usize, refs: &[Option]) -> Result { + refs.get(idx) + .and_then(|v| v.clone()) + .ok_or(MarshalError::InvalidBytecode) +} + +/// Read a marshal bytes object (TYPE_STRING = b's'), resolving TYPE_REF +/// and registering this read in the ref table when `FLAG_REF` is set. +fn read_marshal_bytes( + rdr: &mut R, + bag: &Bag, + refs: &mut Vec>, +) -> Result> { + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Bytes { value } => Ok(value.to_vec()), + _ => Err(MarshalError::BadType), + }; + } + if type_byte != Type::Bytes as u8 { return Err(MarshalError::BadType); } + + let slot = reserve_ref_slot(has_flag, refs); let len = rdr.read_u32()?; - Ok(rdr.read_slice(len)?.to_vec()) + let bytes = rdr.read_slice(len)?.to_vec(); + if let Some(idx) = slot { + refs[idx] = + Some(bag.make_constant::(BorrowedConstant::Bytes { value: &bytes })); + } + Ok(bytes) } -/// Read a marshal string object. -fn read_marshal_str(rdr: &mut R) -> Result { - let type_byte = rdr.read_u8()? & !FLAG_REF; - let s = match type_byte { +/// Read a marshal string object, resolving TYPE_REF and registering +/// this read in the ref table when `FLAG_REF` is set. +fn read_marshal_str( + rdr: &mut R, + bag: &Bag, + refs: &mut Vec>, +) -> Result { + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Str { value } => Ok(value.to_string_lossy().into_owned()), + _ => Err(MarshalError::BadType), + }; + } + + let slot = reserve_ref_slot(has_flag, refs); + let owned = match type_byte { b'u' | b't' | b'a' | b'A' => { let len = rdr.read_u32()?; - rdr.read_str(len)? + alloc::string::String::from(rdr.read_str(len)?) } b'z' | b'Z' => { let len = rdr.read_u8()? as u32; - rdr.read_str(len)? + alloc::string::String::from(rdr.read_str(len)?) } _ => return Err(MarshalError::BadType), }; - Ok(alloc::string::String::from(s)) + if let Some(idx) = slot { + refs[idx] = Some(bag.make_constant::(BorrowedConstant::Str { + value: Wtf8::new(owned.as_str()), + })); + } + Ok(owned) } /// Read a marshal tuple of strings, returning owned Strings. -fn read_marshal_str_vec(rdr: &mut R) -> Result> { - let type_byte = rdr.read_u8()? & !FLAG_REF; +fn read_marshal_str_vec( + rdr: &mut R, + bag: &Bag, + refs: &mut Vec>, +) -> Result> { + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Tuple { elements } => elements + .iter() + .map(|c| match c.borrow_constant() { + BorrowedConstant::Str { value } => Ok(value.to_string_lossy().into_owned()), + _ => Err(MarshalError::BadType), + }) + .collect(), + _ => Err(MarshalError::BadType), + }; + } + let n = match type_byte { b'(' => rdr.read_u32()? as usize, b')' => rdr.read_u8()? as usize, _ => return Err(MarshalError::BadType), }; - (0..n).map(|_| read_marshal_str(rdr)).collect() + let slot = reserve_ref_slot(has_flag, refs); + let items: Vec = (0..n) + .map(|_| read_marshal_str(rdr, bag, refs)) + .collect::>()?; + if let Some(idx) = slot { + let elements: Vec = items + .iter() + .map(|s| { + bag.make_constant::(BorrowedConstant::Str { + value: Wtf8::new(s.as_str()), + }) + }) + .collect(); + refs[idx] = Some(bag.make_constant::(BorrowedConstant::Tuple { + elements: &elements, + })); + } + Ok(items) } fn read_marshal_name_tuple( rdr: &mut R, bag: &Bag, + refs: &mut Vec>, ) -> Result::Name]>> { - let type_byte = rdr.read_u8()? & !FLAG_REF; - let n = match type_byte { - b'(' => rdr.read_u32()? as usize, - b')' => rdr.read_u8()? as usize, - _ => return Err(MarshalError::BadType), - }; - (0..n) - .map(|_| Ok(bag.make_name(&read_marshal_str(rdr)?))) - .collect::>>() - .map(Vec::into_boxed_slice) + let names = read_marshal_str_vec(rdr, bag, refs)?; + Ok(names + .iter() + .map(|s| bag.make_name(s)) + .collect::>() + .into_boxed_slice()) } -/// Read a marshal tuple of constants. +/// Read a marshal tuple of constants. Shares the ref table with the +/// surrounding code-object decode so that nested TYPE_REF entries (for +/// strings, bytes, code objects, etc.) resolve correctly. fn read_marshal_const_tuple( rdr: &mut R, bag: Bag, + depth: usize, + refs: &mut Vec>, ) -> Result> { - let type_byte = rdr.read_u8()? & !FLAG_REF; + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Tuple { elements } => Ok(elements.iter().cloned().collect()), + _ => Err(MarshalError::BadType), + }; + } + let n = match type_byte { b'(' => rdr.read_u32()? as usize, b')' => rdr.read_u8()? as usize, _ => return Err(MarshalError::BadType), }; - (0..n).map(|_| deserialize_value(rdr, bag)).collect() + let slot = reserve_ref_slot(has_flag, refs); + let child_depth = depth - 1; + let items: Vec = (0..n) + .map(|_| read_const_value(rdr, bag, child_depth, refs)) + .collect::>()?; + if let Some(idx) = slot { + refs[idx] = + Some(bag.make_constant::(BorrowedConstant::Tuple { elements: &items })); + } + Ok(items.into_iter().collect()) +} + +/// Read a single value while staying inside an existing code-object ref +/// space. Unlike `deserialize_value_depth`, encountering `Type::Code` +/// here reuses the caller's ref table instead of opening a fresh one — +/// this matches CPython's single global ref space for objects nested +/// inside a code object's const tuple. +fn read_const_value( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, +) -> Result { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } + let raw = rdr.read_u8()?; + let flag = raw & FLAG_REF != 0; + let type_code = raw & !FLAG_REF; + + if type_code == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + return resolve_ref(idx, refs); + } + + let slot = reserve_ref_slot(flag, refs); + let typ = Type::try_from(type_code)?; + let value = if matches!(typ, Type::Code) { + let code = deserialize_code_inner(rdr, bag, depth - 1, refs)?; + bag.make_code(code) + } else { + deserialize_value_typed(rdr, bag, depth, refs, typ)? + }; + if let Some(idx) = slot { + refs[idx] = Some(value.clone()); + } + Ok(value) } pub trait MarshalBag: Copy { @@ -393,6 +575,13 @@ pub trait MarshalBag: Copy { } fn constant_bag(self) -> Self::ConstantBag; + + fn constant_ref_from_value( + &self, + _value: &Self::Value, + ) -> Option<::Constant> { + None + } } impl MarshalBag for Bag { @@ -487,6 +676,13 @@ impl MarshalBag for Bag { fn constant_bag(self) -> Self::ConstantBag { self } + + fn constant_ref_from_value( + &self, + value: &Self::Value, + ) -> Option<::Constant> { + Some(value.clone()) + } } pub const MAX_MARSHAL_STACK_DEPTH: usize = 2000; @@ -506,6 +702,23 @@ fn deserialize_value_depth( return Err(MarshalError::InvalidBytecode); } let raw = rdr.read_u8()?; + deserialize_value_after_header(rdr, bag, depth, refs, raw) +} + +/// Continue deserializing a value after the header byte has already been +/// consumed. Shared by `deserialize_value_depth` and the dict-key branch, +/// where the header byte is read up front to detect the TYPE_NULL +/// terminator. +fn deserialize_value_after_header( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, + raw: u8, +) -> Result { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } let flag = raw & FLAG_REF != 0; let type_code = raw & !FLAG_REF; @@ -528,7 +741,23 @@ fn deserialize_value_depth( }; let typ = Type::try_from(type_code)?; - let value = deserialize_value_typed(rdr, bag, depth, refs, typ)?; + // CPython's r_object() uses one global ref table: TYPE_CODE reserves its + // slot before reading code fields, and those fields may use later TYPE_REF + // indexes. Keep the same indexes even when Bag::Value and Constant differ. + let value = if matches!(typ, Type::Code) { + let mut inner_refs: Vec::Constant>> = refs + .iter() + .map(|value| { + value + .as_ref() + .and_then(|value| bag.constant_ref_from_value(value)) + }) + .collect(); + let code = deserialize_code_inner(rdr, bag.constant_bag(), depth - 1, &mut inner_refs)?; + bag.make_code(code) + } else { + deserialize_value_typed(rdr, bag, depth, refs, typ)? + }; if let Some(idx) = slot { refs[idx] = Some(value.clone()); @@ -630,32 +859,10 @@ fn deserialize_value_typed( let mut pairs = Vec::new(); loop { let raw = rdr.read_u8()?; - let type_code = raw & !FLAG_REF; - if type_code == b'0' { + if raw & !FLAG_REF == b'0' { break; } - // TYPE_REF for key - let k = if type_code == Type::Ref as u8 { - let idx = rdr.read_u32()? as usize; - refs.get(idx) - .and_then(|v| v.clone()) - .ok_or(MarshalError::InvalidBytecode)? - } else { - let flag = raw & FLAG_REF != 0; - let key_slot = if flag { - let idx = refs.len(); - refs.push(None); - Some(idx) - } else { - None - }; - let key_type = Type::try_from(type_code)?; - let k = deserialize_value_typed(rdr, bag, d, refs, key_type)?; - if let Some(idx) = key_slot { - refs[idx] = Some(k.clone()); - } - k - }; + let k = deserialize_value_after_header(rdr, bag, d, refs, raw)?; let v = deserialize_value_depth(rdr, bag, d, refs)?; pairs.push((k, v)); } @@ -667,7 +874,7 @@ fn deserialize_value_typed( let value = rdr.read_slice(len)?; bag.make_bytes(value) } - Type::Code => bag.make_code(deserialize_code(rdr, bag.constant_bag())?), + Type::Code => return Err(MarshalError::BadType), Type::Slice => { let d = depth - 1; let start = deserialize_value_depth(rdr, bag, d, refs)?; @@ -1288,3 +1495,119 @@ fn lt_read_signed_varint(data: &[u8], pos: &mut usize) -> i32 { (val >> 1) as i32 } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::bytecode::{BasicBag, ConstantData}; + + fn hex_to_bytes(hex: &str) -> Vec { + (0..hex.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap()) + .collect() + } + + fn decode_code(hex: &str) -> CodeObject { + let bytes = hex_to_bytes(hex); + let value = deserialize_value(&mut &bytes[..], BasicBag).expect("decode failed"); + match value { + ConstantData::Code { code } => *code, + other => panic!("expected Code, got {other:?}"), + } + } + + fn decode_tuple(hex: &str) -> Vec { + let bytes = hex_to_bytes(hex); + let value = deserialize_value(&mut &bytes[..], BasicBag).expect("decode failed"); + match value { + ConstantData::Tuple { elements } => elements, + other => panic!("expected Tuple, got {other:?}"), + } + } + + /// CPython 3.14 marshal output for: `compile("x = 1", "", "exec")`. + /// Exercises FLAG_REF on the code object and TYPE_REF for qualname + /// pointing back at the obj_name slot. + #[test] + fn cpython_314_trivial_assignment() { + let hex = "e30000000000000000000000000100000000000000f30a00000080005e017400520123002902\ + e9010000004e2901da0178a900f300000000da033c743eda083c6d6f64756c653e72070000000100\ + 0000730a000000f003010101d8040582017205000000"; + let code = decode_code(hex); + assert_eq!(code.obj_name.as_str(), ""); + assert_eq!(code.qualname.as_str(), ""); + assert_eq!(code.source_path.as_str(), ""); + assert_eq!(code.arg_count, 0); + assert_eq!(code.max_stackdepth, 1); + assert_eq!(code.names.len(), 1); + assert_eq!(code.names[0].as_str(), "x"); + assert_eq!(code.constants.len(), 2); + // (1, None) + let consts: &[ConstantData] = &code.constants; + assert!(matches!( + consts[0], + ConstantData::Integer { ref value } if *value == 1.into(), + )); + assert!(matches!(consts[1], ConstantData::None)); + } + + /// CPython 3.14 marshal output for a module with a nested function + /// and a string constant. Verifies that nested code objects inside + /// a const tuple share the surrounding code's ref space. + #[test] + fn cpython_314_nested_code_and_string_const() { + let hex = "e30000000000000000000000000100000000000000f310000000800052001700740052017401\ + 520223002903630200000000000000000000000200000003000000f3120000008000570\ + 12c0000000000000000000000230029014ea9002902da0161da016273020000002626da033c743e\ + da0361646472070000000200000073090000008000d80b0c8d35804cf300000000da0568656c6c\ + 6f4e29027207000000da084752454554494e47720300000072080000007206000000da083c6d6f\ + 64756c653e720b000000010000007311000000f003010101f204010111f006000c1382087208000000"; + let code = decode_code(hex); + assert_eq!(code.obj_name.as_str(), ""); + assert_eq!(code.names.len(), 2); + assert_eq!(code.names[0].as_str(), "add"); + assert_eq!(code.names[1].as_str(), "GREETING"); + assert_eq!(code.constants.len(), 3); + // Inner code, "hello", None + let consts: &[ConstantData] = &code.constants; + let inner = match &consts[0] { + ConstantData::Code { code } => code, + other => panic!("expected nested Code, got {other:?}"), + }; + assert_eq!(inner.obj_name.as_str(), "add"); + assert_eq!(inner.qualname.as_str(), "add"); + assert_eq!(inner.arg_count, 2); + assert_eq!(inner.varnames.len(), 2); + assert_eq!(inner.varnames[0].as_str(), "a"); + assert_eq!(inner.varnames[1].as_str(), "b"); + assert!(matches!( + consts[1], + ConstantData::Str { ref value } if value.as_str().ok() == Some("hello"), + )); + assert!(matches!(consts[2], ConstantData::None)); + } + + /// CPython 3.14 marshal output for: + /// `(compile("x = 1", "", "exec"),)`. + /// The outer tuple occupies ref slot 0 and the code object occupies + /// slot 1, so code-object fields must preserve that global ref offset. + #[test] + fn cpython_314_code_inside_tuple_preserves_ref_indexes() { + let hex = "a901630000000000000000000000000100000000000000f30a00000080005e017400\ + 520123002902e9010000004e2901da0178a900f300000000da033c743eda083c6d6f\ + 64756c653e720700000001000000730a000000f003010101d8040582017205000000"; + let tuple = decode_tuple(hex); + assert_eq!(tuple.len(), 1); + let code = match &tuple[0] { + ConstantData::Code { code } => code, + other => panic!("expected nested Code, got {other:?}"), + }; + assert_eq!(code.obj_name.as_str(), ""); + assert_eq!(code.qualname.as_str(), ""); + assert_eq!(code.source_path.as_str(), ""); + assert_eq!(code.names.len(), 1); + assert_eq!(code.names[0].as_str(), "x"); + assert_eq!(code.constants.len(), 2); + } +} diff --git a/crates/compiler-core/src/varint.rs b/crates/compiler-core/src/varint.rs index c07b8b58e6a..e4ae7cdf682 100644 --- a/crates/compiler-core/src/varint.rs +++ b/crates/compiler-core/src/varint.rs @@ -106,7 +106,7 @@ mod tests { use super::*; #[test] - fn test_le_varint_roundtrip() { + fn le_varint_roundtrip() { // Little-endian is only used internally in linetable, // no read function needed outside of linetable parsing. let mut buf = Vec::new(); @@ -118,7 +118,7 @@ mod tests { } #[test] - fn test_be_varint_roundtrip() { + fn be_varint_roundtrip() { for &val in &[0u32, 1, 63, 64, 127, 128, 4095, 4096, 1_000_000] { let mut buf = Vec::new(); write_varint_be(&mut buf, val); @@ -129,7 +129,7 @@ mod tests { } #[test] - fn test_be_varint_with_start() { + fn be_varint_with_start() { let mut buf = Vec::new(); write_varint_with_start(&mut buf, 42); write_varint_with_start(&mut buf, 100); diff --git a/crates/compiler/src/lib.rs b/crates/compiler/src/lib.rs index 1193661843b..9c0884c7520 100644 --- a/crates/compiler/src/lib.rs +++ b/crates/compiler/src/lib.rs @@ -1,4 +1,9 @@ +pub use ruff_python_ast::token::TokenKind; +use ruff_python_parser::{LexicalErrorType, ParseErrorType}; use ruff_source_file::{PositionEncoding, SourceFile, SourceFileBuilder, SourceLocation}; +use ruff_text_size::TextSlice; +use thiserror::Error; + use rustpython_codegen::{compile, symboltable}; pub use rustpython_codegen::compile::CompileOpts; @@ -9,20 +14,19 @@ pub use ruff_python_ast as ast; pub use ruff_python_parser as parser; pub use rustpython_codegen as codegen; pub use rustpython_compiler_core as core; -use thiserror::Error; #[derive(Error, Debug)] pub enum CompileErrorType { #[error(transparent)] Codegen(#[from] codegen::error::CodegenErrorType), #[error(transparent)] - Parse(#[from] parser::ParseErrorType), + Parse(#[from] ParseErrorType), } #[derive(Error, Debug)] pub struct ParseError { #[source] - pub error: parser::ParseErrorType, + pub error: ParseErrorType, pub raw_location: ruff_text_size::TextRange, pub location: SourceLocation, pub end_location: SourceLocation, @@ -54,57 +58,140 @@ impl CompileError { // For EOF errors (unclosed brackets), find the unclosed bracket position // and adjust both the error location and message let mut is_unclosed_bracket = false; - let (error_type, location, end_location) = if matches!( - &error.error, - parser::ParseErrorType::Lexical(parser::LexicalErrorType::Eof) - ) { - if let Some((bracket_char, bracket_offset)) = find_unclosed_bracket(source_text) { - let bracket_text_size = ruff_text_size::TextSize::new(bracket_offset as u32); - let loc = source_code.source_location(bracket_text_size, PositionEncoding::Utf8); + let (error_type, location, end_location) = match &error.error { + ParseErrorType::Lexical(LexicalErrorType::Eof) => { + if let Some((bracket_char, bracket_offset)) = find_unclosed_bracket(source_text) { + let bracket_text_size = ruff_text_size::TextSize::new(bracket_offset as u32); + let loc = + source_code.source_location(bracket_text_size, PositionEncoding::Utf8); + let end_loc = SourceLocation { + line: loc.line, + character_offset: loc.character_offset.saturating_add(1), + }; + let msg = format!("'{bracket_char}' was never closed"); + is_unclosed_bracket = true; + (ParseErrorType::OtherError(msg), loc, end_loc) + } else { + let loc = + source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + (error.error, loc, end_loc) + } + } + + ParseErrorType::Lexical(LexicalErrorType::IndentationError) => { + // For IndentationError, point the offset to the end of the line content + // instead of the beginning + let loc = + source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let line_idx = loc.line.to_zero_indexed(); + let line = source_text.split('\n').nth(line_idx).unwrap_or(""); + let line_end_col = line.chars().count() + 1; // 1-indexed, past last char let end_loc = SourceLocation { line: loc.line, - character_offset: loc.character_offset.saturating_add(1), + character_offset: ruff_source_file::OneIndexed::new(line_end_col) + .unwrap_or(loc.character_offset), }; - let msg = format!("'{bracket_char}' was never closed"); - is_unclosed_bracket = true; - (parser::ParseErrorType::OtherError(msg), loc, end_loc) - } else { + (error.error, end_loc, end_loc) + } + ParseErrorType::ExpectedToken { expected, found } + if matches!((expected, found), (TokenKind::Comma, TokenKind::Int)) => + { let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let end_loc = + let mut end_loc = source_code.source_location(error.location.end(), PositionEncoding::Utf8); - (error.error, loc, end_loc) + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + let msg = "invalid syntax. Perhaps you forgot a comma?".into(); + (ParseErrorType::OtherError(msg), loc, end_loc) } - } else if matches!( - &error.error, - parser::ParseErrorType::Lexical(parser::LexicalErrorType::IndentationError) - ) { - // For IndentationError, point the offset to the end of the line content - // instead of the beginning - let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let line_idx = loc.line.to_zero_indexed(); - let line = source_text.split('\n').nth(line_idx).unwrap_or(""); - let line_end_col = line.chars().count() + 1; // 1-indexed, past last char - let end_loc = SourceLocation { - line: loc.line, - character_offset: ruff_source_file::OneIndexed::new(line_end_col) - .unwrap_or(loc.character_offset), - }; - (error.error, end_loc, end_loc) - } else { - let loc = source_code.source_location(error.location.start(), PositionEncoding::Utf8); - let mut end_loc = - source_code.source_location(error.location.end(), PositionEncoding::Utf8); - - // If the error range ends at the start of a new line (column 1), - // adjust it to the end of the previous line - if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { - let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); - end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); - end_loc.character_offset = end_loc.character_offset.saturating_add(1); + + ParseErrorType::InvalidAssignmentTarget => { + let loc = + source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + let expr_str = source_file.source_text().slice(error.location); + + let msg = parser::parse_expression(expr_str).map_or_else( + |_| match expr_str { + "yield" => "assignment to yield expression not possible".into(), + _ => format!("cannot assign to {expr_str}"), + }, + |parsed| match *parsed.syntax().body { + ast::Expr::Call(_) => "cannot assign to function call".into(), + ast::Expr::BinOp(_) => "cannot assign to expression".into(), + ast::Expr::If(_) => "cannot assign to conditional expression".into(), + ast::Expr::Generator(_) => "cannot assign to generator expression".into(), + ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::NumberLiteral(_) => { + "cannot assign to literal here. Maybe you meant '==' instead of '='?" + .into() + } + ast::Expr::EllipsisLiteral(_) => { + "cannot assign to ellipsis here. Maybe you meant '==' instead of '='?" + .into() + } + _ => format!("cannot assign to {expr_str}"), + }, + ); + + (ParseErrorType::OtherError(msg), loc, end_loc) + } + + ParseErrorType::InvalidNamedAssignmentTarget => { + let loc = + source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + let target = source_file.source_text().slice(error.location); + let msg = format!("cannot use assignment expressions with {target}"); + (ParseErrorType::OtherError(msg), loc, end_loc) } - (error.error, loc, end_loc) + _ => { + let loc = + source_code.source_location(error.location.start(), PositionEncoding::Utf8); + let mut end_loc = + source_code.source_location(error.location.end(), PositionEncoding::Utf8); + + // If the error range ends at the start of a new line (column 1), + // adjust it to the end of the previous line + if end_loc.character_offset.get() == 1 && end_loc.line > loc.line { + let prev_line_end = error.location.end() - ruff_text_size::TextSize::from(1); + end_loc = source_code.source_location(prev_line_end, PositionEncoding::Utf8); + end_loc.character_offset = end_loc.character_offset.saturating_add(1); + } + + (error.error, loc, end_loc) + } }; Self::Parse(ParseError { @@ -339,29 +426,33 @@ pub fn _compile_symtable( res.map_err(|e| e.into_codegen_error(source_file.name().to_owned()).into()) } -#[test] -fn test_compile() { - let code = "x = 'abc'"; - let compiled = compile(code, Mode::Single, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic_compile() { + let code = "x = 'abc'"; + let compiled = compile(code, Mode::Single, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_phello() { - let code = r#" + #[test] + fn compile_phello() { + let code = r#" initialized = True def main(): print("Hello world!") if __name__ == '__main__': main() "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_if_elif_else() { - let code = r#" + #[test] + fn compile_if_elif_else() { + let code = r#" if False: pass elif False: @@ -371,31 +462,31 @@ elif False: else: pass "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_lambda() { - let code = r#" + #[test] + fn compile_lambda() { + let code = r#" lambda: 'a' "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_lambda2() { - let code = r#" + #[test] + fn compile_lambda2() { + let code = r#" (lambda x: f'hello, {x}')('world}') "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_lambda3() { - let code = r#" + #[test] + fn compile_lambda3() { + let code = r#" def g(): pass def f(): @@ -406,69 +497,69 @@ def f(): else: return g "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_int() { - let code = r#" + #[test] + fn compile_int() { + let code = r#" a = 0xFF "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_bigint() { - let code = r#" + #[test] + fn compile_bigint() { + let code = r#" a = 0xFFFFFFFFFFFFFFFFFFFFFFFF "#; - let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_compile_fstring() { - let code1 = r#" + #[test] + fn compile_fstring() { + let code1 = r#" assert f"1" == '1' "#; - let compiled = compile(code1, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); + let compiled = compile(code1, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); - let code2 = r#" + let code2 = r#" assert f"{1}" == '1' "#; - let compiled = compile(code2, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); - let code3 = r#" + let compiled = compile(code2, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + let code3 = r#" assert f"{1+1}" == '2' "#; - let compiled = compile(code3, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); + let compiled = compile(code3, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); - let code4 = r#" + let code4 = r#" assert f"{{{(lambda: f'{1}')}" == '{1' "#; - let compiled = compile(code4, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); + let compiled = compile(code4, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); - let code5 = r#" + let code5 = r#" assert f"a{1}" == 'a1' "#; - let compiled = compile(code5, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); + let compiled = compile(code5, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); - let code6 = r#" + let code6 = r#" assert f"{{{(lambda x: f'hello, {x}')('world}')}" == '{hello, world}' "#; - let compiled = compile(code6, Mode::Exec, "<>", CompileOpts::default()); - dbg!(compiled.expect("compile error")); -} + let compiled = compile(code6, Mode::Exec, "<>", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } -#[test] -fn test_simple_enum() { - let code = r#" + #[test] + fn simple_enum() { + let code = r#" import enum @enum._simple_enum(enum.IntFlag, boundary=enum.KEEP) class RegexFlag: @@ -476,6 +567,7 @@ class RegexFlag: DEBUG = 1 print(RegexFlag.NOFLAG & RegexFlag.DEBUG) "#; - let compiled = compile(code, Mode::Exec, "", CompileOpts::default()); - dbg!(compiled.expect("compile error")); + let compiled = compile(code, Mode::Exec, "", CompileOpts::default()); + dbg!(compiled.expect("compile error")); + } } diff --git a/crates/derive-impl/Cargo.toml b/crates/derive-impl/Cargo.toml index 383bf229171..7197c51ecfa 100644 --- a/crates/derive-impl/Cargo.toml +++ b/crates/derive-impl/Cargo.toml @@ -16,7 +16,6 @@ rustpython-doc = { workspace = true } itertools = { workspace = true } syn = { workspace = true, features = ["full", "extra-traits"] } -maplit = { workspace = true } proc-macro2 = { workspace = true } quote = { workspace = true } syn-ext = { workspace = true, features = ["full"] } diff --git a/crates/derive-impl/src/compile_bytecode.rs b/crates/derive-impl/src/compile_bytecode.rs index a640d59c577..05197440e69 100644 --- a/crates/derive-impl/src/compile_bytecode.rs +++ b/crates/derive-impl/src/compile_bytecode.rs @@ -83,12 +83,14 @@ impl CompilationSource { CompilationSourceKind::Dir { base, rel_path } => { self.compile_dir(base, &base.join(rel_path), "", mode, compiler) } - _ => Ok(hashmap! { - module_name.to_string() => CompiledModule { + _ => Ok(core::iter::once(( + module_name.to_string(), + CompiledModule { code: self.compile_single(mode, module_name, compiler)?, package: false, }, - }), + )) + .collect()), } } diff --git a/crates/derive-impl/src/from_args.rs b/crates/derive-impl/src/from_args.rs index 8149a3aa65e..adcbbd418d2 100644 --- a/crates/derive-impl/src/from_args.rs +++ b/crates/derive-impl/src/from_args.rs @@ -6,7 +6,7 @@ use syn::{Attribute, Data, DeriveInput, Expr, Field, Ident, Result, Token, parse /// The kind of the python parameter, this corresponds to the value of Parameter.kind /// (https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind) -#[derive(Default)] +#[derive(Clone, Copy, Default, Eq, PartialEq)] enum ParameterKind { PositionalOnly, #[default] @@ -77,9 +77,10 @@ impl ArgAttribute { } fn parse_argument(&mut self, meta: ParseNestedMeta<'_>) -> Result<()> { - if let ParameterKind::Flatten = self.kind { + if self.kind == ParameterKind::Flatten { return Err(meta.error("can't put additional arguments on a flatten arg")); } + if meta.path.is_ident("default") && meta.input.peek(Token![=]) { if matches!(self.default, Some(Some(_))) { return Err(meta.error("Default already set")); diff --git a/crates/derive-impl/src/lib.rs b/crates/derive-impl/src/lib.rs index 91f606bba1f..3d4b7991511 100644 --- a/crates/derive-impl/src/lib.rs +++ b/crates/derive-impl/src/lib.rs @@ -4,9 +4,6 @@ extern crate proc_macro; -#[macro_use] -extern crate maplit; - #[macro_use] mod error; #[macro_use] diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 5e25e381df6..0cc0faa113a 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -644,8 +644,8 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result T // their own __init__ in __dict__. let slot_init = quote!(); - let extra_attrs_tokens = if extra_attrs.is_empty() { - quote!() - } else { - quote!(, #(#extra_attrs),*) - }; + let extra_attrs_tokens = quote!(#(#extra_attrs),*); quote! { - #[pyclass(flags(BASETYPE, HAS_DICT), with(#(#with_items),*) #extra_attrs_tokens)] + #[pyclass(flags(BASETYPE, HAS_DICT), with(#(#with_items),*), #extra_attrs_tokens)] impl #generics #self_ty { #(#items)* } diff --git a/crates/derive-impl/src/pymodule.rs b/crates/derive-impl/src/pymodule.rs index cee8b1be4a2..32d7a0fa6bf 100644 --- a/crates/derive-impl/src/pymodule.rs +++ b/crates/derive-impl/src/pymodule.rs @@ -84,13 +84,13 @@ fn negate_cfg_attrs(cfg_attrs: &[Attribute]) -> Vec { if cfg_attrs.is_empty() { return vec![]; } - let predicates: Vec<_> = cfg_attrs + let predicates = cfg_attrs .iter() .map(|attr| match &attr.meta { syn::Meta::List(list) => list.tokens.clone(), _ => unreachable!("only #[cfg(...)] should be here"), }) - .collect(); + .collect::>(); if predicates.len() == 1 { let predicate = &predicates[0]; vec![parse_quote!(#[cfg(not(#predicate))])] @@ -295,19 +295,20 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result, Vec<_>) = with_items.iter().partition(|w| w.cfg_attrs.is_empty()); - let uncond_paths: Vec<_> = uncond_withs.iter().map(|w| &w.path).collect(); + let uncond_paths = uncond_withs.iter().map(|w| &w.path).collect::>(); let method_defs = if with_items.is_empty() { quote!(#function_items) } else { // For cfg-gated with items, generate conditional const declarations // so the total array size adapts to the cfg at compile time - let cond_const_names: Vec<_> = cond_withs + let cond_const_names = cond_withs .iter() .enumerate() .map(|(i, _)| format_ident!("__WITH_METHODS_{}", i)) - .collect(); - let cond_const_decls: Vec<_> = cond_withs + .collect::>(); + + let cond_const_decls= cond_withs .iter() .zip(&cond_const_names) .map(|(w, name)| { @@ -321,7 +322,7 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result>(); quote!({ const OWN_METHODS: &'static [::rustpython_vm::function::PyMethodDef] = &#function_items; @@ -340,7 +341,7 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result = with_items + let init_with_calls = with_items .iter() .map(|w| { let cfg_attrs = &w.cfg_attrs; @@ -350,7 +351,7 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result>(); items.extend([ parse_quote! { @@ -702,8 +703,7 @@ impl ModuleItem for FunctionItem { let r = loop_unit(); args.context.errors.ok_or_push(r); } - let py_names: Vec<_> = py_names.into_iter().collect(); - py_names + py_names.into_iter().collect::>() } }; let call_flags = infer_native_call_flags(func.sig(), 0); diff --git a/crates/derive-impl/src/pystructseq.rs b/crates/derive-impl/src/pystructseq.rs index 4059aba63b7..874f85741f0 100644 --- a/crates/derive-impl/src/pystructseq.rs +++ b/crates/derive-impl/src/pystructseq.rs @@ -102,12 +102,12 @@ fn parse_fields(input: &mut DeriveInput) -> Result { bail_span!(input, "Only #[pystruct_sequence(...)] form is allowed"); }; - let idents: Vec<_> = l + let idents = l .nested .iter() .filter_map(|n| n.get_ident()) .cloned() - .collect(); + .collect::>(); for ident in idents { match ident.to_string().as_str() { @@ -205,7 +205,7 @@ pub(crate) fn impl_pystruct_sequence_data( let n_unnamed_fields = field_info.n_unnamed_fields(); // Generate field index constants for visible fields (with cfg guards) - let field_indices: Vec<_> = visible_fields + let field_indices = visible_fields .iter() .enumerate() .map(|(i, field)| { @@ -216,78 +216,58 @@ pub(crate) fn impl_pystruct_sequence_data( pub const #const_name: usize = #i; } }) - .collect(); + .collect::>(); // Generate field name entries with cfg guards for named fields - let named_field_names: Vec<_> = named_fields + let named_field_names = named_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { stringify!(#ident), } - } else { - quote! { - #(#cfg_attrs)* - { stringify!(#ident) }, - } + quote! { + #(#cfg_attrs)* + { stringify!(#ident) }, } }) - .collect(); + .collect::>(); // Generate field name entries with cfg guards for skipped fields - let skipped_field_names: Vec<_> = skipped_fields + let skipped_field_names = skipped_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { stringify!(#ident), } - } else { - quote! { - #(#cfg_attrs)* - { stringify!(#ident) }, - } + quote! { + #(#cfg_attrs)* + { stringify!(#ident) }, } }) - .collect(); + .collect::>(); // Generate into_tuple items with cfg guards - let visible_tuple_items: Vec<_> = visible_fields + let visible_tuple_items = visible_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { - ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm), - } - } else { - quote! { - #(#cfg_attrs)* - { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, - } + quote! { + #(#cfg_attrs)* + { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, } }) - .collect(); + .collect::>(); - let skipped_tuple_items: Vec<_> = skipped_fields + let skipped_tuple_items = skipped_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { - ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm), - } - } else { - quote! { - #(#cfg_attrs)* - { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, - } + quote! { + #(#cfg_attrs)* + { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, } }) - .collect(); + .collect::>(); // Generate TryFromObject impl only when try_from_object=true let try_from_object_impl = if try_from_object { @@ -317,44 +297,33 @@ pub(crate) fn impl_pystruct_sequence_data( // Generate try_from_elements trait override only when try_from_object=true let try_from_elements_trait_override = if try_from_object { - let visible_field_inits: Vec<_> = visible_fields + let visible_field_inits = visible_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { #ident: iter.next().unwrap().clone().try_into_value(vm)?, } - } else { - quote! { - #(#cfg_attrs)* - #ident: iter.next().unwrap().clone().try_into_value(vm)?, - } + quote! { + #(#cfg_attrs)* + #ident: iter.next().unwrap().clone().try_into_value(vm)?, } }) - .collect(); - let skipped_field_inits: Vec<_> = skipped_fields + .collect::>(); + + let skipped_field_inits = skipped_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { - #ident: match iter.next() { - Some(v) => v.clone().try_into_value(vm)?, - None => vm.ctx.none(), - }, - } - } else { - quote! { - #(#cfg_attrs)* - #ident: match iter.next() { - Some(v) => v.clone().try_into_value(vm)?, - None => vm.ctx.none(), - }, - } + quote! { + #(#cfg_attrs)* + #ident: match iter.next() { + Some(v) => v.clone().try_into_value(vm)?, + None => vm.ctx.none(), + }, } }) - .collect(); + .collect::>(); + quote! { fn try_from_elements( elements: Vec<::rustpython_vm::PyObjectRef>, @@ -426,6 +395,7 @@ impl ItemMeta for PyStructSequenceMeta { fn from_inner(inner: ItemMetaInner) -> Self { Self { inner } } + fn inner(&self) -> &ItemMetaInner { &self.inner } diff --git a/crates/derive-impl/src/pytraverse.rs b/crates/derive-impl/src/pytraverse.rs index c4ec3823298..c75eee87ba8 100644 --- a/crates/derive-impl/src/pytraverse.rs +++ b/crates/derive-impl/src/pytraverse.rs @@ -38,30 +38,31 @@ fn field_to_traverse_code(field: &Field) -> Result { .iter() .filter_map(pytraverse_arg) .collect::, _>>()?; - let do_trace = if pytraverse_attrs.len() > 1 { + + if pytraverse_attrs.len() > 1 { bail_span!( field, "found multiple #[pytraverse] attributes on the same field, expect at most one" ) - } else if pytraverse_attrs.is_empty() { - // default to always traverse every field - true - } else { - !pytraverse_attrs[0].skip - }; + } + let name = field.ident.as_ref().ok_or_else(|| { syn::Error::new_spanned( field.clone(), "Field should have a name in non-tuple struct", ) })?; - if do_trace { - Ok(quote!( + + // default to always traverse every field + let do_trace = pytraverse_attrs.first().is_none_or(|attr| !attr.skip); + + Ok(if do_trace { + quote!( ::rustpython_vm::object::Traverse::traverse(&self.#name, tracer_fn); - )) + ) } else { - Ok(quote!()) - } + quote!() + }) } /// not trace corresponding field @@ -76,20 +77,16 @@ fn gen_trace_code(item: &mut DeriveInput) -> Result { .iter_mut() .map(|f| -> Result { field_to_traverse_code(f) }) .collect::>()?; - let res = res.into_iter().collect::(); - Ok(res) - } - syn::Fields::Unnamed(fields) => { - let res: TokenStream = (0..fields.unnamed.len()) - .map(|i| { - let i = syn::Index::from(i); - quote!( - ::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn); - ) - }) - .collect(); - Ok(res) + Ok(res.into_iter().collect::()) } + syn::Fields::Unnamed(fields) => Ok((0..fields.unnamed.len()) + .map(|i| { + let i = syn::Index::from(i); + quote!( + ::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn); + ) + }) + .collect::()), _ => Err(syn::Error::new_spanned( fields, "Only named and unnamed fields are supported", @@ -116,12 +113,11 @@ pub(crate) fn impl_pytraverse(mut item: DeriveInput) -> Result { let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); - let ret = quote! { + Ok(quote! { unsafe impl #impl_generics ::rustpython_vm::object::Traverse for #ty #ty_generics #where_clause { fn traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) { #trace_code } } - }; - Ok(ret) + }) } diff --git a/crates/derive-impl/src/util.rs b/crates/derive-impl/src/util.rs index 3ad41679c3d..52e1fa236f3 100644 --- a/crates/derive-impl/src/util.rs +++ b/crates/derive-impl/src/util.rs @@ -80,7 +80,7 @@ impl ToTokens for ValidatedItemNursery { let cfgs = &item.cfgs; let tokens = &item.tokens; quote! { - #( #cfgs )* + #(#cfgs)* { #tokens } @@ -99,12 +99,15 @@ pub(crate) trait ContentItem { type AttrName: core::str::FromStr + core::fmt::Display; fn inner(&self) -> &ContentItemInner; + fn index(&self) -> usize { self.inner().index } + fn attr_name(&self) -> &Self::AttrName { &self.inner().attr_name } + fn new_syn_error(&self, span: Span, message: &str) -> syn::Error { syn::Error::new(span, format!("#[{}] {}", self.attr_name(), message)) } @@ -142,6 +145,7 @@ impl ItemMetaInner { Ok(None) } })?; + if !lits.is_empty() { bail_span!(meta_ident, "#[{meta_ident}(..)] cannot contain literal") } @@ -153,6 +157,10 @@ impl ItemMetaInner { }) } + pub(crate) fn contains_key(&self, key: &str) -> bool { + self.meta_map.contains_key(key) + } + pub(crate) fn item_name(&self) -> String { self.item_ident.to_string() } @@ -215,10 +223,6 @@ impl ItemMetaInner { Ok(value) } - pub(crate) fn _has_key(&self, key: &str) -> bool { - matches!(self.meta_map.get(key), Some((_, _))) - } - pub(crate) fn _bool(&self, key: &str) -> Result { let value = if let Some((_, meta)) = self.meta_map.get(key) { match meta { @@ -517,6 +521,7 @@ impl ExceptionItemMeta { impl core::ops::Deref for ExceptionItemMeta { type Target = ClassItemMeta; + fn deref(&self) -> &Self::Target { &self.0 } @@ -524,8 +529,11 @@ impl core::ops::Deref for ExceptionItemMeta { pub(crate) trait AttributeExt: SynAttributeExt { fn promoted_nested(&self) -> Result; + fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)>; + fn try_remove_name(&mut self, name: &str) -> Result>; + fn fill_nested_meta(&mut self, name: &str, new_item: F) -> Result<()> where F: Fn() -> NestedMeta; @@ -544,6 +552,7 @@ impl AttributeExt for Attribute { })?; Ok(list.nested) } + fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)> { Ok((self.get_ident().unwrap(), self.promoted_nested()?)) } @@ -564,14 +573,10 @@ impl AttributeExt for Attribute { let mut found = None; for (i, item) in nested.iter().enumerate() { - let ident = if let Some(ident) = item.get_ident() { - ident - } else { - continue; - }; - if *ident != item_name { + if item.get_ident().is_none_or(|ident| ident != item_name) { continue; } + if found.is_some() { return Err(syn::Error::new( item.span(), diff --git a/crates/doc/src/lib.rs b/crates/doc/src/lib.rs index 8b7f5d8f75b..c3621c158b6 100644 --- a/crates/doc/src/lib.rs +++ b/crates/doc/src/lib.rs @@ -7,7 +7,7 @@ mod test { use super::DB; #[test] - fn test_db() { + fn db_basic() { let doc = DB.get("array._array_reconstructor"); assert!(doc.is_some()); } diff --git a/crates/host_env/Cargo.toml b/crates/host_env/Cargo.toml index 039c925868c..718e3f41aab 100644 --- a/crates/host_env/Cargo.toml +++ b/crates/host_env/Cargo.toml @@ -11,30 +11,62 @@ license.workspace = true [dependencies] rustpython-wtf8 = { workspace = true } +bitflags = { workspace = true } +getrandom = { workspace = true } libc = { workspace = true } num-traits = { workspace = true } +parking_lot = { workspace = true } +paste = { workspace = true } [target.'cfg(unix)'.dependencies] nix = { workspace = true } +rustix = { workspace = true } + +[target.'cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))'.dependencies] +num_cpus = "1.17.0" [target.'cfg(all(unix, not(target_os = "ios"), not(target_os = "redox")))'.dependencies] termios = { workspace = true } +[target.'cfg(any(unix, windows))'.dependencies] +memmap2 = "0.9.10" +libloading = "0.9" + +[target.'cfg(all(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "android"), not(any(target_env = "musl", target_env = "sgx"))))'.dependencies] +libffi = { workspace = true, features = ["system"] } + [target.'cfg(windows)'.dependencies] +junction = { workspace = true } +schannel = { workspace = true } widestring = { workspace = true } windows-sys = { workspace = true, features = [ "Win32_Foundation", "Win32_Globalization", + "Win32_NetworkManagement_IpHelper", + "Win32_NetworkManagement_Ndis", "Win32_Networking_WinSock", + "Win32_Security", + "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_Diagnostics_Debug", + "Win32_System_Environment", + "Win32_System_IO", "Win32_System_Ioctl", + "Win32_System_JobObjects", + "Win32_System_Kernel", "Win32_System_LibraryLoader", + "Win32_System_Memory", + "Win32_System_Pipes", + "Win32_System_Performance", + "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", + "Win32_System_WindowsProgramming", + "Win32_UI_Shell", + "Win32_UI_WindowsAndMessaging", ] } [lints] diff --git a/crates/host_env/src/cert_store.rs b/crates/host_env/src/cert_store.rs new file mode 100644 index 00000000000..f95716faa99 --- /dev/null +++ b/crates/host_env/src/cert_store.rs @@ -0,0 +1,134 @@ +use std::io; + +use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore}; +use windows_sys::Win32::{ + Foundation::{CRYPT_E_NOT_FOUND, GetLastError}, + Security::Cryptography::{ + CERT_CONTEXT, CRL_CONTEXT, CertCloseStore, CertEnumCRLsInStore, CertOpenSystemStoreW, + PKCS_7_ASN_ENCODING, X509_ASN_ENCODING, + }, +}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum EncodingType { + X509Asn, + Pkcs7Asn, + Other(u32), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum CertificateUses { + All, + Oids(Vec), +} + +#[derive(Debug)] +pub struct CertificateEntry { + pub der: Vec, + pub encoding: EncodingType, + pub valid_uses: io::Result, +} + +#[derive(Debug)] +pub struct CertificateEntries { + pub had_open_store: bool, + pub entries: Vec, +} + +#[derive(Debug)] +pub struct CrlEntry { + pub der: Vec, + pub encoding: EncodingType, +} + +fn encoding_type(raw: u32) -> EncodingType { + if raw & X509_ASN_ENCODING != 0 { + EncodingType::X509Asn + } else if raw & PKCS_7_ASN_ENCODING != 0 { + EncodingType::Pkcs7Asn + } else { + EncodingType::Other(raw) + } +} + +pub fn enum_certificates(store_name: &str) -> CertificateEntries { + let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; + let mut had_open_store = false; + let mut entries = Vec::new(); + + for open in open_fns { + let Ok(store) = open(store_name) else { + continue; + }; + had_open_store = true; + + for cert in store.certs() { + let encoding = unsafe { + let ptr = cert.as_ptr() as *const CERT_CONTEXT; + encoding_type((*ptr).dwCertEncodingType) + }; + let valid_uses = cert.valid_uses().map_or_else( + |err| Err(io::Error::other(err)), + |uses| { + Ok(match uses { + ValidUses::All => CertificateUses::All, + ValidUses::Oids(oids) => CertificateUses::Oids(oids.into_iter().collect()), + }) + }, + ); + entries.push(CertificateEntry { + der: cert.to_der().to_owned(), + encoding, + valid_uses, + }); + } + } + + CertificateEntries { + had_open_store, + entries, + } +} + +pub fn enum_crls(store_name: &str) -> io::Result> { + let store_name_wide: Vec = store_name + .encode_utf16() + .chain(core::iter::once(0)) + .collect(); + + let store = unsafe { CertOpenSystemStoreW(0, store_name_wide.as_ptr()) }; + if store.is_null() { + return Err(io::Error::last_os_error()); + } + + let mut result = Vec::new(); + let mut crl_context: *const CRL_CONTEXT = core::ptr::null(); + loop { + crl_context = unsafe { CertEnumCRLsInStore(store, crl_context) }; + if crl_context.is_null() { + let err = unsafe { GetLastError() }; + if err != CRYPT_E_NOT_FOUND as u32 { + unsafe { + CertCloseStore(store, 0); + } + return Err(io::Error::from_raw_os_error(err as i32)); + } + break; + } + + let crl = unsafe { &*crl_context }; + let der = + unsafe { core::slice::from_raw_parts(crl.pbCrlEncoded, crl.cbCrlEncoded as usize) } + .to_vec(); + result.push(CrlEntry { + der, + encoding: encoding_type(crl.dwCertEncodingType), + }); + } + + unsafe { + CertCloseStore(store, 0); + } + + Ok(result) +} diff --git a/crates/host_env/src/crt_fd_unsupported.rs b/crates/host_env/src/crt_fd_unsupported.rs new file mode 100644 index 00000000000..b6a7918b31c --- /dev/null +++ b/crates/host_env/src/crt_fd_unsupported.rs @@ -0,0 +1,173 @@ +use alloc::fmt; +use core::marker::PhantomData; +use std::{ffi, io}; + +pub type Offset = i64; +pub type Raw = i32; + +const EBADF: i32 = 9; + +#[repr(transparent)] +pub struct Owned { + fd: Raw, +} + +impl fmt::Debug for Owned { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("crt_fd::Owned") + .field(&self.as_raw()) + .finish() + } +} + +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Borrowed<'fd> { + fd: Raw, + _marker: PhantomData<&'fd Owned>, +} + +impl PartialEq for Borrowed<'_> { + fn eq(&self, other: &Self) -> bool { + self.as_raw() == other.as_raw() + } +} + +impl Eq for Borrowed<'_> {} + +impl fmt::Debug for Borrowed<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("crt_fd::Borrowed") + .field(&self.as_raw()) + .finish() + } +} + +impl Owned { + /// Create a `crt_fd::Owned` from a raw file descriptor. + /// + /// # Safety + /// + /// `fd` must be a valid file descriptor for the embedding host. + #[inline] + pub const unsafe fn from_raw(fd: Raw) -> Self { + Self { fd } + } + + /// Create a `crt_fd::Owned` from a raw file descriptor. + /// + /// Returns an error if `fd` is negative. + /// + /// # Safety + /// + /// `fd` must be a valid file descriptor for the embedding host. + #[inline] + pub unsafe fn try_from_raw(fd: Raw) -> io::Result { + if fd < 0 { + Err(ebadf()) + } else { + Ok(unsafe { Self::from_raw(fd) }) + } + } + + #[inline] + pub const fn borrow(&self) -> Borrowed<'_> { + unsafe { Borrowed::borrow_raw(self.as_raw()) } + } + + #[inline] + pub const fn as_raw(&self) -> Raw { + self.fd + } + + #[inline] + pub fn into_raw(self) -> Raw { + let fd = self.fd; + core::mem::forget(self); + fd + } + + pub fn leak<'fd>(self) -> Borrowed<'fd> { + unsafe { Borrowed::borrow_raw(self.into_raw()) } + } +} + +impl Drop for Owned { + fn drop(&mut self) {} +} + +impl<'fd> Borrowed<'fd> { + /// Create a `crt_fd::Borrowed` from a raw file descriptor. + /// + /// # Safety + /// + /// `fd` must be a valid file descriptor for the embedding host. + #[inline] + pub const unsafe fn borrow_raw(fd: Raw) -> Self { + Self { + fd, + _marker: PhantomData, + } + } + + /// Create a `crt_fd::Borrowed` from a raw file descriptor. + /// + /// Returns an error if `fd` is negative. + /// + /// # Safety + /// + /// `fd` must be a valid file descriptor for the embedding host. + #[inline] + pub unsafe fn try_borrow_raw(fd: Raw) -> io::Result { + if fd < 0 { + Err(ebadf()) + } else { + Ok(unsafe { Self::borrow_raw(fd) }) + } + } + + #[inline] + pub const fn as_raw(self) -> Raw { + self.fd + } +} + +#[inline] +fn ebadf() -> io::Error { + io::Error::from_raw_os_error(EBADF) +} + +pub fn open(_path: &ffi::CStr, _flags: i32, _mode: i32) -> io::Result { + Err(unsupported()) +} + +pub fn openat(_dir: Borrowed<'_>, _path: &ffi::CStr, _flags: i32, _mode: i32) -> io::Result { + Err(unsupported()) +} + +pub fn fsync(_fd: Borrowed<'_>) -> io::Result<()> { + Err(ebadf()) +} + +pub fn close(_fd: Owned) -> io::Result<()> { + Err(ebadf()) +} + +pub fn ftruncate(_fd: Borrowed<'_>, _len: Offset) -> io::Result<()> { + Err(ebadf()) +} + +pub fn write(_fd: Borrowed<'_>, _buf: &[u8]) -> io::Result { + Err(ebadf()) +} + +pub fn read(_fd: Borrowed<'_>, _buf: &mut [u8]) -> io::Result { + Err(ebadf()) +} + +fn unsupported() -> io::Error { + io::Error::new( + io::ErrorKind::Unsupported, + "host file descriptors are unsupported on this platform", + ) +} diff --git a/crates/host_env/src/ctypes.rs b/crates/host_env/src/ctypes.rs new file mode 100644 index 00000000000..ae64b06cfe0 --- /dev/null +++ b/crates/host_env/src/ctypes.rs @@ -0,0 +1,2724 @@ +use alloc::borrow::Cow; +use core::ffi::{ + CStr, c_char, c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, + c_ulong, c_ulonglong, c_ushort, c_void, +}; +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +use libffi::middle::Type; +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +use libffi::{ + low, + middle::{Arg, Cif, Closure, CodePtr}, +}; +#[cfg(any(unix, windows))] +use libloading::Library; +#[cfg(unix)] +use libloading::os::unix::Library as UnixLibrary; +#[cfg(any(unix, windows))] +use parking_lot::{Mutex, RwLock}; +use rustpython_wtf8::Wtf8; +use rustpython_wtf8::Wtf8Buf; +#[cfg(any(unix, windows))] +use std::{collections::HashMap, ffi::OsStr, sync::OnceLock}; + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub type FfiType = Type; + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub type FfiArg<'a> = Arg<'a>; + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub type FfiCodePtr = CodePtr; + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub type FfiCif = low::ffi_cif; + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +type CallbackIntResult = low::ffi_arg; + +#[cfg(not(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +)))] +type CallbackIntResult = c_int; + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub type WChar = libc::wchar_t; +#[cfg(not(any(unix, windows, target_os = "wasi")))] +pub type WChar = u32; + +#[cfg(any(unix, windows, target_os = "wasi"))] +type TimeT = libc::time_t; +#[cfg(not(any(unix, windows, target_os = "wasi")))] +type TimeT = i64; + +std::thread_local! { + /// Thread-local ctypes errno, separate from the platform errno. + #[allow(clippy::missing_const_for_thread_local)] + static CTYPES_LOCAL_ERRNO: core::cell::Cell = const { core::cell::Cell::new(0) }; +} + +pub fn get_errno() -> i32 { + CTYPES_LOCAL_ERRNO.with(|e| e.get()) +} + +pub fn set_errno(value: i32) -> i32 { + CTYPES_LOCAL_ERRNO.with(|e| { + let old = e.get(); + e.set(value); + old + }) +} + +#[cfg(not(windows))] +pub fn with_swapped_errno(f: F) -> R +where + F: FnOnce() -> R, +{ + let saved_errno = crate::os::get_errno(); + let saved_ctypes_errno = CTYPES_LOCAL_ERRNO.with(|e| e.get()); + crate::os::set_errno(saved_ctypes_errno); + + let result = f(); + + let new_error = crate::os::get_errno(); + CTYPES_LOCAL_ERRNO.with(|e| e.set(new_error)); + crate::os::set_errno(saved_errno); + + result +} + +pub fn with_callback_errno_preserved(use_errno: bool, f: F) -> R +where + F: FnOnce() -> R, +{ + if !use_errno { + return f(); + } + + let saved = crate::os::get_errno(); + let result = f(); + let _current = crate::os::get_errno(); + crate::os::set_errno(saved); + result +} + +#[cfg(windows)] +std::thread_local! { + /// Thread-local ctypes last_error, separate from the Windows last error. + static CTYPES_LOCAL_LAST_ERROR: core::cell::Cell = const { core::cell::Cell::new(0) }; +} + +#[cfg(windows)] +pub fn get_last_error() -> u32 { + CTYPES_LOCAL_LAST_ERROR.with(|e| e.get()) +} + +#[cfg(windows)] +pub fn set_last_error(value: u32) -> u32 { + CTYPES_LOCAL_LAST_ERROR.with(|e| { + let old = e.get(); + e.set(value); + old + }) +} + +#[cfg(windows)] +pub fn with_swapped_last_error(f: F) -> R +where + F: FnOnce() -> R, +{ + let saved_last_error = crate::windows::get_last_error(); + let saved_ctypes_last_error = CTYPES_LOCAL_LAST_ERROR.with(|e| e.get()); + crate::windows::set_last_error(saved_ctypes_last_error); + + let result = f(); + + let new_error = crate::windows::get_last_error(); + CTYPES_LOCAL_LAST_ERROR.with(|e| e.set(new_error)); + crate::windows::set_last_error(saved_last_error); + + result +} + +#[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") +))] +const LONG_DOUBLE_SIZE: usize = core::mem::size_of::(); + +#[cfg(target_os = "windows")] +const LONG_DOUBLE_SIZE: usize = core::mem::size_of::(); + +#[cfg(not(any( + all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ), + target_os = "windows" +)))] +const LONG_DOUBLE_SIZE: usize = core::mem::size_of::(); + +pub fn simple_type_size(ty: &str) -> Option { + match ty { + "c" | "b" => Some(core::mem::size_of::()), + "u" => Some(core::mem::size_of::()), + "B" | "?" => Some(core::mem::size_of::()), + "h" | "v" => Some(core::mem::size_of::()), + "H" => Some(core::mem::size_of::()), + "i" => Some(core::mem::size_of::()), + "I" => Some(core::mem::size_of::()), + "l" => Some(core::mem::size_of::()), + "L" => Some(core::mem::size_of::()), + "q" => Some(core::mem::size_of::()), + "Q" => Some(core::mem::size_of::()), + "f" => Some(core::mem::size_of::()), + "d" => Some(core::mem::size_of::()), + "g" => Some(LONG_DOUBLE_SIZE), + "z" | "Z" | "P" | "X" | "O" => Some(core::mem::size_of::()), + "void" => Some(0), + _ => None, + } +} + +pub fn simple_type_align(ty: &str) -> Option { + match ty { + "c" | "b" => Some(core::mem::align_of::()), + "u" => Some(core::mem::align_of::()), + "B" | "?" => Some(core::mem::align_of::()), + "h" | "v" => Some(core::mem::align_of::()), + "H" => Some(core::mem::align_of::()), + "i" => Some(core::mem::align_of::()), + "I" => Some(core::mem::align_of::()), + "l" => Some(core::mem::align_of::()), + "L" => Some(core::mem::align_of::()), + "q" => Some(core::mem::align_of::()), + "Q" => Some(core::mem::align_of::()), + "f" => Some(core::mem::align_of::()), + "d" => Some(core::mem::align_of::()), + "g" => { + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + { + Some(core::mem::align_of::()) + } + #[cfg(not(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + )))] + { + Some(core::mem::align_of::()) + } + } + "z" | "Z" | "P" | "X" | "O" => Some(core::mem::align_of::()), + "void" => Some(0), + _ => None, + } +} + +pub fn c_long_bytes_endian(value: i128, swapped: bool) -> Vec { + let value = value as c_long; + int_to_sized_bytes_endian(value as i64, core::mem::size_of::(), swapped) +} + +pub fn c_ulong_bytes_endian(value: i128, swapped: bool) -> Vec { + let value = value as c_ulong; + uint_to_sized_bytes_endian(value as u64, core::mem::size_of::(), swapped) +} + +pub fn simple_type_pep3118_code(code: char) -> char { + match code { + 'i' if core::mem::size_of::() == 2 => 'h', + 'i' if core::mem::size_of::() == 4 => 'i', + 'i' if core::mem::size_of::() == 8 => 'q', + 'I' if core::mem::size_of::() == 2 => 'H', + 'I' if core::mem::size_of::() == 4 => 'I', + 'I' if core::mem::size_of::() == 8 => 'Q', + 'l' if core::mem::size_of::() == 4 => 'l', + 'l' if core::mem::size_of::() == 8 => 'q', + 'L' if core::mem::size_of::() == 4 => 'L', + 'L' if core::mem::size_of::() == 8 => 'Q', + '?' if core::mem::size_of::() == 1 => '?', + '?' if core::mem::size_of::() == 2 => 'H', + '?' if core::mem::size_of::() == 4 => 'L', + '?' if core::mem::size_of::() == 8 => 'Q', + _ => code, + } +} + +pub enum StringAtError { + NullPointer, + TooLong, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RawMemoryViewError { + NullPointer, + NegativeSize, +} + +#[derive(Debug, Clone, Copy)] +pub struct RawMemoryView { + ptr: usize, + size: usize, + readonly: bool, +} + +impl RawMemoryView { + pub fn new(ptr: usize, size: isize, readonly: bool) -> Result { + if ptr == 0 { + return Err(RawMemoryViewError::NullPointer); + } + if size < 0 { + return Err(RawMemoryViewError::NegativeSize); + } + Ok(Self { + ptr, + size: size as usize, + readonly, + }) + } + + pub fn size(self) -> usize { + self.size + } + + pub fn readonly(self) -> bool { + self.readonly + } + + /// # Safety + /// + /// The stored pointer must remain valid for `self.size` bytes. + pub unsafe fn bytes(self) -> &'static [u8] { + unsafe { borrow_memory(self.ptr as *const u8, self.size) } + } + + /// # Safety + /// + /// The stored pointer must remain valid and uniquely writable for + /// `self.size` bytes. + pub unsafe fn bytes_mut(self) -> &'static mut [u8] { + unsafe { borrow_memory_mut(self.ptr as *mut u8, self.size) } + } +} + +// These match the current RustPython _ctypes surface exactly. +pub const RTLD_LOCAL: i32 = 0; +pub const RTLD_GLOBAL: i32 = 0; +pub const SIZEOF_TIME_T: usize = core::mem::size_of::(); + +#[cfg(all(unix, not(target_os = "wasi")))] +pub fn dlopen_mode(load_flags: Option) -> i32 { + load_flags.unwrap_or(libc::RTLD_NOW | libc::RTLD_LOCAL) | libc::RTLD_NOW +} + +#[cfg(not(all(unix, not(target_os = "wasi"))))] +pub fn dlopen_mode(load_flags: Option) -> i32 { + load_flags.unwrap_or(0) +} + +#[cfg(target_os = "macos")] +pub fn dyld_shared_cache_contains_path(path: &str) -> Result { + let c_path = alloc::ffi::CString::new(path)?; + + unsafe extern "C" { + fn _dyld_shared_cache_contains_path(path: *const c_char) -> bool; + } + + Ok(unsafe { _dyld_shared_cache_contains_path(c_path.as_ptr()) }) +} + +/// # Safety +/// +/// `ptr` must be valid to read until the first NUL byte. +pub unsafe fn strlen(ptr: *const c_char) -> usize { + #[cfg(any(unix, windows, target_os = "wasi"))] + { + unsafe { libc::strlen(ptr) } + } + #[cfg(not(any(unix, windows, target_os = "wasi")))] + { + let mut len = 0; + while unsafe { *ptr.add(len) } != 0 { + len += 1; + } + len + } +} + +/// # Safety +/// +/// `ptr` must be valid to read until the first NUL wide character. +pub unsafe fn wcslen(ptr: *const WChar) -> usize { + let mut len = 0; + while unsafe { *ptr.add(len) } != 0 as WChar { + len += 1; + } + len +} + +/// # Safety +/// +/// `ptr` must be a valid NUL-terminated C string. +pub unsafe fn read_c_string_bytes(ptr: *const c_char) -> Vec { + unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec() +} + +#[inline] +pub fn read_pointer_from_buffer(buffer: &[u8]) -> usize { + const PTR_SIZE: usize = core::mem::size_of::(); + buffer + .first_chunk::() + .copied() + .map_or(0, usize::from_ne_bytes) +} + +pub const WCHAR_SIZE: usize = core::mem::size_of::(); + +#[inline] +pub fn wchar_from_bytes(bytes: &[u8]) -> Option { + if bytes.len() < WCHAR_SIZE { + return None; + } + Some(if WCHAR_SIZE == 2 { + u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 + } else { + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + }) +} + +#[inline] +pub fn wchar_to_bytes(ch: u32, buffer: &mut [u8]) { + if WCHAR_SIZE == 2 { + if buffer.len() >= 2 { + buffer[..2].copy_from_slice(&(ch as u16).to_ne_bytes()); + } + } else if buffer.len() >= 4 { + buffer[..4].copy_from_slice(&ch.to_ne_bytes()); + } +} + +pub fn wstring_from_bytes(buffer: &[u8]) -> String { + let mut chars = Vec::new(); + for chunk in buffer.chunks(WCHAR_SIZE) { + if chunk.len() < WCHAR_SIZE { + break; + } + let Some(code) = wchar_from_bytes(chunk) else { + break; + }; + if code == 0 { + break; + } + if let Some(ch) = char::from_u32(code) { + chars.push(ch); + } + } + chars.into_iter().collect() +} + +pub fn wchar_array_field_value(buffer: &[u8]) -> String { + let wchars: Vec = buffer + .chunks(WCHAR_SIZE) + .filter_map(|chunk| wchar_from_bytes(chunk).filter(|&wchar| wchar != 0)) + .map(|wchar| wchar as WChar) + .collect(); + wide_chars_to_wtf8(&wchars).to_string() +} + +pub fn write_wchar_array_value(buffer: &mut [u8], s: &Wtf8) -> Result<(), WCharArrayWriteError> { + let wchar_count = buffer.len() / WCHAR_SIZE; + let char_count = s.code_points().count(); + + if char_count > wchar_count { + return Err(WCharArrayWriteError::TooLong); + } + + for (i, ch) in s.code_points().enumerate() { + let offset = i * WCHAR_SIZE; + wchar_to_bytes(ch.to_u32(), &mut buffer[offset..]); + } + + let terminator_offset = char_count * WCHAR_SIZE; + if terminator_offset + WCHAR_SIZE <= buffer.len() { + wchar_to_bytes(0, &mut buffer[terminator_offset..]); + } + Ok(()) +} + +pub fn encode_wtf8_to_wchar_padded(s: &Wtf8, size: usize) -> Vec { + let mut wchar_bytes = Vec::with_capacity(size); + for cp in s.code_points().take(size / WCHAR_SIZE) { + let mut bytes = [0u8; 4]; + wchar_to_bytes(cp.to_u32(), &mut bytes); + wchar_bytes.extend_from_slice(&bytes[..WCHAR_SIZE]); + } + while wchar_bytes.len() < size { + wchar_bytes.push(0); + } + wchar_bytes +} + +pub fn wchar_null_terminated_bytes(s: &Wtf8) -> Vec { + let wchars: Vec = s + .code_points() + .map(|cp| cp.to_u32() as WChar) + .chain(core::iter::once(0)) + .collect(); + vec_into_bytes(wchars) +} + +pub fn vec_into_bytes(vec: Vec) -> Vec { + let len = vec.len() * core::mem::size_of::(); + let cap = vec.capacity() * core::mem::size_of::(); + let ptr = vec.as_ptr() as *mut u8; + core::mem::forget(vec); + unsafe { Vec::from_raw_parts(ptr, len, cap) } +} + +pub enum IntegerValue { + Signed(i64), + Unsigned(u64), +} + +pub enum AddressValue { + ByteString(u8), + Integer(IntegerValue), + Float(f64), + Pointer(usize), + Bytes(Vec), +} + +pub enum AddressWriteValue<'a> { + Pointer(usize), + U8(u8), + I16(i16), + I32(i32), + I64(i64), + Float(f64), + Bytes(&'a [u8]), +} + +pub enum ArrayElementWriteValue<'a> { + Byte(u8), + Wchar(u32), + Pointer { value: usize, size: usize }, + Float { value: f64, size: usize }, + Bytes { bytes: &'a [u8], size: usize }, +} + +pub enum WCharArrayWriteError { + TooLong, +} + +pub enum SimpleStorageValue { + Byte(u8), + Wchar(u32), + Signed(i128), + Float(f64), + Bool(bool), + Pointer(usize), + ObjectId(usize), + Zero, +} + +pub enum DecodedValue { + Bytes(Vec), + Signed(i64), + Unsigned(u64), + Float(f64), + Bool(bool), + Pointer(usize), + String(String), + None, +} + +pub enum CallbackResultValue { + Signed(i64), + Unsigned(u64), + Float(f64), + Pointer(usize), + Bool(bool), +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub enum FfiArgRef<'a> { + U8(&'a u8), + I8(&'a i8), + U16(&'a u16), + I16(&'a i16), + U32(&'a u32), + I32(&'a i32), + U64(&'a u64), + I64(&'a i64), + F32(&'a f32), + F64(&'a f64), + Pointer(&'a usize), +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +#[derive(Debug, Clone, Copy)] +pub enum FfiValue { + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + F32(f32), + F64(f64), + Pointer(usize), +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub enum CallResult { + Void, + Pointer(usize), + Value(low::ffi_arg), +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub enum CdeclArgValue { + Pointer(isize), + Int(isize), +} + +pub const POINTER_SIZE: usize = core::mem::size_of::(); +pub const POINTER_FORMAT: &str = "X{}"; + +pub fn pointer_size() -> usize { + POINTER_SIZE +} + +pub fn pointer_format() -> &'static str { + POINTER_FORMAT +} + +pub fn has_pointer_width(buffer: &[u8]) -> bool { + buffer.len() >= POINTER_SIZE +} + +pub fn pointer_bytes(value: usize) -> Vec { + pointer_to_sized_bytes(value, POINTER_SIZE) +} + +pub fn null_pointer_bytes() -> Vec { + vec![0; POINTER_SIZE] +} + +pub fn zeroed_bytes(size: usize) -> Vec { + vec![0; size] +} + +pub fn copy_to_sized_bytes(src: &[u8], size: usize) -> Vec { + let mut result = zeroed_bytes(size); + let len = src.len().min(size); + result[..len].copy_from_slice(&src[..len]); + result +} + +pub fn char_array_assignment_bytes(src: &[u8]) -> &[u8] { + if let Some(null_pos) = src.iter().position(|&b| b == 0) { + &src[..=null_pos] + } else { + src + } +} + +pub fn char_array_field_value(buffer: &[u8]) -> &[u8] { + let end = buffer.iter().position(|&b| b == 0).unwrap_or(buffer.len()); + &buffer[..end] +} + +pub fn write_char_array_value(buffer: &mut [u8], src: &[u8]) { + buffer[..src.len()].copy_from_slice(src); + if src.len() < buffer.len() { + buffer[src.len()] = 0; + } +} + +pub fn write_char_array_raw(buffer: &mut [u8], src: &[u8]) { + buffer[..src.len()].copy_from_slice(src); +} + +pub fn write_prefix_limited(buffer: &mut [u8], src: &[u8], size: usize) { + let copy_size = size.min(buffer.len()).min(src.len()); + if copy_size > 0 { + buffer[..copy_size].copy_from_slice(&src[..copy_size]); + } +} + +pub fn pointer_to_sized_bytes_endian(value: usize, size: usize, swapped: bool) -> Vec { + let mut bytes = pointer_to_sized_bytes(value, size); + if swapped { + bytes.reverse(); + } + bytes +} + +pub fn write_pointer_to_buffer_at(buffer: &mut [u8], offset: usize, size: usize, value: usize) { + if offset + size <= buffer.len() { + let ptr_bytes = pointer_to_sized_bytes(value, size); + buffer[offset..offset + size].copy_from_slice(&ptr_bytes); + } +} + +pub fn write_array_element(buffer: &mut [u8], offset: usize, value: ArrayElementWriteValue<'_>) { + match value { + ArrayElementWriteValue::Byte(value) => { + if offset < buffer.len() { + buffer[offset] = value; + } + } + ArrayElementWriteValue::Wchar(value) => { + if offset + WCHAR_SIZE <= buffer.len() { + wchar_to_bytes(value, &mut buffer[offset..]); + } + } + ArrayElementWriteValue::Pointer { value, size } => { + write_pointer_to_buffer_at(buffer, offset, size, value); + } + ArrayElementWriteValue::Float { value, size } => { + if offset + size <= buffer.len() + && let Some(float_bytes) = float_to_sized_bytes(value, size) + { + buffer[offset..offset + size].copy_from_slice(&float_bytes); + } + } + ArrayElementWriteValue::Bytes { bytes, size } => { + let copy_len = bytes.len().min(size); + if offset + copy_len <= buffer.len() { + buffer[offset..offset + copy_len].copy_from_slice(&bytes[..copy_len]); + } + } + } +} + +pub fn read_array_element( + buffer: &[u8], + offset: usize, + element_size: usize, + type_code: Option<&str>, +) -> DecodedValue { + let Some(rest) = buffer.get(offset..) else { + return DecodedValue::Signed(0); + }; + match type_code { + Some("c") => DecodedValue::Bytes(vec![buffer.get(offset).copied().unwrap_or(0)]), + Some("u") => { + let value = wchar_from_bytes(rest) + .and_then(char::from_u32) + .map(|c| c.to_string()) + .unwrap_or_default(); + DecodedValue::String(value) + } + Some("z") => { + if offset + element_size > buffer.len() { + return DecodedValue::None; + } + let ptr_bytes = &buffer[offset..offset + element_size]; + let ptr_val = read_pointer_from_buffer(ptr_bytes); + unsafe { + match read_c_string_from_address(ptr_val) { + Some(bytes) => DecodedValue::Bytes(bytes), + None => DecodedValue::None, + } + } + } + Some("Z") => { + if offset + element_size > buffer.len() { + return DecodedValue::None; + } + let ptr_bytes = &buffer[offset..offset + element_size]; + let ptr_val = read_pointer_from_buffer(ptr_bytes); + unsafe { + match read_wide_string_from_address(ptr_val) { + Some(s) => DecodedValue::String(s.to_string()), + None => DecodedValue::None, + } + } + } + Some("f") => DecodedValue::Float( + rest.first_chunk::<4>() + .copied() + .map_or(0.0, f32::from_ne_bytes) as f64, + ), + Some("d" | "g") => DecodedValue::Float( + rest.first_chunk::<8>() + .copied() + .map_or(0.0, f64::from_ne_bytes), + ), + _ => { + if let Some(bytes) = rest.get(..element_size) { + let is_unsigned = matches!(type_code, Some("B" | "H" | "I" | "L" | "Q")); + match int_from_bytes(bytes, element_size, is_unsigned) { + IntegerValue::Signed(value) => DecodedValue::Signed(value), + IntegerValue::Unsigned(value) => DecodedValue::Unsigned(value), + } + } else { + DecodedValue::Signed(0) + } + } + } +} + +pub fn int_from_bytes(bytes: &[u8], size: usize, unsigned: bool) -> IntegerValue { + match (size, unsigned) { + (1, false) => IntegerValue::Signed(bytes[0] as i8 as i64), + (1, true) => IntegerValue::Unsigned(bytes[0].into()), + (2, false) => IntegerValue::Signed(i16::from_ne_bytes([bytes[0], bytes[1]]).into()), + (2, true) => IntegerValue::Unsigned(u16::from_ne_bytes([bytes[0], bytes[1]]).into()), + (4, false) => IntegerValue::Signed( + i32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).into(), + ), + (4, true) => IntegerValue::Unsigned( + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).into(), + ), + (8, false) => IntegerValue::Signed(i64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ])), + (8, true) => IntegerValue::Unsigned(u64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ])), + _ => IntegerValue::Signed(0), + } +} + +pub fn int_to_sized_bytes(value: i64, size: usize) -> Vec { + match size { + 1 => (value as i8).to_ne_bytes().to_vec(), + 2 => (value as i16).to_ne_bytes().to_vec(), + 4 => (value as i32).to_ne_bytes().to_vec(), + 8 => value.to_ne_bytes().to_vec(), + _ => vec![0u8; size], + } +} + +pub fn uint_to_sized_bytes(value: u64, size: usize) -> Vec { + match size { + 1 => (value as u8).to_ne_bytes().to_vec(), + 2 => (value as u16).to_ne_bytes().to_vec(), + 4 => (value as u32).to_ne_bytes().to_vec(), + 8 => value.to_ne_bytes().to_vec(), + _ => vec![0u8; size], + } +} + +pub fn int_to_sized_bytes_endian(value: i64, size: usize, swapped: bool) -> Vec { + if swapped { + #[cfg(target_endian = "little")] + { + match size { + 1 => (value as i8).to_ne_bytes().to_vec(), + 2 => (value as i16).to_be_bytes().to_vec(), + 4 => (value as i32).to_be_bytes().to_vec(), + 8 => value.to_be_bytes().to_vec(), + _ => vec![0u8; size], + } + } + #[cfg(target_endian = "big")] + { + match size { + 1 => (value as i8).to_ne_bytes().to_vec(), + 2 => (value as i16).to_le_bytes().to_vec(), + 4 => (value as i32).to_le_bytes().to_vec(), + 8 => value.to_le_bytes().to_vec(), + _ => vec![0u8; size], + } + } + } else { + int_to_sized_bytes(value, size) + } +} + +pub fn uint_to_sized_bytes_endian(value: u64, size: usize, swapped: bool) -> Vec { + if swapped { + #[cfg(target_endian = "little")] + { + match size { + 1 => (value as u8).to_ne_bytes().to_vec(), + 2 => (value as u16).to_be_bytes().to_vec(), + 4 => (value as u32).to_be_bytes().to_vec(), + 8 => value.to_be_bytes().to_vec(), + _ => vec![0u8; size], + } + } + #[cfg(target_endian = "big")] + { + match size { + 1 => (value as u8).to_ne_bytes().to_vec(), + 2 => (value as u16).to_le_bytes().to_vec(), + 4 => (value as u32).to_le_bytes().to_vec(), + 8 => value.to_le_bytes().to_vec(), + _ => vec![0u8; size], + } + } + } else { + uint_to_sized_bytes(value, size) + } +} + +pub fn float_to_sized_bytes(value: f64, size: usize) -> Option> { + match size { + 4 => Some((value as f32).to_ne_bytes().to_vec()), + 8 => Some(value.to_ne_bytes().to_vec()), + _ => None, + } +} + +pub fn float_to_sized_bytes_endian(value: f64, size: usize, swapped: bool) -> Option> { + if swapped { + #[cfg(target_endian = "little")] + { + match size { + 4 => Some((value as f32).to_be_bytes().to_vec()), + 8 => Some(value.to_be_bytes().to_vec()), + _ => None, + } + } + #[cfg(target_endian = "big")] + { + match size { + 4 => Some((value as f32).to_le_bytes().to_vec()), + 8 => Some(value.to_le_bytes().to_vec()), + _ => None, + } + } + } else { + float_to_sized_bytes(value, size) + } +} + +pub fn pointer_to_sized_bytes(value: usize, size: usize) -> Vec { + let mut result = vec![0u8; size]; + let bytes = value.to_ne_bytes(); + let len = core::cmp::min(bytes.len(), size); + result[..len].copy_from_slice(&bytes[..len]); + result +} + +pub fn wchar_code_to_bytes_endian(ch: u32, swapped: bool) -> Vec { + let mut buffer = vec![0u8; WCHAR_SIZE]; + wchar_to_bytes(ch, &mut buffer); + if swapped { + buffer.reverse(); + } + buffer +} + +pub fn simple_storage_value_to_bytes_endian( + type_code: &str, + value: SimpleStorageValue, + swapped: bool, +) -> Vec { + match type_code { + "c" => match value { + SimpleStorageValue::Byte(value) => vec![value], + _ => vec![0], + }, + "u" => match value { + SimpleStorageValue::Wchar(value) => wchar_code_to_bytes_endian(value, swapped), + _ => vec![0; WCHAR_SIZE], + }, + "b" => match value { + SimpleStorageValue::Signed(value) => vec![(value as i8) as u8], + _ => vec![0], + }, + "B" => match value { + SimpleStorageValue::Signed(value) => vec![value as u8], + _ => vec![0], + }, + "h" => match value { + SimpleStorageValue::Signed(value) => { + int_to_sized_bytes_endian((value as i16).into(), 2, swapped) + } + _ => vec![0; 2], + }, + "H" => match value { + SimpleStorageValue::Signed(value) => { + uint_to_sized_bytes_endian((value as u16).into(), 2, swapped) + } + _ => vec![0; 2], + }, + "i" => match value { + SimpleStorageValue::Signed(value) => { + int_to_sized_bytes_endian((value as i32).into(), 4, swapped) + } + _ => vec![0; 4], + }, + "I" => match value { + SimpleStorageValue::Signed(value) => { + uint_to_sized_bytes_endian((value as u32).into(), 4, swapped) + } + _ => vec![0; 4], + }, + "l" => match value { + SimpleStorageValue::Signed(value) => c_long_bytes_endian(value, swapped), + _ => vec![0; simple_type_size("l").expect("invalid ctypes simple type")], + }, + "L" => match value { + SimpleStorageValue::Signed(value) => c_ulong_bytes_endian(value, swapped), + _ => vec![0; simple_type_size("L").expect("invalid ctypes simple type")], + }, + "q" => match value { + SimpleStorageValue::Signed(value) => { + int_to_sized_bytes_endian(value as i64, 8, swapped) + } + _ => vec![0; 8], + }, + "Q" => match value { + SimpleStorageValue::Signed(value) => { + uint_to_sized_bytes_endian(value as u64, 8, swapped) + } + _ => vec![0; 8], + }, + "f" => match value { + SimpleStorageValue::Float(value) => { + float_to_sized_bytes_endian(value, 4, swapped).expect("f32 size is fixed") + } + _ => vec![0; 4], + }, + "d" => match value { + SimpleStorageValue::Float(value) => { + float_to_sized_bytes_endian(value, 8, swapped).expect("f64 size is fixed") + } + _ => vec![0; 8], + }, + "g" => { + let value = match value { + SimpleStorageValue::Float(value) => value, + _ => 0.0, + }; + let mut result = + float_to_sized_bytes_endian(value, 8, swapped).expect("f64 size is fixed"); + result.resize( + simple_type_size("g").expect("invalid ctypes simple type"), + 0, + ); + result + } + "?" => match value { + SimpleStorageValue::Bool(value) => vec![if value { 1 } else { 0 }], + _ => vec![0], + }, + "v" => match value { + SimpleStorageValue::Bool(value) => { + let value: i16 = if value { -1 } else { 0 }; + int_to_sized_bytes_endian(value.into(), 2, swapped) + } + _ => vec![0; 2], + }, + "P" | "z" | "Z" => match value { + SimpleStorageValue::Pointer(value) => { + uint_to_sized_bytes_endian(value as u64, pointer_size(), swapped) + } + _ => null_pointer_bytes(), + }, + "O" => match value { + SimpleStorageValue::ObjectId(value) => { + uint_to_sized_bytes_endian(value as u64, pointer_size(), swapped) + } + _ => null_pointer_bytes(), + }, + _ => vec![0], + } +} + +pub fn utf16z_bytes(s: &Wtf8) -> Vec { + vec_into_bytes::(s.encode_wide().chain(core::iter::once(0)).collect()) +} + +pub fn null_terminated_bytes(bytes: &[u8]) -> Vec { + let mut buffer = bytes.to_vec(); + buffer.push(0); + buffer +} + +pub fn decode_type_code(type_code: &str, bytes: &[u8]) -> DecodedValue { + match type_code { + "c" => DecodedValue::Bytes(bytes.to_vec()), + "b" => DecodedValue::Signed(if !bytes.is_empty() { + bytes[0] as i8 as i64 + } else { + 0 + }), + "B" => DecodedValue::Unsigned(if !bytes.is_empty() { + bytes[0].into() + } else { + 0 + }), + "h" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Signed(if bytes.len() >= SIZE { + c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")).into() + } else { + 0 + }) + } + "H" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Unsigned(if bytes.len() >= SIZE { + c_ushort::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")).into() + } else { + 0 + }) + } + "i" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Signed(if bytes.len() >= SIZE { + c_int::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")).into() + } else { + 0 + }) + } + "I" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Unsigned(if bytes.len() >= SIZE { + c_uint::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")).into() + } else { + 0 + }) + } + "l" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Signed(if bytes.len() >= SIZE { + #[allow( + clippy::unnecessary_cast, + clippy::useless_conversion, + reason = "c_long width is platform-dependent" + )] + let val: i64 = + c_long::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) as i64; + val + } else { + 0 + }) + } + "L" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Unsigned(if bytes.len() >= SIZE { + #[allow( + clippy::unnecessary_cast, + clippy::useless_conversion, + reason = "c_ulong width is platform-dependent" + )] + let val: u64 = + c_ulong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) as u64; + val + } else { + 0 + }) + } + "q" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Signed(if bytes.len() >= SIZE { + c_longlong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }) + } + "Q" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Unsigned(if bytes.len() >= SIZE { + c_ulonglong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }) + } + "f" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Float(if bytes.len() >= SIZE { + c_float::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) as f64 + } else { + 0.0 + }) + } + "d" | "g" => { + const SIZE: usize = core::mem::size_of::(); + DecodedValue::Float(if bytes.len() >= SIZE { + c_double::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0.0 + }) + } + "?" => DecodedValue::Bool(!bytes.is_empty() && bytes[0] != 0), + "v" => { + const SIZE: usize = core::mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + DecodedValue::Bool(val != 0) + } + "z" => unsafe { + match read_c_string_from_address(read_pointer_from_buffer(bytes)) { + Some(bytes) => DecodedValue::Bytes(bytes), + None => DecodedValue::None, + } + }, + "Z" => unsafe { + match read_wide_string_from_address(read_pointer_from_buffer(bytes)) { + Some(s) => DecodedValue::String(s.to_string()), + None => DecodedValue::None, + } + }, + "P" => DecodedValue::Pointer(read_pointer_from_buffer(bytes)), + "u" => { + let val = if bytes.len() >= core::mem::size_of::() { + let wc = if core::mem::size_of::() == 2 { + u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 + } else { + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + }; + char::from_u32(wc).unwrap_or('\0') + } else { + '\0' + }; + DecodedValue::String(val.to_string()) + } + _ => DecodedValue::None, + } +} + +/// # Safety +/// +/// `ptr` must point to a valid callback argument storage for the given ctypes +/// `type_code`. +pub unsafe fn callback_arg_value(type_code: Option<&str>, ptr: *const c_void) -> DecodedValue { + match type_code { + Some("b") => DecodedValue::Signed(unsafe { *(ptr as *const i8) as i64 }), + Some("B") => DecodedValue::Unsigned(unsafe { *(ptr as *const u8) as u64 }), + Some("c") => DecodedValue::Bytes(vec![unsafe { *(ptr as *const u8) }]), + Some("h") => DecodedValue::Signed(unsafe { *(ptr as *const i16) as i64 }), + Some("H") => DecodedValue::Unsigned(unsafe { *(ptr as *const u16) as u64 }), + Some("i") => DecodedValue::Signed(unsafe { *(ptr as *const i32) as i64 }), + Some("I") => DecodedValue::Unsigned(unsafe { *(ptr as *const u32) as u64 }), + Some("l") => DecodedValue::Signed({ + #[allow( + clippy::unnecessary_cast, + clippy::useless_conversion, + reason = "c_long width is platform-dependent" + )] + let val: i64 = unsafe { *(ptr as *const c_long) as i64 }; + val + }), + Some("L") => DecodedValue::Unsigned({ + #[allow( + clippy::unnecessary_cast, + clippy::useless_conversion, + reason = "c_ulong width is platform-dependent" + )] + let val: u64 = unsafe { *(ptr as *const c_ulong) as u64 }; + val + }), + Some("q") => DecodedValue::Signed(unsafe { *(ptr as *const c_longlong) }), + Some("Q") => DecodedValue::Unsigned(unsafe { *(ptr as *const c_ulonglong) }), + Some("f") => DecodedValue::Float(unsafe { *(ptr as *const f32) as f64 }), + Some("d") => DecodedValue::Float(unsafe { *(ptr as *const f64) }), + Some("z") => { + let cstr_ptr = unsafe { *(ptr as *const *const c_char) }; + if cstr_ptr.is_null() { + DecodedValue::None + } else { + DecodedValue::Bytes(unsafe { read_c_string_bytes(cstr_ptr) }) + } + } + Some("Z") => { + let wstr_ptr = unsafe { *(ptr as *const *const WChar) }; + if wstr_ptr.is_null() { + DecodedValue::None + } else { + DecodedValue::String(unsafe { read_wide_string(wstr_ptr) }.to_string()) + } + } + Some("P") => DecodedValue::Pointer(unsafe { *(ptr as *const usize) }), + Some("?") => DecodedValue::Bool(unsafe { *(ptr as *const u8) != 0 }), + _ => DecodedValue::None, + } +} + +/// # Safety +/// +/// `args` must point to a libffi callback argument array with a valid entry at +/// `index`, and that entry must be valid for the given ctypes `type_code`. +pub unsafe fn callback_arg_value_at( + type_code: Option<&str>, + args: *const *const c_void, + index: usize, +) -> DecodedValue { + let ptr = unsafe { *args.add(index) }; + unsafe { callback_arg_value(type_code, ptr) } +} + +/// # Safety +/// +/// `result` must point to valid callback result storage for the given ctypes +/// `type_code`. +pub unsafe fn write_callback_result( + type_code: Option<&str>, + result: *mut c_void, + value: CallbackResultValue, +) { + match (type_code, value) { + (Some("b"), CallbackResultValue::Signed(v)) => unsafe { *(result as *mut i8) = v as i8 }, + (Some("B" | "c"), CallbackResultValue::Unsigned(v)) => unsafe { + *(result as *mut u8) = v as u8 + }, + (Some("h"), CallbackResultValue::Signed(v)) => unsafe { *(result as *mut i16) = v as i16 }, + (Some("H"), CallbackResultValue::Unsigned(v)) => unsafe { + *(result as *mut u16) = v as u16 + }, + (Some("i"), CallbackResultValue::Signed(v)) => unsafe { + *(result as *mut CallbackIntResult) = v as i32 as CallbackIntResult + }, + (Some("I"), CallbackResultValue::Unsigned(v)) => unsafe { + *(result as *mut u32) = v as u32 + }, + (Some("l"), CallbackResultValue::Signed(v)) => unsafe { + *(result as *mut c_long) = v as c_long + }, + (Some("L"), CallbackResultValue::Unsigned(v)) => unsafe { + *(result as *mut c_ulong) = v as c_ulong + }, + (Some("q"), CallbackResultValue::Signed(v)) => unsafe { *(result as *mut i64) = v }, + (Some("Q"), CallbackResultValue::Unsigned(v)) => unsafe { *(result as *mut u64) = v }, + (Some("f"), CallbackResultValue::Float(v)) => unsafe { *(result as *mut f32) = v as f32 }, + (Some("d"), CallbackResultValue::Float(v)) => unsafe { *(result as *mut f64) = v }, + (Some("P" | "z" | "Z"), CallbackResultValue::Pointer(v)) => unsafe { + *(result as *mut usize) = v + }, + (Some("?"), CallbackResultValue::Bool(v)) => unsafe { *(result as *mut u8) = u8::from(v) }, + _ => {} + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_value_from_type_code(type_code: &str, buffer: &[u8]) -> FfiValue { + match type_code { + "c" | "b" => FfiValue::I8(buffer.first().map_or(0, |&b| b as i8)), + "B" => FfiValue::U8(buffer.first().copied().unwrap_or(0)), + "h" => FfiValue::I16(buffer.first_chunk().copied().map_or(0, i16::from_ne_bytes)), + "H" => FfiValue::U16(buffer.first_chunk().copied().map_or(0, u16::from_ne_bytes)), + "i" => FfiValue::I32(buffer.first_chunk().copied().map_or(0, i32::from_ne_bytes)), + "I" => FfiValue::U32(buffer.first_chunk().copied().map_or(0, u32::from_ne_bytes)), + "l" | "q" => FfiValue::I64(if let Some(&bytes) = buffer.first_chunk::<8>() { + i64::from_ne_bytes(bytes) + } else if let Some(&bytes) = buffer.first_chunk::<4>() { + i32::from_ne_bytes(bytes).into() + } else { + 0 + }), + "L" | "Q" => FfiValue::U64(if let Some(&bytes) = buffer.first_chunk::<8>() { + u64::from_ne_bytes(bytes) + } else if let Some(&bytes) = buffer.first_chunk::<4>() { + u32::from_ne_bytes(bytes).into() + } else { + 0 + }), + "f" => FfiValue::F32( + buffer + .first_chunk::<4>() + .copied() + .map_or(0.0, f32::from_ne_bytes), + ), + "d" | "g" => FfiValue::F64( + buffer + .first_chunk::<8>() + .copied() + .map_or(0.0, f64::from_ne_bytes), + ), + "z" | "Z" | "P" | "O" => FfiValue::Pointer(read_pointer_from_buffer(buffer)), + "?" => FfiValue::U8(if buffer.first().is_some_and(|&b| b != 0) { + 1 + } else { + 0 + }), + "u" => FfiValue::U32(buffer.first_chunk().copied().map_or(0, u32::from_ne_bytes)), + _ => FfiValue::Pointer(0), + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_value_from_type(buffer: &[u8], ty: Type) -> Option { + if core::ptr::eq(ty.as_raw_ptr(), Type::u8().as_raw_ptr()) { + Some(FfiValue::U8(*buffer.first()?)) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::i8().as_raw_ptr()) { + Some(FfiValue::I8(*buffer.first()? as i8)) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::u16().as_raw_ptr()) { + Some(FfiValue::U16(u16::from_ne_bytes( + *buffer.first_chunk::<2>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::i16().as_raw_ptr()) { + Some(FfiValue::I16(i16::from_ne_bytes( + *buffer.first_chunk::<2>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::u32().as_raw_ptr()) { + Some(FfiValue::U32(u32::from_ne_bytes( + *buffer.first_chunk::<4>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::i32().as_raw_ptr()) { + Some(FfiValue::I32(i32::from_ne_bytes( + *buffer.first_chunk::<4>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::u64().as_raw_ptr()) { + Some(FfiValue::U64(u64::from_ne_bytes( + *buffer.first_chunk::<8>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::i64().as_raw_ptr()) { + Some(FfiValue::I64(i64::from_ne_bytes( + *buffer.first_chunk::<8>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::f32().as_raw_ptr()) { + Some(FfiValue::F32(f32::from_ne_bytes( + *buffer.first_chunk::<4>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::f64().as_raw_ptr()) { + Some(FfiValue::F64(f64::from_ne_bytes( + *buffer.first_chunk::<8>()?, + ))) + } else if core::ptr::eq(ty.as_raw_ptr(), Type::pointer().as_raw_ptr()) { + Some(FfiValue::Pointer(read_pointer_from_buffer(buffer))) + } else { + None + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_type_from_code(ty: &str) -> Option { + match ty { + "c" => Some(Type::u8()), + "u" => Some(if core::mem::size_of::() == 2 { + Type::u16() + } else { + Type::u32() + }), + "b" => Some(Type::i8()), + "B" | "?" => Some(Type::u8()), + "h" | "v" => Some(Type::i16()), + "H" => Some(Type::u16()), + "i" => Some(Type::i32()), + "I" => Some(Type::u32()), + "l" => Some(if core::mem::size_of::() == 8 { + Type::i64() + } else { + Type::i32() + }), + "L" => Some(if core::mem::size_of::() == 8 { + Type::u64() + } else { + Type::u32() + }), + "q" => Some(Type::i64()), + "Q" => Some(Type::u64()), + "f" => Some(Type::f32()), + "d" | "g" => Some(Type::f64()), + "z" | "Z" | "P" | "X" | "O" => Some(Type::pointer()), + "void" => Some(Type::void()), + _ => None, + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_type_from_tag(tag: u8) -> Type { + match tag { + b'c' | b'b' => Type::i8(), + b'B' | b'?' => Type::u8(), + b'h' | b'v' => Type::i16(), + b'H' => Type::u16(), + b'i' => Type::i32(), + b'I' => Type::u32(), + b'l' => { + if core::mem::size_of::() == 8 { + Type::i64() + } else { + Type::i32() + } + } + b'L' => { + if core::mem::size_of::() == 8 { + Type::u64() + } else { + Type::u32() + } + } + b'q' => Type::i64(), + b'Q' => Type::u64(), + b'f' => Type::f32(), + b'd' | b'g' => Type::f64(), + b'u' => { + if core::mem::size_of::() == 2 { + Type::u16() + } else { + Type::u32() + } + } + _ => Type::pointer(), + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_type_from_format(fmt: &str) -> Type { + match fmt.trim_start_matches(['<', '>', '!', '@', '=']) { + "b" => Type::i8(), + "B" => Type::u8(), + "h" => Type::i16(), + "H" => Type::u16(), + "i" | "l" => Type::i32(), + "I" | "L" => Type::u32(), + "q" => Type::i64(), + "Q" => Type::u64(), + "f" => Type::f32(), + "d" => Type::f64(), + "P" | "z" | "Z" | "O" => Type::pointer(), + _ => Type::u8(), + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_repeat_type(elem_type: Type, len: usize) -> Type { + Type::structure(core::iter::repeat_n(elem_type, len)) +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_byte_struct(size: usize) -> Type { + ffi_repeat_type(Type::u8(), size) +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_pointer_type() -> Type { + Type::pointer() +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_i32_type() -> Type { + Type::i32() +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_f64_type() -> Type { + Type::f64() +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_void_type() -> Type { + Type::void() +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_type_for_return_size(size: usize) -> Type { + if size <= 4 { + Type::i32() + } else if size <= 8 { + Type::i64() + } else { + Type::pointer() + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CTypeParamKind { + Structure, + Union, + Array, + Pointer, + Simple, +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_type_for_layout( + kind: CTypeParamKind, + ffi_field_types: &[Type], + size: usize, + length: usize, + format: Option<&str>, +) -> Type { + const MAX_FFI_STRUCT_SIZE: usize = 1024 * 1024; + + match kind { + CTypeParamKind::Structure | CTypeParamKind::Union => { + if !ffi_field_types.is_empty() { + Type::structure(ffi_field_types.iter().cloned()) + } else if size <= MAX_FFI_STRUCT_SIZE { + ffi_byte_struct(size) + } else { + ffi_pointer_type() + } + } + CTypeParamKind::Array => { + if size > MAX_FFI_STRUCT_SIZE || length > MAX_FFI_STRUCT_SIZE { + ffi_pointer_type() + } else if let Some(fmt) = format { + ffi_repeat_type(ffi_type_from_format(fmt), length) + } else { + ffi_byte_struct(size) + } + } + CTypeParamKind::Pointer => ffi_pointer_type(), + CTypeParamKind::Simple => { + if let Some(fmt) = format { + ffi_type_from_format(fmt) + } else { + Type::u8() + } + } + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn callproc( + code_ptr: CodePtr, + ffi_arg_types: Vec, + ffi_return_type: Type, + ffi_args: &[Arg<'_>], + restype_is_none: bool, + is_pointer_return: bool, +) -> CallResult { + let cif = Cif::new(ffi_arg_types, ffi_return_type); + if restype_is_none { + unsafe { cif.call::<()>(code_ptr, ffi_args) }; + CallResult::Void + } else if is_pointer_return { + CallResult::Pointer(unsafe { cif.call::(code_ptr, ffi_args) }) + } else { + CallResult::Value(unsafe { cif.call::(code_ptr, ffi_args) }) + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn call_cdecl_i32(code_ptr: usize, arg_types: Vec, arg_values: &[isize]) -> c_int { + let ffi_args: Vec<_> = arg_values.iter().map(Arg::new).collect(); + let cif = Cif::new(arg_types, Type::c_int()); + let code_ptr = CodePtr::from_ptr(code_ptr as *const _); + unsafe { cif.call(code_ptr, &ffi_args) } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn call_cdecl_i32_values(code_ptr: usize, args: &[CdeclArgValue]) -> c_int { + let mut arg_values = Vec::with_capacity(args.len()); + let mut arg_types = Vec::with_capacity(args.len()); + for arg in args { + match *arg { + CdeclArgValue::Pointer(value) => { + arg_values.push(value); + arg_types.push(Type::pointer()); + } + CdeclArgValue::Int(value) => { + arg_values.push(value); + arg_types.push(Type::isize()); + } + } + } + call_cdecl_i32(code_ptr, arg_types, &arg_values) +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_arg(value: FfiArgRef<'_>) -> Arg<'_> { + match value { + FfiArgRef::U8(v) => Arg::new(v), + FfiArgRef::I8(v) => Arg::new(v), + FfiArgRef::U16(v) => Arg::new(v), + FfiArgRef::I16(v) => Arg::new(v), + FfiArgRef::U32(v) => Arg::new(v), + FfiArgRef::I32(v) => Arg::new(v), + FfiArgRef::U64(v) => Arg::new(v), + FfiArgRef::I64(v) => Arg::new(v), + FfiArgRef::F32(v) => Arg::new(v), + FfiArgRef::F64(v) => Arg::new(v), + FfiArgRef::Pointer(v) => Arg::new(v), + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn ffi_arg_from_value(value: &FfiValue) -> Arg<'_> { + match value { + FfiValue::U8(v) => ffi_arg(FfiArgRef::U8(v)), + FfiValue::I8(v) => ffi_arg(FfiArgRef::I8(v)), + FfiValue::U16(v) => ffi_arg(FfiArgRef::U16(v)), + FfiValue::I16(v) => ffi_arg(FfiArgRef::I16(v)), + FfiValue::U32(v) => ffi_arg(FfiArgRef::U32(v)), + FfiValue::I32(v) => ffi_arg(FfiArgRef::I32(v)), + FfiValue::U64(v) => ffi_arg(FfiArgRef::U64(v)), + FfiValue::I64(v) => ffi_arg(FfiArgRef::I64(v)), + FfiValue::F32(v) => ffi_arg(FfiArgRef::F32(v)), + FfiValue::F64(v) => ffi_arg(FfiArgRef::F64(v)), + FfiValue::Pointer(v) => ffi_arg(FfiArgRef::Pointer(v)), + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn code_ptr_from_addr(addr: usize) -> Option { + if addr == 0 { + None + } else { + Some(CodePtr(addr as *mut _)) + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn null_code_ptr() -> CodePtr { + CodePtr(core::ptr::null_mut()) +} + +#[cfg(windows)] +pub enum ComMethodError { + NullComPointer, + NullVtablePointer, + NullFunctionPointer, +} + +#[cfg(windows)] +pub const HRESULT_E_POINTER: i32 = crate::windows::HRESULT_E_POINTER; + +#[cfg(windows)] +pub const HRESULT_S_OK: i32 = crate::windows::HRESULT_S_OK; + +#[cfg(windows)] +pub fn format_error_message(code: Option) -> Option { + crate::windows::format_error_message(code) +} + +#[cfg(windows)] +pub fn resolve_com_vtable_entry(com_ptr: usize, idx: usize) -> Result { + if com_ptr == 0 { + return Err(ComMethodError::NullComPointer); + } + let vtable_ptr = unsafe { *(com_ptr as *const usize) }; + if vtable_ptr == 0 { + return Err(ComMethodError::NullVtablePointer); + } + let fptr = unsafe { + let vtable = vtable_ptr as *const usize; + *vtable.add(idx) + }; + if fptr == 0 { + return Err(ComMethodError::NullFunctionPointer); + } + Ok(CodePtr(fptr as *mut _)) +} + +#[cfg(windows)] +pub fn copy_com_pointer(src_ptr: usize, dst_addr: usize) -> i32 { + if dst_addr == 0 { + return HRESULT_E_POINTER; + } + + if src_ptr != 0 { + unsafe { + let iunknown = src_ptr as *mut *const usize; + let vtable = *iunknown; + if vtable.is_null() { + return HRESULT_E_POINTER; + } + let addref_fn: extern "system" fn(*mut c_void) -> u32 = + core::mem::transmute(*vtable.add(1)); + addref_fn(src_ptr as *mut c_void); + } + } + + unsafe { + *(dst_addr as *mut usize) = src_ptr; + } + + HRESULT_S_OK +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub struct CallbackThunk { + #[allow(dead_code)] + closure: Closure<'static>, + userdata_ptr: *mut U, + code_ptr: CodePtr, +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +impl CallbackThunk { + pub fn new( + ffi_arg_types: Vec, + ffi_res_type: Type, + userdata: Box, + callback: unsafe extern "C" fn(&low::ffi_cif, &mut c_void, *const *const c_void, &U), + ) -> Self { + let cif = Cif::new(ffi_arg_types, ffi_res_type); + let userdata_ptr = Box::into_raw(userdata); + let userdata_ref: &'static U = unsafe { &*userdata_ptr }; + let closure = Closure::new(cif, callback, userdata_ref); + let code_ptr = CodePtr(*closure.code_ptr() as *mut _); + Self { + closure, + userdata_ptr, + code_ptr, + } + } + + pub fn code_ptr(&self) -> CodePtr { + self.code_ptr + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +impl Drop for CallbackThunk { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(self.userdata_ptr)); + } + } +} + +#[cfg(all( + any( + target_os = "linux", + target_os = "macos", + target_os = "windows", + target_os = "android" + ), + not(any(target_env = "musl", target_env = "sgx")) +))] +pub fn call_result_bytes(raw_result: &CallResult) -> Option<(Vec, usize)> { + match raw_result { + CallResult::Void => None, + CallResult::Pointer(ptr) => { + let bytes = ptr.to_ne_bytes(); + Some((bytes.to_vec(), core::mem::size_of::())) + } + CallResult::Value(val) => { + let bytes = val.to_ne_bytes(); + Some((bytes.to_vec(), core::mem::size_of_val(val))) + } + } +} + +/// # Safety +/// +/// `ptr` must point to `len` readable bytes. +pub unsafe fn bytes_at(ptr: *const u8, len: usize) -> Vec { + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() +} + +/// # Safety +/// +/// The caller must ensure `ptr..ptr+size` remains valid for the lifetime of the returned slice. +pub unsafe fn borrow_memory(ptr: *const u8, size: usize) -> &'static [u8] { + unsafe { core::slice::from_raw_parts(ptr, size) } +} + +/// # Safety +/// +/// The caller must ensure `ptr..ptr+size` remains valid and uniquely borrowed for the lifetime of the returned slice. +pub unsafe fn borrow_memory_mut(ptr: *mut u8, size: usize) -> &'static mut [u8] { + unsafe { core::slice::from_raw_parts_mut(ptr, size) } +} + +/// # Safety +/// +/// `slice` must point to memory that is valid and writable for its full length. +#[allow( + clippy::mut_from_ref, + reason = "ctypes borrowed buffers may wrap writable memory behind a shared slice" +)] +pub unsafe fn borrowed_slice_as_mut(slice: &[u8]) -> &mut [u8] { + unsafe { core::slice::from_raw_parts_mut(slice.as_ptr() as *mut u8, slice.len()) } +} + +pub fn wide_chars_to_wtf8(wchars: &[WChar]) -> Wtf8Buf { + #[cfg(windows)] + { + let wide: Vec = wchars.to_vec(); + Wtf8Buf::from_wide(&wide) + } + #[cfg(not(windows))] + { + #[allow( + clippy::useless_conversion, + reason = "wchar_t is i32 on some platforms and u32 on others" + )] + let s: String = wchars + .iter() + .filter_map(|&c| u32::try_from(c).ok().and_then(char::from_u32)) + .collect(); + Wtf8Buf::from_string(s) + } +} + +/// # Safety +/// +/// `ptr` must be a valid NUL-terminated wide C string. +pub unsafe fn read_wide_string(ptr: *const WChar) -> Wtf8Buf { + let len = unsafe { wcslen(ptr) }; + let wchars = unsafe { core::slice::from_raw_parts(ptr, len) }; + wide_chars_to_wtf8(wchars) +} + +/// # Safety +/// +/// `addr` must either be zero or a valid NUL-terminated C string pointer. +pub unsafe fn read_c_string_from_address(addr: usize) -> Option> { + if addr == 0 { + None + } else { + Some(unsafe { read_c_string_bytes(addr as *const c_char) }) + } +} + +/// # Safety +/// +/// `addr` must either be zero or a valid NUL-terminated wide C string pointer. +pub unsafe fn read_wide_string_from_address(addr: usize) -> Option { + if addr == 0 { + None + } else { + Some(unsafe { read_wide_string(addr as *const WChar) }) + } +} + +/// # Safety +/// +/// `ptr` must point to `len` readable wide characters. +pub unsafe fn read_wide_string_with_len(ptr: *const WChar, len: usize) -> Wtf8Buf { + let wchars = unsafe { core::slice::from_raw_parts(ptr, len) }; + wide_chars_to_wtf8(wchars) +} + +pub fn string_at(ptr: usize, size: isize) -> Result, StringAtError> { + if ptr == 0 { + return Err(StringAtError::NullPointer); + } + if size < 0 { + // SAFETY: caller passed a non-null C string pointer; same precondition as previous VM path. + return Ok(unsafe { read_c_string_bytes(ptr as _) }); + } + let len = { + let size_usize = size as usize; + if size_usize > isize::MAX as usize / 2 { + return Err(StringAtError::TooLong); + } + size_usize + }; + // SAFETY: caller requested exactly `len` readable bytes from non-null pointer. + Ok(unsafe { bytes_at(ptr as *const u8, len) }) +} + +pub fn wstring_at(ptr: usize, size: isize) -> Result { + if ptr == 0 { + return Err(StringAtError::NullPointer); + } + let w_ptr = ptr as *const WChar; + if size < 0 { + // SAFETY: caller passed a non-null NUL-terminated wide string pointer. + return Ok(unsafe { read_wide_string(w_ptr) }); + } + let len = { + let size_usize = size as usize; + if size_usize > isize::MAX as usize / core::mem::size_of::() { + return Err(StringAtError::TooLong); + } + size_usize + }; + // SAFETY: caller requested exactly `len` readable wide characters from non-null pointer. + Ok(unsafe { read_wide_string_with_len(w_ptr, len) }) +} + +/// # Safety +/// +/// `start` must be valid to read `len` elements following `step`. +pub unsafe fn read_bytes_strided(start: *const u8, len: usize, step: isize) -> Vec { + if step == 1 { + return unsafe { bytes_at(start, len) }; + } + let mut result = Vec::with_capacity(len); + let mut cur = start; + for _ in 0..len { + result.push(unsafe { *cur }); + cur = unsafe { cur.offset(step) }; + } + result +} + +pub fn pointer_item_address(ptr_value: usize, index: isize, element_size: usize) -> usize { + let offset = index * element_size as isize; + (ptr_value as isize + offset) as usize +} + +pub fn offset_address(base: usize, offset: isize) -> usize { + (base as isize + offset) as usize +} + +/// # Safety +/// +/// `ptr_value + start * element_size` must be valid to read `len` bytes following +/// `step * element_size`. +pub unsafe fn read_pointer_char_slice( + ptr_value: usize, + start: isize, + len: usize, + step: isize, + element_size: usize, +) -> Vec { + let start_addr = pointer_item_address(ptr_value, start, element_size) as *const u8; + if step == 1 { + unsafe { bytes_at(start_addr, len) } + } else { + unsafe { read_bytes_strided(start_addr, len, step * element_size as isize) } + } +} + +/// # Safety +/// +/// `start` must be valid to read `len` wide characters following `step`. +pub unsafe fn read_wide_string_strided(start: *const WChar, len: usize, step: isize) -> Wtf8Buf { + if step == 1 { + return unsafe { read_wide_string_with_len(start, len) }; + } + let mut wchars = Vec::with_capacity(len); + let mut cur = start; + for _ in 0..len { + wchars.push(unsafe { *cur }); + cur = unsafe { cur.offset(step) }; + } + wide_chars_to_wtf8(&wchars) +} + +/// # Safety +/// +/// `ptr_value + start * sizeof(wchar_t)` must be valid to read `len` wide +/// characters following `step`. +pub unsafe fn read_pointer_wchar_slice( + ptr_value: usize, + start: isize, + len: usize, + step: isize, +) -> Wtf8Buf { + let wchar_size = core::mem::size_of::(); + let start_addr = (ptr_value as isize + start * wchar_size as isize) as *const WChar; + unsafe { read_wide_string_strided(start_addr, len, step) } +} + +/// # Safety +/// +/// `addr` must be readable for `size` bytes and match the alignment/validity +/// requirements implied by `type_code`. +pub unsafe fn read_value_at_address( + addr: usize, + size: usize, + type_code: Option<&str>, +) -> AddressValue { + let ptr = addr as *const u8; + match type_code { + Some("c") => AddressValue::ByteString(unsafe { *ptr }), + Some("b") => AddressValue::Integer(IntegerValue::Signed(unsafe { *ptr as i8 as i64 })), + Some("B") => AddressValue::Integer(IntegerValue::Unsigned(unsafe { (*ptr).into() })), + Some("h") => AddressValue::Integer(IntegerValue::Signed( + unsafe { core::ptr::read_unaligned(ptr as *const i16) }.into(), + )), + Some("H") => AddressValue::Integer(IntegerValue::Unsigned( + unsafe { core::ptr::read_unaligned(ptr as *const u16) }.into(), + )), + Some("i") => AddressValue::Integer(IntegerValue::Signed( + unsafe { core::ptr::read_unaligned(ptr as *const i32) }.into(), + )), + Some("I") => AddressValue::Integer(IntegerValue::Unsigned( + unsafe { core::ptr::read_unaligned(ptr as *const u32) }.into(), + )), + Some("l") => AddressValue::Integer(IntegerValue::Signed(unsafe { + core::ptr::read_unaligned(ptr as *const c_long) + } as i64)), + Some("L") => AddressValue::Integer(IntegerValue::Unsigned(unsafe { + core::ptr::read_unaligned(ptr as *const c_ulong) + } as u64)), + Some("q") => AddressValue::Integer(IntegerValue::Signed(unsafe { + core::ptr::read_unaligned(ptr as *const i64) + })), + Some("Q") => AddressValue::Integer(IntegerValue::Unsigned(unsafe { + core::ptr::read_unaligned(ptr as *const u64) + })), + Some("f") => { + AddressValue::Float(unsafe { core::ptr::read_unaligned(ptr as *const f32) as f64 }) + } + Some("d" | "g") => { + AddressValue::Float(unsafe { core::ptr::read_unaligned(ptr as *const f64) }) + } + Some("P" | "z" | "Z") => { + AddressValue::Pointer(unsafe { core::ptr::read_unaligned(ptr as *const usize) }) + } + _ => AddressValue::Bytes(unsafe { bytes_at(ptr, size) }), + } +} + +/// # Safety +/// +/// `addr` must be valid to write one `u8`. +pub unsafe fn write_u8_at_address(addr: usize, value: u8) { + unsafe { *(addr as *mut u8) = value }; +} + +/// # Safety +/// +/// `addr` must be valid to write one `i16`. +pub unsafe fn write_i16_at_address(addr: usize, value: i16) { + unsafe { core::ptr::write_unaligned(addr as *mut i16, value) }; +} + +/// # Safety +/// +/// `addr` must be valid to write one `i32`. +pub unsafe fn write_i32_at_address(addr: usize, value: i32) { + unsafe { core::ptr::write_unaligned(addr as *mut i32, value) }; +} + +/// # Safety +/// +/// `addr` must be valid to write one `i64`. +pub unsafe fn write_i64_at_address(addr: usize, value: i64) { + unsafe { core::ptr::write_unaligned(addr as *mut i64, value) }; +} + +/// # Safety +/// +/// `addr` must be valid to write one `usize`. +pub unsafe fn write_pointer_at_address(addr: usize, value: usize) { + unsafe { core::ptr::write_unaligned(addr as *mut usize, value) }; +} + +/// # Safety +/// +/// `addr` must be valid to write one `f32`. +pub unsafe fn write_f32_at_address(addr: usize, value: f32) { + unsafe { core::ptr::write_unaligned(addr as *mut f32, value) }; +} + +/// # Safety +/// +/// `addr` must be valid to write one `f64`. +pub unsafe fn write_f64_at_address(addr: usize, value: f64) { + unsafe { core::ptr::write_unaligned(addr as *mut f64, value) }; +} + +/// # Safety +/// +/// `addr` must be valid for writing the storage required by `value`. +pub unsafe fn write_value_to_address(addr: usize, size: usize, value: AddressWriteValue<'_>) { + match value { + AddressWriteValue::Pointer(value) => unsafe { write_pointer_at_address(addr, value) }, + AddressWriteValue::U8(value) => unsafe { write_u8_at_address(addr, value) }, + AddressWriteValue::I16(value) => unsafe { write_i16_at_address(addr, value) }, + AddressWriteValue::I32(value) => unsafe { write_i32_at_address(addr, value) }, + AddressWriteValue::I64(value) => unsafe { write_i64_at_address(addr, value) }, + AddressWriteValue::Float(value) => match size { + 4 => unsafe { write_f32_at_address(addr, value as f32) }, + 8 => unsafe { write_f64_at_address(addr, value) }, + _ => {} + }, + AddressWriteValue::Bytes(bytes) => unsafe { copy_bytes_to_address(addr, bytes, size) }, + } +} + +/// # Safety +/// +/// `addr` must be valid to write `min(bytes.len(), size)` bytes. +pub unsafe fn copy_bytes_to_address(addr: usize, bytes: &[u8], size: usize) { + let copy_len = bytes.len().min(size); + unsafe { core::ptr::copy_nonoverlapping(bytes.as_ptr(), addr as *mut u8, copy_len) }; +} + +pub fn write_simple_storage_buffer(buffer: &mut Cow<'_, [u8]>, bytes: &[u8]) { + match buffer { + Cow::Borrowed(slice) => { + // SAFETY: ctypes borrowed buffers are created only from writable Python buffers. + unsafe { + copy_bytes_to_address(slice.as_ptr() as usize, bytes, slice.len()); + } + } + Cow::Owned(vec) => { + vec.copy_from_slice(bytes); + } + } +} + +pub fn write_cow_bytes_at_offset(buffer: &mut Cow<'_, [u8]>, offset: usize, bytes: &[u8]) { + if offset + bytes.len() > buffer.len() { + return; + } + + match buffer { + Cow::Borrowed(slice) => { + // SAFETY: callers only construct borrowed ctypes buffers for writable memory. + unsafe { + copy_bytes_to_address(slice.as_ptr() as usize + offset, bytes, bytes.len()); + } + } + Cow::Owned(vec) => { + vec[offset..offset + bytes.len()].copy_from_slice(bytes); + } + } +} + +pub fn resize_owned_bytes(old_data: &[u8], new_size: usize) -> Vec { + let mut new_data = vec![0u8; new_size]; + let copy_len = old_data.len().min(new_size); + new_data[..copy_len].copy_from_slice(&old_data[..copy_len]); + new_data +} + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub fn memmove_addr() -> usize { + libc::memmove as *const () as usize +} + +#[cfg(not(any(unix, windows, target_os = "wasi")))] +pub fn memmove_addr() -> usize { + 0 +} + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub fn memset_addr() -> usize { + libc::memset as *const () as usize +} + +#[cfg(not(any(unix, windows, target_os = "wasi")))] +pub fn memset_addr() -> usize { + 0 +} + +#[cfg(any(unix, windows))] +pub enum LookupSymbolError { + LibraryNotFound, + LibraryClosed, + Load(String), +} + +#[cfg(any(unix, windows))] +struct SharedLibrary { + lib: Mutex>, +} + +#[cfg(any(unix, windows))] +impl SharedLibrary { + #[cfg(windows)] + fn new(name: impl AsRef) -> Result { + Ok(Self { + lib: Mutex::new(unsafe { Some(Library::new(name.as_ref())?) }), + }) + } + + #[cfg(unix)] + fn new_with_mode(name: impl AsRef, mode: i32) -> Result { + Ok(Self { + lib: Mutex::new(Some(unsafe { + UnixLibrary::open(Some(name.as_ref()), mode)?.into() + })), + }) + } + + #[cfg(unix)] + fn from_raw_handle(handle: *mut c_void) -> Self { + Self { + lib: Mutex::new(Some(unsafe { UnixLibrary::from_raw(handle).into() })), + } + } + + fn get_pointer(&self) -> usize { + let lib_lock = self.lib.lock(); + if let Some(l) = &*lib_lock { + unsafe { core::mem::transmute_copy::(l) } + } else { + 0 + } + } + + fn lookup_data_symbol_addr(&self, symbol_name: &[u8]) -> Result { + let lib_lock = self.lib.lock(); + let Some(lib) = &*lib_lock else { + return Err(LookupSymbolError::LibraryClosed); + }; + let pointer = unsafe { + lib.get::<*const u8>(symbol_name) + .map_err(|err| LookupSymbolError::Load(err.to_string()))? + }; + Ok(*pointer as usize) + } + + fn lookup_function_symbol_addr(&self, symbol_name: &[u8]) -> Result { + let lib_lock = self.lib.lock(); + let Some(lib) = &*lib_lock else { + return Err(LookupSymbolError::LibraryClosed); + }; + let pointer = unsafe { + lib.get::(symbol_name) + .map_err(|err| LookupSymbolError::Load(err.to_string()))? + }; + Ok(*pointer as *const () as usize) + } +} + +#[cfg(any(unix, windows))] +struct ExternalLibs { + libraries: HashMap, +} + +#[cfg(any(unix, windows))] +impl ExternalLibs { + fn new() -> Self { + Self { + libraries: HashMap::new(), + } + } + + fn get_lib(&self, key: usize) -> Option<&SharedLibrary> { + self.libraries.get(&key) + } + + #[cfg(windows)] + fn open_library( + &mut self, + library_path: impl AsRef, + ) -> Result { + let new_lib = SharedLibrary::new(library_path)?; + let key = new_lib.get_pointer(); + if self.libraries.contains_key(&key) { + drop(new_lib); + return Ok(key); + } + self.libraries.insert(key, new_lib); + Ok(key) + } + + #[cfg(unix)] + fn open_library_with_mode( + &mut self, + library_path: impl AsRef, + mode: i32, + ) -> Result { + let new_lib = SharedLibrary::new_with_mode(library_path, mode)?; + let key = new_lib.get_pointer(); + if self.libraries.contains_key(&key) { + drop(new_lib); + return Ok(key); + } + self.libraries.insert(key, new_lib); + Ok(key) + } + + #[cfg(unix)] + fn insert_raw_library_handle(&mut self, handle: *mut c_void) -> usize { + let key = handle as usize; + self.libraries + .insert(key, SharedLibrary::from_raw_handle(handle)); + key + } + + fn drop_library(&mut self, key: usize) { + self.libraries.remove(&key); + } +} + +#[cfg(any(unix, windows))] +fn libcache() -> &'static RwLock { + static LIBCACHE: OnceLock> = OnceLock::new(); + LIBCACHE.get_or_init(|| RwLock::new(ExternalLibs::new())) +} + +#[cfg(windows)] +pub fn open_library(name: impl AsRef) -> Result { + libcache().write().open_library(name) +} + +#[cfg(unix)] +pub fn open_library_with_mode( + name: impl AsRef, + mode: i32, +) -> Result { + libcache().write().open_library_with_mode(name, mode) +} + +#[cfg(not(unix))] +pub fn open_library_with_mode( + _name: impl AsRef, + _mode: i32, +) -> Result { + Err("dlopen() error".to_string()) +} + +#[cfg(unix)] +pub fn insert_raw_library_handle(handle: *mut c_void) -> usize { + libcache().write().insert_raw_library_handle(handle) +} + +#[cfg(not(unix))] +pub fn insert_raw_library_handle(_handle: *mut c_void) -> usize { + 0 +} + +#[cfg(any(unix, windows))] +pub fn drop_library(handle: usize) { + libcache().write().drop_library(handle); +} + +#[cfg(not(any(unix, windows)))] +pub fn drop_library(_handle: usize) {} + +#[cfg(any(unix, windows))] +pub fn lookup_data_symbol_addr( + handle: usize, + symbol_name: &[u8], +) -> Result { + let cache = libcache().read(); + cache + .get_lib(handle) + .ok_or(LookupSymbolError::LibraryNotFound)? + .lookup_data_symbol_addr(symbol_name) +} + +#[cfg(any(unix, windows))] +pub fn lookup_function_symbol_addr( + handle: usize, + symbol_name: &[u8], +) -> Result { + let cache = libcache().read(); + cache + .get_lib(handle) + .ok_or(LookupSymbolError::LibraryNotFound)? + .lookup_function_symbol_addr(symbol_name) +} + +#[cfg(all(unix, not(target_os = "wasi")))] +pub fn dlopen_self(mode: c_int) -> Result<*mut c_void, String> { + let handle = unsafe { libc::dlopen(core::ptr::null(), mode) }; + if handle.is_null() { + let err = unsafe { libc::dlerror() }; + Err(if err.is_null() { + "dlopen() error".to_string() + } else { + unsafe { CStr::from_ptr(err) } + .to_string_lossy() + .into_owned() + }) + } else { + Ok(handle) + } +} + +#[cfg(not(any(windows, all(unix, not(target_os = "wasi")))))] +pub fn dlopen_self(_mode: c_int) -> Result<*mut c_void, String> { + Err("dlopen() error".to_string()) +} + +#[cfg(all(unix, not(target_os = "wasi")))] +pub fn dlsym_checked(handle: usize, symbol_name: &CStr) -> Result<*mut c_void, String> { + unsafe { + libc::dlerror(); + } + + let ptr = unsafe { libc::dlsym(handle as *mut c_void, symbol_name.as_ptr()) }; + let err = unsafe { libc::dlerror() }; + if !err.is_null() { + return Err(unsafe { CStr::from_ptr(err) } + .to_string_lossy() + .into_owned()); + } + if ptr.is_null() { + return Err(format!( + "symbol '{}' not found", + symbol_name.to_string_lossy() + )); + } + Ok(ptr) +} + +#[cfg(not(any(windows, all(unix, not(target_os = "wasi")))))] +pub fn dlsym_checked(_handle: usize, symbol_name: &CStr) -> Result<*mut c_void, String> { + Err(format!( + "symbol '{}' not found", + symbol_name.to_string_lossy() + )) +} diff --git a/crates/host_env/src/errno.rs b/crates/host_env/src/errno.rs new file mode 100644 index 00000000000..e816a04a55f --- /dev/null +++ b/crates/host_env/src/errno.rs @@ -0,0 +1,56 @@ +// spell-checker:disable + +/// Return the platform `strerror(errno)` message as an owned `String`. +/// Returns `None` when the runtime gives no description for `errno`. +#[cfg(any(unix, windows))] +#[must_use] +pub fn strerror_string(errno: i32) -> Option { + let ptr = unsafe { libc::strerror(errno) }; + if ptr.is_null() { + return None; + } + let s = unsafe { core::ffi::CStr::from_ptr(ptr) }.to_string_lossy(); + Some(s.into_owned()) +} + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub mod errors { + pub use libc::*; + #[cfg(windows)] + pub use windows_sys::Win32::{ + Foundation::*, + Networking::WinSock::{ + WSABASEERR, WSADESCRIPTION_LEN, WSAEACCES, WSAEADDRINUSE, WSAEADDRNOTAVAIL, + WSAEAFNOSUPPORT, WSAEALREADY, WSAEBADF, WSAECANCELLED, WSAECONNABORTED, + WSAECONNREFUSED, WSAECONNRESET, WSAEDESTADDRREQ, WSAEDISCON, WSAEDQUOT, WSAEFAULT, + WSAEHOSTDOWN, WSAEHOSTUNREACH, WSAEINPROGRESS, WSAEINTR, WSAEINVAL, + WSAEINVALIDPROCTABLE, WSAEINVALIDPROVIDER, WSAEISCONN, WSAELOOP, WSAEMFILE, + WSAEMSGSIZE, WSAENAMETOOLONG, WSAENETDOWN, WSAENETRESET, WSAENETUNREACH, WSAENOBUFS, + WSAENOMORE, WSAENOPROTOOPT, WSAENOTCONN, WSAENOTEMPTY, WSAENOTSOCK, WSAEOPNOTSUPP, + WSAEPFNOSUPPORT, WSAEPROCLIM, WSAEPROTONOSUPPORT, WSAEPROTOTYPE, + WSAEPROVIDERFAILEDINIT, WSAEREFUSED, WSAEREMOTE, WSAESHUTDOWN, WSAESOCKTNOSUPPORT, + WSAESTALE, WSAETIMEDOUT, WSAETOOMANYREFS, WSAEUSERS, WSAEWOULDBLOCK, WSAID_ACCEPTEX, + WSAID_CONNECTEX, WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_TRANSMITFILE, + WSAID_TRANSMITPACKETS, WSAID_WSAPOLL, WSAID_WSARECVMSG, WSANO_DATA, WSANO_RECOVERY, + WSANOTINITIALISED, WSAPROTOCOL_LEN, WSASERVICE_NOT_FOUND, WSASYS_STATUS_LEN, + WSASYSCALLFAILURE, WSASYSNOTREADY, WSATRY_AGAIN, WSATYPE_NOT_FOUND, WSAVERNOTSUPPORTED, + }, + }; + #[cfg(windows)] + macro_rules! reexport_wsa { + ($($errname:ident),*$(,)?) => { + paste::paste! { + $(pub const $errname: i32 = windows_sys::Win32::Networking::WinSock:: [] as i32;)* + } + } + } + #[cfg(windows)] + reexport_wsa! { + EADDRINUSE, EADDRNOTAVAIL, EAFNOSUPPORT, EALREADY, ECONNABORTED, ECONNREFUSED, ECONNRESET, + EDESTADDRREQ, EDQUOT, EHOSTDOWN, EHOSTUNREACH, EINPROGRESS, EISCONN, ELOOP, EMSGSIZE, + ENETDOWN, ENETRESET, ENETUNREACH, ENOBUFS, ENOPROTOOPT, ENOTCONN, ENOTSOCK, EOPNOTSUPP, + EPFNOSUPPORT, EPROTONOSUPPORT, EPROTOTYPE, EREMOTE, ESHUTDOWN, ESOCKTNOSUPPORT, ESTALE, + ETIMEDOUT, ETOOMANYREFS, EUSERS, EWOULDBLOCK, + // TODO: EBADF should be here once winerrs are translated to errnos but it messes up some things atm + } +} diff --git a/crates/host_env/src/faulthandler.rs b/crates/host_env/src/faulthandler.rs new file mode 100644 index 00000000000..3afbdebb42b --- /dev/null +++ b/crates/host_env/src/faulthandler.rs @@ -0,0 +1,511 @@ +#![allow( + clippy::missing_safety_doc, + reason = "These wrappers expose low-level fault handler hooks with raw OS ABI semantics." +)] +#![allow( + clippy::result_unit_err, + reason = "These helpers preserve the existing fault-handler error surface." +)] +#![allow(static_mut_refs)] + +#[cfg(unix)] +use alloc::vec::Vec; +#[cfg(unix)] +use parking_lot::Mutex; +#[cfg(windows)] +use windows_sys::Win32::System::{ + Diagnostics::Debug::{ + AddVectoredExceptionHandler, EXCEPTION_POINTERS, PVECTORED_EXCEPTION_HANDLER, + RaiseException, RemoveVectoredExceptionHandler, SEM_NOGPFAULTERRORBOX, SetErrorMode, + }, + Threading::GetCurrentThreadId, +}; + +#[cfg(windows)] +pub type ExceptionPointers = EXCEPTION_POINTERS; + +#[cfg(unix)] +struct FatalSignalHandler { + signum: libc::c_int, + enabled: bool, + name: &'static str, + previous: libc::sigaction, +} + +#[cfg(windows)] +struct FatalSignalHandler { + signum: libc::c_int, + enabled: bool, + name: &'static str, + previous: libc::sighandler_t, +} + +#[cfg(unix)] +impl FatalSignalHandler { + const fn new(signum: libc::c_int, name: &'static str) -> Self { + Self { + signum, + enabled: false, + name, + previous: unsafe { core::mem::zeroed() }, + } + } +} + +#[cfg(windows)] +impl FatalSignalHandler { + const fn new(signum: libc::c_int, name: &'static str) -> Self { + Self { + signum, + enabled: false, + name, + previous: 0, + } + } +} + +#[cfg(unix)] +const FATAL_SIGNAL_COUNT: usize = 5; +#[cfg(windows)] +const FATAL_SIGNAL_COUNT: usize = 4; + +#[cfg(unix)] +static mut FATAL_SIGNAL_HANDLERS: [FatalSignalHandler; FATAL_SIGNAL_COUNT] = [ + FatalSignalHandler::new(libc::SIGBUS, "Bus error"), + FatalSignalHandler::new(libc::SIGILL, "Illegal instruction"), + FatalSignalHandler::new(libc::SIGFPE, "Floating-point exception"), + FatalSignalHandler::new(libc::SIGABRT, "Aborted"), + FatalSignalHandler::new(libc::SIGSEGV, "Segmentation fault"), +]; + +#[cfg(windows)] +static mut FATAL_SIGNAL_HANDLERS: [FatalSignalHandler; FATAL_SIGNAL_COUNT] = [ + FatalSignalHandler::new(libc::SIGILL, "Illegal instruction"), + FatalSignalHandler::new(libc::SIGFPE, "Floating-point exception"), + FatalSignalHandler::new(libc::SIGABRT, "Aborted"), + FatalSignalHandler::new(libc::SIGSEGV, "Segmentation fault"), +]; + +#[cfg(unix)] +const USER_SIGNAL_CAPACITY: usize = 64; + +#[cfg(unix)] +#[derive(Clone, Copy)] +pub struct UserSignal { + pub fd: i32, + pub all_threads: bool, + pub chain: bool, +} + +#[cfg(unix)] +#[derive(Clone, Copy)] +struct RegisteredUserSignal { + enabled: bool, + fd: i32, + all_threads: bool, + chain: bool, + previous: libc::sigaction, +} + +#[cfg(unix)] +impl Default for RegisteredUserSignal { + fn default() -> Self { + Self { + enabled: false, + fd: 2, + all_threads: true, + chain: false, + previous: unsafe { core::mem::zeroed() }, + } + } +} + +#[cfg(unix)] +static USER_SIGNALS: Mutex>> = Mutex::new(None); + +pub fn write_fd(fd: i32, buf: &[u8]) { + let _ = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len() as _) }; +} + +#[cfg(any(unix, windows))] +pub fn is_fatal_signal(signum: libc::c_int) -> bool { + unsafe { + FATAL_SIGNAL_HANDLERS + .iter() + .any(|handler| handler.signum == signum) + } +} + +#[cfg(any(unix, windows))] +pub fn fatal_signal_name(signum: libc::c_int) -> Option<&'static str> { + unsafe { + FATAL_SIGNAL_HANDLERS + .iter() + .find(|handler| handler.signum == signum) + .map(|handler| handler.name) + } +} + +#[cfg(any(unix, windows))] +pub fn abort_process() -> ! { + unsafe { libc::abort() } +} + +#[cfg(any(unix, windows))] +pub fn raise_signal(signum: libc::c_int) { + unsafe { + libc::raise(signum); + } +} + +#[cfg(unix)] +#[inline] +pub fn current_thread_id() -> u64 { + unsafe { libc::pthread_self() as u64 } +} + +#[cfg(windows)] +#[inline] +pub fn current_thread_id() -> u64 { + unsafe { GetCurrentThreadId() as u64 } +} + +#[cfg(unix)] +pub fn install_sigaction( + signum: libc::c_int, + handler: extern "C" fn(libc::c_int), + flags: libc::c_int, + previous: &mut libc::sigaction, +) -> bool { + let mut action: libc::sigaction = unsafe { core::mem::zeroed() }; + action.sa_sigaction = handler as *const () as libc::sighandler_t; + action.sa_flags = flags; + unsafe { libc::sigaction(signum, &action, previous) == 0 } +} + +#[cfg(unix)] +unsafe fn disable_fatal_signal_handler(handler: &mut FatalSignalHandler) { + if !handler.enabled { + return; + } + handler.enabled = false; + restore_sigaction(handler.signum, &handler.previous); +} + +#[cfg(unix)] +pub fn enable_fatal_handlers(handler: extern "C" fn(libc::c_int), flags: libc::c_int) -> bool { + unsafe { + let mut installed = Vec::new(); + for entry in &mut FATAL_SIGNAL_HANDLERS { + if entry.enabled { + continue; + } + + if !install_sigaction(entry.signum, handler, flags, &mut entry.previous) { + for signum in installed { + disable_fatal_signal(signum); + } + return false; + } + entry.enabled = true; + installed.push(entry.signum); + } + } + true +} + +#[cfg(unix)] +pub fn disable_fatal_signal(signum: libc::c_int) { + unsafe { + if let Some(handler) = FATAL_SIGNAL_HANDLERS + .iter_mut() + .find(|handler| handler.signum == signum) + { + disable_fatal_signal_handler(handler); + } + } +} + +#[cfg(unix)] +pub fn disable_fatal_handlers() { + unsafe { + for handler in &mut FATAL_SIGNAL_HANDLERS { + disable_fatal_signal_handler(handler); + } + } +} + +#[cfg(unix)] +pub fn restore_sigaction(signum: libc::c_int, previous: &libc::sigaction) { + unsafe { + libc::sigaction(signum, previous, core::ptr::null_mut()); + } +} + +#[cfg(unix)] +pub fn signal_default_and_raise(signum: libc::c_int) { + unsafe { + libc::signal(signum, libc::SIG_DFL); + libc::raise(signum); + } +} + +#[cfg(unix)] +pub fn exit_immediately(code: libc::c_int) -> ! { + unsafe { libc::_exit(code) } +} + +#[cfg(unix)] +pub fn get_user_signal(signum: usize) -> Option { + let guard = USER_SIGNALS.lock(); + guard + .as_ref() + .and_then(|signals| signals.get(signum)) + .and_then(|signal| { + signal.enabled.then_some(UserSignal { + fd: signal.fd, + all_threads: signal.all_threads, + chain: signal.chain, + }) + }) +} + +#[cfg(unix)] +pub fn register_user_signal( + signum: libc::c_int, + fd: i32, + all_threads: bool, + chain: bool, + handler: extern "C" fn(libc::c_int), +) -> std::io::Result<()> { + if signum < 0 || signum as usize >= USER_SIGNAL_CAPACITY { + return Err(std::io::Error::from_raw_os_error(libc::EINVAL)); + } + let signum = signum as usize; + let mut guard = USER_SIGNALS.lock(); + if guard.is_none() { + *guard = Some(vec![RegisteredUserSignal::default(); USER_SIGNAL_CAPACITY]); + } + let signals = guard + .as_mut() + .expect("user signal table must be initialized"); + let entry = &mut signals[signum]; + + if !entry.enabled { + let mut previous = unsafe { core::mem::zeroed() }; + if !install_sigaction( + signum as libc::c_int, + handler, + if chain { + libc::SA_NODEFER + } else { + libc::SA_RESTART + }, + &mut previous, + ) { + return Err(std::io::Error::last_os_error()); + } + entry.previous = previous; + } + + entry.enabled = true; + entry.fd = fd; + entry.all_threads = all_threads; + entry.chain = chain; + Ok(()) +} + +#[cfg(unix)] +pub fn unregister_user_signal(signum: libc::c_int) -> bool { + if signum < 0 { + return false; + } + let signum = signum as usize; + let mut guard = USER_SIGNALS.lock(); + let Some(signals) = guard.as_mut() else { + return false; + }; + let Some(entry) = signals.get_mut(signum) else { + return false; + }; + if !entry.enabled { + return false; + } + + let previous = entry.previous; + *entry = RegisteredUserSignal::default(); + restore_sigaction(signum as libc::c_int, &previous); + true +} + +#[cfg(unix)] +pub fn reraise_user_signal(signum: libc::c_int, handler: extern "C" fn(libc::c_int)) -> bool { + if signum < 0 { + return false; + } + let signum_usize = signum as usize; + let previous = { + let guard = USER_SIGNALS.lock(); + let Some(signals) = guard.as_ref() else { + return false; + }; + let Some(entry) = signals.get(signum_usize) else { + return false; + }; + if !entry.enabled || !entry.chain { + return false; + } + entry.previous + }; + + let saved_errno = crate::os::get_errno(); + restore_sigaction(signum, &previous); + crate::os::set_errno(saved_errno); + raise_signal(signum); + + let mut ignored_previous = unsafe { core::mem::zeroed() }; + let _ = install_sigaction(signum, handler, libc::SA_NODEFER, &mut ignored_previous); + + crate::os::set_errno(saved_errno); + true +} + +#[cfg(windows)] +pub fn install_signal_handler( + signum: libc::c_int, + handler: extern "C" fn(libc::c_int), +) -> Result { + let previous = unsafe { libc::signal(signum, handler as *const () as libc::sighandler_t) }; + if previous == libc::SIG_ERR as libc::sighandler_t { + Err(()) + } else { + Ok(previous) + } +} + +#[cfg(windows)] +unsafe fn disable_fatal_signal_handler(handler: &mut FatalSignalHandler) { + if !handler.enabled { + return; + } + handler.enabled = false; + restore_signal_handler(handler.signum, handler.previous); +} + +#[cfg(windows)] +pub fn enable_fatal_handlers(handler: extern "C" fn(libc::c_int), _flags: libc::c_int) -> bool { + unsafe { + for entry in &mut FATAL_SIGNAL_HANDLERS { + if entry.enabled { + continue; + } + + let Ok(previous) = install_signal_handler(entry.signum, handler) else { + return false; + }; + entry.previous = previous; + entry.enabled = true; + } + } + true +} + +#[cfg(windows)] +pub fn disable_fatal_signal(signum: libc::c_int) { + unsafe { + if let Some(handler) = FATAL_SIGNAL_HANDLERS + .iter_mut() + .find(|handler| handler.signum == signum) + { + disable_fatal_signal_handler(handler); + } + } +} + +#[cfg(windows)] +pub fn disable_fatal_handlers() { + unsafe { + for handler in &mut FATAL_SIGNAL_HANDLERS { + disable_fatal_signal_handler(handler); + } + } +} + +#[cfg(windows)] +pub fn restore_signal_handler(signum: libc::c_int, previous: libc::sighandler_t) { + unsafe { + libc::signal(signum, previous); + } +} + +#[cfg(windows)] +pub fn signal_default_and_raise(signum: libc::c_int) { + unsafe { + libc::signal(signum, libc::SIG_DFL); + libc::raise(signum); + } +} + +#[cfg(windows)] +pub fn add_vectored_exception_handler(handler: PVECTORED_EXCEPTION_HANDLER) -> usize { + unsafe { AddVectoredExceptionHandler(1, handler) as usize } +} + +#[cfg(windows)] +pub fn remove_vectored_exception_handler(handle: usize) { + if handle != 0 { + unsafe { + RemoveVectoredExceptionHandler(handle as *mut core::ffi::c_void); + } + } +} + +#[cfg(windows)] +pub fn suppress_crash_report() { + unsafe { + let mode = SetErrorMode(SEM_NOGPFAULTERRORBOX); + SetErrorMode(mode | SEM_NOGPFAULTERRORBOX); + } +} + +#[cfg(windows)] +pub fn raise_exception(code: u32, flags: u32) { + unsafe { + RaiseException(code, flags, 0, core::ptr::null()); + } +} + +#[cfg(windows)] +pub fn ignore_exception(code: u32) -> bool { + if (code & 0x8000_0000) == 0 { + return true; + } + code == 0xE06D7363 || code == 0xE0434352 +} + +#[cfg(windows)] +pub fn exception_description(code: u32) -> Option<&'static str> { + match code { + 0xC0000005 => Some("access violation"), + 0xC000008C => Some("float divide by zero"), + 0xC0000091 => Some("float overflow"), + 0xC0000094 => Some("int divide by zero"), + 0xC0000095 => Some("integer overflow"), + 0xC0000006 => Some("page error"), + 0xC00000FD => Some("stack overflow"), + 0xC000001D => Some("illegal instruction"), + _ => None, + } +} + +#[cfg(windows)] +pub unsafe fn exception_code(exc_info: *mut EXCEPTION_POINTERS) -> u32 { + let record = unsafe { &*(*exc_info).ExceptionRecord }; + record.ExceptionCode as u32 +} + +#[cfg(windows)] +#[inline] +pub fn is_access_violation(code: u32) -> bool { + code == 0xC0000005 +} diff --git a/crates/host_env/src/fcntl.rs b/crates/host_env/src/fcntl.rs index b4dba53fa3d..2467a8727bc 100644 --- a/crates/host_env/src/fcntl.rs +++ b/crates/host_env/src/fcntl.rs @@ -1,21 +1,59 @@ use std::io; +#[cfg(unix)] +use std::os::fd::BorrowedFd; + +use crate::os::CheckLibcResult; + +pub fn normalize_ioctl_request(request: i64) -> libc::c_ulong { + (request as u32) as libc::c_ulong +} + pub fn fcntl_int(fd: i32, cmd: i32, arg: i32) -> io::Result { - let ret = unsafe { libc::fcntl(fd, cmd, arg) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) + unsafe { libc::fcntl(fd, cmd, arg) }.check_libc_neg() +} + +pub fn validate_fd(fd: i32) -> io::Result<()> { + fcntl_int(fd, libc::F_GETFD, 0).map(|_| ()) +} + +#[cfg(unix)] +pub fn get_inheritable(fd: BorrowedFd<'_>) -> io::Result { + use nix::fcntl as nix_fcntl; + + let flags = nix_fcntl::FdFlag::from_bits_truncate( + nix_fcntl::fcntl(fd, nix_fcntl::FcntlArg::F_GETFD).map_err(io::Error::from)?, + ); + Ok(!flags.contains(nix_fcntl::FdFlag::FD_CLOEXEC)) +} + +#[cfg(unix)] +pub fn get_blocking(fd: BorrowedFd<'_>) -> io::Result { + use nix::fcntl as nix_fcntl; + + let flags = nix_fcntl::OFlag::from_bits_truncate( + nix_fcntl::fcntl(fd, nix_fcntl::FcntlArg::F_GETFL).map_err(io::Error::from)?, + ); + Ok(!flags.contains(nix_fcntl::OFlag::O_NONBLOCK)) +} + +#[cfg(unix)] +pub fn set_blocking(fd: BorrowedFd<'_>, blocking: bool) -> io::Result<()> { + use nix::fcntl as nix_fcntl; + + let flags = nix_fcntl::OFlag::from_bits_truncate( + nix_fcntl::fcntl(fd, nix_fcntl::FcntlArg::F_GETFL).map_err(io::Error::from)?, + ); + let mut new_flags = flags; + new_flags.set(nix_fcntl::OFlag::O_NONBLOCK, !blocking); + if flags != new_flags { + nix_fcntl::fcntl(fd, nix_fcntl::FcntlArg::F_SETFL(new_flags)).map_err(io::Error::from)?; } + Ok(()) } pub fn fcntl_with_bytes(fd: i32, cmd: i32, arg: &mut [u8]) -> io::Result { - let ret = unsafe { libc::fcntl(fd, cmd, arg.as_mut_ptr()) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) - } + unsafe { libc::fcntl(fd, cmd, arg.as_mut_ptr()) }.check_libc_neg() } /// # Safety @@ -27,36 +65,56 @@ pub unsafe fn ioctl_ptr( request: libc::c_ulong, arg: *mut libc::c_void, ) -> io::Result { - let ret = unsafe { libc::ioctl(fd, request as _, arg) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) - } + unsafe { libc::ioctl(fd, request as _, arg) }.check_libc_neg() } pub fn ioctl_int(fd: i32, request: libc::c_ulong, arg: i32) -> io::Result { - let ret = unsafe { libc::ioctl(fd, request as _, arg) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) - } + unsafe { libc::ioctl(fd, request as _, arg) }.check_libc_neg() } #[cfg(not(any(target_os = "wasi", target_os = "redox")))] pub fn flock(fd: i32, operation: i32) -> io::Result { - let ret = unsafe { libc::flock(fd, operation) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) - } + unsafe { libc::flock(fd, operation) }.check_libc_neg() +} + +#[cfg(not(any(target_os = "wasi", target_os = "redox")))] +pub enum LockfError { + InvalidCmd, + Overflow(String), + Io(io::Error), } #[cfg(not(any(target_os = "wasi", target_os = "redox")))] -pub fn lockf(fd: i32, cmd: i32, lock: &libc::flock) -> io::Result { - let ret = unsafe { +pub fn lockf(fd: i32, cmd: i32, len: i64, start: i64, whence: i32) -> Result { + fn convert_field(value: T) -> Result + where + T: TryInto, + T::Error: core::fmt::Display, + { + value + .try_into() + .map_err(|err| LockfError::Overflow(err.to_string())) + } + + let l_type = if cmd == libc::LOCK_UN { + libc::F_UNLCK + } else if (cmd & libc::LOCK_SH) != 0 { + libc::F_RDLCK + } else if (cmd & libc::LOCK_EX) != 0 { + libc::F_WRLCK + } else { + return Err(LockfError::InvalidCmd); + }; + + let lock = libc::flock { + l_type: convert_field(l_type)?, + l_whence: convert_field(whence)?, + l_start: convert_field(start)?, + l_len: convert_field(len)?, + ..unsafe { core::mem::zeroed() } + }; + + unsafe { libc::fcntl( fd, if (cmd & libc::LOCK_NB) != 0 { @@ -64,12 +122,9 @@ pub fn lockf(fd: i32, cmd: i32, lock: &libc::flock) -> io::Result { } else { libc::F_SETLKW }, - lock, + &lock, ) - }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) } + .check_libc_neg() + .map_err(LockfError::Io) } diff --git a/crates/host_env/src/grp.rs b/crates/host_env/src/grp.rs new file mode 100644 index 00000000000..131369ce949 --- /dev/null +++ b/crates/host_env/src/grp.rs @@ -0,0 +1,53 @@ +use std::io; + +pub struct Group { + pub name: String, + pub passwd: String, + pub gid: u32, + pub mem: Vec, +} + +fn cstr_lossy(s: alloc::ffi::CString) -> String { + s.into_string() + .unwrap_or_else(|e| e.into_cstring().to_string_lossy().into_owned()) +} + +impl From for Group { + fn from(group: nix::unistd::Group) -> Self { + Self { + name: group.name, + passwd: cstr_lossy(group.passwd), + gid: group.gid.as_raw(), + mem: group.mem, + } + } +} + +pub fn getgrgid(gid: libc::gid_t) -> io::Result> { + nix::unistd::Group::from_gid(nix::unistd::Gid::from_raw(gid)) + .map(|group| group.map(Into::into)) + .map_err(io::Error::from) +} + +pub fn getgrnam(name: &str) -> io::Result> { + nix::unistd::Group::from_name(name) + .map(|group| group.map(Into::into)) + .map_err(io::Error::from) +} + +pub fn getgrall() -> Vec { + use core::ptr::NonNull; + + static GETGRALL: parking_lot::Mutex<()> = parking_lot::Mutex::new(()); + let _guard = GETGRALL.lock(); + let mut list = Vec::new(); + + unsafe { libc::setgrent() }; + while let Some(ptr) = NonNull::new(unsafe { libc::getgrent() }) { + let group = nix::unistd::Group::from(unsafe { ptr.as_ref() }); + list.push(group.into()); + } + unsafe { libc::endgrent() }; + + list +} diff --git a/crates/host_env/src/io.rs b/crates/host_env/src/io.rs new file mode 100644 index 00000000000..4ae1e4b3641 --- /dev/null +++ b/crates/host_env/src/io.rs @@ -0,0 +1,264 @@ +#[cfg(any(unix, target_os = "wasi"))] +use core::ffi::CStr; +use std::io; + +#[cfg(any(unix, target_os = "wasi"))] +use crate::fileutils; +use crate::{crt_fd, os}; + +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + pub struct FileMode: u8 { + const CREATED = 0b0001; + const READABLE = 0b0010; + const WRITABLE = 0b0100; + const APPENDING = 0b1000; + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum FileModeError { + Invalid, + BadRwa, +} + +impl FileModeError { + pub fn error_msg(self, mode_str: &str) -> String { + match self { + Self::Invalid => format!("invalid mode: {mode_str}"), + Self::BadRwa => { + "Must have exactly one of create/read/write/append mode and at most one plus" + .to_owned() + } + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct ParsedFileMode { + pub mode: FileMode, + pub flags: i32, +} + +impl FileMode { + pub const fn raw_mode(self) -> &'static str { + if self.contains(Self::CREATED) { + if self.contains(Self::READABLE) { + "xb+" + } else { + "xb" + } + } else if self.contains(Self::APPENDING) { + if self.contains(Self::READABLE) { + "ab+" + } else { + "ab" + } + } else if self.contains(Self::READABLE) { + if self.contains(Self::WRITABLE) { + "rb+" + } else { + "rb" + } + } else { + "wb" + } + } +} + +pub fn parse_fileio_mode(mode_str: &str) -> Result { + let mut flags = 0; + let mut plus = false; + let mut rwa = false; + let mut mode = FileMode::empty(); + for c in mode_str.bytes() { + match c { + b'x' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::WRITABLE | FileMode::CREATED); + flags |= libc::O_EXCL | libc::O_CREAT; + } + b'r' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::READABLE); + } + b'w' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::WRITABLE); + flags |= libc::O_CREAT | libc::O_TRUNC; + } + b'a' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::WRITABLE | FileMode::APPENDING); + flags |= libc::O_APPEND | libc::O_CREAT; + } + b'+' => { + if plus { + return Err(FileModeError::BadRwa); + } + plus = true; + mode.insert(FileMode::READABLE | FileMode::WRITABLE); + } + b'b' => {} + _ => return Err(FileModeError::Invalid), + } + } + + if !rwa { + return Err(FileModeError::BadRwa); + } + + if mode.contains(FileMode::READABLE | FileMode::WRITABLE) { + flags |= libc::O_RDWR; + } else if mode.contains(FileMode::READABLE) { + flags |= libc::O_RDONLY; + } else { + flags |= libc::O_WRONLY; + } + + #[cfg(windows)] + { + flags |= libc::O_BINARY | libc::O_NOINHERIT; + } + #[cfg(unix)] + { + flags |= libc::O_CLOEXEC; + } + + Ok(ParsedFileMode { mode, flags }) +} + +#[derive(Clone, Copy, Debug)] +pub struct FileTargetInfo { + pub blksize: Option, +} + +#[cfg(any(unix, target_os = "wasi"))] +pub fn inspect_file_target(fd: crt_fd::Borrowed<'_>) -> io::Result { + let status = fileutils::fstat(fd)?; + if (status.st_mode & libc::S_IFMT) == libc::S_IFDIR { + return Err(io::Error::from_raw_os_error(libc::EISDIR)); + } + #[allow(clippy::useless_conversion, reason = "needed for 32-bit platforms")] + let blksize = (status.st_blksize > 1).then(|| i64::from(status.st_blksize)); + Ok(FileTargetInfo { blksize }) +} + +#[cfg(windows)] +pub fn inspect_file_target(fd: crt_fd::Borrowed<'_>) -> io::Result { + if !crate::nt::fd_exists(fd) { + return Err(io::Error::from_raw_os_error( + crate::nt::ERROR_INVALID_HANDLE_I32, + )); + } + Ok(FileTargetInfo { blksize: None }) +} + +#[cfg(any(unix, target_os = "wasi"))] +pub fn open_path(path: &CStr, flags: i32, mode: i32) -> io::Result { + crt_fd::open(path, flags, mode) +} + +#[cfg(windows)] +pub fn open_path(path: &widestring::WideCStr, flags: i32, mode: i32) -> io::Result { + crt_fd::wopen(path, flags, mode) +} + +#[cfg(windows)] +pub fn should_forget_fd_after_inspect_error(err: &io::Error, _fd_is_own: bool) -> bool { + err.raw_os_error() == Some(crate::nt::ERROR_INVALID_HANDLE_I32) +} + +#[cfg(any(unix, target_os = "wasi"))] +pub fn should_forget_fd_after_inspect_error(err: &io::Error, fd_is_own: bool) -> bool { + let errno = err.raw_os_error(); + (errno == Some(libc::EISDIR) || errno == Some(libc::EBADF)) + && (!fd_is_own || errno == Some(libc::EBADF)) +} + +pub fn seek_to_end(fd: crt_fd::Borrowed<'_>) -> io::Result { + os::seek_fd(fd, 0, libc::SEEK_END) +} + +pub fn is_seekable(fd: crt_fd::Borrowed<'_>) -> bool { + os::seek_fd(fd, 0, libc::SEEK_CUR).is_ok() +} + +pub fn validate_whence(whence: i32) -> bool { + let standard = (0..=2).contains(&whence); + #[cfg(any(target_os = "dragonfly", target_os = "freebsd", target_os = "linux"))] + { + standard || matches!(whence, libc::SEEK_DATA | libc::SEEK_HOLE) + } + #[cfg(not(any(target_os = "dragonfly", target_os = "freebsd", target_os = "linux")))] + { + standard + } +} + +pub fn is_interrupted_errno(errno: i32) -> bool { + errno == libc::EINTR +} + +pub fn is_interrupted_error(err: &io::Error) -> bool { + err.raw_os_error() == Some(libc::EINTR) +} + +pub fn is_would_block_error(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::WouldBlock || err.raw_os_error() == Some(libc::EAGAIN) +} + +pub fn seek( + fd: crt_fd::Borrowed<'_>, + offset: crt_fd::Offset, + how: i32, +) -> io::Result { + os::seek_fd(fd, offset, how) +} + +pub fn tell(fd: crt_fd::Borrowed<'_>) -> io::Result { + os::seek_fd(fd, 0, libc::SEEK_CUR) +} + +pub fn isatty(fd: i32) -> bool { + os::isatty(fd) +} + +pub fn read_once(fd: crt_fd::Borrowed<'_>, buf: &mut [u8]) -> io::Result { + crt_fd::read(fd, buf) +} + +pub fn read_all(fd: crt_fd::Borrowed<'_>, out: &mut Vec) -> io::Result<()> { + let mut fd = fd; + std::io::Read::read_to_end(&mut fd, out).map(|_| ()) +} + +pub fn write_once(fd: crt_fd::Borrowed<'_>, buf: &[u8]) -> io::Result { + crt_fd::write(fd, buf) +} + +pub fn close_owned_fd(fd: crt_fd::Owned) -> io::Result<()> { + crt_fd::close(fd) +} + +/// Async-signal-safe raw write to the platform stderr file descriptor. +/// Avoids `std::io::stderr()` locking so it is safe to call from fork +/// children and signal handlers. +#[cfg(unix)] +pub fn write_stderr_raw(buf: &[u8]) { + unsafe { + let _ = libc::write(libc::STDERR_FILENO, buf.as_ptr().cast(), buf.len()); + } +} diff --git a/crates/host_env/src/io_unsupported.rs b/crates/host_env/src/io_unsupported.rs new file mode 100644 index 00000000000..e46f05af900 --- /dev/null +++ b/crates/host_env/src/io_unsupported.rs @@ -0,0 +1,225 @@ +use core::ffi::CStr; +use std::io; + +use crate::crt_fd; + +const EBADF: i32 = 9; +const EAGAIN: i32 = 11; +const EINTR: i32 = 4; +const EISDIR: i32 = 21; + +const O_RDONLY: i32 = 0; +const O_WRONLY: i32 = 1; +const O_RDWR: i32 = 2; +const O_APPEND: i32 = 0x0008; +const O_CREAT: i32 = 0x0200; +const O_TRUNC: i32 = 0x0400; +const O_EXCL: i32 = 0x0800; + +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + pub struct FileMode: u8 { + const CREATED = 0b0001; + const READABLE = 0b0010; + const WRITABLE = 0b0100; + const APPENDING = 0b1000; + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum FileModeError { + Invalid, + BadRwa, +} + +impl FileModeError { + pub fn error_msg(self, mode_str: &str) -> String { + match self { + Self::Invalid => format!("invalid mode: {mode_str}"), + Self::BadRwa => { + "Must have exactly one of create/read/write/append mode and at most one plus" + .to_owned() + } + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct ParsedFileMode { + pub mode: FileMode, + pub flags: i32, +} + +impl FileMode { + pub const fn raw_mode(self) -> &'static str { + if self.contains(Self::CREATED) { + if self.contains(Self::READABLE) { + "xb+" + } else { + "xb" + } + } else if self.contains(Self::APPENDING) { + if self.contains(Self::READABLE) { + "ab+" + } else { + "ab" + } + } else if self.contains(Self::READABLE) { + if self.contains(Self::WRITABLE) { + "rb+" + } else { + "rb" + } + } else { + "wb" + } + } +} + +pub fn parse_fileio_mode(mode_str: &str) -> Result { + let mut flags = 0; + let mut plus = false; + let mut binary = false; + let mut rwa = false; + let mut mode = FileMode::empty(); + for c in mode_str.bytes() { + match c { + b'x' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::WRITABLE | FileMode::CREATED); + flags |= O_EXCL | O_CREAT; + } + b'r' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::READABLE); + } + b'w' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::WRITABLE); + flags |= O_CREAT | O_TRUNC; + } + b'a' => { + if rwa { + return Err(FileModeError::BadRwa); + } + rwa = true; + mode.insert(FileMode::WRITABLE | FileMode::APPENDING); + flags |= O_APPEND | O_CREAT; + } + b'+' => { + if plus { + return Err(FileModeError::BadRwa); + } + plus = true; + mode.insert(FileMode::READABLE | FileMode::WRITABLE); + } + b'b' => { + if binary { + return Err(FileModeError::Invalid); + } + binary = true; + } + _ => return Err(FileModeError::Invalid), + } + } + + if !rwa { + return Err(FileModeError::BadRwa); + } + + if mode.contains(FileMode::READABLE | FileMode::WRITABLE) { + flags |= O_RDWR; + } else if mode.contains(FileMode::READABLE) { + flags |= O_RDONLY; + } else { + flags |= O_WRONLY; + } + + Ok(ParsedFileMode { mode, flags }) +} + +#[derive(Clone, Copy, Debug)] +pub struct FileTargetInfo { + pub blksize: Option, +} + +pub fn inspect_file_target(_fd: crt_fd::Borrowed<'_>) -> io::Result { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn open_path(_path: &CStr, _flags: i32, _mode: i32) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "host filesystem is unsupported on this platform", + )) +} + +pub fn should_forget_fd_after_inspect_error(err: &io::Error, fd_is_own: bool) -> bool { + let errno = err.raw_os_error(); + (errno == Some(EISDIR) || errno == Some(EBADF)) && (!fd_is_own || errno == Some(EBADF)) +} + +pub fn seek_to_end(_fd: crt_fd::Borrowed<'_>) -> io::Result { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn is_seekable(_fd: crt_fd::Borrowed<'_>) -> bool { + false +} + +pub fn validate_whence(whence: i32) -> bool { + (0..=2).contains(&whence) +} + +pub fn is_interrupted_errno(errno: i32) -> bool { + errno == EINTR +} + +pub fn is_interrupted_error(err: &io::Error) -> bool { + err.raw_os_error() == Some(EINTR) +} + +pub fn is_would_block_error(err: &io::Error) -> bool { + err.raw_os_error() == Some(EAGAIN) +} + +pub fn seek( + _fd: crt_fd::Borrowed<'_>, + _offset: crt_fd::Offset, + _how: i32, +) -> io::Result { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn tell(_fd: crt_fd::Borrowed<'_>) -> io::Result { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn isatty(_fd: i32) -> bool { + false +} + +pub fn read_once(_fd: crt_fd::Borrowed<'_>, _buf: &mut [u8]) -> io::Result { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn read_all(_fd: crt_fd::Borrowed<'_>, _out: &mut Vec) -> io::Result<()> { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn write_once(_fd: crt_fd::Borrowed<'_>, _buf: &[u8]) -> io::Result { + Err(io::Error::from_raw_os_error(EBADF)) +} + +pub fn close_owned_fd(_fd: crt_fd::Owned) -> io::Result<()> { + Err(io::Error::from_raw_os_error(EBADF)) +} diff --git a/crates/host_env/src/lib.rs b/crates/host_env/src/lib.rs index 80c2109a46f..99f67b2b496 100644 --- a/crates/host_env/src/lib.rs +++ b/crates/host_env/src/lib.rs @@ -1,18 +1,35 @@ +#![allow(clippy::must_use_candidate)] + extern crate alloc; #[macro_use] mod macros; pub use macros::*; +pub mod ctypes; +#[cfg(any(unix, windows, target_os = "wasi"))] +pub mod errno; +#[cfg(any(unix, windows, target_os = "wasi"))] +pub mod io; +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] +#[path = "io_unsupported.rs"] +pub mod io; pub mod os; +#[cfg(any(unix, windows))] +pub mod thread; #[cfg(any(unix, windows, target_os = "wasi"))] pub mod crt_fd; +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] +#[path = "crt_fd_unsupported.rs"] +pub mod crt_fd; #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] pub mod fileutils; #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] pub mod fs; +#[cfg(any(unix, windows))] +pub mod locale; #[cfg(windows)] pub mod windows; @@ -21,22 +38,51 @@ pub mod windows; pub mod fcntl; #[cfg(any(unix, windows, target_os = "wasi"))] pub mod select; +#[cfg(any(unix, windows))] +pub mod socket; #[cfg(unix)] pub mod syslog; #[cfg(all(unix, not(target_os = "redox"), not(target_os = "ios")))] pub mod termios; #[cfg(unix)] +pub mod grp; +#[cfg(unix)] +pub mod posix; +#[cfg(target_os = "wasi")] +#[path = "posix_wasi.rs"] pub mod posix; +#[cfg(unix)] +pub mod pwd; +#[cfg(unix)] +pub mod resource; #[cfg(all(unix, not(target_os = "redox"), not(target_os = "android")))] pub mod shm; -#[cfg(unix)] +#[cfg(any(unix, windows))] pub mod signal; pub mod time; +#[cfg(windows)] +pub mod cert_store; +#[cfg(any(unix, windows))] +pub mod faulthandler; +#[cfg(any(unix, windows))] +pub mod mmap; #[cfg(windows)] pub mod msvcrt; +#[cfg(any(unix, windows))] +pub mod multiprocessing; #[cfg(windows)] pub mod nt; #[cfg(windows)] +pub mod overlapped; +#[cfg(windows)] +pub mod testconsole; +#[cfg(windows)] pub mod winapi; +#[cfg(windows)] +pub mod winreg; +#[cfg(windows)] +pub mod winsound; +#[cfg(windows)] +pub mod wmi; diff --git a/crates/host_env/src/locale.rs b/crates/host_env/src/locale.rs new file mode 100644 index 00000000000..52fa7904421 --- /dev/null +++ b/crates/host_env/src/locale.rs @@ -0,0 +1,164 @@ +use alloc::vec::Vec; +use core::{ffi::CStr, ptr}; + +#[cfg(windows)] +#[repr(C)] +struct RawLconv { + decimal_point: *mut libc::c_char, + thousands_sep: *mut libc::c_char, + grouping: *mut libc::c_char, + int_curr_symbol: *mut libc::c_char, + currency_symbol: *mut libc::c_char, + mon_decimal_point: *mut libc::c_char, + mon_thousands_sep: *mut libc::c_char, + mon_grouping: *mut libc::c_char, + positive_sign: *mut libc::c_char, + negative_sign: *mut libc::c_char, + int_frac_digits: libc::c_char, + frac_digits: libc::c_char, + p_cs_precedes: libc::c_char, + p_sep_by_space: libc::c_char, + n_cs_precedes: libc::c_char, + n_sep_by_space: libc::c_char, + p_sign_posn: libc::c_char, + n_sign_posn: libc::c_char, +} + +#[cfg(windows)] +unsafe extern "C" { + fn localeconv() -> *mut RawLconv; +} + +#[cfg(unix)] +use libc::localeconv; + +#[derive(Debug, Clone)] +pub struct LocaleConv { + pub decimal_point: Vec, + pub thousands_sep: Vec, + pub grouping: Vec, + pub int_curr_symbol: Vec, + pub currency_symbol: Vec, + pub mon_decimal_point: Vec, + pub mon_thousands_sep: Vec, + pub mon_grouping: Vec, + pub positive_sign: Vec, + pub negative_sign: Vec, + pub int_frac_digits: libc::c_char, + pub frac_digits: libc::c_char, + pub p_cs_precedes: libc::c_char, + pub p_sep_by_space: libc::c_char, + pub n_cs_precedes: libc::c_char, + pub n_sep_by_space: libc::c_char, + pub p_sign_posn: libc::c_char, + pub n_sign_posn: libc::c_char, +} + +fn copy_cstr(ptr: *const libc::c_char) -> Vec { + if ptr.is_null() { + Vec::new() + } else { + unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec() + } +} + +fn copy_grouping(ptr: *const libc::c_char) -> Vec { + if ptr.is_null() { + return Vec::new(); + } + let mut out = Vec::new(); + let mut cur = ptr; + unsafe { + while ![0, libc::c_char::MAX].contains(&*cur) { + out.push(*cur); + cur = cur.add(1); + } + } + out +} + +pub fn localeconv_data() -> LocaleConv { + let lc = unsafe { localeconv() }; + unsafe { + LocaleConv { + decimal_point: copy_cstr((*lc).decimal_point), + thousands_sep: copy_cstr((*lc).thousands_sep), + grouping: copy_grouping((*lc).grouping), + int_curr_symbol: copy_cstr((*lc).int_curr_symbol), + currency_symbol: copy_cstr((*lc).currency_symbol), + mon_decimal_point: copy_cstr((*lc).mon_decimal_point), + mon_thousands_sep: copy_cstr((*lc).mon_thousands_sep), + mon_grouping: copy_grouping((*lc).mon_grouping), + positive_sign: copy_cstr((*lc).positive_sign), + negative_sign: copy_cstr((*lc).negative_sign), + int_frac_digits: (*lc).int_frac_digits, + frac_digits: (*lc).frac_digits, + p_cs_precedes: (*lc).p_cs_precedes, + p_sep_by_space: (*lc).p_sep_by_space, + n_cs_precedes: (*lc).n_cs_precedes, + n_sep_by_space: (*lc).n_sep_by_space, + p_sign_posn: (*lc).p_sign_posn, + n_sign_posn: (*lc).n_sign_posn, + } + } +} + +pub fn strcoll(string1: &CStr, string2: &CStr) -> libc::c_int { + unsafe { libc::strcoll(string1.as_ptr(), string2.as_ptr()) } +} + +pub fn strxfrm(string: &CStr, _initial_len: usize) -> Vec { + let len = unsafe { libc::strxfrm(ptr::null_mut(), string.as_ptr(), 0) }; + let mut buff = vec![0u8; len + 1]; + unsafe { + libc::strxfrm(buff.as_mut_ptr() as _, string.as_ptr(), buff.len()); + } + buff.truncate(len); + buff +} + +pub fn setlocale(category: i32, locale: Option<&CStr>) -> Option> { + let result = unsafe { + match locale { + None => libc::setlocale(category, ptr::null()), + Some(locale) => libc::setlocale(category, locale.as_ptr()), + } + }; + (!result.is_null()).then(|| unsafe { CStr::from_ptr(result) }.to_bytes().to_vec()) +} + +#[cfg(windows)] +pub fn acp() -> u32 { + unsafe { windows_sys::Win32::Globalization::GetACP() } +} + +#[cfg(windows)] +pub fn decode_ansi_bytes(bytes: &[u8]) -> Option { + use core::ptr; + use windows_sys::Win32::Globalization::{CP_ACP, MultiByteToWideChar}; + + if bytes.is_empty() { + return Some(String::new()); + } + let len_i32 = i32::try_from(bytes.len()).ok()?; + + let len = + unsafe { MultiByteToWideChar(CP_ACP, 0, bytes.as_ptr(), len_i32, ptr::null_mut(), 0) }; + if len <= 0 { + return None; + } + let mut wide = vec![0u16; len as usize]; + unsafe { + MultiByteToWideChar(CP_ACP, 0, bytes.as_ptr(), len_i32, wide.as_mut_ptr(), len); + } + Some(String::from_utf16_lossy(&wide)) +} + +#[cfg(all( + unix, + not(any(target_os = "ios", target_os = "android", target_os = "redox")) +))] +pub fn nl_langinfo_codeset() -> Option> { + let codeset = unsafe { libc::nl_langinfo(libc::CODESET) }; + (!codeset.is_null()).then(|| unsafe { CStr::from_ptr(codeset) }.to_bytes().to_vec()) +} diff --git a/crates/host_env/src/mmap.rs b/crates/host_env/src/mmap.rs new file mode 100644 index 00000000000..62dc50f1c31 --- /dev/null +++ b/crates/host_env/src/mmap.rs @@ -0,0 +1,371 @@ +#![allow( + clippy::not_unsafe_ptr_arg_deref, + reason = "These helpers are thin wrappers around raw Windows mapping APIs." +)] + +use std::io; + +#[cfg(windows)] +use crate::windows::{CheckWin32Bool, HandleToOwned}; +#[cfg(unix)] +use crate::{crt_fd, fileutils, posix}; +use memmap2::{Mmap, MmapMut, MmapOptions}; +#[cfg(windows)] +use std::os::windows::io::{AsRawHandle, IntoRawHandle}; +#[cfg(windows)] +use windows_sys::Win32::{ + Foundation::{ + CloseHandle, DUPLICATE_SAME_ACCESS, DuplicateHandle, GetLastError, HANDLE, + INVALID_HANDLE_VALUE, + }, + Storage::FileSystem::{FILE_BEGIN, GetFileSize, SetEndOfFile, SetFilePointerEx}, + System::{ + Memory::{ + CreateFileMappingW, FILE_MAP_COPY, FILE_MAP_READ, FILE_MAP_WRITE, FlushViewOfFile, + MEMORY_MAPPED_VIEW_ADDRESS, MapViewOfFile, PAGE_READONLY, PAGE_READWRITE, + PAGE_WRITECOPY, UnmapViewOfFile, + }, + Threading::GetCurrentProcess, + }, +}; + +#[cfg(windows)] +pub type Handle = HANDLE; +#[cfg(windows)] +pub const INVALID_HANDLE: Handle = INVALID_HANDLE_VALUE; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum AccessMode { + Default = 0, + Read = 1, + Write = 2, + Copy = 3, +} + +#[cfg(windows)] +#[derive(Debug)] +pub struct NamedMmap { + map_handle: Handle, + view_ptr: *mut u8, + len: usize, +} + +#[derive(Debug)] +pub enum MappedFile { + Read(Mmap), + Write(MmapMut), +} + +impl MappedFile { + pub fn as_slice(&self) -> &[u8] { + match self { + Self::Read(mmap) => &mmap[..], + Self::Write(mmap) => &mmap[..], + } + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + match self { + Self::Read(_) => panic!("mmap can't modify a readonly memory map."), + Self::Write(mmap) => &mut mmap[..], + } + } + + pub fn as_ptr(&self) -> *const u8 { + match self { + Self::Read(mmap) => mmap.as_ptr(), + Self::Write(mmap) => mmap.as_ptr(), + } + } + + pub fn flush_range(&self, offset: usize, size: usize) -> io::Result<()> { + match self { + Self::Read(_) => Ok(()), + Self::Write(mmap) => mmap.flush_range(offset, size), + } + } + + #[cfg(all(unix, not(target_os = "redox")))] + pub fn madvise_range(&self, start: usize, length: usize, advice: i32) -> io::Result<()> { + let ptr = unsafe { self.as_ptr().add(start) }; + posix::madvise(ptr as usize, length, advice) + } +} + +#[cfg(windows)] +unsafe impl Send for NamedMmap {} +#[cfg(windows)] +unsafe impl Sync for NamedMmap {} + +#[cfg(windows)] +impl NamedMmap { + pub fn as_slice(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self.view_ptr, self.len) } + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + unsafe { core::slice::from_raw_parts_mut(self.view_ptr, self.len) } + } + + pub fn ptr_at(&self, offset: usize) -> *const core::ffi::c_void { + unsafe { self.view_ptr.add(offset) as *const _ } + } + + pub fn flush_range(&self, offset: usize, size: usize) -> io::Result<()> { + flush_view(self.ptr_at(offset), size) + } +} + +#[cfg(windows)] +impl Drop for NamedMmap { + fn drop(&mut self) { + unsafe { + if !self.view_ptr.is_null() { + UnmapViewOfFile(MEMORY_MAPPED_VIEW_ADDRESS { + Value: self.view_ptr as *mut _, + }); + } + if !self.map_handle.is_null() { + CloseHandle(self.map_handle); + } + } + } +} + +#[cfg(windows)] +pub fn duplicate_handle(handle: Handle) -> io::Result { + let mut new_handle: Handle = INVALID_HANDLE; + unsafe { + DuplicateHandle( + GetCurrentProcess(), + handle, + GetCurrentProcess(), + &mut new_handle, + 0, + 0, + DUPLICATE_SAME_ACCESS, + ) + } + .check_win32_bool()?; + Ok(new_handle) +} + +#[cfg(windows)] +pub fn get_file_len(handle: Handle) -> io::Result { + let mut high: u32 = 0; + let low = unsafe { GetFileSize(handle, &mut high) }; + if low == u32::MAX { + let err = io::Error::last_os_error(); + if err.raw_os_error() != Some(0) { + return Err(err); + } + } + Ok(((high as i64) << 32) | (low as i64)) +} + +#[cfg(unix)] +pub fn file_len(fd: crt_fd::Borrowed<'_>) -> io::Result { + #[allow(clippy::useless_conversion, reason = "needed for 32-bit platforms")] + Ok(fileutils::fstat(fd)?.st_size.into()) +} + +#[cfg(unix)] +pub fn prepare_file_mapping(fd: crt_fd::Borrowed<'_>) { + #[cfg(target_os = "macos")] + { + let _ = posix::full_fsync(fd.into()); + } + #[cfg(not(target_os = "macos"))] + { + let _ = fd; + } +} + +#[cfg(windows)] +pub fn is_invalid_handle_value(handle: isize) -> bool { + handle == INVALID_HANDLE as isize +} + +#[cfg(windows)] +pub fn extend_file(handle: Handle, size: i64) -> io::Result<()> { + unsafe { SetFilePointerEx(handle, size, core::ptr::null_mut(), FILE_BEGIN) } + .check_win32_bool()?; + unsafe { SetEndOfFile(handle) }.check_win32_bool() +} + +#[cfg(unix)] +pub fn close_descriptor(fd: i32) { + if fd >= 0 { + let _ = crt_fd::close(unsafe { crt_fd::Owned::from_raw(fd) }); + } +} + +#[cfg(windows)] +pub fn close_handle(handle: Handle) { + unsafe { CloseHandle(handle) }; +} + +#[cfg(windows)] +pub fn flush_view(ptr: *const core::ffi::c_void, size: usize) -> io::Result<()> { + unsafe { FlushViewOfFile(ptr, size) }.check_win32_bool() +} + +#[cfg(windows)] +pub fn last_error() -> u32 { + unsafe { GetLastError() } +} + +#[cfg(windows)] +pub fn create_named_mapping( + file_handle: Handle, + tag: &str, + access: AccessMode, + offset: i64, + map_size: usize, +) -> io::Result { + let (fl_protect, desired_access) = match access { + AccessMode::Default | AccessMode::Write => (PAGE_READWRITE, FILE_MAP_WRITE), + AccessMode::Read => (PAGE_READONLY, FILE_MAP_READ), + AccessMode::Copy => (PAGE_WRITECOPY, FILE_MAP_COPY), + }; + + let total_size = (offset as u64) + .checked_add(map_size as u64) + .ok_or_else(|| io::Error::from_raw_os_error(libc::EOVERFLOW))?; + let size_hi = (total_size >> 32) as u32; + let size_lo = total_size as u32; + let tag_wide: Vec = tag.encode_utf16().chain(core::iter::once(0)).collect(); + + let map_handle = unsafe { + CreateFileMappingW( + file_handle, + core::ptr::null(), + fl_protect, + size_hi, + size_lo, + tag_wide.as_ptr(), + ) + } + .into_owned() + .ok_or_else(io::Error::last_os_error)?; + + let off_hi = (offset as u64 >> 32) as u32; + let off_lo = offset as u32; + let view = unsafe { + MapViewOfFile( + map_handle.as_raw_handle() as Handle, + desired_access, + off_hi, + off_lo, + map_size, + ) + }; + if view.Value.is_null() { + // `map_handle` is closed automatically when dropped on this error path. + return Err(io::Error::last_os_error()); + } + + Ok(NamedMmap { + map_handle: map_handle.into_raw_handle() as Handle, + view_ptr: view.Value as *mut u8, + len: map_size, + }) +} + +#[cfg(unix)] +pub fn map_anon(size: usize) -> io::Result { + let mut mmap_opt = MmapOptions::new(); + mmap_opt.len(size).map_anon().map(MappedFile::Write) +} + +#[cfg(windows)] +pub fn map_anon(size: usize) -> io::Result { + let mut mmap_opt = MmapOptions::new(); + mmap_opt.len(size).map_anon().map(MappedFile::Write) +} + +#[cfg(unix)] +pub fn map_file( + fd: crt_fd::Borrowed<'_>, + offset: i64, + size: usize, + access: AccessMode, +) -> io::Result<(crt_fd::Owned, MappedFile)> { + let new_fd: crt_fd::Owned = posix::dup_noninheritable(fd.into())?.into(); + let mut mmap_opt = MmapOptions::new(); + let mmap_opt = mmap_opt.offset(offset as u64).len(size); + + let mapped = match access { + AccessMode::Default | AccessMode::Write => { + unsafe { mmap_opt.map_mut(&new_fd) }.map(MappedFile::Write)? + } + AccessMode::Read => unsafe { mmap_opt.map(&new_fd) }.map(MappedFile::Read)?, + AccessMode::Copy => unsafe { mmap_opt.map_copy(&new_fd) }.map(MappedFile::Write)?, + }; + + Ok((new_fd, mapped)) +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn validate_advice(advice: i32) -> bool { + match advice { + libc::MADV_NORMAL + | libc::MADV_RANDOM + | libc::MADV_SEQUENTIAL + | libc::MADV_WILLNEED + | libc::MADV_DONTNEED => true, + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd" + ))] + libc::MADV_FREE => true, + #[cfg(target_os = "linux")] + libc::MADV_DONTFORK + | libc::MADV_DOFORK + | libc::MADV_MERGEABLE + | libc::MADV_UNMERGEABLE + | libc::MADV_HUGEPAGE + | libc::MADV_NOHUGEPAGE + | libc::MADV_REMOVE + | libc::MADV_DONTDUMP + | libc::MADV_DODUMP + | libc::MADV_HWPOISON => true, + #[cfg(target_os = "freebsd")] + libc::MADV_NOSYNC + | libc::MADV_AUTOSYNC + | libc::MADV_NOCORE + | libc::MADV_CORE + | libc::MADV_PROTECT => true, + _ => false, + } +} + +#[cfg(windows)] +pub fn map_handle( + handle: Handle, + offset: i64, + size: usize, + access: AccessMode, +) -> io::Result { + use std::{ + fs::File, + os::windows::io::{FromRawHandle, RawHandle}, + }; + + let file = unsafe { File::from_raw_handle(handle as RawHandle) }; + let mut mmap_opt = MmapOptions::new(); + let mmap_opt = mmap_opt.offset(offset as u64).len(size); + + let result = match access { + AccessMode::Default | AccessMode::Write => { + unsafe { mmap_opt.map_mut(&file) }.map(MappedFile::Write) + } + AccessMode::Read => unsafe { mmap_opt.map(&file) }.map(MappedFile::Read), + AccessMode::Copy => unsafe { mmap_opt.map_copy(&file) }.map(MappedFile::Write), + }; + + core::mem::forget(file); + result +} diff --git a/crates/host_env/src/msvcrt.rs b/crates/host_env/src/msvcrt.rs index 6905c2d6556..b94a6d060c7 100644 --- a/crates/host_env/src/msvcrt.rs +++ b/crates/host_env/src/msvcrt.rs @@ -2,13 +2,20 @@ use alloc::{string::String, vec::Vec}; use std::io; use crate::crt_fd; +use crate::os::CheckLibcResult; use windows_sys::Win32::System::Diagnostics::Debug; +pub type ErrorMode = u32; + pub const LK_UNLCK: i32 = 0; pub const LK_LOCK: i32 = 1; pub const LK_NBLCK: i32 = 2; pub const LK_RLCK: i32 = 3; pub const LK_NBRLCK: i32 = 4; +pub const SEM_FAILCRITICALERRORS: ErrorMode = Debug::SEM_FAILCRITICALERRORS; +pub const SEM_NOALIGNMENTFAULTEXCEPT: ErrorMode = Debug::SEM_NOALIGNMENTFAULTEXCEPT; +pub const SEM_NOGPFAULTERRORBOX: ErrorMode = Debug::SEM_NOGPFAULTERRORBOX; +pub const SEM_NOOPENFILEERRORBOX: ErrorMode = Debug::SEM_NOOPENFILEERRORBOX; unsafe extern "C" { fn _getch() -> i32; @@ -37,9 +44,7 @@ pub fn getch() -> Vec { #[must_use] pub fn getwch() -> String { let value = unsafe { _getwch() }; - char::from_u32(value) - .unwrap_or_else(|| panic!("invalid unicode {value:#x} from _getwch")) - .to_string() + char::from_u32(value).unwrap().to_string() } #[must_use] @@ -50,9 +55,7 @@ pub fn getche() -> Vec { #[must_use] pub fn getwche() -> String { let value = unsafe { _getwche() }; - char::from_u32(value) - .unwrap_or_else(|| panic!("invalid unicode {value:#x} from _getwche")) - .to_string() + char::from_u32(value).unwrap().to_string() } pub fn putch(c: u8) { @@ -87,45 +90,27 @@ pub fn kbhit() -> i32 { } pub fn locking(fd: i32, mode: i32, nbytes: i64) -> io::Result<()> { - let ret = unsafe { suppress_iph!(_locking(fd, mode, nbytes)) }; - if ret == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } + unsafe { suppress_iph!(_locking(fd, mode, nbytes)) }.check_libc_neg()?; + Ok(()) } pub fn heapmin() -> io::Result<()> { - let ret = unsafe { suppress_iph!(_heapmin()) }; - if ret == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } + unsafe { suppress_iph!(_heapmin()) }.check_libc_neg()?; + Ok(()) } pub fn setmode(fd: crt_fd::Borrowed<'_>, flags: i32) -> io::Result { - let ret = unsafe { suppress_iph!(_setmode(fd, flags)) }; - if ret == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) - } + unsafe { suppress_iph!(_setmode(fd, flags)) }.check_libc_neg() } pub fn open_osfhandle(handle: isize, flags: i32) -> io::Result { - let ret = unsafe { suppress_iph!(libc::open_osfhandle(handle, flags)) }; - if ret == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(ret) - } + unsafe { suppress_iph!(libc::open_osfhandle(handle, flags)) }.check_libc_neg() } pub fn get_error_mode() -> u32 { unsafe { suppress_iph!(Debug::GetErrorMode()) } } -pub fn set_error_mode(mode: Debug::THREAD_ERROR_MODE) -> u32 { +pub fn set_error_mode(mode: ErrorMode) -> u32 { unsafe { suppress_iph!(Debug::SetErrorMode(mode)) } } diff --git a/crates/host_env/src/multiprocessing.rs b/crates/host_env/src/multiprocessing.rs new file mode 100644 index 00000000000..4e79b2573cb --- /dev/null +++ b/crates/host_env/src/multiprocessing.rs @@ -0,0 +1,511 @@ +#![allow( + clippy::not_unsafe_ptr_arg_deref, + reason = "Semaphore helpers intentionally mirror OS handle and pointer APIs." +)] +#![allow( + clippy::result_unit_err, + reason = "These helpers preserve the existing host-facing error surface." +)] + +#[cfg(unix)] +use alloc::ffi::CString; +#[cfg(windows)] +use std::io; + +#[cfg(unix)] +use libc::sem_t; +#[cfg(unix)] +use nix::errno::Errno; + +#[cfg(unix)] +#[derive(Debug)] +pub struct SemHandle { + raw: *mut sem_t, +} + +#[cfg(unix)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum SemError { + WouldBlock, + TimedOut, + Interrupted, + AlreadyExists, + NotFound, + InvalidInput, + Other(i32), +} + +#[cfg(unix)] +impl SemError { + fn from_errno(err: Errno) -> Self { + match err { + Errno::EAGAIN => Self::WouldBlock, + Errno::ETIMEDOUT => Self::TimedOut, + Errno::EINTR => Self::Interrupted, + Errno::EEXIST => Self::AlreadyExists, + Errno::ENOENT => Self::NotFound, + Errno::EINVAL => Self::InvalidInput, + other => Self::Other(other as i32), + } + } + + pub fn raw_os_error(self) -> i32 { + match self { + Self::WouldBlock => Errno::EAGAIN as i32, + Self::TimedOut => Errno::ETIMEDOUT as i32, + Self::Interrupted => Errno::EINTR as i32, + Self::AlreadyExists => Errno::EEXIST as i32, + Self::NotFound => Errno::ENOENT as i32, + Self::InvalidInput => Errno::EINVAL as i32, + Self::Other(code) => code, + } + } + + pub fn description(self) -> String { + Errno::from_raw(self.raw_os_error()).desc().to_owned() + } +} + +#[cfg(unix)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum TryAcquireStatus { + Acquired, + WouldBlock, + Interrupted, + Error(SemError), +} + +#[cfg(unix)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum WaitStatus { + Acquired, + TimedOut, + Interrupted, + Error(SemError), +} + +#[cfg(windows)] +use windows_sys::Win32::{ + Foundation::{ + CloseHandle, ERROR_TOO_MANY_POSTS, GetLastError, HANDLE, INVALID_HANDLE_VALUE, WAIT_FAILED, + WAIT_OBJECT_0, WAIT_TIMEOUT, + }, + Networking::WinSock::{SOCKET, WSAGetLastError, closesocket, recv, send}, + System::Threading::{ + CreateSemaphoreW, GetCurrentThreadId, INFINITE, ReleaseSemaphore, WaitForSingleObjectEx, + }, +}; + +#[cfg(windows)] +pub type RawHandle = HANDLE; +#[cfg(windows)] +pub type RawSocket = SOCKET; +#[cfg(windows)] +pub const INFINITE_TIMEOUT: u32 = INFINITE; + +#[cfg(windows)] +#[derive(Debug)] +pub struct SemHandle { + raw: HANDLE, +} + +unsafe impl Send for SemHandle {} +unsafe impl Sync for SemHandle {} + +#[cfg(unix)] +impl SemHandle { + pub fn create( + name: &str, + value: u32, + unlink: bool, + ) -> Result<(Self, Option), SemError> { + let cname = semaphore_name(name).map_err(|_| SemError::InvalidInput)?; + let raw = + unsafe { libc::sem_open(cname.as_ptr(), libc::O_CREAT | libc::O_EXCL, 0o600, value) }; + if raw == libc::SEM_FAILED { + return Err(SemError::from_errno(Errno::last())); + } + if unlink { + if unsafe { libc::sem_unlink(cname.as_ptr()) } != 0 { + let err = SemError::from_errno(Errno::last()); + unsafe { + libc::sem_close(raw); + } + Err(err) + } else { + Ok((Self { raw }, None)) + } + } else { + Ok((Self { raw }, Some(name.to_owned()))) + } + } + + pub fn open_existing(name: &str) -> Result { + let cname = semaphore_name(name).map_err(|_| SemError::InvalidInput)?; + let raw = unsafe { libc::sem_open(cname.as_ptr(), 0) }; + if raw == libc::SEM_FAILED { + Err(SemError::from_errno(Errno::last())) + } else { + Ok(Self { raw }) + } + } + + #[inline] + pub fn as_ptr(&self) -> *mut sem_t { + self.raw + } +} + +#[cfg(windows)] +impl SemHandle { + pub fn create(value: i32, maxvalue: i32) -> io::Result { + use crate::windows::CheckWin32Handle; + let handle = + unsafe { CreateSemaphoreW(core::ptr::null(), value, maxvalue, core::ptr::null()) } + .check_nonnull()?; + Ok(Self { raw: handle }) + } + + #[inline] + pub fn from_raw(raw: HANDLE) -> Self { + Self { raw } + } + + #[inline] + pub fn as_raw(&self) -> HANDLE { + self.raw + } +} + +#[cfg(unix)] +impl Drop for SemHandle { + fn drop(&mut self) { + if !self.raw.is_null() { + unsafe { + libc::sem_close(self.raw); + } + } + } +} + +#[cfg(windows)] +impl Drop for SemHandle { + fn drop(&mut self) { + if self.raw != 0 as HANDLE && self.raw != INVALID_HANDLE_VALUE { + unsafe { + CloseHandle(self.raw); + } + } + } +} + +#[cfg(unix)] +#[inline] +pub fn current_thread_id() -> u64 { + unsafe { libc::pthread_self() as u64 } +} + +#[cfg(windows)] +#[inline] +pub fn current_thread_id() -> u32 { + unsafe { GetCurrentThreadId() } +} + +#[cfg(windows)] +#[inline] +pub fn wait_for_single_object(handle: HANDLE, timeout_ms: u32) -> u32 { + unsafe { WaitForSingleObjectEx(handle, timeout_ms, 0) } +} + +#[cfg(windows)] +#[inline] +pub fn wait_object_0() -> u32 { + WAIT_OBJECT_0 +} + +#[cfg(windows)] +#[inline] +pub fn wait_timeout() -> u32 { + WAIT_TIMEOUT +} + +#[cfg(windows)] +#[inline] +pub fn close_socket(socket: SOCKET) -> io::Result<()> { + let res = unsafe { closesocket(socket) }; + if res != 0 { + Err(io::Error::from_raw_os_error(unsafe { WSAGetLastError() })) + } else { + Ok(()) + } +} + +#[cfg(windows)] +pub fn recv_socket(socket: SOCKET, size: usize) -> io::Result> { + let len = i32::try_from(size).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidInput, "socket receive size too large") + })?; + let mut buf = vec![0u8; size]; + let n_read = unsafe { recv(socket, buf.as_mut_ptr() as *mut _, len, 0) }; + if n_read < 0 { + Err(io::Error::from_raw_os_error(unsafe { WSAGetLastError() })) + } else { + buf.truncate(n_read as usize); + Ok(buf) + } +} + +#[cfg(windows)] +pub fn send_socket(socket: SOCKET, buf: &[u8]) -> io::Result { + let len = i32::try_from(buf.len()) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "socket send buffer too large"))?; + let ret = unsafe { send(socket, buf.as_ptr() as *const _, len, 0) }; + if ret < 0 { + Err(io::Error::from_raw_os_error(unsafe { WSAGetLastError() })) + } else { + Ok(ret) + } +} + +#[cfg(windows)] +#[inline] +pub fn wait_failed() -> u32 { + WAIT_FAILED +} + +#[cfg(windows)] +pub fn release_semaphore(handle: HANDLE) -> Result<(), u32> { + if unsafe { ReleaseSemaphore(handle, 1, core::ptr::null_mut()) } == 0 { + Err(unsafe { GetLastError() }) + } else { + Ok(()) + } +} + +#[cfg(windows)] +pub fn get_semaphore_value(handle: HANDLE) -> Result { + match wait_for_single_object(handle, 0) { + WAIT_OBJECT_0 => { + let mut previous: i32 = 0; + if unsafe { ReleaseSemaphore(handle, 1, &mut previous) } == 0 { + Err(()) + } else { + Ok(previous + 1) + } + } + WAIT_TIMEOUT => Ok(0), + _ => Err(()), + } +} + +#[cfg(windows)] +#[inline] +pub fn is_too_many_posts(err: u32) -> bool { + err == ERROR_TOO_MANY_POSTS +} + +#[cfg(unix)] +pub fn semaphore_name(name: &str) -> Result { + let mut full = String::with_capacity(name.len() + 1); + if !name.starts_with('/') { + full.push('/'); + } + full.push_str(name); + CString::new(full) +} + +#[cfg(unix)] +pub fn sem_unlink(name: &str) -> Result<(), SemError> { + let cname = semaphore_name(name).map_err(|_| SemError::InvalidInput)?; + let res = unsafe { libc::sem_unlink(cname.as_ptr()) }; + if res < 0 { + Err(SemError::from_errno(Errno::last())) + } else { + Ok(()) + } +} + +#[cfg(all(unix, not(target_vendor = "apple")))] +/// # Safety +/// +/// `handle` must point to a valid `sem_t` that remains alive for the duration +/// of this call and is valid to pass to `sem_getvalue`. +pub unsafe fn get_semaphore_value(handle: *mut sem_t) -> Result { + let mut sval: libc::c_int = 0; + let res = unsafe { libc::sem_getvalue(handle, &mut sval) }; + if res < 0 { + Err(SemError::from_errno(Errno::last())) + } else { + Ok(if sval < 0 { 0 } else { sval }) + } +} + +#[cfg(unix)] +#[allow(clippy::not_unsafe_ptr_arg_deref)] +pub fn sem_trywait_status(handle: *mut sem_t) -> TryAcquireStatus { + if unsafe { libc::sem_trywait(handle) } == 0 { + TryAcquireStatus::Acquired + } else { + match Errno::last() { + Errno::EAGAIN => TryAcquireStatus::WouldBlock, + Errno::EINTR => TryAcquireStatus::Interrupted, + err => TryAcquireStatus::Error(SemError::from_errno(err)), + } + } +} + +#[cfg(unix)] +#[allow(clippy::not_unsafe_ptr_arg_deref)] +pub fn sem_post(handle: *mut sem_t) -> Result<(), SemError> { + if unsafe { libc::sem_post(handle) } < 0 { + Err(SemError::from_errno(Errno::last())) + } else { + Ok(()) + } +} + +#[cfg(unix)] +pub fn sem_value_max() -> i32 { + let val = unsafe { libc::sysconf(libc::_SC_SEM_VALUE_MAX) }; + if val < 0 || val > i32::MAX as libc::c_long { + i32::MAX + } else { + val as i32 + } +} + +#[cfg(unix)] +pub fn gettimeofday() -> Result { + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: 0, + }; + if unsafe { libc::gettimeofday(&mut tv, core::ptr::null_mut()) } < 0 { + Err(SemError::from_errno(Errno::last())) + } else { + Ok(tv) + } +} + +#[cfg(unix)] +pub fn deadline_from_timeout(timeout: f64) -> Result { + let timeout = if timeout < 0.0 { 0.0 } else { timeout }; + if !timeout.is_finite() { + return Err(SemError::InvalidInput); + } + let tv = gettimeofday()?; + let sec_f64 = timeout.floor(); + if sec_f64 > libc::time_t::MAX as f64 { + return Err(SemError::InvalidInput); + } + let sec = sec_f64 as libc::time_t; + let nsec = (1e9 * (timeout - sec as f64) + 0.5) as libc::c_long; + let tv_nsec = (tv.tv_usec as libc::c_long) + .checked_mul(1000) + .and_then(|base| base.checked_add(nsec)) + .ok_or(SemError::InvalidInput)?; + let mut deadline = libc::timespec { + tv_sec: tv.tv_sec.checked_add(sec).ok_or(SemError::InvalidInput)?, + tv_nsec: tv_nsec as _, + }; + deadline.tv_sec = deadline + .tv_sec + .checked_add((deadline.tv_nsec / 1_000_000_000) as libc::time_t) + .ok_or(SemError::InvalidInput)?; + deadline.tv_nsec %= 1_000_000_000; + Ok(deadline) +} + +#[cfg(unix)] +#[allow(clippy::not_unsafe_ptr_arg_deref)] +pub fn sem_wait_status(handle: *mut sem_t, deadline: Option<&libc::timespec>) -> WaitStatus { + #[cfg(not(target_vendor = "apple"))] + if let Some(deadline) = deadline { + if unsafe { libc::sem_timedwait(handle, deadline) } == 0 { + WaitStatus::Acquired + } else { + match Errno::last() { + Errno::ETIMEDOUT => WaitStatus::TimedOut, + Errno::EINTR => WaitStatus::Interrupted, + err => WaitStatus::Error(SemError::from_errno(err)), + } + } + } else { + if unsafe { libc::sem_wait(handle) } == 0 { + WaitStatus::Acquired + } else { + match Errno::last() { + Errno::EINTR => WaitStatus::Interrupted, + err => WaitStatus::Error(SemError::from_errno(err)), + } + } + } + + #[cfg(target_vendor = "apple")] + { + debug_assert!(deadline.is_none()); + if unsafe { libc::sem_wait(handle) } == 0 { + WaitStatus::Acquired + } else { + match Errno::last() { + Errno::EINTR => WaitStatus::Interrupted, + err => WaitStatus::Error(SemError::from_errno(err)), + } + } + } +} + +#[cfg(target_vendor = "apple")] +pub enum PollWaitStep { + Acquired, + Timeout, + Continue(u64), +} + +#[cfg(target_vendor = "apple")] +#[allow(clippy::not_unsafe_ptr_arg_deref)] +pub fn sem_timedwait_poll_step( + handle: *mut sem_t, + deadline: &libc::timespec, + delay: u64, +) -> Result { + if unsafe { libc::sem_trywait(handle) } == 0 { + return Ok(PollWaitStep::Acquired); + } + let err = Errno::last(); + if err != Errno::EAGAIN { + return Err(SemError::from_errno(err)); + } + + let now = gettimeofday()?; + let deadline_usec = deadline.tv_sec * 1_000_000 + deadline.tv_nsec / 1000; + #[allow(clippy::unnecessary_cast)] + let now_usec = now.tv_sec as i64 * 1_000_000 + now.tv_usec as i64; + if now_usec >= deadline_usec { + return Ok(PollWaitStep::Timeout); + } + + let difference = (deadline_usec - now_usec) as u64; + let mut delay = delay + 1000; + if delay > 20000 { + delay = 20000; + } + if delay > difference { + delay = difference; + } + + let mut tv_delay = libc::timeval { + tv_sec: (delay / 1_000_000) as _, + tv_usec: (delay % 1_000_000) as _, + }; + unsafe { + libc::select( + 0, + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + &mut tv_delay, + ); + } + Ok(PollWaitStep::Continue(delay)) +} diff --git a/crates/host_env/src/nt.rs b/crates/host_env/src/nt.rs index c6771aad40a..4c77b30e616 100644 --- a/crates/host_env/src/nt.rs +++ b/crates/host_env/src/nt.rs @@ -1,30 +1,361 @@ +#![allow( + clippy::not_unsafe_ptr_arg_deref, + reason = "This module mirrors raw Win32 path, handle, and CRT entry points." +)] + // cspell:ignore hchmod -use std::{ffi::OsStr, io, os::windows::io::AsRawHandle}; +use std::{ + ffi::{OsStr, OsString}, + io, + os::windows::{ffi::OsStringExt, io::AsRawHandle}, + path::Path, +}; -use crate::{crt_fd, windows::ToWideString}; +use core::sync::atomic::{AtomicBool, Ordering}; + +use crate::{ + crt_fd, + fileutils::{ + StatStruct, + windows::{FILE_INFO_BY_NAME_CLASS, get_file_information_by_name, stat_basic_info_to_stat}, + }, + windows::{CheckWin32Bool, CheckWin32Handle, CheckWin32Sentinel, HandleToOwned, ToWideString}, +}; +use libc::intptr_t; use windows_sys::Win32::{ - Foundation::HANDLE, + Foundation::{ + CloseHandle, ERROR_INVALID_HANDLE, GetLastError, HANDLE, INVALID_HANDLE_VALUE, MAX_PATH, + }, + Globalization::{CP_UTF8, MultiByteToWideChar, WideCharToMultiByte}, Storage::FileSystem::{ - FILE_ATTRIBUTE_READONLY, FILE_BASIC_INFO, FileBasicInfo, GetFileAttributesW, - GetFileInformationByHandleEx, INVALID_FILE_ATTRIBUTES, SetFileAttributesW, - SetFileInformationByHandle, + CreateFileW, FILE_BASIC_INFO, FILE_FLAG_BACKUP_SEMANTICS, FILE_FLAG_OPEN_REPARSE_POINT, + FILE_READ_ATTRIBUTES, FILE_TYPE_UNKNOWN, FileBasicInfo, FindClose, FindFirstFileW, + GetFileAttributesW, GetFileInformationByHandleEx, GetFileType, GetFullPathNameW, + INVALID_FILE_ATTRIBUTES, OPEN_EXISTING, SetFileAttributesW, SetFileInformationByHandle, + WIN32_FIND_DATAW, }, + System::{Console, Threading}, +}; + +pub type Handle = HANDLE; +pub const MAX_PATH_USIZE: usize = MAX_PATH as usize; +pub const ERROR_INVALID_HANDLE_I32: i32 = ERROR_INVALID_HANDLE as i32; +pub const LOAD_LIBRARY_SEARCH_APPLICATION_DIR: u32 = + windows_sys::Win32::System::LibraryLoader::LOAD_LIBRARY_SEARCH_APPLICATION_DIR; +pub const LOAD_LIBRARY_SEARCH_DEFAULT_DIRS: u32 = + windows_sys::Win32::System::LibraryLoader::LOAD_LIBRARY_SEARCH_DEFAULT_DIRS; +pub const LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR: u32 = + windows_sys::Win32::System::LibraryLoader::LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR; +pub const LOAD_LIBRARY_SEARCH_SYSTEM32: u32 = + windows_sys::Win32::System::LibraryLoader::LOAD_LIBRARY_SEARCH_SYSTEM32; +pub const LOAD_LIBRARY_SEARCH_USER_DIRS: u32 = + windows_sys::Win32::System::LibraryLoader::LOAD_LIBRARY_SEARCH_USER_DIRS; + +pub use windows_sys::Win32::Storage::FileSystem::{ + FILE_ATTRIBUTE_ARCHIVE, FILE_ATTRIBUTE_COMPRESSED, FILE_ATTRIBUTE_DEVICE, + FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_ENCRYPTED, FILE_ATTRIBUTE_HIDDEN, + FILE_ATTRIBUTE_INTEGRITY_STREAM, FILE_ATTRIBUTE_NO_SCRUB_DATA, FILE_ATTRIBUTE_NORMAL, + FILE_ATTRIBUTE_NOT_CONTENT_INDEXED, FILE_ATTRIBUTE_OFFLINE, FILE_ATTRIBUTE_READONLY, + FILE_ATTRIBUTE_REPARSE_POINT, FILE_ATTRIBUTE_SPARSE_FILE, FILE_ATTRIBUTE_SYSTEM, + FILE_ATTRIBUTE_TEMPORARY, FILE_ATTRIBUTE_VIRTUAL, }; +#[cfg(target_env = "msvc")] +unsafe extern "C" { + fn _cwait(termstat: *mut i32, procHandle: intptr_t, action: i32) -> intptr_t; + fn _wexecv(cmdname: *const u16, argv: *const *const u16) -> intptr_t; + fn _wexecve(cmdname: *const u16, argv: *const *const u16, envp: *const *const u16) -> intptr_t; + fn _wspawnv(mode: i32, cmdname: *const u16, argv: *const *const u16) -> intptr_t; + fn _wspawnve( + mode: i32, + cmdname: *const u16, + argv: *const *const u16, + envp: *const *const u16, + ) -> intptr_t; +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum TestType { + RegularFile, + Directory, + Symlink, + Junction, + LinkReparsePoint, + RegularReparsePoint, +} + +const IO_REPARSE_TAG_SYMLINK: u32 = 0xA000000C; +const S_IFMT: u16 = libc::S_IFMT as u16; +const S_IFDIR_MODE: u16 = libc::S_IFDIR as u16; +const S_IFCHR_MODE: u16 = libc::S_IFCHR as u16; +const S_IFIFO_MODE: u16 = crate::fileutils::windows::S_IFIFO as u16; + +#[repr(C)] +#[derive(Default)] +struct FileAttributeTagInfo { + file_attributes: u32, + reparse_tag: u32, +} + +fn win32_large_integer_to_time(li: i64) -> (libc::time_t, i32) { + let nsec = ((li % 10_000_000) * 100) as i32; + let sec = (li / 10_000_000 - crate::fileutils::windows::SECS_BETWEEN_EPOCHS) as libc::time_t; + (sec, nsec) +} + +fn win32_filetime_to_time(ft_low: u32, ft_high: u32) -> (libc::time_t, i32) { + let ticks = ((ft_high as i64) << 32) | (ft_low as i64); + let nsec = ((ticks % 10_000_000) * 100) as i32; + let sec = (ticks / 10_000_000 - crate::fileutils::windows::SECS_BETWEEN_EPOCHS) as libc::time_t; + (sec, nsec) +} + +fn win32_attribute_data_to_stat( + info: &windows_sys::Win32::Storage::FileSystem::BY_HANDLE_FILE_INFORMATION, + reparse_tag: u32, + basic_info: Option<&windows_sys::Win32::Storage::FileSystem::FILE_BASIC_INFO>, + id_info: Option<&windows_sys::Win32::Storage::FileSystem::FILE_ID_INFO>, +) -> StatStruct { + use windows_sys::Win32::Storage::FileSystem::{ + FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_READONLY, FILE_ATTRIBUTE_REPARSE_POINT, + }; + + let mut st_mode: u16 = 0; + if info.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY != 0 { + st_mode |= S_IFDIR_MODE | 0o111; + } else { + st_mode |= libc::S_IFREG as u16; + } + if info.dwFileAttributes & FILE_ATTRIBUTE_READONLY != 0 { + st_mode |= 0o444; + } else { + st_mode |= 0o666; + } + + let st_size = ((info.nFileSizeHigh as u64) << 32) | (info.nFileSizeLow as u64); + let st_dev = id_info.map_or(info.dwVolumeSerialNumber, |id| id.VolumeSerialNumber as u32); + let st_nlink = info.nNumberOfLinks as i32; + + let (st_birthtime, st_birthtime_nsec, st_mtime, st_mtime_nsec, st_atime, st_atime_nsec) = + if let Some(bi) = basic_info { + let (birth, birth_nsec) = win32_large_integer_to_time(bi.CreationTime); + let (mtime, mtime_nsec) = win32_large_integer_to_time(bi.LastWriteTime); + let (atime, atime_nsec) = win32_large_integer_to_time(bi.LastAccessTime); + (birth, birth_nsec, mtime, mtime_nsec, atime, atime_nsec) + } else { + let (birth, birth_nsec) = win32_filetime_to_time( + info.ftCreationTime.dwLowDateTime, + info.ftCreationTime.dwHighDateTime, + ); + let (mtime, mtime_nsec) = win32_filetime_to_time( + info.ftLastWriteTime.dwLowDateTime, + info.ftLastWriteTime.dwHighDateTime, + ); + let (atime, atime_nsec) = win32_filetime_to_time( + info.ftLastAccessTime.dwLowDateTime, + info.ftLastAccessTime.dwHighDateTime, + ); + (birth, birth_nsec, mtime, mtime_nsec, atime, atime_nsec) + }; + + let (st_ino, st_ino_high) = if let Some(id) = id_info { + let bytes = id.FileId.Identifier; + ( + u64::from_le_bytes(bytes[0..8].try_into().unwrap()), + u64::from_le_bytes(bytes[8..16].try_into().unwrap()), + ) + } else { + ( + ((info.nFileIndexHigh as u64) << 32) | (info.nFileIndexLow as u64), + 0, + ) + }; + + if info.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 + && reparse_tag == IO_REPARSE_TAG_SYMLINK + { + st_mode = (st_mode & !S_IFMT) | crate::fileutils::windows::S_IFLNK as u16; + } + + StatStruct { + st_dev, + st_ino, + st_ino_high, + st_mode, + st_nlink, + st_uid: 0, + st_gid: 0, + st_rdev: 0, + st_size, + st_atime, + st_atime_nsec, + st_mtime, + st_mtime_nsec, + st_ctime: 0, + st_ctime_nsec: 0, + st_birthtime, + st_birthtime_nsec, + st_file_attributes: info.dwFileAttributes, + st_reparse_tag: reparse_tag, + } +} + +pub fn visible_env_vars() -> impl Iterator { + crate::os::vars().filter(|(key, _)| !key.starts_with('=')) +} + +#[derive(Debug)] +pub enum ReadlinkError { + Io(io::Error), + NotSymbolicLink, + InvalidReparseData, +} + +#[derive(Debug)] +pub enum ReadConsoleError { + Io(io::Error), + BufferTooSmall { available: usize, required: usize }, +} + +pub fn access(path: &Path, mode: u8) -> bool { + let wide = path.as_os_str().to_wide_with_nul(); + let attr = unsafe { GetFileAttributesW(wide.as_ptr()) }; + attr != INVALID_FILE_ATTRIBUTES + && (mode & 2 == 0 + || attr & FILE_ATTRIBUTE_READONLY == 0 + || attr & windows_sys::Win32::Storage::FileSystem::FILE_ATTRIBUTE_DIRECTORY != 0) +} + +pub fn remove(path: &Path) -> io::Result<()> { + use windows_sys::Win32::Storage::FileSystem::{ + DeleteFileW, RemoveDirectoryW, WIN32_FIND_DATAW, + }; + use windows_sys::Win32::System::SystemServices::{ + IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, + }; + + let wide_path = path.as_os_str().to_wide_with_nul(); + let attrs = unsafe { GetFileAttributesW(wide_path.as_ptr()) }; + + let mut is_directory = false; + let mut is_link = false; + + if attrs != INVALID_FILE_ATTRIBUTES { + is_directory = + (attrs & windows_sys::Win32::Storage::FileSystem::FILE_ATTRIBUTE_DIRECTORY) != 0; + + if is_directory + && (attrs & windows_sys::Win32::Storage::FileSystem::FILE_ATTRIBUTE_REPARSE_POINT) != 0 + { + let mut find_data: WIN32_FIND_DATAW = unsafe { core::mem::zeroed() }; + let handle = unsafe { FindFirstFileW(wide_path.as_ptr(), &mut find_data) }; + if handle != INVALID_HANDLE_VALUE { + is_link = find_data.dwReserved0 == IO_REPARSE_TAG_SYMLINK + || find_data.dwReserved0 == IO_REPARSE_TAG_MOUNT_POINT; + unsafe { FindClose(handle) }; + } + } + } + + if is_directory && is_link { + unsafe { RemoveDirectoryW(wide_path.as_ptr()) } + } else { + unsafe { DeleteFileW(wide_path.as_ptr()) } + } + .check_win32_bool() +} + +pub fn supports_virtual_terminal() -> bool { + let mut mode = 0; + let handle = unsafe { Console::GetStdHandle(Console::STD_ERROR_HANDLE) }; + (unsafe { Console::GetConsoleMode(handle, &mut mode) }) != 0 + && mode & Console::ENABLE_VIRTUAL_TERMINAL_PROCESSING != 0 +} + +pub fn symlink( + src: &Path, + dst: &Path, + src_wide: &widestring::WideCStr, + dst_wide: &widestring::WideCStr, + target_is_directory: bool, +) -> io::Result<()> { + use windows_sys::Win32::Storage::FileSystem::WIN32_FILE_ATTRIBUTE_DATA; + use windows_sys::Win32::Storage::FileSystem::{ + CreateSymbolicLinkW, FILE_ATTRIBUTE_DIRECTORY, GetFileAttributesExW, + SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE, SYMBOLIC_LINK_FLAG_DIRECTORY, + }; + + static HAS_UNPRIVILEGED_FLAG: AtomicBool = AtomicBool::new(true); + + fn check_dir(src: &Path, dst: &Path) -> bool { + use windows_sys::Win32::Storage::FileSystem::GetFileExInfoStandard; + + let Some(dst_parent) = dst.parent() else { + return false; + }; + let resolved = if src.is_absolute() { + src.to_path_buf() + } else { + dst_parent.join(src) + }; + let wide = match widestring::WideCString::from_os_str(&resolved) { + Ok(wide) => wide, + Err(_) => return false, + }; + let mut info: WIN32_FILE_ATTRIBUTE_DATA = unsafe { core::mem::zeroed() }; + let ok = unsafe { + GetFileAttributesExW( + wide.as_ptr(), + GetFileExInfoStandard, + (&mut info as *mut WIN32_FILE_ATTRIBUTE_DATA).cast(), + ) + }; + ok != 0 && (info.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) != 0 + } + + let mut flags = 0u32; + if HAS_UNPRIVILEGED_FLAG.load(Ordering::Relaxed) { + flags |= SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE; + } + if target_is_directory || check_dir(src, dst) { + flags |= SYMBOLIC_LINK_FLAG_DIRECTORY; + } + + let mut result = unsafe { CreateSymbolicLinkW(dst_wide.as_ptr(), src_wide.as_ptr(), flags) }; + if !result + && HAS_UNPRIVILEGED_FLAG.load(Ordering::Relaxed) + && unsafe { windows_sys::Win32::Foundation::GetLastError() } + == windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER + { + let flags = flags & !SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE; + result = unsafe { CreateSymbolicLinkW(dst_wide.as_ptr(), src_wide.as_ptr(), flags) }; + if result + || unsafe { windows_sys::Win32::Foundation::GetLastError() } + != windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER + { + HAS_UNPRIVILEGED_FLAG.store(false, Ordering::Relaxed); + } + } + + if result { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } +} + #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn win32_hchmod(handle: HANDLE, mode: u32, write_bit: u32) -> io::Result<()> { let mut info: FILE_BASIC_INFO = unsafe { core::mem::zeroed() }; - let ret = unsafe { + unsafe { GetFileInformationByHandleEx( handle, FileBasicInfo, (&mut info as *mut FILE_BASIC_INFO).cast(), core::mem::size_of::() as u32, ) - }; - if ret == 0 { - return Err(io::Error::last_os_error()); } + .check_win32_bool()?; if mode & write_bit != 0 { info.FileAttributes &= !FILE_ATTRIBUTE_READONLY; @@ -32,19 +363,15 @@ pub fn win32_hchmod(handle: HANDLE, mode: u32, write_bit: u32) -> io::Result<()> info.FileAttributes |= FILE_ATTRIBUTE_READONLY; } - let ret = unsafe { + unsafe { SetFileInformationByHandle( handle, FileBasicInfo, (&info as *const FILE_BASIC_INFO).cast(), core::mem::size_of::() as u32, ) - }; - if ret == 0 { - Err(io::Error::last_os_error()) - } else { - Ok(()) } + .check_win32_bool() } pub fn fchmod(fd: i32, mode: u32, write_bit: u32) -> io::Result<()> { @@ -55,19 +382,1881 @@ pub fn fchmod(fd: i32, mode: u32, write_bit: u32) -> io::Result<()> { pub fn win32_lchmod(path: &OsStr, mode: u32, write_bit: u32) -> io::Result<()> { let wide = path.to_wide_with_nul(); - let attr = unsafe { GetFileAttributesW(wide.as_ptr()) }; - if attr == INVALID_FILE_ATTRIBUTES { - return Err(io::Error::last_os_error()); - } + let attr = unsafe { GetFileAttributesW(wide.as_ptr()) }.check_ne(INVALID_FILE_ATTRIBUTES)?; let new_attr = if mode & write_bit != 0 { attr & !FILE_ATTRIBUTE_READONLY } else { attr | FILE_ATTRIBUTE_READONLY }; - let ret = unsafe { SetFileAttributesW(wide.as_ptr(), new_attr) }; - if ret == 0 { + unsafe { SetFileAttributesW(wide.as_ptr(), new_attr) }.check_win32_bool() +} + +pub fn chmod_follow(path: &widestring::WideCStr, mode: u32, write_bit: u32) -> io::Result<()> { + use windows_sys::Win32::Storage::FileSystem::{ + FILE_FLAG_BACKUP_SEMANTICS, FILE_READ_ATTRIBUTES, FILE_SHARE_DELETE, FILE_SHARE_READ, + FILE_SHARE_WRITE, FILE_WRITE_ATTRIBUTES, OPEN_EXISTING, + }; + + let handle = unsafe { + CreateFileW( + path.as_ptr(), + FILE_READ_ATTRIBUTES | FILE_WRITE_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, + core::ptr::null_mut(), + ) + }; + use std::os::windows::io::AsRawHandle; + let handle = handle.into_owned().ok_or_else(io::Error::last_os_error)?; + win32_hchmod(handle.as_raw_handle() as HANDLE, mode, write_bit) +} + +pub fn find_first_file_name(path: &Path) -> io::Result { + let wide_path = path.as_os_str().to_wide_with_nul(); + let mut find_data: WIN32_FIND_DATAW = unsafe { core::mem::zeroed() }; + + let handle = unsafe { FindFirstFileW(wide_path.as_ptr(), &mut find_data) }.check_valid()?; + unsafe { FindClose(handle) }; + + let len = find_data + .cFileName + .iter() + .position(|&c| c == 0) + .unwrap_or(find_data.cFileName.len()); + Ok(OsString::from_wide(&find_data.cFileName[..len])) +} + +pub fn path_isdevdrive(path: &Path) -> io::Result { + use windows_sys::Win32::Storage::FileSystem::{ + FILE_SHARE_READ, FILE_SHARE_WRITE, GetDriveTypeW, GetVolumePathNameW, + }; + use windows_sys::Win32::System::IO::DeviceIoControl; + use windows_sys::Win32::System::Ioctl::FSCTL_QUERY_PERSISTENT_VOLUME_STATE; + use windows_sys::Win32::System::WindowsProgramming::DRIVE_FIXED; + + const PERSISTENT_VOLUME_STATE_DEV_VOLUME: u32 = 0x0000_2000; + + #[repr(C)] + struct FileFsPersistentVolumeInformation { + volume_flags: u32, + flag_mask: u32, + version: u32, + reserved: u32, + } + + let wide_path = path.as_os_str().to_wide_with_nul(); + let mut volume = [0u16; MAX_PATH as usize]; + unsafe { GetVolumePathNameW(wide_path.as_ptr(), volume.as_mut_ptr(), volume.len() as _) } + .check_win32_bool()?; + if unsafe { GetDriveTypeW(volume.as_ptr()) } != DRIVE_FIXED { + return Ok(false); + } + + let handle = unsafe { + CreateFileW( + volume.as_ptr(), + FILE_READ_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, + core::ptr::null_mut(), + ) + } + .check_valid()?; + + let mut volume_state = FileFsPersistentVolumeInformation { + volume_flags: 0, + flag_mask: PERSISTENT_VOLUME_STATE_DEV_VOLUME, + version: 1, + reserved: 0, + }; + let ok = unsafe { + DeviceIoControl( + handle, + FSCTL_QUERY_PERSISTENT_VOLUME_STATE, + (&volume_state as *const FileFsPersistentVolumeInformation).cast(), + core::mem::size_of::() as u32, + (&mut volume_state as *mut FileFsPersistentVolumeInformation).cast(), + core::mem::size_of::() as u32, + core::ptr::null_mut(), + core::ptr::null_mut(), + ) + }; + unsafe { CloseHandle(handle) }; + + if ok == 0 { + let err = io::Error::last_os_error(); + if err.raw_os_error() + == Some(windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER as i32) + { + return Ok(false); + } + return Err(err); + } + + Ok((volume_state.volume_flags & PERSISTENT_VOLUME_STATE_DEV_VOLUME) != 0) +} + +pub fn is_reparse_tag_name_surrogate(tag: u32) -> bool { + (tag & 0x20000000) != 0 +} + +pub fn file_info_error_is_trustworthy(error: u32) -> bool { + use windows_sys::Win32::Foundation; + matches!( + error, + Foundation::ERROR_FILE_NOT_FOUND + | Foundation::ERROR_PATH_NOT_FOUND + | Foundation::ERROR_NOT_READY + | Foundation::ERROR_BAD_NET_NAME + | Foundation::ERROR_BAD_NETPATH + | Foundation::ERROR_BAD_PATHNAME + | Foundation::ERROR_INVALID_NAME + | Foundation::ERROR_FILENAME_EXCED_RANGE + ) +} + +pub fn test_info( + attributes: u32, + reparse_tag: u32, + disk_device: bool, + tested_type: TestType, +) -> bool { + use windows_sys::Win32::Storage::FileSystem::{ + FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_REPARSE_POINT, + }; + use windows_sys::Win32::System::SystemServices::{ + IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, + }; + + match tested_type { + TestType::RegularFile => { + disk_device && attributes != 0 && (attributes & FILE_ATTRIBUTE_DIRECTORY) == 0 + } + TestType::Directory => (attributes & FILE_ATTRIBUTE_DIRECTORY) != 0, + TestType::Symlink => { + (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 + && reparse_tag == IO_REPARSE_TAG_SYMLINK + } + TestType::Junction => { + (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 + && reparse_tag == IO_REPARSE_TAG_MOUNT_POINT + } + TestType::LinkReparsePoint => { + (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 + && is_reparse_tag_name_surrogate(reparse_tag) + } + TestType::RegularReparsePoint => { + (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 + && reparse_tag != 0 + && !is_reparse_tag_name_surrogate(reparse_tag) + } + } +} + +pub fn test_file_type_by_handle(handle: HANDLE, tested_type: TestType, disk_only: bool) -> bool { + use windows_sys::Win32::Storage::FileSystem::{ + FILE_ATTRIBUTE_TAG_INFO, FILE_TYPE_DISK, FileAttributeTagInfo as FileAttributeTagInfoClass, + }; + + let disk_device = unsafe { GetFileType(handle) } == FILE_TYPE_DISK; + if disk_only && !disk_device { + return false; + } + + if tested_type != TestType::RegularFile && tested_type != TestType::Directory { + let mut info: FILE_ATTRIBUTE_TAG_INFO = unsafe { core::mem::zeroed() }; + let ret = unsafe { + GetFileInformationByHandleEx( + handle, + FileAttributeTagInfoClass, + (&mut info as *mut FILE_ATTRIBUTE_TAG_INFO).cast(), + core::mem::size_of::() as u32, + ) + }; + if ret == 0 { + return false; + } + test_info( + info.FileAttributes, + info.ReparseTag, + disk_device, + tested_type, + ) + } else { + let mut info: FILE_BASIC_INFO = unsafe { core::mem::zeroed() }; + let ret = unsafe { + GetFileInformationByHandleEx( + handle, + FileBasicInfo, + (&mut info as *mut FILE_BASIC_INFO).cast(), + core::mem::size_of::() as u32, + ) + }; + if ret == 0 { + return false; + } + test_info(info.FileAttributes, 0, disk_device, tested_type) + } +} + +fn win32_xstat_attributes_from_dir( + path: &OsStr, +) -> io::Result<( + windows_sys::Win32::Storage::FileSystem::BY_HANDLE_FILE_INFORMATION, + u32, +)> { + use windows_sys::Win32::Storage::FileSystem::{ + BY_HANDLE_FILE_INFORMATION, FILE_ATTRIBUTE_REPARSE_POINT, + }; + + let wide: Vec = path.to_wide_with_nul(); + let mut find_data: WIN32_FIND_DATAW = unsafe { core::mem::zeroed() }; + + let handle = unsafe { FindFirstFileW(wide.as_ptr(), &mut find_data) }.check_valid()?; + unsafe { FindClose(handle) }; + + let mut info: BY_HANDLE_FILE_INFORMATION = unsafe { core::mem::zeroed() }; + info.dwFileAttributes = find_data.dwFileAttributes; + info.ftCreationTime = find_data.ftCreationTime; + info.ftLastAccessTime = find_data.ftLastAccessTime; + info.ftLastWriteTime = find_data.ftLastWriteTime; + info.nFileSizeHigh = find_data.nFileSizeHigh; + info.nFileSizeLow = find_data.nFileSizeLow; + + let reparse_tag = if find_data.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 { + find_data.dwReserved0 + } else { + 0 + }; + + Ok((info, reparse_tag)) +} + +fn win32_xstat_slow_impl(path: &OsStr, traverse: bool) -> io::Result { + use windows_sys::Win32::{ + Foundation::{ + ERROR_ACCESS_DENIED, ERROR_CANT_ACCESS_FILE, ERROR_INVALID_FUNCTION, + ERROR_INVALID_PARAMETER, ERROR_NOT_SUPPORTED, ERROR_SHARING_VIOLATION, GENERIC_READ, + }, + Storage::FileSystem::{ + BY_HANDLE_FILE_INFORMATION, FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_NORMAL, + FILE_ATTRIBUTE_REPARSE_POINT, FILE_BASIC_INFO, FILE_ID_INFO, FILE_SHARE_READ, + FILE_SHARE_WRITE, FILE_TYPE_CHAR, FILE_TYPE_PIPE, + FileAttributeTagInfo as FileAttributeTagInfoClass, FileBasicInfo, FileIdInfo, + GetFileAttributesW, GetFileInformationByHandle, + }, + }; + + let wide: Vec = path.to_wide_with_nul(); + let access = FILE_READ_ATTRIBUTES; + let mut flags = FILE_FLAG_BACKUP_SEMANTICS; + if !traverse { + flags |= FILE_FLAG_OPEN_REPARSE_POINT; + } + + let mut h_file = unsafe { + CreateFileW( + wide.as_ptr(), + access, + 0, + core::ptr::null(), + OPEN_EXISTING, + flags, + core::ptr::null_mut(), + ) + }; + + let mut file_info: BY_HANDLE_FILE_INFORMATION = unsafe { core::mem::zeroed() }; + let mut tag_info = FileAttributeTagInfo::default(); + let mut is_unhandled_tag = false; + + if h_file == INVALID_HANDLE_VALUE { + let error = io::Error::last_os_error(); + match error.raw_os_error().unwrap_or(0) as u32 { + ERROR_ACCESS_DENIED | ERROR_SHARING_VIOLATION => { + let (info, reparse_tag) = win32_xstat_attributes_from_dir(path)?; + file_info = info; + tag_info.reparse_tag = reparse_tag; + + if file_info.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 + && (traverse || !is_reparse_tag_name_surrogate(tag_info.reparse_tag)) + { + return Err(error); + } + } + ERROR_INVALID_PARAMETER => { + h_file = unsafe { + CreateFileW( + wide.as_ptr(), + access | GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, + core::ptr::null(), + OPEN_EXISTING, + flags, + core::ptr::null_mut(), + ) + }; + if h_file == INVALID_HANDLE_VALUE { + return Err(error); + } + } + ERROR_CANT_ACCESS_FILE if traverse => { + is_unhandled_tag = true; + h_file = unsafe { + CreateFileW( + wide.as_ptr(), + access, + 0, + core::ptr::null(), + OPEN_EXISTING, + flags | FILE_FLAG_OPEN_REPARSE_POINT, + core::ptr::null_mut(), + ) + }; + if h_file == INVALID_HANDLE_VALUE { + return Err(error); + } + } + _ => return Err(error), + } + } + + let result = (|| -> io::Result { + if h_file != INVALID_HANDLE_VALUE { + let file_type = unsafe { GetFileType(h_file) }; + if file_type != windows_sys::Win32::Storage::FileSystem::FILE_TYPE_DISK { + if file_type == FILE_TYPE_UNKNOWN { + let err = io::Error::last_os_error(); + if err.raw_os_error().unwrap_or(0) != 0 { + return Err(err); + } + } + let file_attributes = unsafe { GetFileAttributesW(wide.as_ptr()) }; + let mut st_mode = 0; + if file_attributes != INVALID_FILE_ATTRIBUTES + && file_attributes & FILE_ATTRIBUTE_DIRECTORY != 0 + { + st_mode = S_IFDIR_MODE; + } else if file_type == FILE_TYPE_CHAR { + st_mode = S_IFCHR_MODE; + } else if file_type == FILE_TYPE_PIPE { + st_mode = S_IFIFO_MODE; + } + return Ok(StatStruct { + st_mode, + ..Default::default() + }); + } + + if !traverse || is_unhandled_tag { + let mut local_tag_info: FileAttributeTagInfo = unsafe { core::mem::zeroed() }; + let ret = unsafe { + GetFileInformationByHandleEx( + h_file, + FileAttributeTagInfoClass, + (&mut local_tag_info as *mut FileAttributeTagInfo).cast(), + core::mem::size_of::() as u32, + ) + }; + if ret == 0 { + match io::Error::last_os_error().raw_os_error().unwrap_or(0) as u32 { + ERROR_INVALID_PARAMETER | ERROR_INVALID_FUNCTION | ERROR_NOT_SUPPORTED => { + local_tag_info.file_attributes = FILE_ATTRIBUTE_NORMAL; + local_tag_info.reparse_tag = 0; + } + _ => return Err(io::Error::last_os_error()), + } + } else if local_tag_info.file_attributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 { + if is_reparse_tag_name_surrogate(local_tag_info.reparse_tag) { + if is_unhandled_tag { + return Err(io::Error::from_raw_os_error( + ERROR_CANT_ACCESS_FILE as i32, + )); + } + } else if !is_unhandled_tag { + unsafe { CloseHandle(h_file) }; + h_file = INVALID_HANDLE_VALUE; + return win32_xstat_slow_impl(path, true); + } + } + tag_info = local_tag_info; + } + + if unsafe { GetFileInformationByHandle(h_file, &mut file_info) } == 0 { + match io::Error::last_os_error().raw_os_error().unwrap_or(0) as u32 { + ERROR_INVALID_PARAMETER | ERROR_INVALID_FUNCTION | ERROR_NOT_SUPPORTED => { + return Ok(StatStruct { + st_mode: 0x6000, + ..Default::default() + }); + } + _ => return Err(io::Error::last_os_error()), + } + } + + let mut basic_info: FILE_BASIC_INFO = unsafe { core::mem::zeroed() }; + let has_basic_info = unsafe { + GetFileInformationByHandleEx( + h_file, + FileBasicInfo, + (&mut basic_info as *mut FILE_BASIC_INFO).cast(), + core::mem::size_of::() as u32, + ) + } != 0; + + let mut id_info: FILE_ID_INFO = unsafe { core::mem::zeroed() }; + let has_id_info = unsafe { + GetFileInformationByHandleEx( + h_file, + FileIdInfo, + (&mut id_info as *mut FILE_ID_INFO).cast(), + core::mem::size_of::() as u32, + ) + } != 0; + + let mut result = win32_attribute_data_to_stat( + &file_info, + tag_info.reparse_tag, + if has_basic_info { + Some(&basic_info) + } else { + None + }, + if has_id_info { Some(&id_info) } else { None }, + ); + result.update_st_mode_from_path(path, file_info.dwFileAttributes); + Ok(result) + } else { + let mut result = + win32_attribute_data_to_stat(&file_info, tag_info.reparse_tag, None, None); + result.update_st_mode_from_path(path, file_info.dwFileAttributes); + Ok(result) + } + })(); + + if h_file != INVALID_HANDLE_VALUE { + unsafe { CloseHandle(h_file) }; + } + result +} + +pub fn win32_xstat(path: &OsStr, traverse: bool) -> io::Result { + use windows_sys::Win32::{Foundation, Storage::FileSystem::FILE_ATTRIBUTE_REPARSE_POINT}; + + match get_file_information_by_name(path, FILE_INFO_BY_NAME_CLASS::FileStatBasicByNameInfo) { + Ok(stat_info) => { + if (stat_info.FileAttributes & FILE_ATTRIBUTE_REPARSE_POINT == 0) + || (!traverse && is_reparse_tag_name_surrogate(stat_info.ReparseTag)) + { + let mut result = stat_basic_info_to_stat(&stat_info); + if result.st_ino != 0 || result.st_ino_high != 0 { + result.update_st_mode_from_path(path, stat_info.FileAttributes); + result.st_ctime = result.st_birthtime; + result.st_ctime_nsec = result.st_birthtime_nsec; + return Ok(result); + } + } + } + Err(err) => { + if let Some(errno) = err.raw_os_error() + && matches!( + errno as u32, + Foundation::ERROR_FILE_NOT_FOUND + | Foundation::ERROR_PATH_NOT_FOUND + | Foundation::ERROR_NOT_READY + | Foundation::ERROR_BAD_NET_NAME + ) + { + return Err(err); + } + } + } + + let mut result = win32_xstat_slow_impl(path, traverse)?; + result.st_ctime = result.st_birthtime; + result.st_ctime_nsec = result.st_birthtime_nsec; + Ok(result) +} + +pub fn test_file_type_by_name(path: &Path, tested_type: TestType) -> bool { + match get_file_information_by_name( + path.as_os_str(), + FILE_INFO_BY_NAME_CLASS::FileStatBasicByNameInfo, + ) { + Ok(info) => { + let disk_device = matches!( + info.DeviceType, + windows_sys::Win32::Storage::FileSystem::FILE_DEVICE_DISK + | windows_sys::Win32::System::Ioctl::FILE_DEVICE_VIRTUAL_DISK + | windows_sys::Win32::Storage::FileSystem::FILE_DEVICE_CD_ROM + ); + let result = test_info( + info.FileAttributes, + info.ReparseTag, + disk_device, + tested_type, + ); + if !result + || !matches!(tested_type, TestType::RegularFile | TestType::Directory) + || (info.FileAttributes + & windows_sys::Win32::Storage::FileSystem::FILE_ATTRIBUTE_REPARSE_POINT) + == 0 + { + return result; + } + } + Err(err) => { + if let Some(code) = err.raw_os_error() + && file_info_error_is_trustworthy(code as u32) + { + return false; + } + } + } + + let mut flags = FILE_FLAG_BACKUP_SEMANTICS; + if !matches!(tested_type, TestType::RegularFile | TestType::Directory) { + flags |= FILE_FLAG_OPEN_REPARSE_POINT; + } + let wide_path = path.as_os_str().to_wide_with_nul(); + let handle = unsafe { + CreateFileW( + wide_path.as_ptr(), + FILE_READ_ATTRIBUTES, + 0, + core::ptr::null(), + OPEN_EXISTING, + flags, + core::ptr::null_mut(), + ) + }; + if handle != INVALID_HANDLE_VALUE { + let result = test_file_type_by_handle(handle, tested_type, false); + unsafe { CloseHandle(handle) }; + return result; + } + + match unsafe { GetLastError() } { + windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED + | windows_sys::Win32::Foundation::ERROR_SHARING_VIOLATION + | windows_sys::Win32::Foundation::ERROR_CANT_ACCESS_FILE + | windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER => { + let stat = win32_xstat( + path.as_os_str(), + matches!(tested_type, TestType::RegularFile | TestType::Directory), + ); + if let Ok(st) = stat { + let disk_device = (st.st_mode & libc::S_IFREG as u16) != 0; + return test_info( + st.st_file_attributes, + st.st_reparse_tag, + disk_device, + tested_type, + ); + } + } + _ => {} + } + + false +} + +pub fn test_file_exists_by_name(path: &Path, follow_links: bool) -> bool { + match get_file_information_by_name( + path.as_os_str(), + FILE_INFO_BY_NAME_CLASS::FileStatBasicByNameInfo, + ) { + Ok(info) => { + if (info.FileAttributes + & windows_sys::Win32::Storage::FileSystem::FILE_ATTRIBUTE_REPARSE_POINT) + == 0 + || (!follow_links && is_reparse_tag_name_surrogate(info.ReparseTag)) + { + return true; + } + } + Err(err) => { + if let Some(code) = err.raw_os_error() + && file_info_error_is_trustworthy(code as u32) + { + return false; + } + } + } + + let wide_path = path.as_os_str().to_wide_with_nul(); + let mut flags = FILE_FLAG_BACKUP_SEMANTICS; + if !follow_links { + flags |= FILE_FLAG_OPEN_REPARSE_POINT; + } + let handle = unsafe { + CreateFileW( + wide_path.as_ptr(), + FILE_READ_ATTRIBUTES, + 0, + core::ptr::null(), + OPEN_EXISTING, + flags, + core::ptr::null_mut(), + ) + }; + if handle != INVALID_HANDLE_VALUE { + if follow_links { + unsafe { CloseHandle(handle) }; + return true; + } + let is_regular_reparse_point = + test_file_type_by_handle(handle, TestType::RegularReparsePoint, false); + unsafe { CloseHandle(handle) }; + if !is_regular_reparse_point { + return true; + } + let handle = unsafe { + CreateFileW( + wide_path.as_ptr(), + FILE_READ_ATTRIBUTES, + 0, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, + core::ptr::null_mut(), + ) + }; + if handle != INVALID_HANDLE_VALUE { + unsafe { CloseHandle(handle) }; + return true; + } + } + + match unsafe { GetLastError() } { + windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED + | windows_sys::Win32::Foundation::ERROR_SHARING_VIOLATION + | windows_sys::Win32::Foundation::ERROR_CANT_ACCESS_FILE + | windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER => { + return win32_xstat(path.as_os_str(), follow_links).is_ok(); + } + _ => {} + } + + false +} + +pub fn path_exists_via_open(path: &Path, follow_links: bool) -> bool { + let wide_path = path.as_os_str().to_wide_with_nul(); + let mut flags = FILE_FLAG_BACKUP_SEMANTICS; + if !follow_links { + flags |= FILE_FLAG_OPEN_REPARSE_POINT; + } + let handle = unsafe { + CreateFileW( + wide_path.as_ptr(), + FILE_READ_ATTRIBUTES, + 0, + core::ptr::null(), + OPEN_EXISTING, + flags, + core::ptr::null_mut(), + ) + }; + if handle != INVALID_HANDLE_VALUE { + if follow_links { + unsafe { CloseHandle(handle) }; + return true; + } + let is_regular_reparse_point = + test_file_type_by_handle(handle, TestType::RegularReparsePoint, false); + unsafe { CloseHandle(handle) }; + if !is_regular_reparse_point { + return true; + } + let handle = unsafe { + CreateFileW( + wide_path.as_ptr(), + FILE_READ_ATTRIBUTES, + 0, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, + core::ptr::null_mut(), + ) + }; + if handle != INVALID_HANDLE_VALUE { + unsafe { CloseHandle(handle) }; + return true; + } + } + false +} + +pub fn fd_exists(fd: crate::crt_fd::Borrowed<'_>) -> bool { + let handle = match crate::crt_fd::as_handle(fd) { + Ok(handle) => handle, + Err(_) => return false, + }; + let file_type = unsafe { GetFileType(handle.as_raw_handle() as _) }; + if file_type != FILE_TYPE_UNKNOWN { + true + } else { + unsafe { GetLastError() == 0 } + } +} + +pub fn pipe() -> io::Result<(i32, i32)> { + use std::os::windows::io::{AsRawHandle, IntoRawHandle}; + use windows_sys::Win32::Security::SECURITY_ATTRIBUTES; + use windows_sys::Win32::System::Pipes::CreatePipe; + + let mut attr = SECURITY_ATTRIBUTES { + nLength: core::mem::size_of::() as u32, + lpSecurityDescriptor: core::ptr::null_mut(), + bInheritHandle: 0, + }; + + let (read_handle, write_handle) = unsafe { + let mut read = core::mem::MaybeUninit::::uninit(); + let mut write = core::mem::MaybeUninit::::uninit(); + CreatePipe( + read.as_mut_ptr() as *mut _, + write.as_mut_ptr() as *mut _, + &mut attr as *mut _, + 0, + ) + .check_win32_bool()?; + (read.assume_init() as HANDLE, write.assume_init() as HANDLE) + }; + // RAII wrappers: both handles are auto-closed on any early return below. + let read_handle = read_handle + .into_owned() + .expect("CreatePipe returned valid read handle"); + let write_handle = write_handle + .into_owned() + .expect("CreatePipe returned valid write handle"); + + const O_NOINHERIT: i32 = 0x80; + let read_fd = crate::msvcrt::open_osfhandle(read_handle.as_raw_handle() as isize, O_NOINHERIT)?; + // Ownership of the read handle now belongs to the CRT fd. + let _ = read_handle.into_raw_handle(); + + let write_fd = match crate::msvcrt::open_osfhandle( + write_handle.as_raw_handle() as isize, + libc::O_WRONLY | O_NOINHERIT, + ) { + Ok(fd) => { + let _ = write_handle.into_raw_handle(); + fd + } + Err(err) => { + // Close the CRT fd we already created; `write_handle` auto-closes via Drop. + let _ = unsafe { crt_fd::Owned::from_raw(read_fd) }; + return Err(err); + } + }; + + Ok((read_fd, write_fd)) +} + +pub fn mkdir(path: &widestring::WideCStr, mode: i32) -> io::Result<()> { + use windows_sys::Win32::Foundation::LocalFree; + use windows_sys::Win32::Security::Authorization::{ + ConvertStringSecurityDescriptorToSecurityDescriptorW, SDDL_REVISION_1, + }; + use windows_sys::Win32::Security::SECURITY_ATTRIBUTES; + + let ok = if mode == 0o700 { + let mut sec_attr = SECURITY_ATTRIBUTES { + nLength: core::mem::size_of::() as u32, + lpSecurityDescriptor: core::ptr::null_mut(), + bInheritHandle: 0, + }; + let sddl: Vec = "D:P(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)(A;OICI;FA;;;OW)\0" + .encode_utf16() + .collect(); + unsafe { + ConvertStringSecurityDescriptorToSecurityDescriptorW( + sddl.as_ptr(), + SDDL_REVISION_1, + &mut sec_attr.lpSecurityDescriptor, + core::ptr::null_mut(), + ) + } + .check_win32_bool()?; + let ok = unsafe { + windows_sys::Win32::Storage::FileSystem::CreateDirectoryW( + path.as_ptr(), + (&sec_attr as *const SECURITY_ATTRIBUTES).cast(), + ) + }; + unsafe { LocalFree(sec_attr.lpSecurityDescriptor) }; + ok + } else { + unsafe { + windows_sys::Win32::Storage::FileSystem::CreateDirectoryW( + path.as_ptr(), + core::ptr::null_mut(), + ) + } + }; + + ok.check_win32_bool() +} + +unsafe extern "C" { + fn _umask(mask: i32) -> i32; + fn _wputenv(envstring: *const u16) -> libc::c_int; +} + +pub fn umask(mask: i32) -> io::Result { + let result = unsafe { _umask(mask) }; + if result < 0 { + Err(crate::os::errno_io_error()) + } else { + Ok(result) + } +} + +/// Update the CRT environment via `_wputenv`. +/// `envstring` must point to a nul-terminated wide string of the form `KEY=value`. +pub fn wputenv(envstring: &widestring::WideCStr) -> io::Result<()> { + let result = unsafe { crate::suppress_iph!(_wputenv(envstring.as_ptr())) }; + if result != 0 { + Err(crate::os::errno_io_error()) + } else { + Ok(()) + } +} + +fn set_fd_inheritable(fd: i32, inheritable: bool) -> io::Result<()> { + let borrowed = unsafe { crt_fd::Borrowed::borrow_raw(fd) }; + let handle = crt_fd::as_handle(borrowed)?; + set_handle_inheritable(handle.as_raw_handle() as _, inheritable) +} + +pub fn dup(fd: i32) -> io::Result { + let fd2 = unsafe { crate::suppress_iph!(libc::dup(fd)) }; + if fd2 < 0 { + return Err(crate::os::errno_io_error()); + } + if let Err(err) = set_fd_inheritable(fd2, false) { + let _ = unsafe { crt_fd::Owned::from_raw(fd2) }; + return Err(err); + } + Ok(fd2) +} + +pub fn dup2(fd: i32, fd2: i32, inheritable: bool) -> io::Result { + let result = unsafe { crate::suppress_iph!(libc::dup2(fd, fd2)) }; + if result < 0 { + return Err(crate::os::errno_io_error()); + } + if !inheritable && let Err(err) = set_fd_inheritable(fd2, false) { + let _ = unsafe { crt_fd::Owned::from_raw(fd2) }; + return Err(err); + } + Ok(fd2) +} + +pub fn readlink(path: &Path) -> Result { + use windows_sys::Win32::Storage::FileSystem::{ + FILE_FLAG_BACKUP_SEMANTICS, FILE_FLAG_OPEN_REPARSE_POINT, FILE_SHARE_DELETE, + FILE_SHARE_READ, FILE_SHARE_WRITE, + }; + use windows_sys::Win32::System::IO::DeviceIoControl; + use windows_sys::Win32::System::Ioctl::FSCTL_GET_REPARSE_POINT; + use windows_sys::Win32::System::SystemServices::{ + IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, + }; + + let wide_path = path.as_os_str().to_wide_with_nul(); + let handle = unsafe { + CreateFileW( + wide_path.as_ptr(), + 0, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, + core::ptr::null_mut(), + ) + }; + + if handle == INVALID_HANDLE_VALUE { + return Err(ReadlinkError::Io(io::Error::last_os_error())); + } + + const BUFFER_SIZE: usize = 16384; + let mut buffer = vec![0u8; BUFFER_SIZE]; + let mut bytes_returned: u32 = 0; + let ok = unsafe { + DeviceIoControl( + handle, + FSCTL_GET_REPARSE_POINT, + core::ptr::null(), + 0, + buffer.as_mut_ptr() as *mut _, + BUFFER_SIZE as u32, + &mut bytes_returned, + core::ptr::null_mut(), + ) + }; + unsafe { CloseHandle(handle) }; + if ok == 0 { + return Err(ReadlinkError::Io(io::Error::last_os_error())); + } + + let reparse_tag = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); + let (substitute_offset, substitute_length, path_buffer_start) = + if reparse_tag == IO_REPARSE_TAG_SYMLINK { + ( + u16::from_le_bytes([buffer[8], buffer[9]]) as usize, + u16::from_le_bytes([buffer[10], buffer[11]]) as usize, + 20usize, + ) + } else if reparse_tag == IO_REPARSE_TAG_MOUNT_POINT { + ( + u16::from_le_bytes([buffer[8], buffer[9]]) as usize, + u16::from_le_bytes([buffer[10], buffer[11]]) as usize, + 16usize, + ) + } else { + return Err(ReadlinkError::NotSymbolicLink); + }; + + let path_start = path_buffer_start + substitute_offset; + let path_end = path_start + substitute_length; + if path_end > buffer.len() { + return Err(ReadlinkError::InvalidReparseData); + } + + let path_slice = &buffer[path_start..path_end]; + let mut wide_chars: Vec = path_slice + .chunks_exact(2) + .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .collect(); + + if wide_chars.len() > 4 + && wide_chars[0] == b'\\' as u16 + && wide_chars[1] == b'?' as u16 + && wide_chars[2] == b'?' as u16 + && wide_chars[3] == b'\\' as u16 + { + wide_chars[1] = b'\\' as u16; + } + + Ok(OsString::from_wide(&wide_chars)) +} + +pub fn kill(pid: u32, sig: u32) -> io::Result<()> { + if sig == Console::CTRL_C_EVENT || sig == Console::CTRL_BREAK_EVENT { + let ok = unsafe { Console::GenerateConsoleCtrlEvent(sig, pid) }; + if ok == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } else { + let handle = unsafe { Threading::OpenProcess(Threading::PROCESS_ALL_ACCESS, 0, pid) } + .check_nonnull()?; + let result = unsafe { Threading::TerminateProcess(handle, sig) }.check_win32_bool(); + unsafe { CloseHandle(handle) }; + result + } +} + +pub fn getfinalpathname(path: &Path) -> io::Result { + use windows_sys::Win32::Storage::FileSystem::{GetFinalPathNameByHandleW, VOLUME_NAME_DOS}; + + let wide = path.as_os_str().to_wide_with_nul(); + let handle = unsafe { + CreateFileW( + wide.as_ptr(), + 0, + 0, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, + core::ptr::null_mut(), + ) + } + .check_valid()?; + + let mut buffer = vec![0u16; MAX_PATH as usize]; + let result = loop { + let ret = unsafe { + GetFinalPathNameByHandleW( + handle, + buffer.as_mut_ptr(), + buffer.len() as u32, + VOLUME_NAME_DOS, + ) + }; + if ret == 0 { + break Err(io::Error::last_os_error()); + } + if ret as usize >= buffer.len() { + buffer.resize(ret as usize, 0); + continue; + } + break Ok(OsString::from_wide(&buffer[..ret as usize])); + }; + unsafe { CloseHandle(handle) }; + result +} + +pub fn getfullpathname(path: &Path) -> io::Result { + let wide = path.as_os_str().to_wide_with_nul(); + let mut buffer = vec![0u16; MAX_PATH as usize]; + let mut ret = unsafe { + windows_sys::Win32::Storage::FileSystem::GetFullPathNameW( + wide.as_ptr(), + buffer.len() as u32, + buffer.as_mut_ptr(), + core::ptr::null_mut(), + ) + } + .check_ne(0)?; + if ret as usize > buffer.len() { + buffer.resize(ret as usize, 0); + ret = unsafe { + windows_sys::Win32::Storage::FileSystem::GetFullPathNameW( + wide.as_ptr(), + buffer.len() as u32, + buffer.as_mut_ptr(), + core::ptr::null_mut(), + ) + } + .check_ne(0)?; + } + let _ = ret; + Ok(widestring::WideCString::from_vec_truncate(buffer).to_os_string()) +} + +pub fn getvolumepathname(path: &Path) -> io::Result { + let wide = path.as_os_str().to_wide_with_nul(); + let buflen = core::cmp::max(wide.len(), MAX_PATH as usize); + let mut buffer = vec![0u16; buflen]; + unsafe { + windows_sys::Win32::Storage::FileSystem::GetVolumePathNameW( + wide.as_ptr(), + buffer.as_mut_ptr(), + buflen as u32, + ) + } + .check_win32_bool()?; + Ok(widestring::WideCString::from_vec_truncate(buffer).to_os_string()) +} + +pub fn getdiskusage(path: &Path) -> io::Result<(u64, u64)> { + use windows_sys::Win32::Storage::FileSystem::GetDiskFreeSpaceExW; + + let wide = path.as_os_str().to_wide_with_nul(); + let mut free_to_me = 0u64; + let mut total = 0u64; + let mut free = 0u64; + let ok = unsafe { GetDiskFreeSpaceExW(wide.as_ptr(), &mut free_to_me, &mut total, &mut free) }; + if ok != 0 { + return Ok((total, free)); + } + + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_DIRECTORY as i32) + && let Some(parent) = path.parent() + { + let parent = widestring::WideCString::from_os_str(parent).unwrap(); + let ok = + unsafe { GetDiskFreeSpaceExW(parent.as_ptr(), &mut free_to_me, &mut total, &mut free) }; + if ok != 0 { + return Ok((total, free)); + } + } + Err(err) +} + +pub fn get_handle_inheritable(handle: intptr_t) -> io::Result { + let mut flags = 0; + let ok = + unsafe { windows_sys::Win32::Foundation::GetHandleInformation(handle as _, &mut flags) }; + if ok == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(flags & windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT != 0) + } +} + +pub fn set_handle_inheritable(handle: intptr_t, inheritable: bool) -> io::Result<()> { + let flags = if inheritable { + windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT + } else { + 0 + }; + let ok = unsafe { + windows_sys::Win32::Foundation::SetHandleInformation( + handle as _, + windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT, + flags, + ) + }; + if ok == 0 { Err(io::Error::last_os_error()) } else { Ok(()) } } + +pub fn getlogin() -> io::Result { + let mut buffer = [0u16; 257]; + let mut size = buffer.len() as u32; + let ok = unsafe { + windows_sys::Win32::System::WindowsProgramming::GetUserNameW(buffer.as_mut_ptr(), &mut size) + }; + if ok == 0 { + return Err(io::Error::last_os_error()); + } + Ok(OsString::from_wide(&buffer[..(size - 1) as usize]) + .to_str() + .unwrap() + .to_string()) +} + +pub fn listdrives() -> io::Result> { + let mut buffer = [0u16; 256]; + let len = unsafe { + windows_sys::Win32::Storage::FileSystem::GetLogicalDriveStringsW( + buffer.len() as u32, + buffer.as_mut_ptr(), + ) + }; + if len == 0 { + return Err(io::Error::last_os_error()); + } + if len as usize >= buffer.len() { + return Err(io::Error::from_raw_os_error( + windows_sys::Win32::Foundation::ERROR_MORE_DATA as i32, + )); + } + Ok(buffer[..(len - 1) as usize] + .split(|&c| c == 0) + .map(OsString::from_wide) + .collect()) +} + +pub fn listvolumes() -> io::Result> { + let mut result = Vec::new(); + let mut buffer = [0u16; MAX_PATH as usize + 1]; + + let find = unsafe { + windows_sys::Win32::Storage::FileSystem::FindFirstVolumeW( + buffer.as_mut_ptr(), + buffer.len() as u32, + ) + }; + if find == INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + loop { + let len = buffer.iter().position(|&c| c == 0).unwrap_or(buffer.len()); + result.push(OsString::from_wide(&buffer[..len])); + + let ok = unsafe { + windows_sys::Win32::Storage::FileSystem::FindNextVolumeW( + find, + buffer.as_mut_ptr(), + buffer.len() as u32, + ) + }; + if ok == 0 { + let err = io::Error::last_os_error(); + unsafe { windows_sys::Win32::Storage::FileSystem::FindVolumeClose(find) }; + if err.raw_os_error() + == Some(windows_sys::Win32::Foundation::ERROR_NO_MORE_FILES as i32) + { + break; + } + return Err(err); + } + } + + Ok(result) +} + +pub fn listmounts(volume: &Path) -> io::Result> { + let wide = volume.as_os_str().to_wide_with_nul(); + let mut buflen: u32 = MAX_PATH + 1; + let mut buffer = vec![0u16; buflen as usize]; + + loop { + let ok = unsafe { + windows_sys::Win32::Storage::FileSystem::GetVolumePathNamesForVolumeNameW( + wide.as_ptr(), + buffer.as_mut_ptr(), + buflen, + &mut buflen, + ) + }; + if ok != 0 { + break; + } + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_MORE_DATA as i32) { + buffer.resize(buflen as usize, 0); + continue; + } + return Err(err); + } + + let mut result = Vec::new(); + let mut start = 0; + for (i, &c) in buffer.iter().enumerate() { + if c == 0 { + if i > start { + result.push(OsString::from_wide(&buffer[start..i])); + } + start = i + 1; + if start < buffer.len() && buffer[start] == 0 { + break; + } + } + } + Ok(result) +} + +pub fn getppid() -> u32 { + use windows_sys::Win32::System::Threading::{GetCurrentProcess, PROCESS_BASIC_INFORMATION}; + + type NtQueryInformationProcessFn = unsafe extern "system" fn( + process_handle: isize, + process_information_class: u32, + process_information: *mut core::ffi::c_void, + process_information_length: u32, + return_length: *mut u32, + ) -> i32; + + let ntdll = unsafe { + windows_sys::Win32::System::LibraryLoader::GetModuleHandleW(windows_sys::w!("ntdll.dll")) + }; + if ntdll.is_null() { + return 0; + } + + let func = unsafe { + windows_sys::Win32::System::LibraryLoader::GetProcAddress( + ntdll, + c"NtQueryInformationProcess".as_ptr() as *const u8, + ) + }; + let Some(func) = func else { + return 0; + }; + let nt_query: NtQueryInformationProcessFn = unsafe { core::mem::transmute(func) }; + + let mut info: PROCESS_BASIC_INFORMATION = unsafe { core::mem::zeroed() }; + let status = unsafe { + nt_query( + GetCurrentProcess() as isize, + 0, + (&mut info as *mut PROCESS_BASIC_INFORMATION).cast(), + core::mem::size_of::() as u32, + core::ptr::null_mut(), + ) + }; + + if status >= 0 + && info.InheritedFromUniqueProcessId != 0 + && info.InheritedFromUniqueProcessId < u32::MAX as usize + { + info.InheritedFromUniqueProcessId as u32 + } else { + 0 + } +} + +pub fn path_skip_root(path: &widestring::WideCStr) -> Option { + let mut end: *const u16 = core::ptr::null(); + let hr = unsafe { windows_sys::Win32::UI::Shell::PathCchSkipRoot(path.as_ptr(), &mut end) }; + if hr >= 0 { + assert!(!end.is_null()); + Some( + unsafe { end.offset_from(path.as_ptr()) } + .try_into() + .expect("len must be non-negative"), + ) + } else { + None + } +} + +pub fn get_terminal_size_handle(h: HANDLE) -> io::Result<(usize, usize)> { + let mut csbi = core::mem::MaybeUninit::uninit(); + let ret = unsafe { Console::GetConsoleScreenBufferInfo(h, csbi.as_mut_ptr()) }; + if ret == 0 { + let err = unsafe { GetLastError() }; + if err != windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED { + return Err(io::Error::last_os_error()); + } + let conout: Vec = "CONOUT$\0".encode_utf16().collect(); + let console_handle = unsafe { + CreateFileW( + conout.as_ptr(), + windows_sys::Win32::Foundation::GENERIC_READ + | windows_sys::Win32::Foundation::GENERIC_WRITE, + windows_sys::Win32::Storage::FileSystem::FILE_SHARE_READ + | windows_sys::Win32::Storage::FileSystem::FILE_SHARE_WRITE, + core::ptr::null(), + windows_sys::Win32::Storage::FileSystem::OPEN_EXISTING, + 0, + core::ptr::null_mut(), + ) + }; + if console_handle == INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + let ret = unsafe { Console::GetConsoleScreenBufferInfo(console_handle, csbi.as_mut_ptr()) }; + unsafe { CloseHandle(console_handle) }; + if ret == 0 { + return Err(io::Error::last_os_error()); + } + } + let csbi = unsafe { csbi.assume_init() }; + let window = csbi.srWindow; + let columns = (window.Right - window.Left + 1) as usize; + let lines = (window.Bottom - window.Top + 1) as usize; + Ok((columns, lines)) +} + +pub fn handle_from_fd(fd: i32) -> HANDLE { + unsafe { crate::suppress_iph!(libc::get_osfhandle(fd)) as HANDLE } +} + +pub fn console_type(handle: HANDLE) -> char { + if is_invalid_handle(handle) { + return '\0'; + } + let mut mode: u32 = 0; + if unsafe { Console::GetConsoleMode(handle, &mut mode) } == 0 { + return '\0'; + } + let mut peek_count: u32 = 0; + if unsafe { Console::GetNumberOfConsoleInputEvents(handle, &mut peek_count) } != 0 { + 'r' + } else { + 'w' + } +} + +pub fn is_invalid_handle(handle: Handle) -> bool { + handle == INVALID_HANDLE_VALUE || handle.is_null() +} + +pub fn console_type_from_fd(fd: i32) -> char { + if fd < 0 { + '\0' + } else { + console_type(handle_from_fd(fd)) + } +} + +pub fn console_type_from_name(name: &str) -> char { + if name.eq_ignore_ascii_case("CONIN$") { + return 'r'; + } + if name.eq_ignore_ascii_case("CONOUT$") { + return 'w'; + } + if name.eq_ignore_ascii_case("CON") { + return 'x'; + } + + let wide: Vec = name.encode_utf16().chain(core::iter::once(0)).collect(); + let mut buf = [0u16; MAX_PATH as usize]; + let length = unsafe { + GetFullPathNameW( + wide.as_ptr(), + buf.len() as u32, + buf.as_mut_ptr(), + core::ptr::null_mut(), + ) + }; + if length == 0 || length as usize > buf.len() { + return '\0'; + } + + let full_path = &buf[..length as usize]; + let path_part = if full_path.len() >= 4 + && full_path[0] == b'\\' as u16 + && full_path[1] == b'\\' as u16 + && (full_path[2] == b'.' as u16 || full_path[2] == b'?' as u16) + && full_path[3] == b'\\' as u16 + { + &full_path[4..] + } else { + full_path + }; + + let path_str = String::from_utf16_lossy(path_part); + if path_str.eq_ignore_ascii_case("CONIN$") { + 'r' + } else if path_str.eq_ignore_ascii_case("CONOUT$") { + 'w' + } else if path_str.eq_ignore_ascii_case("CON") { + 'x' + } else { + '\0' + } +} + +fn copy_from_small_buf(buf: &mut [u8; 4], dest: &mut [u8]) -> usize { + let mut n = 0; + while buf[0] != 0 && n < dest.len() { + dest[n] = buf[0]; + n += 1; + for i in 1..buf.len() { + buf[i - 1] = buf[i]; + } + buf[buf.len() - 1] = 0; + } + n +} + +fn find_last_utf8_boundary(buf: &[u8], len: usize) -> usize { + let len = len.min(buf.len()); + for count in 1..=4.min(len) { + let c = buf[len - count]; + if c < 0x80 { + return len; + } + if c >= 0xc0 { + let expected = if c < 0xe0 { + 2 + } else if c < 0xf0 { + 3 + } else { + 4 + }; + if count < expected { + return len - count; + } + return len; + } + } + len +} + +fn wchar_to_utf8_count(data: &[u8], mut len: usize, mut n: u32) -> usize { + let mut start: usize = 0; + loop { + let mut mid = 0; + for i in (len / 2)..=len { + mid = find_last_utf8_boundary(data, i); + if mid != 0 { + break; + } + } + if mid == len { + return start + len; + } + if mid == 0 { + mid = if len > 1 { len - 1 } else { 1 }; + } + let wlen = unsafe { + MultiByteToWideChar( + CP_UTF8, + 0, + data[start..].as_ptr(), + mid as i32, + core::ptr::null_mut(), + 0, + ) + } as u32; + if wlen <= n { + start += mid; + len -= mid; + n -= wlen; + } else { + len = mid; + } + } +} + +pub fn read_console_into( + handle: HANDLE, + dest: &mut [u8], + smallbuf: &mut [u8; 4], +) -> Result { + if dest.is_empty() { + return Ok(0); + } + + let mut wlen = (dest.len() / 4) as u32; + if wlen == 0 { + wlen = 1; + } + + let mut read_len = copy_from_small_buf(smallbuf, dest); + if read_len > 0 { + wlen = wlen.saturating_sub(1); + } + if read_len >= dest.len() || wlen == 0 { + return Ok(read_len); + } + + let mut wbuf = vec![0u16; wlen as usize]; + let mut nread: u32 = 0; + if unsafe { + Console::ReadConsoleW( + handle, + wbuf.as_mut_ptr().cast(), + wlen, + &mut nread, + core::ptr::null(), + ) + } == 0 + { + return Err(ReadConsoleError::Io(io::Error::last_os_error())); + } + if nread == 0 || wbuf[0] == 0x1A { + return Ok(read_len); + } + + let remaining = dest.len() - read_len; + let u8n = if remaining < 4 { + let converted = unsafe { + WideCharToMultiByte( + CP_UTF8, + 0, + wbuf.as_ptr(), + nread as i32, + smallbuf.as_mut_ptr().cast(), + smallbuf.len() as i32, + core::ptr::null(), + core::ptr::null_mut(), + ) + }; + if converted > 0 { + copy_from_small_buf(smallbuf, &mut dest[read_len..]) as i32 + } else { + 0 + } + } else { + unsafe { + WideCharToMultiByte( + CP_UTF8, + 0, + wbuf.as_ptr(), + nread as i32, + dest[read_len..].as_mut_ptr().cast(), + remaining as i32, + core::ptr::null(), + core::ptr::null_mut(), + ) + } + }; + + if u8n > 0 { + read_len += u8n as usize; + return Ok(read_len); + } + + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER as i32) + { + let needed = unsafe { + WideCharToMultiByte( + CP_UTF8, + 0, + wbuf.as_ptr(), + nread as i32, + core::ptr::null_mut(), + 0, + core::ptr::null(), + core::ptr::null_mut(), + ) + }; + if needed > 0 { + return Err(ReadConsoleError::BufferTooSmall { + available: remaining, + required: needed as usize, + }); + } + } + Err(ReadConsoleError::Io(err)) +} + +pub fn read_console_all(handle: HANDLE, smallbuf: &mut [u8; 4]) -> io::Result> { + let mut result = Vec::new(); + let mut tmp = [0u8; 4]; + let n = copy_from_small_buf(smallbuf, &mut tmp); + result.extend_from_slice(&tmp[..n]); + + let mut wbuf = vec![0u16; 8192]; + loop { + let mut nread: u32 = 0; + if unsafe { + Console::ReadConsoleW( + handle, + wbuf.as_mut_ptr().cast(), + wbuf.len() as u32, + &mut nread, + core::ptr::null(), + ) + } == 0 + { + return Err(io::Error::last_os_error()); + } + if nread == 0 || wbuf[0] == 0x1A { + break; + } + + let needed = unsafe { + WideCharToMultiByte( + CP_UTF8, + 0, + wbuf.as_ptr(), + nread as i32, + core::ptr::null_mut(), + 0, + core::ptr::null(), + core::ptr::null_mut(), + ) + }; + if needed == 0 { + return Err(io::Error::last_os_error()); + } + let offset = result.len(); + result.resize(offset + needed as usize, 0); + if unsafe { + WideCharToMultiByte( + CP_UTF8, + 0, + wbuf.as_ptr(), + nread as i32, + result[offset..].as_mut_ptr().cast(), + needed, + core::ptr::null(), + core::ptr::null_mut(), + ) + } == 0 + { + return Err(io::Error::last_os_error()); + } + if nread < wbuf.len() as u32 { + break; + } + } + + Ok(result) +} + +pub fn write_console_utf8(handle: HANDLE, data: &[u8], max_bytes: usize) -> io::Result { + if data.is_empty() { + return Ok(0); + } + + let mut len = data.len().min(max_bytes); + let max_wlen: u32 = 32766 / 2; + len = len.min(max_wlen as usize * 3); + + let wlen = loop { + len = find_last_utf8_boundary(data, len); + let wlen = unsafe { + MultiByteToWideChar( + CP_UTF8, + 0, + data.as_ptr(), + len as i32, + core::ptr::null_mut(), + 0, + ) + }; + if wlen as u32 <= max_wlen { + break wlen; + } + len /= 2; + }; + if wlen == 0 { + return Ok(0); + } + + let mut wbuf = vec![0u16; wlen as usize]; + let wlen = unsafe { + MultiByteToWideChar( + CP_UTF8, + 0, + data.as_ptr(), + len as i32, + wbuf.as_mut_ptr(), + wlen, + ) + }; + if wlen == 0 { + return Err(io::Error::last_os_error()); + } + + let mut written: u32 = 0; + if unsafe { + Console::WriteConsoleW( + handle, + wbuf.as_ptr().cast(), + wlen as u32, + &mut written, + core::ptr::null(), + ) + } == 0 + { + return Err(io::Error::last_os_error()); + } + + if written < wlen as u32 { + len = wchar_to_utf8_count(data, len, written); + } + Ok(len) +} + +pub fn open_console_path_fd(path: &widestring::WideCStr, writable: bool) -> io::Result { + use windows_sys::Win32::{ + Foundation::{GENERIC_READ, GENERIC_WRITE}, + Storage::FileSystem::{FILE_SHARE_READ, FILE_SHARE_WRITE}, + }; + + let access = if writable { + GENERIC_WRITE + } else { + GENERIC_READ + }; + + let mut handle = unsafe { + CreateFileW( + path.as_ptr(), + GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE, + core::ptr::null(), + OPEN_EXISTING, + 0, + core::ptr::null_mut(), + ) + }; + if handle == INVALID_HANDLE_VALUE { + handle = unsafe { + CreateFileW( + path.as_ptr(), + access, + FILE_SHARE_READ | FILE_SHARE_WRITE, + core::ptr::null(), + OPEN_EXISTING, + 0, + core::ptr::null_mut(), + ) + }; + } + if handle == INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + let osf_flags = if writable { + libc::O_WRONLY | libc::O_BINARY | 0x80 + } else { + libc::O_RDONLY | libc::O_BINARY | 0x80 + }; + match crate::msvcrt::open_osfhandle(handle as isize, osf_flags) { + Ok(fd) => Ok(fd), + Err(err) => { + unsafe { CloseHandle(handle) }; + Err(err) + } + } +} + +#[cfg(target_env = "msvc")] +pub fn cwait(pid: intptr_t, opt: i32) -> io::Result<(intptr_t, i32)> { + let mut status = 0; + let pid = unsafe { crate::suppress_iph!(_cwait(&mut status, pid, opt)) }; + if pid == -1 { + Err(crate::os::errno_io_error()) + } else { + Ok((pid, status)) + } +} + +#[cfg(target_env = "msvc")] +fn null_terminated_ptrs(strings: &[&widestring::WideCStr]) -> Vec<*const u16> { + strings + .iter() + .map(|s| s.as_ptr()) + .chain(core::iter::once(core::ptr::null())) + .collect() +} + +#[cfg(target_env = "msvc")] +pub fn spawnv( + mode: i32, + path: &widestring::WideCStr, + argv: &[&widestring::WideCStr], +) -> io::Result { + let argv_ptrs = null_terminated_ptrs(argv); + let result = unsafe { crate::suppress_iph!(_wspawnv(mode, path.as_ptr(), argv_ptrs.as_ptr())) }; + if result == -1 { + Err(crate::os::errno_io_error()) + } else { + Ok(result) + } +} + +#[cfg(target_env = "msvc")] +pub fn spawnve( + mode: i32, + path: &widestring::WideCStr, + argv: &[&widestring::WideCStr], + envp: &[&widestring::WideCStr], +) -> io::Result { + let argv_ptrs = null_terminated_ptrs(argv); + let envp_ptrs = null_terminated_ptrs(envp); + let result = unsafe { + crate::suppress_iph!(_wspawnve( + mode, + path.as_ptr(), + argv_ptrs.as_ptr(), + envp_ptrs.as_ptr() + )) + }; + if result == -1 { + Err(crate::os::errno_io_error()) + } else { + Ok(result) + } +} + +#[cfg(target_env = "msvc")] +pub fn execv(path: &widestring::WideCStr, argv: &[&widestring::WideCStr]) -> io::Result<()> { + let argv_ptrs = null_terminated_ptrs(argv); + let result = unsafe { crate::suppress_iph!(_wexecv(path.as_ptr(), argv_ptrs.as_ptr())) }; + if result == -1 { + Err(crate::os::errno_io_error()) + } else { + Ok(()) + } +} + +#[cfg(target_env = "msvc")] +pub fn execve( + path: &widestring::WideCStr, + argv: &[&widestring::WideCStr], + envp: &[&widestring::WideCStr], +) -> io::Result<()> { + let argv_ptrs = null_terminated_ptrs(argv); + let envp_ptrs = null_terminated_ptrs(envp); + let result = unsafe { + crate::suppress_iph!(_wexecve( + path.as_ptr(), + argv_ptrs.as_ptr(), + envp_ptrs.as_ptr() + )) + }; + if result == -1 { + Err(crate::os::errno_io_error()) + } else { + Ok(()) + } +} diff --git a/crates/host_env/src/os.rs b/crates/host_env/src/os.rs index c6dec6bbfeb..8afa07aea45 100644 --- a/crates/host_env/src/os.rs +++ b/crates/host_env/src/os.rs @@ -1,7 +1,15 @@ // spell-checker:disable // TODO: we can move more os-specific bindings/interfaces from stdlib::{os, posix, nt} to here +#[cfg(any(unix, windows, target_os = "wasi"))] +use crate::crt_fd; +#[cfg(windows)] +use crate::fs; +#[cfg(any(unix, windows))] +use core::ffi::CStr; use core::str::Utf8Error; +#[cfg(windows)] +use core::time::Duration; use std::{ env, ffi::{OsStr, OsString}, @@ -9,6 +17,17 @@ use std::{ path::PathBuf, process::ExitCode, }; +#[cfg(windows)] +use { + std::{os::windows::io::AsRawHandle, path::Path}, + windows_sys::Win32::{ + Foundation::FILETIME, + Storage::FileSystem::{ + FILE_FLAG_BACKUP_SEMANTICS, INVALID_SET_FILE_POINTER, SetFilePointer, SetFileTime, + }, + System::SystemInformation::{GetSystemInfo, SYSTEM_INFO}, + }, +}; /// Convert exit code to std::process::ExitCode /// @@ -71,7 +90,32 @@ pub unsafe fn remove_var(key: impl AsRef) { } pub fn set_current_dir(path: impl AsRef) -> io::Result<()> { - env::set_current_dir(path) + env::set_current_dir(&path)?; + + #[cfg(windows)] + { + use std::os::windows::ffi::OsStrExt; + use windows_sys::Win32::System::Environment::SetEnvironmentVariableW; + + if let Ok(cwd) = env::current_dir() { + let cwd_str = cwd.as_os_str(); + let mut cwd_wide: Vec = cwd_str.encode_wide().collect(); + + let is_unc_like_path = cwd_wide.len() >= 2 + && ((cwd_wide[0] == b'\\' as u16 && cwd_wide[1] == b'\\' as u16) + || (cwd_wide[0] == b'/' as u16 && cwd_wide[1] == b'/' as u16)); + + if !is_unc_like_path { + let env_name: [u16; 4] = [b'=' as u16, cwd_wide[0], b':' as u16, 0]; + cwd_wide.push(0); + unsafe { + SetEnvironmentVariableW(env_name.as_ptr(), cwd_wide.as_ptr()); + } + } + } + } + + Ok(()) } #[must_use] @@ -79,10 +123,236 @@ pub fn process_id() -> u32 { std::process::id() } +#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] +pub fn cpu_count() -> usize { + num_cpus::get() +} + +#[cfg(not(any(not(target_arch = "wasm32"), target_os = "wasi")))] +pub fn cpu_count() -> usize { + 1 +} + +#[cfg(unix)] +pub fn page_size() -> usize { + rustix::param::page_size() +} + +#[cfg(target_arch = "wasm32")] +pub const fn page_size() -> usize { + // WebAssembly's page size is a constant defined by the spec. + 1024 * 64 +} + +#[cfg(windows)] +pub fn page_size() -> usize { + let mut info = SYSTEM_INFO::default(); + unsafe { + GetSystemInfo(&mut info); + } + info.dwPageSize as _ +} + +#[cfg(unix)] +pub fn alloc_granularity() -> usize { + // On Unix-likes, the page size is the smallest allocation unit rather than a separate concept + // of allocation granularity. + page_size() +} + +#[cfg(target_arch = "wasm32")] +pub const fn alloc_granularity() -> usize { + // Like Unix, WebAssembly doesn't separate page size and alloc granularity. + page_size() +} + +#[cfg(windows)] +pub fn alloc_granularity() -> usize { + let mut info = SYSTEM_INFO::default(); + unsafe { + GetSystemInfo(&mut info); + } + info.dwAllocationGranularity as _ +} + +pub fn device_encoding(_fd: i32) -> Option { + #[cfg(any( + target_os = "android", + target_os = "redox", + all(target_arch = "wasm32", not(target_os = "wasi")) + ))] + { + Some("UTF-8".to_owned()) + } + + #[cfg(windows)] + { + use windows_sys::Win32::System::Console; + let cp = match _fd { + 0 => unsafe { Console::GetConsoleCP() }, + 1 | 2 => unsafe { Console::GetConsoleOutputCP() }, + _ => 0, + }; + + Some(format!("cp{cp}")) + } + + #[cfg(not(any( + target_os = "android", + target_os = "redox", + windows, + all(target_arch = "wasm32", not(target_os = "wasi")) + )))] + { + let encoding = unsafe { + let encoding = libc::nl_langinfo(libc::CODESET); + if encoding.is_null() || encoding.read() == b'\0' as libc::c_char { + "UTF-8".to_owned() + } else { + core::ffi::CStr::from_ptr(encoding) + .to_string_lossy() + .into_owned() + } + }; + + Some(encoding) + } +} + pub fn exit(code: i32) -> ! { std::process::exit(code) } +/// Wrapper around the C `abort()` call: terminates the process abnormally. +pub fn abort() -> ! { + unsafe extern "C" { + fn abort() -> !; + } + unsafe { abort() } +} + +/// Read `size` cryptographically random bytes from the OS. +pub fn urandom(size: usize) -> io::Result> { + let mut buf = vec![0u8; size]; + getrandom::fill(&mut buf).map_err(io::Error::from)?; + Ok(buf) +} + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub fn isatty(fd: i32) -> bool { + unsafe { suppress_iph!(libc::isatty(fd)) != 0 } +} + +#[cfg(not(any(unix, windows, target_os = "wasi")))] +pub fn isatty(_fd: i32) -> bool { + false +} + +#[cfg(any(unix, windows))] +pub fn system(command: &CStr) -> libc::c_int { + unsafe { libc::system(command.as_ptr()) } +} + +#[cfg(target_os = "linux")] +pub fn copy_file_range( + src: crt_fd::Borrowed<'_>, + offset_src: Option<&mut crt_fd::Offset>, + dst: crt_fd::Borrowed<'_>, + offset_dst: Option<&mut crt_fd::Offset>, + count: usize, +) -> io::Result { + #[allow(clippy::unnecessary_option_map_or_else)] + let p_offset_src = offset_src.map_or_else(core::ptr::null_mut, |x| x as *mut _); + #[allow(clippy::unnecessary_option_map_or_else)] + let p_offset_dst = offset_dst.map_or_else(core::ptr::null_mut, |x| x as *mut _); + + // Why not use `libc::copy_file_range`: On musl, the libc wrapper may be missing. + let ret = unsafe { + libc::syscall( + libc::SYS_copy_file_range, + src.as_raw(), + p_offset_src, + dst.as_raw(), + p_offset_dst, + count, + 0u32, + ) + }; + + usize::try_from(ret).map_err(|_| io::Error::last_os_error()) +} + +pub fn rename( + from: impl AsRef, + to: impl AsRef, +) -> io::Result<()> { + std::fs::rename(from, to) +} + +#[cfg(windows)] +pub fn seek_fd( + fd: crt_fd::Borrowed<'_>, + position: crt_fd::Offset, + how: i32, +) -> io::Result { + let handle = crt_fd::as_handle(fd)?; + let mut distance_to_move: [i32; 2] = unsafe { core::mem::transmute(position) }; + let ret = unsafe { + SetFilePointer( + handle.as_raw_handle(), + distance_to_move[0], + &mut distance_to_move[1], + how as _, + ) + }; + if ret == INVALID_SET_FILE_POINTER { + Err(io::Error::last_os_error()) + } else { + distance_to_move[0] = ret as _; + Ok(unsafe { core::mem::transmute::<[i32; 2], i64>(distance_to_move) }) + } +} + +#[cfg(any(unix, target_os = "wasi"))] +pub fn seek_fd( + fd: crt_fd::Borrowed<'_>, + position: crt_fd::Offset, + how: i32, +) -> io::Result { + unsafe { suppress_iph!(libc::lseek(fd.as_raw(), position, how)) }.check_libc_neg() +} + +#[cfg(windows)] +fn filetime_from_duration(duration: Duration) -> FILETIME { + let intervals = ((duration.as_secs() as i64 + 11644473600) * 10_000_000) + + (duration.subsec_nanos() as i64 / 100); + FILETIME { + dwLowDateTime: intervals as u32, + dwHighDateTime: (intervals >> 32) as u32, + } +} + +#[cfg(windows)] +pub fn set_file_times( + path: impl AsRef, + access: Duration, + modified: Duration, +) -> io::Result<()> { + use crate::windows::CheckWin32Bool; + let access = filetime_from_duration(access); + let modified = filetime_from_duration(modified); + let file = fs::open_write_with_custom_flags(path, FILE_FLAG_BACKUP_SEMANTICS)?; + unsafe { + SetFileTime( + file.as_raw_handle() as _, + core::ptr::null(), + &access, + &modified, + ) + } + .check_win32_bool() +} + pub trait ErrorExt { fn posix_errno(&self) -> i32; } @@ -115,6 +385,52 @@ pub fn errno_io_error() -> io::Error { std::io::Error::last_os_error() } +/// Convert a libc-style return value into an `io::Result`. +/// +/// Negative values are treated as errors; errno is read via [`errno_io_error`]. +/// Modeled after PyPy's `rposix.handle_posix_error`. +pub trait CheckLibcResult: Sized { + /// Returns `Ok(self)` if non-negative, otherwise `Err` with the current errno. + fn check_libc_neg(self) -> io::Result; +} + +macro_rules! impl_check_libc_result { + ($($ty:ty),* $(,)?) => { + $( + impl CheckLibcResult for $ty { + #[inline] + fn check_libc_neg(self) -> io::Result { + if self < 0 { Err(errno_io_error()) } else { Ok(self) } + } + } + )* + }; +} + +impl_check_libc_result!(i16, i32, i64, isize); + +/// libc convention where `0` means success and any non-zero value indicates failure +/// (with errno set). Used by APIs like `sigemptyset`, `sigaction`, `pthread_*`, etc. +pub trait CheckLibcZero { + /// Returns `Ok(())` if `self == 0`, otherwise the current errno as an `io::Error`. + fn check_libc_zero(self) -> io::Result<()>; +} + +macro_rules! impl_check_libc_zero { + ($($ty:ty),* $(,)?) => { + $( + impl CheckLibcZero for $ty { + #[inline] + fn check_libc_zero(self) -> io::Result<()> { + if self == 0 { Ok(()) } else { Err(errno_io_error()) } + } + } + )* + }; +} + +impl_check_libc_zero!(i32, i64, isize); + #[cfg(windows)] pub fn get_errno() -> i32 { unsafe extern "C" { @@ -131,6 +447,10 @@ pub fn get_errno() -> i32 { std::io::Error::last_os_error().posix_errno() } +pub fn clear_errno() { + set_errno(0); +} + /// Set errno to the specified value. #[cfg(windows)] pub fn set_errno(value: i32) { @@ -145,6 +465,16 @@ pub fn set_errno(value: i32) { nix::errno::Errno::from_raw(value).set(); } +#[cfg(target_os = "wasi")] +pub fn set_errno(value: i32) { + unsafe { + *libc::__errno_location() = value; + } +} + +#[cfg(not(any(unix, windows, target_os = "wasi")))] +pub fn set_errno(_value: i32) {} + #[cfg(unix)] pub fn bytes_as_os_str(b: &[u8]) -> Result<&std::ffi::OsStr, Utf8Error> { use std::os::unix::ffi::OsStrExt; diff --git a/crates/host_env/src/overlapped.rs b/crates/host_env/src/overlapped.rs new file mode 100644 index 00000000000..d84897d620d --- /dev/null +++ b/crates/host_env/src/overlapped.rs @@ -0,0 +1,1219 @@ +#![allow( + clippy::not_unsafe_ptr_arg_deref, + reason = "This module exposes raw overlapped I/O wrappers over Win32 and Winsock APIs." +)] +#![allow( + clippy::too_many_arguments, + reason = "These helpers preserve the underlying Win32 and Winsock call shapes." +)] + +use alloc::sync::Arc; +use core::sync::atomic::{AtomicBool, Ordering}; +use std::{ + collections::HashMap, + io, + sync::{Mutex, OnceLock}, +}; + +use crate::windows::{CheckWin32Bool, CheckWin32Handle}; +use windows_sys::Win32::{ + Foundation::{CloseHandle, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, HANDLE}, + Networking::WinSock::{AF_INET, AF_INET6, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6}, + System::{ + Diagnostics::Debug::{ + FORMAT_MESSAGE_ALLOCATE_BUFFER, FORMAT_MESSAGE_FROM_SYSTEM, + FORMAT_MESSAGE_IGNORE_INSERTS, FormatMessageW, + }, + IO::{CancelIoEx, GetOverlappedResult, OVERLAPPED}, + Pipes::ConnectNamedPipe, + Threading::{CreateEventW, SetEvent}, + }, +}; + +pub type Handle = HANDLE; +pub type OverlappedIo = OVERLAPPED; +pub type SocketAddrRaw = SOCKADDR; +pub type SocketAddrV4 = SOCKADDR_IN; +pub type SocketAddrV6 = SOCKADDR_IN6; +pub const AF_INET_FAMILY: i32 = AF_INET as i32; +pub const AF_INET6_FAMILY: i32 = AF_INET6 as i32; +pub const INVALID_HANDLE_VALUE_ISIZE: isize = -1; +pub const SO_UPDATE_ACCEPT_CONTEXT_VALUE: i32 = + windows_sys::Win32::Networking::WinSock::SO_UPDATE_ACCEPT_CONTEXT; +pub const SO_UPDATE_CONNECT_CONTEXT_VALUE: i32 = + windows_sys::Win32::Networking::WinSock::SO_UPDATE_CONNECT_CONTEXT; +pub const TF_REUSE_SOCKET_FLAG: u32 = windows_sys::Win32::Networking::WinSock::TF_REUSE_SOCKET; + +pub struct TransferResult { + pub transferred: u32, + pub error: u32, +} + +pub struct OverlappedResult { + pub transferred: u32, + pub error: u32, +} + +pub struct Operation { + overlapped: Box, + handle: HANDLE, + pending: bool, + completed: bool, + read_buffer: Option>, + write_buffer: Option>, +} + +impl core::fmt::Debug for Operation { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Operation") + .field("handle", &self.handle) + .field("pending", &self.pending) + .field("completed", &self.completed) + .finish() + } +} + +unsafe impl Sync for Operation {} +unsafe impl Send for Operation {} + +impl Operation { + pub fn new(handle: HANDLE) -> io::Result { + let event = + unsafe { CreateEventW(core::ptr::null(), 1, 0, core::ptr::null()) }.check_nonnull()?; + + let mut overlapped: OVERLAPPED = unsafe { core::mem::zeroed() }; + overlapped.hEvent = event; + Ok(Self { + overlapped: Box::new(overlapped), + handle, + pending: false, + completed: false, + read_buffer: None, + write_buffer: None, + }) + } + + pub fn event(&self) -> HANDLE { + self.overlapped.hEvent + } + + pub fn is_completed(&self) -> bool { + self.completed + } + + pub fn read_buffer(&self) -> Option<&[u8]> { + self.read_buffer.as_deref() + } + + pub fn get_result(&mut self, wait: bool) -> io::Result { + use windows_sys::Win32::Foundation::{ + ERROR_IO_INCOMPLETE, ERROR_OPERATION_ABORTED, ERROR_SUCCESS, GetLastError, + }; + + let mut transferred = 0; + let ret = unsafe { + GetOverlappedResult( + self.handle, + &*self.overlapped, + &mut transferred, + i32::from(wait), + ) + }; + + let err = if ret == 0 { + unsafe { GetLastError() } + } else { + ERROR_SUCCESS + }; + + match err { + ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_OPERATION_ABORTED => { + self.completed = true; + self.pending = false; + } + ERROR_IO_INCOMPLETE => {} + _ => { + self.pending = false; + return Err(io::Error::from_raw_os_error(err as i32)); + } + } + + if self.completed + && let Some(read_buffer) = &mut self.read_buffer + && transferred != read_buffer.len() as u32 + { + read_buffer.truncate(transferred as usize); + } + + Ok(TransferResult { + transferred, + error: err, + }) + } + + pub fn cancel(&mut self) -> io::Result<()> { + let ret = if self.pending { + unsafe { CancelIoEx(self.handle, &*self.overlapped) } + } else { + 1 + }; + if ret == 0 { + let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; + if err != windows_sys::Win32::Foundation::ERROR_NOT_FOUND { + return Err(io::Error::from_raw_os_error(err as i32)); + } + } + self.pending = false; + Ok(()) + } + + pub fn connect_named_pipe(&mut self) -> io::Result<()> { + use windows_sys::Win32::Foundation::ERROR_PIPE_CONNECTED; + + self.completed = false; + let err = start_connect_named_pipe(self.handle, &mut *self.overlapped); + match err { + ERROR_IO_PENDING => { + self.pending = true; + } + ERROR_PIPE_CONNECTED => { + unsafe { SetEvent(self.overlapped.hEvent) }.check_win32_bool()?; + } + _ => return Err(io::Error::from_raw_os_error(err as i32)), + } + Ok(()) + } + + pub fn write(&mut self, buffer: &[u8]) -> io::Result { + if self.pending { + return Err(io::Error::new( + io::ErrorKind::WouldBlock, + "overlapped operation is pending", + )); + } + let len = core::cmp::min(buffer.len(), u32::MAX as usize) as u32; + self.write_buffer = Some(buffer[..len as usize].to_vec()); + let write_buf = self + .write_buffer + .as_ref() + .expect("write buffer initialized"); + self.completed = false; + let err = start_write_file(self.handle, write_buf.as_ptr(), len, &mut *self.overlapped); + + if err != ERROR_SUCCESS && err != ERROR_IO_PENDING { + return Err(io::Error::from_raw_os_error(err as i32)); + } + if err == ERROR_IO_PENDING { + self.pending = true; + } + + Ok(err) + } + + pub fn read(&mut self, size: u32) -> io::Result { + if self.pending { + return Err(io::Error::new( + io::ErrorKind::WouldBlock, + "overlapped operation is pending", + )); + } + self.read_buffer = Some(vec![0u8; size as usize]); + let read_buf = self.read_buffer.as_mut().expect("read buffer initialized"); + self.completed = false; + let err = start_read_file( + self.handle, + read_buf.as_mut_ptr(), + size, + &mut *self.overlapped, + ); + + if err != ERROR_SUCCESS && err != ERROR_IO_PENDING && err != ERROR_MORE_DATA { + return Err(io::Error::from_raw_os_error(err as i32)); + } + if err == ERROR_IO_PENDING { + self.pending = true; + } + + Ok(err) + } +} + +impl Drop for Operation { + fn drop(&mut self) { + if self.pending { + let _ = unsafe { CancelIoEx(self.handle, &*self.overlapped) }; + let mut transferred = 0; + let _ = + unsafe { GetOverlappedResult(self.handle, &*self.overlapped, &mut transferred, 1) }; + self.pending = false; + } + if !self.overlapped.hEvent.is_null() { + unsafe { CloseHandle(self.overlapped.hEvent) }; + } + } +} + +pub struct QueuedCompletionStatus { + pub error: u32, + pub bytes_transferred: u32, + pub completion_key: usize, + pub overlapped: usize, +} + +pub struct WaitCallbackData { + completion_port: HANDLE, + overlapped: *mut OVERLAPPED, + fired: AtomicBool, +} + +struct WaitCallbackEntry { + data: Arc, + raw_ptr: usize, +} + +pub enum WaitResult { + Timeout, + Queued(QueuedCompletionStatus), +} + +pub enum SocketAddress { + V4 { + host: String, + port: u16, + }, + V6 { + host: String, + port: u16, + flowinfo: u32, + scope_id: u32, + }, +} + +static ACCEPT_EX: OnceLock = OnceLock::new(); +static CONNECT_EX: OnceLock = OnceLock::new(); +static DISCONNECT_EX: OnceLock = OnceLock::new(); +static TRANSMIT_FILE: OnceLock = OnceLock::new(); +static WAIT_CALLBACK_REGISTRY: OnceLock>> = OnceLock::new(); + +fn wait_callback_registry() -> &'static Mutex> { + WAIT_CALLBACK_REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn winsock_extension_or_error(lock: &OnceLock) -> Result { + use windows_sys::Win32::Networking::WinSock::WSAEOPNOTSUPP; + + if let Some(func) = lock.get() { + return Ok(*func); + } + if initialize_winsock_extensions().is_ok() + && let Some(func) = lock.get() + { + return Ok(*func); + } + Err(WSAEOPNOTSUPP as u32) +} + +pub fn initialize_winsock_extensions() -> io::Result<()> { + use windows_sys::Win32::Networking::WinSock::{ + INVALID_SOCKET, IPPROTO_TCP, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCK_STREAM, SOCKET_ERROR, + WSAGetLastError, WSAIoctl, closesocket, socket, + }; + + const WSAID_ACCEPTEX: windows_sys::core::GUID = windows_sys::core::GUID { + data1: 0xb5367df1, + data2: 0xcbac, + data3: 0x11cf, + data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], + }; + const WSAID_CONNECTEX: windows_sys::core::GUID = windows_sys::core::GUID { + data1: 0x25a207b9, + data2: 0xddf3, + data3: 0x4660, + data4: [0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e], + }; + const WSAID_DISCONNECTEX: windows_sys::core::GUID = windows_sys::core::GUID { + data1: 0x7fda2e11, + data2: 0x8630, + data3: 0x436f, + data4: [0xa0, 0x31, 0xf5, 0x36, 0xa6, 0xee, 0xc1, 0x57], + }; + const WSAID_TRANSMITFILE: windows_sys::core::GUID = windows_sys::core::GUID { + data1: 0xb5367df0, + data2: 0xcbac, + data3: 0x11cf, + data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], + }; + + if ACCEPT_EX.get().is_some() + && CONNECT_EX.get().is_some() + && DISCONNECT_EX.get().is_some() + && TRANSMIT_FILE.get().is_some() + { + return Ok(()); + } + + let s = unsafe { socket(AF_INET as i32, SOCK_STREAM, IPPROTO_TCP) }; + if s == INVALID_SOCKET { + return Err(io::Error::from_raw_os_error( + unsafe { WSAGetLastError() } as i32 + )); + } + + let mut dw_bytes = 0; + + macro_rules! get_extension { + ($guid:expr, $lock:expr) => {{ + let mut func_ptr: usize = 0; + let ret = unsafe { + WSAIoctl( + s, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &$guid as *const _ as *const _, + core::mem::size_of_val(&$guid) as u32, + &mut func_ptr as *mut _ as *mut _, + core::mem::size_of::() as u32, + &mut dw_bytes, + core::ptr::null_mut(), + None, + ) + }; + if ret == SOCKET_ERROR { + let err = unsafe { WSAGetLastError() }; + unsafe { closesocket(s) }; + return Err(io::Error::from_raw_os_error(err as i32)); + } + let _ = $lock.set(func_ptr); + }}; + } + + get_extension!(WSAID_ACCEPTEX, ACCEPT_EX); + get_extension!(WSAID_CONNECTEX, CONNECT_EX); + get_extension!(WSAID_DISCONNECTEX, DISCONNECT_EX); + get_extension!(WSAID_TRANSMITFILE, TRANSMIT_FILE); + + unsafe { closesocket(s) }; + Ok(()) +} + +pub fn mark_as_completed(ov: &mut OVERLAPPED) { + ov.Internal = 0; + if !ov.hEvent.is_null() { + unsafe { + let _ = SetEvent(ov.hEvent); + } + } +} + +pub fn has_overlapped_io_completed(overlapped: &OVERLAPPED) -> bool { + overlapped.Internal != (windows_sys::Win32::Foundation::STATUS_PENDING as usize) +} + +pub fn cancel_overlapped(handle: HANDLE, overlapped: *const OVERLAPPED) -> io::Result<()> { + let ret = unsafe { CancelIoEx(handle, overlapped) }; + if ret == 0 { + let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; + if err != windows_sys::Win32::Foundation::ERROR_NOT_FOUND { + return Err(io::Error::from_raw_os_error(err as i32)); + } + } + Ok(()) +} + +pub fn get_overlapped_result( + handle: HANDLE, + overlapped: *const OVERLAPPED, + wait: bool, +) -> OverlappedResult { + let mut transferred = 0; + let ret = unsafe { GetOverlappedResult(handle, overlapped, &mut transferred, i32::from(wait)) }; + let error = if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + }; + OverlappedResult { transferred, error } +} + +pub fn cancel_overlapped_for_drop( + handle: HANDLE, + overlapped: *const OVERLAPPED, +) -> OverlappedResult { + let cancelled = unsafe { CancelIoEx(handle, overlapped) } != 0; + get_overlapped_result(handle, overlapped, cancelled) +} + +pub fn start_read_file( + handle: HANDLE, + buffer: *mut u8, + len: u32, + overlapped: *mut OVERLAPPED, +) -> u32 { + let mut transferred = 0; + let ret = unsafe { + windows_sys::Win32::Storage::FileSystem::ReadFile( + handle, + buffer.cast(), + len, + &mut transferred, + overlapped, + ) + }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + } +} + +pub fn start_write_file( + handle: HANDLE, + buffer: *const u8, + len: u32, + overlapped: *mut OVERLAPPED, +) -> u32 { + let mut transferred = 0; + let ret = unsafe { + windows_sys::Win32::Storage::FileSystem::WriteFile( + handle, + buffer.cast(), + len, + &mut transferred, + overlapped, + ) + }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + } +} + +pub fn start_wsa_recv( + handle: usize, + buffer: *mut u8, + len: u32, + flags: *mut u32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSARecv}; + + let wsabuf = WSABUF { + buf: buffer.cast(), + len, + }; + let mut transferred = 0; + let ret = unsafe { + WSARecv( + handle, + &wsabuf, + 1, + &mut transferred, + flags, + overlapped, + None, + ) + }; + if ret < 0 { + unsafe { WSAGetLastError() as u32 } + } else { + ERROR_SUCCESS + } +} + +pub fn start_wsa_send( + handle: usize, + buffer: *const u8, + len: u32, + flags: u32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSASend}; + + let wsabuf = WSABUF { + buf: buffer.cast_mut().cast(), + len, + }; + let mut transferred = 0; + let ret = unsafe { + WSASend( + handle, + &wsabuf, + 1, + &mut transferred, + flags, + overlapped, + None, + ) + }; + if ret < 0 { + unsafe { WSAGetLastError() as u32 } + } else { + ERROR_SUCCESS + } +} + +pub fn start_accept_ex( + listen_socket: usize, + accept_socket: usize, + buffer: *mut u8, + address_size: u32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + + type AcceptExFn = unsafe extern "system" fn( + s_listen_socket: usize, + s_accept_socket: usize, + lp_output_buffer: *mut core::ffi::c_void, + dw_receive_data_length: u32, + dw_local_address_length: u32, + dw_remote_address_length: u32, + lpdw_bytes_received: *mut u32, + lp_overlapped: *mut OVERLAPPED, + ) -> i32; + + let accept_ex = match winsock_extension_or_error(&ACCEPT_EX) { + Ok(func) => unsafe { core::mem::transmute::(func) }, + Err(err) => return err, + }; + let mut bytes_received = 0; + let ret = unsafe { + accept_ex( + listen_socket, + accept_socket, + buffer.cast(), + 0, + address_size, + address_size, + &mut bytes_received, + overlapped, + ) + }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { WSAGetLastError() as u32 } + } +} + +pub fn start_connect_ex( + socket: usize, + address: *const SOCKADDR, + address_len: i32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + + type ConnectExFn = unsafe extern "system" fn( + s: usize, + name: *const SOCKADDR, + namelen: i32, + lp_send_buffer: *const core::ffi::c_void, + dw_send_data_length: u32, + lpdw_bytes_sent: *mut u32, + lp_overlapped: *mut OVERLAPPED, + ) -> i32; + + let connect_ex = match winsock_extension_or_error(&CONNECT_EX) { + Ok(func) => unsafe { core::mem::transmute::(func) }, + Err(err) => return err, + }; + let ret = unsafe { + connect_ex( + socket, + address, + address_len, + core::ptr::null(), + 0, + core::ptr::null_mut(), + overlapped, + ) + }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { WSAGetLastError() as u32 } + } +} + +pub fn start_disconnect_ex(socket: usize, flags: u32, overlapped: *mut OVERLAPPED) -> u32 { + use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + + type DisconnectExFn = unsafe extern "system" fn( + s: usize, + lp_overlapped: *mut OVERLAPPED, + dw_flags: u32, + dw_reserved: u32, + ) -> i32; + + let disconnect_ex = match winsock_extension_or_error(&DISCONNECT_EX) { + Ok(func) => unsafe { core::mem::transmute::(func) }, + Err(err) => return err, + }; + let ret = unsafe { disconnect_ex(socket, overlapped, flags, 0) }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { WSAGetLastError() as u32 } + } +} + +pub fn start_transmit_file( + socket: usize, + file: HANDLE, + count_to_write: u32, + count_per_send: u32, + flags: u32, + offset: u32, + offset_high: u32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + + type TransmitFileFn = unsafe extern "system" fn( + h_socket: usize, + h_file: HANDLE, + n_number_of_bytes_to_write: u32, + n_number_of_bytes_per_send: u32, + lp_overlapped: *mut OVERLAPPED, + lp_transmit_buffers: *const core::ffi::c_void, + dw_reserved: u32, + ) -> i32; + + unsafe { + (*overlapped).Anonymous.Anonymous.Offset = offset; + (*overlapped).Anonymous.Anonymous.OffsetHigh = offset_high; + } + + let transmit_file = match winsock_extension_or_error(&TRANSMIT_FILE) { + Ok(func) => unsafe { core::mem::transmute::(func) }, + Err(err) => return err, + }; + let ret = unsafe { + transmit_file( + socket, + file, + count_to_write, + count_per_send, + overlapped, + core::ptr::null(), + flags, + ) + }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { WSAGetLastError() as u32 } + } +} + +pub fn start_connect_named_pipe(pipe: HANDLE, overlapped: *mut OVERLAPPED) -> u32 { + let ret = unsafe { ConnectNamedPipe(pipe, overlapped) }; + if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + } +} + +pub fn start_wsa_send_to( + handle: usize, + buffer: *const u8, + len: u32, + flags: u32, + address: *const SOCKADDR, + address_len: i32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSASendTo}; + + let wsabuf = WSABUF { + buf: buffer.cast_mut().cast(), + len, + }; + let mut transferred = 0; + let ret = unsafe { + WSASendTo( + handle, + &wsabuf, + 1, + &mut transferred, + flags, + address, + address_len, + overlapped, + None, + ) + }; + if ret < 0 { + unsafe { WSAGetLastError() as u32 } + } else { + ERROR_SUCCESS + } +} + +pub fn start_wsa_recv_from( + handle: usize, + buffer: *mut u8, + len: u32, + flags: *mut u32, + address: *mut SOCKADDR, + address_len: *mut i32, + overlapped: *mut OVERLAPPED, +) -> u32 { + use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSARecvFrom}; + + let wsabuf = WSABUF { + buf: buffer.cast(), + len, + }; + let mut transferred = 0; + let ret = unsafe { + WSARecvFrom( + handle, + &wsabuf, + 1, + &mut transferred, + flags, + address, + address_len, + overlapped, + None, + ) + }; + if ret < 0 { + unsafe { WSAGetLastError() as u32 } + } else { + ERROR_SUCCESS + } +} + +pub fn connect_pipe(address: &str) -> io::Result { + use windows_sys::Win32::{ + Foundation::{GENERIC_READ, GENERIC_WRITE}, + Storage::FileSystem::{CreateFileW, FILE_FLAG_OVERLAPPED, OPEN_EXISTING}, + }; + + let address_wide: Vec = address.encode_utf16().chain(core::iter::once(0)).collect(); + let handle = unsafe { + CreateFileW( + address_wide.as_ptr(), + GENERIC_READ | GENERIC_WRITE, + 0, + core::ptr::null(), + OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, + core::ptr::null_mut(), + ) + }; + let handle = handle.check_valid()?; + Ok(handle as isize) +} + +pub fn create_io_completion_port( + handle: isize, + port: isize, + key: usize, + concurrency: u32, +) -> io::Result { + let r = unsafe { + windows_sys::Win32::System::IO::CreateIoCompletionPort( + handle as HANDLE, + port as HANDLE, + key, + concurrency, + ) + } + .check_nonnull()?; + Ok(r as isize) +} + +pub fn get_queued_completion_status(port: isize, msecs: u32) -> io::Result { + let mut bytes_transferred = 0; + let mut completion_key = 0; + let mut overlapped: *mut OVERLAPPED = core::ptr::null_mut(); + let ret = unsafe { + windows_sys::Win32::System::IO::GetQueuedCompletionStatus( + port as HANDLE, + &mut bytes_transferred, + &mut completion_key, + &mut overlapped, + msecs, + ) + }; + let err = if ret != 0 { + windows_sys::Win32::Foundation::ERROR_SUCCESS + } else { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + }; + if overlapped.is_null() { + if err == windows_sys::Win32::Foundation::WAIT_TIMEOUT { + Ok(WaitResult::Timeout) + } else { + Err(io::Error::from_raw_os_error(err as i32)) + } + } else { + Ok(WaitResult::Queued(QueuedCompletionStatus { + error: err, + bytes_transferred, + completion_key, + overlapped: overlapped as usize, + })) + } +} + +pub fn post_queued_completion_status( + port: isize, + bytes: u32, + key: usize, + address: usize, +) -> io::Result<()> { + unsafe { + windows_sys::Win32::System::IO::PostQueuedCompletionStatus( + port as HANDLE, + bytes, + key, + address as *mut OVERLAPPED, + ) + } + .check_win32_bool() +} + +unsafe impl Send for WaitCallbackData {} +unsafe impl Sync for WaitCallbackData {} + +unsafe extern "system" fn post_to_queue_callback( + parameter: *mut core::ffi::c_void, + timer_or_wait_fired: bool, +) { + let raw_ptr = parameter as *const WaitCallbackData; + let data = unsafe { Arc::from_raw(raw_ptr) }; + data.fired.store(true, Ordering::Release); + unsafe { + let _ = windows_sys::Win32::System::IO::PostQueuedCompletionStatus( + data.completion_port, + if timer_or_wait_fired { 1 } else { 0 }, + 0, + data.overlapped, + ); + } +} + +pub fn register_wait_with_queue( + object: isize, + completion_port: isize, + overlapped: usize, + timeout: u32, +) -> io::Result { + use windows_sys::Win32::System::Threading::{ + RegisterWaitForSingleObject, WT_EXECUTEINWAITTHREAD, WT_EXECUTEONLYONCE, + }; + + let data = Arc::new(WaitCallbackData { + completion_port: completion_port as HANDLE, + overlapped: overlapped as *mut OVERLAPPED, + fired: AtomicBool::new(false), + }); + let data_ptr = Arc::into_raw(data.clone()); + + let mut new_wait_object: HANDLE = core::ptr::null_mut(); + let ret = unsafe { + RegisterWaitForSingleObject( + &mut new_wait_object, + object as HANDLE, + Some(post_to_queue_callback), + data_ptr as *mut _, + timeout, + WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE, + ) + }; + if ret == 0 { + unsafe { + let _ = Arc::from_raw(data_ptr); + } + return Err(io::Error::last_os_error()); + } + + let wait_handle = new_wait_object as isize; + if let Ok(mut registry) = wait_callback_registry().lock() { + registry.insert( + wait_handle, + WaitCallbackEntry { + data, + raw_ptr: data_ptr as usize, + }, + ); + } + Ok(wait_handle) +} + +fn cleanup_wait_callback_data(wait_handle: isize) { + if let Ok(mut registry) = wait_callback_registry().lock() + && let Some(entry) = registry.remove(&wait_handle) + && !entry.data.fired.load(Ordering::Acquire) + { + unsafe { + let _ = Arc::from_raw(entry.raw_ptr as *const WaitCallbackData); + } + } +} + +pub fn unregister_wait(wait_handle: isize) -> io::Result<()> { + let ret = + unsafe { windows_sys::Win32::System::Threading::UnregisterWait(wait_handle as HANDLE) }; + cleanup_wait_callback_data(wait_handle); + ret.check_win32_bool() +} + +pub fn unregister_wait_ex(wait_handle: isize, event: isize) -> io::Result<()> { + let ret = unsafe { + windows_sys::Win32::System::Threading::UnregisterWaitEx( + wait_handle as HANDLE, + event as HANDLE, + ) + }; + cleanup_wait_callback_data(wait_handle); + ret.check_win32_bool() +} + +pub fn bind_local(socket: isize, family: i32) -> io::Result<()> { + use windows_sys::Win32::Networking::WinSock::{ + INADDR_ANY, SOCKET_ERROR, WSAGetLastError, bind, + }; + + let ret = if family == AF_INET as i32 { + let mut addr: SOCKADDR_IN = unsafe { core::mem::zeroed() }; + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + unsafe { + bind( + socket as _, + &addr as *const _ as *const SOCKADDR, + core::mem::size_of::() as i32, + ) + } + } else if family == AF_INET6 as i32 { + let mut addr: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + unsafe { + bind( + socket as _, + &addr as *const _ as *const SOCKADDR, + core::mem::size_of::() as i32, + ) + } + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "expected tuple of length 2 or 4", + )); + }; + + if ret == SOCKET_ERROR { + Err(io::Error::from_raw_os_error( + unsafe { WSAGetLastError() } as i32 + )) + } else { + Ok(()) + } +} + +pub fn parse_address_v4_wide(host_wide: &[u16], port: u16) -> io::Result<(Vec, i32)> { + use windows_sys::Win32::Networking::WinSock::{WSAGetLastError, WSAStringToAddressW}; + + let mut addr: SOCKADDR_IN = unsafe { core::mem::zeroed() }; + addr.sin_family = AF_INET; + + let mut addr_len = core::mem::size_of::() as i32; + + let ret = unsafe { + WSAStringToAddressW( + host_wide.as_ptr(), + AF_INET as i32, + core::ptr::null(), + &mut addr as *mut _ as *mut SOCKADDR, + &mut addr_len, + ) + }; + if ret < 0 { + return Err(io::Error::from_raw_os_error( + unsafe { WSAGetLastError() } as i32 + )); + } + + // WSAStringToAddressW overwrites the port field. + addr.sin_port = port.to_be(); + + let bytes = unsafe { + core::slice::from_raw_parts( + &addr as *const _ as *const u8, + core::mem::size_of::(), + ) + }; + Ok((bytes.to_vec(), addr_len)) +} + +pub fn parse_address_v4(host: &str, port: u16) -> io::Result<(Vec, i32)> { + let host_wide: Vec = host.encode_utf16().chain([0]).collect(); + parse_address_v4_wide(&host_wide, port) +} + +pub fn parse_address_v6( + host: &str, + port: u16, + flowinfo: u32, + scope_id: u32, +) -> io::Result<(Vec, i32)> { + let host_wide: Vec = host.encode_utf16().chain([0]).collect(); + parse_address_v6_wide(&host_wide, port, flowinfo, scope_id) +} + +pub fn parse_address_v6_wide( + host_wide: &[u16], + port: u16, + flowinfo: u32, + scope_id: u32, +) -> io::Result<(Vec, i32)> { + use windows_sys::Win32::Networking::WinSock::{WSAGetLastError, WSAStringToAddressW}; + + let mut addr: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; + addr.sin6_family = AF_INET6; + + let mut addr_len = core::mem::size_of::() as i32; + + let ret = unsafe { + WSAStringToAddressW( + host_wide.as_ptr(), + AF_INET6 as i32, + core::ptr::null(), + &mut addr as *mut _ as *mut SOCKADDR, + &mut addr_len, + ) + }; + if ret < 0 { + return Err(io::Error::from_raw_os_error( + unsafe { WSAGetLastError() } as i32 + )); + } + + // WSAStringToAddressW may overwrite these fields. + addr.sin6_port = port.to_be(); + addr.sin6_flowinfo = flowinfo; + addr.Anonymous.sin6_scope_id = scope_id; + + let bytes = unsafe { + core::slice::from_raw_parts( + &addr as *const _ as *const u8, + core::mem::size_of::(), + ) + }; + Ok((bytes.to_vec(), addr_len)) +} + +pub fn unparse_address(addr: *const SOCKADDR, addr_len: i32) -> io::Result { + use core::net::{Ipv4Addr, Ipv6Addr}; + + if addr.is_null() || addr_len < core::mem::size_of::() as i32 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "address buffer too small", + )); + } + + let family = unsafe { (*addr).sa_family }; + if family == AF_INET { + if addr_len < core::mem::size_of::() as i32 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "address buffer too small for AF_INET", + )); + } + let addr_in = unsafe { &*(addr as *const SOCKADDR_IN) }; + let ip_bytes = unsafe { addr_in.sin_addr.S_un.S_un_b }; + Ok(SocketAddress::V4 { + host: Ipv4Addr::new(ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4) + .to_string(), + port: u16::from_be(addr_in.sin_port), + }) + } else if family == AF_INET6 { + if addr_len < core::mem::size_of::() as i32 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "address buffer too small for AF_INET6", + )); + } + let addr = unsafe { &*(addr as *const SOCKADDR_IN6) }; + let ip_bytes = unsafe { addr.sin6_addr.u.Byte }; + let scope_id = unsafe { addr.Anonymous.sin6_scope_id }; + Ok(SocketAddress::V6 { + host: Ipv6Addr::from(ip_bytes).to_string(), + port: u16::from_be(addr.sin6_port), + flowinfo: u32::from_be(addr.sin6_flowinfo), + scope_id, + }) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "recvfrom returned unsupported address family", + )) + } +} + +pub fn format_message(error_code: u32) -> String { + use windows_sys::Win32::Foundation::LocalFree; + + const LANG_NEUTRAL: u32 = 0; + const SUBLANG_DEFAULT: u32 = 1; + + let mut buffer: *mut u16 = core::ptr::null_mut(); + let len = unsafe { + FormatMessageW( + FORMAT_MESSAGE_ALLOCATE_BUFFER + | FORMAT_MESSAGE_FROM_SYSTEM + | FORMAT_MESSAGE_IGNORE_INSERTS, + core::ptr::null(), + error_code, + (SUBLANG_DEFAULT << 10) | LANG_NEUTRAL, + &mut buffer as *mut _ as *mut u16, + 0, + core::ptr::null(), + ) + }; + + if len == 0 || buffer.is_null() { + if !buffer.is_null() { + unsafe { LocalFree(buffer as *mut _) }; + } + return format!("unknown error code {error_code}"); + } + + let slice = unsafe { core::slice::from_raw_parts(buffer, len as usize) }; + let msg = String::from_utf16_lossy(slice).trim_end().to_string(); + unsafe { LocalFree(buffer as *mut _) }; + msg +} + +pub fn wsa_connect(socket: isize, addr_ptr: *const SOCKADDR, addr_len: i32) -> io::Result<()> { + use windows_sys::Win32::Networking::WinSock::{SOCKET_ERROR, WSAConnect, WSAGetLastError}; + + let ret = unsafe { + WSAConnect( + socket as _, + addr_ptr, + addr_len, + core::ptr::null(), + core::ptr::null_mut(), + core::ptr::null(), + core::ptr::null(), + ) + }; + if ret == SOCKET_ERROR { + Err(io::Error::from_raw_os_error( + unsafe { WSAGetLastError() } as i32 + )) + } else { + Ok(()) + } +} diff --git a/crates/host_env/src/posix.rs b/crates/host_env/src/posix.rs index 36f2e1df966..f10e99ac938 100644 --- a/crates/host_env/src/posix.rs +++ b/crates/host_env/src/posix.rs @@ -1,17 +1,535 @@ -use std::os::fd::BorrowedFd; +use alloc::ffi::CString; +#[cfg(all(unix, not(target_os = "redox")))] +use alloc::vec::Vec; +use core::ffi::CStr; +#[cfg(all(unix, not(target_os = "redox")))] +use core::ptr::NonNull; +use std::ffi::{OsStr, OsString}; +#[cfg(target_os = "linux")] +use std::os::fd::FromRawFd; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd}; +use std::path::Path; + +pub struct UnameInfo { + pub sysname: String, + pub nodename: String, + pub release: String, + pub version: String, + pub machine: String, +} + +#[cfg(all(unix, not(target_os = "redox")))] +#[derive(Clone, Copy, Debug)] +pub struct StatVfsInfo { + pub f_bsize: libc::c_ulong, + pub f_frsize: libc::c_ulong, + pub f_blocks: libc::fsblkcnt_t, + pub f_bfree: libc::fsblkcnt_t, + pub f_bavail: libc::fsblkcnt_t, + pub f_files: libc::fsfilcnt_t, + pub f_ffree: libc::fsfilcnt_t, + pub f_favail: libc::fsfilcnt_t, + pub f_flag: libc::c_ulong, + pub f_namemax: libc::c_ulong, + pub f_fsid: libc::c_ulong, +} + +#[cfg(all(unix, not(target_os = "redox")))] +#[derive(Clone, Debug)] +pub struct RawDirEntry { + pub name: Vec, + pub d_type: Option, + pub ino: u64, +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub struct FdDirStream(NonNull); -pub fn set_inheritable(fd: BorrowedFd<'_>, inheritable: bool) -> nix::Result<()> { +#[cfg(all(target_os = "linux", target_env = "gnu"))] +pub type PriorityWhichType = libc::__priority_which_t; +#[cfg(not(all(target_os = "linux", target_env = "gnu")))] +pub type PriorityWhichType = libc::c_int; + +#[cfg(target_os = "freebsd")] +pub type PriorityWhoType = i32; +#[cfg(not(target_os = "freebsd"))] +pub type PriorityWhoType = u32; + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +#[derive(Clone, Debug)] +pub enum PosixSpawnFileAction { + Open { + fd: i32, + path: CString, + oflag: i32, + mode: u32, + }, + Close { + fd: i32, + }, + Dup2 { + fd: i32, + newfd: i32, + }, +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +pub struct PosixSpawnConfig<'a> { + pub path: &'a CStr, + pub args: &'a [CString], + pub env: &'a [CString], + pub file_actions: &'a [PosixSpawnFileAction], + pub setsigdef: Option<&'a [i32]>, + pub setpgroup: Option, + pub resetids: bool, + pub setsid: bool, + pub setsigmask: Option<&'a [i32]>, + pub spawnp: bool, +} + +pub fn set_inheritable(fd: BorrowedFd<'_>, inheritable: bool) -> std::io::Result<()> { use nix::fcntl; - let flags = fcntl::FdFlag::from_bits_truncate(fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFD)?); + let flags = fcntl::FdFlag::from_bits_truncate( + fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFD).map_err(std::io::Error::from)?, + ); let mut new_flags = flags; new_flags.set(fcntl::FdFlag::FD_CLOEXEC, !inheritable); if flags != new_flags { - fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFD(new_flags))?; + fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFD(new_flags)).map_err(std::io::Error::from)?; } Ok(()) } +pub fn is_session_leader() -> bool { + unsafe { libc::getsid(0) == libc::getpid() } +} + +pub fn getpid() -> libc::pid_t { + unsafe { libc::getpid() } +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn dup_fd(fd: BorrowedFd<'_>) -> std::io::Result { + nix::unistd::dup(fd).map_err(std::io::Error::from) +} + +#[cfg(not(target_os = "redox"))] +pub fn symlinkat(src: &CStr, dir_fd: BorrowedFd<'_>, dst: &CStr) -> std::io::Result<()> { + nix::unistd::symlinkat(src, dir_fd, dst).map_err(std::io::Error::from) +} + +#[cfg(target_os = "redox")] +pub fn symlink(src: &CStr, dst: &CStr) -> std::io::Result<()> { + let ret = unsafe { libc::symlink(src.as_ptr(), dst.as_ptr()) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(not(target_os = "redox"))] +pub fn chroot(path: &Path) -> std::io::Result<()> { + nix::unistd::chroot(path).map_err(std::io::Error::from) +} + +#[cfg(not(target_os = "redox"))] +pub fn unlinkat(dir_fd: i32, path: &CStr) -> std::io::Result<()> { + let ret = unsafe { libc::unlinkat(dir_fd, path.as_ptr(), 0) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(any(target_os = "macos", target_os = "freebsd", target_os = "netbsd"))] +pub fn lchmod(path: &CStr, mode: libc::mode_t) -> std::io::Result<()> { + unsafe extern "C" { + fn lchmod(path: *const libc::c_char, mode: libc::mode_t) -> libc::c_int; + } + if unsafe { lchmod(path.as_ptr(), mode) } == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) + } +} + +#[cfg(target_os = "macos")] +pub fn fcopyfile(in_fd: i32, out_fd: i32, flags: u32) -> std::io::Result<()> { + unsafe extern "C" { + fn fcopyfile( + in_fd: libc::c_int, + out_fd: libc::c_int, + state: *mut libc::c_void, + flags: u32, + ) -> libc::c_int; + } + let ret = unsafe { fcopyfile(in_fd, out_fd, core::ptr::null_mut(), flags) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(not(windows))] +pub fn make_dir(path: &CStr, mode: u32) -> std::io::Result<()> { + let ret = unsafe { libc::mkdir(path.as_ptr(), mode as _) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(all(not(windows), not(target_os = "redox")))] +pub fn make_dir_at(dir_fd: i32, path: &CStr, mode: u32) -> std::io::Result<()> { + let ret = unsafe { libc::mkdirat(dir_fd, path.as_ptr(), mode as _) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(unix)] +pub fn link_paths(src: &CStr, dst: &CStr, follow_symlinks: bool) -> std::io::Result<()> { + let flags = if follow_symlinks { + libc::AT_SYMLINK_FOLLOW + } else { + 0 + }; + let ret = unsafe { + libc::linkat( + libc::AT_FDCWD, + src.as_ptr(), + libc::AT_FDCWD, + dst.as_ptr(), + flags, + ) + }; + if ret != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(all(not(windows), not(target_os = "redox")))] +pub fn remove_dir_at(dir_fd: i32, path: &CStr) -> std::io::Result<()> { + let ret = unsafe { libc::unlinkat(dir_fd, path.as_ptr(), libc::AT_REMOVEDIR) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +fn statvfs_info_from_raw(st: libc::statvfs) -> StatVfsInfo { + let f_fsid = { + let ptr = core::ptr::addr_of!(st.f_fsid) as *const u8; + let size = core::mem::size_of_val(&st.f_fsid); + if size >= 8 { + let bytes = unsafe { core::slice::from_raw_parts(ptr, 8) }; + u64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]) as libc::c_ulong + } else if size >= 4 { + let bytes = unsafe { core::slice::from_raw_parts(ptr, 4) }; + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as libc::c_ulong + } else { + 0 + } + }; + + StatVfsInfo { + f_bsize: st.f_bsize, + f_frsize: st.f_frsize, + f_blocks: st.f_blocks, + f_bfree: st.f_bfree, + f_bavail: st.f_bavail, + f_files: st.f_files, + f_ffree: st.f_ffree, + f_favail: st.f_favail, + f_flag: st.f_flag, + f_namemax: st.f_namemax, + f_fsid, + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn statvfs_path(path: &CStr) -> std::io::Result { + let mut st: libc::statvfs = unsafe { core::mem::zeroed() }; + let ret = unsafe { libc::statvfs(path.as_ptr(), &mut st) }; + if ret != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(statvfs_info_from_raw(st)) + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn statvfs_fd(fd: i32) -> std::io::Result { + let mut st: libc::statvfs = unsafe { core::mem::zeroed() }; + let ret = unsafe { libc::fstatvfs(fd, &mut st) }; + if ret != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(statvfs_info_from_raw(st)) + } +} + +#[cfg(not(target_os = "redox"))] +pub fn mknod(path: &CStr, mode: libc::mode_t, device: libc::dev_t) -> std::io::Result<()> { + let ret = unsafe { libc::mknod(path.as_ptr(), mode, device) }; + if ret == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) + } +} + +#[cfg(all(not(target_os = "redox"), not(target_vendor = "apple")))] +pub fn mknodat( + dir_fd: i32, + path: &CStr, + mode: libc::mode_t, + device: libc::dev_t, +) -> std::io::Result<()> { + let ret = unsafe { libc::mknodat(dir_fd, path.as_ptr(), mode, device) }; + if ret == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) + } +} + +fn uid_from_raw(uid: u32) -> nix::unistd::Uid { + nix::unistd::Uid::from_raw(uid) +} + +fn gid_from_raw(gid: u32) -> nix::unistd::Gid { + nix::unistd::Gid::from_raw(gid) +} + +pub fn fchown(fd: BorrowedFd<'_>, uid: Option, gid: Option) -> std::io::Result<()> { + nix::unistd::fchown(fd, uid.map(uid_from_raw), gid.map(gid_from_raw)) + .map_err(std::io::Error::from) +} + +#[cfg(not(windows))] +pub fn stat_path( + path: &OsStr, + dir_fd: Option, + follow_symlinks: bool, +) -> std::io::Result> { + use crate::os::ffi::OsStrExt; + + let path = match CString::new(path.as_bytes()) { + Ok(path) => path, + Err(_) => return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)), + }; + + let mut stat = core::mem::MaybeUninit::uninit(); + #[cfg(not(target_os = "redox"))] + if let Some(dir_fd) = dir_fd { + let flags = if follow_symlinks { + 0 + } else { + libc::AT_SYMLINK_NOFOLLOW + }; + let ret = unsafe { libc::fstatat(dir_fd, path.as_ptr(), stat.as_mut_ptr(), flags) }; + if ret < 0 { + return Err(std::io::Error::last_os_error()); + } + return Ok(Some(unsafe { stat.assume_init() })); + } + + let ret = if follow_symlinks { + unsafe { libc::stat(path.as_ptr(), stat.as_mut_ptr()) } + } else { + unsafe { libc::lstat(path.as_ptr(), stat.as_mut_ptr()) } + }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(Some(unsafe { stat.assume_init() })) + } +} + +#[cfg(not(windows))] +pub fn stat_fd(fd: crate::crt_fd::Borrowed<'_>) -> std::io::Result { + crate::fileutils::fstat(fd) +} + +#[cfg(not(target_os = "redox"))] +pub fn fchdir(fd: i32) -> std::io::Result<()> { + let ret = unsafe { libc::fchdir(fd) }; + if ret == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) + } +} + +pub fn fork() -> std::io::Result { + let pid = unsafe { libc::fork() }; + if pid == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(pid) + } +} + +pub fn write_fd(fd: BorrowedFd<'_>, buf: &[u8]) -> std::io::Result { + nix::unistd::write(fd, buf).map_err(std::io::Error::from) +} + +pub fn fchownat( + dir_fd: BorrowedFd<'_>, + path: &OsStr, + uid: Option, + gid: Option, + follow_symlinks: bool, +) -> std::io::Result<()> { + let flag = if follow_symlinks { + nix::fcntl::AtFlags::empty() + } else { + nix::fcntl::AtFlags::AT_SYMLINK_NOFOLLOW + }; + nix::unistd::fchownat( + dir_fd, + path, + uid.map(uid_from_raw), + gid.map(gid_from_raw), + flag, + ) + .map_err(std::io::Error::from) +} + +pub fn uname_info() -> Result { + let info = rustix::system::uname(); + Ok(UnameInfo { + sysname: info.sysname().to_str()?.into(), + nodename: info.nodename().to_str()?.into(), + release: info.release().to_str()?.into(), + version: info.version().to_str()?.into(), + machine: info.machine().to_str()?.into(), + }) +} + +#[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "linux", + target_os = "android", + target_os = "netbsd", + target_os = "openbsd" +))] +pub fn pipe2(flags: libc::c_int) -> std::io::Result<(std::os::fd::OwnedFd, std::os::fd::OwnedFd)> { + nix::unistd::pipe2(nix::fcntl::OFlag::from_bits_truncate(flags)).map_err(std::io::Error::from) +} + +#[cfg(not(target_os = "redox"))] +pub fn pipe() -> std::io::Result<(OwnedFd, OwnedFd)> { + let (rfd, wfd) = nix::unistd::pipe().map_err(std::io::Error::from)?; + set_inheritable(rfd.as_fd(), false)?; + set_inheritable(wfd.as_fd(), false)?; + Ok((rfd, wfd)) +} + +pub fn sched_yield() -> std::io::Result<()> { + nix::sched::sched_yield().map_err(std::io::Error::from) +} + +#[cfg(not(target_os = "redox"))] +pub fn nice(increment: i32) -> std::io::Result { + crate::os::clear_errno(); + let res = unsafe { libc::nice(increment) }; + if res == -1 && crate::os::get_errno() != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } +} + +#[cfg(not(target_os = "redox"))] +pub fn sched_get_priority_max(policy: i32) -> std::io::Result { + let max = unsafe { libc::sched_get_priority_max(policy) }; + if max == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(max) + } +} + +#[cfg(not(target_os = "redox"))] +pub fn sched_get_priority_min(policy: i32) -> std::io::Result { + let min = unsafe { libc::sched_get_priority_min(policy) }; + if min == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(min) + } +} + +#[cfg(not(target_os = "redox"))] +pub fn fchmod(fd: BorrowedFd<'_>, mode: u32) -> std::io::Result<()> { + nix::sys::stat::fchmod( + fd, + nix::sys::stat::Mode::from_bits_truncate(mode as libc::mode_t), + ) + .map_err(std::io::Error::from) +} + +#[cfg(target_os = "redox")] +pub fn utimes( + path: &Path, + acc: core::time::Duration, + modif: core::time::Duration, +) -> std::io::Result<()> { + let tv = |d: core::time::Duration| libc::timeval { + tv_sec: d.as_secs() as _, + tv_usec: d.subsec_micros() as _, + }; + nix::sys::stat::utimes(path, &tv(acc).into(), &tv(modif).into()).map_err(std::io::Error::from) +} + +#[cfg(all(any(target_os = "wasi", unix), not(target_os = "redox")))] +pub fn set_file_times_at( + dir_fd: i32, + path: &CStr, + access: core::time::Duration, + modified: core::time::Duration, + follow_symlinks: bool, +) -> std::io::Result<()> { + let ts = |d: core::time::Duration| libc::timespec { + tv_sec: d.as_secs() as _, + tv_nsec: d.subsec_nanos() as _, + }; + let times = [ts(access), ts(modified)]; + let ret = unsafe { + libc::utimensat( + dir_fd, + path.as_ptr(), + times.as_ptr(), + if follow_symlinks { + 0 + } else { + libc::AT_SYMLINK_NOFOLLOW + }, + ) + }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + #[cfg(target_os = "macos")] #[must_use] pub fn get_number_of_os_threads() -> isize { @@ -89,3 +607,1118 @@ pub fn get_number_of_os_threads() -> isize { pub fn get_number_of_os_threads() -> isize { 0 } + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Permissions { + pub is_readable: bool, + pub is_writable: bool, + pub is_executable: bool, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum AccessError { + InvalidMode, + Os(i32), +} + +impl From for AccessError { + fn from(value: std::io::Error) -> Self { + Self::Os(value.raw_os_error().unwrap_or(0)) + } +} + +const F_OK: u8 = 0; +const R_OK: u8 = 4; +const W_OK: u8 = 2; +const X_OK: u8 = 1; + +fn get_permissions(mode: u32) -> Permissions { + Permissions { + is_readable: mode & 4 != 0, + is_writable: mode & 2 != 0, + is_executable: mode & 1 != 0, + } +} + +pub fn get_right_permission( + mode: u32, + file_owner: u32, + file_group: u32, +) -> std::io::Result { + let owner_mode = (mode & 0o700) >> 6; + let owner_permissions = get_permissions(owner_mode); + + let group_mode = (mode & 0o070) >> 3; + let group_permissions = get_permissions(group_mode); + + let others_mode = mode & 0o007; + let others_permissions = get_permissions(others_mode); + + let user_id = nix::unistd::getuid().as_raw(); + let groups_ids = getgroups()?; + + if file_owner == user_id { + Ok(owner_permissions) + } else if groups_ids.contains(&file_group) { + Ok(group_permissions) + } else { + Ok(others_permissions) + } +} + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub fn getgroups() -> std::io::Result> { + use core::ptr; + use libc::{c_int, gid_t}; + use nix::errno::Errno; + + let ret = unsafe { libc::getgroups(0, ptr::null_mut()) }; + let mut groups = + Vec::::with_capacity(Errno::result(ret).map_err(std::io::Error::from)? as usize); + let ret = unsafe { libc::getgroups(groups.capacity() as c_int, groups.as_mut_ptr()) }; + + Errno::result(ret).map_err(std::io::Error::from).map(|s| { + unsafe { groups.set_len(s as usize) }; + groups.into_iter().collect() + }) +} + +#[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "redox")))] +pub fn getgroups() -> std::io::Result> { + nix::unistd::getgroups() + .map(|groups| groups.into_iter().map(|gid| gid.as_raw()).collect()) + .map_err(std::io::Error::from) +} + +#[cfg(target_os = "redox")] +pub fn getgroups() -> std::io::Result> { + Err(std::io::Error::from_raw_os_error(libc::EOPNOTSUPP)) +} + +pub fn check_access(path: &Path, mode: u8) -> Result { + use std::os::unix::fs::MetadataExt; + + if mode & !(R_OK | W_OK | X_OK) != 0 { + return Err(AccessError::InvalidMode); + } + + let metadata = match crate::fs::metadata(path) { + Ok(m) => m, + Err(_) => return Ok(false), + }; + + if mode == F_OK { + return Ok(true); + } + + let perm = get_right_permission(metadata.mode(), metadata.uid(), metadata.gid())?; + + let r_ok = (mode & R_OK == 0) || perm.is_readable; + let w_ok = (mode & W_OK == 0) || perm.is_writable; + let x_ok = (mode & X_OK == 0) || perm.is_executable; + + Ok(r_ok && w_ok && x_ok) +} + +pub fn close_fds(above: i32, keep: &[BorrowedFd<'_>]) { + #[cfg(not(target_os = "redox"))] + if close_dir_fds(above, keep).is_ok() { + return; + } + #[cfg(target_os = "redox")] + if close_filetable_fds(above, keep).is_ok() { + return; + } + close_fds_brute_force(above, keep) +} + +#[allow(clippy::too_many_arguments)] +pub fn setup_child_fds( + fds_to_keep: &[BorrowedFd<'_>], + errpipe_write: BorrowedFd<'_>, + p2cread: i32, + p2cwrite: i32, + c2pread: i32, + c2pwrite: i32, + errread: i32, + errwrite: i32, + errpipe_read: i32, +) -> std::io::Result<()> { + for &fd in fds_to_keep { + if fd.as_raw_fd() != errpipe_write.as_raw_fd() { + set_inheritable(fd, true)?; + } + } + + for fd in [p2cwrite, c2pread, errread] { + if fd >= 0 { + nix::unistd::close(fd).map_err(std::io::Error::from)?; + } + } + nix::unistd::close(errpipe_read).map_err(std::io::Error::from)?; + + let c2pwrite = if c2pwrite == 0 { + let fd = unsafe { BorrowedFd::borrow_raw(c2pwrite) }; + let dup = nix::unistd::dup(fd).map_err(std::io::Error::from)?; + set_inheritable(dup.as_fd(), true)?; + dup.into_raw_fd() + } else { + c2pwrite + }; + + let mut errwrite = errwrite; + while errwrite == 0 || errwrite == 1 { + let fd = unsafe { BorrowedFd::borrow_raw(errwrite) }; + let dup = nix::unistd::dup(fd).map_err(std::io::Error::from)?; + set_inheritable(dup.as_fd(), true)?; + errwrite = dup.into_raw_fd(); + } + + dup_into_stdio(p2cread, 0)?; + dup_into_stdio(c2pwrite, 1)?; + dup_into_stdio(errwrite, 2)?; + Ok(()) +} + +fn dup_into_stdio(fd: i32, io_fd: i32) -> std::io::Result<()> { + if fd < 0 { + return Ok(()); + } + let fd = unsafe { BorrowedFd::borrow_raw(fd) }; + if fd.as_raw_fd() == io_fd { + set_inheritable(fd, true) + } else { + match io_fd { + 0 => nix::unistd::dup2_stdin(fd).map_err(std::io::Error::from), + 1 => nix::unistd::dup2_stdout(fd).map_err(std::io::Error::from), + 2 => nix::unistd::dup2_stderr(fd).map_err(std::io::Error::from), + _ => unreachable!(), + } + } +} + +pub fn chdir(cwd: &CStr) -> nix::Result<()> { + nix::unistd::chdir(cwd) +} + +pub fn set_umask(child_umask: i32) { + if child_umask >= 0 { + unsafe { libc::umask(child_umask as libc::mode_t) }; + } +} + +pub fn umask(mask: libc::mode_t) -> libc::mode_t { + unsafe { libc::umask(mask) } +} + +#[cfg(not(any(target_os = "redox", target_os = "android")))] +pub fn sync() { + unsafe { libc::sync() }; +} + +pub fn getlogin() -> Option { + let ptr = unsafe { libc::getlogin() }; + if ptr.is_null() { + None + } else { + Some(unsafe { CStr::from_ptr(ptr) }.to_owned()) + } +} + +pub fn restore_signals() { + unsafe { + libc::signal(libc::SIGPIPE, libc::SIG_DFL); + libc::signal(libc::SIGXFSZ, libc::SIG_DFL); + } +} + +pub fn setsid_if_needed(call_setsid: bool) -> nix::Result<()> { + if call_setsid { + nix::unistd::setsid()?; + } + Ok(()) +} + +pub fn setpgid_if_needed(pgid_to_set: libc::pid_t) -> nix::Result<()> { + if pgid_to_set > -1 { + nix::unistd::setpgid( + nix::unistd::Pid::from_raw(0), + nix::unistd::Pid::from_raw(pgid_to_set), + )?; + } + Ok(()) +} + +pub fn setgroups_if_needed(_groups: Option<&[u32]>) -> nix::Result<()> { + #[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "redox")))] + if let Some(groups) = _groups { + let groups = groups.iter().copied().map(gid_from_raw).collect::>(); + nix::unistd::setgroups(&groups)?; + } + Ok(()) +} + +pub fn setregid_if_needed(gid: Option) -> nix::Result<()> { + if let Some(gid) = gid.filter(|&x| x != u32::MAX) { + let ret = unsafe { libc::setregid(gid as libc::gid_t, gid as libc::gid_t) }; + nix::Error::result(ret)?; + } + Ok(()) +} + +pub fn setreuid_if_needed(uid: Option) -> nix::Result<()> { + if let Some(uid) = uid.filter(|&x| x != u32::MAX) { + let ret = unsafe { libc::setreuid(uid as libc::uid_t, uid as libc::uid_t) }; + nix::Error::result(ret)?; + } + Ok(()) +} + +pub fn getppid() -> libc::pid_t { + nix::unistd::getppid().as_raw() +} + +pub fn getgid() -> u32 { + nix::unistd::getgid().as_raw() +} + +pub fn getegid() -> u32 { + nix::unistd::getegid().as_raw() +} + +pub fn getpgid(pid: u32) -> std::io::Result { + nix::unistd::getpgid(Some(nix::unistd::Pid::from_raw(pid as i32))) + .map(nix::unistd::Pid::as_raw) + .map_err(std::io::Error::from) +} + +pub fn getpgrp() -> libc::pid_t { + nix::unistd::getpgrp().as_raw() +} + +#[cfg(not(target_os = "redox"))] +pub fn getsid(pid: u32) -> std::io::Result { + nix::unistd::getsid(Some(nix::unistd::Pid::from_raw(pid as i32))) + .map(nix::unistd::Pid::as_raw) + .map_err(std::io::Error::from) +} + +pub fn getuid() -> u32 { + nix::unistd::getuid().as_raw() +} + +pub fn geteuid() -> u32 { + nix::unistd::geteuid().as_raw() +} + +#[cfg(not(any(target_os = "wasi", target_os = "android")))] +pub fn setgid(gid: u32) -> std::io::Result<()> { + nix::unistd::setgid(gid_from_raw(gid)).map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] +pub fn setegid(egid: u32) -> std::io::Result<()> { + nix::unistd::setegid(gid_from_raw(egid)).map_err(std::io::Error::from) +} + +pub fn setpgid(pid: u32, pgid: u32) -> std::io::Result<()> { + nix::unistd::setpgid( + nix::unistd::Pid::from_raw(pid as i32), + nix::unistd::Pid::from_raw(pgid as i32), + ) + .map_err(std::io::Error::from) +} + +pub fn setpgrp() -> std::io::Result<()> { + nix::unistd::setpgid(nix::unistd::Pid::from_raw(0), nix::unistd::Pid::from_raw(0)) + .map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "redox")))] +pub fn setsid() -> std::io::Result<()> { + nix::unistd::setsid() + .map(drop) + .map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "redox")))] +pub fn tcgetpgrp(fd: BorrowedFd<'_>) -> std::io::Result { + nix::unistd::tcgetpgrp(fd) + .map(nix::unistd::Pid::as_raw) + .map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "redox")))] +pub fn tcsetpgrp(fd: BorrowedFd<'_>, pgid: libc::pid_t) -> std::io::Result<()> { + nix::unistd::tcsetpgrp(fd, nix::unistd::Pid::from_raw(pgid)).map_err(std::io::Error::from) +} + +#[cfg(not(target_os = "redox"))] +pub fn getpriority(which: PriorityWhichType, who: PriorityWhoType) -> std::io::Result { + crate::os::clear_errno(); + let retval = unsafe { libc::getpriority(which, who) }; + if crate::os::get_errno() != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(retval) + } +} + +#[cfg(not(target_os = "redox"))] +pub fn setpriority( + which: PriorityWhichType, + who: PriorityWhoType, + priority: i32, +) -> std::io::Result<()> { + let retval = unsafe { libc::setpriority(which, who, priority) }; + if retval == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +pub fn waitpid(pid: libc::pid_t, status: &mut i32, opt: i32) -> std::io::Result { + let res = unsafe { libc::waitpid(pid, status, opt) }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } +} + +pub fn kill(pid: i32, sig: i32) -> std::io::Result<()> { + let ret = unsafe { libc::kill(pid, sig) }; + if ret == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(not(any(target_os = "wasi", target_os = "android")))] +pub fn setuid(uid: u32) -> std::io::Result<()> { + nix::unistd::setuid(uid_from_raw(uid)).map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] +pub fn seteuid(euid: u32) -> std::io::Result<()> { + nix::unistd::seteuid(uid_from_raw(euid)).map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] +pub fn setreuid(ruid: u32, euid: u32) -> std::io::Result<()> { + let ret = unsafe { libc::setreuid(ruid as libc::uid_t, euid as libc::uid_t) }; + nix::Error::result(ret) + .map(drop) + .map_err(std::io::Error::from) +} + +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +pub fn setresuid(ruid: u32, euid: u32, suid: u32) -> std::io::Result<()> { + let ret = unsafe { + libc::setresuid( + ruid as libc::uid_t, + euid as libc::uid_t, + suid as libc::uid_t, + ) + }; + nix::Error::result(ret) + .map(drop) + .map_err(std::io::Error::from) +} + +#[cfg(not(target_os = "redox"))] +pub fn openpty() -> std::io::Result<(OwnedFd, OwnedFd)> { + let pty = nix::pty::openpty(None, None).map_err(std::io::Error::from)?; + set_inheritable(pty.master.as_fd(), false)?; + set_inheritable(pty.slave.as_fd(), false)?; + Ok((pty.master, pty.slave)) +} + +pub fn ttyname(fd: BorrowedFd<'_>) -> std::io::Result { + nix::unistd::ttyname(fd) + .map(std::path::PathBuf::into_os_string) + .map_err(std::io::Error::from) +} + +pub fn execv(path: &CStr, argv: &[&CStr]) -> std::io::Result<()> { + match nix::unistd::execv(path, argv) { + Ok(never) => match never {}, + Err(err) => Err(err.into()), + } +} + +pub fn execve(path: &CStr, argv: &[&CStr], env: &[&CStr]) -> std::io::Result<()> { + match nix::unistd::execve(path, argv, env) { + Ok(never) => match never {}, + Err(err) => Err(err.into()), + } +} + +#[cfg(any(target_os = "android", target_os = "linux", target_os = "openbsd"))] +pub fn getresuid() -> std::io::Result<(u32, u32, u32)> { + let ret = nix::unistd::getresuid().map_err(std::io::Error::from)?; + Ok(( + ret.real.as_raw(), + ret.effective.as_raw(), + ret.saved.as_raw(), + )) +} + +#[cfg(any(target_os = "android", target_os = "linux", target_os = "openbsd"))] +pub fn getresgid() -> std::io::Result<(u32, u32, u32)> { + let ret = nix::unistd::getresgid().map_err(std::io::Error::from)?; + Ok(( + ret.real.as_raw(), + ret.effective.as_raw(), + ret.saved.as_raw(), + )) +} + +#[cfg(any(target_os = "freebsd", target_os = "linux", target_os = "openbsd"))] +pub fn setresgid(rgid: u32, egid: u32, sgid: u32) -> std::io::Result<()> { + let ret = unsafe { + libc::setresgid( + rgid as libc::gid_t, + egid as libc::gid_t, + sgid as libc::gid_t, + ) + }; + nix::Error::result(ret) + .map(drop) + .map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] +pub fn setregid(rgid: u32, egid: u32) -> std::io::Result<()> { + let ret = unsafe { libc::setregid(rgid as libc::gid_t, egid as libc::gid_t) }; + nix::Error::result(ret) + .map(drop) + .map_err(std::io::Error::from) +} + +#[cfg(any(target_os = "freebsd", target_os = "linux", target_os = "openbsd"))] +pub fn initgroups(user: &CStr, gid: u32) -> std::io::Result<()> { + nix::unistd::initgroups(user, gid_from_raw(gid)).map_err(std::io::Error::from) +} + +#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "redox")))] +pub fn setgroups_raw(groups: &[u32]) -> std::io::Result<()> { + let gids = groups.iter().copied().map(gid_from_raw).collect::>(); + nix::unistd::setgroups(&gids).map_err(std::io::Error::from) +} + +pub fn dup_noninheritable(fd: BorrowedFd<'_>) -> std::io::Result { + let fd = nix::unistd::dup(fd).map_err(std::io::Error::from)?; + set_inheritable(fd.as_fd(), false)?; + Ok(fd) +} + +pub fn dup2(fd: BorrowedFd<'_>, fd2: OwnedFd, inheritable: bool) -> std::io::Result { + let mut fd2 = core::mem::ManuallyDrop::new(fd2); + nix::unistd::dup2(fd, &mut fd2).map_err(std::io::Error::from)?; + let fd2 = core::mem::ManuallyDrop::into_inner(fd2); + if !inheritable { + set_inheritable(fd2.as_fd(), false)?; + } + Ok(fd2) +} + +#[cfg(all(unix, not(target_os = "redox")))] +impl FdDirStream { + pub fn from_fd(fd: BorrowedFd<'_>) -> std::io::Result { + let new_fd = dup_fd(fd)?; + let raw_fd = new_fd.into_raw_fd(); + let ptr = unsafe { libc::fdopendir(raw_fd) }; + match NonNull::new(ptr) { + Some(ptr) => Ok(Self(ptr)), + None => { + unsafe { libc::close(raw_fd) }; + Err(std::io::Error::last_os_error()) + } + } + } + + pub fn next_entry(&mut self) -> std::io::Result> { + loop { + crate::os::clear_errno(); + let ptr = unsafe { libc::readdir(self.0.as_ptr()) }; + if ptr.is_null() { + let err = crate::os::get_errno(); + return if err == 0 { + Ok(None) + } else { + Err(std::io::Error::from_raw_os_error(err)) + }; + } + + let entry = unsafe { &*ptr }; + let name = unsafe { CStr::from_ptr(entry.d_name.as_ptr()) }.to_bytes(); + if name == b"." || name == b".." { + continue; + } + #[cfg(target_os = "freebsd")] + let ino = entry.d_fileno as u64; + #[cfg(not(target_os = "freebsd"))] + let ino = entry.d_ino as u64; + + return Ok(Some(RawDirEntry { + name: name.to_vec(), + d_type: (entry.d_type != libc::DT_UNKNOWN).then_some(entry.d_type), + ino, + })); + } + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +impl Drop for FdDirStream { + fn drop(&mut self) { + unsafe { + libc::rewinddir(self.0.as_ptr()); + libc::closedir(self.0.as_ptr()); + } + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +impl core::fmt::Debug for FdDirStream { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_tuple("FdDirStream").field(&self.0).finish() + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +unsafe impl Send for FdDirStream {} +#[cfg(all(unix, not(target_os = "redox")))] +unsafe impl Sync for FdDirStream {} + +pub fn get_terminal_size(fd: libc::c_int) -> std::io::Result<(u16, u16)> { + let mut w = libc::winsize { + ws_row: 0, + ws_col: 0, + ws_xpixel: 0, + ws_ypixel: 0, + }; + let ret = unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut w) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok((w.ws_col, w.ws_row)) + } +} + +#[cfg(target_os = "macos")] +pub fn full_fsync(fd: BorrowedFd<'_>) -> std::io::Result<()> { + let ret = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_FULLFSYNC) }; + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn madvise(addr: usize, len: usize, advice: i32) -> std::io::Result<()> { + let ret = unsafe { libc::madvise(addr as *mut libc::c_void, len, advice) }; + if ret != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +pub fn pathconf(path: &CStr, name: i32) -> std::io::Result> { + crate::os::clear_errno(); + debug_assert_eq!(crate::os::get_errno(), 0); + let raw = unsafe { libc::pathconf(path.as_ptr(), name) }; + if raw == -1 { + if crate::os::get_errno() == 0 { + Ok(None) + } else { + Err(std::io::Error::last_os_error()) + } + } else { + Ok(Some(raw)) + } +} + +pub fn fpathconf(fd: i32, name: i32) -> std::io::Result> { + crate::os::clear_errno(); + debug_assert_eq!(crate::os::get_errno(), 0); + let raw = unsafe { libc::fpathconf(fd, name) }; + if raw == -1 { + if crate::os::get_errno() == 0 { + Ok(None) + } else { + Err(std::io::Error::last_os_error()) + } + } else { + Ok(Some(raw)) + } +} + +pub fn sysconf(name: i32) -> std::io::Result { + crate::os::set_errno(0); + let raw = unsafe { libc::sysconf(name) }; + if raw == -1 && crate::os::get_errno() != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(raw) + } +} + +#[cfg(target_os = "linux")] +/// # Safety +/// +/// `buf` must be valid for writes of `buflen` bytes. +pub unsafe fn getrandom( + buf: *mut libc::c_void, + buflen: usize, + flags: u32, +) -> std::io::Result { + let len = unsafe { libc::syscall(libc::SYS_getrandom, buf, buflen, flags as usize) as isize }; + if len < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(len as usize) + } +} + +pub fn wcoredump(status: i32) -> bool { + libc::WCOREDUMP(status) +} + +pub fn wifcontinued(status: i32) -> bool { + libc::WIFCONTINUED(status) +} + +pub fn wifstopped(status: i32) -> bool { + libc::WIFSTOPPED(status) +} + +pub fn wifsignaled(status: i32) -> bool { + libc::WIFSIGNALED(status) +} + +pub fn wifexited(status: i32) -> bool { + libc::WIFEXITED(status) +} + +pub fn wexitstatus(status: i32) -> i32 { + libc::WEXITSTATUS(status) +} + +pub fn wstopsig(status: i32) -> i32 { + libc::WSTOPSIG(status) +} + +pub fn wtermsig(status: i32) -> i32 { + libc::WTERMSIG(status) +} + +#[cfg(target_os = "linux")] +pub fn pidfd_open(pid: libc::pid_t, flags: u32) -> std::io::Result { + let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, flags) as libc::c_long }; + if fd == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(unsafe { OwnedFd::from_raw_fd(fd as libc::c_int) }) + } +} + +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +pub fn getgrouplist(user: &CStr, gid: u32) -> std::io::Result> { + nix::unistd::getgrouplist(user, gid_from_raw(gid)) + .map(|groups| groups.into_iter().map(|gid| gid.as_raw()).collect()) + .map_err(std::io::Error::from) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +pub fn validate_posix_spawn_signal(sig: i32) -> bool { + nix::sys::signal::Signal::try_from(sig).is_ok() +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +pub const fn supports_posix_spawn_setsid() -> bool { + cfg!(any( + target_os = "linux", + target_os = "haiku", + target_os = "solaris", + target_os = "illumos", + target_os = "hurd", + )) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +fn build_posix_spawn_file_actions( + actions: &[PosixSpawnFileAction], +) -> std::io::Result { + let mut file_actions = + nix::spawn::PosixSpawnFileActions::init().map_err(std::io::Error::from)?; + for action in actions { + match action { + PosixSpawnFileAction::Open { + fd, + path, + oflag, + mode, + } => file_actions + .add_open( + *fd, + path.as_c_str(), + nix::fcntl::OFlag::from_bits_retain(*oflag), + nix::sys::stat::Mode::from_bits_retain(*mode as libc::mode_t), + ) + .map_err(std::io::Error::from)?, + PosixSpawnFileAction::Close { fd } => { + file_actions.add_close(*fd).map_err(std::io::Error::from)? + } + PosixSpawnFileAction::Dup2 { fd, newfd } => file_actions + .add_dup2(*fd, *newfd) + .map_err(std::io::Error::from)?, + } + } + Ok(file_actions) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +fn build_sigset(signals: &[i32]) -> nix::sys::signal::SigSet { + let mut set = nix::sys::signal::SigSet::empty(); + for &sig in signals { + let sig = nix::sys::signal::Signal::try_from(sig).expect("validated signal"); + set.add(sig); + } + set +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +fn build_posix_spawn_attrs( + config: &PosixSpawnConfig<'_>, +) -> std::io::Result { + let mut attrp = nix::spawn::PosixSpawnAttr::init().map_err(std::io::Error::from)?; + let mut flags = nix::spawn::PosixSpawnFlags::empty(); + + if let Some(sigs) = config.setsigdef { + let set = build_sigset(sigs); + attrp.set_sigdefault(&set).map_err(std::io::Error::from)?; + flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_SETSIGDEF); + } + + if let Some(pgid) = config.setpgroup { + attrp + .set_pgroup(nix::unistd::Pid::from_raw(pgid)) + .map_err(std::io::Error::from)?; + flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_SETPGROUP); + } + + if config.resetids { + flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_RESETIDS); + } + + if config.setsid { + #[cfg(any( + target_os = "linux", + target_os = "haiku", + target_os = "solaris", + target_os = "illumos", + target_os = "hurd", + ))] + { + flags.insert(nix::spawn::PosixSpawnFlags::from_bits_retain( + libc::POSIX_SPAWN_SETSID, + )); + } + #[cfg(not(any( + target_os = "linux", + target_os = "haiku", + target_os = "solaris", + target_os = "illumos", + target_os = "hurd", + )))] + { + return Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "setsid parameter is not supported on this platform", + )); + } + } + + if let Some(sigs) = config.setsigmask { + let set = build_sigset(sigs); + attrp.set_sigmask(&set).map_err(std::io::Error::from)?; + flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_SETSIGMASK); + } + + if !flags.is_empty() { + attrp.set_flags(flags).map_err(std::io::Error::from)?; + } + + Ok(attrp) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] +pub fn posix_spawn(config: PosixSpawnConfig<'_>) -> std::io::Result { + let file_actions = build_posix_spawn_file_actions(config.file_actions)?; + let attrp = build_posix_spawn_attrs(&config)?; + let pid = if config.spawnp { + nix::spawn::posix_spawnp(config.path, &file_actions, &attrp, config.args, config.env) + } else { + nix::spawn::posix_spawn(config.path, &file_actions, &attrp, config.args, config.env) + } + .map_err(std::io::Error::from)?; + Ok(pid.into()) +} + +#[cfg(target_os = "linux")] +pub fn sendfile( + out_fd: BorrowedFd<'_>, + in_fd: BorrowedFd<'_>, + offset: &mut crate::crt_fd::Offset, + count: usize, +) -> std::io::Result { + nix::sys::sendfile::sendfile(out_fd, in_fd, Some(offset), count).map_err(std::io::Error::from) +} + +#[cfg(target_os = "macos")] +pub fn sendfile( + in_fd: BorrowedFd<'_>, + out_fd: BorrowedFd<'_>, + offset: crate::crt_fd::Offset, + count: i64, + headers: Option<&[&[u8]]>, + trailers: Option<&[&[u8]]>, +) -> (std::io::Result<()>, i64) { + let (res, written) = + nix::sys::sendfile::sendfile(in_fd, out_fd, offset, Some(count), headers, trailers); + (res.map_err(std::io::Error::from), written) +} + +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd" +))] +pub fn sched_getscheduler(pid: libc::pid_t) -> std::io::Result { + let policy = unsafe { libc::sched_getscheduler(pid) }; + if policy == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(policy) + } +} + +#[cfg(all( + not(target_env = "musl"), + any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd" + ) +))] +pub fn sched_setscheduler( + pid: i32, + policy: i32, + param: &libc::sched_param, +) -> std::io::Result { + let ret = unsafe { libc::sched_setscheduler(pid, policy, param) }; + if ret == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(ret) + } +} + +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd" +))] +pub fn sched_getparam(pid: libc::pid_t) -> std::io::Result { + let mut param = core::mem::MaybeUninit::uninit(); + let ret = unsafe { libc::sched_getparam(pid, param.as_mut_ptr()) }; + if ret == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(unsafe { param.assume_init() }) + } +} + +#[cfg(all( + not(target_env = "musl"), + any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd" + ) +))] +pub fn sched_setparam(pid: i32, param: &libc::sched_param) -> std::io::Result { + let ret = unsafe { libc::sched_setparam(pid, param) }; + if ret == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(ret) + } +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +pub fn exec_replace>( + exec_list: &[T], + argv: *const *const libc::c_char, + envp: Option<*const *const libc::c_char>, +) -> nix::errno::Errno { + let mut first_err = None; + for exec in exec_list { + if let Some(envp) = envp { + unsafe { libc::execve(exec.as_ref().as_ptr(), argv, envp) }; + } else { + unsafe { libc::execv(exec.as_ref().as_ptr(), argv) }; + } + let e = nix::errno::Errno::last(); + if e != nix::errno::Errno::ENOENT && e != nix::errno::Errno::ENOTDIR && first_err.is_none() + { + first_err = Some(e); + } + } + first_err.unwrap_or_else(nix::errno::Errno::last) +} + +fn should_keep(above: i32, keep: &[BorrowedFd<'_>], fd: i32) -> bool { + fd > above + && keep + .binary_search_by_key(&fd, BorrowedFd::as_raw_fd) + .is_err() +} + +#[cfg(not(target_os = "redox"))] +fn close_dir_fds(above: i32, keep: &[BorrowedFd<'_>]) -> nix::Result<()> { + use nix::{dir::Dir, fcntl::OFlag}; + use std::os::fd::AsRawFd; + + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "netbsd", + target_os = "openbsd", + target_vendor = "apple", + ))] + let fd_dir_name = c"/dev/fd"; + + #[cfg(any(target_os = "linux", target_os = "android"))] + let fd_dir_name = c"/proc/self/fd"; + + let mut dir = Dir::open( + fd_dir_name, + OFlag::O_RDONLY | OFlag::O_DIRECTORY, + nix::sys::stat::Mode::empty(), + )?; + let dirfd = dir.as_raw_fd(); + 'outer: for e in dir.iter() { + let e = e?; + let mut parser = IntParser::default(); + for &c in e.file_name().to_bytes() { + if parser.feed(c).is_err() { + continue 'outer; + } + } + let fd = parser.num; + if fd != dirfd && should_keep(above, keep, fd) { + let _ = nix::unistd::close(fd); + } + } + Ok(()) +} + +#[cfg(target_os = "redox")] +fn close_filetable_fds(above: i32, keep: &[BorrowedFd<'_>]) -> nix::Result<()> { + use nix::fcntl; + use std::os::fd::AsRawFd; + + let filetable = fcntl::open( + c"/scheme/thisproc/current/filetable", + fcntl::OFlag::O_RDONLY, + nix::sys::stat::Mode::empty(), + )?; + let read_one = || -> nix::Result<_> { + let mut byte = 0; + let n = nix::unistd::read(&filetable, std::slice::from_mut(&mut byte))?; + Ok((n > 0).then_some(byte)) + }; + while let Some(c) = read_one()? { + let mut parser = IntParser::default(); + if parser.feed(c).is_err() { + continue; + } + let done = loop { + let Some(c) = read_one()? else { break true }; + if parser.feed(c).is_err() { + break false; + } + }; + + let fd = parser.num; + if fd != filetable.as_raw_fd() && should_keep(above, keep, fd) { + let _ = nix::unistd::close(fd); + } + if done { + break; + } + } + Ok(()) +} + +fn close_fds_brute_force(above: i32, keep: &[BorrowedFd<'_>]) { + debug_assert!( + keep.windows(2) + .all(|fds| fds[0].as_raw_fd() <= fds[1].as_raw_fd()), + "close_fds_brute_force requires `keep` to be sorted ascending" + ); + + let max_fd = nix::unistd::sysconf(nix::unistd::SysconfVar::OPEN_MAX) + .ok() + .flatten() + .unwrap_or(256) as i32; + + let mut prev = above; + for fd in keep + .iter() + .map(BorrowedFd::as_raw_fd) + .chain(core::iter::once(max_fd)) + { + for candidate in prev + 1..fd { + unsafe { libc::close(candidate) }; + } + prev = fd; + } +} + +#[derive(Default)] +struct IntParser { + num: i32, +} + +struct NonDigit; + +impl IntParser { + fn feed(&mut self, c: u8) -> Result<(), NonDigit> { + let digit = (c as char).to_digit(10).ok_or(NonDigit)?; + self.num *= 10; + self.num += digit as i32; + Ok(()) + } +} diff --git a/crates/host_env/src/posix_wasi.rs b/crates/host_env/src/posix_wasi.rs new file mode 100644 index 00000000000..d4eff8c6866 --- /dev/null +++ b/crates/host_env/src/posix_wasi.rs @@ -0,0 +1,85 @@ +use alloc::ffi::CString; +use core::{ffi::CStr, time::Duration}; +use std::{ffi::OsStr, io}; + +use crate::os::CheckLibcResult; + +pub fn make_dir(path: &CStr, mode: u32) -> io::Result<()> { + unsafe { libc::mkdir(path.as_ptr(), mode as _) }.check_libc_neg()?; + Ok(()) +} + +pub fn make_dir_at(dir_fd: i32, path: &CStr, mode: u32) -> io::Result<()> { + unsafe { libc::mkdirat(dir_fd, path.as_ptr(), mode as _) }.check_libc_neg()?; + Ok(()) +} + +pub fn remove_dir_at(dir_fd: i32, path: &CStr) -> io::Result<()> { + unsafe { libc::unlinkat(dir_fd, path.as_ptr(), libc::AT_REMOVEDIR) }.check_libc_neg()?; + Ok(()) +} + +pub fn stat_path( + path: &OsStr, + dir_fd: Option, + follow_symlinks: bool, +) -> io::Result> { + use crate::os::ffi::OsStrExt; + + let path = match CString::new(path.as_bytes()) { + Ok(path) => path, + Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidInput)), + }; + + let mut stat = core::mem::MaybeUninit::uninit(); + if let Some(dir_fd) = dir_fd { + let flags = if follow_symlinks { + 0 + } else { + libc::AT_SYMLINK_NOFOLLOW + }; + unsafe { libc::fstatat(dir_fd, path.as_ptr(), stat.as_mut_ptr(), flags) } + .check_libc_neg()?; + return Ok(Some(unsafe { stat.assume_init() })); + } + + let ret = if follow_symlinks { + unsafe { libc::stat(path.as_ptr(), stat.as_mut_ptr()) } + } else { + unsafe { libc::lstat(path.as_ptr(), stat.as_mut_ptr()) } + }; + ret.check_libc_neg()?; + Ok(Some(unsafe { stat.assume_init() })) +} + +pub fn stat_fd(fd: crate::crt_fd::Borrowed<'_>) -> io::Result { + crate::fileutils::fstat(fd) +} + +pub fn set_file_times_at( + dir_fd: i32, + path: &CStr, + access: Duration, + modified: Duration, + follow_symlinks: bool, +) -> io::Result<()> { + let ts = |d: Duration| libc::timespec { + tv_sec: d.as_secs() as _, + tv_nsec: d.subsec_nanos() as _, + }; + let times = [ts(access), ts(modified)]; + unsafe { + libc::utimensat( + dir_fd, + path.as_ptr(), + times.as_ptr(), + if follow_symlinks { + 0 + } else { + libc::AT_SYMLINK_NOFOLLOW + }, + ) + } + .check_libc_neg()?; + Ok(()) +} diff --git a/crates/host_env/src/pwd.rs b/crates/host_env/src/pwd.rs new file mode 100644 index 00000000000..db04b46b6d3 --- /dev/null +++ b/crates/host_env/src/pwd.rs @@ -0,0 +1,61 @@ +use nix::unistd::{self, User}; +use std::io; + +#[derive(Debug, Clone)] +pub struct Passwd { + pub name: String, + pub passwd: String, + pub uid: u32, + pub gid: u32, + pub gecos: String, + pub dir: String, + pub shell: String, +} + +impl From for Passwd { + fn from(user: User) -> Self { + let cstr_lossy = |s: alloc::ffi::CString| { + s.into_string() + .unwrap_or_else(|e| e.into_cstring().to_string_lossy().into_owned()) + }; + let pathbuf_lossy = |p: std::path::PathBuf| { + p.into_os_string() + .into_string() + .unwrap_or_else(|s| s.to_string_lossy().into_owned()) + }; + Self { + name: user.name, + passwd: cstr_lossy(user.passwd), + uid: user.uid.as_raw(), + gid: user.gid.as_raw(), + gecos: cstr_lossy(user.gecos), + dir: pathbuf_lossy(user.dir), + shell: pathbuf_lossy(user.shell), + } + } +} + +pub fn getpwnam(name: &str) -> Option { + User::from_name(name).ok().flatten().map(Into::into) +} + +pub fn getpwuid(uid: libc::uid_t) -> io::Result> { + User::from_uid(unistd::Uid::from_raw(uid)) + .map(|user| user.map(Into::into)) + .map_err(io::Error::from) +} + +#[cfg(not(target_os = "android"))] +pub fn getpwall() -> Vec { + static GETPWALL: parking_lot::Mutex<()> = parking_lot::Mutex::new(()); + let _guard = GETPWALL.lock(); + let mut list = Vec::new(); + + unsafe { libc::setpwent() }; + while let Some(ptr) = core::ptr::NonNull::new(unsafe { libc::getpwent() }) { + list.push(User::from(unsafe { ptr.as_ref() }).into()); + } + unsafe { libc::endpwent() }; + + list +} diff --git a/crates/host_env/src/resource.rs b/crates/host_env/src/resource.rs new file mode 100644 index 00000000000..587428fe9b0 --- /dev/null +++ b/crates/host_env/src/resource.rs @@ -0,0 +1,74 @@ +use std::io; + +use crate::os::CheckLibcResult; + +#[derive(Debug, Clone, Copy)] +pub struct RUsage { + pub ru_utime: libc::timeval, + pub ru_stime: libc::timeval, + pub ru_maxrss: libc::c_long, + pub ru_ixrss: libc::c_long, + pub ru_idrss: libc::c_long, + pub ru_isrss: libc::c_long, + pub ru_minflt: libc::c_long, + pub ru_majflt: libc::c_long, + pub ru_nswap: libc::c_long, + pub ru_inblock: libc::c_long, + pub ru_oublock: libc::c_long, + pub ru_msgsnd: libc::c_long, + pub ru_msgrcv: libc::c_long, + pub ru_nsignals: libc::c_long, + pub ru_nvcsw: libc::c_long, + pub ru_nivcsw: libc::c_long, +} + +impl From for RUsage { + fn from(rusage: libc::rusage) -> Self { + Self { + ru_utime: rusage.ru_utime, + ru_stime: rusage.ru_stime, + ru_maxrss: rusage.ru_maxrss, + ru_ixrss: rusage.ru_ixrss, + ru_idrss: rusage.ru_idrss, + ru_isrss: rusage.ru_isrss, + ru_minflt: rusage.ru_minflt, + ru_majflt: rusage.ru_majflt, + ru_nswap: rusage.ru_nswap, + ru_inblock: rusage.ru_inblock, + ru_oublock: rusage.ru_oublock, + ru_msgsnd: rusage.ru_msgsnd, + ru_msgrcv: rusage.ru_msgrcv, + ru_nsignals: rusage.ru_nsignals, + ru_nvcsw: rusage.ru_nvcsw, + ru_nivcsw: rusage.ru_nivcsw, + } + } +} + +pub fn getrusage(who: i32) -> io::Result { + let mut rusage = core::mem::MaybeUninit::::uninit(); + unsafe { libc::getrusage(who, rusage.as_mut_ptr()) }.check_libc_neg()?; + Ok(unsafe { rusage.assume_init() }.into()) +} + +pub fn getrlimit(resource: libc::rlim_t) -> io::Result { + let mut rlimit = core::mem::MaybeUninit::::uninit(); + unsafe { libc::getrlimit(resource as _, rlimit.as_mut_ptr()) }.check_libc_neg()?; + Ok(unsafe { rlimit.assume_init() }) +} + +pub fn setrlimit(resource: libc::rlim_t, limits: libc::rlimit) -> io::Result<()> { + unsafe { libc::setrlimit(resource as _, &limits) }.check_libc_neg()?; + Ok(()) +} + +#[cfg(not(any(target_os = "redox", target_os = "wasi")))] +pub fn disable_core_dumps() { + let rl = libc::rlimit { + rlim_cur: 0, + rlim_max: 0, + }; + unsafe { + let _ = libc::setrlimit(libc::RLIMIT_CORE, &rl); + } +} diff --git a/crates/host_env/src/select.rs b/crates/host_env/src/select.rs index 922bafe0697..385a6a110b2 100644 --- a/crates/host_env/src/select.rs +++ b/crates/host_env/src/select.rs @@ -3,24 +3,31 @@ use std::io; #[cfg(unix)] pub mod platform { + pub use libc::pollfd; pub use libc::{FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO, fd_set, select, timeval}; + use std::io; pub use std::os::unix::io::RawFd; #[must_use] pub const fn check_err(x: i32) -> bool { x < 0 } + + pub fn last_select_error() -> io::Error { + io::Error::last_os_error() + } } #[allow(non_snake_case)] #[cfg(windows)] pub mod platform { pub use WinSock::{FD_SET as fd_set, FD_SETSIZE, SOCKET as RawFd, TIMEVAL as timeval, select}; + use std::io; use windows_sys::Win32::Networking::WinSock; /// # Safety /// - /// Requirements forwarded from the caller. + /// `set` must be a valid mutable pointer to an initialized WinSock fd_set. pub unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) { let mut slot = unsafe { (&raw mut (*set).fd_array).cast::() }; let fd_count = unsafe { (*set).fd_count }; @@ -40,14 +47,14 @@ pub mod platform { /// # Safety /// - /// Requirements forwarded from the caller. + /// `set` must be a valid mutable pointer to a WinSock fd_set. pub unsafe fn FD_ZERO(set: *mut fd_set) { unsafe { (*set).fd_count = 0 }; } /// # Safety /// - /// Requirements forwarded from the caller. + /// `set` must be a valid mutable pointer to an initialized WinSock fd_set. pub unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool { use WinSock::__WSAFDIsSet; unsafe { __WSAFDIsSet(fd as _, set) != 0 } @@ -57,11 +64,16 @@ pub mod platform { pub fn check_err(x: i32) -> bool { x == WinSock::SOCKET_ERROR } + + pub fn last_select_error() -> io::Error { + io::Error::from_raw_os_error(unsafe { WinSock::WSAGetLastError() }) + } } #[cfg(target_os = "wasi")] pub mod platform { pub use libc::{FD_SETSIZE, timeval}; + use std::io; pub use std::os::fd::RawFd; pub const fn check_err(x: i32) -> bool { @@ -75,6 +87,9 @@ pub mod platform { } #[allow(non_snake_case)] + /// # Safety + /// + /// `set` must be a valid pointer to an initialized fd_set. pub unsafe fn FD_ISSET(fd: RawFd, set: *const fd_set) -> bool { let set = unsafe { &*set }; for p in &set.__fds[..set.__nfds] { @@ -86,6 +101,9 @@ pub mod platform { } #[allow(non_snake_case)] + /// # Safety + /// + /// `set` must be a valid mutable pointer to an initialized fd_set. pub unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) { let set = unsafe { &mut *set }; for p in &set.__fds[..set.__nfds] { @@ -94,12 +112,16 @@ pub mod platform { } } let n = set.__nfds; - assert!(n < set.__fds.len(), "fd_set full"); - set.__fds[n] = fd; - set.__nfds = n + 1; + if n < FD_SETSIZE { + set.__fds[n] = fd; + set.__nfds = n + 1; + } } #[allow(non_snake_case)] + /// # Safety + /// + /// `set` must be a valid mutable pointer to an fd_set. pub unsafe fn FD_ZERO(set: *mut fd_set) { unsafe { (*set).__nfds = 0 }; } @@ -113,15 +135,21 @@ pub mod platform { timeout: *const timeval, ) -> libc::c_int; } + + pub fn last_select_error() -> io::Error { + io::Error::last_os_error() + } } pub use platform::{RawFd, timeval}; +#[cfg(unix)] +pub type PollFd = platform::pollfd; + #[repr(transparent)] pub struct FdSet(MaybeUninit); impl FdSet { - #[must_use] pub fn new() -> Self { let mut fdset = MaybeUninit::zeroed(); unsafe { platform::FD_ZERO(fdset.as_mut_ptr()) }; @@ -174,16 +202,118 @@ pub fn select( ) }; if platform::check_err(ret) { - Err(io::Error::last_os_error()) + Err(platform::last_select_error()) } else { Ok(ret) } } -#[must_use] pub fn sec_to_timeval(sec: f64) -> timeval { timeval { tv_sec: sec.trunc() as _, tv_usec: (sec.fract() * 1e6) as _, } } + +#[cfg(unix)] +#[inline] +pub fn search_poll_fd(fds: &[PollFd], fd: i32) -> Result { + fds.binary_search_by_key(&fd, |pfd| pfd.fd) +} + +#[cfg(unix)] +pub fn insert_poll_fd(fds: &mut Vec, fd: i32, events: i16) { + match search_poll_fd(fds, fd) { + Ok(i) => fds[i].events = events, + Err(i) => fds.insert( + i, + PollFd { + fd, + events, + revents: 0, + }, + ), + } +} + +#[cfg(unix)] +pub fn get_poll_fd_mut(fds: &mut [PollFd], fd: i32) -> Option<&mut PollFd> { + search_poll_fd(fds, fd).ok().map(move |i| &mut fds[i]) +} + +#[cfg(unix)] +pub fn remove_poll_fd(fds: &mut Vec, fd: i32) -> Option { + search_poll_fd(fds, fd).ok().map(|i| fds.remove(i)) +} + +#[cfg(unix)] +pub fn poll_fds(fds: &mut [PollFd], timeout: i32) -> std::io::Result { + let res = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as _, timeout) }; + if res < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } +} + +#[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))] +pub mod epoll { + use std::os::fd::{AsFd, IntoRawFd, OwnedFd}; + + pub use rustix::event::Timespec; + pub use rustix::event::epoll::{Event, EventData, EventFlags}; + + #[derive(Debug)] + pub enum WaitError { + Interrupted, + Io(std::io::Error), + } + + pub fn create() -> std::io::Result { + rustix::event::epoll::create(rustix::event::epoll::CreateFlags::CLOEXEC).map_err(Into::into) + } + + pub fn close(fd: OwnedFd) -> nix::Result<()> { + nix::unistd::close(fd.into_raw_fd()) + } + + pub fn add(epoll: &OwnedFd, fd: F, data: u64, events: u32) -> std::io::Result<()> { + rustix::event::epoll::add( + epoll, + fd, + EventData::new_u64(data), + EventFlags::from_bits_retain(events), + ) + .map_err(Into::into) + } + + pub fn modify(epoll: &OwnedFd, fd: F, data: u64, events: u32) -> std::io::Result<()> { + rustix::event::epoll::modify( + epoll, + fd, + EventData::new_u64(data), + EventFlags::from_bits_retain(events), + ) + .map_err(Into::into) + } + + pub fn delete(epoll: &OwnedFd, fd: F) -> std::io::Result<()> { + rustix::event::epoll::delete(epoll, fd).map_err(Into::into) + } + + pub fn wait( + epoll: &OwnedFd, + events: &mut Vec, + timeout: Option<&Timespec>, + ) -> Result { + events.clear(); + match rustix::event::epoll::wait(epoll, rustix::buffer::spare_capacity(events), timeout) { + Ok(n) => { + unsafe { events.set_len(n) }; + Ok(n) + } + Err(rustix::io::Errno::INTR) => Err(WaitError::Interrupted), + Err(err) => Err(WaitError::Io(err.into())), + } + } +} diff --git a/crates/host_env/src/shm.rs b/crates/host_env/src/shm.rs index 08a2ae9787c..78e7d3921bc 100644 --- a/crates/host_env/src/shm.rs +++ b/crates/host_env/src/shm.rs @@ -1,23 +1,16 @@ use core::ffi::CStr; use std::io; +use crate::os::CheckLibcResult; + pub fn shm_open(name: &CStr, flags: libc::c_int, mode: libc::c_uint) -> io::Result { #[cfg(target_os = "freebsd")] let mode = mode.try_into().unwrap(); - let fd = unsafe { libc::shm_open(name.as_ptr(), flags, mode) }; - if fd == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(fd) - } + unsafe { libc::shm_open(name.as_ptr(), flags, mode) }.check_libc_neg() } pub fn shm_unlink(name: &CStr) -> io::Result<()> { - let ret = unsafe { libc::shm_unlink(name.as_ptr()) }; - if ret == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } + unsafe { libc::shm_unlink(name.as_ptr()) }.check_libc_neg()?; + Ok(()) } diff --git a/crates/host_env/src/signal.rs b/crates/host_env/src/signal.rs index 21794cc1827..f9f2a8206b9 100644 --- a/crates/host_env/src/signal.rs +++ b/crates/host_env/src/signal.rs @@ -1,8 +1,22 @@ +use std::io; +#[cfg(windows)] +use std::sync::Once; + +#[cfg(unix)] +use crate::os::CheckLibcResult; +#[cfg(any(unix, windows))] +use crate::os::CheckLibcZero; + +#[cfg(any(unix, windows))] +pub use libc::sighandler_t; + +#[cfg(unix)] #[must_use] pub fn timeval_to_double(tv: &libc::timeval) -> f64 { tv.tv_sec as f64 + (tv.tv_usec as f64 / 1_000_000.0) } +#[cfg(unix)] #[must_use] pub fn double_to_timeval(val: f64) -> libc::timeval { libc::timeval { @@ -11,6 +25,7 @@ pub fn double_to_timeval(val: f64) -> libc::timeval { } } +#[cfg(unix)] #[must_use] pub fn itimerval_to_tuple(it: &libc::itimerval) -> (f64, f64) { ( @@ -18,3 +33,310 @@ pub fn itimerval_to_tuple(it: &libc::itimerval) -> (f64, f64) { timeval_to_double(&it.it_interval), ) } + +#[cfg(all(unix, not(target_os = "redox")))] +unsafe extern "C" { + #[link_name = "siginterrupt"] + fn c_siginterrupt(sig: i32, flag: i32) -> i32; +} + +#[cfg(any(target_os = "linux", target_os = "android"))] +mod ffi { + unsafe extern "C" { + pub(super) fn getitimer( + which: libc::c_int, + curr_value: *mut libc::itimerval, + ) -> libc::c_int; + pub(super) fn setitimer( + which: libc::c_int, + new_value: *const libc::itimerval, + old_value: *mut libc::itimerval, + ) -> libc::c_int; + } +} + +#[cfg(any(unix, windows))] +/// # Safety +/// +/// The caller must ensure `signalnum` is a valid platform signal number. +pub unsafe fn probe_handler(signalnum: i32) -> Option { + let handler = unsafe { libc::signal(signalnum, libc::SIG_IGN) }; + if handler == libc::SIG_ERR as sighandler_t { + None + } else { + unsafe { libc::signal(signalnum, handler) }; + Some(handler) + } +} + +#[cfg(any(unix, windows))] +/// # Safety +/// +/// The caller must ensure `signalnum` is a valid platform signal number and +/// `handler` is accepted by the platform signal ABI. +pub unsafe fn install_handler(signalnum: i32, handler: sighandler_t) -> io::Result { + let old = unsafe { libc::signal(signalnum, handler) }; + if old == libc::SIG_ERR as sighandler_t { + return Err(io::Error::last_os_error()); + } + #[cfg(all(unix, not(target_os = "redox")))] + let _ = siginterrupt(signalnum, 1); + Ok(old) +} + +#[cfg(any(unix, windows))] +pub fn raise_signal(signalnum: i32) -> io::Result<()> { + unsafe { libc::raise(signalnum) }.check_libc_zero() +} + +#[cfg(unix)] +pub fn alarm(seconds: u32) -> u32 { + unsafe { libc::alarm(seconds) } +} + +#[cfg(unix)] +pub fn pause() { + unsafe { libc::pause() }; +} + +#[cfg(unix)] +pub fn set_sigint_default_onstack() -> io::Result<()> { + let mut action: libc::sigaction = unsafe { core::mem::zeroed() }; + action.sa_sigaction = libc::SIG_DFL; + action.sa_flags = libc::SA_ONSTACK; + unsafe { libc::sigemptyset(&mut action.sa_mask) }.check_libc_zero()?; + unsafe { libc::sigaction(libc::SIGINT, &action, core::ptr::null_mut()) }.check_libc_zero() +} + +#[cfg(unix)] +pub fn send_sigint_to_self() -> io::Result<()> { + unsafe { libc::kill(libc::getpid(), libc::SIGINT) }.check_libc_zero() +} + +#[cfg(unix)] +pub fn setitimer(which: i32, new: &libc::itimerval) -> io::Result { + let mut old = core::mem::MaybeUninit::::uninit(); + #[cfg(any(target_os = "linux", target_os = "android"))] + let ret = unsafe { ffi::setitimer(which, new, old.as_mut_ptr()) }; + #[cfg(not(any(target_os = "linux", target_os = "android")))] + let ret = unsafe { libc::setitimer(which, new, old.as_mut_ptr()) }; + ret.check_libc_zero()?; + Ok(unsafe { old.assume_init() }) +} + +#[cfg(unix)] +pub fn getitimer(which: i32) -> io::Result { + let mut old = core::mem::MaybeUninit::::uninit(); + #[cfg(any(target_os = "linux", target_os = "android"))] + let ret = unsafe { ffi::getitimer(which, old.as_mut_ptr()) }; + #[cfg(not(any(target_os = "linux", target_os = "android")))] + let ret = unsafe { libc::getitimer(which, old.as_mut_ptr()) }; + ret.check_libc_zero()?; + Ok(unsafe { old.assume_init() }) +} + +#[cfg(unix)] +pub fn sigemptyset() -> io::Result { + let mut set: libc::sigset_t = unsafe { core::mem::zeroed() }; + unsafe { libc::sigemptyset(&mut set) }.check_libc_zero()?; + Ok(set) +} + +#[cfg(unix)] +pub fn sigaddset(set: &mut libc::sigset_t, signum: i32) -> io::Result<()> { + unsafe { libc::sigaddset(set, signum) }.check_libc_zero() +} + +#[cfg(unix)] +pub fn pthread_sigmask(how: i32, set: &libc::sigset_t) -> io::Result { + let mut old_mask: libc::sigset_t = unsafe { core::mem::zeroed() }; + let err = unsafe { libc::pthread_sigmask(how, set, &mut old_mask) }; + if err != 0 { + Err(io::Error::from_raw_os_error(err)) + } else { + Ok(old_mask) + } +} + +#[cfg(target_os = "linux")] +pub fn pidfd_send_signal(pidfd: i32, sig: i32, flags: u32) -> io::Result<()> { + let ret = unsafe { + libc::syscall( + libc::SYS_pidfd_send_signal, + pidfd, + sig, + core::ptr::null::(), + flags, + ) as libc::c_long + }; + ret.check_libc_neg()?; + Ok(()) +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn siginterrupt(signalnum: i32, flag: i32) -> io::Result<()> { + unsafe { c_siginterrupt(signalnum, flag) }.check_libc_neg()?; + Ok(()) +} + +#[cfg(windows)] +pub const VALID_SIGNALS: &[i32] = &[ + libc::SIGINT, + libc::SIGILL, + libc::SIGFPE, + libc::SIGSEGV, + libc::SIGTERM, + 21, // SIGBREAK / _SIGBREAK + libc::SIGABRT, +]; + +#[cfg(windows)] +pub const SIGBREAK: i32 = 21; +#[cfg(windows)] +pub const CTRL_C_EVENT: u32 = 0; +#[cfg(windows)] +pub const CTRL_BREAK_EVENT: u32 = 1; +#[cfg(windows)] +pub const INVALID_SOCKET: libc::SOCKET = windows_sys::Win32::Networking::WinSock::INVALID_SOCKET; + +#[cfg(windows)] +pub fn is_valid_signal(signalnum: i32) -> bool { + VALID_SIGNALS.contains(&signalnum) +} + +#[cfg(windows)] +fn init_winsock() { + static WSA_INIT: Once = Once::new(); + WSA_INIT.call_once(|| unsafe { + let mut wsa_data = core::mem::MaybeUninit::uninit(); + let _ = windows_sys::Win32::Networking::WinSock::WSAStartup(0x0101, wsa_data.as_mut_ptr()); + }); +} + +#[cfg(windows)] +pub fn wakeup_fd_is_socket(fd: libc::SOCKET) -> io::Result { + use windows_sys::Win32::Networking::WinSock; + + init_winsock(); + let mut res = 0i32; + let mut res_size = core::mem::size_of::() as i32; + let getsockopt_res = unsafe { + WinSock::getsockopt( + fd, + WinSock::SOL_SOCKET, + WinSock::SO_ERROR, + &mut res as *mut i32 as *mut _, + &mut res_size, + ) + }; + if getsockopt_res == 0 { + return Ok(true); + } + + let err = io::Error::last_os_error(); + if err.raw_os_error() != Some(WinSock::WSAENOTSOCK) { + return Err(err); + } + + let fd_i32 = + i32::try_from(fd).map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid fd"))?; + let borrowed = unsafe { crate::crt_fd::Borrowed::try_borrow_raw(fd_i32) }?; + crate::fileutils::fstat(borrowed)?; + Ok(false) +} + +#[cfg(windows)] +pub fn notify_signal( + signum: i32, + wakeup_fd: libc::SOCKET, + wakeup_is_socket: bool, + sigint_event: Option, +) { + if signum == libc::SIGINT + && let Some(handle) = sigint_event + { + unsafe { + windows_sys::Win32::System::Threading::SetEvent(handle as _); + } + } + + if wakeup_fd == INVALID_SOCKET { + return; + } + + let sigbyte = signum as u8; + if wakeup_is_socket { + unsafe { + let _ = windows_sys::Win32::Networking::WinSock::send( + wakeup_fd, + &sigbyte as *const u8 as *const _, + 1, + 0, + ); + } + } else { + unsafe { + let _ = libc::write(wakeup_fd as _, &sigbyte as *const u8 as *const _, 1); + } + } +} + +#[cfg(unix)] +pub fn notify_signal(signum: i32, wakeup_fd: i32) { + if wakeup_fd == -1 { + return; + } + let sigbyte = signum as u8; + unsafe { + let _ = libc::write(wakeup_fd, &sigbyte as *const u8 as *const _, 1); + } +} + +#[cfg(unix)] +pub fn strsignal(signalnum: i32) -> Option { + let s = unsafe { libc::strsignal(signalnum) }; + if s.is_null() { + None + } else { + let cstr = unsafe { core::ffi::CStr::from_ptr(s) }; + Some(cstr.to_string_lossy().into_owned()) + } +} + +#[cfg(windows)] +pub fn strsignal(signalnum: i32) -> Option { + let name = match signalnum { + libc::SIGINT => "Interrupt", + libc::SIGILL => "Illegal instruction", + libc::SIGFPE => "Floating-point exception", + libc::SIGSEGV => "Segmentation fault", + libc::SIGTERM => "Terminated", + 21 => "Break", + libc::SIGABRT => "Aborted", + _ => return None, + }; + Some(name.to_owned()) +} + +#[cfg(unix)] +pub fn valid_signals(max_signum: usize) -> io::Result> { + let mut mask: libc::sigset_t = unsafe { core::mem::zeroed() }; + unsafe { libc::sigfillset(&mut mask) }.check_libc_zero()?; + let mut signals = Vec::new(); + for signum in 1..max_signum { + if unsafe { libc::sigismember(&mask, signum as i32) } == 1 { + signals.push(signum as i32); + } + } + Ok(signals) +} + +#[cfg(unix)] +pub fn sigset_contains(mask: &libc::sigset_t, signum: i32) -> bool { + unsafe { libc::sigismember(mask, signum) == 1 } +} + +#[cfg(windows)] +pub fn valid_signals(_max_signum: usize) -> io::Result> { + Ok(VALID_SIGNALS.to_vec()) +} diff --git a/crates/host_env/src/socket.rs b/crates/host_env/src/socket.rs new file mode 100644 index 00000000000..c409132a725 --- /dev/null +++ b/crates/host_env/src/socket.rs @@ -0,0 +1,817 @@ +#[cfg(unix)] +use crate::os::CheckLibcResult; +#[cfg(unix)] +use core::ffi::CStr; +#[cfg(unix)] +use std::os::fd::AsRawFd; +#[cfg(unix)] +use std::{io, os::fd::BorrowedFd}; + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn sethostname(hostname: &str) -> io::Result<()> { + nix::unistd::sethostname(hostname).map_err(io::Error::from) +} + +#[cfg(unix)] +pub fn close_socket_ignore_connreset(socket: libc::c_int) -> io::Result<()> { + let ret = unsafe { libc::close(socket) }; + if ret < 0 { + let err = io::Error::last_os_error(); + if err.raw_os_error() != Some(libc::ECONNRESET) { + return Err(err); + } + } + Ok(()) +} + +#[cfg(unix)] +pub fn getsockopt_int(fd: libc::c_int, level: i32, name: i32) -> io::Result { + let mut flag: libc::c_int = 0; + let mut flagsize = core::mem::size_of::() as libc::socklen_t; + unsafe { + libc::getsockopt( + fd, + level, + name, + &mut flag as *mut libc::c_int as *mut _, + &mut flagsize, + ) + } + .check_libc_neg()?; + Ok(flag) +} + +#[cfg(unix)] +pub fn getsockopt_bytes( + fd: libc::c_int, + level: i32, + name: i32, + buflen: usize, +) -> io::Result> { + let mut buf = vec![0u8; buflen]; + let mut optlen = buflen as libc::socklen_t; + unsafe { libc::getsockopt(fd, level, name, buf.as_mut_ptr() as *mut _, &mut optlen) } + .check_libc_neg()?; + buf.truncate(optlen as usize); + Ok(buf) +} + +#[cfg(unix)] +pub fn setsockopt_bytes(fd: libc::c_int, level: i32, name: i32, value: &[u8]) -> io::Result<()> { + unsafe { + libc::setsockopt( + fd, + level, + name, + value.as_ptr() as *const _, + value.len() as libc::socklen_t, + ) + } + .check_libc_neg()?; + Ok(()) +} + +#[cfg(unix)] +pub fn setsockopt_int(fd: libc::c_int, level: i32, name: i32, value: i32) -> io::Result<()> { + unsafe { + libc::setsockopt( + fd, + level, + name, + &value as *const i32 as *const _, + core::mem::size_of::() as libc::socklen_t, + ) + } + .check_libc_neg()?; + Ok(()) +} + +#[cfg(unix)] +pub fn setsockopt_none(fd: libc::c_int, level: i32, name: i32, optlen: u32) -> io::Result<()> { + unsafe { + libc::setsockopt( + fd, + level, + name, + core::ptr::null(), + optlen as libc::socklen_t, + ) + } + .check_libc_neg()?; + Ok(()) +} + +#[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", +))] +pub fn if_nameindex() -> io::Result> { + let list = nix::net::if_::if_nameindex().map_err(io::Error::from)?; + Ok(list + .to_slice() + .iter() + .map(|iface| (iface.index(), iface.name().to_string_lossy().into_owned())) + .collect()) +} + +#[cfg(unix)] +pub fn if_nametoindex_checked(name: &CStr) -> io::Result { + let ret = unsafe { libc::if_nametoindex(name.as_ptr()) }; + if ret == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret) + } +} + +#[cfg(unix)] +pub fn if_indextoname_checked(index: u32) -> io::Result { + let mut buf = [0u8; libc::IF_NAMESIZE]; + let ret = unsafe { + libc::if_indextoname(index as libc::c_uint, buf.as_mut_ptr() as *mut libc::c_char) + }; + if ret.is_null() { + Err(io::Error::last_os_error()) + } else { + let buf = unsafe { CStr::from_ptr(buf.as_ptr() as *const libc::c_char) }; + Ok(buf.to_string_lossy().into_owned()) + } +} + +#[cfg(unix)] +pub fn gai_error_string(err: i32) -> String { + unsafe { CStr::from_ptr(libc::gai_strerror(err)) } + .to_string_lossy() + .into_owned() +} + +#[cfg(unix)] +pub fn h_error_string(err: i32) -> String { + unsafe { CStr::from_ptr(libc::hstrerror(err)) } + .to_string_lossy() + .into_owned() +} + +#[cfg(all(unix, not(target_os = "redox")))] +#[derive(Debug, Clone)] +pub struct AncillaryMessage { + pub level: i32, + pub kind: i32, + pub data: Vec, +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub type SocketAddressBytes = [u8; core::mem::size_of::()]; + +#[cfg(all(unix, not(target_os = "redox")))] +#[derive(Debug, Clone)] +pub struct RawSocketAddress { + pub storage: SocketAddressBytes, + pub len: usize, +} + +#[cfg(all(unix, not(target_os = "redox")))] +#[derive(Debug, Clone)] +pub struct RecvMsgResult { + pub data: Vec, + pub ancdata: Vec, + pub msg_flags: i32, + pub address: Option, +} + +#[cfg(all(unix, not(target_os = "redox")))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AncillaryPackError { + ItemTooLarge, + TooMuchData, + UnexpectedNullHeader, +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn checked_cmsg_len(len: usize) -> Option { + let cmsg_len = |length| unsafe { libc::CMSG_LEN(length) }; + if len as u64 > (i32::MAX as u64 - cmsg_len(0) as u64) { + return None; + } + let res = cmsg_len(len as _) as usize; + if res > i32::MAX as usize || res < len { + return None; + } + Some(res) +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn checked_cmsg_space(len: usize) -> Option { + let cmsg_space = |length| unsafe { libc::CMSG_SPACE(length) }; + if len as u64 > (i32::MAX as u64 - cmsg_space(1) as u64) { + return None; + } + let res = cmsg_space(len as _) as usize; + if res > i32::MAX as usize || res < len { + return None; + } + Some(res) +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn pack_ancillary_messages(cmsgs: &[(i32, i32, &[u8])]) -> Result, AncillaryPackError> { + use core::{mem, ptr}; + + if cmsgs.is_empty() { + return Ok(vec![]); + } + + let capacity = cmsgs + .iter() + .map(|(_, _, buf)| buf.len()) + .try_fold(0usize, |sum, len| { + let space = checked_cmsg_space(len).ok_or(AncillaryPackError::ItemTooLarge)?; + usize::checked_add(sum, space).ok_or(AncillaryPackError::TooMuchData) + })?; + + let mut cmsg_buffer = vec![0u8; capacity]; + let mut mhdr = unsafe { mem::zeroed::() }; + mhdr.msg_control = cmsg_buffer.as_mut_ptr().cast(); + mhdr.msg_controllen = capacity as _; + + let mut pmhdr: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(&mhdr) }; + for (lvl, typ, data) in cmsgs { + if pmhdr.is_null() { + return Err(AncillaryPackError::UnexpectedNullHeader); + } + let cmsg_len = checked_cmsg_len(data.len()).ok_or(AncillaryPackError::ItemTooLarge)?; + unsafe { + (*pmhdr).cmsg_level = *lvl; + (*pmhdr).cmsg_type = *typ; + (*pmhdr).cmsg_len = cmsg_len as _; + ptr::copy_nonoverlapping(data.as_ptr(), libc::CMSG_DATA(pmhdr), data.len()); + pmhdr = libc::CMSG_NXTHDR(&mhdr, pmhdr); + } + } + + Ok(cmsg_buffer) +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn parse_ancillary_messages(control: &[u8]) -> Vec { + use core::mem; + + if control.is_empty() { + return Vec::new(); + } + + let mut msg = unsafe { mem::zeroed::() }; + msg.msg_control = control.as_ptr() as *mut _; + msg.msg_controllen = control.len() as _; + + let ctrl_buf = msg.msg_control as *const u8; + let ctrl_end = unsafe { ctrl_buf.add(msg.msg_controllen as _) }; + + let mut result = Vec::new(); + let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(&msg) }; + while !cmsg.is_null() { + let cmsg_ref = unsafe { &*cmsg }; + let data_ptr = unsafe { libc::CMSG_DATA(cmsg) }; + let data_len_from_cmsg = cmsg_ref.cmsg_len as usize - (data_ptr as usize - cmsg as usize); + let available = ctrl_end as usize - data_ptr as usize; + let data_len = data_len_from_cmsg.min(available); + let data = unsafe { core::slice::from_raw_parts(data_ptr, data_len) }; + result.push(AncillaryMessage { + level: cmsg_ref.cmsg_level, + kind: cmsg_ref.cmsg_type, + data: data.to_vec(), + }); + cmsg = unsafe { libc::CMSG_NXTHDR(&msg, cmsg) }; + } + + result +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn recvmsg( + fd: BorrowedFd<'_>, + bufsize: usize, + ancbufsize: usize, + flags: i32, +) -> io::Result { + use core::mem::MaybeUninit; + + let mut data_buf: Vec> = vec![MaybeUninit::uninit(); bufsize]; + let mut anc_buf: Vec> = vec![MaybeUninit::uninit(); ancbufsize]; + let mut addr_storage: libc::sockaddr_storage = unsafe { core::mem::zeroed() }; + + let mut iov = [libc::iovec { + iov_base: data_buf.as_mut_ptr().cast(), + iov_len: bufsize, + }]; + + let mut msg: libc::msghdr = unsafe { core::mem::zeroed() }; + msg.msg_name = (&mut addr_storage as *mut libc::sockaddr_storage).cast(); + msg.msg_namelen = core::mem::size_of::() as libc::socklen_t; + msg.msg_iov = iov.as_mut_ptr(); + msg.msg_iovlen = 1; + if ancbufsize > 0 { + msg.msg_control = anc_buf.as_mut_ptr().cast(); + msg.msg_controllen = ancbufsize as _; + } + + let ret = unsafe { libc::recvmsg(fd.as_raw_fd(), &mut msg, flags) }.check_libc_neg()?; + + let data = unsafe { + data_buf.set_len(ret as usize); + core::mem::transmute::>, Vec>(data_buf) + }; + let control = unsafe { + core::slice::from_raw_parts(anc_buf.as_ptr().cast::(), msg.msg_controllen as usize) + }; + let ancdata = parse_ancillary_messages(control); + let address = if msg.msg_namelen > 0 { + let storage = unsafe { + core::mem::transmute::(addr_storage) + }; + Some(RawSocketAddress { + storage, + len: msg.msg_namelen as usize, + }) + } else { + None + }; + + Ok(RecvMsgResult { + data, + ancdata, + msg_flags: msg.msg_flags, + address, + }) +} + +#[cfg(target_os = "linux")] +pub fn sendmsg_afalg( + fd: BorrowedFd<'_>, + buffers: &[io::IoSlice<'_>], + op: u32, + iv: Option<&[u8]>, + assoclen: Option, + flags: i32, +) -> io::Result { + let mut control_buf = Vec::new(); + + { + let op_bytes = op.to_ne_bytes(); + let space = unsafe { libc::CMSG_SPACE(core::mem::size_of::() as u32) } as usize; + let old_len = control_buf.len(); + control_buf.resize(old_len + space, 0u8); + + let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; + unsafe { + (*cmsg).cmsg_len = libc::CMSG_LEN(core::mem::size_of::() as u32) as _; + (*cmsg).cmsg_level = libc::SOL_ALG; + (*cmsg).cmsg_type = libc::ALG_SET_OP; + let data = libc::CMSG_DATA(cmsg); + core::ptr::copy_nonoverlapping(op_bytes.as_ptr(), data, op_bytes.len()); + } + } + + if let Some(iv_bytes) = iv { + let iv_struct_size = 4 + iv_bytes.len(); + let space = unsafe { libc::CMSG_SPACE(iv_struct_size as u32) } as usize; + let old_len = control_buf.len(); + control_buf.resize(old_len + space, 0u8); + + let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; + unsafe { + (*cmsg).cmsg_len = libc::CMSG_LEN(iv_struct_size as u32) as _; + (*cmsg).cmsg_level = libc::SOL_ALG; + (*cmsg).cmsg_type = libc::ALG_SET_IV; + let data = libc::CMSG_DATA(cmsg); + let ivlen = (iv_bytes.len() as u32).to_ne_bytes(); + core::ptr::copy_nonoverlapping(ivlen.as_ptr(), data, 4); + core::ptr::copy_nonoverlapping(iv_bytes.as_ptr(), data.add(4), iv_bytes.len()); + } + } + + if let Some(assoclen_val) = assoclen { + let assoclen_bytes = assoclen_val.to_ne_bytes(); + let space = unsafe { libc::CMSG_SPACE(core::mem::size_of::() as u32) } as usize; + let old_len = control_buf.len(); + control_buf.resize(old_len + space, 0u8); + + let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; + unsafe { + (*cmsg).cmsg_len = libc::CMSG_LEN(core::mem::size_of::() as u32) as _; + (*cmsg).cmsg_level = libc::SOL_ALG; + (*cmsg).cmsg_type = libc::ALG_SET_AEAD_ASSOCLEN; + let data = libc::CMSG_DATA(cmsg); + core::ptr::copy_nonoverlapping(assoclen_bytes.as_ptr(), data, assoclen_bytes.len()); + } + } + + let iovecs: Vec = buffers + .iter() + .map(|buf| libc::iovec { + iov_base: buf.as_ptr() as *mut _, + iov_len: buf.len(), + }) + .collect(); + + let mut msghdr: libc::msghdr = unsafe { core::mem::zeroed() }; + msghdr.msg_iov = iovecs.as_ptr() as *mut _; + msghdr.msg_iovlen = iovecs.len() as _; + if !control_buf.is_empty() { + msghdr.msg_control = control_buf.as_mut_ptr() as *mut _; + msghdr.msg_controllen = control_buf.len() as _; + } + + let ret = unsafe { libc::sendmsg(fd.as_raw_fd(), &msghdr, flags) }.check_libc_neg()?; + Ok(ret as usize) +} + +#[cfg(windows)] +use core::{ffi::CStr, ptr::NonNull}; +#[cfg(windows)] +use std::io; +#[cfg(windows)] +use windows_sys::Win32::{ + NetworkManagement::{ + IpHelper::{ + ConvertInterfaceLuidToNameW, FreeMibTable, GetIfTable2Ex, MIB_IF_ROW2, MIB_IF_TABLE2, + MibIfTableRaw, if_indextoname, if_nametoindex, + }, + Ndis::{IF_MAX_STRING_SIZE, NET_LUID_LH}, + }, + Networking::WinSock::{ + FROM_PROTOCOL_INFO, INVALID_SOCKET, SOCKET, SOCKET_ERROR, WSA_FLAG_OVERLAPPED, + WSADuplicateSocketW, WSAGetLastError, WSAIoctl, WSAPROTOCOL_INFOW, WSASocketW, + }, +}; + +#[cfg(windows)] +pub use windows_sys::Win32::Networking::WinSock::{ + AF_APPLETALK, AF_DECnet, AF_IPX, AF_LINK, AI_ADDRCONFIG, AI_ALL, AI_CANONNAME, AI_NUMERICSERV, + AI_V4MAPPED, INADDR_ANY, INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_NONE, IP_ADD_MEMBERSHIP, + IP_DROP_MEMBERSHIP, IP_HDRINCL, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, + IP_OPTIONS, IP_RECVDSTADDR, IP_TOS, IP_TTL, IPPORT_RESERVED, IPPROTO_AH, IPPROTO_CBT, + IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, IPPROTO_FRAGMENT, IPPROTO_GGP, IPPROTO_HOPOPTS, + IPPROTO_ICLFXBM, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IGP, + IPPROTO_IP, IPPROTO_IP as IPPROTO_IPIP, IPPROTO_IPV4, IPPROTO_IPV6, IPPROTO_L2TP, IPPROTO_ND, + IPPROTO_NONE, IPPROTO_PGM, IPPROTO_PIM, IPPROTO_PUP, IPPROTO_RAW, IPPROTO_RDP, IPPROTO_ROUTING, + IPPROTO_SCTP, IPPROTO_ST, IPPROTO_TCP, IPPROTO_UDP, IPV6_CHECKSUM, IPV6_DONTFRAG, + IPV6_HOPLIMIT, IPV6_HOPOPTS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, IPV6_MULTICAST_HOPS, + IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_PKTINFO, IPV6_RECVRTHDR, IPV6_RECVTCLASS, + IPV6_RTHDR, IPV6_TCLASS, IPV6_UNICAST_HOPS, IPV6_V6ONLY, MSG_BCAST, MSG_CTRUNC, MSG_DONTROUTE, + MSG_MCAST, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_MAXSERV, + NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, RCVALL_IPLEVEL, RCVALL_OFF, RCVALL_ON, + RCVALL_SOCKETLEVELONLY, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS, + SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, SO_BROADCAST, SO_ERROR, SO_KEEPALIVE, SO_LINGER, + SO_OOBINLINE, SO_RCVBUF, SO_REUSEADDR, SO_SNDBUF, SO_TYPE, SO_USELOOPBACK, SOCK_DGRAM, + SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET, SOCK_STREAM, SOCKET_ERROR as SOCKET_ERROR_CODE, SOL_SOCKET, + SOMAXCONN, TCP_NODELAY, WSAEBADF, WSAECONNRESET, WSAENOTSOCK, WSAEWOULDBLOCK, getprotobyname, + getservbyname, getservbyport, getsockopt, setsockopt, +}; + +#[cfg(windows)] +pub const SO_EXCLUSIVEADDRUSE: i32 = -5; +#[cfg(windows)] +pub const EAI_MEMORY: i32 = windows_sys::Win32::Networking::WinSock::WSA_NOT_ENOUGH_MEMORY; +#[cfg(windows)] +pub const EAI_FAMILY: i32 = windows_sys::Win32::Networking::WinSock::WSAEAFNOSUPPORT; +#[cfg(windows)] +pub const EAI_BADFLAGS: i32 = windows_sys::Win32::Networking::WinSock::WSAEINVAL; +#[cfg(windows)] +pub const EAI_SOCKTYPE: i32 = windows_sys::Win32::Networking::WinSock::WSAESOCKTNOSUPPORT; +#[cfg(windows)] +pub const EAI_NODATA: i32 = windows_sys::Win32::Networking::WinSock::WSAHOST_NOT_FOUND; +#[cfg(windows)] +pub const EAI_NONAME: i32 = windows_sys::Win32::Networking::WinSock::WSAHOST_NOT_FOUND; +#[cfg(windows)] +pub const EAI_FAIL: i32 = windows_sys::Win32::Networking::WinSock::WSANO_RECOVERY; +#[cfg(windows)] +pub const EAI_AGAIN: i32 = windows_sys::Win32::Networking::WinSock::WSATRY_AGAIN; +#[cfg(windows)] +pub const EAI_SERVICE: i32 = windows_sys::Win32::Networking::WinSock::WSATYPE_NOT_FOUND; +#[cfg(windows)] +pub const IF_NAMESIZE: usize = IF_MAX_STRING_SIZE as usize; +#[cfg(windows)] +pub const AF_UNSPEC: i32 = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as i32; +#[cfg(windows)] +pub const AF_INET: i32 = windows_sys::Win32::Networking::WinSock::AF_INET as i32; +#[cfg(windows)] +pub const AF_INET6: i32 = windows_sys::Win32::Networking::WinSock::AF_INET6 as i32; +#[cfg(windows)] +pub const AI_PASSIVE: i32 = windows_sys::Win32::Networking::WinSock::AI_PASSIVE as i32; +#[cfg(windows)] +pub const AI_NUMERICHOST: i32 = windows_sys::Win32::Networking::WinSock::AI_NUMERICHOST as i32; +#[cfg(windows)] +pub const FROM_PROTOCOL_INFO_VALUE: i32 = FROM_PROTOCOL_INFO; + +#[cfg(windows)] +pub type RawSocket = SOCKET; + +#[cfg(windows)] +pub const INVALID_RAW_SOCKET: RawSocket = INVALID_SOCKET as RawSocket; + +#[cfg(windows)] +#[repr(C)] +pub struct TcpKeepalive { + pub onoff: u32, + pub keepalivetime: u32, + pub keepaliveinterval: u32, +} + +#[cfg(windows)] +pub struct SharedSocket { + pub raw: RawSocket, + pub family: i32, + pub socket_type: i32, + pub protocol: i32, +} + +#[cfg(windows)] +pub fn last_socket_error() -> io::Error { + io::Error::from_raw_os_error(unsafe { WSAGetLastError() }) +} + +#[cfg(windows)] +pub fn set_socket_inheritable(socket: RawSocket, inheritable: bool) -> io::Result<()> { + crate::nt::set_handle_inheritable(socket as _, inheritable) +} + +#[cfg(windows)] +pub fn close_socket_ignore_connreset(socket: RawSocket) -> io::Result<()> { + let ret = unsafe { windows_sys::Win32::Networking::WinSock::closesocket(socket) }; + if ret != 0 { + let err = last_socket_error(); + if err.raw_os_error() != Some(WSAECONNRESET) { + return Err(err); + } + } + Ok(()) +} + +#[cfg(windows)] +pub fn getsockopt_int(socket: RawSocket, level: i32, name: i32) -> io::Result { + let mut flag = 0i32; + let mut optlen = core::mem::size_of::() as i32; + let ret = unsafe { + getsockopt( + socket, + level, + name, + &mut flag as *mut i32 as *mut _, + &mut optlen, + ) + }; + if ret == SOCKET_ERROR { + Err(crate::os::errno_io_error()) + } else { + Ok(flag) + } +} + +#[cfg(windows)] +pub fn getsockopt_bytes( + socket: RawSocket, + level: i32, + name: i32, + buflen: usize, +) -> io::Result> { + let mut buf = vec![0u8; buflen]; + let mut optlen = buflen as i32; + let ret = unsafe { getsockopt(socket, level, name, buf.as_mut_ptr() as *mut _, &mut optlen) }; + if ret == SOCKET_ERROR { + Err(crate::os::errno_io_error()) + } else { + buf.truncate(optlen as usize); + Ok(buf) + } +} + +#[cfg(windows)] +pub fn setsockopt_bytes(socket: RawSocket, level: i32, name: i32, value: &[u8]) -> io::Result<()> { + let ret = unsafe { + setsockopt( + socket, + level, + name, + value.as_ptr() as *const _, + value.len() as i32, + ) + }; + if ret == SOCKET_ERROR { + Err(crate::os::errno_io_error()) + } else { + Ok(()) + } +} + +#[cfg(windows)] +pub fn setsockopt_int(socket: RawSocket, level: i32, name: i32, value: i32) -> io::Result<()> { + let ret = unsafe { + setsockopt( + socket, + level, + name, + &value as *const i32 as *const _, + core::mem::size_of::() as i32, + ) + }; + if ret == SOCKET_ERROR { + Err(crate::os::errno_io_error()) + } else { + Ok(()) + } +} + +#[cfg(windows)] +pub fn setsockopt_none(socket: RawSocket, level: i32, name: i32, optlen: u32) -> io::Result<()> { + let ret = unsafe { setsockopt(socket, level, name, core::ptr::null(), optlen as i32) }; + if ret == SOCKET_ERROR { + Err(crate::os::errno_io_error()) + } else { + Ok(()) + } +} + +#[cfg(windows)] +pub fn protocol_info_size() -> usize { + core::mem::size_of::() +} + +#[cfg(windows)] +pub fn socket_from_share_data(bytes: &[u8]) -> io::Result { + let mut info: WSAPROTOCOL_INFOW = unsafe { core::mem::zeroed() }; + unsafe { + core::ptr::copy_nonoverlapping( + bytes.as_ptr(), + &mut info as *mut WSAPROTOCOL_INFOW as *mut u8, + protocol_info_size(), + ); + } + + let raw = unsafe { + WSASocketW( + FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + &info, + 0, + WSA_FLAG_OVERLAPPED, + ) + }; + if raw == INVALID_SOCKET { + return Err(last_socket_error()); + } + + crate::nt::set_handle_inheritable(raw as _, false)?; + + Ok(SharedSocket { + raw, + family: info.iAddressFamily, + socket_type: info.iSocketType, + protocol: info.iProtocol, + }) +} + +#[cfg(windows)] +pub fn share_socket(socket: RawSocket, process_id: u32) -> io::Result> { + let mut info = core::mem::MaybeUninit::::uninit(); + let ret = unsafe { WSADuplicateSocketW(socket, process_id, info.as_mut_ptr()) }; + if ret == SOCKET_ERROR { + return Err(last_socket_error()); + } + let info = unsafe { info.assume_init() }; + let bytes = unsafe { + core::slice::from_raw_parts( + &info as *const WSAPROTOCOL_INFOW as *const u8, + core::mem::size_of::(), + ) + }; + Ok(bytes.to_vec()) +} + +#[cfg(windows)] +pub fn ioctl_u32(socket: RawSocket, cmd: u32, option: u32) -> io::Result { + let mut recv = 0u32; + let ret = unsafe { + WSAIoctl( + socket, + cmd, + &option as *const u32 as *const _, + core::mem::size_of::() as u32, + core::ptr::null_mut(), + 0, + &mut recv, + core::ptr::null_mut(), + None, + ) + }; + if ret == SOCKET_ERROR { + Err(last_socket_error()) + } else { + Ok(recv) + } +} + +#[cfg(windows)] +pub fn ioctl_keepalive(socket: RawSocket, keepalive: TcpKeepalive) -> io::Result { + let mut recv = 0u32; + let ret = unsafe { + WSAIoctl( + socket, + windows_sys::Win32::Networking::WinSock::SIO_KEEPALIVE_VALS, + &keepalive as *const TcpKeepalive as *const _, + core::mem::size_of::() as u32, + core::ptr::null_mut(), + 0, + &mut recv, + core::ptr::null_mut(), + None, + ) + }; + if ret == SOCKET_ERROR { + Err(last_socket_error()) + } else { + Ok(recv) + } +} + +#[cfg(windows)] +pub fn if_nametoindex_checked(name: &CStr) -> io::Result { + crate::os::set_errno(libc::ENODEV); + let ret = unsafe { if_nametoindex(name.as_ptr() as _) }; + if ret == 0 { + Err(crate::os::errno_io_error()) + } else { + Ok(ret) + } +} + +#[cfg(windows)] +pub fn if_indextoname_checked(index: u32) -> io::Result { + let mut buf = [0; IF_MAX_STRING_SIZE as usize + 1]; + crate::os::set_errno(libc::ENXIO); + let ret = unsafe { if_indextoname(index, buf.as_mut_ptr()) }; + if ret.is_null() { + Err(crate::os::errno_io_error()) + } else { + let buf = unsafe { CStr::from_ptr(buf.as_ptr() as _) }; + Ok(buf.to_string_lossy().into_owned()) + } +} + +#[cfg(windows)] +pub fn if_nameindex() -> io::Result> { + fn get_name(luid: &NET_LUID_LH) -> io::Result { + let mut buf = [0u16; IF_MAX_STRING_SIZE as usize + 1]; + let ret = unsafe { ConvertInterfaceLuidToNameW(luid, buf.as_mut_ptr(), buf.len()) }; + if ret != 0 { + return Err(io::Error::from_raw_os_error(ret as i32)); + } + let len = buf.iter().position(|&c| c == 0).unwrap_or(buf.len()); + Ok(String::from_utf16_lossy(&buf[..len])) + } + + struct MibTable { + ptr: NonNull, + } + + impl MibTable { + fn get_raw() -> io::Result { + let mut ptr = core::ptr::null_mut(); + let ret = unsafe { GetIfTable2Ex(MibIfTableRaw, &mut ptr) }; + if ret == 0 { + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + Ok(Self { ptr }) + } else { + Err(io::Error::from_raw_os_error(ret as i32)) + } + } + + fn as_slice(&self) -> &[MIB_IF_ROW2] { + unsafe { + let p = self.ptr.as_ptr(); + let ptr = &raw const (*p).Table as *const MIB_IF_ROW2; + core::slice::from_raw_parts(ptr, (*p).NumEntries as usize) + } + } + } + + impl Drop for MibTable { + fn drop(&mut self) { + unsafe { FreeMibTable(self.ptr.as_ptr() as *mut _) }; + } + } + + let table = MibTable::get_raw()?; + table + .as_slice() + .iter() + .map(|entry| Ok((entry.InterfaceIndex, get_name(&entry.InterfaceLuid)?))) + .collect() +} diff --git a/crates/host_env/src/syslog.rs b/crates/host_env/src/syslog.rs index a4100ac2a21..8820b8f1c5d 100644 --- a/crates/host_env/src/syslog.rs +++ b/crates/host_env/src/syslog.rs @@ -1,9 +1,7 @@ use alloc::boxed::Box; use core::ffi::CStr; -use std::{ - os::raw::c_char, - sync::{OnceLock, RwLock}, -}; +use parking_lot::RwLock; +use std::{os::raw::c_char, sync::OnceLock}; #[derive(Debug)] enum GlobalIdent { @@ -27,10 +25,7 @@ fn global_ident() -> &'static RwLock> { #[must_use] pub fn is_open() -> bool { - global_ident() - .read() - .expect("syslog lock poisoned") - .is_some() + global_ident().read().is_some() } pub fn openlog(ident: Option>, logoption: i32, facility: i32) { @@ -38,19 +33,20 @@ pub fn openlog(ident: Option>, logoption: i32, facility: i32) { Some(ident) => GlobalIdent::Explicit(ident), None => GlobalIdent::Implicit, }; - let mut locked_ident = global_ident().write().expect("syslog lock poisoned"); + let mut locked_ident = global_ident().write(); unsafe { libc::openlog(ident.as_ptr(), logoption, facility) }; *locked_ident = Some(ident); } pub fn syslog(priority: i32, msg: &CStr) { + let _locked_ident = global_ident().read(); let cformat = c"%s"; unsafe { libc::syslog(priority, cformat.as_ptr(), msg.as_ptr()) }; } pub fn closelog() { - if is_open() { - let mut locked_ident = global_ident().write().expect("syslog lock poisoned"); + let mut locked_ident = global_ident().write(); + if locked_ident.is_some() { unsafe { libc::closelog() }; *locked_ident = None; } @@ -63,7 +59,7 @@ pub fn setlogmask(maskpri: i32) -> i32 { #[must_use] pub const fn log_mask(pri: i32) -> i32 { - pri << 1 + 1 << pri } #[must_use] diff --git a/crates/host_env/src/termios.rs b/crates/host_env/src/termios.rs index 76bcd0c9f01..074d03a455b 100644 --- a/crates/host_env/src/termios.rs +++ b/crates/host_env/src/termios.rs @@ -1,11 +1,105 @@ -pub fn tcgetattr(fd: i32) -> std::io::Result<::termios::Termios> { - ::termios::Termios::from_fd(fd) +pub type Termios = ::termios::Termios; + +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "illumos", + target_os = "linux", + target_os = "macos", + target_os = "openbsd", + target_os = "solaris" +))] +pub use ::termios::os::target::TAB3; +#[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" +))] +pub use ::termios::os::target::TCSASOFT; +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "solaris" +))] +pub use ::termios::os::target::{B460800, B921600}; +#[cfg(any(target_os = "android", target_os = "linux"))] +pub use ::termios::os::target::{ + B500000, B576000, B1000000, B1152000, B1500000, B2000000, B2500000, B3000000, B3500000, + B4000000, CBAUDEX, +}; +#[cfg(any( + target_os = "android", + target_os = "illumos", + target_os = "linux", + target_os = "macos", + target_os = "solaris" +))] +pub use ::termios::os::target::{ + BS0, BS1, BSDLY, CR0, CR1, CR2, CR3, CRDLY, FF0, FF1, FFDLY, NL0, NL1, NLDLY, OFDEL, OFILL, + TAB1, TAB2, VT0, VT1, VTDLY, +}; +#[cfg(any( + target_os = "android", + target_os = "illumos", + target_os = "linux", + target_os = "solaris" +))] +pub use ::termios::os::target::{CBAUD, CIBAUD, IUCLC, OLCUC, XCASE}; +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "illumos", + target_os = "linux", + target_os = "macos", + target_os = "solaris" +))] +pub use ::termios::os::target::{TAB0, TABDLY}; +#[cfg(any(target_os = "android", target_os = "linux"))] +pub use ::termios::os::target::{VSWTC, VSWTC as VSWTCH}; +#[cfg(any(target_os = "illumos", target_os = "solaris"))] +pub use ::termios::os::target::{VSWTCH, VSWTCH as VSWTC}; +pub use ::termios::{ + B0, B50, B75, B110, B134, B150, B200, B300, B600, B1200, B1800, B2400, B4800, B9600, B19200, + B38400, BRKINT, CLOCAL, CREAD, CS5, CS6, CS7, CS8, CSIZE, CSTOPB, ECHO, ECHOE, ECHOK, ECHONL, + HUPCL, ICANON, ICRNL, IEXTEN, IGNBRK, IGNCR, IGNPAR, INLCR, INPCK, ISIG, ISTRIP, IXANY, IXOFF, + IXON, NOFLSH, OCRNL, ONLCR, ONLRET, ONOCR, OPOST, PARENB, PARMRK, PARODD, TCIFLUSH, TCIOFF, + TCIOFLUSH, TCION, TCOFLUSH, TCOOFF, TCOON, TCSADRAIN, TCSAFLUSH, TCSANOW, TOSTOP, VEOF, VEOL, + VERASE, VINTR, VKILL, VMIN, VQUIT, VSTART, VSTOP, VSUSP, VTIME, + os::target::{ + B57600, B115200, B230400, CRTSCTS, ECHOCTL, ECHOKE, ECHOPRT, EXTA, EXTB, FLUSHO, IMAXBEL, + NCCS, PENDIN, VDISCARD, VEOL2, VLNEXT, VREPRINT, VWERASE, + }, +}; + +pub fn tcgetattr(fd: i32) -> std::io::Result { + Termios::from_fd(fd) } -pub fn tcsetattr(fd: i32, when: i32, termios: &::termios::Termios) -> std::io::Result<()> { +pub fn tcsetattr(fd: i32, when: i32, termios: &Termios) -> std::io::Result<()> { ::termios::tcsetattr(fd, when, termios) } +pub fn cfgetispeed(termios: &Termios) -> libc::speed_t { + ::termios::cfgetispeed(termios) +} + +pub fn cfgetospeed(termios: &Termios) -> libc::speed_t { + ::termios::cfgetospeed(termios) +} + +pub fn cfsetispeed(termios: &mut Termios, speed: libc::speed_t) -> std::io::Result<()> { + ::termios::cfsetispeed(termios, speed) +} + +pub fn cfsetospeed(termios: &mut Termios, speed: libc::speed_t) -> std::io::Result<()> { + ::termios::cfsetospeed(termios, speed) +} + pub fn tcsendbreak(fd: i32, duration: i32) -> std::io::Result<()> { ::termios::tcsendbreak(fd, duration) } diff --git a/crates/host_env/src/testconsole.rs b/crates/host_env/src/testconsole.rs new file mode 100644 index 00000000000..02b82f4bf4d --- /dev/null +++ b/crates/host_env/src/testconsole.rs @@ -0,0 +1,45 @@ +use std::io; +use windows_sys::Win32::{ + Foundation::HANDLE, + System::Console::{INPUT_RECORD, KEY_EVENT, WriteConsoleInputW}, +}; + +use crate::windows::{CheckWin32Bool, CheckWin32Handle}; + +pub fn write_console_input(fd: i32, data: &[u16]) -> io::Result<()> { + let handle = (unsafe { libc::get_osfhandle(fd) } as HANDLE).check_valid()?; + + let size = data.len() as u32; + let mut records: Vec = Vec::with_capacity(data.len()); + for &wc in data { + let mut rec: INPUT_RECORD = unsafe { core::mem::zeroed() }; + rec.EventType = KEY_EVENT as u16; + rec.Event.KeyEvent.bKeyDown = 1; + rec.Event.KeyEvent.wRepeatCount = 1; + rec.Event.KeyEvent.uChar.UnicodeChar = wc; + records.push(rec); + } + + let mut total: u32 = 0; + while total < size { + let mut wrote: u32 = 0; + unsafe { + WriteConsoleInputW( + handle, + records[total as usize..].as_ptr(), + size - total, + &mut wrote, + ) + } + .check_win32_bool()?; + if wrote == 0 { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "WriteConsoleInputW made no progress", + )); + } + total += wrote; + } + + Ok(()) +} diff --git a/crates/host_env/src/thread.rs b/crates/host_env/src/thread.rs new file mode 100644 index 00000000000..5b0d42477b6 --- /dev/null +++ b/crates/host_env/src/thread.rs @@ -0,0 +1,42 @@ +#[cfg(any(target_os = "linux", target_os = "macos"))] +use alloc::ffi::CString; + +#[cfg(unix)] +pub fn current_thread_id() -> u64 { + unsafe { libc::pthread_self() as u64 } +} + +#[cfg(windows)] +pub fn current_thread_id() -> u64 { + unsafe { windows_sys::Win32::System::Threading::GetCurrentThreadId() as u64 } +} + +#[cfg(target_os = "linux")] +pub fn set_current_thread_name(name: &str) { + if CString::new(name).is_ok() { + let truncated = if name.len() > 15 { + let mut end = 15; + while !name.is_char_boundary(end) { + end -= 1; + } + CString::new(&name[..end]).expect("slice of null-free string is null-free") + } else { + CString::new(name).expect("name was already checked for nul bytes") + }; + unsafe { + libc::pthread_setname_np(libc::pthread_self(), truncated.as_ptr()); + } + } +} + +#[cfg(target_os = "macos")] +pub fn set_current_thread_name(name: &str) { + if let Ok(c_name) = CString::new(name) { + unsafe { + libc::pthread_setname_np(c_name.as_ptr()); + } + } +} + +#[cfg(not(any(target_os = "linux", target_os = "macos")))] +pub fn set_current_thread_name(_name: &str) {} diff --git a/crates/host_env/src/time.rs b/crates/host_env/src/time.rs index 0bcabe9957b..3e3227d78de 100644 --- a/crates/host_env/src/time.rs +++ b/crates/host_env/src/time.rs @@ -1,6 +1,11 @@ +#[cfg(unix)] +use alloc::ffi::CString; use core::time::Duration; use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH}; +#[cfg(target_env = "msvc")] +use alloc::string::String; + pub const SEC_TO_MS: i64 = 1000; pub const MS_TO_US: i64 = 1000; pub const SEC_TO_US: i64 = SEC_TO_MS * MS_TO_US; @@ -10,21 +15,422 @@ pub const SEC_TO_NS: i64 = SEC_TO_MS * MS_TO_NS; pub const NS_TO_MS: i64 = 1000 * 1000; pub const NS_TO_US: i64 = 1000; +/// Access to the C runtime's `tzset` / `timezone` / `daylight` / `tzname` +/// globals used by Python's `time` module. +/// +/// Not available under MSVC (which exposes these only via the +/// `_get_tzname`-style helpers) or on `wasm32` (no libc tz state). +#[cfg(all(not(target_env = "msvc"), not(target_arch = "wasm32")))] +pub mod tz { + unsafe extern "C" { + #[cfg(not(target_os = "freebsd"))] + #[link_name = "daylight"] + static c_daylight: core::ffi::c_int; + #[link_name = "timezone"] + static c_timezone: core::ffi::c_long; + #[link_name = "tzname"] + static c_tzname: [*const core::ffi::c_char; 2]; + #[link_name = "tzset"] + fn c_tzset(); + } + + pub fn tzset() { + unsafe { c_tzset() } + } + + #[must_use] + pub fn timezone() -> core::ffi::c_long { + unsafe { c_timezone } + } + + #[cfg(not(target_os = "freebsd"))] + #[must_use] + pub fn daylight() -> core::ffi::c_int { + unsafe { c_daylight } + } + + /// Snapshot of `tzname[0]` / `tzname[1]` as owned `String`s. + /// Reads the C globals once and copies the bytes out so callers don't + /// have to handle the raw pointers themselves. + #[must_use] + pub fn tzname_strings() -> (String, String) { + unsafe fn to_str(s: *const core::ffi::c_char) -> String { + unsafe { core::ffi::CStr::from_ptr(s) } + .to_string_lossy() + .into_owned() + } + unsafe { (to_str(c_tzname[0]), to_str(c_tzname[1])) } + } +} + pub fn duration_since_system_now() -> Result { SystemTime::now().duration_since(UNIX_EPOCH) } +#[cfg(unix)] +pub type TimeT = libc::time_t; + +#[cfg(windows)] +pub type TimeT = libc::time_t; + +#[cfg(unix)] +#[derive(Clone, Copy, Debug)] +pub struct ProcessTimes { + pub user: f64, + pub system: f64, + pub children_user: f64, + pub children_system: f64, + pub elapsed: f64, +} + +#[cfg(unix)] +#[cfg_attr(target_env = "musl", allow(deprecated))] +pub fn current_time_t() -> TimeT { + unsafe { libc::time(core::ptr::null_mut()) } +} + +#[cfg(unix)] +#[cfg_attr(target_env = "musl", allow(deprecated))] +pub fn gmtime_from_timestamp(when: TimeT) -> Option { + let mut out = core::mem::MaybeUninit::::uninit(); + let ret = unsafe { libc::gmtime_r(&when, out.as_mut_ptr()) }; + (!ret.is_null()).then(|| unsafe { out.assume_init() }) +} + +#[cfg(unix)] +#[cfg_attr(target_env = "musl", allow(deprecated))] +pub fn localtime_from_timestamp(when: TimeT) -> Option { + let mut out = core::mem::MaybeUninit::::uninit(); + let ret = unsafe { libc::localtime_r(&when, out.as_mut_ptr()) }; + (!ret.is_null()).then(|| unsafe { out.assume_init() }) +} + +#[cfg(unix)] +pub fn mktime(tm: &mut libc::tm) -> TimeT { + unsafe { libc::mktime(tm) } +} + +#[cfg(windows)] +unsafe extern "C" { + fn _gmtime64_s(tm: *mut libc::tm, time: *const libc::time_t) -> libc::c_int; + fn _localtime64_s(tm: *mut libc::tm, time: *const libc::time_t) -> libc::c_int; + #[link_name = "_mktime64"] + fn c_mktime(tm: *mut libc::tm) -> libc::time_t; +} + +#[cfg(windows)] +#[cfg_attr(target_env = "musl", allow(deprecated))] +pub fn current_time_t() -> TimeT { + unsafe { libc::time(core::ptr::null_mut()) } +} + +#[cfg(windows)] +#[cfg_attr(target_env = "musl", allow(deprecated))] +pub fn gmtime_from_timestamp(when: TimeT) -> Option { + let mut out = core::mem::MaybeUninit::::uninit(); + let err = unsafe { _gmtime64_s(out.as_mut_ptr(), &when) }; + (err == 0).then(|| unsafe { out.assume_init() }) +} + +#[cfg(windows)] +#[cfg_attr(target_env = "musl", allow(deprecated))] +pub fn localtime_from_timestamp(when: TimeT) -> Option { + let mut out = core::mem::MaybeUninit::::uninit(); + let err = unsafe { _localtime64_s(out.as_mut_ptr(), &when) }; + (err == 0).then(|| unsafe { out.assume_init() }) +} + +#[cfg(windows)] +pub fn mktime(tm: &mut libc::tm) -> TimeT { + unsafe { crate::suppress_iph!(c_mktime(tm)) } +} + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub fn strerror(errno: i32) -> String { + unsafe { core::ffi::CStr::from_ptr(libc::strerror(errno)) } + .to_string_lossy() + .into_owned() +} + +#[cfg(unix)] +pub fn nix_errno_display(errno: i32) -> String { + nix::errno::Errno::from_raw(errno).to_string() +} + +#[cfg(all(unix, not(any(target_os = "redox", target_os = "android"))))] +pub fn getloadavg() -> std::io::Result<[f64; 3]> { + let mut loadavg = [0f64; 3]; + let ok = unsafe { libc::getloadavg(loadavg.as_mut_ptr(), 3) }; + if ok != 3 { + Err(std::io::Error::last_os_error()) + } else { + Ok(loadavg) + } +} + +#[cfg(unix)] +pub fn waitstatus_to_exitcode(status: libc::c_int) -> Option { + if libc::WIFEXITED(status) { + return Some(libc::WEXITSTATUS(status)); + } + if libc::WIFSIGNALED(status) { + return Some(-libc::WTERMSIG(status)); + } + None +} + +#[cfg(any(unix, all(target_arch = "wasm32", target_os = "emscripten")))] +pub fn process_times() -> std::io::Result { + let mut t = libc::tms { + tms_utime: 0, + tms_stime: 0, + tms_cutime: 0, + tms_cstime: 0, + }; + + let tick_for_second = unsafe { libc::sysconf(libc::_SC_CLK_TCK) }; + if tick_for_second <= 0 { + return Err(std::io::Error::last_os_error()); + } + let tick_for_second = tick_for_second as f64; + let c = unsafe { libc::times(&mut t as *mut _) }; + if c == (-1i8) as libc::clock_t { + return Err(std::io::Error::last_os_error()); + } + + Ok(ProcessTimes { + user: t.tms_utime as f64 / tick_for_second, + system: t.tms_stime as f64 / tick_for_second, + children_user: t.tms_cutime as f64 / tick_for_second, + children_system: t.tms_cstime as f64 / tick_for_second, + elapsed: c as f64 / tick_for_second, + }) +} + +#[cfg(unix)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct ClockId(libc::clockid_t); + +#[cfg(unix)] +impl ClockId { + pub const fn from_raw(raw: libc::clockid_t) -> Self { + Self(raw) + } + + pub const fn as_raw(self) -> libc::clockid_t { + self.0 + } + + pub const CLOCK_MONOTONIC: Self = Self(libc::CLOCK_MONOTONIC); + pub const CLOCK_REALTIME: Self = Self(libc::CLOCK_REALTIME); + + #[cfg(not(any( + target_os = "illumos", + target_os = "netbsd", + target_os = "solaris", + target_os = "openbsd", + target_os = "wasi", + )))] + pub const CLOCK_PROCESS_CPUTIME_ID: Self = Self(libc::CLOCK_PROCESS_CPUTIME_ID); + + #[cfg(not(any( + target_os = "illumos", + target_os = "netbsd", + target_os = "solaris", + target_os = "openbsd", + target_os = "redox", + )))] + pub const CLOCK_THREAD_CPUTIME_ID: Self = Self(libc::CLOCK_THREAD_CPUTIME_ID); +} + +#[cfg(unix)] +fn nix_clock_id(id: ClockId) -> nix::time::ClockId { + nix::time::ClockId::from_raw(id.as_raw()) +} + +#[cfg(unix)] +pub fn clock_gettime(id: ClockId) -> std::io::Result { + nix::time::clock_gettime(nix_clock_id(id)) + .map(Duration::from) + .map_err(std::io::Error::from) +} + +#[cfg(all(unix, not(target_os = "redox")))] +pub fn clock_getres(id: ClockId) -> std::io::Result { + nix::time::clock_getres(nix_clock_id(id)) + .map(Duration::from) + .map_err(std::io::Error::from) +} + +#[cfg(all(unix, not(target_os = "redox"), not(target_vendor = "apple")))] +pub fn clock_settime(id: ClockId, time: Duration) -> std::io::Result<()> { + let ts = nix::sys::time::TimeSpec::from(time); + nix::time::clock_settime(nix_clock_id(id), ts) + .map(drop) + .map_err(std::io::Error::from) +} + +#[cfg(all(unix, not(target_os = "redox"), target_os = "macos"))] +pub fn clock_settime(id: ClockId, time: Duration) -> std::io::Result<()> { + let ts = nix::sys::time::TimeSpec::from(time); + let ret = unsafe { libc::clock_settime(id.as_raw(), ts.as_ref()) }; + if ret != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(unix)] +pub fn nanosleep(duration: Duration) -> std::io::Result<()> { + let ts = nix::sys::time::TimeSpec::from(duration); + let ret = unsafe { libc::nanosleep(ts.as_ref(), core::ptr::null_mut()) }; + if ret != 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(target_os = "solaris")] +pub fn gethrvtime_duration() -> Duration { + Duration::from_nanos(unsafe { libc::gethrvtime() }) +} + +#[cfg(target_env = "msvc")] +#[cfg(not(target_arch = "wasm32"))] +#[derive(Clone, Debug)] +pub struct WindowsTimeZoneInfo { + pub bias: i32, + pub standard_bias: i32, + pub daylight_bias: i32, + pub standard_name: String, + pub daylight_name: String, +} + +#[cfg(target_env = "msvc")] +#[cfg(not(target_arch = "wasm32"))] +fn decode_tz_name(name: &[u16]) -> String { + widestring::decode_utf16_lossy(name.iter().copied()) + .take_while(|&c| c != '\0') + .collect() +} + #[cfg(target_env = "msvc")] #[cfg(not(target_arch = "wasm32"))] #[must_use] -pub fn get_tz_info() -> windows_sys::Win32::System::Time::TIME_ZONE_INFORMATION { +pub fn get_tz_info() -> WindowsTimeZoneInfo { let mut info = unsafe { core::mem::zeroed() }; unsafe { windows_sys::Win32::System::Time::GetTimeZoneInformation(&mut info) }; - info + WindowsTimeZoneInfo { + bias: info.Bias as i32, + standard_bias: info.StandardBias as i32, + daylight_bias: info.DaylightBias as i32, + standard_name: decode_tz_name(&info.StandardName), + daylight_name: decode_tz_name(&info.DaylightName), + } +} + +#[cfg(windows)] +fn u64_from_filetime(time: windows_sys::Win32::Foundation::FILETIME) -> u64 { + u64::from(time.dwLowDateTime) | (u64::from(time.dwHighDateTime) << 32) +} + +#[cfg(windows)] +#[derive(Clone, Copy, Debug)] +pub struct ProcessTimes100ns { + pub user: u64, + pub system: u64, +} + +#[cfg(windows)] +pub fn query_performance_frequency() -> Option { + let mut freq = core::mem::MaybeUninit::uninit(); + (unsafe { + windows_sys::Win32::System::Performance::QueryPerformanceFrequency(freq.as_mut_ptr()) + } != 0) + .then(|| unsafe { freq.assume_init() }) +} + +#[cfg(windows)] +pub fn query_performance_counter() -> i64 { + let mut counter = core::mem::MaybeUninit::uninit(); + unsafe { + windows_sys::Win32::System::Performance::QueryPerformanceCounter(counter.as_mut_ptr()); + counter.assume_init() + } +} + +#[cfg(windows)] +pub fn get_system_time_adjustment() -> Option { + let mut time_adjustment = core::mem::MaybeUninit::uninit(); + let mut time_increment = core::mem::MaybeUninit::uninit(); + let mut is_time_adjustment_disabled = core::mem::MaybeUninit::uninit(); + (unsafe { + windows_sys::Win32::System::SystemInformation::GetSystemTimeAdjustment( + time_adjustment.as_mut_ptr(), + time_increment.as_mut_ptr(), + is_time_adjustment_disabled.as_mut_ptr(), + ) + } != 0) + .then(|| unsafe { time_increment.assume_init() }) +} + +#[cfg(windows)] +pub fn tick_count64() -> u64 { + unsafe { windows_sys::Win32::System::SystemInformation::GetTickCount64() } +} + +#[cfg(windows)] +pub fn get_thread_time_100ns() -> Option { + let mut creation_time = core::mem::MaybeUninit::uninit(); + let mut exit_time = core::mem::MaybeUninit::uninit(); + let mut kernel_time = core::mem::MaybeUninit::uninit(); + let mut user_time = core::mem::MaybeUninit::uninit(); + (unsafe { + windows_sys::Win32::System::Threading::GetThreadTimes( + windows_sys::Win32::System::Threading::GetCurrentThread(), + creation_time.as_mut_ptr(), + exit_time.as_mut_ptr(), + kernel_time.as_mut_ptr(), + user_time.as_mut_ptr(), + ) + } != 0) + .then(|| unsafe { + u64_from_filetime(kernel_time.assume_init()) + + u64_from_filetime(user_time.assume_init()) + }) +} + +#[cfg(windows)] +pub fn get_process_time_100ns() -> Option { + get_process_times_100ns().map(|times| times.user + times.system) +} + +#[cfg(windows)] +pub fn get_process_times_100ns() -> Option { + let mut creation_time = core::mem::MaybeUninit::uninit(); + let mut exit_time = core::mem::MaybeUninit::uninit(); + let mut kernel_time = core::mem::MaybeUninit::uninit(); + let mut user_time = core::mem::MaybeUninit::uninit(); + (unsafe { + windows_sys::Win32::System::Threading::GetProcessTimes( + windows_sys::Win32::System::Threading::GetCurrentProcess(), + creation_time.as_mut_ptr(), + exit_time.as_mut_ptr(), + kernel_time.as_mut_ptr(), + user_time.as_mut_ptr(), + ) + } != 0) + .then(|| unsafe { + ProcessTimes100ns { + user: u64_from_filetime(user_time.assume_init()), + system: u64_from_filetime(kernel_time.assume_init()), + } + }) } #[cfg(any(unix, windows))] -#[must_use] pub fn asctime_from_tm(tm: &libc::tm) -> String { const WDAY_NAME: [&str; 7] = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"]; const MON_NAME: [&str; 12] = [ @@ -41,3 +447,196 @@ pub fn asctime_from_tm(tm: &libc::tm) -> String { tm.tm_year + 1900 ) } + +#[cfg(any(unix, windows))] +#[derive(Clone, Debug)] +pub struct CheckedTm { + pub tm: libc::tm, + #[cfg(unix)] + pub zone: Option, +} + +#[cfg(any(unix, windows))] +#[derive(Clone, Debug)] +pub struct CheckedTmParts { + pub year: i64, + pub tm_mon: i32, + pub tm_mday: i32, + pub tm_hour: i32, + pub tm_min: i32, + pub tm_sec: i32, + pub tm_wday: i32, + pub tm_yday: i32, + pub tm_isdst: i32, + #[cfg(unix)] + pub zone: Option, + #[cfg(unix)] + pub gmtoff: Option, +} + +#[cfg(any(unix, windows))] +#[derive(Clone, Copy, Debug)] +pub struct MktimeTmParts { + pub year: i32, + pub tm_sec: i32, + pub tm_min: i32, + pub tm_hour: i32, + pub tm_mday: i32, + pub tm_mon: i32, + pub tm_yday: i32, + pub tm_isdst: i32, +} + +#[cfg(any(unix, windows))] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum CheckedTmError { + YearOutOfRange, + MonthOutOfRange, + DayOfMonthOutOfRange, + HourOutOfRange, + MinuteOutOfRange, + SecondsOutOfRange, + DayOfWeekOutOfRange, + DayOfYearOutOfRange, + EmbeddedNul, +} + +#[cfg(any(unix, windows))] +pub fn checked_tm_from_parts(parts: CheckedTmParts) -> Result { + if parts.year < i64::from(i32::MIN) + 1900 || parts.year > i64::from(i32::MAX) { + return Err(CheckedTmError::YearOutOfRange); + } + + let mut tm: libc::tm = unsafe { core::mem::zeroed() }; + tm.tm_year = parts.year as i32 - 1900; + tm.tm_mon = parts.tm_mon; + tm.tm_mday = parts.tm_mday; + tm.tm_hour = parts.tm_hour; + tm.tm_min = parts.tm_min; + tm.tm_sec = parts.tm_sec; + tm.tm_wday = parts.tm_wday; + tm.tm_yday = parts.tm_yday; + tm.tm_isdst = parts.tm_isdst; + + if tm.tm_mon == -1 { + tm.tm_mon = 0; + } else if !(0..=11).contains(&tm.tm_mon) { + return Err(CheckedTmError::MonthOutOfRange); + } + if tm.tm_mday == 0 { + tm.tm_mday = 1; + } else if !(0..=31).contains(&tm.tm_mday) { + return Err(CheckedTmError::DayOfMonthOutOfRange); + } + if !(0..=23).contains(&tm.tm_hour) { + return Err(CheckedTmError::HourOutOfRange); + } + if !(0..=59).contains(&tm.tm_min) { + return Err(CheckedTmError::MinuteOutOfRange); + } + if !(0..=61).contains(&tm.tm_sec) { + return Err(CheckedTmError::SecondsOutOfRange); + } + if tm.tm_wday < 0 { + return Err(CheckedTmError::DayOfWeekOutOfRange); + } + if tm.tm_yday == -1 { + tm.tm_yday = 0; + } else if !(0..=365).contains(&tm.tm_yday) { + return Err(CheckedTmError::DayOfYearOutOfRange); + } + + #[cfg(unix)] + { + let zone = match parts.zone { + Some(zone) => Some(CString::new(zone).map_err(|_| CheckedTmError::EmbeddedNul)?), + None => None, + }; + if let Some(zone) = &zone { + tm.tm_zone = zone.as_ptr().cast_mut(); + } + if let Some(gmtoff) = parts.gmtoff { + tm.tm_gmtoff = gmtoff as _; + } + Ok(CheckedTm { tm, zone }) + } + #[cfg(windows)] + { + Ok(CheckedTm { tm }) + } +} + +#[cfg(any(unix, windows))] +pub fn mktime_tm_from_parts(parts: MktimeTmParts) -> Result { + if parts.year < i32::MIN + 1900 { + return Err(CheckedTmError::YearOutOfRange); + } + let mut tm: libc::tm = unsafe { core::mem::zeroed() }; + tm.tm_sec = parts.tm_sec; + tm.tm_min = parts.tm_min; + tm.tm_hour = parts.tm_hour; + tm.tm_mday = parts.tm_mday; + tm.tm_mon = parts.tm_mon - 1; + tm.tm_year = parts.year - 1900; + tm.tm_wday = -1; + tm.tm_yday = parts.tm_yday - 1; + tm.tm_isdst = parts.tm_isdst; + Ok(tm) +} + +#[cfg(unix)] +pub fn strftime_ascii(fmt: &str, tm: &libc::tm) -> Result { + let fmt_c = CString::new(fmt).map_err(|_| CheckedTmError::EmbeddedNul)?; + let mut size = 1024usize; + let max_scale = 256usize.saturating_mul(fmt.len().max(1)); + loop { + let mut out = vec![0u8; size]; + let written = unsafe { + libc::strftime( + out.as_mut_ptr().cast(), + out.len(), + fmt_c.as_ptr(), + tm as *const libc::tm, + ) + }; + if written > 0 || size >= max_scale { + return Ok(String::from_utf8_lossy(&out[..written]).into_owned()); + } + size = size.saturating_mul(2); + } +} + +#[cfg(windows)] +unsafe extern "C" { + fn wcsftime( + s: *mut libc::wchar_t, + max: libc::size_t, + format: *const libc::wchar_t, + tm: *const libc::tm, + ) -> libc::size_t; +} + +#[cfg(windows)] +pub fn strftime_ascii(fmt: &str, tm: &libc::tm) -> Result { + if fmt.contains('\0') { + return Err(CheckedTmError::EmbeddedNul); + } + let fmt_wide: Vec = fmt.encode_utf16().chain(core::iter::once(0)).collect(); + let mut size = 1024usize; + let max_scale = 256usize.saturating_mul(fmt.len().max(1)); + loop { + let mut out = vec![0u16; size]; + let written = unsafe { + crate::suppress_iph!(wcsftime( + out.as_mut_ptr(), + out.len(), + fmt_wide.as_ptr(), + tm as *const libc::tm, + )) + }; + if written > 0 || size >= max_scale { + return Ok(String::from_utf16_lossy(&out[..written])); + } + size = size.saturating_mul(2); + } +} diff --git a/crates/host_env/src/winapi.rs b/crates/host_env/src/winapi.rs index 70d7feea01d..4e4536c3518 100644 --- a/crates/host_env/src/winapi.rs +++ b/crates/host_env/src/winapi.rs @@ -1,15 +1,799 @@ -use windows_sys::Win32::Foundation::HANDLE; +#![allow( + clippy::not_unsafe_ptr_arg_deref, + reason = "This module mirrors Win32 APIs with raw handle and pointer parameters." +)] + +use std::{io, path::Path}; +use windows_sys::Win32::{ + Foundation::{HANDLE, HMODULE, WAIT_FAILED}, + System::Threading::PROCESS_INFORMATION, +}; + +use crate::windows::{CheckWin32Bool, CheckWin32Handle}; + +pub use windows_sys::Win32::{ + Foundation::{ + DUPLICATE_CLOSE_SOURCE, DUPLICATE_SAME_ACCESS, ERROR_ACCESS_DENIED, ERROR_ALREADY_EXISTS, + ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_NETNAME_DELETED, ERROR_NO_DATA, + ERROR_NO_SYSTEM_RESOURCES, ERROR_NOT_FOUND, ERROR_OPERATION_ABORTED, ERROR_PIPE_BUSY, + ERROR_PIPE_CONNECTED, ERROR_PORT_UNREACHABLE, ERROR_PRIVILEGE_NOT_HELD, ERROR_SEM_TIMEOUT, + ERROR_SUCCESS, GENERIC_READ, GENERIC_WRITE, STILL_ACTIVE, WAIT_ABANDONED_0, WAIT_OBJECT_0, + WAIT_TIMEOUT, + }, + Globalization::{ + LCMAP_FULLWIDTH, LCMAP_HALFWIDTH, LCMAP_HIRAGANA, LCMAP_KATAKANA, LCMAP_LINGUISTIC_CASING, + LCMAP_LOWERCASE, LCMAP_SIMPLIFIED_CHINESE, LCMAP_TITLECASE, LCMAP_TRADITIONAL_CHINESE, + LCMAP_UPPERCASE, + }, + Storage::FileSystem::{ + COPY_FILE_ALLOW_DECRYPTED_DESTINATION, COPY_FILE_COPY_SYMLINK, COPY_FILE_FAIL_IF_EXISTS, + COPY_FILE_NO_BUFFERING, COPY_FILE_NO_OFFLOAD, COPY_FILE_OPEN_SOURCE_FOR_WRITE, + COPY_FILE_REQUEST_COMPRESSED_TRAFFIC, COPY_FILE_REQUEST_SECURITY_PRIVILEGES, + COPY_FILE_RESTARTABLE, COPY_FILE_RESUME_FROM_PAUSE, COPYFILE2_CALLBACK_CHUNK_FINISHED, + COPYFILE2_CALLBACK_CHUNK_STARTED, COPYFILE2_CALLBACK_ERROR, + COPYFILE2_CALLBACK_POLL_CONTINUE, COPYFILE2_CALLBACK_STREAM_FINISHED, + COPYFILE2_CALLBACK_STREAM_STARTED, COPYFILE2_PROGRESS_CANCEL, COPYFILE2_PROGRESS_CONTINUE, + COPYFILE2_PROGRESS_PAUSE, COPYFILE2_PROGRESS_QUIET, COPYFILE2_PROGRESS_STOP, + FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, FILE_GENERIC_READ, FILE_GENERIC_WRITE, + FILE_TYPE_CHAR, FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE, OPEN_EXISTING, + PIPE_ACCESS_DUPLEX, PIPE_ACCESS_INBOUND, SYNCHRONIZE, + }, + System::{ + Console::{STD_ERROR_HANDLE, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE}, + Memory::{ + FILE_MAP_ALL_ACCESS, FILE_MAP_COPY, FILE_MAP_EXECUTE, FILE_MAP_READ, FILE_MAP_WRITE, + MEM_COMMIT, MEM_FREE, MEM_IMAGE, MEM_MAPPED, MEM_PRIVATE, MEM_RESERVE, PAGE_EXECUTE, + PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, PAGE_EXECUTE_WRITECOPY, PAGE_GUARD, + PAGE_NOACCESS, PAGE_NOCACHE, PAGE_READONLY, PAGE_READWRITE, PAGE_WRITECOMBINE, + PAGE_WRITECOPY, SEC_COMMIT, SEC_IMAGE, SEC_LARGE_PAGES, SEC_NOCACHE, SEC_RESERVE, + SEC_WRITECOMBINE, + }, + Pipes::{ + NMPWAIT_WAIT_FOREVER, PIPE_READMODE_MESSAGE, PIPE_TYPE_MESSAGE, + PIPE_UNLIMITED_INSTANCES, PIPE_WAIT, + }, + SystemServices::LOCALE_NAME_MAX_LENGTH, + Threading::{ + ABOVE_NORMAL_PRIORITY_CLASS, BELOW_NORMAL_PRIORITY_CLASS, CREATE_BREAKAWAY_FROM_JOB, + CREATE_DEFAULT_ERROR_MODE, CREATE_NEW_CONSOLE, CREATE_NEW_PROCESS_GROUP, + CREATE_NO_WINDOW, DETACHED_PROCESS, HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, + NORMAL_PRIORITY_CLASS, PROCESS_ALL_ACCESS, PROCESS_DUP_HANDLE, REALTIME_PRIORITY_CLASS, + STARTF_FORCEOFFFEEDBACK, STARTF_FORCEONFEEDBACK, STARTF_PREVENTPINNING, + STARTF_RUNFULLSCREEN, STARTF_TITLEISAPPID, STARTF_TITLEISLINKNAME, + STARTF_UNTRUSTEDSOURCE, STARTF_USECOUNTCHARS, STARTF_USEFILLATTRIBUTE, + STARTF_USEHOTKEY, STARTF_USEPOSITION, STARTF_USESHOWWINDOW, STARTF_USESIZE, + STARTF_USESTDHANDLES, + }, + }, + UI::WindowsAndMessaging::SW_HIDE, +}; + +pub type Handle = HANDLE; +pub type StdHandle = windows_sys::Win32::System::Console::STD_HANDLE; +pub type FileType = windows_sys::Win32::Storage::FileSystem::FILE_TYPE; +pub const MAX_PATH_USIZE: usize = windows_sys::Win32::Foundation::MAX_PATH as usize; +pub const INFINITE_TIMEOUT: u32 = windows_sys::Win32::System::Threading::INFINITE; +pub const CREATE_UNICODE_ENVIRONMENT_FLAG: u32 = + windows_sys::Win32::System::Threading::CREATE_UNICODE_ENVIRONMENT; +pub const EXTENDED_STARTUPINFO_PRESENT_FLAG: u32 = + windows_sys::Win32::System::Threading::EXTENDED_STARTUPINFO_PRESENT; +pub const LCMAP_BYTEREV_FLAG: u32 = windows_sys::Win32::Globalization::LCMAP_BYTEREV; +pub const LCMAP_HASH_FLAG: u32 = windows_sys::Win32::Globalization::LCMAP_HASH; +pub const LCMAP_SORTHANDLE_FLAG: u32 = windows_sys::Win32::Globalization::LCMAP_SORTHANDLE; +pub const LCMAP_SORTKEY_FLAG: u32 = windows_sys::Win32::Globalization::LCMAP_SORTKEY; + +pub struct PeekNamedPipeResult { + pub data: Option>, + pub available: u32, + pub left_this_message: u32, +} + +pub struct ReadFileResult { + pub data: Vec, + pub error: u32, +} + +pub struct WriteFileResult { + pub written: u32, + pub error: u32, +} + +pub enum BatchedWaitResult { + All, + Indices(Vec), +} + +pub enum BatchedWaitError { + Timeout, + Interrupted, + Os(u32), +} + +pub enum BuildEnvironmentBlockError { + ContainsNul, + IllegalName, +} + +pub enum MimeRegistryReadError { + Os(u32), + Callback(E), +} + +pub struct AttrList { + handlelist: Vec, + attrlist: Vec, +} + +pub struct StartupInfoData { + pub flags: u32, + pub show_window: u16, + pub std_input: HANDLE, + pub std_output: HANDLE, + pub std_error: HANDLE, +} + +pub struct ProcessInfo { + pub process: HANDLE, + pub thread: HANDLE, + pub process_id: u32, + pub thread_id: u32, +} #[must_use] pub fn get_acp() -> u32 { unsafe { windows_sys::Win32::Globalization::GetACP() } } +pub fn close_handle(handle: HANDLE) -> i32 { + unsafe { windows_sys::Win32::Foundation::CloseHandle(handle) } +} + +impl AttrList { + pub fn as_mut_ptr(&mut self) -> *mut core::ffi::c_void { + self.attrlist.as_mut_ptr().cast() + } +} + +impl Drop for AttrList { + fn drop(&mut self) { + unsafe { + windows_sys::Win32::System::Threading::DeleteProcThreadAttributeList( + self.attrlist.as_mut_ptr().cast(), + ) + }; + } +} + +pub fn create_file_w( + file_name: &widestring::WideCStr, + desired_access: u32, + share_mode: u32, + creation_disposition: u32, + flags_and_attributes: u32, +) -> io::Result { + let handle = unsafe { + windows_sys::Win32::Storage::FileSystem::CreateFileW( + file_name.as_ptr(), + desired_access, + share_mode, + core::ptr::null(), + creation_disposition, + flags_and_attributes, + core::ptr::null_mut(), + ) + }; + handle.check_valid() +} + +/// # Safety +/// `startup_info` must point to a valid `STARTUPINFOW` (or extended). +unsafe fn create_process_w_raw( + app_name: Option<&widestring::WideCStr>, + command_line: Option<&mut [u16]>, + inherit_handles: i32, + creation_flags: u32, + env: Option<&[u16]>, + current_dir: Option<&widestring::WideCStr>, + startup_info: *mut windows_sys::Win32::System::Threading::STARTUPINFOW, +) -> io::Result { + let mut procinfo = core::mem::MaybeUninit::::uninit(); + unsafe { + windows_sys::Win32::System::Threading::CreateProcessW( + app_name.map_or(core::ptr::null(), |s| s.as_ptr()), + command_line.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()), + core::ptr::null(), + core::ptr::null(), + inherit_handles, + creation_flags, + env.map_or(core::ptr::null(), |s| s.as_ptr().cast()), + current_dir.map_or(core::ptr::null(), |s| s.as_ptr()), + startup_info, + procinfo.as_mut_ptr(), + ) + } + .check_win32_bool()?; + Ok(unsafe { procinfo.assume_init() }) +} + +/// Win32 `CreateProcessW` requires `lpCommandLine` to be NUL-terminated. +/// The buffer is passed `&mut [u16]` because `CreateProcessW` may modify it +/// in place. +#[inline] +fn validate_command_line_terminated(buf: &[u16]) -> io::Result<()> { + if buf.last() == Some(&0) { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "command_line buffer passed to create_process must be NUL-terminated", + )) + } +} + +/// Win32 `CreateProcessW` with `CREATE_UNICODE_ENVIRONMENT` requires +/// `lpEnvironment` to be a sequence of `KEY=value\0` strings followed by a +/// final terminating `\0` — i.e. the block ends with two consecutive zero +/// `u16`s. +#[inline] +fn validate_environment_block_terminated(buf: &[u16]) -> io::Result<()> { + if buf.len() >= 2 && buf[buf.len() - 2..] == [0, 0] { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "env block passed to create_process must end with a double NUL terminator", + )) + } +} + +#[allow( + clippy::too_many_arguments, + reason = "This is the semantic host wrapper for Win32 CreateProcess parameters." +)] +pub fn create_process( + app_name: Option<&widestring::WideCStr>, + command_line: Option<&mut [u16]>, + inherit_handles: i32, + creation_flags: u32, + env: Option<&[u16]>, + current_dir: Option<&widestring::WideCStr>, + startup_info: StartupInfoData, + handle_list: Option>, +) -> io::Result { + if let Some(cmd) = command_line.as_deref() { + validate_command_line_terminated(cmd)?; + } + if let Some(env_block) = env { + validate_environment_block_terminated(env_block)?; + } + + let mut si: windows_sys::Win32::System::Threading::STARTUPINFOEXW = + unsafe { core::mem::zeroed() }; + si.StartupInfo.cb = core::mem::size_of_val(&si) as _; + si.StartupInfo.dwFlags = startup_info.flags; + si.StartupInfo.wShowWindow = startup_info.show_window; + si.StartupInfo.hStdInput = startup_info.std_input; + si.StartupInfo.hStdOutput = startup_info.std_output; + si.StartupInfo.hStdError = startup_info.std_error; + + let mut attrlist = create_handle_list_attribute_list(handle_list)?; + si.lpAttributeList = attrlist + .as_mut() + .map_or_else(core::ptr::null_mut, |l| l.as_mut_ptr() as _); + + let procinfo = unsafe { + create_process_w_raw( + app_name, + command_line, + inherit_handles, + creation_flags | EXTENDED_STARTUPINFO_PRESENT_FLAG | CREATE_UNICODE_ENVIRONMENT_FLAG, + env, + current_dir, + &mut si as *mut _ as *mut _, + )? + }; + + Ok(ProcessInfo { + process: procinfo.hProcess, + thread: procinfo.hThread, + process_id: procinfo.dwProcessId, + thread_id: procinfo.dwThreadId, + }) +} + +pub fn create_junction(src: &Path, dst: &Path) -> io::Result<()> { + junction::create(src, dst) +} + +pub fn build_environment_block( + entries: Vec<(String, String)>, +) -> Result, BuildEnvironmentBlockError> { + use std::collections::HashMap; + + let mut last_entry: HashMap> = HashMap::new(); + for (key, value) in entries { + if key.contains('\0') || value.contains('\0') { + return Err(BuildEnvironmentBlockError::ContainsNul); + } + if key.is_empty() || key[1..].contains('=') { + return Err(BuildEnvironmentBlockError::IllegalName); + } + + let key_upper = key.to_uppercase(); + let mut entry: Vec = key.encode_utf16().collect(); + entry.push(b'=' as u16); + entry.extend(value.encode_utf16()); + entry.push(0); + last_entry.insert(key_upper, entry); + } + + let mut entries: Vec<(String, Vec)> = last_entry.into_iter().collect(); + entries.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut out = Vec::new(); + for (_, entry) in entries { + out.extend(entry); + } + if out.is_empty() { + out.push(0); + } + out.push(0); + Ok(out) +} + +pub fn create_handle_list_attribute_list( + handlelist: Option>, +) -> io::Result> { + let Some(handlelist) = handlelist else { + return Ok(None); + }; + + let mut size = 0; + let first = unsafe { + windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList( + core::ptr::null_mut(), + 1, + 0, + &mut size, + ) + }; + if first != 0 + || unsafe { windows_sys::Win32::Foundation::GetLastError() } + != windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER + { + return Err(io::Error::last_os_error()); + } + + let mut attrs = AttrList { + handlelist, + attrlist: vec![0u8; size], + }; + unsafe { + windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList( + attrs.attrlist.as_mut_ptr().cast(), + 1, + 0, + &mut size, + ) + } + .check_win32_bool()?; + + unsafe { + windows_sys::Win32::System::Threading::UpdateProcThreadAttribute( + attrs.attrlist.as_mut_ptr().cast(), + 0, + (2 & 0xffff) | 0x20000, + attrs.handlelist.as_mut_ptr().cast(), + (attrs.handlelist.len() * core::mem::size_of::()) as _, + core::ptr::null_mut(), + core::ptr::null(), + ) + } + .check_win32_bool()?; + + Ok(Some(attrs)) +} + +pub fn get_std_handle(std_handle: StdHandle) -> io::Result> { + let handle = unsafe { windows_sys::Win32::System::Console::GetStdHandle(std_handle) }; + if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { + Err(io::Error::last_os_error()) + } else if handle.is_null() { + Ok(None) + } else { + Ok(Some(handle)) + } +} + +pub fn open_process( + desired_access: u32, + inherit_handle: bool, + process_id: u32, +) -> io::Result { + unsafe { + windows_sys::Win32::System::Threading::OpenProcess( + desired_access, + i32::from(inherit_handle), + process_id, + ) + } + .check_nonnull() +} + +pub fn create_pipe(size: u32) -> io::Result<(HANDLE, HANDLE)> { + unsafe { + let mut read = core::mem::MaybeUninit::::uninit(); + let mut write = core::mem::MaybeUninit::::uninit(); + windows_sys::Win32::System::Pipes::CreatePipe( + read.as_mut_ptr(), + write.as_mut_ptr(), + core::ptr::null(), + size, + ) + .check_win32_bool()?; + Ok((read.assume_init(), write.assume_init())) + } +} + +pub fn create_event_w( + manual_reset: bool, + initial_state: bool, + name: Option<&widestring::WideCStr>, +) -> io::Result { + let name_ptr = name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { + windows_sys::Win32::System::Threading::CreateEventW( + core::ptr::null(), + i32::from(manual_reset), + i32::from(initial_state), + name_ptr, + ) + } + .check_nonnull() +} + +pub fn set_event(handle: HANDLE) -> io::Result<()> { + unsafe { windows_sys::Win32::System::Threading::SetEvent(handle) }.check_win32_bool() +} + +pub fn reset_event(handle: HANDLE) -> io::Result<()> { + unsafe { windows_sys::Win32::System::Threading::ResetEvent(handle) }.check_win32_bool() +} + +pub fn wait_for_single_object(handle: HANDLE, milliseconds: u32) -> io::Result { + let ret = + unsafe { windows_sys::Win32::System::Threading::WaitForSingleObject(handle, milliseconds) }; + if ret == WAIT_FAILED { + Err(io::Error::last_os_error()) + } else { + Ok(ret) + } +} + +pub fn wait_for_multiple_objects( + handles: &[HANDLE], + wait_all: bool, + milliseconds: u32, +) -> io::Result { + let ret = unsafe { + windows_sys::Win32::System::Threading::WaitForMultipleObjects( + handles.len() as u32, + handles.as_ptr(), + i32::from(wait_all), + milliseconds, + ) + }; + if ret == WAIT_FAILED { + Err(io::Error::last_os_error()) + } else { + Ok(ret) + } +} + +pub fn batched_wait_for_multiple_objects( + handles: &[HANDLE], + wait_all: bool, + milliseconds: u32, + sigint_event: Option, +) -> Result { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicU32, Ordering}; + use windows_sys::Win32::{ + Foundation::{CloseHandle, WAIT_ABANDONED_0}, + System::{ + SystemInformation::GetTickCount64, + Threading::{ + CreateThread, GetExitCodeThread, INFINITE, ResumeThread, TerminateThread, + WaitForMultipleObjects, + }, + }, + }; + + const MAXIMUM_WAIT_OBJECTS: usize = 64; + let batch_size = MAXIMUM_WAIT_OBJECTS - 1; + let mut batches: Vec<&[HANDLE]> = Vec::new(); + let mut i = 0; + while i < handles.len() { + let end = core::cmp::min(i + batch_size, handles.len()); + batches.push(&handles[i..end]); + i = end; + } + + if wait_all { + let mut err = None; + let deadline = if milliseconds != INFINITE { + Some(unsafe { GetTickCount64() } + milliseconds as u64) + } else { + None + }; + + for batch in &batches { + let timeout = if let Some(deadline) = deadline { + let now = unsafe { GetTickCount64() }; + if now >= deadline { + err = Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT); + break; + } + (deadline - now) as u32 + } else { + INFINITE + }; + + let result = + unsafe { WaitForMultipleObjects(batch.len() as u32, batch.as_ptr(), 1, timeout) }; + if result == WAIT_FAILED { + err = Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }); + break; + } + if result == windows_sys::Win32::Foundation::WAIT_TIMEOUT { + err = Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT); + break; + } + + if let Some(sigint_event) = sigint_event { + let sig_result = unsafe { + windows_sys::Win32::System::Threading::WaitForSingleObject(sigint_event, 0) + }; + if sig_result == WAIT_OBJECT_0 { + err = Some(windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT); + break; + } + if sig_result == WAIT_FAILED { + err = Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }); + break; + } + } + } + + return match err { + Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT) => Err(BatchedWaitError::Timeout), + Some(windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT) => { + Err(BatchedWaitError::Interrupted) + } + Some(err) => Err(BatchedWaitError::Os(err)), + None => Ok(BatchedWaitResult::All), + }; + } + + let cancel_event = create_event_w(true, false, None) + .map_err(|err| BatchedWaitError::Os(err.raw_os_error().unwrap_or_default() as u32))?; + + struct BatchData { + handles: Vec, + cancel_event: HANDLE, + handle_base: usize, + result: AtomicU32, + thread: core::cell::UnsafeCell, + } + + unsafe impl Send for BatchData {} + unsafe impl Sync for BatchData {} + + extern "system" fn batch_wait_thread(param: *mut core::ffi::c_void) -> u32 { + let data = unsafe { &*(param as *const BatchData) }; + let result = unsafe { + windows_sys::Win32::System::Threading::WaitForMultipleObjects( + data.handles.len() as u32, + data.handles.as_ptr(), + 0, + windows_sys::Win32::System::Threading::INFINITE, + ) + }; + data.result.store(result, Ordering::SeqCst); + + if result == WAIT_FAILED { + let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; + let _ = set_event(data.cancel_event); + err + } else if (WAIT_ABANDONED_0..WAIT_ABANDONED_0 + MAXIMUM_WAIT_OBJECTS as u32) + .contains(&result) + { + data.result.store(WAIT_FAILED, Ordering::SeqCst); + let _ = set_event(data.cancel_event); + windows_sys::Win32::Foundation::ERROR_ABANDONED_WAIT_0 + } else { + 0 + } + } + + let batch_data: Vec> = batches + .iter() + .enumerate() + .map(|(idx, batch)| { + let base = idx * batch_size; + let mut handles_with_cancel = batch.to_vec(); + handles_with_cancel.push(cancel_event); + Arc::new(BatchData { + handles: handles_with_cancel, + cancel_event, + handle_base: base, + result: AtomicU32::new(WAIT_FAILED), + thread: core::cell::UnsafeCell::new(core::ptr::null_mut()), + }) + }) + .collect(); + + let mut thread_handles: Vec = Vec::new(); + for data in &batch_data { + let thread = unsafe { + CreateThread( + core::ptr::null(), + 1, + Some(batch_wait_thread), + Arc::as_ptr(data) as *const _ as *mut _, + 4, + core::ptr::null_mut(), + ) + }; + if thread.is_null() { + for &handle in &thread_handles { + unsafe { TerminateThread(handle, 0) }; + unsafe { CloseHandle(handle) }; + } + unsafe { CloseHandle(cancel_event) }; + return Err(BatchedWaitError::Os( + io::Error::last_os_error() + .raw_os_error() + .unwrap_or_default() as u32, + )); + } + unsafe { *data.thread.get() = thread }; + thread_handles.push(thread); + } + + for &thread in &thread_handles { + unsafe { ResumeThread(thread) }; + } + + let mut thread_handles_raw = thread_handles.clone(); + if let Some(sigint_event) = sigint_event { + thread_handles_raw.push(sigint_event); + } + let result = unsafe { + WaitForMultipleObjects( + thread_handles_raw.len() as u32, + thread_handles_raw.as_ptr(), + 0, + milliseconds, + ) + }; + + let err = if result == WAIT_FAILED { + Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }) + } else if result == windows_sys::Win32::Foundation::WAIT_TIMEOUT { + Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT) + } else if sigint_event.is_some() + && result == WAIT_OBJECT_0 + (thread_handles_raw.len() - 1) as u32 + { + Some(windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT) + } else { + None + }; + + let _ = set_event(cancel_event); + unsafe { + WaitForMultipleObjects( + thread_handles.len() as u32, + thread_handles.as_ptr(), + 1, + INFINITE, + ) + }; + + let mut thread_err = err; + for data in &batch_data { + if thread_err.is_none() && data.result.load(Ordering::SeqCst) == WAIT_FAILED { + let mut exit_code = 0; + let thread = unsafe { *data.thread.get() }; + if unsafe { GetExitCodeThread(thread, &mut exit_code) } == 0 { + thread_err = Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }); + } else if exit_code != 0 { + thread_err = Some(exit_code); + } + } + let thread = unsafe { *data.thread.get() }; + unsafe { CloseHandle(thread) }; + } + unsafe { CloseHandle(cancel_event) }; + + match thread_err { + Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT) => Err(BatchedWaitError::Timeout), + Some(windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT) => { + Err(BatchedWaitError::Interrupted) + } + Some(err) => Err(BatchedWaitError::Os(err)), + None => { + let mut triggered_indices = Vec::new(); + for data in &batch_data { + let result = data.result.load(Ordering::SeqCst); + let triggered = result as i32 - WAIT_OBJECT_0 as i32; + if triggered >= 0 && (triggered as usize) < data.handles.len() - 1 { + triggered_indices.push(data.handle_base + triggered as usize); + } + } + Ok(BatchedWaitResult::Indices(triggered_indices)) + } + } +} + +pub fn duplicate_handle( + src_process: HANDLE, + src: HANDLE, + target_process: HANDLE, + access: u32, + inherit: i32, + options: u32, +) -> io::Result { + let target = unsafe { + let mut target = core::mem::MaybeUninit::::uninit(); + let ok = windows_sys::Win32::Foundation::DuplicateHandle( + src_process, + src, + target_process, + target.as_mut_ptr(), + access, + inherit, + options, + ); + if ok == 0 { + return Err(io::Error::last_os_error()); + } + target.assume_init() + }; + Ok(target) +} + #[must_use] pub fn get_current_process() -> HANDLE { unsafe { windows_sys::Win32::System::Threading::GetCurrentProcess() } } +pub fn get_exit_code_process(handle: HANDLE) -> io::Result { + let mut exit_code = core::mem::MaybeUninit::::uninit(); + unsafe { + windows_sys::Win32::System::Threading::GetExitCodeProcess(handle, exit_code.as_mut_ptr()) + } + .check_win32_bool()?; + Ok(unsafe { exit_code.assume_init() }) +} + +pub fn get_file_type(handle: HANDLE) -> io::Result { + let file_type = unsafe { windows_sys::Win32::Storage::FileSystem::GetFileType(handle) }; + if file_type == 0 && unsafe { windows_sys::Win32::Foundation::GetLastError() } != 0 { + Err(io::Error::last_os_error()) + } else { + Ok(file_type) + } +} + +pub fn terminate_process(handle: HANDLE, exit_code: u32) -> i32 { + unsafe { windows_sys::Win32::System::Threading::TerminateProcess(handle, exit_code) } +} + +pub fn exit_process(exit_code: u32) -> ! { + unsafe { windows_sys::Win32::System::Threading::ExitProcess(exit_code) } +} + #[must_use] pub fn get_last_error() -> u32 { unsafe { windows_sys::Win32::Foundation::GetLastError() } @@ -19,3 +803,557 @@ pub fn get_last_error() -> u32 { pub fn get_version() -> u32 { unsafe { windows_sys::Win32::System::SystemInformation::GetVersion() } } + +pub fn create_job_object_w(name: Option<&widestring::WideCStr>) -> io::Result { + let name_ptr = name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { windows_sys::Win32::System::JobObjects::CreateJobObjectW(core::ptr::null(), name_ptr) } + .check_nonnull() +} + +pub fn assign_process_to_job_object(job: HANDLE, process: HANDLE) -> io::Result<()> { + unsafe { windows_sys::Win32::System::JobObjects::AssignProcessToJobObject(job, process) } + .check_win32_bool() +} + +pub fn terminate_job_object(job: HANDLE, exit_code: u32) -> io::Result<()> { + unsafe { windows_sys::Win32::System::JobObjects::TerminateJobObject(job, exit_code) } + .check_win32_bool() +} + +pub fn set_job_object_kill_on_close(job: HANDLE) -> io::Result<()> { + use windows_sys::Win32::System::JobObjects::{ + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, + JobObjectExtendedLimitInformation, SetInformationJobObject, + }; + + let mut info: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = unsafe { core::mem::zeroed() }; + info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + unsafe { + SetInformationJobObject( + job, + JobObjectExtendedLimitInformation, + (&info as *const JOBOBJECT_EXTENDED_LIMIT_INFORMATION).cast(), + core::mem::size_of::() as u32, + ) + } + .check_win32_bool() +} + +pub fn get_module_file_name(module: HMODULE, buffer: &mut [u16]) -> u32 { + unsafe { + windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW( + module, + buffer.as_mut_ptr(), + buffer.len() as u32, + ) + } +} + +pub fn get_short_path_name_w(path: &widestring::WideCStr) -> io::Result> { + get_path_name_impl( + path, + windows_sys::Win32::Storage::FileSystem::GetShortPathNameW, + ) +} + +pub fn get_long_path_name_w(path: &widestring::WideCStr) -> io::Result> { + get_path_name_impl( + path, + windows_sys::Win32::Storage::FileSystem::GetLongPathNameW, + ) +} + +fn get_path_name_impl( + path: &widestring::WideCStr, + api_fn: unsafe extern "system" fn(*const u16, *mut u16, u32) -> u32, +) -> io::Result> { + let size = unsafe { api_fn(path.as_ptr(), core::ptr::null_mut(), 0) }; + if size == 0 { + return Err(io::Error::last_os_error()); + } + + let mut buffer = vec![0u16; size as usize]; + let result = unsafe { api_fn(path.as_ptr(), buffer.as_mut_ptr(), buffer.len() as u32) }; + if result == 0 { + return Err(io::Error::last_os_error()); + } + buffer.truncate(result as usize); + Ok(buffer) +} + +pub fn open_mutex_w( + desired_access: u32, + inherit_handle: bool, + name: &widestring::WideCStr, +) -> io::Result { + unsafe { + windows_sys::Win32::System::Threading::OpenMutexW( + desired_access, + i32::from(inherit_handle), + name.as_ptr(), + ) + } + .check_nonnull() +} + +pub fn release_mutex(handle: HANDLE) -> i32 { + unsafe { windows_sys::Win32::System::Threading::ReleaseMutex(handle) } +} + +pub fn create_named_pipe_w( + name: &widestring::WideCStr, + open_mode: u32, + pipe_mode: u32, + max_instances: u32, + out_buffer_size: u32, + in_buffer_size: u32, + default_timeout: u32, +) -> io::Result { + unsafe { + windows_sys::Win32::System::Pipes::CreateNamedPipeW( + name.as_ptr(), + open_mode, + pipe_mode, + max_instances, + out_buffer_size, + in_buffer_size, + default_timeout, + core::ptr::null(), + ) + } + .check_valid() +} + +pub fn create_file_mapping_w( + file_handle: HANDLE, + protect: u32, + max_size_high: u32, + max_size_low: u32, + name: Option<&widestring::WideCStr>, +) -> io::Result { + let name_ptr = name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { + windows_sys::Win32::System::Memory::CreateFileMappingW( + file_handle, + core::ptr::null(), + protect, + max_size_high, + max_size_low, + name_ptr, + ) + } + .check_nonnull() +} + +pub fn open_file_mapping_w( + desired_access: u32, + inherit_handle: bool, + name: &widestring::WideCStr, +) -> io::Result { + unsafe { + windows_sys::Win32::System::Memory::OpenFileMappingW( + desired_access, + i32::from(inherit_handle), + name.as_ptr(), + ) + } + .check_nonnull() +} + +pub fn map_view_of_file( + file_map: HANDLE, + desired_access: u32, + file_offset_high: u32, + file_offset_low: u32, + number_bytes: usize, +) -> io::Result { + let address = unsafe { + windows_sys::Win32::System::Memory::MapViewOfFile( + file_map, + desired_access, + file_offset_high, + file_offset_low, + number_bytes, + ) + }; + let ptr = address.Value; + if ptr.is_null() { + Err(io::Error::last_os_error()) + } else { + Ok(ptr as isize) + } +} + +pub fn unmap_view_of_file(address: isize) -> io::Result<()> { + let view = windows_sys::Win32::System::Memory::MEMORY_MAPPED_VIEW_ADDRESS { + Value: address as *mut core::ffi::c_void, + }; + unsafe { windows_sys::Win32::System::Memory::UnmapViewOfFile(view) }.check_win32_bool() +} + +pub fn virtual_query_size(address: isize) -> io::Result { + let mut mbi: windows_sys::Win32::System::Memory::MEMORY_BASIC_INFORMATION = + unsafe { core::mem::zeroed() }; + let ret = unsafe { + windows_sys::Win32::System::Memory::VirtualQuery( + address as *const core::ffi::c_void, + &mut mbi, + core::mem::size_of::(), + ) + }; + if ret == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(mbi.RegionSize) + } +} + +pub fn copy_file2( + src: &widestring::WideCStr, + dst: &widestring::WideCStr, + flags: u32, +) -> io::Result<()> { + let mut params: windows_sys::Win32::Storage::FileSystem::COPYFILE2_EXTENDED_PARAMETERS = + unsafe { core::mem::zeroed() }; + params.dwSize = core::mem::size_of_val(¶ms) as u32; + params.dwCopyFlags = flags; + + let hr = unsafe { + windows_sys::Win32::Storage::FileSystem::CopyFile2(src.as_ptr(), dst.as_ptr(), ¶ms) + }; + if hr < 0 { + let err = if (hr as u32 >> 16) == 0x8007 { + (hr as u32) & 0xFFFF + } else { + hr as u32 + }; + Err(io::Error::from_raw_os_error(err as i32)) + } else { + Ok(()) + } +} + +pub fn read_windows_mimetype_registry_in_batches( + mut on_entries: F, +) -> Result<(), MimeRegistryReadError> +where + F: FnMut(&mut Vec<(String, String)>) -> Result<(), E>, +{ + use windows_sys::Win32::System::Registry::{ + HKEY, HKEY_CLASSES_ROOT, KEY_READ, REG_SZ, RegCloseKey, RegEnumKeyExW, RegOpenKeyExW, + RegQueryValueExW, + }; + + let mut hkcr: HKEY = core::ptr::null_mut(); + let err = + unsafe { RegOpenKeyExW(HKEY_CLASSES_ROOT, core::ptr::null(), 0, KEY_READ, &mut hkcr) }; + if err != 0 { + return Err(MimeRegistryReadError::Os(err)); + } + + let mut index = 0; + let mut entries = Vec::new(); + loop { + let mut ext_buf = [0u16; 128]; + let mut cch_ext = ext_buf.len() as u32; + let err = unsafe { + RegEnumKeyExW( + hkcr, + index, + ext_buf.as_mut_ptr(), + &mut cch_ext, + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + ) + }; + index += 1; + + if err == windows_sys::Win32::Foundation::ERROR_NO_MORE_ITEMS { + break; + } + if err != 0 && err != windows_sys::Win32::Foundation::ERROR_MORE_DATA { + unsafe { RegCloseKey(hkcr) }; + return Err(MimeRegistryReadError::Os(err)); + } + if cch_ext == 0 || ext_buf[0] != b'.' as u16 { + continue; + } + + let ext_wide = &ext_buf[..cch_ext as usize]; + let mut subkey: HKEY = core::ptr::null_mut(); + let err = unsafe { RegOpenKeyExW(hkcr, ext_buf.as_ptr(), 0, KEY_READ, &mut subkey) }; + if err == windows_sys::Win32::Foundation::ERROR_FILE_NOT_FOUND + || err == windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED + { + continue; + } + if err != 0 { + unsafe { RegCloseKey(hkcr) }; + return Err(MimeRegistryReadError::Os(err)); + } + + let content_type_key: Vec = "Content Type\0".encode_utf16().collect(); + let mut type_buf = [0u16; 256]; + let mut cb_type = (type_buf.len() * 2) as u32; + let mut reg_type = 0; + let err = unsafe { + RegQueryValueExW( + subkey, + content_type_key.as_ptr(), + core::ptr::null_mut(), + &mut reg_type, + type_buf.as_mut_ptr().cast(), + &mut cb_type, + ) + }; + unsafe { RegCloseKey(subkey) }; + + if err != 0 || reg_type != REG_SZ || cb_type == 0 { + continue; + } + + let type_len = (cb_type as usize / 2).saturating_sub(1); + let type_str = String::from_utf16_lossy(&type_buf[..type_len]); + let ext_str = String::from_utf16_lossy(ext_wide); + if type_str.is_empty() { + continue; + } + + entries.push((type_str, ext_str)); + if entries.len() >= 64 { + on_entries(&mut entries).map_err(MimeRegistryReadError::Callback)?; + } + } + + unsafe { RegCloseKey(hkcr) }; + if !entries.is_empty() { + on_entries(&mut entries).map_err(MimeRegistryReadError::Callback)?; + } + Ok(()) +} + +pub fn lc_map_string_ex( + locale: &widestring::WideCStr, + flags: u32, + src: &[u16], +) -> io::Result> { + let src_len = src.len() as i32; + let dest_size = unsafe { + windows_sys::Win32::Globalization::LCMapStringEx( + locale.as_ptr(), + flags, + src.as_ptr(), + src_len, + core::ptr::null_mut(), + 0, + core::ptr::null(), + core::ptr::null(), + 0, + ) + }; + if dest_size <= 0 { + return Err(io::Error::last_os_error()); + } + + let mut dest = vec![0u16; dest_size as usize]; + let nmapped = unsafe { + windows_sys::Win32::Globalization::LCMapStringEx( + locale.as_ptr(), + flags, + src.as_ptr(), + src_len, + dest.as_mut_ptr(), + dest_size, + core::ptr::null(), + core::ptr::null(), + 0, + ) + }; + if nmapped <= 0 { + return Err(io::Error::last_os_error()); + } + dest.truncate(nmapped as usize); + Ok(dest) +} + +pub fn connect_named_pipe(handle: HANDLE) -> io::Result<()> { + let ret = unsafe { + windows_sys::Win32::System::Pipes::ConnectNamedPipe(handle, core::ptr::null_mut()) + }; + if ret == 0 { + let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; + if err != windows_sys::Win32::Foundation::ERROR_PIPE_CONNECTED { + return Err(io::Error::from_raw_os_error(err as i32)); + } + } + Ok(()) +} + +pub fn wait_named_pipe_w(name: &widestring::WideCStr, timeout: u32) -> io::Result<()> { + unsafe { windows_sys::Win32::System::Pipes::WaitNamedPipeW(name.as_ptr(), timeout) } + .check_win32_bool() +} + +pub fn peek_named_pipe(handle: HANDLE, size: Option) -> io::Result { + let mut available = 0; + let mut left_this_message = 0; + match size { + Some(size) => { + let mut data = vec![0u8; size as usize]; + let mut read = 0; + unsafe { + windows_sys::Win32::System::Pipes::PeekNamedPipe( + handle, + data.as_mut_ptr().cast(), + size, + &mut read, + &mut available, + &mut left_this_message, + ) + } + .check_win32_bool()?; + data.truncate(read as usize); + Ok(PeekNamedPipeResult { + data: Some(data), + available, + left_this_message, + }) + } + None => { + unsafe { + windows_sys::Win32::System::Pipes::PeekNamedPipe( + handle, + core::ptr::null_mut(), + 0, + core::ptr::null_mut(), + &mut available, + &mut left_this_message, + ) + } + .check_win32_bool()?; + Ok(PeekNamedPipeResult { + data: None, + available, + left_this_message, + }) + } + } +} + +pub fn write_file(handle: HANDLE, buffer: &[u8]) -> io::Result { + let len = core::cmp::min(buffer.len(), u32::MAX as usize) as u32; + let mut written = 0; + let ret = unsafe { + windows_sys::Win32::Storage::FileSystem::WriteFile( + handle, + buffer.as_ptr().cast(), + len, + &mut written, + core::ptr::null_mut(), + ) + }; + let err = if ret == 0 { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + } else { + 0 + }; + if ret == 0 { + Err(io::Error::from_raw_os_error(err as i32)) + } else { + Ok(WriteFileResult { + written, + error: err, + }) + } +} + +pub fn read_file(handle: HANDLE, size: u32) -> io::Result { + let mut data = vec![0u8; size as usize]; + let mut read = 0; + let ret = unsafe { + windows_sys::Win32::Storage::FileSystem::ReadFile( + handle, + data.as_mut_ptr().cast(), + size, + &mut read, + core::ptr::null_mut(), + ) + }; + let err = if ret == 0 { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + } else { + 0 + }; + if ret == 0 && err != windows_sys::Win32::Foundation::ERROR_MORE_DATA { + return Err(io::Error::from_raw_os_error(err as i32)); + } + data.truncate(read as usize); + Ok(ReadFileResult { data, error: err }) +} + +pub fn set_named_pipe_handle_state( + handle: HANDLE, + mode: Option, + max_collection_count: Option, + collect_data_timeout: Option, +) -> io::Result<()> { + let mut dw_args = [ + mode.unwrap_or_default(), + max_collection_count.unwrap_or_default(), + collect_data_timeout.unwrap_or_default(), + ]; + let mut p_args = [core::ptr::null_mut(); 3]; + for (index, arg) in [mode, max_collection_count, collect_data_timeout] + .into_iter() + .enumerate() + { + if arg.is_some() { + p_args[index] = &mut dw_args[index]; + } + } + unsafe { + windows_sys::Win32::System::Pipes::SetNamedPipeHandleState( + handle, p_args[0], p_args[1], p_args[2], + ) + } + .check_win32_bool() +} + +pub fn create_mutex_w( + initial_owner: bool, + name: Option<&widestring::WideCStr>, +) -> io::Result { + let name_ptr = name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { + windows_sys::Win32::System::Threading::CreateMutexW( + core::ptr::null(), + i32::from(initial_owner), + name_ptr, + ) + } + .check_nonnull() +} + +pub fn open_event_w( + desired_access: u32, + inherit_handle: bool, + name: &widestring::WideCStr, +) -> io::Result { + unsafe { + windows_sys::Win32::System::Threading::OpenEventW( + desired_access, + i32::from(inherit_handle), + name.as_ptr(), + ) + } + .check_nonnull() +} + +pub fn need_current_directory_for_exe_path_w(exe_name: &widestring::WideCStr) -> bool { + unsafe { + windows_sys::Win32::System::Environment::NeedCurrentDirectoryForExePathW(exe_name.as_ptr()) + != 0 + } +} diff --git a/crates/host_env/src/windows.rs b/crates/host_env/src/windows.rs index 9667edf9149..bde8d679737 100644 --- a/crates/host_env/src/windows.rs +++ b/crates/host_env/src/windows.rs @@ -1,15 +1,405 @@ use rustpython_wtf8::Wtf8; use std::{ ffi::{OsStr, OsString}, + io, os::windows::ffi::{OsStrExt, OsStringExt}, }; +use windows_sys::Win32::{ + Foundation::{ + E_POINTER, ERROR_INSUFFICIENT_BUFFER, ERROR_INVALID_FLAGS, ERROR_NO_UNICODE_TRANSLATION, + MAX_PATH, S_OK, + }, + Networking::WinSock::WSAStartup, + Storage::FileSystem::{ + GetFileVersionInfoSizeW, GetFileVersionInfoW, VS_FIXEDFILEINFO, VerQueryValueW, + }, + System::{ + Diagnostics::Debug::{ + FORMAT_MESSAGE_ALLOCATE_BUFFER, FORMAT_MESSAGE_FROM_SYSTEM, + FORMAT_MESSAGE_IGNORE_INSERTS, FormatMessageW, + }, + LibraryLoader::{GetModuleFileNameW, GetModuleHandleW}, + SystemInformation::{GetVersionExW, OSVERSIONINFOEXW, OSVERSIONINFOW}, + Threading::{GetCurrentThreadStackLimits, SetThreadStackGuarantee}, + }, +}; /// _MAX_ENV from Windows CRT stdlib.h - maximum environment variable size pub const _MAX_ENV: usize = 32767; +pub const HRESULT_E_POINTER: i32 = E_POINTER; +pub const HRESULT_S_OK: i32 = S_OK; +pub const CP_ACP: u32 = windows_sys::Win32::Globalization::CP_ACP; +pub const CP_OEMCP: u32 = windows_sys::Win32::Globalization::CP_OEMCP; +pub const CP_UTF7: u32 = windows_sys::Win32::Globalization::CP_UTF7; +pub const CP_UTF8: u32 = windows_sys::Win32::Globalization::CP_UTF8; +pub const MB_ERR_INVALID_CHARS: u32 = windows_sys::Win32::Globalization::MB_ERR_INVALID_CHARS; +pub const WC_ERR_INVALID_CHARS: u32 = windows_sys::Win32::Globalization::WC_ERR_INVALID_CHARS; +pub const WC_NO_BEST_FIT_CHARS: u32 = windows_sys::Win32::Globalization::WC_NO_BEST_FIT_CHARS; +pub const ERROR_INVALID_FLAGS_I32: i32 = ERROR_INVALID_FLAGS as i32; +pub const ERROR_NO_UNICODE_TRANSLATION_I32: i32 = ERROR_NO_UNICODE_TRANSLATION as i32; +pub const ERROR_INSUFFICIENT_BUFFER_I32: i32 = ERROR_INSUFFICIENT_BUFFER as i32; + +pub fn init_winsock() { + static WSA_INIT: parking_lot::Once = parking_lot::Once::new(); + WSA_INIT.call_once(|| unsafe { + let mut wsa_data = core::mem::MaybeUninit::uninit(); + let _ = WSAStartup(0x0101, wsa_data.as_mut_ptr()); + }) +} + +/// Win32 BOOL convention: 0 = failure, nonzero = success. +/// Reads error via `GetLastError()` through [`io::Error::last_os_error`]. +pub trait CheckWin32Bool { + fn check_win32_bool(self) -> io::Result<()>; +} + +impl CheckWin32Bool for i32 { + #[inline] + fn check_win32_bool(self) -> io::Result<()> { + if self == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } +} + +/// Convenience checks for Win32 `HANDLE` return values. +pub trait CheckWin32Handle: Sized { + /// Returns `Ok(self)` if the handle is non-NULL, otherwise the last OS error. + /// Use for APIs whose failure sentinel is NULL (Create*Event, Create*Mutex, + /// CreateFileMapping, GetModuleHandle, etc.). + fn check_nonnull(self) -> io::Result; + + /// Returns `Ok(self)` unless the handle equals `INVALID_HANDLE_VALUE`. + /// Use for APIs whose failure sentinel is `INVALID_HANDLE_VALUE` + /// (CreateFileW, CreateNamedPipeW, etc.). + fn check_valid(self) -> io::Result; +} + +impl CheckWin32Handle for windows_sys::Win32::Foundation::HANDLE { + #[inline] + fn check_nonnull(self) -> io::Result { + if self.is_null() { + Err(io::Error::last_os_error()) + } else { + Ok(self) + } + } + + #[inline] + fn check_valid(self) -> io::Result { + if self == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { + Err(io::Error::last_os_error()) + } else { + Ok(self) + } + } +} + +/// Generic sentinel check for Win32 return values not covered by [`CheckWin32Bool`] or +/// [`CheckWin32Handle`]. Use for APIs that signal failure with a specific value +/// (`INVALID_FILE_ATTRIBUTES`, `WAIT_FAILED`, `GetFileVersionInfoSizeW` returning `0`, etc.). +pub trait CheckWin32Sentinel: Sized + Copy + PartialEq { + /// Returns `Ok(self)` if `self != fail`, otherwise the last OS error. + #[inline] + fn check_ne(self, fail: Self) -> io::Result { + if self == fail { + Err(io::Error::last_os_error()) + } else { + Ok(self) + } + } +} + +impl CheckWin32Sentinel for u32 {} +impl CheckWin32Sentinel for i32 {} +impl CheckWin32Sentinel for u64 {} + +use std::os::windows::io::FromRawHandle; +pub use std::os::windows::io::{BorrowedHandle, OwnedHandle}; + +/// Conversion of raw Win32 `HANDLE` into an [`OwnedHandle`] (RAII auto-close). +pub trait HandleToOwned: Sized { + /// Wraps the handle in an [`OwnedHandle`] that calls `CloseHandle` on drop. + /// Returns `None` if the handle is NULL or `INVALID_HANDLE_VALUE`. + fn into_owned(self) -> Option; +} + +impl HandleToOwned for windows_sys::Win32::Foundation::HANDLE { + #[inline] + fn into_owned(self) -> Option { + if self.is_null() || self == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { + None + } else { + // SAFETY: caller is asserting via the public `Create*`/`Open*` paths + // that this handle is owned and unaliased. + Some(unsafe { OwnedHandle::from_raw_handle(self.cast()) }) + } + } +} + +#[derive(Clone, Debug)] +pub struct WindowsVersionInfo { + pub major: u32, + pub minor: u32, + pub build: u32, + pub platform: u32, + pub service_pack: String, + pub service_pack_major: u16, + pub service_pack_minor: u16, + pub suite_mask: u16, + pub product_type: u8, +} + +fn get_kernel32_version() -> io::Result<(u32, u32, u32)> { + unsafe { + let module_name: Vec = OsStr::new("kernel32.dll").to_wide_with_nul(); + let h_kernel32 = GetModuleHandleW(module_name.as_ptr()).check_nonnull()?; + + let mut kernel32_path = [0u16; MAX_PATH as usize]; + let len = GetModuleFileNameW( + h_kernel32, + kernel32_path.as_mut_ptr(), + kernel32_path.len() as u32, + ); + if len == 0 { + return Err(io::Error::last_os_error()); + } + + let ver_block_size = GetFileVersionInfoSizeW(kernel32_path.as_ptr(), core::ptr::null_mut()); + if ver_block_size == 0 { + return Err(io::Error::last_os_error()); + } + + let mut ver_block = vec![0u8; ver_block_size as usize]; + GetFileVersionInfoW( + kernel32_path.as_ptr(), + 0, + ver_block_size, + ver_block.as_mut_ptr() as *mut _, + ) + .check_win32_bool()?; + + let sub_block: Vec = OsStr::new("").to_wide_with_nul(); + + let mut ffi_ptr: *mut VS_FIXEDFILEINFO = core::ptr::null_mut(); + let mut ffi_len: u32 = 0; + VerQueryValueW( + ver_block.as_ptr() as *const _, + sub_block.as_ptr(), + &mut ffi_ptr as *mut *mut VS_FIXEDFILEINFO as *mut *mut _, + &mut ffi_len as *mut u32, + ) + .check_win32_bool()?; + if ffi_ptr.is_null() { + return Err(io::Error::last_os_error()); + } + + let ffi = *ffi_ptr; + let real_major = (ffi.dwProductVersionMS >> 16) & 0xFFFF; + let real_minor = ffi.dwProductVersionMS & 0xFFFF; + let real_build = (ffi.dwProductVersionLS >> 16) & 0xFFFF; + + Ok((real_major, real_minor, real_build)) + } +} + +pub fn get_windows_version() -> io::Result { + let mut version: OSVERSIONINFOEXW = unsafe { core::mem::zeroed() }; + version.dwOSVersionInfoSize = core::mem::size_of::() as u32; + unsafe { + let os_vi = &mut version as *mut OSVERSIONINFOEXW as *mut OSVERSIONINFOW; + GetVersionExW(os_vi) + } + .check_win32_bool()?; + + let service_pack = { + let (last, _) = version + .szCSDVersion + .iter() + .take_while(|&x| x != &0) + .enumerate() + .last() + .unwrap_or((0, &0)); + let sp = OsString::from_wide(&version.szCSDVersion[..last]); + sp.into_string() + .map_err(|_| io::Error::other("service pack is not ASCII"))? + }; + let (major, minor, build) = get_kernel32_version()?; + Ok(WindowsVersionInfo { + major, + minor, + build, + platform: version.dwPlatformId, + service_pack, + service_pack_major: version.wServicePackMajor, + service_pack_minor: version.wServicePackMinor, + suite_mask: version.wSuiteMask, + product_type: version.wProductType, + }) +} + +pub fn current_thread_stack_bounds() -> (usize, usize) { + let mut low: usize = 0; + let mut high: usize = 0; + unsafe { + GetCurrentThreadStackLimits(&mut low as *mut usize, &mut high as *mut usize); + let mut guarantee: u32 = 0; + SetThreadStackGuarantee(&mut guarantee); + low += guarantee as usize; + } + (low, high) +} + +pub fn set_last_error(error: u32) { + unsafe { windows_sys::Win32::Foundation::SetLastError(error) } +} + +pub fn get_last_error() -> u32 { + unsafe { windows_sys::Win32::Foundation::GetLastError() } +} + +pub fn format_error_message(code: Option) -> Option { + let error_code = code.unwrap_or_else(get_last_error); + let mut buffer: *mut u16 = core::ptr::null_mut(); + let len = unsafe { + FormatMessageW( + FORMAT_MESSAGE_ALLOCATE_BUFFER + | FORMAT_MESSAGE_FROM_SYSTEM + | FORMAT_MESSAGE_IGNORE_INSERTS, + core::ptr::null(), + error_code, + 0, + &mut buffer as *mut *mut u16 as *mut u16, + 0, + core::ptr::null(), + ) + }; + + if len == 0 || buffer.is_null() { + return None; + } + + let message = unsafe { + let slice = core::slice::from_raw_parts(buffer, len as usize); + let msg = String::from_utf16_lossy(slice).trim_end().to_string(); + windows_sys::Win32::Foundation::LocalFree(buffer as *mut _); + msg + }; + Some(message) +} + +pub fn wide_char_to_multi_byte_len( + code_page: u32, + flags: u32, + wide: &[u16], + track_default_char: bool, +) -> io::Result<(usize, bool)> { + let mut used_default_char = 0i32; + let pused = if track_default_char { + &mut used_default_char as *mut i32 + } else { + core::ptr::null_mut() + }; + let size = unsafe { + windows_sys::Win32::Globalization::WideCharToMultiByte( + code_page, + flags, + wide.as_ptr(), + wide.len() as i32, + core::ptr::null_mut(), + 0, + core::ptr::null(), + pused, + ) + }; + if size <= 0 { + Err(io::Error::last_os_error()) + } else { + Ok((size as usize, used_default_char != 0)) + } +} + +pub fn wide_char_to_multi_byte( + code_page: u32, + flags: u32, + wide: &[u16], + out: &mut [u8], + track_default_char: bool, +) -> io::Result<(usize, bool)> { + let mut used_default_char = 0i32; + let pused = if track_default_char { + &mut used_default_char as *mut i32 + } else { + core::ptr::null_mut() + }; + let size = unsafe { + windows_sys::Win32::Globalization::WideCharToMultiByte( + code_page, + flags, + wide.as_ptr(), + wide.len() as i32, + out.as_mut_ptr().cast(), + out.len() as i32, + core::ptr::null(), + pused, + ) + }; + if size <= 0 { + Err(io::Error::last_os_error()) + } else { + Ok((size as usize, used_default_char != 0)) + } +} + +pub fn multi_byte_to_wide_len(code_page: u32, flags: u32, bytes: &[u8]) -> io::Result { + let size = unsafe { + windows_sys::Win32::Globalization::MultiByteToWideChar( + code_page, + flags, + bytes.as_ptr().cast(), + bytes.len() as i32, + core::ptr::null_mut(), + 0, + ) + }; + if size <= 0 { + Err(io::Error::last_os_error()) + } else { + Ok(size as usize) + } +} + +pub fn multi_byte_to_wide( + code_page: u32, + flags: u32, + bytes: &[u8], + out: &mut [u16], +) -> io::Result { + let size = unsafe { + windows_sys::Win32::Globalization::MultiByteToWideChar( + code_page, + flags, + bytes.as_ptr().cast(), + bytes.len() as i32, + out.as_mut_ptr(), + out.len() as i32, + ) + }; + if size <= 0 { + Err(io::Error::last_os_error()) + } else { + Ok(size as usize) + } +} pub trait ToWideString { fn to_wide(&self) -> Vec; fn to_wide_with_nul(&self) -> Vec; + fn to_wide_cstring(&self) -> widestring::WideCString { + widestring::WideCString::from_vec_truncate(self.to_wide()) + } } impl ToWideString for T diff --git a/crates/host_env/src/winreg.rs b/crates/host_env/src/winreg.rs new file mode 100644 index 00000000000..5324ac258e2 --- /dev/null +++ b/crates/host_env/src/winreg.rs @@ -0,0 +1,524 @@ +#![allow( + clippy::missing_safety_doc, + reason = "This module intentionally exposes raw Win32 registry wrappers." +)] +#![allow( + clippy::not_unsafe_ptr_arg_deref, + reason = "These wrappers mirror Win32 APIs that operate on caller-provided pointers." +)] +#![allow( + clippy::too_many_arguments, + reason = "These helpers preserve the underlying Win32 registry call shapes." +)] + +extern crate alloc; + +use alloc::string::FromUtf16Error; +use std::ffi::OsStr; + +use crate::windows::ToWideString; +use windows_sys::Win32::{ + Foundation, + Security::SECURITY_ATTRIBUTES, + System::{Environment, Registry}, +}; + +pub type HKEY = Registry::HKEY; +pub const ERROR_MORE_DATA: u32 = Foundation::ERROR_MORE_DATA; +pub const ERROR_INVALID_HANDLE: u32 = Foundation::ERROR_INVALID_HANDLE; +pub const HKEY_CLASSES_ROOT: HKEY = Registry::HKEY_CLASSES_ROOT; +pub const HKEY_CURRENT_USER: HKEY = Registry::HKEY_CURRENT_USER; +pub const HKEY_LOCAL_MACHINE: HKEY = Registry::HKEY_LOCAL_MACHINE; +pub const HKEY_USERS: HKEY = Registry::HKEY_USERS; +pub const HKEY_PERFORMANCE_DATA: HKEY = Registry::HKEY_PERFORMANCE_DATA; +pub const HKEY_CURRENT_CONFIG: HKEY = Registry::HKEY_CURRENT_CONFIG; +pub const HKEY_DYN_DATA: HKEY = Registry::HKEY_DYN_DATA; + +pub const KEY_ALL_ACCESS: u32 = Registry::KEY_ALL_ACCESS; +pub const KEY_CREATE_LINK: u32 = Registry::KEY_CREATE_LINK; +pub const KEY_CREATE_SUB_KEY: u32 = Registry::KEY_CREATE_SUB_KEY; +pub const KEY_ENUMERATE_SUB_KEYS: u32 = Registry::KEY_ENUMERATE_SUB_KEYS; +pub const KEY_EXECUTE: u32 = Registry::KEY_EXECUTE; +pub const KEY_NOTIFY: u32 = Registry::KEY_NOTIFY; +pub const KEY_QUERY_VALUE: u32 = Registry::KEY_QUERY_VALUE; +pub const KEY_READ: u32 = Registry::KEY_READ; +pub const KEY_SET_VALUE: u32 = Registry::KEY_SET_VALUE; +pub const KEY_WOW64_32KEY: u32 = Registry::KEY_WOW64_32KEY; +pub const KEY_WOW64_64KEY: u32 = Registry::KEY_WOW64_64KEY; +pub const KEY_WRITE: u32 = Registry::KEY_WRITE; + +pub const REG_BINARY: u32 = Registry::REG_BINARY; +pub const REG_CREATED_NEW_KEY: u32 = Registry::REG_CREATED_NEW_KEY; +pub const REG_DWORD: u32 = Registry::REG_DWORD; +pub const REG_DWORD_BIG_ENDIAN: u32 = Registry::REG_DWORD_BIG_ENDIAN; +pub const REG_DWORD_LITTLE_ENDIAN: u32 = Registry::REG_DWORD_LITTLE_ENDIAN; +pub const REG_EXPAND_SZ: u32 = Registry::REG_EXPAND_SZ; +pub const REG_FULL_RESOURCE_DESCRIPTOR: u32 = Registry::REG_FULL_RESOURCE_DESCRIPTOR; +pub const REG_LINK: u32 = Registry::REG_LINK; +pub const REG_MULTI_SZ: u32 = Registry::REG_MULTI_SZ; +pub const REG_NONE: u32 = Registry::REG_NONE; +pub const REG_NOTIFY_CHANGE_ATTRIBUTES: u32 = Registry::REG_NOTIFY_CHANGE_ATTRIBUTES; +pub const REG_NOTIFY_CHANGE_LAST_SET: u32 = Registry::REG_NOTIFY_CHANGE_LAST_SET; +pub const REG_NOTIFY_CHANGE_NAME: u32 = Registry::REG_NOTIFY_CHANGE_NAME; +pub const REG_NOTIFY_CHANGE_SECURITY: u32 = Registry::REG_NOTIFY_CHANGE_SECURITY; +pub const REG_OPENED_EXISTING_KEY: u32 = Registry::REG_OPENED_EXISTING_KEY; +pub const REG_OPTION_BACKUP_RESTORE: u32 = Registry::REG_OPTION_BACKUP_RESTORE; +pub const REG_OPTION_CREATE_LINK: u32 = Registry::REG_OPTION_CREATE_LINK; +pub const REG_OPTION_NON_VOLATILE: u32 = Registry::REG_OPTION_NON_VOLATILE; +pub const REG_OPTION_OPEN_LINK: u32 = Registry::REG_OPTION_OPEN_LINK; +pub const REG_OPTION_RESERVED: u32 = Registry::REG_OPTION_RESERVED; +pub const REG_OPTION_VOLATILE: u32 = Registry::REG_OPTION_VOLATILE; +pub const REG_QWORD: u32 = Registry::REG_QWORD; +pub const REG_QWORD_LITTLE_ENDIAN: u32 = Registry::REG_QWORD_LITTLE_ENDIAN; +pub const REG_RESOURCE_LIST: u32 = Registry::REG_RESOURCE_LIST; +pub const REG_RESOURCE_REQUIREMENTS_LIST: u32 = Registry::REG_RESOURCE_REQUIREMENTS_LIST; +pub const REG_SZ: u32 = Registry::REG_SZ; +pub const REG_WHOLE_HIVE_VOLATILE: u32 = Registry::REG_WHOLE_HIVE_VOLATILE as u32; +pub const REG_REFRESH_HIVE: u32 = 0x00000002; +pub const REG_NO_LAZY_FLUSH: u32 = 0x00000004; +pub const REG_LEGAL_OPTION: u32 = Registry::REG_OPTION_RESERVED + | Registry::REG_OPTION_NON_VOLATILE + | Registry::REG_OPTION_VOLATILE + | Registry::REG_OPTION_CREATE_LINK + | Registry::REG_OPTION_BACKUP_RESTORE + | Registry::REG_OPTION_OPEN_LINK; +pub const REG_LEGAL_CHANGE_FILTER: u32 = Registry::REG_NOTIFY_CHANGE_NAME + | Registry::REG_NOTIFY_CHANGE_ATTRIBUTES + | Registry::REG_NOTIFY_CHANGE_LAST_SET + | Registry::REG_NOTIFY_CHANGE_SECURITY; + +pub fn bytes_as_wide_slice(bytes: &[u8]) -> &[u16] { + let (prefix, u16_slice, suffix) = unsafe { bytes.align_to::() }; + debug_assert!( + prefix.is_empty() && suffix.is_empty(), + "Registry data should be u16-aligned" + ); + u16_slice +} + +pub fn close_key(hkey: Registry::HKEY) -> u32 { + unsafe { Registry::RegCloseKey(hkey) } +} + +pub unsafe fn connect_registry( + computer_name: Option<&widestring::WideCStr>, + key: Registry::HKEY, + out_key: *mut Registry::HKEY, +) -> u32 { + let name_ptr = computer_name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { Registry::RegConnectRegistryW(name_ptr, key, out_key) } +} + +pub unsafe fn create_key( + key: Registry::HKEY, + sub_key: &widestring::WideCStr, + out_key: *mut Registry::HKEY, +) -> u32 { + unsafe { Registry::RegCreateKeyW(key, sub_key.as_ptr(), out_key) } +} + +pub unsafe fn create_key_ex( + key: Registry::HKEY, + sub_key: &widestring::WideCStr, + reserved: u32, + class: *mut u16, + options: u32, + sam: u32, + security: *const SECURITY_ATTRIBUTES, + result: *mut Registry::HKEY, + disposition: *mut u32, +) -> u32 { + unsafe { + Registry::RegCreateKeyExW( + key, + sub_key.as_ptr(), + reserved, + class, + options, + sam, + security, + result, + disposition, + ) + } +} + +pub unsafe fn delete_key(key: Registry::HKEY, sub_key: &widestring::WideCStr) -> u32 { + unsafe { Registry::RegDeleteKeyW(key, sub_key.as_ptr()) } +} + +pub unsafe fn delete_key_ex( + key: Registry::HKEY, + sub_key: &widestring::WideCStr, + sam: u32, + reserved: u32, +) -> u32 { + unsafe { Registry::RegDeleteKeyExW(key, sub_key.as_ptr(), sam, reserved) } +} + +pub unsafe fn delete_value(key: Registry::HKEY, value_name: Option<&widestring::WideCStr>) -> u32 { + let name_ptr = value_name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { Registry::RegDeleteValueW(key, name_ptr) } +} + +pub unsafe fn enum_key_ex( + key: Registry::HKEY, + index: u32, + name: *mut u16, + name_len: *mut u32, +) -> u32 { + unsafe { + Registry::RegEnumKeyExW( + key, + index, + name, + name_len, + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + ) + } +} + +pub unsafe fn query_info_key( + key: Registry::HKEY, + sub_keys: *mut u32, + values: *mut u32, + max_value_name_len: *mut u32, + max_value_len: *mut u32, +) -> u32 { + unsafe { + Registry::RegQueryInfoKeyW( + key, + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + sub_keys, + core::ptr::null_mut(), + core::ptr::null_mut(), + values, + max_value_name_len, + max_value_len, + core::ptr::null_mut(), + core::ptr::null_mut(), + ) + } +} + +pub struct QueryInfo { + pub sub_keys: u32, + pub values: u32, + pub last_write_time: u64, +} + +pub fn query_info_key_full(key: Registry::HKEY) -> Result { + let mut sub_keys = 0; + let mut values = 0; + let mut last_write_time: Foundation::FILETIME = unsafe { core::mem::zeroed() }; + let err = unsafe { + Registry::RegQueryInfoKeyW( + key, + core::ptr::null_mut(), + core::ptr::null_mut(), + 0 as _, + &mut sub_keys, + core::ptr::null_mut(), + core::ptr::null_mut(), + &mut values, + core::ptr::null_mut(), + core::ptr::null_mut(), + core::ptr::null_mut(), + &mut last_write_time, + ) + }; + if err != 0 { + return Err(err); + } + Ok(QueryInfo { + sub_keys, + values, + last_write_time: ((last_write_time.dwHighDateTime as u64) << 32) + | last_write_time.dwLowDateTime as u64, + }) +} + +pub unsafe fn enum_value( + key: Registry::HKEY, + index: u32, + value_name: *mut u16, + value_name_len: *mut u32, + value_type: *mut u32, + data: *mut u8, + data_len: *mut u32, +) -> u32 { + unsafe { + Registry::RegEnumValueW( + key, + index, + value_name, + value_name_len, + core::ptr::null_mut(), + value_type, + data, + data_len, + ) + } +} + +pub fn flush_key(key: Registry::HKEY) -> u32 { + unsafe { Registry::RegFlushKey(key) } +} + +pub unsafe fn load_key( + key: Registry::HKEY, + sub_key: &widestring::WideCStr, + file_name: &widestring::WideCStr, +) -> u32 { + unsafe { Registry::RegLoadKeyW(key, sub_key.as_ptr(), file_name.as_ptr()) } +} + +pub unsafe fn open_key_ex( + key: Registry::HKEY, + sub_key: &widestring::WideCStr, + options: u32, + sam: u32, + out_key: *mut Registry::HKEY, +) -> u32 { + unsafe { Registry::RegOpenKeyExW(key, sub_key.as_ptr(), options, sam, out_key) } +} + +pub unsafe fn query_value_ex( + key: Registry::HKEY, + value_name: Option<&widestring::WideCStr>, + value_type: *mut u32, + data: *mut u8, + data_len: *mut u32, +) -> u32 { + let name_ptr = value_name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { + Registry::RegQueryValueExW( + key, + name_ptr, + core::ptr::null_mut(), + value_type, + data, + data_len, + ) + } +} + +pub unsafe fn save_key(key: Registry::HKEY, file_name: &widestring::WideCStr) -> u32 { + unsafe { Registry::RegSaveKeyW(key, file_name.as_ptr(), core::ptr::null_mut()) } +} + +pub unsafe fn set_value_ex( + key: Registry::HKEY, + value_name: Option<&widestring::WideCStr>, + typ: u32, + ptr: *const u8, + len: u32, +) -> u32 { + let name_ptr = value_name.map_or(core::ptr::null(), |n| n.as_ptr()); + unsafe { Registry::RegSetValueExW(key, name_ptr, 0, typ, ptr, len) } +} + +pub fn disable_reflection_key(key: Registry::HKEY) -> u32 { + unsafe { Registry::RegDisableReflectionKey(key) } +} + +pub fn enable_reflection_key(key: Registry::HKEY) -> u32 { + unsafe { Registry::RegEnableReflectionKey(key) } +} + +pub unsafe fn query_reflection_key(key: Registry::HKEY, result: *mut i32) -> u32 { + unsafe { Registry::RegQueryReflectionKey(key, result) } +} + +pub enum ExpandEnvironmentStringsError { + Os, + Utf16(FromUtf16Error), +} + +pub enum QueryStringError { + Code(u32), + Utf16(FromUtf16Error), +} + +pub fn query_default_value( + hkey: Registry::HKEY, + sub_key: Option<&OsStr>, +) -> Result { + let child_key = if let Some(sub_key) = sub_key.filter(|s| !s.is_empty()) { + let wide_sub_key = sub_key.to_wide_cstring(); + let mut out_key = core::ptr::null_mut(); + let res = unsafe { + open_key_ex( + hkey, + &wide_sub_key, + 0, + Registry::KEY_QUERY_VALUE, + &mut out_key, + ) + }; + if res != 0 { + return Err(QueryStringError::Code(res)); + } + Some(out_key) + } else { + None + }; + + let target_key = child_key.unwrap_or(hkey); + let mut buf_size: u32 = 256; + let mut buffer: Vec = vec![0; buf_size as usize]; + let mut reg_type: u32 = 0; + + let result = loop { + let mut size = buf_size; + let res = unsafe { + query_value_ex( + target_key, + None, + &mut reg_type, + buffer.as_mut_ptr(), + &mut size, + ) + }; + if res == Foundation::ERROR_MORE_DATA { + buf_size *= 2; + buffer.resize(buf_size as usize, 0); + continue; + } + if res == Foundation::ERROR_FILE_NOT_FOUND { + break Ok(String::new()); + } + if res != 0 { + break Err(QueryStringError::Code(res)); + } + if reg_type != Registry::REG_SZ { + break Err(QueryStringError::Code(Foundation::ERROR_INVALID_DATA)); + } + + let u16_slice = bytes_as_wide_slice(&buffer[..size as usize]); + let len = u16_slice + .iter() + .position(|&c| c == 0) + .unwrap_or(u16_slice.len()); + break String::from_utf16(&u16_slice[..len]).map_err(QueryStringError::Utf16); + }; + + if let Some(ck) = child_key { + close_key(ck); + } + + result +} + +pub fn query_value_bytes(hkey: Registry::HKEY, value_name: &OsStr) -> Result<(Vec, u32), u32> { + let wide_name = value_name.to_wide_cstring(); + let mut buf_size: u32 = 0; + let res = unsafe { + query_value_ex( + hkey, + Some(&wide_name), + core::ptr::null_mut(), + core::ptr::null_mut(), + &mut buf_size, + ) + }; + if res == Foundation::ERROR_MORE_DATA || buf_size == 0 { + buf_size = 256; + } else if res != 0 { + return Err(res); + } + + let mut ret_buf = vec![0u8; buf_size as usize]; + let mut typ = 0; + + loop { + let mut ret_size = buf_size; + let res = unsafe { + query_value_ex( + hkey, + Some(&wide_name), + &mut typ, + ret_buf.as_mut_ptr(), + &mut ret_size, + ) + }; + if res != Foundation::ERROR_MORE_DATA { + if res != 0 { + return Err(res); + } + ret_buf.truncate(ret_size as usize); + return Ok((ret_buf, typ)); + } + buf_size *= 2; + ret_buf.resize(buf_size as usize, 0); + } +} + +pub fn set_default_value(hkey: Registry::HKEY, sub_key: &OsStr, typ: u32, value: &OsStr) -> u32 { + let child_key = if !sub_key.is_empty() { + let wide_sub_key = sub_key.to_wide_cstring(); + let mut out_key = core::ptr::null_mut(); + let res = unsafe { + create_key_ex( + hkey, + &wide_sub_key, + 0, + core::ptr::null_mut(), + 0, + Registry::KEY_SET_VALUE, + core::ptr::null(), + &mut out_key, + core::ptr::null_mut(), + ) + }; + if res != 0 { + return res; + } + Some(out_key) + } else { + None + }; + + let target_key = child_key.unwrap_or(hkey); + let wide_value = value.to_wide_with_nul(); + let res = unsafe { + set_value_ex( + target_key, + None, + typ, + wide_value.as_ptr() as *const u8, + (wide_value.len() * 2) as u32, + ) + }; + + if let Some(ck) = child_key { + close_key(ck); + } + res +} + +pub fn expand_environment_strings(input: &OsStr) -> Result { + let wide_input = input.to_wide_with_nul(); + let required_size = unsafe { + Environment::ExpandEnvironmentStringsW(wide_input.as_ptr(), core::ptr::null_mut(), 0) + }; + if required_size == 0 { + return Err(ExpandEnvironmentStringsError::Os); + } + + let mut out = vec![0u16; required_size as usize]; + let written = unsafe { + Environment::ExpandEnvironmentStringsW(wide_input.as_ptr(), out.as_mut_ptr(), required_size) + }; + if written == 0 { + return Err(ExpandEnvironmentStringsError::Os); + } + + let len = out.iter().position(|&c| c == 0).unwrap_or(out.len()); + String::from_utf16(&out[..len]).map_err(ExpandEnvironmentStringsError::Utf16) +} diff --git a/crates/host_env/src/winsound.rs b/crates/host_env/src/winsound.rs new file mode 100644 index 00000000000..07fc5751df1 --- /dev/null +++ b/crates/host_env/src/winsound.rs @@ -0,0 +1,81 @@ +// spell-checker:ignore pszSound fdwSound winmm + +use std::io; + +#[link(name = "winmm")] +unsafe extern "system" { + fn PlaySoundW(pszSound: *const u16, hmod: isize, fdwSound: u32) -> i32; +} + +unsafe extern "system" { + fn Beep(dwFreq: u32, dwDuration: u32) -> i32; + fn MessageBeep(uType: u32) -> i32; +} + +/// `SND_ASYNC` flag value from `mmsystem.h`. +const SND_ASYNC: u32 = 0x0001; +/// `SND_MEMORY` flag value from `mmsystem.h`. +const SND_MEMORY: u32 = 0x0004; + +/// Source for a `PlaySound` call. +pub enum PlaySoundSource<'a> { + /// Stop currently playing sound (NULL `pszSound`). + Stop, + /// Play sound data from memory; pass with `SND_MEMORY` set in `flags`. + Memory(&'a [u8]), + /// Play sound by filename or system alias. + Name(&'a widestring::WideCStr), +} + +/// Returns `Ok(())` when `PlaySoundW` returns non-zero, an error otherwise. +/// +/// Rejects `Memory(_)` together with `SND_ASYNC` because async playback +/// requires the buffer to outlive the call; combining them with a borrowed +/// slice would let WinMM read freed memory. +pub fn play_sound(source: PlaySoundSource<'_>, flags: u32) -> Result<(), PlaySoundError> { + if matches!(source, PlaySoundSource::Memory(_)) && flags & SND_ASYNC != 0 { + return Err(PlaySoundError::MemoryAsyncRejected); + } + // `SND_MEMORY` requires a `Memory(_)` source; an empty pointer would + // dereference garbage. + if !matches!(source, PlaySoundSource::Memory(_)) && flags & SND_MEMORY != 0 { + return Err(PlaySoundError::MemoryFlagWithoutBuffer); + } + let ptr: *const u16 = match source { + PlaySoundSource::Stop => core::ptr::null(), + PlaySoundSource::Memory(buf) => buf.as_ptr().cast(), + PlaySoundSource::Name(s) => s.as_ptr(), + }; + let ok = unsafe { PlaySoundW(ptr, 0, flags) }; + if ok == 0 { + Err(PlaySoundError::CallFailed) + } else { + Ok(()) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum PlaySoundError { + /// `PlaySoundW` returned 0; there is no documented errno for this path. + CallFailed, + /// `Memory(_)` source cannot be combined with `SND_ASYNC` in the safe API. + MemoryAsyncRejected, + /// `SND_MEMORY` set in `flags` but no `Memory(_)` buffer supplied. + MemoryFlagWithoutBuffer, +} + +/// `Beep(freq, duration)`. `false` on failure. +#[must_use] +pub fn beep(frequency: u32, duration_ms: u32) -> bool { + unsafe { Beep(frequency, duration_ms) != 0 } +} + +/// `MessageBeep(type)`. On failure returns `Err` populated from `GetLastError`. +pub fn message_beep(beep_type: u32) -> io::Result<()> { + let ok = unsafe { MessageBeep(beep_type) }; + if ok == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } +} diff --git a/crates/host_env/src/wmi.rs b/crates/host_env/src/wmi.rs new file mode 100644 index 00000000000..a6d77f77e9f --- /dev/null +++ b/crates/host_env/src/wmi.rs @@ -0,0 +1,676 @@ +#![allow( + clippy::upper_case_acronyms, + reason = "These names mirror the Windows COM and ABI types they wrap." +)] +#![allow(non_snake_case)] +#![allow(unsafe_op_in_unsafe_fn)] + +use core::ffi::c_void; +use core::ptr::{null, null_mut}; +use windows_sys::Win32::Foundation::{ + CloseHandle, ERROR_BROKEN_PIPE, ERROR_MORE_DATA, ERROR_NOT_ENOUGH_MEMORY, GetLastError, HANDLE, + WAIT_OBJECT_0, WAIT_TIMEOUT, +}; +use windows_sys::Win32::Storage::FileSystem::{ReadFile, WriteFile}; +use windows_sys::Win32::System::Pipes::CreatePipe; +use windows_sys::Win32::System::Threading::{ + CreateEventW, CreateThread, GetExitCodeThread, SetEvent, WaitForSingleObject, +}; + +pub const BUFFER_SIZE: usize = 8192; + +pub enum ExecQueryError { + MoreData, + Code(u32), +} + +type HRESULT = i32; + +#[repr(C)] +struct GUID { + data1: u32, + data2: u16, + data3: u16, + data4: [u8; 8], +} + +#[repr(C, align(8))] +struct VARIANT([u64; 3]); + +impl VARIANT { + fn zeroed() -> Self { + Self([0u64; 3]) + } +} + +const CLSID_WBEM_LOCATOR: GUID = GUID { + data1: 0x4590F811, + data2: 0x1D3A, + data3: 0x11D0, + data4: [0x89, 0x1F, 0x00, 0xAA, 0x00, 0x4B, 0x2E, 0x24], +}; + +const IID_IWBEM_LOCATOR: GUID = GUID { + data1: 0xDC12A687, + data2: 0x737F, + data3: 0x11CF, + data4: [0x88, 0x4D, 0x00, 0xAA, 0x00, 0x4B, 0x2E, 0x24], +}; + +const COINIT_APARTMENTTHREADED: u32 = 0x2; +const CLSCTX_INPROC_SERVER: u32 = 0x1; +const RPC_C_AUTHN_LEVEL_DEFAULT: u32 = 0; +const RPC_C_IMP_LEVEL_IMPERSONATE: u32 = 3; +const RPC_C_AUTHN_LEVEL_CALL: u32 = 3; +const RPC_C_AUTHN_WINNT: u32 = 10; +const RPC_C_AUTHZ_NONE: u32 = 0; +const EOAC_NONE: u32 = 0; +const RPC_E_TOO_LATE: HRESULT = 0x80010119_u32 as i32; +const WBEM_FLAG_FORWARD_ONLY: i32 = 0x20; +const WBEM_FLAG_RETURN_IMMEDIATELY: i32 = 0x10; +const WBEM_S_FALSE: HRESULT = 1; +const WBEM_S_NO_MORE_DATA: HRESULT = 0x40005; +const WBEM_INFINITE: i32 = -1; +const WBEM_FLAVOR_MASK_ORIGIN: i32 = 0x60; +const WBEM_FLAVOR_ORIGIN_SYSTEM: i32 = 0x40; + +#[link(name = "ole32")] +unsafe extern "system" { + fn CoInitializeEx(pvReserved: *mut c_void, dwCoInit: u32) -> HRESULT; + fn CoUninitialize(); + fn CoInitializeSecurity( + pSecDesc: *const c_void, + cAuthSvc: i32, + asAuthSvc: *const c_void, + pReserved1: *const c_void, + dwAuthnLevel: u32, + dwImpLevel: u32, + pAuthList: *const c_void, + dwCapabilities: u32, + pReserved3: *const c_void, + ) -> HRESULT; + fn CoCreateInstance( + rclsid: *const GUID, + pUnkOuter: *mut c_void, + dwClsContext: u32, + riid: *const GUID, + ppv: *mut *mut c_void, + ) -> HRESULT; + fn CoSetProxyBlanket( + pProxy: *mut c_void, + dwAuthnSvc: u32, + dwAuthzSvc: u32, + pServerPrincName: *const u16, + dwAuthnLevel: u32, + dwImpLevel: u32, + pAuthInfo: *const c_void, + dwCapabilities: u32, + ) -> HRESULT; +} + +#[link(name = "oleaut32")] +unsafe extern "system" { + fn SysAllocString(psz: *const u16) -> *mut u16; + fn SysFreeString(bstrString: *mut u16); + fn VariantClear(pvarg: *mut VARIANT) -> HRESULT; +} + +#[link(name = "propsys")] +unsafe extern "system" { + fn VariantToString(varIn: *const VARIANT, pszBuf: *mut u16, cchBuf: u32) -> HRESULT; +} + +unsafe fn com_release(this: *mut c_void) { + if !this.is_null() { + let vtable = *(this as *const *const usize); + let release: unsafe extern "system" fn(*mut c_void) -> u32 = + core::mem::transmute(*vtable.add(2)); + release(this); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn locator_connect_server( + this: *mut c_void, + network_resource: *const u16, + user: *const u16, + password: *const u16, + locale: *const u16, + security_flags: i32, + authority: *const u16, + ctx: *mut c_void, + services: *mut *mut c_void, +) -> HRESULT { + let vtable = *(this as *const *const usize); + let method: unsafe extern "system" fn( + *mut c_void, + *const u16, + *const u16, + *const u16, + *const u16, + i32, + *const u16, + *mut c_void, + *mut *mut c_void, + ) -> HRESULT = core::mem::transmute(*vtable.add(3)); + method( + this, + network_resource, + user, + password, + locale, + security_flags, + authority, + ctx, + services, + ) +} + +unsafe fn services_exec_query( + this: *mut c_void, + query_language: *const u16, + query: *const u16, + flags: i32, + ctx: *mut c_void, + enumerator: *mut *mut c_void, +) -> HRESULT { + let vtable = *(this as *const *const usize); + let method: unsafe extern "system" fn( + *mut c_void, + *const u16, + *const u16, + i32, + *mut c_void, + *mut *mut c_void, + ) -> HRESULT = core::mem::transmute(*vtable.add(20)); + method(this, query_language, query, flags, ctx, enumerator) +} + +unsafe fn enum_next( + this: *mut c_void, + timeout: i32, + count: u32, + objects: *mut *mut c_void, + returned: *mut u32, +) -> HRESULT { + let vtable = *(this as *const *const usize); + let method: unsafe extern "system" fn( + *mut c_void, + i32, + u32, + *mut *mut c_void, + *mut u32, + ) -> HRESULT = core::mem::transmute(*vtable.add(4)); + method(this, timeout, count, objects, returned) +} + +unsafe fn object_begin_enumeration(this: *mut c_void, enum_flags: i32) -> HRESULT { + let vtable = *(this as *const *const usize); + let method: unsafe extern "system" fn(*mut c_void, i32) -> HRESULT = + core::mem::transmute(*vtable.add(8)); + method(this, enum_flags) +} + +unsafe fn object_next( + this: *mut c_void, + flags: i32, + name: *mut *mut u16, + val: *mut VARIANT, + cim_type: *mut i32, + flavor: *mut i32, +) -> HRESULT { + let vtable = *(this as *const *const usize); + let method: unsafe extern "system" fn( + *mut c_void, + i32, + *mut *mut u16, + *mut VARIANT, + *mut i32, + *mut i32, + ) -> HRESULT = core::mem::transmute(*vtable.add(9)); + method(this, flags, name, val, cim_type, flavor) +} + +unsafe fn object_end_enumeration(this: *mut c_void) -> HRESULT { + let vtable = *(this as *const *const usize); + let method: unsafe extern "system" fn(*mut c_void) -> HRESULT = + core::mem::transmute(*vtable.add(10)); + method(this) +} + +fn hresult_from_win32(err: u32) -> HRESULT { + if err == 0 { + 0 + } else { + ((err & 0xFFFF) | 0x80070000) as HRESULT + } +} + +fn succeeded(hr: HRESULT) -> bool { + hr >= 0 +} + +fn failed(hr: HRESULT) -> bool { + hr < 0 +} + +fn wide_str(s: &str) -> Vec { + s.encode_utf16().chain(core::iter::once(0)).collect() +} + +unsafe fn wcslen(s: *const u16) -> usize { + let mut len = 0; + while unsafe { *s.add(len) } != 0 { + len += 1; + } + len +} + +unsafe fn wait_event(event: HANDLE, timeout: u32) -> u32 { + match unsafe { WaitForSingleObject(event, timeout) } { + WAIT_OBJECT_0 => 0, + WAIT_TIMEOUT => WAIT_TIMEOUT, + _ => unsafe { GetLastError() }, + } +} + +struct QueryThreadData { + query: Vec, + write_pipe: HANDLE, + init_event: HANDLE, + connect_event: HANDLE, +} + +unsafe impl Send for QueryThreadData {} + +unsafe extern "system" fn query_thread(param: *mut c_void) -> u32 { + unsafe { query_thread_impl(param) } +} + +unsafe fn query_thread_impl(param: *mut c_void) -> u32 { + let data = unsafe { Box::from_raw(param as *mut QueryThreadData) }; + let write_pipe = data.write_pipe; + let init_event = data.init_event; + let connect_event = data.connect_event; + + let mut locator: *mut c_void = null_mut(); + let mut services: *mut c_void = null_mut(); + let mut enumerator: *mut c_void = null_mut(); + let mut hr: HRESULT = 0; + + let bstr_query = unsafe { SysAllocString(data.query.as_ptr()) }; + if bstr_query.is_null() { + hr = hresult_from_win32(ERROR_NOT_ENOUGH_MEMORY); + } + + drop(data); + + if succeeded(hr) { + hr = unsafe { CoInitializeEx(null_mut(), COINIT_APARTMENTTHREADED) }; + } + + if failed(hr) { + unsafe { + CloseHandle(write_pipe); + if !bstr_query.is_null() { + SysFreeString(bstr_query); + } + } + return hr as u32; + } + + hr = unsafe { + CoInitializeSecurity( + null(), + -1, + null(), + null(), + RPC_C_AUTHN_LEVEL_DEFAULT, + RPC_C_IMP_LEVEL_IMPERSONATE, + null(), + EOAC_NONE, + null(), + ) + }; + if hr == RPC_E_TOO_LATE { + hr = 0; + } + + if succeeded(hr) { + hr = unsafe { + CoCreateInstance( + &CLSID_WBEM_LOCATOR, + null_mut(), + CLSCTX_INPROC_SERVER, + &IID_IWBEM_LOCATOR, + &mut locator, + ) + }; + } + if succeeded(hr) && unsafe { SetEvent(init_event) } == 0 { + hr = hresult_from_win32(unsafe { GetLastError() }); + } + + if succeeded(hr) { + let root_cimv2 = wide_str("ROOT\\CIMV2"); + let bstr_root = unsafe { SysAllocString(root_cimv2.as_ptr()) }; + hr = unsafe { + locator_connect_server( + locator, + bstr_root, + null(), + null(), + null(), + 0, + null(), + null_mut(), + &mut services, + ) + }; + if !bstr_root.is_null() { + unsafe { SysFreeString(bstr_root) }; + } + } + if succeeded(hr) && unsafe { SetEvent(connect_event) } == 0 { + hr = hresult_from_win32(unsafe { GetLastError() }); + } + + if succeeded(hr) { + hr = unsafe { + CoSetProxyBlanket( + services, + RPC_C_AUTHN_WINNT, + RPC_C_AUTHZ_NONE, + null(), + RPC_C_AUTHN_LEVEL_CALL, + RPC_C_IMP_LEVEL_IMPERSONATE, + null(), + EOAC_NONE, + ) + }; + } + if succeeded(hr) { + let wql = wide_str("WQL"); + let bstr_wql = unsafe { SysAllocString(wql.as_ptr()) }; + hr = unsafe { + services_exec_query( + services, + bstr_wql, + bstr_query, + WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, + null_mut(), + &mut enumerator, + ) + }; + if !bstr_wql.is_null() { + unsafe { SysFreeString(bstr_wql) }; + } + } + + let mut value: *mut c_void; + let mut start_of_enum = true; + let null_sep: u16 = 0; + let eq_sign: u16 = b'=' as u16; + + while succeeded(hr) { + let mut got: u32 = 0; + let mut written: u32 = 0; + value = null_mut(); + hr = unsafe { enum_next(enumerator, WBEM_INFINITE, 1, &mut value, &mut got) }; + + if hr == WBEM_S_FALSE { + hr = 0; + break; + } + if failed(hr) || got != 1 || value.is_null() { + continue; + } + + if !start_of_enum + && unsafe { + WriteFile( + write_pipe, + &null_sep as *const u16 as *const _, + 2, + &mut written, + null_mut(), + ) + } == 0 + { + hr = hresult_from_win32(unsafe { GetLastError() }); + unsafe { com_release(value) }; + break; + } + start_of_enum = false; + + hr = unsafe { object_begin_enumeration(value, 0) }; + if failed(hr) { + unsafe { com_release(value) }; + break; + } + + while succeeded(hr) { + let mut prop_name: *mut u16 = null_mut(); + let mut prop_value = VARIANT::zeroed(); + let mut flavor: i32 = 0; + + hr = unsafe { + object_next( + value, + 0, + &mut prop_name, + &mut prop_value, + null_mut(), + &mut flavor, + ) + }; + + if hr == WBEM_S_NO_MORE_DATA { + hr = 0; + break; + } + + if succeeded(hr) && (flavor & WBEM_FLAVOR_MASK_ORIGIN) != WBEM_FLAVOR_ORIGIN_SYSTEM { + let mut prop_str = [0u16; BUFFER_SIZE]; + hr = unsafe { + VariantToString(&prop_value, prop_str.as_mut_ptr(), BUFFER_SIZE as u32) + }; + + if succeeded(hr) { + let cb_str1 = (unsafe { wcslen(prop_name) } * 2) as u32; + let cb_str2 = (unsafe { wcslen(prop_str.as_ptr()) } * 2) as u32; + + if unsafe { + WriteFile( + write_pipe, + prop_name as *const _, + cb_str1, + &mut written, + null_mut(), + ) + } == 0 + || unsafe { + WriteFile( + write_pipe, + &eq_sign as *const u16 as *const _, + 2, + &mut written, + null_mut(), + ) + } == 0 + || unsafe { + WriteFile( + write_pipe, + prop_str.as_ptr() as *const _, + cb_str2, + &mut written, + null_mut(), + ) + } == 0 + || unsafe { + WriteFile( + write_pipe, + &null_sep as *const u16 as *const _, + 2, + &mut written, + null_mut(), + ) + } == 0 + { + hr = hresult_from_win32(unsafe { GetLastError() }); + } + } + + unsafe { + VariantClear(&mut prop_value); + SysFreeString(prop_name); + } + } + } + + unsafe { + object_end_enumeration(value); + com_release(value); + } + } + + unsafe { + if !bstr_query.is_null() { + SysFreeString(bstr_query); + } + if !enumerator.is_null() { + com_release(enumerator); + } + if !services.is_null() { + com_release(services); + } + if !locator.is_null() { + com_release(locator); + } + CoUninitialize(); + CloseHandle(write_pipe); + } + + hr as u32 +} + +pub fn exec_query(query_str: &str) -> Result { + let query_wide = wide_str(query_str); + + let mut h_thread: HANDLE = null_mut(); + let mut err: u32 = 0; + let mut buffer = [0u16; BUFFER_SIZE]; + let mut offset: u32 = 0; + let mut bytes_read: u32 = 0; + + let mut read_pipe: HANDLE = null_mut(); + let mut write_pipe: HANDLE = null_mut(); + + unsafe { + let init_event = CreateEventW(null(), 1, 0, null()); + let connect_event = CreateEventW(null(), 1, 0, null()); + + if init_event.is_null() + || connect_event.is_null() + || CreatePipe(&mut read_pipe, &mut write_pipe, null(), 0) == 0 + { + err = GetLastError(); + } else { + let thread_data = Box::new(QueryThreadData { + query: query_wide, + write_pipe, + init_event, + connect_event, + }); + let thread_data_ptr = Box::into_raw(thread_data); + + h_thread = CreateThread( + null(), + 0, + Some(query_thread), + thread_data_ptr as *const _ as *mut _, + 0, + null_mut(), + ); + + if h_thread.is_null() { + err = GetLastError(); + let data = Box::from_raw(thread_data_ptr); + CloseHandle(data.write_pipe); + } + } + + if err == 0 { + err = wait_event(init_event, 1000); + if err == 0 { + err = wait_event(connect_event, 100); + } + } + + while err == 0 { + let buf_ptr = (buffer.as_mut_ptr() as *mut u8).add(offset as usize); + let buf_remaining = (BUFFER_SIZE * 2) as u32 - offset; + + if ReadFile( + read_pipe, + buf_ptr as *mut _, + buf_remaining, + &mut bytes_read, + null_mut(), + ) != 0 + { + offset += bytes_read; + if offset >= (BUFFER_SIZE * 2) as u32 { + err = ERROR_MORE_DATA; + } + } else { + err = GetLastError(); + } + } + + if !read_pipe.is_null() { + CloseHandle(read_pipe); + } + + if !h_thread.is_null() { + let thread_err: u32; + match WaitForSingleObject(h_thread, 100) { + WAIT_OBJECT_0 => { + let mut exit_code: u32 = 0; + if GetExitCodeThread(h_thread, &mut exit_code) == 0 { + thread_err = GetLastError(); + } else { + thread_err = exit_code; + } + } + WAIT_TIMEOUT => { + thread_err = WAIT_TIMEOUT; + } + _ => { + thread_err = GetLastError(); + } + } + if err == 0 || err == ERROR_BROKEN_PIPE { + err = thread_err; + } + + CloseHandle(h_thread); + } + + CloseHandle(init_event); + CloseHandle(connect_event); + } + + if err == ERROR_MORE_DATA { + return Err(ExecQueryError::MoreData); + } + if err != 0 { + return Err(ExecQueryError::Code(err)); + } + if offset == 0 { + return Ok(String::new()); + } + + let char_count = (offset as usize) / 2 - 1; + Ok(String::from_utf16_lossy(&buffer[..char_count])) +} diff --git a/crates/jit/tests/bool_tests.rs b/crates/jit/tests/bool_tests.rs index 191993938df..8a5f4ea9db3 100644 --- a/crates/jit/tests/bool_tests.rs +++ b/crates/jit/tests/bool_tests.rs @@ -1,43 +1,45 @@ -#[test] -fn test_return() { - let return_ = jit_function! { return_(a: bool) -> bool => r##" +#[cfg(test)] +mod tests { + #[test] + fn basic_return() { + let return_ = jit_function! { return_(a: bool) -> bool => r##" def return_(a: bool): return a "## }; - assert_eq!(return_(true), Ok(true)); - assert_eq!(return_(false), Ok(false)); -} + assert_eq!(return_(true), Ok(true)); + assert_eq!(return_(false), Ok(false)); + } -#[test] -fn test_const() { - let const_true = jit_function! { const_true(a: i64) -> bool => r##" + #[test] + fn basic_const() { + let const_true = jit_function! { const_true(a: i64) -> bool => r##" def const_true(a: int): return True "## }; - assert_eq!(const_true(0), Ok(true)); + assert_eq!(const_true(0), Ok(true)); - let const_false = jit_function! { const_false(a: i64) -> bool => r##" + let const_false = jit_function! { const_false(a: i64) -> bool => r##" def const_false(a: int): return False "## }; - assert_eq!(const_false(0), Ok(false)); -} + assert_eq!(const_false(0), Ok(false)); + } -#[test] -fn test_not() { - let not_ = jit_function! { not_(a: bool) -> bool => r##" + #[test] + fn basic_not() { + let not_ = jit_function! { not_(a: bool) -> bool => r##" def not_(a: bool): return not a "## }; - assert_eq!(not_(true), Ok(false)); - assert_eq!(not_(false), Ok(true)); -} + assert_eq!(not_(true), Ok(false)); + assert_eq!(not_(false), Ok(true)); + } -#[test] -fn test_if_not() { - let if_not = jit_function! { if_not(a: bool) -> i64 => r##" + #[test] + fn basic_if_not() { + let if_not = jit_function! { if_not(a: bool) -> i64 => r##" def if_not(a: bool): if not a: return 0 @@ -47,156 +49,157 @@ fn test_if_not() { return -1 "## }; - assert_eq!(if_not(true), Ok(1)); - assert_eq!(if_not(false), Ok(0)); -} + assert_eq!(if_not(true), Ok(1)); + assert_eq!(if_not(false), Ok(0)); + } -#[test] -fn test_eq() { - let eq = jit_function! { eq(a:bool, b:bool) -> i64 => r##" + #[test] + fn basic_eq() { + let eq = jit_function! { eq(a:bool, b:bool) -> i64 => r##" def eq(a: bool, b: bool): if a == b: return 1 return 0 "## }; - assert_eq!(eq(false, false), Ok(1)); - assert_eq!(eq(true, true), Ok(1)); - assert_eq!(eq(false, true), Ok(0)); - assert_eq!(eq(true, false), Ok(0)); -} + assert_eq!(eq(false, false), Ok(1)); + assert_eq!(eq(true, true), Ok(1)); + assert_eq!(eq(false, true), Ok(0)); + assert_eq!(eq(true, false), Ok(0)); + } -#[test] -fn test_eq_with_integers() { - let eq = jit_function! { eq(a:bool, b:i64) -> i64 => r##" + #[test] + fn eq_with_integers() { + let eq = jit_function! { eq(a:bool, b:i64) -> i64 => r##" def eq(a: bool, b: int): if a == b: return 1 return 0 "## }; - assert_eq!(eq(false, 0), Ok(1)); - assert_eq!(eq(true, 1), Ok(1)); - assert_eq!(eq(false, 1), Ok(0)); - assert_eq!(eq(true, 0), Ok(0)); -} + assert_eq!(eq(false, 0), Ok(1)); + assert_eq!(eq(true, 1), Ok(1)); + assert_eq!(eq(false, 1), Ok(0)); + assert_eq!(eq(true, 0), Ok(0)); + } -#[test] -fn test_gt() { - let gt = jit_function! { gt(a:bool, b:bool) -> i64 => r##" + #[test] + fn basic_gt() { + let gt = jit_function! { gt(a:bool, b:bool) -> i64 => r##" def gt(a: bool, b: bool): if a > b: return 1 return 0 "## }; - assert_eq!(gt(false, false), Ok(0)); - assert_eq!(gt(true, true), Ok(0)); - assert_eq!(gt(false, true), Ok(0)); - assert_eq!(gt(true, false), Ok(1)); -} + assert_eq!(gt(false, false), Ok(0)); + assert_eq!(gt(true, true), Ok(0)); + assert_eq!(gt(false, true), Ok(0)); + assert_eq!(gt(true, false), Ok(1)); + } -#[test] -fn test_gt_with_integers() { - let gt = jit_function! { gt(a:i64, b:bool) -> i64 => r##" + #[test] + fn gt_with_integers() { + let gt = jit_function! { gt(a:i64, b:bool) -> i64 => r##" def gt(a: int, b: bool): if a > b: return 1 return 0 "## }; - assert_eq!(gt(0, false), Ok(0)); - assert_eq!(gt(1, true), Ok(0)); - assert_eq!(gt(0, true), Ok(0)); - assert_eq!(gt(1, false), Ok(1)); -} + assert_eq!(gt(0, false), Ok(0)); + assert_eq!(gt(1, true), Ok(0)); + assert_eq!(gt(0, true), Ok(0)); + assert_eq!(gt(1, false), Ok(1)); + } -#[test] -fn test_lt() { - let lt = jit_function! { lt(a:bool, b:bool) -> i64 => r##" + #[test] + fn basic_lt() { + let lt = jit_function! { lt(a:bool, b:bool) -> i64 => r##" def lt(a: bool, b: bool): if a < b: return 1 return 0 "## }; - assert_eq!(lt(false, false), Ok(0)); - assert_eq!(lt(true, true), Ok(0)); - assert_eq!(lt(false, true), Ok(1)); - assert_eq!(lt(true, false), Ok(0)); -} + assert_eq!(lt(false, false), Ok(0)); + assert_eq!(lt(true, true), Ok(0)); + assert_eq!(lt(false, true), Ok(1)); + assert_eq!(lt(true, false), Ok(0)); + } -#[test] -fn test_lt_with_integers() { - let lt = jit_function! { lt(a:i64, b:bool) -> i64 => r##" + #[test] + fn lt_with_integers() { + let lt = jit_function! { lt(a:i64, b:bool) -> i64 => r##" def lt(a: int, b: bool): if a < b: return 1 return 0 "## }; - assert_eq!(lt(0, false), Ok(0)); - assert_eq!(lt(1, true), Ok(0)); - assert_eq!(lt(0, true), Ok(1)); - assert_eq!(lt(1, false), Ok(0)); -} + assert_eq!(lt(0, false), Ok(0)); + assert_eq!(lt(1, true), Ok(0)); + assert_eq!(lt(0, true), Ok(1)); + assert_eq!(lt(1, false), Ok(0)); + } -#[test] -fn test_gte() { - let gte = jit_function! { gte(a:bool, b:bool) -> i64 => r##" + #[test] + fn basic_gte() { + let gte = jit_function! { gte(a:bool, b:bool) -> i64 => r##" def gte(a: bool, b: bool): if a >= b: return 1 return 0 "## }; - assert_eq!(gte(false, false), Ok(1)); - assert_eq!(gte(true, true), Ok(1)); - assert_eq!(gte(false, true), Ok(0)); - assert_eq!(gte(true, false), Ok(1)); -} + assert_eq!(gte(false, false), Ok(1)); + assert_eq!(gte(true, true), Ok(1)); + assert_eq!(gte(false, true), Ok(0)); + assert_eq!(gte(true, false), Ok(1)); + } -#[test] -fn test_gte_with_integers() { - let gte = jit_function! { gte(a:bool, b:i64) -> i64 => r##" + #[test] + fn gte_with_integers() { + let gte = jit_function! { gte(a:bool, b:i64) -> i64 => r##" def gte(a: bool, b: int): if a >= b: return 1 return 0 "## }; - assert_eq!(gte(false, 0), Ok(1)); - assert_eq!(gte(true, 1), Ok(1)); - assert_eq!(gte(false, 1), Ok(0)); - assert_eq!(gte(true, 0), Ok(1)); -} + assert_eq!(gte(false, 0), Ok(1)); + assert_eq!(gte(true, 1), Ok(1)); + assert_eq!(gte(false, 1), Ok(0)); + assert_eq!(gte(true, 0), Ok(1)); + } -#[test] -fn test_lte() { - let lte = jit_function! { lte(a:bool, b:bool) -> i64 => r##" + #[test] + fn basic_lte() { + let lte = jit_function! { lte(a:bool, b:bool) -> i64 => r##" def lte(a: bool, b: bool): if a <= b: return 1 return 0 "## }; - assert_eq!(lte(false, false), Ok(1)); - assert_eq!(lte(true, true), Ok(1)); - assert_eq!(lte(false, true), Ok(1)); - assert_eq!(lte(true, false), Ok(0)); -} + assert_eq!(lte(false, false), Ok(1)); + assert_eq!(lte(true, true), Ok(1)); + assert_eq!(lte(false, true), Ok(1)); + assert_eq!(lte(true, false), Ok(0)); + } -#[test] -fn test_lte_with_integers() { - let lte = jit_function! { lte(a:bool, b:i64) -> i64 => r##" + #[test] + fn lte_with_integers() { + let lte = jit_function! { lte(a:bool, b:i64) -> i64 => r##" def lte(a: bool, b: int): if a <= b: return 1 return 0 "## }; - assert_eq!(lte(false, 0), Ok(1)); - assert_eq!(lte(true, 1), Ok(1)); - assert_eq!(lte(false, 1), Ok(1)); - assert_eq!(lte(true, 0), Ok(0)); + assert_eq!(lte(false, 0), Ok(1)); + assert_eq!(lte(true, 1), Ok(1)); + assert_eq!(lte(false, 1), Ok(1)); + assert_eq!(lte(true, 0), Ok(0)); + } } diff --git a/crates/jit/tests/float_tests.rs b/crates/jit/tests/float_tests.rs index b5fcba9fc6a..b9bbb3ea63c 100644 --- a/crates/jit/tests/float_tests.rs +++ b/crates/jit/tests/float_tests.rs @@ -1,379 +1,382 @@ -macro_rules! assert_approx_eq { - ($left:expr, $right:expr) => { - match ($left, $right) { - (Ok(lhs), Ok(rhs)) => approx::assert_relative_eq!(lhs, rhs), - (lhs, rhs) => assert_eq!(lhs, rhs), - } - }; -} - -macro_rules! assert_bits_eq { - ($left:expr, $right:expr) => { - match ($left, $right) { - (Ok(lhs), Ok(rhs)) => assert!(lhs.to_bits() == rhs.to_bits()), - (lhs, rhs) => assert_eq!(lhs, rhs), - } - }; -} - -#[test] -fn test_add() { - let add = jit_function! { add(a:f64, b:f64) -> f64 => r##" +#[cfg(test)] +mod tests { + macro_rules! assert_approx_eq { + ($left:expr, $right:expr) => { + match ($left, $right) { + (Ok(lhs), Ok(rhs)) => approx::assert_relative_eq!(lhs, rhs), + (lhs, rhs) => assert_eq!(lhs, rhs), + } + }; + } + + macro_rules! assert_bits_eq { + ($left:expr, $right:expr) => { + match ($left, $right) { + (Ok(lhs), Ok(rhs)) => assert!(lhs.to_bits() == rhs.to_bits()), + (lhs, rhs) => assert_eq!(lhs, rhs), + } + }; + } + + #[test] + fn basic_add() { + let add = jit_function! { add(a:f64, b:f64) -> f64 => r##" def add(a: float, b: float): return a + b "## }; - assert_approx_eq!(add(5.5, 10.2), Ok(15.7)); - assert_approx_eq!(add(-4.5, 7.6), Ok(3.1)); - assert_approx_eq!(add(-5.2, -3.9), Ok(-9.1)); - assert_bits_eq!(add(-5.2, f64::NAN), Ok(f64::NAN)); - assert_eq!(add(2.0, f64::INFINITY), Ok(f64::INFINITY)); - assert_eq!(add(-2.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); - assert_eq!(add(1.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); -} - -#[test] -fn test_add_with_integer() { - let add = jit_function! { add(a:f64, b:i64) -> f64 => r##" + assert_approx_eq!(add(5.5, 10.2), Ok(15.7)); + assert_approx_eq!(add(-4.5, 7.6), Ok(3.1)); + assert_approx_eq!(add(-5.2, -3.9), Ok(-9.1)); + assert_bits_eq!(add(-5.2, f64::NAN), Ok(f64::NAN)); + assert_eq!(add(2.0, f64::INFINITY), Ok(f64::INFINITY)); + assert_eq!(add(-2.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); + assert_eq!(add(1.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); + } + + #[test] + fn add_with_integer() { + let add = jit_function! { add(a:f64, b:i64) -> f64 => r##" def add(a: float, b: int): return a + b "## }; - assert_approx_eq!(add(5.5, 10), Ok(15.5)); - assert_approx_eq!(add(-4.6, 7), Ok(2.4)); - assert_approx_eq!(add(-5.2, -3), Ok(-8.2)); -} + assert_approx_eq!(add(5.5, 10), Ok(15.5)); + assert_approx_eq!(add(-4.6, 7), Ok(2.4)); + assert_approx_eq!(add(-5.2, -3), Ok(-8.2)); + } -#[test] -fn test_sub() { - let sub = jit_function! { sub(a:f64, b:f64) -> f64 => r##" + #[test] + fn basic_sub() { + let sub = jit_function! { sub(a:f64, b:f64) -> f64 => r##" def sub(a: float, b: float): return a - b "## }; - assert_approx_eq!(sub(5.2, 3.6), Ok(1.6)); - assert_approx_eq!(sub(3.4, 4.2), Ok(-0.8)); - assert_approx_eq!(sub(-2.1, 1.3), Ok(-3.4)); - assert_approx_eq!(sub(3.1, -1.3), Ok(4.4)); - assert_bits_eq!(sub(-5.2, f64::NAN), Ok(f64::NAN)); - assert_eq!(sub(f64::INFINITY, 2.0), Ok(f64::INFINITY)); - assert_eq!(sub(-2.0, f64::NEG_INFINITY), Ok(f64::INFINITY)); - assert_eq!(sub(1.0, f64::INFINITY), Ok(f64::NEG_INFINITY)); -} - -#[test] -fn test_sub_with_integer() { - let sub = jit_function! { sub(a:i64, b:f64) -> f64 => r##" + assert_approx_eq!(sub(5.2, 3.6), Ok(1.6)); + assert_approx_eq!(sub(3.4, 4.2), Ok(-0.8)); + assert_approx_eq!(sub(-2.1, 1.3), Ok(-3.4)); + assert_approx_eq!(sub(3.1, -1.3), Ok(4.4)); + assert_bits_eq!(sub(-5.2, f64::NAN), Ok(f64::NAN)); + assert_eq!(sub(f64::INFINITY, 2.0), Ok(f64::INFINITY)); + assert_eq!(sub(-2.0, f64::NEG_INFINITY), Ok(f64::INFINITY)); + assert_eq!(sub(1.0, f64::INFINITY), Ok(f64::NEG_INFINITY)); + } + + #[test] + fn sub_with_integer() { + let sub = jit_function! { sub(a:i64, b:f64) -> f64 => r##" def sub(a: int, b: float): return a - b "## }; - assert_approx_eq!(sub(5, 3.6), Ok(1.4)); - assert_approx_eq!(sub(3, -4.2), Ok(7.2)); - assert_approx_eq!(sub(-2, 1.3), Ok(-3.3)); - assert_approx_eq!(sub(-3, -1.3), Ok(-1.7)); -} + assert_approx_eq!(sub(5, 3.6), Ok(1.4)); + assert_approx_eq!(sub(3, -4.2), Ok(7.2)); + assert_approx_eq!(sub(-2, 1.3), Ok(-3.3)); + assert_approx_eq!(sub(-3, -1.3), Ok(-1.7)); + } -#[test] -fn test_mul() { - let mul = jit_function! { mul(a:f64, b:f64) -> f64 => r##" + #[test] + fn basic_mul() { + let mul = jit_function! { mul(a:f64, b:f64) -> f64 => r##" def mul(a: float, b: float): return a * b "## }; - assert_approx_eq!(mul(5.2, 2.0), Ok(10.4)); - assert_approx_eq!(mul(3.4, -1.7), Ok(-5.779999999999999)); - assert_bits_eq!(mul(1.0, 0.0), Ok(0.0f64)); - assert_bits_eq!(mul(1.0, -0.0), Ok(-0.0f64)); - assert_bits_eq!(mul(-1.0, 0.0), Ok(-0.0f64)); - assert_bits_eq!(mul(-1.0, -0.0), Ok(0.0f64)); - assert_bits_eq!(mul(-5.2, f64::NAN), Ok(f64::NAN)); - assert_eq!(mul(1.0, f64::INFINITY), Ok(f64::INFINITY)); - assert_eq!(mul(1.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); - assert_eq!(mul(-1.0, f64::INFINITY), Ok(f64::NEG_INFINITY)); - assert!(mul(0.0, f64::INFINITY).unwrap().is_nan()); - assert_eq!(mul(f64::NEG_INFINITY, f64::INFINITY), Ok(f64::NEG_INFINITY)); -} - -#[test] -fn test_mul_with_integer() { - let mul = jit_function! { mul(a:f64, b:i64) -> f64 => r##" + assert_approx_eq!(mul(5.2, 2.0), Ok(10.4)); + assert_approx_eq!(mul(3.4, -1.7), Ok(-5.779999999999999)); + assert_bits_eq!(mul(1.0, 0.0), Ok(0.0f64)); + assert_bits_eq!(mul(1.0, -0.0), Ok(-0.0f64)); + assert_bits_eq!(mul(-1.0, 0.0), Ok(-0.0f64)); + assert_bits_eq!(mul(-1.0, -0.0), Ok(0.0f64)); + assert_bits_eq!(mul(-5.2, f64::NAN), Ok(f64::NAN)); + assert_eq!(mul(1.0, f64::INFINITY), Ok(f64::INFINITY)); + assert_eq!(mul(1.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); + assert_eq!(mul(-1.0, f64::INFINITY), Ok(f64::NEG_INFINITY)); + assert!(mul(0.0, f64::INFINITY).unwrap().is_nan()); + assert_eq!(mul(f64::NEG_INFINITY, f64::INFINITY), Ok(f64::NEG_INFINITY)); + } + + #[test] + fn mul_with_integer() { + let mul = jit_function! { mul(a:f64, b:i64) -> f64 => r##" def mul(a: float, b: int): return a * b "## }; - assert_approx_eq!(mul(5.2, 2), Ok(10.4)); - assert_approx_eq!(mul(3.4, -1), Ok(-3.4)); - assert_bits_eq!(mul(1.0, 0), Ok(0.0f64)); - assert_bits_eq!(mul(-0.0, 1), Ok(-0.0f64)); - assert_bits_eq!(mul(0.0, -1), Ok(-0.0f64)); - assert_bits_eq!(mul(-0.0, -1), Ok(0.0f64)); -} - -#[test] -fn test_power() { - let pow = jit_function! { pow(a:f64, b:f64) -> f64 => r##" + assert_approx_eq!(mul(5.2, 2), Ok(10.4)); + assert_approx_eq!(mul(3.4, -1), Ok(-3.4)); + assert_bits_eq!(mul(1.0, 0), Ok(0.0f64)); + assert_bits_eq!(mul(-0.0, 1), Ok(-0.0f64)); + assert_bits_eq!(mul(0.0, -1), Ok(-0.0f64)); + assert_bits_eq!(mul(-0.0, -1), Ok(0.0f64)); + } + + #[test] + fn basic_power() { + let pow = jit_function! { pow(a:f64, b:f64) -> f64 => r##" def pow(a:float, b: float): return a**b "##}; - // Test base cases - assert_approx_eq!(pow(0.0, 0.0), Ok(1.0)); - assert_approx_eq!(pow(0.0, 1.0), Ok(0.0)); - assert_approx_eq!(pow(1.0, 0.0), Ok(1.0)); - assert_approx_eq!(pow(1.0, 1.0), Ok(1.0)); - assert_approx_eq!(pow(1.0, -1.0), Ok(1.0)); - assert_approx_eq!(pow(-1.0, 0.0), Ok(1.0)); - assert_approx_eq!(pow(-1.0, 1.0), Ok(-1.0)); - assert_approx_eq!(pow(-1.0, -1.0), Ok(-1.0)); - - // NaN and Infinity cases - assert_approx_eq!(pow(f64::NAN, 0.0), Ok(1.0)); - //assert_approx_eq!(pow(f64::NAN, 1.0), Ok(f64::NAN)); // Return the correct answer but fails compare - //assert_approx_eq!(pow(0.0, f64::NAN), Ok(f64::NAN)); // Return the correct answer but fails compare - assert_approx_eq!(pow(f64::INFINITY, 0.0), Ok(1.0)); - assert_approx_eq!(pow(f64::INFINITY, 1.0), Ok(f64::INFINITY)); - assert_approx_eq!(pow(f64::INFINITY, f64::INFINITY), Ok(f64::INFINITY)); - // Negative infinity cases: - // For any exponent of 0.0, the result is 1.0. - assert_approx_eq!(pow(f64::NEG_INFINITY, 0.0), Ok(1.0)); - // For negative infinity base, when b is an odd integer, result is -infinity; - // when b is even, result is +infinity. - assert_approx_eq!(pow(f64::NEG_INFINITY, 1.0), Ok(f64::NEG_INFINITY)); - assert_approx_eq!(pow(f64::NEG_INFINITY, 2.0), Ok(f64::INFINITY)); - assert_approx_eq!(pow(f64::NEG_INFINITY, 3.0), Ok(f64::NEG_INFINITY)); - // Exponent -infinity gives 0.0. - assert_approx_eq!(pow(f64::NEG_INFINITY, f64::NEG_INFINITY), Ok(0.0)); - - // Test positive float base, positive float exponent - assert_approx_eq!(pow(2.0, 2.0), Ok(4.0)); - assert_approx_eq!(pow(3.0, 3.0), Ok(27.0)); - assert_approx_eq!(pow(4.0, 4.0), Ok(256.0)); - assert_approx_eq!(pow(2.0, 3.0), Ok(8.0)); - assert_approx_eq!(pow(2.0, 4.0), Ok(16.0)); - // Test negative float base, positive float exponent (integral exponents only) - assert_approx_eq!(pow(-2.0, 2.0), Ok(4.0)); - assert_approx_eq!(pow(-3.0, 3.0), Ok(-27.0)); - assert_approx_eq!(pow(-4.0, 4.0), Ok(256.0)); - assert_approx_eq!(pow(-2.0, 3.0), Ok(-8.0)); - assert_approx_eq!(pow(-2.0, 4.0), Ok(16.0)); - // Test positive float base, positive float exponent - assert_approx_eq!(pow(2.5, 2.0), Ok(6.25)); - assert_approx_eq!(pow(3.5, 3.0), Ok(42.875)); - assert_approx_eq!(pow(4.5, 4.0), Ok(410.0625)); - assert_approx_eq!(pow(2.5, 3.0), Ok(15.625)); - assert_approx_eq!(pow(2.5, 4.0), Ok(39.0625)); - // Test negative float base, positive float exponent (integral exponents only) - assert_approx_eq!(pow(-2.5, 2.0), Ok(6.25)); - assert_approx_eq!(pow(-3.5, 3.0), Ok(-42.875)); - assert_approx_eq!(pow(-4.5, 4.0), Ok(410.0625)); - assert_approx_eq!(pow(-2.5, 3.0), Ok(-15.625)); - assert_approx_eq!(pow(-2.5, 4.0), Ok(39.0625)); - // Test positive float base, positive float exponent with non-integral exponents - assert_approx_eq!(pow(2.0, 2.5), Ok(5.656854249492381)); - assert_approx_eq!(pow(3.0, 3.5), Ok(46.76537180435969)); - assert_approx_eq!(pow(4.0, 4.5), Ok(512.0)); - assert_approx_eq!(pow(2.0, 3.5), Ok(11.313708498984761)); - assert_approx_eq!(pow(2.0, 4.5), Ok(22.627416997969522)); - // Test positive float base, negative float exponent - assert_approx_eq!(pow(2.0, -2.5), Ok(0.1767766952966369)); - assert_approx_eq!(pow(3.0, -3.5), Ok(0.021383343303319473)); - assert_approx_eq!(pow(4.0, -4.5), Ok(0.001953125)); - assert_approx_eq!(pow(2.0, -3.5), Ok(0.08838834764831845)); - assert_approx_eq!(pow(2.0, -4.5), Ok(0.04419417382415922)); - // Test negative float base, negative float exponent (integral exponents only) - assert_approx_eq!(pow(-2.0, -2.0), Ok(0.25)); - assert_approx_eq!(pow(-3.0, -3.0), Ok(-0.037037037037037035)); - assert_approx_eq!(pow(-4.0, -4.0), Ok(0.00390625)); - assert_approx_eq!(pow(-2.0, -3.0), Ok(-0.125)); - assert_approx_eq!(pow(-2.0, -4.0), Ok(0.0625)); - - // Currently negative float base with non-integral exponent is not supported: - // assert_approx_eq!(pow(-2.0, 2.5), Ok(5.656854249492381)); - // assert_approx_eq!(pow(-3.0, 3.5), Ok(-46.76537180435969)); - // assert_approx_eq!(pow(-4.0, 4.5), Ok(512.0)); - // assert_approx_eq!(pow(-2.0, -2.5), Ok(0.1767766952966369)); - // assert_approx_eq!(pow(-3.0, -3.5), Ok(0.021383343303319473)); - // assert_approx_eq!(pow(-4.0, -4.5), Ok(0.001953125)); - - // Extra cases **NOTE** these are not all working: - // * If they are commented in then they work - // * If they are commented out with a number that is the current return value it throws vs the expected value - // * If they are commented out with a "fail to run" that means I couldn't get them to work, could add a case for really big or small values - // 1e308^2.0 - assert_approx_eq!(pow(1e308, 2.0), Ok(f64::INFINITY)); - // 1e308^(1e-2) - assert_approx_eq!(pow(1e308, 1e-2), Ok(1202.2644346174131)); - // 1e-308^2.0 - //assert_approx_eq!(pow(1e-308, 2.0), Ok(0.0)); // --8.403311421507407 - // 1e-308^-2.0 - assert_approx_eq!(pow(1e-308, -2.0), Ok(f64::INFINITY)); - // 1e100^(1e50) - //assert_approx_eq!(pow(1e100, 1e50), Ok(1.0000000000000002e+150)); // fail to run (Crashes as "illegal hardware instruction") - // 1e50^(1e-100) - assert_approx_eq!(pow(1e50, 1e-100), Ok(1.0)); - // 1e308^(-1e2) - //assert_approx_eq!(pow(1e308, -1e2), Ok(0.0)); // 2.961801792837933e25 - // 1e-308^(1e2) - //assert_approx_eq!(pow(1e-308, 1e2), Ok(f64::INFINITY)); // 1.6692559244043896e46 - // 1e308^(-1e308) - // assert_approx_eq!(pow(1e308, -1e308), Ok(0.0)); // fail to run (Crashes as "illegal hardware instruction") - // 1e-308^(1e308) - // assert_approx_eq!(pow(1e-308, 1e308), Ok(0.0)); // fail to run (Crashes as "illegal hardware instruction") -} - -#[test] -fn test_div() { - let div = jit_function! { div(a:f64, b:f64) -> f64 => r##" + // Test base cases + assert_approx_eq!(pow(0.0, 0.0), Ok(1.0)); + assert_approx_eq!(pow(0.0, 1.0), Ok(0.0)); + assert_approx_eq!(pow(1.0, 0.0), Ok(1.0)); + assert_approx_eq!(pow(1.0, 1.0), Ok(1.0)); + assert_approx_eq!(pow(1.0, -1.0), Ok(1.0)); + assert_approx_eq!(pow(-1.0, 0.0), Ok(1.0)); + assert_approx_eq!(pow(-1.0, 1.0), Ok(-1.0)); + assert_approx_eq!(pow(-1.0, -1.0), Ok(-1.0)); + + // NaN and Infinity cases + assert_approx_eq!(pow(f64::NAN, 0.0), Ok(1.0)); + //assert_approx_eq!(pow(f64::NAN, 1.0), Ok(f64::NAN)); // Return the correct answer but fails compare + //assert_approx_eq!(pow(0.0, f64::NAN), Ok(f64::NAN)); // Return the correct answer but fails compare + assert_approx_eq!(pow(f64::INFINITY, 0.0), Ok(1.0)); + assert_approx_eq!(pow(f64::INFINITY, 1.0), Ok(f64::INFINITY)); + assert_approx_eq!(pow(f64::INFINITY, f64::INFINITY), Ok(f64::INFINITY)); + // Negative infinity cases: + // For any exponent of 0.0, the result is 1.0. + assert_approx_eq!(pow(f64::NEG_INFINITY, 0.0), Ok(1.0)); + // For negative infinity base, when b is an odd integer, result is -infinity; + // when b is even, result is +infinity. + assert_approx_eq!(pow(f64::NEG_INFINITY, 1.0), Ok(f64::NEG_INFINITY)); + assert_approx_eq!(pow(f64::NEG_INFINITY, 2.0), Ok(f64::INFINITY)); + assert_approx_eq!(pow(f64::NEG_INFINITY, 3.0), Ok(f64::NEG_INFINITY)); + // Exponent -infinity gives 0.0. + assert_approx_eq!(pow(f64::NEG_INFINITY, f64::NEG_INFINITY), Ok(0.0)); + + // Test positive float base, positive float exponent + assert_approx_eq!(pow(2.0, 2.0), Ok(4.0)); + assert_approx_eq!(pow(3.0, 3.0), Ok(27.0)); + assert_approx_eq!(pow(4.0, 4.0), Ok(256.0)); + assert_approx_eq!(pow(2.0, 3.0), Ok(8.0)); + assert_approx_eq!(pow(2.0, 4.0), Ok(16.0)); + // Test negative float base, positive float exponent (integral exponents only) + assert_approx_eq!(pow(-2.0, 2.0), Ok(4.0)); + assert_approx_eq!(pow(-3.0, 3.0), Ok(-27.0)); + assert_approx_eq!(pow(-4.0, 4.0), Ok(256.0)); + assert_approx_eq!(pow(-2.0, 3.0), Ok(-8.0)); + assert_approx_eq!(pow(-2.0, 4.0), Ok(16.0)); + // Test positive float base, positive float exponent + assert_approx_eq!(pow(2.5, 2.0), Ok(6.25)); + assert_approx_eq!(pow(3.5, 3.0), Ok(42.875)); + assert_approx_eq!(pow(4.5, 4.0), Ok(410.0625)); + assert_approx_eq!(pow(2.5, 3.0), Ok(15.625)); + assert_approx_eq!(pow(2.5, 4.0), Ok(39.0625)); + // Test negative float base, positive float exponent (integral exponents only) + assert_approx_eq!(pow(-2.5, 2.0), Ok(6.25)); + assert_approx_eq!(pow(-3.5, 3.0), Ok(-42.875)); + assert_approx_eq!(pow(-4.5, 4.0), Ok(410.0625)); + assert_approx_eq!(pow(-2.5, 3.0), Ok(-15.625)); + assert_approx_eq!(pow(-2.5, 4.0), Ok(39.0625)); + // Test positive float base, positive float exponent with non-integral exponents + assert_approx_eq!(pow(2.0, 2.5), Ok(5.656854249492381)); + assert_approx_eq!(pow(3.0, 3.5), Ok(46.76537180435969)); + assert_approx_eq!(pow(4.0, 4.5), Ok(512.0)); + assert_approx_eq!(pow(2.0, 3.5), Ok(11.313708498984761)); + assert_approx_eq!(pow(2.0, 4.5), Ok(22.627416997969522)); + // Test positive float base, negative float exponent + assert_approx_eq!(pow(2.0, -2.5), Ok(0.1767766952966369)); + assert_approx_eq!(pow(3.0, -3.5), Ok(0.021383343303319473)); + assert_approx_eq!(pow(4.0, -4.5), Ok(0.001953125)); + assert_approx_eq!(pow(2.0, -3.5), Ok(0.08838834764831845)); + assert_approx_eq!(pow(2.0, -4.5), Ok(0.04419417382415922)); + // Test negative float base, negative float exponent (integral exponents only) + assert_approx_eq!(pow(-2.0, -2.0), Ok(0.25)); + assert_approx_eq!(pow(-3.0, -3.0), Ok(-0.037037037037037035)); + assert_approx_eq!(pow(-4.0, -4.0), Ok(0.00390625)); + assert_approx_eq!(pow(-2.0, -3.0), Ok(-0.125)); + assert_approx_eq!(pow(-2.0, -4.0), Ok(0.0625)); + + // Currently negative float base with non-integral exponent is not supported: + // assert_approx_eq!(pow(-2.0, 2.5), Ok(5.656854249492381)); + // assert_approx_eq!(pow(-3.0, 3.5), Ok(-46.76537180435969)); + // assert_approx_eq!(pow(-4.0, 4.5), Ok(512.0)); + // assert_approx_eq!(pow(-2.0, -2.5), Ok(0.1767766952966369)); + // assert_approx_eq!(pow(-3.0, -3.5), Ok(0.021383343303319473)); + // assert_approx_eq!(pow(-4.0, -4.5), Ok(0.001953125)); + + // Extra cases **NOTE** these are not all working: + // * If they are commented in then they work + // * If they are commented out with a number that is the current return value it throws vs the expected value + // * If they are commented out with a "fail to run" that means I couldn't get them to work, could add a case for really big or small values + // 1e308^2.0 + assert_approx_eq!(pow(1e308, 2.0), Ok(f64::INFINITY)); + // 1e308^(1e-2) + assert_approx_eq!(pow(1e308, 1e-2), Ok(1202.2644346174131)); + // 1e-308^2.0 + //assert_approx_eq!(pow(1e-308, 2.0), Ok(0.0)); // --8.403311421507407 + // 1e-308^-2.0 + assert_approx_eq!(pow(1e-308, -2.0), Ok(f64::INFINITY)); + // 1e100^(1e50) + //assert_approx_eq!(pow(1e100, 1e50), Ok(1.0000000000000002e+150)); // fail to run (Crashes as "illegal hardware instruction") + // 1e50^(1e-100) + assert_approx_eq!(pow(1e50, 1e-100), Ok(1.0)); + // 1e308^(-1e2) + //assert_approx_eq!(pow(1e308, -1e2), Ok(0.0)); // 2.961801792837933e25 + // 1e-308^(1e2) + //assert_approx_eq!(pow(1e-308, 1e2), Ok(f64::INFINITY)); // 1.6692559244043896e46 + // 1e308^(-1e308) + // assert_approx_eq!(pow(1e308, -1e308), Ok(0.0)); // fail to run (Crashes as "illegal hardware instruction") + // 1e-308^(1e308) + // assert_approx_eq!(pow(1e-308, 1e308), Ok(0.0)); // fail to run (Crashes as "illegal hardware instruction") + } + + #[test] + fn basic_div() { + let div = jit_function! { div(a:f64, b:f64) -> f64 => r##" def div(a: float, b: float): return a / b "## }; - assert_approx_eq!(div(5.2, 2.0), Ok(2.6)); - assert_approx_eq!(div(3.4, -1.7), Ok(-2.0)); - assert_eq!(div(1.0, 0.0), Ok(f64::INFINITY)); - assert_eq!(div(1.0, -0.0), Ok(f64::NEG_INFINITY)); - assert_eq!(div(-1.0, 0.0), Ok(f64::NEG_INFINITY)); - assert_eq!(div(-1.0, -0.0), Ok(f64::INFINITY)); - assert_bits_eq!(div(-5.2, f64::NAN), Ok(f64::NAN)); - assert_eq!(div(f64::INFINITY, 2.0), Ok(f64::INFINITY)); - assert_bits_eq!(div(-2.0, f64::NEG_INFINITY), Ok(0.0f64)); - assert_bits_eq!(div(1.0, f64::INFINITY), Ok(0.0f64)); - assert_bits_eq!(div(2.0, f64::NEG_INFINITY), Ok(-0.0f64)); - assert_bits_eq!(div(-1.0, f64::INFINITY), Ok(-0.0f64)); -} - -#[test] -fn test_div_with_integer() { - let div = jit_function! { div(a:f64, b:i64) -> f64 => r##" + assert_approx_eq!(div(5.2, 2.0), Ok(2.6)); + assert_approx_eq!(div(3.4, -1.7), Ok(-2.0)); + assert_eq!(div(1.0, 0.0), Ok(f64::INFINITY)); + assert_eq!(div(1.0, -0.0), Ok(f64::NEG_INFINITY)); + assert_eq!(div(-1.0, 0.0), Ok(f64::NEG_INFINITY)); + assert_eq!(div(-1.0, -0.0), Ok(f64::INFINITY)); + assert_bits_eq!(div(-5.2, f64::NAN), Ok(f64::NAN)); + assert_eq!(div(f64::INFINITY, 2.0), Ok(f64::INFINITY)); + assert_bits_eq!(div(-2.0, f64::NEG_INFINITY), Ok(0.0f64)); + assert_bits_eq!(div(1.0, f64::INFINITY), Ok(0.0f64)); + assert_bits_eq!(div(2.0, f64::NEG_INFINITY), Ok(-0.0f64)); + assert_bits_eq!(div(-1.0, f64::INFINITY), Ok(-0.0f64)); + } + + #[test] + fn div_with_integer() { + let div = jit_function! { div(a:f64, b:i64) -> f64 => r##" def div(a: float, b: int): return a / b "## }; - assert_approx_eq!(div(5.2, 2), Ok(2.6)); - assert_approx_eq!(div(3.4, -1), Ok(-3.4)); - assert_eq!(div(1.0, 0), Ok(f64::INFINITY)); - assert_eq!(div(1.0, -0), Ok(f64::INFINITY)); - assert_eq!(div(-1.0, 0), Ok(f64::NEG_INFINITY)); - assert_eq!(div(-1.0, -0), Ok(f64::NEG_INFINITY)); - assert_eq!(div(f64::INFINITY, 2), Ok(f64::INFINITY)); - assert_eq!(div(f64::NEG_INFINITY, 3), Ok(f64::NEG_INFINITY)); -} - -#[test] -fn test_if_bool() { - let if_bool = jit_function! { if_bool(a:f64) -> i64 => r##" + assert_approx_eq!(div(5.2, 2), Ok(2.6)); + assert_approx_eq!(div(3.4, -1), Ok(-3.4)); + assert_eq!(div(1.0, 0), Ok(f64::INFINITY)); + assert_eq!(div(1.0, -0), Ok(f64::INFINITY)); + assert_eq!(div(-1.0, 0), Ok(f64::NEG_INFINITY)); + assert_eq!(div(-1.0, -0), Ok(f64::NEG_INFINITY)); + assert_eq!(div(f64::INFINITY, 2), Ok(f64::INFINITY)); + assert_eq!(div(f64::NEG_INFINITY, 3), Ok(f64::NEG_INFINITY)); + } + + #[test] + fn basic_if_bool() { + let if_bool = jit_function! { if_bool(a:f64) -> i64 => r##" def if_bool(a: float): if a: return 1 return 0 "## }; - assert_eq!(if_bool(5.2), Ok(1)); - assert_eq!(if_bool(-3.4), Ok(1)); - assert_eq!(if_bool(f64::NAN), Ok(1)); - assert_eq!(if_bool(f64::INFINITY), Ok(1)); + assert_eq!(if_bool(5.2), Ok(1)); + assert_eq!(if_bool(-3.4), Ok(1)); + assert_eq!(if_bool(f64::NAN), Ok(1)); + assert_eq!(if_bool(f64::INFINITY), Ok(1)); - assert_eq!(if_bool(0.0), Ok(0)); -} + assert_eq!(if_bool(0.0), Ok(0)); + } -#[test] -fn test_float_eq() { - let float_eq = jit_function! { float_eq(a: f64, b: f64) -> bool => r##" + #[test] + fn basic_float_eq() { + let float_eq = jit_function! { float_eq(a: f64, b: f64) -> bool => r##" def float_eq(a: float, b: float): return a == b "## }; - assert_eq!(float_eq(2.0, 2.0), Ok(true)); - assert_eq!(float_eq(3.4, -1.7), Ok(false)); - assert_eq!(float_eq(0.0, 0.0), Ok(true)); - assert_eq!(float_eq(-0.0, -0.0), Ok(true)); - assert_eq!(float_eq(-0.0, 0.0), Ok(true)); - assert_eq!(float_eq(-5.2, f64::NAN), Ok(false)); - assert_eq!(float_eq(f64::NAN, f64::NAN), Ok(false)); - assert_eq!(float_eq(f64::INFINITY, f64::NEG_INFINITY), Ok(false)); -} - -#[test] -fn test_float_ne() { - let float_ne = jit_function! { float_ne(a: f64, b: f64) -> bool => r##" + assert_eq!(float_eq(2.0, 2.0), Ok(true)); + assert_eq!(float_eq(3.4, -1.7), Ok(false)); + assert_eq!(float_eq(0.0, 0.0), Ok(true)); + assert_eq!(float_eq(-0.0, -0.0), Ok(true)); + assert_eq!(float_eq(-0.0, 0.0), Ok(true)); + assert_eq!(float_eq(-5.2, f64::NAN), Ok(false)); + assert_eq!(float_eq(f64::NAN, f64::NAN), Ok(false)); + assert_eq!(float_eq(f64::INFINITY, f64::NEG_INFINITY), Ok(false)); + } + + #[test] + fn basic_float_ne() { + let float_ne = jit_function! { float_ne(a: f64, b: f64) -> bool => r##" def float_ne(a: float, b: float): return a != b "## }; - assert_eq!(float_ne(2.0, 2.0), Ok(false)); - assert_eq!(float_ne(3.4, -1.7), Ok(true)); - assert_eq!(float_ne(0.0, 0.0), Ok(false)); - assert_eq!(float_ne(-0.0, -0.0), Ok(false)); - assert_eq!(float_ne(-0.0, 0.0), Ok(false)); - assert_eq!(float_ne(-5.2, f64::NAN), Ok(true)); - assert_eq!(float_ne(f64::NAN, f64::NAN), Ok(true)); - assert_eq!(float_ne(f64::INFINITY, f64::NEG_INFINITY), Ok(true)); -} - -#[test] -fn test_float_gt() { - let float_gt = jit_function! { float_gt(a: f64, b: f64) -> bool => r##" + assert_eq!(float_ne(2.0, 2.0), Ok(false)); + assert_eq!(float_ne(3.4, -1.7), Ok(true)); + assert_eq!(float_ne(0.0, 0.0), Ok(false)); + assert_eq!(float_ne(-0.0, -0.0), Ok(false)); + assert_eq!(float_ne(-0.0, 0.0), Ok(false)); + assert_eq!(float_ne(-5.2, f64::NAN), Ok(true)); + assert_eq!(float_ne(f64::NAN, f64::NAN), Ok(true)); + assert_eq!(float_ne(f64::INFINITY, f64::NEG_INFINITY), Ok(true)); + } + + #[test] + fn basic_float_gt() { + let float_gt = jit_function! { float_gt(a: f64, b: f64) -> bool => r##" def float_gt(a: float, b: float): return a > b "## }; - assert_eq!(float_gt(2.0, 2.0), Ok(false)); - assert_eq!(float_gt(3.4, -1.7), Ok(true)); - assert_eq!(float_gt(0.0, 0.0), Ok(false)); - assert_eq!(float_gt(-0.0, -0.0), Ok(false)); - assert_eq!(float_gt(-0.0, 0.0), Ok(false)); - assert_eq!(float_gt(-5.2, f64::NAN), Ok(false)); - assert_eq!(float_gt(f64::NAN, f64::NAN), Ok(false)); - assert_eq!(float_gt(f64::INFINITY, f64::NEG_INFINITY), Ok(true)); -} - -#[test] -fn test_float_gte() { - let float_gte = jit_function! { float_gte(a: f64, b: f64) -> bool => r##" + assert_eq!(float_gt(2.0, 2.0), Ok(false)); + assert_eq!(float_gt(3.4, -1.7), Ok(true)); + assert_eq!(float_gt(0.0, 0.0), Ok(false)); + assert_eq!(float_gt(-0.0, -0.0), Ok(false)); + assert_eq!(float_gt(-0.0, 0.0), Ok(false)); + assert_eq!(float_gt(-5.2, f64::NAN), Ok(false)); + assert_eq!(float_gt(f64::NAN, f64::NAN), Ok(false)); + assert_eq!(float_gt(f64::INFINITY, f64::NEG_INFINITY), Ok(true)); + } + + #[test] + fn basic_float_gte() { + let float_gte = jit_function! { float_gte(a: f64, b: f64) -> bool => r##" def float_gte(a: float, b: float): return a >= b "## }; - assert_eq!(float_gte(2.0, 2.0), Ok(true)); - assert_eq!(float_gte(3.4, -1.7), Ok(true)); - assert_eq!(float_gte(0.0, 0.0), Ok(true)); - assert_eq!(float_gte(-0.0, -0.0), Ok(true)); - assert_eq!(float_gte(-0.0, 0.0), Ok(true)); - assert_eq!(float_gte(-5.2, f64::NAN), Ok(false)); - assert_eq!(float_gte(f64::NAN, f64::NAN), Ok(false)); - assert_eq!(float_gte(f64::INFINITY, f64::NEG_INFINITY), Ok(true)); -} - -#[test] -fn test_float_lt() { - let float_lt = jit_function! { float_lt(a: f64, b: f64) -> bool => r##" + assert_eq!(float_gte(2.0, 2.0), Ok(true)); + assert_eq!(float_gte(3.4, -1.7), Ok(true)); + assert_eq!(float_gte(0.0, 0.0), Ok(true)); + assert_eq!(float_gte(-0.0, -0.0), Ok(true)); + assert_eq!(float_gte(-0.0, 0.0), Ok(true)); + assert_eq!(float_gte(-5.2, f64::NAN), Ok(false)); + assert_eq!(float_gte(f64::NAN, f64::NAN), Ok(false)); + assert_eq!(float_gte(f64::INFINITY, f64::NEG_INFINITY), Ok(true)); + } + + #[test] + fn basic_float_lt() { + let float_lt = jit_function! { float_lt(a: f64, b: f64) -> bool => r##" def float_lt(a: float, b: float): return a < b "## }; - assert_eq!(float_lt(2.0, 2.0), Ok(false)); - assert_eq!(float_lt(3.4, -1.7), Ok(false)); - assert_eq!(float_lt(0.0, 0.0), Ok(false)); - assert_eq!(float_lt(-0.0, -0.0), Ok(false)); - assert_eq!(float_lt(-0.0, 0.0), Ok(false)); - assert_eq!(float_lt(-5.2, f64::NAN), Ok(false)); - assert_eq!(float_lt(f64::NAN, f64::NAN), Ok(false)); - assert_eq!(float_lt(f64::INFINITY, f64::NEG_INFINITY), Ok(false)); -} - -#[test] -fn test_float_lte() { - let float_lte = jit_function! { float_lte(a: f64, b: f64) -> bool => r##" + assert_eq!(float_lt(2.0, 2.0), Ok(false)); + assert_eq!(float_lt(3.4, -1.7), Ok(false)); + assert_eq!(float_lt(0.0, 0.0), Ok(false)); + assert_eq!(float_lt(-0.0, -0.0), Ok(false)); + assert_eq!(float_lt(-0.0, 0.0), Ok(false)); + assert_eq!(float_lt(-5.2, f64::NAN), Ok(false)); + assert_eq!(float_lt(f64::NAN, f64::NAN), Ok(false)); + assert_eq!(float_lt(f64::INFINITY, f64::NEG_INFINITY), Ok(false)); + } + + #[test] + fn basic_float_lte() { + let float_lte = jit_function! { float_lte(a: f64, b: f64) -> bool => r##" def float_lte(a: float, b: float): return a <= b "## }; - assert_eq!(float_lte(2.0, 2.0), Ok(true)); - assert_eq!(float_lte(3.4, -1.7), Ok(false)); - assert_eq!(float_lte(0.0, 0.0), Ok(true)); - assert_eq!(float_lte(-0.0, -0.0), Ok(true)); - assert_eq!(float_lte(-0.0, 0.0), Ok(true)); - assert_eq!(float_lte(-5.2, f64::NAN), Ok(false)); - assert_eq!(float_lte(f64::NAN, f64::NAN), Ok(false)); - assert_eq!(float_lte(f64::INFINITY, f64::NEG_INFINITY), Ok(false)); + assert_eq!(float_lte(2.0, 2.0), Ok(true)); + assert_eq!(float_lte(3.4, -1.7), Ok(false)); + assert_eq!(float_lte(0.0, 0.0), Ok(true)); + assert_eq!(float_lte(-0.0, -0.0), Ok(true)); + assert_eq!(float_lte(-0.0, 0.0), Ok(true)); + assert_eq!(float_lte(-5.2, f64::NAN), Ok(false)); + assert_eq!(float_lte(f64::NAN, f64::NAN), Ok(false)); + assert_eq!(float_lte(f64::INFINITY, f64::NEG_INFINITY), Ok(false)); + } } diff --git a/crates/jit/tests/int_tests.rs b/crates/jit/tests/int_tests.rs index 5ab2697e075..23cf98aafe1 100644 --- a/crates/jit/tests/int_tests.rs +++ b/crates/jit/tests/int_tests.rs @@ -1,326 +1,329 @@ -use core::f64; +#[cfg(test)] +mod tests { + use core::f64; -#[test] -fn test_add() { - let add = jit_function! { add(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_add() { + let add = jit_function! { add(a:i64, b:i64) -> i64 => r##" def add(a: int, b: int): return a + b "## }; - assert_eq!(add(5, 10), Ok(15)); - assert_eq!(add(-5, 12), Ok(7)); - assert_eq!(add(-5, -3), Ok(-8)); -} + assert_eq!(add(5, 10), Ok(15)); + assert_eq!(add(-5, 12), Ok(7)); + assert_eq!(add(-5, -3), Ok(-8)); + } -#[test] -fn test_sub() { - let sub = jit_function! { sub(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_sub() { + let sub = jit_function! { sub(a:i64, b:i64) -> i64 => r##" def sub(a: int, b: int): return a - b "## }; - assert_eq!(sub(5, 10), Ok(-5)); - assert_eq!(sub(12, 10), Ok(2)); - assert_eq!(sub(7, 10), Ok(-3)); - assert_eq!(sub(-3, -10), Ok(7)); -} + assert_eq!(sub(5, 10), Ok(-5)); + assert_eq!(sub(12, 10), Ok(2)); + assert_eq!(sub(7, 10), Ok(-3)); + assert_eq!(sub(-3, -10), Ok(7)); + } -#[test] -fn test_mul() { - let mul = jit_function! { mul(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_mul() { + let mul = jit_function! { mul(a:i64, b:i64) -> i64 => r##" def mul(a: int, b: int): return a * b "## }; - assert_eq!(mul(5, 10), Ok(50)); - assert_eq!(mul(0, 5), Ok(0)); - assert_eq!(mul(5, 0), Ok(0)); - assert_eq!(mul(0, 0), Ok(0)); - assert_eq!(mul(-5, 10), Ok(-50)); - assert_eq!(mul(5, -10), Ok(-50)); - assert_eq!(mul(-5, -10), Ok(50)); - assert_eq!(mul(999999, 999999), Ok(999998000001)); - assert_eq!(mul(i64::MAX, 1), Ok(i64::MAX)); - assert_eq!(mul(1, i64::MAX), Ok(i64::MAX)); -} - -#[test] -fn test_div() { - let div = jit_function! { div(a:i64, b:i64) -> f64 => r##" + assert_eq!(mul(5, 10), Ok(50)); + assert_eq!(mul(0, 5), Ok(0)); + assert_eq!(mul(5, 0), Ok(0)); + assert_eq!(mul(0, 0), Ok(0)); + assert_eq!(mul(-5, 10), Ok(-50)); + assert_eq!(mul(5, -10), Ok(-50)); + assert_eq!(mul(-5, -10), Ok(50)); + assert_eq!(mul(999999, 999999), Ok(999998000001)); + assert_eq!(mul(i64::MAX, 1), Ok(i64::MAX)); + assert_eq!(mul(1, i64::MAX), Ok(i64::MAX)); + } + + #[test] + fn basic_div() { + let div = jit_function! { div(a:i64, b:i64) -> f64 => r##" def div(a: int, b: int): return a / b "## }; - assert_eq!(div(0, 1), Ok(0.0)); - assert_eq!(div(5, 1), Ok(5.0)); - assert_eq!(div(5, 10), Ok(0.5)); - assert_eq!(div(5, 2), Ok(2.5)); - assert_eq!(div(12, 10), Ok(1.2)); - assert_eq!(div(7, 10), Ok(0.7)); - assert_eq!(div(-3, -1), Ok(3.0)); - assert_eq!(div(-3, 1), Ok(-3.0)); - assert_eq!(div(1, 1000), Ok(0.001)); - assert_eq!(div(1, 100000), Ok(0.00001)); - assert_eq!(div(2, 3), Ok(0.6666666666666666)); - assert_eq!(div(1, 3), Ok(0.3333333333333333)); - assert_eq!(div(i64::MAX, 2), Ok(4611686018427387904.0)); - assert_eq!(div(i64::MIN, 2), Ok(-4611686018427387904.0)); - assert_eq!(div(i64::MIN, -1), Ok(9223372036854775808.0)); // Overflow case - assert_eq!(div(i64::MIN, i64::MAX), Ok(-1.0)); -} - -#[test] -fn test_floor_div() { - let floor_div = jit_function! { floor_div(a:i64, b:i64) -> i64 => r##" + assert_eq!(div(0, 1), Ok(0.0)); + assert_eq!(div(5, 1), Ok(5.0)); + assert_eq!(div(5, 10), Ok(0.5)); + assert_eq!(div(5, 2), Ok(2.5)); + assert_eq!(div(12, 10), Ok(1.2)); + assert_eq!(div(7, 10), Ok(0.7)); + assert_eq!(div(-3, -1), Ok(3.0)); + assert_eq!(div(-3, 1), Ok(-3.0)); + assert_eq!(div(1, 1000), Ok(0.001)); + assert_eq!(div(1, 100000), Ok(0.00001)); + assert_eq!(div(2, 3), Ok(0.6666666666666666)); + assert_eq!(div(1, 3), Ok(0.3333333333333333)); + assert_eq!(div(i64::MAX, 2), Ok(4611686018427387904.0)); + assert_eq!(div(i64::MIN, 2), Ok(-4611686018427387904.0)); + assert_eq!(div(i64::MIN, -1), Ok(9223372036854775808.0)); // Overflow case + assert_eq!(div(i64::MIN, i64::MAX), Ok(-1.0)); + } + + #[test] + fn basic_floor_div() { + let floor_div = jit_function! { floor_div(a:i64, b:i64) -> i64 => r##" def floor_div(a: int, b: int): return a // b "## }; - assert_eq!(floor_div(5, 10), Ok(0)); - assert_eq!(floor_div(5, 2), Ok(2)); - assert_eq!(floor_div(12, 10), Ok(1)); - assert_eq!(floor_div(7, 10), Ok(0)); - assert_eq!(floor_div(-3, -1), Ok(3)); -} + assert_eq!(floor_div(5, 10), Ok(0)); + assert_eq!(floor_div(5, 2), Ok(2)); + assert_eq!(floor_div(12, 10), Ok(1)); + assert_eq!(floor_div(7, 10), Ok(0)); + assert_eq!(floor_div(-3, -1), Ok(3)); + } -#[test] + #[test] -fn test_exp() { - let exp = jit_function! { exp(a: i64, b: i64) -> i64 => r##" + fn basic_exp() { + let exp = jit_function! { exp(a: i64, b: i64) -> i64 => r##" def exp(a: int, b: int): return a ** b "## }; - assert_eq!(exp(2, 3), Ok(8)); - assert_eq!(exp(3, 2), Ok(9)); - assert_eq!(exp(5, 0), Ok(1)); - assert_eq!(exp(0, 0), Ok(1)); - assert_eq!(exp(-5, 0), Ok(1)); - assert_eq!(exp(0, 1), Ok(0)); - assert_eq!(exp(0, 5), Ok(0)); - assert_eq!(exp(-2, 2), Ok(4)); - assert_eq!(exp(-3, 4), Ok(81)); - assert_eq!(exp(-2, 3), Ok(-8)); - assert_eq!(exp(-3, 3), Ok(-27)); - assert_eq!(exp(1000, 2), Ok(1000000)); -} - -#[test] -fn test_mod() { - let modulo = jit_function! { modulo(a:i64, b:i64) -> i64 => r##" + assert_eq!(exp(2, 3), Ok(8)); + assert_eq!(exp(3, 2), Ok(9)); + assert_eq!(exp(5, 0), Ok(1)); + assert_eq!(exp(0, 0), Ok(1)); + assert_eq!(exp(-5, 0), Ok(1)); + assert_eq!(exp(0, 1), Ok(0)); + assert_eq!(exp(0, 5), Ok(0)); + assert_eq!(exp(-2, 2), Ok(4)); + assert_eq!(exp(-3, 4), Ok(81)); + assert_eq!(exp(-2, 3), Ok(-8)); + assert_eq!(exp(-3, 3), Ok(-27)); + assert_eq!(exp(1000, 2), Ok(1000000)); + } + + #[test] + fn basic_mod() { + let modulo = jit_function! { modulo(a:i64, b:i64) -> i64 => r##" def modulo(a: int, b: int): return a % b "## }; - assert_eq!(modulo(5, 10), Ok(5)); - assert_eq!(modulo(5, 2), Ok(1)); - assert_eq!(modulo(12, 10), Ok(2)); - assert_eq!(modulo(7, 10), Ok(7)); - assert_eq!(modulo(-3, 1), Ok(0)); - assert_eq!(modulo(-5, 10), Ok(-5)); -} - -#[test] -fn test_power() { - let power = jit_function! { power(a:i64, b:i64) -> i64 => r##" + assert_eq!(modulo(5, 10), Ok(5)); + assert_eq!(modulo(5, 2), Ok(1)); + assert_eq!(modulo(12, 10), Ok(2)); + assert_eq!(modulo(7, 10), Ok(7)); + assert_eq!(modulo(-3, 1), Ok(0)); + assert_eq!(modulo(-5, 10), Ok(-5)); + } + + #[test] + fn basic_power() { + let power = jit_function! { power(a:i64, b:i64) -> i64 => r##" def power(a: int, b: int): return a ** b "## - }; - assert_eq!(power(10, 2), Ok(100)); - assert_eq!(power(5, 1), Ok(5)); - assert_eq!(power(1, 0), Ok(1)); -} - -#[test] -fn test_lshift() { - let lshift = jit_function! { lshift(a:i64, b:i64) -> i64 => r##" + }; + assert_eq!(power(10, 2), Ok(100)); + assert_eq!(power(5, 1), Ok(5)); + assert_eq!(power(1, 0), Ok(1)); + } + + #[test] + fn basic_lshift() { + let lshift = jit_function! { lshift(a:i64, b:i64) -> i64 => r##" def lshift(a: int, b: int): return a << b "## }; - assert_eq!(lshift(5, 10), Ok(5120)); - assert_eq!(lshift(5, 2), Ok(20)); - assert_eq!(lshift(12, 10), Ok(12288)); - assert_eq!(lshift(7, 10), Ok(7168)); - assert_eq!(lshift(-3, 1), Ok(-6)); - assert_eq!(lshift(-10, 2), Ok(-40)); -} - -#[test] -fn test_rshift() { - let rshift = jit_function! { rshift(a:i64, b:i64) -> i64 => r##" + assert_eq!(lshift(5, 10), Ok(5120)); + assert_eq!(lshift(5, 2), Ok(20)); + assert_eq!(lshift(12, 10), Ok(12288)); + assert_eq!(lshift(7, 10), Ok(7168)); + assert_eq!(lshift(-3, 1), Ok(-6)); + assert_eq!(lshift(-10, 2), Ok(-40)); + } + + #[test] + fn basic_rshift() { + let rshift = jit_function! { rshift(a:i64, b:i64) -> i64 => r##" def rshift(a: int, b: int): return a >> b "## }; - assert_eq!(rshift(5120, 10), Ok(5)); - assert_eq!(rshift(20, 2), Ok(5)); - assert_eq!(rshift(12288, 10), Ok(12)); - assert_eq!(rshift(7168, 10), Ok(7)); - assert_eq!(rshift(-3, 1), Ok(-2)); - assert_eq!(rshift(-10, 2), Ok(-3)); -} - -#[test] -fn test_and() { - let bitand = jit_function! { bitand(a:i64, b:i64) -> i64 => r##" + assert_eq!(rshift(5120, 10), Ok(5)); + assert_eq!(rshift(20, 2), Ok(5)); + assert_eq!(rshift(12288, 10), Ok(12)); + assert_eq!(rshift(7168, 10), Ok(7)); + assert_eq!(rshift(-3, 1), Ok(-2)); + assert_eq!(rshift(-10, 2), Ok(-3)); + } + + #[test] + fn basic_and() { + let bitand = jit_function! { bitand(a:i64, b:i64) -> i64 => r##" def bitand(a: int, b: int): return a & b "## }; - assert_eq!(bitand(5120, 10), Ok(0)); - assert_eq!(bitand(20, 16), Ok(16)); - assert_eq!(bitand(12488, 4249), Ok(4232)); - assert_eq!(bitand(7168, 2), Ok(0)); - assert_eq!(bitand(-3, 1), Ok(1)); - assert_eq!(bitand(-10, 2), Ok(2)); -} - -#[test] -fn test_or() { - let bitor = jit_function! { bitor(a:i64, b:i64) -> i64 => r##" + assert_eq!(bitand(5120, 10), Ok(0)); + assert_eq!(bitand(20, 16), Ok(16)); + assert_eq!(bitand(12488, 4249), Ok(4232)); + assert_eq!(bitand(7168, 2), Ok(0)); + assert_eq!(bitand(-3, 1), Ok(1)); + assert_eq!(bitand(-10, 2), Ok(2)); + } + + #[test] + fn basic_or() { + let bitor = jit_function! { bitor(a:i64, b:i64) -> i64 => r##" def bitor(a: int, b: int): return a | b "## }; - assert_eq!(bitor(5120, 10), Ok(5130)); - assert_eq!(bitor(20, 16), Ok(20)); - assert_eq!(bitor(12488, 4249), Ok(12505)); - assert_eq!(bitor(7168, 2), Ok(7170)); - assert_eq!(bitor(-3, 1), Ok(-3)); - assert_eq!(bitor(-10, 2), Ok(-10)); -} - -#[test] -fn test_xor() { - let bitxor = jit_function! { bitxor(a:i64, b:i64) -> i64 => r##" + assert_eq!(bitor(5120, 10), Ok(5130)); + assert_eq!(bitor(20, 16), Ok(20)); + assert_eq!(bitor(12488, 4249), Ok(12505)); + assert_eq!(bitor(7168, 2), Ok(7170)); + assert_eq!(bitor(-3, 1), Ok(-3)); + assert_eq!(bitor(-10, 2), Ok(-10)); + } + + #[test] + fn basic_xor() { + let bitxor = jit_function! { bitxor(a:i64, b:i64) -> i64 => r##" def bitxor(a: int, b: int): return a ^ b "## }; - assert_eq!(bitxor(5120, 10), Ok(5130)); - assert_eq!(bitxor(20, 16), Ok(4)); - assert_eq!(bitxor(12488, 4249), Ok(8273)); - assert_eq!(bitxor(7168, 2), Ok(7170)); - assert_eq!(bitxor(-3, 1), Ok(-4)); - assert_eq!(bitxor(-10, 2), Ok(-12)); -} - -#[test] -fn test_eq() { - let eq = jit_function! { eq(a:i64, b:i64) -> i64 => r##" + assert_eq!(bitxor(5120, 10), Ok(5130)); + assert_eq!(bitxor(20, 16), Ok(4)); + assert_eq!(bitxor(12488, 4249), Ok(8273)); + assert_eq!(bitxor(7168, 2), Ok(7170)); + assert_eq!(bitxor(-3, 1), Ok(-4)); + assert_eq!(bitxor(-10, 2), Ok(-12)); + } + + #[test] + fn basic_eq() { + let eq = jit_function! { eq(a:i64, b:i64) -> i64 => r##" def eq(a: int, b: int): if a == b: return 1 return 0 "## }; - assert_eq!(eq(0, 0), Ok(1)); - assert_eq!(eq(1, 1), Ok(1)); - assert_eq!(eq(0, 1), Ok(0)); - assert_eq!(eq(-200, 200), Ok(0)); -} + assert_eq!(eq(0, 0), Ok(1)); + assert_eq!(eq(1, 1), Ok(1)); + assert_eq!(eq(0, 1), Ok(0)); + assert_eq!(eq(-200, 200), Ok(0)); + } -#[test] -fn test_gt() { - let gt = jit_function! { gt(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_gt() { + let gt = jit_function! { gt(a:i64, b:i64) -> i64 => r##" def gt(a: int, b: int): if a > b: return 1 return 0 "## }; - assert_eq!(gt(5, 2), Ok(1)); - assert_eq!(gt(2, 5), Ok(0)); - assert_eq!(gt(2, 2), Ok(0)); - assert_eq!(gt(5, 5), Ok(0)); - assert_eq!(gt(-1, -10), Ok(1)); - assert_eq!(gt(1, -1), Ok(1)); -} - -#[test] -fn test_lt() { - let lt = jit_function! { lt(a:i64, b:i64) -> i64 => r##" + assert_eq!(gt(5, 2), Ok(1)); + assert_eq!(gt(2, 5), Ok(0)); + assert_eq!(gt(2, 2), Ok(0)); + assert_eq!(gt(5, 5), Ok(0)); + assert_eq!(gt(-1, -10), Ok(1)); + assert_eq!(gt(1, -1), Ok(1)); + } + + #[test] + fn basic_lt() { + let lt = jit_function! { lt(a:i64, b:i64) -> i64 => r##" def lt(a: int, b: int): if a < b: return 1 return 0 "## }; - assert_eq!(lt(-1, -5), Ok(0)); - assert_eq!(lt(10, 0), Ok(0)); - assert_eq!(lt(0, 1), Ok(1)); - assert_eq!(lt(-10, -1), Ok(1)); - assert_eq!(lt(100, 100), Ok(0)); -} + assert_eq!(lt(-1, -5), Ok(0)); + assert_eq!(lt(10, 0), Ok(0)); + assert_eq!(lt(0, 1), Ok(1)); + assert_eq!(lt(-10, -1), Ok(1)); + assert_eq!(lt(100, 100), Ok(0)); + } -#[test] -fn test_gte() { - let gte = jit_function! { gte(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_gte() { + let gte = jit_function! { gte(a:i64, b:i64) -> i64 => r##" def gte(a: int, b: int): if a >= b: return 1 return 0 "## }; - assert_eq!(gte(-64, -64), Ok(1)); - assert_eq!(gte(100, -1), Ok(1)); - assert_eq!(gte(1, 2), Ok(0)); - assert_eq!(gte(1, 0), Ok(1)); -} + assert_eq!(gte(-64, -64), Ok(1)); + assert_eq!(gte(100, -1), Ok(1)); + assert_eq!(gte(1, 2), Ok(0)); + assert_eq!(gte(1, 0), Ok(1)); + } -#[test] -fn test_lte() { - let lte = jit_function! { lte(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_lte() { + let lte = jit_function! { lte(a:i64, b:i64) -> i64 => r##" def lte(a: int, b: int): if a <= b: return 1 return 0 "## }; - assert_eq!(lte(-100, -100), Ok(1)); - assert_eq!(lte(-100, 100), Ok(1)); - assert_eq!(lte(10, 1), Ok(0)); - assert_eq!(lte(0, -2), Ok(0)); -} + assert_eq!(lte(-100, -100), Ok(1)); + assert_eq!(lte(-100, 100), Ok(1)); + assert_eq!(lte(10, 1), Ok(0)); + assert_eq!(lte(0, -2), Ok(0)); + } -#[test] -fn test_minus() { - let minus = jit_function! { minus(a:i64) -> i64 => r##" + #[test] + fn basic_minus() { + let minus = jit_function! { minus(a:i64) -> i64 => r##" def minus(a: int): return -a "## }; - assert_eq!(minus(5), Ok(-5)); - assert_eq!(minus(12), Ok(-12)); - assert_eq!(minus(-7), Ok(7)); - assert_eq!(minus(-3), Ok(3)); - assert_eq!(minus(0), Ok(0)); -} + assert_eq!(minus(5), Ok(-5)); + assert_eq!(minus(12), Ok(-12)); + assert_eq!(minus(-7), Ok(7)); + assert_eq!(minus(-3), Ok(3)); + assert_eq!(minus(0), Ok(0)); + } -#[test] -fn test_plus() { - let plus = jit_function! { plus(a:i64) -> i64 => r##" + #[test] + fn basic_plus() { + let plus = jit_function! { plus(a:i64) -> i64 => r##" def plus(a: int): return +a "## }; - assert_eq!(plus(5), Ok(5)); - assert_eq!(plus(12), Ok(12)); - assert_eq!(plus(-7), Ok(-7)); - assert_eq!(plus(-3), Ok(-3)); - assert_eq!(plus(0), Ok(0)); -} + assert_eq!(plus(5), Ok(5)); + assert_eq!(plus(12), Ok(12)); + assert_eq!(plus(-7), Ok(-7)); + assert_eq!(plus(-3), Ok(-3)); + assert_eq!(plus(0), Ok(0)); + } -#[test] -fn test_not() { - let not_ = jit_function! { not_(a: i64) -> bool => r##" + #[test] + fn basic_not() { + let not_ = jit_function! { not_(a: i64) -> bool => r##" def not_(a: int): return not a "## }; - assert_eq!(not_(0), Ok(true)); - assert_eq!(not_(1), Ok(false)); - assert_eq!(not_(-1), Ok(false)); + assert_eq!(not_(0), Ok(true)); + assert_eq!(not_(1), Ok(false)); + assert_eq!(not_(-1), Ok(false)); + } } diff --git a/crates/jit/tests/misc_tests.rs b/crates/jit/tests/misc_tests.rs index 25d66c46c06..b73100ad6ec 100644 --- a/crates/jit/tests/misc_tests.rs +++ b/crates/jit/tests/misc_tests.rs @@ -1,74 +1,76 @@ -use rustpython_jit::{AbiValue, JitArgumentError}; +#[cfg(test)] +mod tests { + use rustpython_jit::{AbiValue, JitArgumentError}; -// TODO currently broken -// #[test] -// fn test_no_return_value() { -// let func = jit_function! { func() => r##" -// def func(): -// pass -// "## }; -// -// assert_eq!(func(), Ok(())); -// } + // TODO currently broken + // #[test] + // fn test_no_return_value() { + // let func = jit_function! { func() => r##" + // def func(): + // pass + // "## }; + // + // assert_eq!(func(), Ok(())); + // } -#[test] -fn test_invoke() { - let func = jit_function! { func => r##" + #[test] + fn invoke() { + let func = jit_function! { func => r##" def func(a: int, b: float): return 1 "## }; - assert_eq!( - func.invoke(&[AbiValue::Int(1)]), - Err(JitArgumentError::WrongNumberOfArguments) - ); - assert_eq!( - func.invoke(&[AbiValue::Int(1), AbiValue::Float(2.0), AbiValue::Int(0)]), - Err(JitArgumentError::WrongNumberOfArguments) - ); - assert_eq!( - func.invoke(&[AbiValue::Int(1), AbiValue::Int(1)]), - Err(JitArgumentError::ArgumentTypeMismatch) - ); - assert_eq!( - func.invoke(&[AbiValue::Int(1), AbiValue::Float(2.0)]), - Ok(Some(AbiValue::Int(1))) - ); -} + assert_eq!( + func.invoke(&[AbiValue::Int(1)]), + Err(JitArgumentError::WrongNumberOfArguments) + ); + assert_eq!( + func.invoke(&[AbiValue::Int(1), AbiValue::Float(2.0), AbiValue::Int(0)]), + Err(JitArgumentError::WrongNumberOfArguments) + ); + assert_eq!( + func.invoke(&[AbiValue::Int(1), AbiValue::Int(1)]), + Err(JitArgumentError::ArgumentTypeMismatch) + ); + assert_eq!( + func.invoke(&[AbiValue::Int(1), AbiValue::Float(2.0)]), + Ok(Some(AbiValue::Int(1))) + ); + } -#[test] -fn test_args_builder() { - let func = jit_function! { func=> r##" + #[test] + fn args_builder() { + let func = jit_function! { func=> r##" def func(a: int, b: float): return 1 "## }; - let mut args_builder = func.args_builder(); - assert_eq!(args_builder.set(0, AbiValue::Int(1)), Ok(())); - assert!(args_builder.is_set(0)); - assert!(!args_builder.is_set(1)); - assert_eq!( - args_builder.set(1, AbiValue::Int(1)), - Err(JitArgumentError::ArgumentTypeMismatch) - ); - assert!(args_builder.is_set(0)); - assert!(!args_builder.is_set(1)); - assert!(args_builder.into_args().is_none()); + let mut args_builder = func.args_builder(); + assert_eq!(args_builder.set(0, AbiValue::Int(1)), Ok(())); + assert!(args_builder.is_set(0)); + assert!(!args_builder.is_set(1)); + assert_eq!( + args_builder.set(1, AbiValue::Int(1)), + Err(JitArgumentError::ArgumentTypeMismatch) + ); + assert!(args_builder.is_set(0)); + assert!(!args_builder.is_set(1)); + assert!(args_builder.into_args().is_none()); - let mut args_builder = func.args_builder(); - assert_eq!(args_builder.set(0, AbiValue::Int(1)), Ok(())); - assert_eq!(args_builder.set(1, AbiValue::Float(1.0)), Ok(())); - assert!(args_builder.is_set(0)); - assert!(args_builder.is_set(1)); + let mut args_builder = func.args_builder(); + assert_eq!(args_builder.set(0, AbiValue::Int(1)), Ok(())); + assert_eq!(args_builder.set(1, AbiValue::Float(1.0)), Ok(())); + assert!(args_builder.is_set(0)); + assert!(args_builder.is_set(1)); - let args = args_builder.into_args(); - assert!(args.is_some()); - assert_eq!(args.unwrap().invoke(), Some(AbiValue::Int(1))); -} + let args = args_builder.into_args(); + assert!(args.is_some()); + assert_eq!(args.unwrap().invoke(), Some(AbiValue::Int(1))); + } -#[test] -fn test_if_else() { - let if_else = jit_function! { if_else(a:i64) -> i64 => r##" + #[test] + fn basic_if_else() { + let if_else = jit_function! { if_else(a:i64) -> i64 => r##" def if_else(a: int): if a: return 42 @@ -79,15 +81,15 @@ fn test_if_else() { return 0 "## }; - assert_eq!(if_else(0), Ok(0)); - assert_eq!(if_else(1), Ok(42)); - assert_eq!(if_else(-1), Ok(42)); - assert_eq!(if_else(100), Ok(42)); -} + assert_eq!(if_else(0), Ok(0)); + assert_eq!(if_else(1), Ok(42)); + assert_eq!(if_else(-1), Ok(42)); + assert_eq!(if_else(100), Ok(42)); + } -#[test] -fn test_while_loop() { - let while_loop = jit_function! { while_loop(a:i64) -> i64 => r##" + #[test] + fn basic_while_loop() { + let while_loop = jit_function! { while_loop(a:i64) -> i64 => r##" def while_loop(a: int): b = 0 while a > 0: @@ -95,32 +97,33 @@ fn test_while_loop() { a -= 1 return b "## }; - assert_eq!(while_loop(0), Ok(0)); - assert_eq!(while_loop(-1), Ok(0)); - assert_eq!(while_loop(1), Ok(1)); - assert_eq!(while_loop(10), Ok(10)); -} + assert_eq!(while_loop(0), Ok(0)); + assert_eq!(while_loop(-1), Ok(0)); + assert_eq!(while_loop(1), Ok(1)); + assert_eq!(while_loop(10), Ok(10)); + } -#[test] -fn test_unpack_tuple() { - let unpack_tuple = jit_function! { unpack_tuple(a:i64, b:i64) -> i64 => r##" + #[test] + fn basic_unpack_tuple() { + let unpack_tuple = jit_function! { unpack_tuple(a:i64, b:i64) -> i64 => r##" def unpack_tuple(a: int, b: int): a, b = b, a return a "## }; - assert_eq!(unpack_tuple(0, 1), Ok(1)); - assert_eq!(unpack_tuple(1, 2), Ok(2)); -} + assert_eq!(unpack_tuple(0, 1), Ok(1)); + assert_eq!(unpack_tuple(1, 2), Ok(2)); + } -#[test] -fn test_recursive_fib() { - let fib = jit_function! { fib(n: i64) -> i64 => r##" + #[test] + fn recursive_fib() { + let fib = jit_function! { fib(n: i64) -> i64 => r##" def fib(n: int) -> int: if n == 0 or n == 1: return 1 return fib(n-1) + fib(n-2) "## }; - assert_eq!(fib(10), Ok(89)); + assert_eq!(fib(10), Ok(89)); + } } diff --git a/crates/jit/tests/none_tests.rs b/crates/jit/tests/none_tests.rs index 561f1f01c96..a1d36af80b8 100644 --- a/crates/jit/tests/none_tests.rs +++ b/crates/jit/tests/none_tests.rs @@ -1,16 +1,18 @@ -#[test] -fn test_not() { - let not_ = jit_function! { not_(x: i64) -> bool => r##" +#[cfg(test)] +mod tests { + #[test] + fn basic_not() { + let not_ = jit_function! { not_(x: i64) -> bool => r##" def not_(x: int): return not None "## }; - assert_eq!(not_(0), Ok(true)); -} + assert_eq!(not_(0), Ok(true)); + } -#[test] -fn test_if_not() { - let if_not = jit_function! { if_not(x: i64) -> i64 => r##" + #[test] + fn basic_if_not() { + let if_not = jit_function! { if_not(x: i64) -> i64 => r##" def if_not(x: int): if not None: return 1 @@ -20,5 +22,6 @@ fn test_if_not() { return -1 "## }; - assert_eq!(if_not(0), Ok(1)); + assert_eq!(if_not(0), Ok(1)); + } } diff --git a/crates/literal/src/float.rs b/crates/literal/src/float.rs index 0856f646b22..79caca0592c 100644 --- a/crates/literal/src/float.rs +++ b/crates/literal/src/float.rs @@ -3,7 +3,7 @@ use alloc::borrow::ToOwned; use alloc::format; use alloc::string::{String, ToString}; use core::f64; -use num_traits::{Float, Zero}; +use num_traits::Zero; pub fn parse_str(literal: &str) -> Option { parse_inner(literal.trim().as_bytes()) @@ -209,6 +209,111 @@ pub fn format_general( } } +fn prefer_cpython_tie_repr(s: String, value: f64) -> String { + let Some(exponent_pos) = s.find('e') else { + return s; + }; + let Some(digit_pos) = s[..exponent_pos].bytes().rposition(|b| b.is_ascii_digit()) else { + return s; + }; + + let digit = s.as_bytes()[digit_pos]; + if digit == b'0' { + return s; + } + let decremented = digit - 1; + if !(decremented - b'0').is_multiple_of(2) { + return s; + } + + let mut candidate = s.clone(); + candidate.replace_range( + digit_pos..digit_pos + 1, + core::str::from_utf8(&[decremented]).unwrap(), + ); + if parse_str(&candidate).is_none_or(|parsed| parsed.to_bits() != value.to_bits()) { + return s; + } + + let Some(current_distance) = decimal_distance_to_f64(&s, value) else { + return s; + }; + let Some(candidate_distance) = decimal_distance_to_f64(&candidate, value) else { + return s; + }; + + if candidate_distance <= current_distance { + candidate + } else { + s + } +} + +fn checked_pow_u128(base: u128, exp: u32) -> Option { + let mut result = 1u128; + for _ in 0..exp { + result = result.checked_mul(base)?; + } + Some(result) +} + +fn parse_decimal_rational(s: &str) -> Option<(u128, u32)> { + let exponent_pos = s.find('e')?; + let exponent = s[exponent_pos + 1..].parse::().ok()?; + let significand = s[..exponent_pos] + .strip_prefix('-') + .unwrap_or(&s[..exponent_pos]); + let dot_pos = significand.find('.'); + let frac_digits = dot_pos + .map(|pos| significand.len().saturating_sub(pos + 1)) + .unwrap_or(0); + let mut digits = String::with_capacity(significand.len()); + for ch in significand.chars() { + if ch != '.' { + digits.push(ch); + } + } + let mut int = digits.parse::().ok()?; + let mut scale = i32::try_from(frac_digits).ok()? - exponent; + if scale < 0 { + int = int.checked_mul(checked_pow_u128(10, (-scale) as u32)?)?; + scale = 0; + } + Some((int, scale as u32)) +} + +fn f64_mantissa_exponent(value: f64) -> Option<(u128, i32)> { + let bits = value.abs().to_bits(); + let exponent = ((bits >> 52) & 0x7ff) as i32; + let fraction = bits & ((1u64 << 52) - 1); + if exponent == 0 { + Some((u128::from(fraction), 1 - 1023 - 52)) + } else if exponent < 0x7ff { + Some((u128::from((1u64 << 52) | fraction), exponent - 1023 - 52)) + } else { + None + } +} + +fn decimal_distance_to_f64(s: &str, value: f64) -> Option { + let (decimal_int, decimal_scale) = parse_decimal_rational(s)?; + let (mantissa, binary_exponent) = f64_mantissa_exponent(value)?; + if binary_exponent >= 0 || decimal_scale > 38 { + return None; + } + + let binary_scale = u32::try_from(-binary_exponent).ok()?; + let common_twos = decimal_scale.max(binary_scale); + let decimal_scaled = + decimal_int.checked_mul(checked_pow_u128(2, common_twos - decimal_scale)?)?; + let five_power = checked_pow_u128(5, decimal_scale)?; + let binary_scaled = mantissa + .checked_mul(checked_pow_u128(2, common_twos - binary_scale)?)? + .checked_mul(five_power)?; + + Some(decimal_scaled.abs_diff(binary_scaled)) +} + // TODO: rewrite using format_general pub fn to_string(value: f64) -> String { let lit = format!("{value:e}"); @@ -223,7 +328,7 @@ pub fn to_string(value: f64) -> String { value.to_string() } } else { - format!("{significand}e{exponent:+#03}") + prefer_cpython_tie_repr(format!("{significand}e{exponent:+#03}"), value) } } else { let mut s = value.to_string(); @@ -232,6 +337,22 @@ pub fn to_string(value: f64) -> String { } } +#[cfg(test)] +mod tests { + use super::to_string; + + #[test] + fn repr_uses_cpython_tie_digit_for_power_of_two() { + assert_eq!(to_string(2.0f64.powi(-25)), "2.9802322387695312e-08"); + assert_eq!(to_string((-2.0f64).powi(-25)), "-2.9802322387695312e-08"); + assert_eq!(to_string(2.0f64.powi(-26)), "1.4901161193847656e-08"); + assert_eq!( + to_string(2.0f64.powi(-14) - 2.0f64.powi(-25)), + "6.1005353927612305e-05" + ); + } +} + pub fn from_hex(s: &str) -> Option { if let Ok(f) = hexf_parse::parse_hexf64(s, false) { return Some(f); @@ -281,22 +402,23 @@ pub fn from_hex(s: &str) -> Option { } pub fn to_hex(value: f64) -> String { - let (mantissa, exponent, sign) = value.integer_decode(); - let sign_fmt = if sign < 0 { "-" } else { "" }; + let bits = value.to_bits(); + let sign_fmt = if bits >> 63 != 0 { "-" } else { "" }; match value { value if value.is_zero() => format!("{sign_fmt}0x0.0p+0"), value if value.is_infinite() => format!("{sign_fmt}inf"), value if value.is_nan() => "nan".to_owned(), _ => { - const BITS: i16 = 52; - const FRACT_MASK: u64 = 0xf_ffff_ffff_ffff; - format!( - "{}{:#x}.{:013x}p{:+}", - sign_fmt, - mantissa >> BITS, - mantissa & FRACT_MASK, - exponent + BITS - ) + const FRACT_MASK: u64 = (1u64 << 52) - 1; + const EXP_MASK: u64 = 0x7ff; + let exponent = (bits >> 52) & EXP_MASK; + let fraction = bits & FRACT_MASK; + if exponent == 0 { + format!("{sign_fmt}0x0.{fraction:013x}p-1022") + } else { + let exponent = i32::try_from(exponent).unwrap() - 1023; + format!("{sign_fmt}0x1.{fraction:013x}p{exponent:+}") + } } } } @@ -304,6 +426,10 @@ pub fn to_hex(value: f64) -> String { #[test] fn test_to_hex() { use rand::Rng; + assert_eq!(to_hex(f64::from_bits(1)), "0x0.0000000000001p-1022"); + assert_eq!(to_hex(f64::from_bits(2)), "0x0.0000000000002p-1022"); + assert_eq!(to_hex(-f64::from_bits(1)), "-0x0.0000000000001p-1022"); + assert_eq!(to_hex(f64::MIN_POSITIVE), "0x1.0000000000000p-1022"); for _ in 0..20000 { let bytes = rand::rng().random::(); let f = f64::from_bits(bytes); diff --git a/crates/sre_engine/tests/tests.rs b/crates/sre_engine/tests/tests.rs index 795c9a05d42..f1edb64cdbf 100644 --- a/crates/sre_engine/tests/tests.rs +++ b/crates/sre_engine/tests/tests.rs @@ -1,200 +1,203 @@ // spell-checker:disable -use rustpython_sre_engine::{Request, State, StrDrive}; - -struct Pattern { - #[allow(unused)] - pattern: &'static str, - code: &'static [u32], -} +#[cfg(test)] +mod tests { + use rustpython_sre_engine::{Request, State, StrDrive}; + + struct Pattern { + #[expect(unused, reason = "Needed for automated script")] + pattern: &'static str, + code: &'static [u32], + } -impl Pattern { - fn state<'a, S: StrDrive>(&self, string: S) -> (Request<'a, S>, State) { - let req = Request::new(string, 0, usize::MAX, self.code, false); - let state = State::default(); - (req, state) + impl Pattern { + fn state<'a, S: StrDrive>(&self, string: S) -> (Request<'a, S>, State) { + let req = Request::new(string, 0, usize::MAX, self.code, false); + let state = State::default(); + (req, state) + } } -} -#[test] -fn test_2427() { - // pattern lookbehind = re.compile(r'(?x)++x') - // START GENERATED by generate_tests.py - #[rustfmt::skip] let p = Pattern { pattern: "(?>x)++x", code: &[14, 4, 0, 2, 4294967295, 28, 8, 1, 4294967295, 27, 4, 16, 120, 1, 1, 16, 120, 1] }; - // END GENERATED - let (req, mut state) = p.state("xxx"); - assert!(!state.py_match(&req)); -} + #[test] + fn possessive_atomic_group() { + // pattern p = re.compile('(?>x)++x') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { pattern: "(?>x)++x", code: &[14, 4, 0, 2, 4294967295, 28, 8, 1, 4294967295, 27, 4, 16, 120, 1, 1, 16, 120, 1] }; + // END GENERATED + let (req, mut state) = p.state("xxx"); + assert!(!state.py_match(&req)); + } -#[test] -fn test_bug_20998() { - // pattern p = re.compile('[a-c]+', re.I) - // START GENERATED by generate_tests.py - #[rustfmt::skip] let p = Pattern { pattern: "[a-c]+", code: &[14, 4, 0, 1, 4294967295, 24, 10, 1, 4294967295, 39, 5, 22, 97, 99, 0, 1, 1] }; - // END GENERATED - let (mut req, mut state) = p.state("ABC"); - req.match_all = true; - assert!(state.py_match(&req)); - assert_eq!(state.cursor.position, 3); -} + #[test] + fn bug_20998() { + // pattern p = re.compile('[a-c]+', re.I) + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { pattern: "[a-c]+", code: &[14, 4, 0, 1, 4294967295, 24, 10, 1, 4294967295, 39, 5, 22, 97, 99, 0, 1, 1] }; + // END GENERATED + let (mut req, mut state) = p.state("ABC"); + req.match_all = true; + assert!(state.py_match(&req)); + assert_eq!(state.cursor.position, 3); + } -#[test] -fn test_bigcharset() { - // pattern p = re.compile('[a-z]*', re.I) - // START GENERATED by generate_tests.py - #[rustfmt::skip] let p = Pattern { pattern: "[a-z]*", code: &[14, 4, 0, 0, 4294967295, 24, 97, 0, 4294967295, 39, 92, 10, 3, 33685760, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 0, 0, 0, 134217726, 0, 0, 0, 0, 0, 131072, 0, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1] }; - // END GENERATED - let (req, mut state) = p.state("x "); - assert!(state.py_match(&req)); - assert_eq!(state.cursor.position, 1); -} + #[test] + fn bigcharset() { + // pattern p = re.compile('[a-z]*', re.I) + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { pattern: "[a-z]*", code: &[14, 4, 0, 0, 4294967295, 24, 97, 0, 4294967295, 39, 92, 10, 3, 33685760, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 0, 0, 0, 134217726, 0, 0, 0, 0, 0, 131072, 0, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1] }; + // END GENERATED + let (req, mut state) = p.state("x "); + assert!(state.py_match(&req)); + assert_eq!(state.cursor.position, 1); + } -#[test] -fn test_search_nonascii() { - #[allow(unused)] + #[test] + fn search_nonascii() { + #[allow(unused)] // pattern p = re.compile('\xe0+') // START GENERATED by generate_tests.py #[rustfmt::skip] let p = Pattern { pattern: "\u{e0}+", code: &[14, 4, 0, 1, 4294967295, 24, 6, 1, 4294967295, 16, 224, 1, 1] }; - // END GENERATED + // END GENERATED + } } diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index 25a4b8c12f2..4547368322d 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -16,15 +16,16 @@ host_env = ["rustpython-vm/host_env"] compiler = ["rustpython-vm/compiler"] threading = ["rustpython-common/threading", "rustpython-vm/threading"] sqlite = ["dep:libsqlite3-sys"] -# SSL backends - default to rustls -ssl = [] -ssl-rustls = ["ssl", "rustls", "rustls-native-certs", "rustls-pemfile", "rustls-platform-verifier", "x509-cert", "x509-parser", "der", "pem-rfc7468", "webpki-roots", "aws-lc-rs", "oid-registry", "pkcs8"] -ssl-rustls-fips = ["ssl-rustls", "aws-lc-rs/fips"] +# SSL backends +ssl = ["host_env"] +ssl-rustls = ["__ssl-rustls", "rustls/custom-provider"] ssl-openssl = ["ssl", "openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"] -ssl-vendor = ["ssl-openssl", "openssl/vendored"] +ssl-openssl-vendor = ["ssl-openssl", "openssl/vendored"] tkinter = ["dep:tk-sys", "dep:tcl-sys", "dep:widestring"] flame-it = ["flame"] +__ssl-rustls = ["ssl", "rustls", "rustls-native-certs", "rustls-pemfile", "rustls-platform-verifier", "x509-cert", "x509-parser", "der", "pem-rfc7468", "webpki-roots", "oid-registry", "pkcs8"] + [dependencies] # rustpython crates rustpython-derive = { workspace = true } @@ -37,7 +38,6 @@ ruff_python_ast = { workspace = true } ruff_text_size = { workspace = true } ruff_source_file = { workspace = true } -ahash = { workspace = true } ascii = { workspace = true } crossbeam-utils = { workspace = true } flame = { workspace = true, optional = true } @@ -45,13 +45,13 @@ hex = { workspace = true } itertools = { workspace = true } indexmap = { workspace = true } libc = { workspace = true } -nix = { workspace = true } num-complex = { workspace = true } malachite-bigint = { workspace = true } num-traits = { workspace = true } num_enum = { workspace = true } parking_lot = { workspace = true } phf = { workspace = true, default-features = true, features = ["macros"] } +rapidhash = { workspace = true } memchr = { workspace = true } base64 = { workspace = true } @@ -101,16 +101,8 @@ chrono.workspace = true mac_address = { workspace = true } uuid = { workspace = true, features = ["v1"] } -[target.'cfg(all(unix, not(target_os = "redox"), not(target_os = "ios")))'.dependencies] -termios = { workspace = true } - -[target.'cfg(unix)'.dependencies] -rustix = { workspace = true } - # mmap + socket dependencies [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -memmap2 = { workspace = true } -page_size = { workspace = true } gethostname = { workspace = true } socket2 = { workspace = true, features = ["all"] } dns-lookup = { workspace = true } @@ -122,7 +114,7 @@ openssl-probe = { workspace = true, optional = true } foreign-types-shared = { workspace = true, optional = true } # Rustls dependencies (optional, for ssl-rustls feature) -rustls = { workspace = true, default-features = false, features = ["std", "tls12", "aws_lc_rs"], optional = true } +rustls = { workspace = true, default-features = false, features = ["std", "tls12"], optional = true } rustls-native-certs = { workspace = true, optional = true } rustls-pemfile = { workspace = true, optional = true } rustls-platform-verifier = { workspace = true, optional = true } @@ -131,7 +123,6 @@ x509-parser = { workspace = true, optional = true } der = { workspace = true, optional = true } pem-rfc7468 = { workspace = true, features = ["alloc"], optional = true } webpki-roots = { workspace = true, optional = true } -aws-lc-rs = { workspace = true, optional = true } oid-registry = { workspace = true, features = ["x509", "pkcs1", "nist_algs"], optional = true } pkcs8 = { workspace = true, features = ["encryption", "pkcs5", "pem"], optional = true } @@ -142,26 +133,8 @@ liblzma-sys = { workspace = true } [target.'cfg(windows)'.dependencies] paste = { workspace = true } -schannel = { workspace = true } widestring = { workspace = true } -[target.'cfg(windows)'.dependencies.windows-sys] -workspace = true -features = [ - "Win32_Foundation", - "Win32_Networking_WinSock", - "Win32_NetworkManagement_IpHelper", - "Win32_NetworkManagement_Ndis", - "Win32_Security_Cryptography", - "Win32_Storage_FileSystem", - "Win32_System_Diagnostics_Debug", - "Win32_System_Environment", - "Win32_System_Console", - "Win32_System_IO", - "Win32_System_Memory", - "Win32_System_Threading" -] - [target.'cfg(target_os = "macos")'.dependencies] system-configuration = { workspace = true } diff --git a/crates/stdlib/src/_opcode.rs b/crates/stdlib/src/_opcode.rs index dc38e45443c..bfdb06b305f 100644 --- a/crates/stdlib/src/_opcode.rs +++ b/crates/stdlib/src/_opcode.rs @@ -238,7 +238,7 @@ output = re.sub(r'(0xdeadbeef', tmp } #[test] - fn test_if_ors() { + fn if_ors() { assert_dis_snapshot!( r#" if True or False or False: @@ -248,7 +248,7 @@ if True or False or False: } #[test] - fn test_if_ands() { + fn if_ands() { assert_dis_snapshot!( r#" if True and False and False: @@ -258,7 +258,7 @@ if True and False and False: } #[test] - fn test_if_mixed() { + fn if_mixed() { assert_dis_snapshot!( r#" if (True and False) or (False and True): @@ -268,7 +268,7 @@ if (True and False) or (False and True): } #[test] - fn test_nested_bool_op() { + fn nested_bool_op() { assert_dis_snapshot!( r#" x = Test() and False or False @@ -277,7 +277,7 @@ x = Test() and False or False } #[test] - fn test_const_no_op() { + fn const_no_op() { assert_dis_snapshot!( r#" x = not True @@ -286,7 +286,7 @@ x = not True } #[test] - fn test_constant_true_if_pass_keeps_line_anchor_nop() { + fn constant_true_if_pass_keeps_line_anchor_nop() { assert_dis_snapshot!( r#" if 1: @@ -296,7 +296,7 @@ if 1: } #[test] - fn test_nested_double_async_with() { + fn nested_double_async_with() { assert_dis_snapshot!( r#" async def test(): @@ -314,7 +314,7 @@ async def test(): } #[test] - fn test_bare_function_annotations_check_attribute_and_subscript_expressions() { + fn bare_function_annotations_check_attribute_and_subscript_expressions() { assert_dis_snapshot!( r#" def f(one: int): diff --git a/crates/stdlib/src/_sqlite3.rs b/crates/stdlib/src/_sqlite3.rs index fd7fec4feb3..5dfdf6f4f1e 100644 --- a/crates/stdlib/src/_sqlite3.rs +++ b/crates/stdlib/src/_sqlite3.rs @@ -44,6 +44,7 @@ mod _sqlite3 { sqlite3_value_text, sqlite3_value_type, }; use malachite_bigint::Sign; + use num_traits::ToPrimitive; use rustpython_common::{ atomic::{Ordering, PyAtomic, Radium}, hash::PyHash, @@ -2551,28 +2552,35 @@ mod _sqlite3 { value: Option, vm: &VirtualMachine, ) -> PyResult<()> { - let Some(value) = value else { - return Err(vm.new_type_error("Blob doesn't support slice deletion")); - }; self.ensure_connection_open(vm)?; let inner = self.inner(vm)?; if let Some(index) = needle.try_index_opt(vm) { // Handle single item assignment: blob[i] = b - let Some(value) = value.downcast_ref::() else { + let Some(value) = value else { + return Err(vm.new_type_error("Blob doesn't support item deletion")); + }; + let Some(int_val) = value.downcast_ref::() else { return Err(vm.new_type_error(format!( "'{}' object cannot be interpreted as an integer", value.class() ))); }; - let value = value.try_to_primitive::(vm)?; let blob_len = inner.blob.bytes(); let index = Self::wrapped_index(index?, blob_len, vm)?; - Self::expect_write(blob_len, 1, index, vm)?; - let ret = inner.blob.write_single(value, index); + // Mirror CPython ass_subscript_index: use PyLong_AsLong, treat any + // overflow (e.g. 2**65) as -1, then validate the [0, 255] range. + let val = int_val.as_bigint().to_i64().unwrap_or(-1); + if !(0..=255).contains(&val) { + return Err(vm.new_value_error("byte must be in range(0, 256)")); + } + let ret = inner.blob.write_single(val as u8, index); self.check(ret, vm) } else if let Some(slice) = needle.downcast_ref::() { // Handle slice assignment: blob[a:b:c] = b"..." + let Some(value) = value else { + return Err(vm.new_type_error("Blob doesn't support slice deletion")); + }; let value_buf = PyBuffer::try_from_borrowed_object(vm, &value)?; let buf = value_buf @@ -2604,25 +2612,25 @@ mod _sqlite3 { self.check(ret, vm) } else { let span_len = range.end - range.start; + let range_start = range.start; let mut temp_buf = vec![0u8; span_len]; let ret = inner.blob.read( temp_buf.as_mut_ptr().cast(), span_len as c_int, - range.start as c_int, + range_start as c_int, ); self.check(ret, vm)?; - let mut i_in_temp: usize = 0; - for i_in_src in 0..slice_len { - temp_buf[i_in_temp] = buf[i_in_src]; - i_in_temp += step as usize; + let iter = SaturatedSliceIter::from_adjust_indices(range, step, slice_len); + for (i_in_src, abs_idx) in iter.enumerate() { + temp_buf[abs_idx - range_start] = buf[i_in_src]; } let ret = inner.blob.write( temp_buf.as_ptr().cast(), span_len as c_int, - range.start as c_int, + range_start as c_int, ); self.check(ret, vm) } diff --git a/crates/stdlib/src/_testconsole.rs b/crates/stdlib/src/_testconsole.rs index 0db508e3da5..78cba3b397d 100644 --- a/crates/stdlib/src/_testconsole.rs +++ b/crates/stdlib/src/_testconsole.rs @@ -5,23 +5,14 @@ mod _testconsole { use crate::vm::{ PyObjectRef, PyResult, VirtualMachine, convert::IntoPyException, function::ArgBytesLike, }; - use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE; - - type Handle = windows_sys::Win32::Foundation::HANDLE; + use rustpython_host_env::testconsole as host_testconsole; #[pyfunction] fn write_input(file: PyObjectRef, s: ArgBytesLike, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Console::{INPUT_RECORD, KEY_EVENT, WriteConsoleInputW}; - // Get the fd from the file object via fileno() let fd_obj = vm.call_method(&file, "fileno", ())?; let fd: i32 = fd_obj.try_into_value(vm)?; - let handle = unsafe { libc::get_osfhandle(fd) } as Handle; - if handle == INVALID_HANDLE_VALUE { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - let data = s.borrow_buf(); let data = &*data; @@ -33,39 +24,7 @@ mod _testconsole { .chunks_exact(2) .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) .collect(); - - let size = wchars.len() as u32; - - // Create INPUT_RECORD array - let mut records: Vec = Vec::with_capacity(wchars.len()); - for &wc in &wchars { - // SAFETY: zeroing and accessing the union field for KEY_EVENT - let mut rec: INPUT_RECORD = unsafe { core::mem::zeroed() }; - rec.EventType = KEY_EVENT as u16; - rec.Event.KeyEvent.bKeyDown = 1; // TRUE - rec.Event.KeyEvent.wRepeatCount = 1; - rec.Event.KeyEvent.uChar.UnicodeChar = wc; - records.push(rec); - } - - let mut total: u32 = 0; - while total < size { - let mut wrote: u32 = 0; - let res = unsafe { - WriteConsoleInputW( - handle, - records[total as usize..].as_ptr(), - size - total, - &mut wrote, - ) - }; - if res == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - total += wrote; - } - - Ok(()) + host_testconsole::write_console_input(fd, &wchars).map_err(|e| e.into_pyexception(vm)) } #[pyfunction] diff --git a/crates/stdlib/src/contextvars.rs b/crates/stdlib/src/contextvars.rs index ee5942755c1..0a6e0f12314 100644 --- a/crates/stdlib/src/contextvars.rs +++ b/crates/stdlib/src/contextvars.rs @@ -28,7 +28,7 @@ mod _contextvars { use indexmap::IndexMap; // TODO: Real hamt implementation - type Hamt = IndexMap, PyObjectRef, ahash::RandomState>; + type Hamt = IndexMap, PyObjectRef, rapidhash::quality::RandomState>; #[pyclass(no_attr, name = "Hamt", module = "contextvars")] #[derive(Debug, PyPayload)] diff --git a/crates/stdlib/src/faulthandler.rs b/crates/stdlib/src/faulthandler.rs index 9ea970767a9..b3af501d83a 100644 --- a/crates/stdlib/src/faulthandler.rs +++ b/crates/stdlib/src/faulthandler.rs @@ -3,9 +3,10 @@ pub(crate) use decl::module_def; #[allow(static_mut_refs)] // TODO: group code only with static mut refs #[pymodule(name = "faulthandler")] mod decl { + #[cfg(any(unix, windows))] + use crate::vm::frame::Frame; use crate::vm::{ PyObjectRef, PyResult, VirtualMachine, - frame::Frame, function::{ArgIntoFloat, OptionalArg}, }; use alloc::sync::Arc; @@ -13,76 +14,11 @@ mod decl { use core::time::Duration; use parking_lot::{Condvar, Mutex}; #[cfg(any(unix, windows))] + use rustpython_host_env::faulthandler as host_faulthandler; + #[cfg(any(unix, windows))] use rustpython_host_env::os::{get_errno, set_errno}; use std::thread; - /// fault_handler_t - #[cfg(unix)] - struct FaultHandler { - signum: libc::c_int, - enabled: bool, - name: &'static str, - previous: libc::sigaction, - } - - #[cfg(windows)] - struct FaultHandler { - signum: libc::c_int, - enabled: bool, - name: &'static str, - previous: libc::sighandler_t, - } - - #[cfg(unix)] - impl FaultHandler { - const fn new(signum: libc::c_int, name: &'static str) -> Self { - Self { - signum, - enabled: false, - name, - // SAFETY: sigaction is a C struct that can be zero-initialized - previous: unsafe { core::mem::zeroed() }, - } - } - } - - #[cfg(windows)] - impl FaultHandler { - const fn new(signum: libc::c_int, name: &'static str) -> Self { - Self { - signum, - enabled: false, - name, - previous: 0, - } - } - } - - /// faulthandler_handlers[] - /// Number of fatal signals - #[cfg(unix)] - const FAULTHANDLER_NSIGNALS: usize = 5; - #[cfg(windows)] - const FAULTHANDLER_NSIGNALS: usize = 4; - - // Signal handlers use mutable statics matching faulthandler.c implementation. - #[cfg(unix)] - static mut FAULTHANDLER_HANDLERS: [FaultHandler; FAULTHANDLER_NSIGNALS] = [ - FaultHandler::new(libc::SIGBUS, "Bus error"), - FaultHandler::new(libc::SIGILL, "Illegal instruction"), - FaultHandler::new(libc::SIGFPE, "Floating-point exception"), - FaultHandler::new(libc::SIGABRT, "Aborted"), - FaultHandler::new(libc::SIGSEGV, "Segmentation fault"), - ]; - - #[cfg(windows)] - static mut FAULTHANDLER_HANDLERS: [FaultHandler; FAULTHANDLER_NSIGNALS] = [ - FaultHandler::new(libc::SIGILL, "Illegal instruction"), - FaultHandler::new(libc::SIGFPE, "Floating-point exception"), - FaultHandler::new(libc::SIGABRT, "Aborted"), - FaultHandler::new(libc::SIGSEGV, "Segmentation fault"), - ]; - /// fatal_error state struct FatalErrorState { enabled: AtomicBool, @@ -124,7 +60,7 @@ mod decl { #[cfg(any(unix, windows))] fn puts_bytes(fd: i32, s: &[u8]) { - let _ = unsafe { libc::write(fd, s.as_ptr().cast::(), s.len() as _) }; + host_faulthandler::write_fd(fd, s); } // _Py_DumpHexadecimal (traceback.c) @@ -165,14 +101,9 @@ mod decl { } /// Get current thread ID - #[cfg(unix)] - fn current_thread_id() -> u64 { - unsafe { libc::pthread_self() as u64 } - } - - #[cfg(windows)] + #[cfg(any(unix, windows))] fn current_thread_id() -> u64 { - unsafe { windows_sys::Win32::System::Threading::GetCurrentThreadId() as u64 } + host_faulthandler::current_thread_id() } // write_thread_id (traceback.c:1240-1256) @@ -259,6 +190,7 @@ mod decl { } /// MAX_STRING_LENGTH in traceback.c + #[cfg(any(unix, windows))] const MAX_STRING_LENGTH: usize = 500; /// Truncate a UTF-8 string to at most `max_bytes` without splitting a @@ -434,29 +366,6 @@ mod decl { // Signal handlers - /// faulthandler_disable_fatal_handler (faulthandler.c:310-321) - #[cfg(unix)] - unsafe fn faulthandler_disable_fatal_handler(handler: &mut FaultHandler) { - if !handler.enabled { - return; - } - handler.enabled = false; - unsafe { - libc::sigaction(handler.signum, &handler.previous, core::ptr::null_mut()); - } - } - - #[cfg(windows)] - unsafe fn faulthandler_disable_fatal_handler(handler: &mut FaultHandler) { - if !handler.enabled { - return; - } - handler.enabled = false; - unsafe { - libc::signal(handler.signum, handler.previous); - } - } - // faulthandler_fatal_error #[cfg(unix)] extern "C" fn faulthandler_fatal_error(signum: libc::c_int) { @@ -468,20 +377,10 @@ mod decl { let fd = FATAL_ERROR.fd.load(Ordering::Relaxed); - let handler = unsafe { - FAULTHANDLER_HANDLERS - .iter_mut() - .find(|h| h.signum == signum) - }; - - if let Some(h) = handler { - // Disable handler (restores previous) - unsafe { - faulthandler_disable_fatal_handler(h); - } - + if let Some(name) = host_faulthandler::fatal_signal_name(signum) { + host_faulthandler::disable_fatal_signal(signum); puts(fd, "Fatal Python error: "); - puts(fd, h.name); + puts(fd, name); puts(fd, "\n\n"); } else { puts(fd, "Fatal Python error from unexpected signum: "); @@ -498,15 +397,10 @@ mod decl { // We cannot just restore the previous handler because Rust's runtime // may have installed its own SIGSEGV handler (for stack overflow detection) // that doesn't terminate the process on software-raised signals. - unsafe { - libc::signal(signum, libc::SIG_DFL); - libc::raise(signum); - } + host_faulthandler::signal_default_and_raise(signum); // Fallback if raise() somehow didn't terminate the process - unsafe { - libc::_exit(1); - } + host_faulthandler::exit_immediately(1); } // faulthandler_fatal_error for Windows @@ -520,18 +414,10 @@ mod decl { let fd = FATAL_ERROR.fd.load(Ordering::Relaxed); - let handler = unsafe { - FAULTHANDLER_HANDLERS - .iter_mut() - .find(|h| h.signum == signum) - }; - - if let Some(h) = handler { - unsafe { - faulthandler_disable_fatal_handler(h); - } + if let Some(name) = host_faulthandler::fatal_signal_name(signum) { + host_faulthandler::disable_fatal_signal(signum); puts(fd, "Fatal Python error: "); - puts(fd, h.name); + puts(fd, name); puts(fd, "\n\n"); } else { puts(fd, "Fatal Python error from unexpected signum: "); @@ -544,10 +430,7 @@ mod decl { set_errno(save_errno); - unsafe { - libc::signal(signum, libc::SIG_DFL); - libc::raise(signum); - } + host_faulthandler::signal_default_and_raise(signum); // Fallback rustpython_host_env::os::exit(1); @@ -559,20 +442,12 @@ mod decl { #[cfg(windows)] fn faulthandler_ignore_exception(code: u32) -> bool { - // bpo-30557: ignore exceptions which are not errors - if (code & 0x80000000) == 0 { - return true; - } - // bpo-31701: ignore MSC and COM exceptions - if code == 0xE06D7363 || code == 0xE0434352 { - return true; - } - false + host_faulthandler::ignore_exception(code) } #[cfg(windows)] unsafe extern "system" fn faulthandler_exc_handler( - exc_info: *mut windows_sys::Win32::System::Diagnostics::Debug::EXCEPTION_POINTERS, + exc_info: *mut host_faulthandler::ExceptionPointers, ) -> i32 { const EXCEPTION_CONTINUE_SEARCH: i32 = 0; @@ -580,8 +455,7 @@ mod decl { return EXCEPTION_CONTINUE_SEARCH; } - let record = unsafe { &*(*exc_info).ExceptionRecord }; - let code = record.ExceptionCode as u32; + let code = unsafe { host_faulthandler::exception_code(exc_info) }; if faulthandler_ignore_exception(code) { return EXCEPTION_CONTINUE_SEARCH; @@ -590,32 +464,17 @@ mod decl { let fd = FATAL_ERROR.fd.load(Ordering::Relaxed); puts(fd, "Windows fatal exception: "); - match code { - 0xC0000005 => puts(fd, "access violation"), - 0xC000008C => puts(fd, "float divide by zero"), - 0xC0000091 => puts(fd, "float overflow"), - 0xC0000094 => puts(fd, "int divide by zero"), - 0xC0000095 => puts(fd, "integer overflow"), - 0xC0000006 => puts(fd, "page error"), - 0xC00000FD => puts(fd, "stack overflow"), - 0xC000001D => puts(fd, "illegal instruction"), - _ => { - puts(fd, "code "); - dump_hexadecimal(fd, code as u64, 8); - } + if let Some(description) = host_faulthandler::exception_description(code) { + puts(fd, description); + } else { + puts(fd, "code "); + dump_hexadecimal(fd, code as u64, 8); } puts(fd, "\n\n"); // Disable SIGSEGV handler for access violations to avoid double output - if code == 0xC0000005 { - unsafe { - for handler in &mut FAULTHANDLER_HANDLERS { - if handler.signum == libc::SIGSEGV { - faulthandler_disable_fatal_handler(handler); - break; - } - } - } + if host_faulthandler::is_access_violation(code) { + host_faulthandler::disable_fatal_signal(libc::SIGSEGV); } let all_threads = FATAL_ERROR.all_threads.load(Ordering::Relaxed); @@ -631,23 +490,8 @@ mod decl { return true; } - unsafe { - for handler in &mut FAULTHANDLER_HANDLERS { - if handler.enabled { - continue; - } - - let mut action: libc::sigaction = core::mem::zeroed(); - action.sa_sigaction = faulthandler_fatal_error as *const () as libc::sighandler_t; - // SA_NODEFER flag - action.sa_flags = libc::SA_NODEFER; - - if libc::sigaction(handler.signum, &action, &mut handler.previous) != 0 { - return false; - } - - handler.enabled = true; - } + if !host_faulthandler::enable_fatal_handlers(faulthandler_fatal_error, libc::SA_NODEFER) { + return false; } FATAL_ERROR.enabled.store(true, Ordering::Relaxed); @@ -660,31 +504,15 @@ mod decl { return true; } - unsafe { - for handler in &mut FAULTHANDLER_HANDLERS { - if handler.enabled { - continue; - } - - handler.previous = libc::signal( - handler.signum, - faulthandler_fatal_error as *const () as libc::sighandler_t, - ); - - // SIG_ERR is -1 as sighandler_t (which is usize on Windows) - if handler.previous == libc::SIG_ERR as libc::sighandler_t { - return false; - } - - handler.enabled = true; - } + if !host_faulthandler::enable_fatal_handlers(faulthandler_fatal_error, 0) { + return false; } // Register Windows vectored exception handler #[cfg(windows)] { - use windows_sys::Win32::System::Diagnostics::Debug::AddVectoredExceptionHandler; - let h = unsafe { AddVectoredExceptionHandler(1, Some(faulthandler_exc_handler)) }; + let h = + host_faulthandler::add_vectored_exception_handler(Some(faulthandler_exc_handler)); EXC_HANDLER.store(h as usize, Ordering::Relaxed); } @@ -699,22 +527,13 @@ mod decl { return; } - unsafe { - for handler in &mut FAULTHANDLER_HANDLERS { - faulthandler_disable_fatal_handler(handler); - } - } + host_faulthandler::disable_fatal_handlers(); // Remove Windows vectored exception handler #[cfg(windows)] { - use windows_sys::Win32::System::Diagnostics::Debug::RemoveVectoredExceptionHandler; let h = EXC_HANDLER.swap(0, Ordering::Relaxed); - if h != 0 { - unsafe { - RemoveVectoredExceptionHandler(h as *mut core::ffi::c_void); - } - } + host_faulthandler::remove_vectored_exception_handler(h); } } @@ -944,112 +763,27 @@ mod decl { } } - #[cfg(unix)] - mod user_signals { - use parking_lot::Mutex; - - const NSIG: usize = 64; - - #[derive(Clone, Copy)] - pub(super) struct UserSignal { - pub enabled: bool, - pub fd: i32, - pub all_threads: bool, - pub chain: bool, - pub previous: libc::sigaction, - } - - impl Default for UserSignal { - fn default() -> Self { - Self { - enabled: false, - fd: 2, // stderr - all_threads: true, - chain: false, - // SAFETY: sigaction is a C struct that can be zero-initialized - previous: unsafe { core::mem::zeroed() }, - } - } - } - - static USER_SIGNALS: Mutex>> = Mutex::new(None); - - pub(super) fn get_user_signal(signum: usize) -> Option { - let guard = USER_SIGNALS.lock(); - guard.as_ref().and_then(|v| v.get(signum).copied()) - } - - pub(super) fn set_user_signal(signum: usize, signal: UserSignal) { - let mut guard = USER_SIGNALS.lock(); - if guard.is_none() { - *guard = Some(vec![UserSignal::default(); NSIG]); - } - if let Some(ref mut v) = *guard - && signum < v.len() - { - v[signum] = signal; - } - } - - pub(super) fn clear_user_signal(signum: usize) -> Option { - let mut guard = USER_SIGNALS.lock(); - if let Some(ref mut v) = *guard - && signum < v.len() - && v[signum].enabled - { - let old = v[signum]; - v[signum] = UserSignal::default(); - return Some(old); - } - None - } - - pub(super) fn is_enabled(signum: usize) -> bool { - let guard = USER_SIGNALS.lock(); - guard - .as_ref() - .and_then(|v| v.get(signum)) - .is_some_and(|s| s.enabled) - } - } - #[cfg(unix)] extern "C" fn faulthandler_user_signal(signum: libc::c_int) { let save_errno = get_errno(); - let user = match user_signals::get_user_signal(signum as usize) { - Some(u) if u.enabled => u, + let user = match host_faulthandler::get_user_signal(signum as usize) { + Some(u) => u, _ => return, }; faulthandler_dump_traceback(user.fd, user.all_threads); if user.chain { - // Restore the previous handler and re-raise - unsafe { - libc::sigaction(signum, &user.previous, core::ptr::null_mut()); - } set_errno(save_errno); - unsafe { - libc::raise(signum); - } - // Re-install our handler with the same flags as register() - let save_errno2 = get_errno(); - unsafe { - let mut action: libc::sigaction = core::mem::zeroed(); - action.sa_sigaction = faulthandler_user_signal as *const () as libc::sighandler_t; - action.sa_flags = libc::SA_NODEFER; - libc::sigaction(signum, &action, core::ptr::null_mut()); - } - set_errno(save_errno2); + let _ = host_faulthandler::reraise_user_signal(signum, faulthandler_user_signal); } } #[cfg(unix)] fn check_signum(signum: i32, vm: &VirtualMachine) -> PyResult<()> { // Check if it's a fatal signal (faulthandler.c uses faulthandler_handlers array) - let is_fatal = unsafe { FAULTHANDLER_HANDLERS.iter().any(|h| h.signum == signum) }; - if is_fatal { + if host_faulthandler::is_fatal_signal(signum) { return Err(vm.new_runtime_error(format!( "signal {signum} cannot be registered, use enable() instead" ))); @@ -1084,46 +818,19 @@ mod decl { let fd = get_fd_from_file_opt(args.file, vm)?; - let signum = args.signum as usize; - - // Get current handler to save as previous - let previous = if !user_signals::is_enabled(signum) { - unsafe { - let mut action: libc::sigaction = core::mem::zeroed(); - action.sa_sigaction = faulthandler_user_signal as *const () as libc::sighandler_t; - // SA_RESTART by default; SA_NODEFER only when chaining - // (faulthandler.c:860-864) - action.sa_flags = if args.chain { - libc::SA_NODEFER - } else { - libc::SA_RESTART - }; - - let mut prev: libc::sigaction = core::mem::zeroed(); - if libc::sigaction(args.signum, &action, &mut prev) != 0 { - return Err(vm.new_os_error(format!( - "Failed to register signal handler for signal {}", - args.signum - ))); - } - prev - } - } else { - // Already registered, keep previous handler - user_signals::get_user_signal(signum) - .map_or(unsafe { core::mem::zeroed() }, |u| u.previous) - }; - - user_signals::set_user_signal( - signum, - user_signals::UserSignal { - enabled: true, - fd, - all_threads: args.all_threads, - chain: args.chain, - previous, - }, - ); + host_faulthandler::register_user_signal( + args.signum, + fd, + args.all_threads, + args.chain, + faulthandler_user_signal, + ) + .map_err(|_| { + vm.new_os_error(format!( + "Failed to register signal handler for signal {}", + args.signum + )) + })?; Ok(()) } @@ -1132,16 +839,7 @@ mod decl { #[pyfunction] fn unregister(signum: i32, vm: &VirtualMachine) -> PyResult { check_signum(signum, vm)?; - - if let Some(old) = user_signals::clear_user_signal(signum as usize) { - // Restore previous handler - unsafe { - libc::sigaction(signum, &old.previous, core::ptr::null_mut()); - } - Ok(true) - } else { - Ok(false) - } + Ok(host_faulthandler::unregister_user_signal(signum)) } // Test functions for faulthandler testing @@ -1188,10 +886,7 @@ mod decl { #[cfg(not(target_arch = "wasm32"))] { suppress_crash_report(); - - unsafe { - libc::abort(); - } + host_faulthandler::abort_process(); } } @@ -1200,10 +895,7 @@ mod decl { #[cfg(not(target_arch = "wasm32"))] { suppress_crash_report(); - - unsafe { - libc::raise(libc::SIGFPE); - } + host_faulthandler::raise_signal(libc::SIGFPE); } } @@ -1226,28 +918,14 @@ mod decl { fn suppress_crash_report() { #[cfg(windows)] { - use windows_sys::Win32::System::Diagnostics::Debug::{ - SEM_NOGPFAULTERRORBOX, SetErrorMode, - }; - unsafe { - let mode = SetErrorMode(SEM_NOGPFAULTERRORBOX); - SetErrorMode(mode | SEM_NOGPFAULTERRORBOX); - } + host_faulthandler::suppress_crash_report(); } #[cfg(unix)] { - // Disable core dumps #[cfg(not(any(target_os = "redox", target_os = "wasi")))] { - use libc::{RLIMIT_CORE, rlimit, setrlimit}; - let rl = rlimit { - rlim_cur: 0, - rlim_max: 0, - }; - unsafe { - let _ = setrlimit(RLIMIT_CORE, &rl); - } + rustpython_host_env::resource::disable_core_dumps(); } } } @@ -1285,11 +963,7 @@ mod decl { #[cfg(windows)] #[pyfunction] fn _raise_exception(args: RaiseExceptionArgs, _vm: &VirtualMachine) { - use windows_sys::Win32::System::Diagnostics::Debug::RaiseException; - suppress_crash_report(); - unsafe { - RaiseException(args.code, args.flags, 0, core::ptr::null()); - } + host_faulthandler::raise_exception(args.code, args.flags); } } diff --git a/crates/stdlib/src/fcntl.rs b/crates/stdlib/src/fcntl.rs index 6d27b3cee57..5081c9e9c14 100644 --- a/crates/stdlib/src/fcntl.rs +++ b/crates/stdlib/src/fcntl.rs @@ -9,6 +9,7 @@ mod fcntl { use crate::vm::{ PyResult, VirtualMachine, builtins::PyIntRef, + convert::ToPyException, function::{ArgMemoryBuffer, ArgStrOrBytesLike, Either, OptionalArg}, stdlib::_io, }; @@ -95,10 +96,7 @@ mod fcntl { mutate_flag: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - // Convert to unsigned - handles both positive u32 values and negative i32 values - // that represent the same bit pattern (e.g., TIOCSWINSZ on some platforms). - // First truncate to u32 (takes lower 32 bits), then zero-extend to c_ulong. - let request = (request as u32) as libc::c_ulong; + let request = host_fcntl::normalize_ioctl_request(request); let arg = arg.unwrap_or_else(|| Either::B(0)); match arg { Either::A(buf_kind) => { @@ -158,39 +156,20 @@ mod fcntl { whence: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - macro_rules! try_into_l_type { - ($l_type:path) => { - $l_type - .try_into() - .map_err(|e| vm.new_overflow_error(format!("{e}"))) - }; - } - - let mut l: libc::flock = unsafe { core::mem::zeroed() }; - l.l_type = if cmd == libc::LOCK_UN { - try_into_l_type!(libc::F_UNLCK) - } else if (cmd & libc::LOCK_SH) != 0 { - try_into_l_type!(libc::F_RDLCK) - } else if (cmd & libc::LOCK_EX) != 0 { - try_into_l_type!(libc::F_WRLCK) - } else { - return Err(vm.new_value_error("unrecognized lockf argument")); - }?; - l.l_start = match start { + let start = match start { OptionalArg::Present(s) => s.try_to_primitive(vm)?, OptionalArg::Missing => 0, }; - l.l_len = match len { + let len = match len { OptionalArg::Present(l_) => l_.try_to_primitive(vm)?, OptionalArg::Missing => 0, }; - l.l_whence = match whence { - OptionalArg::Present(w) => w - .try_into() - .map_err(|e| vm.new_overflow_error(format!("{e}")))?, + let whence = match whence { + OptionalArg::Present(w) => w, OptionalArg::Missing => 0, }; - let ret = host_fcntl::lockf(fd, cmd, &l).map_err(|_| vm.new_last_errno_error())?; + let ret = + host_fcntl::lockf(fd, cmd, len, start, whence).map_err(|err| err.to_pyexception(vm))?; Ok(vm.ctx.new_int(ret).into()) } } diff --git a/crates/stdlib/src/grp.rs b/crates/stdlib/src/grp.rs index 70aa7d4e4c6..34e9929d2c4 100644 --- a/crates/stdlib/src/grp.rs +++ b/crates/stdlib/src/grp.rs @@ -10,8 +10,7 @@ mod grp { exceptions, types::PyStructSequence, }; - use core::ptr::NonNull; - use nix::unistd; + use rustpython_host_env::grp as host_grp; #[pystruct_sequence_data] struct GroupData { @@ -29,15 +28,11 @@ mod grp { impl PyGroup {} impl GroupData { - fn from_unistd_group(group: unistd::Group, vm: &VirtualMachine) -> Self { - let cstr_lossy = |s: alloc::ffi::CString| { - s.into_string() - .unwrap_or_else(|e| e.into_cstring().to_string_lossy().into_owned()) - }; + fn from_group(group: host_grp::Group, vm: &VirtualMachine) -> Self { Self { gr_name: group.name, - gr_passwd: cstr_lossy(group.passwd), - gr_gid: group.gid.as_raw(), + gr_passwd: group.passwd, + gr_gid: group.gid, gr_mem: vm .ctx .new_list(group.mem.iter().map(|s| s.to_pyobject(vm)).collect()), @@ -48,11 +43,9 @@ mod grp { #[pyfunction] fn getgrgid(gid: PyIntRef, vm: &VirtualMachine) -> PyResult { let gr_gid = gid.as_bigint(); - let gid = libc::gid_t::try_from(gr_gid) - .map(unistd::Gid::from_raw) - .ok(); + let gid = libc::gid_t::try_from(gr_gid).ok(); let group = gid - .map(unistd::Group::from_gid) + .map(host_grp::getgrgid) .transpose() .map_err(|err| err.into_pyexception(vm))? .flatten(); @@ -63,7 +56,7 @@ mod grp { .into(), ) })?; - Ok(GroupData::from_unistd_group(group, vm)) + Ok(GroupData::from_group(group, vm)) } #[pyfunction] @@ -72,7 +65,7 @@ mod grp { if gr_name.contains('\0') { return Err(exceptions::cstring_error(vm)); } - let group = unistd::Group::from_name(gr_name).map_err(|err| err.into_pyexception(vm))?; + let group = host_grp::getgrnam(gr_name).map_err(|err| err.into_pyexception(vm))?; let group = group.ok_or_else(|| { vm.new_key_error( vm.ctx @@ -80,24 +73,14 @@ mod grp { .into(), ) })?; - Ok(GroupData::from_unistd_group(group, vm)) + Ok(GroupData::from_group(group, vm)) } #[pyfunction] fn getgrall(vm: &VirtualMachine) -> Vec { - // setgrent, getgrent, etc are not thread safe. Could use fgetgrent_r, but this is easier - static GETGRALL: parking_lot::Mutex<()> = parking_lot::Mutex::new(()); - let _guard = GETGRALL.lock(); - let mut list = Vec::new(); - - unsafe { libc::setgrent() }; - while let Some(ptr) = NonNull::new(unsafe { libc::getgrent() }) { - let group = unistd::Group::from(unsafe { ptr.as_ref() }); - let group = GroupData::from_unistd_group(group, vm).to_pyobject(vm); - list.push(group); - } - unsafe { libc::endgrent() }; - - list + host_grp::getgrall() + .into_iter() + .map(|group| GroupData::from_group(group, vm).to_pyobject(vm)) + .collect() } } diff --git a/crates/stdlib/src/json.rs b/crates/stdlib/src/json.rs index a32397ad59d..dc2fbbc8892 100644 --- a/crates/stdlib/src/json.rs +++ b/crates/stdlib/src/json.rs @@ -189,30 +189,51 @@ mod _json { fn parse_number(&self, bytes: &[u8], vm: &VirtualMachine) -> Option<(PyResult, usize)> { flame_guard!("JsonScanner::parse_number"); - let mut has_neg = false; - let mut has_decimal = false; - let mut has_exponent = false; - let mut has_e_sign = false; + // RFC 8259 defines JSON numbers in ASCII syntax, including digits, + // '-', '.', 'e'/'E', and an optional exponent sign, so byte iteration + // is equivalent to char iteration here. let mut i = 0; - // JSON numbers are ASCII per RFC 8259 (digits, '-', '+', '.', 'e', 'E'), - // so byte iteration is equivalent to char iteration here. - for &b in bytes { - match b { - b'-' if i == 0 => has_neg = true, - b'0'..=b'9' => {} - b'.' if !has_decimal => has_decimal = true, - b'e' | b'E' if !has_exponent => has_exponent = true, - b'+' | b'-' if !has_e_sign => has_e_sign = true, - _ => break, - } + if bytes.get(i) == Some(&b'-') { i += 1; } - if i == 0 || (i == 1 && has_neg) { - return None; + match bytes.get(i) { + Some(b'0') => i += 1, + Some(b'1'..=b'9') => { + i += 1; + while matches!(bytes.get(i), Some(b'0'..=b'9')) { + i += 1; + } + } + _ => return None, + } + + let mut is_float = false; + if bytes.get(i) == Some(&b'.') && matches!(bytes.get(i + 1), Some(b'0'..=b'9')) { + is_float = true; + i += 2; + while matches!(bytes.get(i), Some(b'0'..=b'9')) { + i += 1; + } + } + + if matches!(bytes.get(i), Some(b'e' | b'E')) { + let mut exponent_end = i + 1; + if matches!(bytes.get(exponent_end), Some(b'+' | b'-')) { + exponent_end += 1; + } + if matches!(bytes.get(exponent_end), Some(b'0'..=b'9')) { + is_float = true; + exponent_end += 1; + while matches!(bytes.get(exponent_end), Some(b'0'..=b'9')) { + exponent_end += 1; + } + i = exponent_end; + } } + // SAFETY: the loop above accepts only ASCII bytes, so bytes[..i] is valid UTF-8. let buf = unsafe { core::str::from_utf8_unchecked(&bytes[..i]) }; - let ret = if has_decimal || has_exponent { + let ret = if is_float { // float if let Some(ref parse_float) = self.parse_float { parse_float.call((buf,), vm) diff --git a/crates/stdlib/src/json/machinery.rs b/crates/stdlib/src/json/machinery.rs index 30882588e34..2acd6c8cb71 100644 --- a/crates/stdlib/src/json/machinery.rs +++ b/crates/stdlib/src/json/machinery.rs @@ -44,10 +44,6 @@ static ESCAPE_CHARS: [&str; 0x20] = [ // And which one need to be escaped (1) // The characters that need escaping are 0x00 to 0x1F, 0x22 ("), 0x5C (\), 0x7F (DEL) // Non-ASCII unicode characters can be safely included in a JSON string -#[allow( - clippy::unusual_byte_groupings, - reason = "groups of 16 are intentional here" -)] static NEEDS_ESCAPING_BITSET: [u64; 4] = [ //fedcba9876543210_fedcba9876543210_fedcba9876543210_fedcba9876543210 0b0000000000000000_0000000000000100_1111111111111111_1111111111111111, // 3_2_1_0 diff --git a/crates/stdlib/src/lib.rs b/crates/stdlib/src/lib.rs index 4670b07f06c..b1ae53af853 100644 --- a/crates/stdlib/src/lib.rs +++ b/crates/stdlib/src/lib.rs @@ -128,11 +128,11 @@ mod openssl; #[cfg(all( feature = "host_env", not(target_arch = "wasm32"), - feature = "ssl-rustls" + feature = "__ssl-rustls" ))] -mod ssl; +pub mod ssl; -#[cfg(all(feature = "ssl-openssl", feature = "ssl-rustls", not(clippy)))] +#[cfg(all(feature = "ssl-openssl", feature = "__ssl-rustls", not(clippy)))] compile_error!(r#"features "ssl-openssl" and "ssl-rustls" are mutually exclusive"#); #[cfg(all( @@ -246,7 +246,7 @@ pub fn stdlib_module_defs(ctx: &Context) -> Vec<&'static builtins::PyModuleDef> #[cfg(all( feature = "host_env", not(target_arch = "wasm32"), - feature = "ssl-rustls" + feature = "__ssl-rustls" ))] ssl::module_def(ctx), statistics::module_def(ctx), diff --git a/crates/stdlib/src/locale.rs b/crates/stdlib/src/locale.rs index 251c9586e18..74e9053fbfb 100644 --- a/crates/stdlib/src/locale.rs +++ b/crates/stdlib/src/locale.rs @@ -2,55 +2,16 @@ pub(crate) use _locale::module_def; -#[cfg(windows)] -#[repr(C)] -struct lconv { - decimal_point: *mut libc::c_char, - thousands_sep: *mut libc::c_char, - grouping: *mut libc::c_char, - int_curr_symbol: *mut libc::c_char, - currency_symbol: *mut libc::c_char, - mon_decimal_point: *mut libc::c_char, - mon_thousands_sep: *mut libc::c_char, - mon_grouping: *mut libc::c_char, - positive_sign: *mut libc::c_char, - negative_sign: *mut libc::c_char, - int_frac_digits: libc::c_char, - frac_digits: libc::c_char, - p_cs_precedes: libc::c_char, - p_sep_by_space: libc::c_char, - n_cs_precedes: libc::c_char, - n_sep_by_space: libc::c_char, - p_sign_posn: libc::c_char, - n_sign_posn: libc::c_char, - int_p_cs_precedes: libc::c_char, - int_p_sep_by_space: libc::c_char, - int_n_cs_precedes: libc::c_char, - int_n_sep_by_space: libc::c_char, - int_p_sign_posn: libc::c_char, - int_n_sign_posn: libc::c_char, -} - -#[cfg(windows)] -unsafe extern "C" { - fn localeconv() -> *mut lconv; -} - -#[cfg(unix)] -use libc::localeconv; - #[pymodule] mod _locale { use alloc::ffi::CString; - use core::{ffi::CStr, ptr}; + use rustpython_host_env::locale as host_locale; use rustpython_vm::{ PyObjectRef, PyResult, VirtualMachine, builtins::{PyDictRef, PyIntRef, PyListRef, PyTypeRef, PyUtf8StrRef}, convert::ToPyException, function::OptionalArg, }; - #[cfg(windows)] - use windows_sys::Win32::Globalization::GetACP; #[cfg(all( unix, @@ -78,19 +39,11 @@ mod _locale { vm.ctx.new_int(libc::c_char::MAX) } - unsafe fn copy_grouping(group: *const libc::c_char, vm: &VirtualMachine) -> PyListRef { + fn copy_grouping(group: &[libc::c_char], vm: &VirtualMachine) -> PyListRef { let mut group_vec: Vec = Vec::new(); - if group.is_null() { - return vm.ctx.new_list(group_vec); - } - - unsafe { - let mut ptr = group; - while ![0, libc::c_char::MAX].contains(&*ptr) { - let val = vm.ctx.new_int(*ptr); - group_vec.push(val.into()); - ptr = ptr.add(1); - } + for &value in group { + let val = vm.ctx.new_int(value); + group_vec.push(val.into()); } // https://github.com/python/cpython/blob/677320348728ce058fa3579017e985af74a236d4/Modules/_localemodule.c#L80 if !group_vec.is_empty() { @@ -99,47 +52,21 @@ mod _locale { vm.ctx.new_list(group_vec) } - unsafe fn pystr_from_raw_cstr( - vm: &VirtualMachine, - raw_ptr: *const libc::c_char, - ) -> PyObjectRef { - let slice = unsafe { CStr::from_ptr(raw_ptr) }; - + fn pystr_from_bytes(vm: &VirtualMachine, bytes: &[u8]) -> PyObjectRef { // Fast path: ASCII/UTF-8 - if let Ok(s) = slice.to_str() { + if let Ok(s) = core::str::from_utf8(bytes) { return vm.new_pyobj(s); } // On Windows, locale strings use the ANSI code page encoding #[cfg(windows)] { - use windows_sys::Win32::Globalization::{CP_ACP, MultiByteToWideChar}; - let bytes = slice.to_bytes(); - unsafe { - let len = MultiByteToWideChar( - CP_ACP, - 0, - bytes.as_ptr(), - bytes.len() as i32, - ptr::null_mut(), - 0, - ); - if len > 0 { - let mut wide = vec![0u16; len as usize]; - MultiByteToWideChar( - CP_ACP, - 0, - bytes.as_ptr(), - bytes.len() as i32, - wide.as_mut_ptr(), - len, - ); - return vm.new_pyobj(String::from_utf16_lossy(&wide)); - } + if let Some(decoded) = host_locale::decode_ansi_bytes(bytes) { + return vm.new_pyobj(decoded); } } - vm.new_pyobj(String::from_utf8_lossy(slice.to_bytes()).into_owned()) + vm.new_pyobj(String::from_utf8_lossy(bytes).into_owned()) } #[pyattr(name = "Error", once)] @@ -155,73 +82,59 @@ mod _locale { fn strcoll(string1: PyUtf8StrRef, string2: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult { let cstr1 = CString::new(string1.as_str()).map_err(|e| e.to_pyexception(vm))?; let cstr2 = CString::new(string2.as_str()).map_err(|e| e.to_pyexception(vm))?; - Ok(vm.new_pyobj(unsafe { libc::strcoll(cstr1.as_ptr(), cstr2.as_ptr()) })) + Ok(vm.new_pyobj(host_locale::strcoll(&cstr1, &cstr2))) } #[pyfunction] fn strxfrm(string: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult { // https://github.com/python/cpython/blob/eaae563b6878aa050b4ad406b67728b6b066220e/Modules/_localemodule.c#L390-L442 let n1 = string.byte_len() + 1; - let mut buff = vec![0u8; n1]; - let cstr = CString::new(string.as_str()).map_err(|e| e.to_pyexception(vm))?; - let n2 = unsafe { libc::strxfrm(buff.as_mut_ptr() as _, cstr.as_ptr(), n1) }; - buff = vec![0u8; n2 + 1]; - unsafe { - libc::strxfrm(buff.as_mut_ptr() as _, cstr.as_ptr(), n2 + 1); - } + let buff = host_locale::strxfrm(&cstr, n1); Ok(vm.new_pyobj(String::from_utf8(buff).expect("strxfrm returned invalid utf-8 string"))) } #[pyfunction] fn localeconv(vm: &VirtualMachine) -> PyResult { let result = vm.ctx.new_dict(); + let lc = host_locale::localeconv_data(); - unsafe { - macro_rules! set_string_field { - ($lc:expr, $field:ident) => {{ - result.set_item( - stringify!($field), - pystr_from_raw_cstr(vm, (*$lc).$field), - vm, - )? - }}; - } - - macro_rules! set_int_field { - ($lc:expr, $field:ident) => {{ result.set_item(stringify!($field), vm.new_pyobj((*$lc).$field), vm)? }}; - } + macro_rules! set_string_field { + ($lc:expr, $field:ident) => {{ result.set_item(stringify!($field), pystr_from_bytes(vm, &$lc.$field), vm)? }}; + } - macro_rules! set_group_field { - ($lc:expr, $field:ident) => {{ - result.set_item( - stringify!($field), - copy_grouping((*$lc).$field, vm).into(), - vm, - )? - }}; - } + macro_rules! set_int_field { + ($lc:expr, $field:ident) => {{ result.set_item(stringify!($field), vm.new_pyobj($lc.$field), vm)? }}; + } - let lc = super::localeconv(); - set_group_field!(lc, mon_grouping); - set_group_field!(lc, grouping); - set_int_field!(lc, int_frac_digits); - set_int_field!(lc, frac_digits); - set_int_field!(lc, p_cs_precedes); - set_int_field!(lc, p_sep_by_space); - set_int_field!(lc, n_cs_precedes); - set_int_field!(lc, p_sign_posn); - set_int_field!(lc, n_sign_posn); - set_string_field!(lc, decimal_point); - set_string_field!(lc, thousands_sep); - set_string_field!(lc, int_curr_symbol); - set_string_field!(lc, currency_symbol); - set_string_field!(lc, mon_decimal_point); - set_string_field!(lc, mon_thousands_sep); - set_int_field!(lc, n_sep_by_space); - set_string_field!(lc, positive_sign); - set_string_field!(lc, negative_sign); + macro_rules! set_group_field { + ($lc:expr, $field:ident) => {{ + result.set_item( + stringify!($field), + copy_grouping(&$lc.$field, vm).into(), + vm, + )? + }}; } + + set_group_field!(lc, mon_grouping); + set_group_field!(lc, grouping); + set_int_field!(lc, int_frac_digits); + set_int_field!(lc, frac_digits); + set_int_field!(lc, p_cs_precedes); + set_int_field!(lc, p_sep_by_space); + set_int_field!(lc, n_cs_precedes); + set_int_field!(lc, p_sign_posn); + set_int_field!(lc, n_sign_posn); + set_string_field!(lc, decimal_point); + set_string_field!(lc, thousands_sep); + set_string_field!(lc, int_curr_symbol); + set_string_field!(lc, currency_symbol); + set_string_field!(lc, mon_decimal_point); + set_string_field!(lc, mon_thousands_sep); + set_int_field!(lc, n_sep_by_space); + set_string_field!(lc, positive_sign); + set_string_field!(lc, negative_sign); Ok(result) } @@ -268,37 +181,32 @@ mod _locale { return Err(vm.new_exception_msg(error, "unsupported locale setting".into())); } - unsafe { - let result = match args.locale.flatten() { - None => libc::setlocale(args.category, ptr::null()), - Some(locale) => { - let locale_str = locale.as_str(); - // On Windows, validate encoding name length - #[cfg(windows)] - { - let valid = if args.category == LC_ALL { - check_locale_name_all(locale_str) - } else { - check_locale_name(locale_str) - }; - if !valid { - return Err( - vm.new_exception_msg(error, "unsupported locale setting".into()) - ); - } + let result = match args.locale.flatten() { + None => host_locale::setlocale(args.category, None), + Some(locale) => { + let locale_str = locale.as_str(); + #[cfg(windows)] + { + let valid = if args.category == LC_ALL { + check_locale_name_all(locale_str) + } else { + check_locale_name(locale_str) + }; + if !valid { + return Err( + vm.new_exception_msg(error, "unsupported locale setting".into()) + ); } - let c_locale: CString = - CString::new(locale_str).map_err(|e| e.to_pyexception(vm))?; - libc::setlocale(args.category, c_locale.as_ptr()) } - }; - - if result.is_null() { - return Err(vm.new_exception_msg(error, "unsupported locale setting".into())); + let c_locale: CString = + CString::new(locale_str).map_err(|e| e.to_pyexception(vm))?; + host_locale::setlocale(args.category, Some(&c_locale)) } - - Ok(pystr_from_raw_cstr(vm, result)) - } + }; + let Some(result) = result else { + return Err(vm.new_exception_msg(error, "unsupported locale setting".into())); + }; + Ok(pystr_from_bytes(vm, &result)) } /// Get the current locale encoding. @@ -306,26 +214,21 @@ mod _locale { fn getencoding() -> String { #[cfg(windows)] { - // On Windows, use GetACP() to get the ANSI code page - let acp = unsafe { GetACP() }; + let acp = host_locale::acp(); format!("cp{acp}") } #[cfg(not(windows))] { - // On Unix, use nl_langinfo(CODESET) or fallback to UTF-8 #[cfg(all( unix, not(any(target_os = "ios", target_os = "android", target_os = "redox")) ))] { - unsafe { - let codeset = libc::nl_langinfo(libc::CODESET); - if !codeset.is_null() - && let Ok(s) = CStr::from_ptr(codeset).to_str() - && !s.is_empty() - { - return s.to_string(); - } + if let Some(codeset) = host_locale::nl_langinfo_codeset() + && let Ok(s) = core::str::from_utf8(&codeset) + && !s.is_empty() + { + return s.to_string(); } "UTF-8".to_string() } diff --git a/crates/stdlib/src/mmap.rs b/crates/stdlib/src/mmap.rs index 9cb84b4efa6..2332ee0e1ce 100644 --- a/crates/stdlib/src/mmap.rs +++ b/crates/stdlib/src/mmap.rs @@ -23,67 +23,18 @@ mod mmap { }; use core::ops::{Deref, DerefMut}; use crossbeam_utils::atomic::AtomicCell; - use memmap2::{Mmap, MmapMut, MmapOptions}; use num_traits::Signed; - use std::io::{self, Write}; + #[cfg(windows)] + use std::io; + use std::io::Write; - #[cfg(unix)] - use nix::{sys::stat::fstat, unistd}; #[cfg(unix)] use rustpython_host_env::crt_fd; + #[cfg(any(unix, windows))] + use rustpython_host_env::mmap as host_mmap; #[cfg(windows)] - use rustpython_host_env::suppress_iph; - #[cfg(windows)] - use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}; - #[cfg(windows)] - use windows_sys::Win32::{ - Foundation::{ - CloseHandle, DUPLICATE_SAME_ACCESS, DuplicateHandle, HANDLE, INVALID_HANDLE_VALUE, - }, - Storage::FileSystem::{FILE_BEGIN, GetFileSize, SetEndOfFile, SetFilePointerEx}, - System::Memory::{ - CreateFileMappingW, FILE_MAP_COPY, FILE_MAP_READ, FILE_MAP_WRITE, FlushViewOfFile, - MapViewOfFile, PAGE_READONLY, PAGE_READWRITE, PAGE_WRITECOPY, UnmapViewOfFile, - }, - System::Threading::GetCurrentProcess, - }; - - #[cfg(unix)] - fn validate_advice(vm: &VirtualMachine, advice: i32) -> PyResult { - match advice { - libc::MADV_NORMAL - | libc::MADV_RANDOM - | libc::MADV_SEQUENTIAL - | libc::MADV_WILLNEED - | libc::MADV_DONTNEED => Ok(advice), - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd" - ))] - libc::MADV_FREE => Ok(advice), - #[cfg(target_os = "linux")] - libc::MADV_DONTFORK - | libc::MADV_DOFORK - | libc::MADV_MERGEABLE - | libc::MADV_UNMERGEABLE - | libc::MADV_HUGEPAGE - | libc::MADV_NOHUGEPAGE - | libc::MADV_REMOVE - | libc::MADV_DONTDUMP - | libc::MADV_DODUMP - | libc::MADV_HWPOISON => Ok(advice), - #[cfg(target_os = "freebsd")] - libc::MADV_NOSYNC - | libc::MADV_AUTOSYNC - | libc::MADV_NOCORE - | libc::MADV_CORE - | libc::MADV_PROTECT => Ok(advice), - _ => Err(vm.new_value_error("Not a valid Advice value")), - } - } + use rustpython_host_env::nt as host_nt; #[repr(C)] #[derive(PartialEq, Eq, Debug)] @@ -180,16 +131,15 @@ mod mmap { #[pyattr] const ACCESS_COPY: u32 = AccessMode::Copy as u32; - #[cfg(not(target_arch = "wasm32"))] #[pyattr(name = "PAGESIZE", once)] fn page_size(_vm: &VirtualMachine) -> usize { - page_size::get() + rustpython_host_env::os::page_size() } #[cfg(not(target_arch = "wasm32"))] #[pyattr(name = "ALLOCATIONGRANULARITY", once)] fn granularity(_vm: &VirtualMachine) -> usize { - page_size::get_granularity() + rustpython_host_env::os::alloc_granularity() } #[pyattr(name = "error", once)] @@ -197,68 +147,19 @@ mod mmap { vm.ctx.exceptions.os_error.to_owned() } - /// Named file mapping on Windows using raw Win32 APIs. - /// Supports tagname parameter for inter-process shared memory. - #[cfg(windows)] - struct NamedMmap { - map_handle: HANDLE, - view_ptr: *mut u8, - len: usize, - } - - #[cfg(windows)] - // SAFETY: The memory mapping is managed by the OS and is safe to share - // across threads. Access is synchronized by PyMutex in PyMmap. - unsafe impl Send for NamedMmap {} - #[cfg(windows)] - unsafe impl Sync for NamedMmap {} - - #[cfg(windows)] - impl core::fmt::Debug for NamedMmap { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("NamedMmap") - .field("map_handle", &self.map_handle) - .field("view_ptr", &self.view_ptr) - .field("len", &self.len) - .finish() - } - } - - #[cfg(windows)] - impl Drop for NamedMmap { - fn drop(&mut self) { - unsafe { - if !self.view_ptr.is_null() { - UnmapViewOfFile( - windows_sys::Win32::System::Memory::MEMORY_MAPPED_VIEW_ADDRESS { - Value: self.view_ptr as *mut _, - }, - ); - } - if !self.map_handle.is_null() { - CloseHandle(self.map_handle); - } - } - } - } - #[derive(Debug)] enum MmapObj { - Write(MmapMut), - Read(Mmap), + Mapped(host_mmap::MappedFile), #[cfg(windows)] - Named(NamedMmap), + Named(host_mmap::NamedMmap), } impl MmapObj { fn as_slice(&self) -> &[u8] { match self { - Self::Read(mmap) => &mmap[..], - Self::Write(mmap) => &mmap[..], + Self::Mapped(mmap) => mmap.as_slice(), #[cfg(windows)] - Self::Named(named) => unsafe { - core::slice::from_raw_parts(named.view_ptr, named.len) - }, + Self::Named(named) => named.as_slice(), } } } @@ -272,7 +173,7 @@ mod mmap { #[cfg(unix)] fd: AtomicCell, #[cfg(windows)] - handle: AtomicCell, // HANDLE is isize on Windows + handle: AtomicCell, // host_mmap::Handle is isize on Windows offset: i64, size: AtomicCell, pos: AtomicCell, // relative to offset @@ -286,15 +187,13 @@ mod mmap { #[cfg(unix)] { let fd = self.fd.swap(-1); - if fd >= 0 { - unsafe { libc::close(fd) }; - } + host_mmap::close_descriptor(fd); } #[cfg(windows)] { - let handle = self.handle.swap(INVALID_HANDLE_VALUE as isize); - if handle != INVALID_HANDLE_VALUE as isize { - unsafe { CloseHandle(handle as HANDLE) }; + let handle = self.handle.swap(host_mmap::INVALID_HANDLE as isize); + if handle != host_mmap::INVALID_HANDLE as isize { + host_mmap::close_handle(handle as host_mmap::Handle); } } } @@ -484,14 +383,11 @@ mod mmap { // fcntl(2) is necessary to force DISKSYNC and get around mmap(2) bug #[cfg(target_os = "macos")] if let Ok(fd) = fd { - use std::os::fd::AsRawFd; - unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_FULLFSYNC) }; + host_mmap::prepare_file_mapping(fd); } if let Ok(fd) = fd { - let metadata = fstat(fd) - .map_err(|err| io::Error::from_raw_os_error(err as i32).to_pyexception(vm))?; - let file_len = metadata.st_size as i64; + let file_len = host_mmap::file_len(fd).map_err(|err| err.to_pyexception(vm))?; if map_size == 0 { if file_len == 0 { @@ -510,22 +406,22 @@ mod mmap { } } - let mut mmap_opt = MmapOptions::new(); - let mmap_opt = mmap_opt.offset(offset as u64).len(map_size); - let (fd, mmap) = || -> std::io::Result<_> { if let Ok(fd) = fd { - let new_fd: crt_fd::Owned = unistd::dup(fd)?.into(); - let mmap = match access { - AccessMode::Default | AccessMode::Write => { - MmapObj::Write(unsafe { mmap_opt.map_mut(&new_fd) }?) - } - AccessMode::Read => MmapObj::Read(unsafe { mmap_opt.map(&new_fd) }?), - AccessMode::Copy => MmapObj::Write(unsafe { mmap_opt.map_copy(&new_fd) }?), - }; + let (new_fd, mmap) = host_mmap::map_file( + fd, + offset, + map_size, + match access { + AccessMode::Default => host_mmap::AccessMode::Default, + AccessMode::Read => host_mmap::AccessMode::Read, + AccessMode::Write => host_mmap::AccessMode::Write, + AccessMode::Copy => host_mmap::AccessMode::Copy, + }, + )?; Ok((Some(new_fd), mmap)) } else { - let mmap = MmapObj::Write(mmap_opt.map_anon()?); + let mmap = host_mmap::map_anon(map_size)?; Ok((None, mmap)) } }() @@ -533,7 +429,7 @@ mod mmap { Ok(Self { closed: AtomicCell::new(false), - mmap: PyMutex::new(Some(mmap)), + mmap: PyMutex::new(Some(MmapObj::Mapped(mmap))), fd: AtomicCell::new(fd.map_or(-1, |fd| fd.into_raw())), offset, size: AtomicCell::new(map_size), @@ -568,163 +464,105 @@ mod mmap { _ => None, }; - // Get file handle from fileno - // fileno -1 or 0 means anonymous mapping - let fh: Option = if fileno != -1 && fileno != 0 { - // Convert CRT file descriptor to Windows HANDLE + // Get file handle from fileno. fileno -1 means anonymous mapping. + let fh: Option = if fileno != -1 { + // Convert CRT file descriptor to a Windows file mapping handle. // Use suppress_iph! to avoid crashes when the fd is invalid. // This is critical because socket fds wrapped via _open_osfhandle // may cause crashes in _get_osfhandle on Windows. // See Python bug https://bugs.python.org/issue30114 - let handle = unsafe { suppress_iph!(libc::get_osfhandle(fileno)) }; + let handle = host_nt::handle_from_fd(fileno); // Check for invalid handle value (-1 on Windows) - if handle == -1 || handle == INVALID_HANDLE_VALUE as isize { + if host_mmap::is_invalid_handle_value(handle as isize) { return Err(vm.new_os_error(format!("Invalid file descriptor: {fileno}"))); } - Some(handle as HANDLE) + Some(handle as host_mmap::Handle) } else { None }; // Get file size if we have a file handle and map_size is 0 - let mut duplicated_handle: HANDLE = INVALID_HANDLE_VALUE; + let mut duplicated_handle: host_mmap::Handle = host_mmap::INVALID_HANDLE; if let Some(fh) = fh { // Duplicate handle so Python code can close the original - let mut new_handle: HANDLE = INVALID_HANDLE_VALUE; - let result = unsafe { - DuplicateHandle( - GetCurrentProcess(), - fh, - GetCurrentProcess(), - &mut new_handle, - 0, - 0, // not inheritable - DUPLICATE_SAME_ACCESS, - ) - }; - if result == 0 { - return Err(io::Error::last_os_error().to_pyexception(vm)); - } - duplicated_handle = new_handle; + duplicated_handle = + host_mmap::duplicate_handle(fh).map_err(|e| e.to_pyexception(vm))?; // Get file size - let mut high: u32 = 0; - let low = unsafe { GetFileSize(fh, &mut high) }; - if low == u32::MAX { - let err = io::Error::last_os_error(); - if err.raw_os_error() != Some(0) { - unsafe { CloseHandle(duplicated_handle) }; + let file_len = match host_mmap::get_file_len(fh) { + Ok(len) => len, + Err(err) => { + host_mmap::close_handle(duplicated_handle); return Err(err.to_pyexception(vm)); } - } - let file_len = ((high as i64) << 32) | (low as i64); + }; if map_size == 0 { if file_len == 0 { - unsafe { CloseHandle(duplicated_handle) }; + host_mmap::close_handle(duplicated_handle); return Err(vm.new_value_error("cannot mmap an empty file")); } if offset >= file_len { - unsafe { CloseHandle(duplicated_handle) }; + host_mmap::close_handle(duplicated_handle); return Err(vm.new_value_error("mmap offset is greater than file size")); } if file_len - offset > isize::MAX as i64 { - unsafe { CloseHandle(duplicated_handle) }; + host_mmap::close_handle(duplicated_handle); return Err(vm.new_value_error("mmap length is too large")); } map_size = (file_len - offset) as usize; } else { // If map_size > file_len, extend the file (Windows behavior) let required_size = offset.checked_add(map_size as i64).ok_or_else(|| { - unsafe { CloseHandle(duplicated_handle) }; + host_mmap::close_handle(duplicated_handle); vm.new_overflow_error("mmap size would cause file size overflow") })?; - if required_size > file_len { - // Extend file using SetFilePointerEx + SetEndOfFile - let result = unsafe { - SetFilePointerEx( - duplicated_handle, - required_size, - core::ptr::null_mut(), - FILE_BEGIN, - ) - }; - if result == 0 { - let err = io::Error::last_os_error(); - unsafe { CloseHandle(duplicated_handle) }; - return Err(err.to_pyexception(vm)); - } - let result = unsafe { SetEndOfFile(duplicated_handle) }; - if result == 0 { - let err = io::Error::last_os_error(); - unsafe { CloseHandle(duplicated_handle) }; - return Err(err.to_pyexception(vm)); - } + if required_size > file_len + && let Err(err) = host_mmap::extend_file(duplicated_handle, required_size) + { + host_mmap::close_handle(duplicated_handle); + return Err(err.to_pyexception(vm)); } } } // When tagname is provided, use raw Win32 APIs for named shared memory if let Some(ref tag) = tag_str { - let (fl_protect, desired_access) = match access { - AccessMode::Default | AccessMode::Write => (PAGE_READWRITE, FILE_MAP_WRITE), - AccessMode::Read => (PAGE_READONLY, FILE_MAP_READ), - AccessMode::Copy => (PAGE_WRITECOPY, FILE_MAP_COPY), - }; - let fh = if let Some(fh) = fh { // Close the duplicated handle - we'll use the original // file handle for CreateFileMappingW - if duplicated_handle != INVALID_HANDLE_VALUE { - unsafe { CloseHandle(duplicated_handle) }; + if duplicated_handle != host_mmap::INVALID_HANDLE { + host_mmap::close_handle(duplicated_handle); } fh } else { - INVALID_HANDLE_VALUE + host_mmap::INVALID_HANDLE }; - let tag_wide: Vec = tag.encode_utf16().chain(core::iter::once(0)).collect(); - - let total_size = (offset as u64) - .checked_add(map_size as u64) - .ok_or_else(|| vm.new_overflow_error("mmap offset plus size would overflow"))?; - let size_hi = (total_size >> 32) as u32; - let size_lo = total_size as u32; - - let map_handle = unsafe { - CreateFileMappingW( - fh, - core::ptr::null(), - fl_protect, - size_hi, - size_lo, - tag_wide.as_ptr(), - ) - }; - if map_handle.is_null() { - return Err(io::Error::last_os_error().to_pyexception(vm)); - } - - let off_hi = (offset as u64 >> 32) as u32; - let off_lo = offset as u32; - - let view = - unsafe { MapViewOfFile(map_handle, desired_access, off_hi, off_lo, map_size) }; - if view.Value.is_null() { - unsafe { CloseHandle(map_handle) }; - return Err(io::Error::last_os_error().to_pyexception(vm)); - } - - let named = NamedMmap { - map_handle, - view_ptr: view.Value as *mut u8, - len: map_size, - }; + let named = host_mmap::create_named_mapping( + fh, + tag, + match access { + AccessMode::Default => host_mmap::AccessMode::Default, + AccessMode::Read => host_mmap::AccessMode::Read, + AccessMode::Write => host_mmap::AccessMode::Write, + AccessMode::Copy => host_mmap::AccessMode::Copy, + }, + offset, + map_size, + ) + .map_err(|err| { + if err.raw_os_error() == Some(libc::EOVERFLOW) { + vm.new_overflow_error("mmap offset plus size would overflow") + } else { + err.to_pyexception(vm) + } + })?; return Ok(Self { closed: AtomicCell::new(false), mmap: PyMutex::new(Some(MmapObj::Named(named))), - handle: AtomicCell::new(INVALID_HANDLE_VALUE as isize), + handle: AtomicCell::new(host_mmap::INVALID_HANDLE as isize), offset, size: AtomicCell::new(map_size), pos: AtomicCell::new(0), @@ -733,34 +571,14 @@ mod mmap { }); } - let mut mmap_opt = MmapOptions::new(); - let mmap_opt = mmap_opt.offset(offset as u64).len(map_size); - - let (handle, mmap) = if duplicated_handle != INVALID_HANDLE_VALUE { - // Safety: We just duplicated this handle and it's valid - let owned_handle = - unsafe { OwnedHandle::from_raw_handle(duplicated_handle as RawHandle) }; - - let mmap_result = match access { - AccessMode::Default | AccessMode::Write => { - unsafe { mmap_opt.map_mut(&owned_handle) }.map(MmapObj::Write) - } - AccessMode::Read => unsafe { mmap_opt.map(&owned_handle) }.map(MmapObj::Read), - AccessMode::Copy => { - unsafe { mmap_opt.map_copy(&owned_handle) }.map(MmapObj::Write) - } - }; - - let mmap = mmap_result.map_err(|e| e.to_pyexception(vm))?; - - // Keep the handle alive - let raw = owned_handle.as_raw_handle() as isize; - core::mem::forget(owned_handle); - (raw, mmap) + let (handle, mmap) = if duplicated_handle != host_mmap::INVALID_HANDLE { + let mmap = Self::create_mmap_windows(duplicated_handle, offset, map_size, &access) + .map_err(|e| e.to_pyexception(vm))?; + (duplicated_handle as isize, mmap) } else { // Anonymous mapping - let mmap = mmap_opt.map_anon().map_err(|e| e.to_pyexception(vm))?; - (INVALID_HANDLE_VALUE as isize, MmapObj::Write(mmap)) + let mmap = host_mmap::map_anon(map_size).map_err(|e| e.to_pyexception(vm))?; + (host_mmap::INVALID_HANDLE as isize, MmapObj::Mapped(mmap)) }; Ok(Self { @@ -855,12 +673,9 @@ mod mmap { fn as_bytes_mut(&self) -> BorrowedValueMut<'_, [u8]> { PyMutexGuard::map(self.mmap.lock(), |m| { match m.as_mut().expect("mmap closed or invalid") { - MmapObj::Read(_) => panic!("mmap can't modify a readonly memory map."), - MmapObj::Write(mmap) => &mut mmap[..], + MmapObj::Mapped(mmap) => mmap.as_mut_slice(), #[cfg(windows)] - MmapObj::Named(named) => unsafe { - core::slice::from_raw_parts_mut(named.view_ptr, named.len) - }, + MmapObj::Named(named) => named.as_mut_slice(), } }) .into() @@ -898,12 +713,9 @@ mod mmap { } match self.check_valid(vm)?.deref_mut().as_mut().unwrap() { - MmapObj::Write(mmap) => Ok(f(&mut mmap[..])), + MmapObj::Mapped(mmap) => Ok(f(mmap.as_mut_slice())), #[cfg(windows)] - MmapObj::Named(named) => Ok(f(unsafe { - core::slice::from_raw_parts_mut(named.view_ptr, named.len) - })), - _ => unreachable!("already checked"), + MmapObj::Named(named) => Ok(f(named.as_mut_slice())), } } @@ -1010,18 +822,15 @@ mod mmap { } match self.check_valid(vm)?.deref().as_ref().unwrap() { - MmapObj::Read(_mmap) => {} - MmapObj::Write(mmap) => { + MmapObj::Mapped(mmap) => { mmap.flush_range(offset, size) .map_err(|e| e.to_pyexception(vm))?; } #[cfg(windows)] MmapObj::Named(named) => { - let ptr = unsafe { named.view_ptr.add(offset) }; - let result = unsafe { FlushViewOfFile(ptr as *const _, size) }; - if result == 0 { - return Err(io::Error::last_os_error().to_pyexception(vm)); - } + named + .flush_range(offset, size) + .map_err(|e| e.to_pyexception(vm))?; } } @@ -1032,22 +841,18 @@ mod mmap { #[pymethod] fn madvise(&self, options: AdviseOptions, vm: &VirtualMachine) -> PyResult<()> { let (option, start, length) = options.values(self.__len__(), vm)?; - let advice = validate_advice(vm, option)?; + if !host_mmap::validate_advice(option) { + return Err(vm.new_value_error("Not a valid Advice value")); + } let guard = self.check_valid(vm)?; let mmap = guard.deref().as_ref().unwrap(); - let ptr = match mmap { - MmapObj::Read(m) => m.as_ptr(), - MmapObj::Write(m) => m.as_ptr(), - }; - - // Apply madvise to the specified range (start, length) - let ptr_with_offset = unsafe { ptr.add(start) }; - let result = - unsafe { libc::madvise(ptr_with_offset as *mut libc::c_void, length, advice) }; - if result != 0 { - return Err(io::Error::last_os_error().to_pyexception(vm)); + match mmap { + MmapObj::Mapped(m) => m.madvise_range(start, length, option), + #[cfg(windows)] + MmapObj::Named(_) => unreachable!("unix-only method"), } + .map_err(|e| e.to_pyexception(vm))?; Ok(()) } @@ -1202,7 +1007,7 @@ mod mmap { return Err(vm.new_os_error("mmap: cannot resize a named memory mapping")); } - let is_anonymous = handle == INVALID_HANDLE_VALUE as isize; + let is_anonymous = host_mmap::is_invalid_handle_value(handle); if is_anonymous { // For anonymous mmap, we need to: @@ -1214,19 +1019,16 @@ mod mmap { let copy_size = core::cmp::min(old_size, newsize); // Create new anonymous mmap - let mut new_mmap_opts = MmapOptions::new(); - let mut new_mmap = new_mmap_opts - .len(newsize) - .map_anon() - .map_err(|e| e.to_pyexception(vm))?; + let mut new_mmap = + host_mmap::map_anon(newsize).map_err(|e| e.to_pyexception(vm))?; // Copy data from old mmap to new mmap if let Some(old_mmap) = mmap_guard.as_ref() { let src = &old_mmap.as_slice()[..copy_size]; - new_mmap[..copy_size].copy_from_slice(src); + new_mmap.as_mut_slice()[..copy_size].copy_from_slice(src); } - *mmap_guard = Some(MmapObj::Write(new_mmap)); + *mmap_guard = Some(MmapObj::Mapped(new_mmap)); self.size.store(newsize); } else { // File-backed mmap resize @@ -1234,34 +1036,26 @@ mod mmap { // Drop the current mmap to release the file mapping *mmap_guard = None; - // Resize the file let required_size = self.offset + newsize as i64; - let result = unsafe { - SetFilePointerEx( - handle as HANDLE, - required_size, - core::ptr::null_mut(), - FILE_BEGIN, - ) - }; - if result == 0 { + if let Err(err) = host_mmap::extend_file(handle as host_mmap::Handle, required_size) + { // Restore original mmap on error - let err = io::Error::last_os_error(); - self.try_restore_mmap(&mut mmap_guard, handle as HANDLE, self.size.load()); - return Err(err.to_pyexception(vm)); - } - - let result = unsafe { SetEndOfFile(handle as HANDLE) }; - if result == 0 { - let err = io::Error::last_os_error(); - self.try_restore_mmap(&mut mmap_guard, handle as HANDLE, self.size.load()); + self.try_restore_mmap( + &mut mmap_guard, + handle as host_mmap::Handle, + self.size.load(), + ); return Err(err.to_pyexception(vm)); } // Create new mmap with the new size - let new_mmap = - Self::create_mmap_windows(handle as HANDLE, self.offset, newsize, &self.access) - .map_err(|e| e.to_pyexception(vm))?; + let new_mmap = Self::create_mmap_windows( + handle as host_mmap::Handle, + self.offset, + newsize, + &self.access, + ) + .map_err(|e| e.to_pyexception(vm))?; *mmap_guard = Some(new_mmap); self.size.store(newsize); @@ -1319,7 +1113,7 @@ mod mmap { #[pymethod] fn size(&self, vm: &VirtualMachine) -> std::io::Result { let fd = unsafe { crt_fd::Borrowed::try_borrow_raw(self.fd.load())? }; - let file_len = fstat(fd)?.st_size; + let file_len = host_mmap::file_len(fd)?; Ok(PyInt::from(file_len).into_ref(&vm.ctx)) } @@ -1327,20 +1121,13 @@ mod mmap { #[pymethod] fn size(&self, vm: &VirtualMachine) -> PyResult { let handle = self.handle.load(); - if handle == INVALID_HANDLE_VALUE as isize { + if host_mmap::is_invalid_handle_value(handle) { // Anonymous mapping, return the mmap size return Ok(PyInt::from(self.__len__()).into_ref(&vm.ctx)); } - let mut high: u32 = 0; - let low = unsafe { GetFileSize(handle as HANDLE, &mut high) }; - if low == u32::MAX { - let err = io::Error::last_os_error(); - if err.raw_os_error() != Some(0) { - return Err(err.to_pyexception(vm)); - } - } - let file_len = ((high as i64) << 32) | (low as i64); + let file_len = host_mmap::get_file_len(handle as host_mmap::Handle) + .map_err(|e| e.to_pyexception(vm))?; Ok(PyInt::from(file_len).into_ref(&vm.ctx)) } @@ -1426,39 +1213,35 @@ mod mmap { impl PyMmap { #[cfg(windows)] fn create_mmap_windows( - handle: HANDLE, + handle: host_mmap::Handle, offset: i64, size: usize, access: &AccessMode, ) -> io::Result { - use std::fs::File; - - // Create an owned handle wrapper for memmap2 - // We need to create a File from the handle - let file = unsafe { File::from_raw_handle(handle as RawHandle) }; - - let mut mmap_opt = MmapOptions::new(); - let mmap_opt = mmap_opt.offset(offset as u64).len(size); - - let result = match access { - AccessMode::Default | AccessMode::Write => { - unsafe { mmap_opt.map_mut(&file) }.map(MmapObj::Write) - } - AccessMode::Read => unsafe { mmap_opt.map(&file) }.map(MmapObj::Read), - AccessMode::Copy => unsafe { mmap_opt.map_copy(&file) }.map(MmapObj::Write), - }; - - // Don't close the file handle - we're borrowing it - core::mem::forget(file); - - result + host_mmap::map_handle( + handle, + offset, + size, + match access { + AccessMode::Default => host_mmap::AccessMode::Default, + AccessMode::Read => host_mmap::AccessMode::Read, + AccessMode::Write => host_mmap::AccessMode::Write, + AccessMode::Copy => host_mmap::AccessMode::Copy, + }, + ) + .map(MmapObj::Mapped) } /// Try to restore mmap after a failed resize operation. /// Returns true if restoration succeeded, false otherwise. /// If restoration fails, marks the mmap as closed. #[cfg(windows)] - fn try_restore_mmap(&self, mmap_guard: &mut Option, handle: HANDLE, size: usize) { + fn try_restore_mmap( + &self, + mmap_guard: &mut Option, + handle: host_mmap::Handle, + size: usize, + ) { match Self::create_mmap_windows(handle, self.offset, size, &self.access) { Ok(mmap) => *mmap_guard = Some(mmap), Err(_) => self.closed.store(true), diff --git a/crates/stdlib/src/multiprocessing.rs b/crates/stdlib/src/multiprocessing.rs index 36f3991022b..53f6692577e 100644 --- a/crates/stdlib/src/multiprocessing.rs +++ b/crates/stdlib/src/multiprocessing.rs @@ -6,18 +6,12 @@ mod _multiprocessing { use crate::vm::{ Context, FromArgs, Py, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyDict, PyType, PyTypeRef}, + convert::ToPyException, function::{ArgBytesLike, FuncArgs, KwArgs}, types::Constructor, }; use core::sync::atomic::{AtomicI32, AtomicU32, Ordering}; - use windows_sys::Win32::Foundation::{ - CloseHandle, ERROR_TOO_MANY_POSTS, HANDLE, INVALID_HANDLE_VALUE, WAIT_FAILED, - WAIT_OBJECT_0, WAIT_TIMEOUT, - }; - use windows_sys::Win32::Networking::WinSock::{self, SOCKET}; - use windows_sys::Win32::System::Threading::{ - CreateSemaphoreW, GetCurrentThreadId, INFINITE, ReleaseSemaphore, WaitForSingleObjectEx, - }; + use rustpython_host_env::multiprocessing as host_multiprocessing; // These match the values in Lib/multiprocessing/synchronize.py const RECURSIVE_MUTEX: i32 = 0; @@ -26,7 +20,8 @@ mod _multiprocessing { macro_rules! ismine { ($self:expr) => { $self.count.load(Ordering::Acquire) > 0 - && $self.last_tid.load(Ordering::Acquire) == unsafe { GetCurrentThreadId() } + && $self.last_tid.load(Ordering::Acquire) + == host_multiprocessing::current_thread_id() }; } @@ -56,54 +51,7 @@ mod _multiprocessing { count: AtomicI32, } - #[derive(Debug)] - struct SemHandle { - raw: HANDLE, - } - - unsafe impl Send for SemHandle {} - unsafe impl Sync for SemHandle {} - - impl SemHandle { - fn create(value: i32, maxvalue: i32, vm: &VirtualMachine) -> PyResult { - let handle = - unsafe { CreateSemaphoreW(core::ptr::null(), value, maxvalue, core::ptr::null()) }; - if handle == 0 as HANDLE { - return Err(vm.new_last_os_error()); - } - Ok(Self { raw: handle }) - } - - #[inline] - fn as_raw(&self) -> HANDLE { - self.raw - } - } - - impl Drop for SemHandle { - fn drop(&mut self) { - if self.raw != 0 as HANDLE && self.raw != INVALID_HANDLE_VALUE { - unsafe { - CloseHandle(self.raw); - } - } - } - } - - /// _GetSemaphoreValue - get value of semaphore by briefly acquiring and releasing - fn get_semaphore_value(handle: HANDLE) -> Result { - match unsafe { WaitForSingleObjectEx(handle, 0, 0) } { - WAIT_OBJECT_0 => { - let mut previous: i32 = 0; - if unsafe { ReleaseSemaphore(handle, 1, &mut previous) } == 0 { - return Err(()); - } - Ok(previous + 1) - } - WAIT_TIMEOUT => Ok(0), - _ => Err(()), - } - } + type SemHandle = host_multiprocessing::SemHandle; #[pyclass(with(Constructor), flags(BASETYPE))] impl SemLock { @@ -147,13 +95,13 @@ mod _multiprocessing { let full_msecs: u32 = if !blocking { 0 } else if timeout_obj.as_ref().is_none_or(|o| vm.is_none(o)) { - INFINITE + host_multiprocessing::INFINITE_TIMEOUT } else { let timeout: f64 = timeout_obj.unwrap().try_float(vm)?.to_f64(); let timeout = timeout * 1000.0; // convert to ms if timeout < 0.0 { 0 - } else if timeout >= 0.5 * INFINITE as f64 { + } else if timeout >= 0.5 * host_multiprocessing::INFINITE_TIMEOUT as f64 { return Err(vm.new_overflow_error("timeout is too large")); } else { (timeout + 0.5) as u32 @@ -167,14 +115,14 @@ mod _multiprocessing { } // Check whether we can acquire without blocking - match unsafe { WaitForSingleObjectEx(self.handle.as_raw(), 0, 0) } { - WAIT_OBJECT_0 => { + match host_multiprocessing::wait_for_single_object(self.handle.as_raw(), 0) { + x if x == host_multiprocessing::wait_object_0() => { self.last_tid - .store(unsafe { GetCurrentThreadId() }, Ordering::Release); + .store(host_multiprocessing::current_thread_id(), Ordering::Release); self.count.fetch_add(1, Ordering::Release); return Ok(true); } - WAIT_FAILED => return Err(vm.new_last_os_error()), + x if x == host_multiprocessing::wait_failed() => return Err(vm.new_last_os_error()), _ => {} } @@ -183,7 +131,7 @@ mod _multiprocessing { let poll_ms: u32 = 100; let mut elapsed: u32 = 0; loop { - let wait_ms = if full_msecs == INFINITE { + let wait_ms = if full_msecs == host_multiprocessing::INFINITE_TIMEOUT { poll_ms } else { let remaining = full_msecs.saturating_sub(elapsed); @@ -194,22 +142,26 @@ mod _multiprocessing { }; let handle = self.handle.as_raw(); - let res = vm.allow_threads(|| unsafe { WaitForSingleObjectEx(handle, wait_ms, 0) }); + let res = vm.allow_threads(|| { + host_multiprocessing::wait_for_single_object(handle, wait_ms) + }); match res { - WAIT_OBJECT_0 => { + x if x == host_multiprocessing::wait_object_0() => { self.last_tid - .store(unsafe { GetCurrentThreadId() }, Ordering::Release); + .store(host_multiprocessing::current_thread_id(), Ordering::Release); self.count.fetch_add(1, Ordering::Release); return Ok(true); } - WAIT_TIMEOUT => { + x if x == host_multiprocessing::wait_timeout() => { vm.check_signals()?; - if full_msecs != INFINITE { + if full_msecs != host_multiprocessing::INFINITE_TIMEOUT { elapsed = elapsed.saturating_add(wait_ms); } } - WAIT_FAILED => return Err(vm.new_last_os_error()), + x if x == host_multiprocessing::wait_failed() => { + return Err(vm.new_last_os_error()); + } _ => { return Err(vm.new_runtime_error(format!( "WaitForSingleObject() gave unrecognized value {res}" @@ -234,9 +186,8 @@ mod _multiprocessing { } } - if unsafe { ReleaseSemaphore(self.handle.as_raw(), 1, core::ptr::null_mut()) } == 0 { - let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; - if err == ERROR_TOO_MANY_POSTS { + if let Err(err) = host_multiprocessing::release_semaphore(self.handle.as_raw()) { + if host_multiprocessing::is_too_many_posts(err) { return Err(vm.new_value_error("semaphore or lock released too many times")); } return Err(vm.new_last_os_error()); @@ -273,9 +224,7 @@ mod _multiprocessing { ) -> PyResult { // On Windows, _rebuild receives the handle directly (no sem_open) let zelf = Self { - handle: SemHandle { - raw: handle as HANDLE, - }, + handle: SemHandle::from_raw(handle as host_multiprocessing::RawHandle), kind, maxvalue, name, @@ -308,13 +257,14 @@ mod _multiprocessing { #[pymethod] fn _get_value(&self, vm: &VirtualMachine) -> PyResult { - get_semaphore_value(self.handle.as_raw()).map_err(|_| vm.new_last_os_error()) + host_multiprocessing::get_semaphore_value(self.handle.as_raw()) + .map_err(|_| vm.new_last_os_error()) } #[pymethod] fn _is_zero(&self, vm: &VirtualMachine) -> PyResult { - let val = - get_semaphore_value(self.handle.as_raw()).map_err(|_| vm.new_last_os_error())?; + let val = host_multiprocessing::get_semaphore_value(self.handle.as_raw()) + .map_err(|_| vm.new_last_os_error())?; Ok(val == 0) } @@ -346,7 +296,8 @@ mod _multiprocessing { return Err(vm.new_value_error("invalid value")); } - let handle = SemHandle::create(args.value, args.maxvalue, vm)?; + let handle = + SemHandle::create(args.value, args.maxvalue).map_err(|e| e.to_pyexception(vm))?; let name = if args.unlink { None } else { Some(args.name) }; Ok(Self { @@ -372,37 +323,22 @@ mod _multiprocessing { #[pyfunction] fn closesocket(socket: usize, vm: &VirtualMachine) -> PyResult<()> { - let res = unsafe { WinSock::closesocket(socket as SOCKET) }; - if res != 0 { - Err(vm.new_last_os_error()) - } else { - Ok(()) - } + host_multiprocessing::close_socket(socket as host_multiprocessing::RawSocket) + .map_err(|_| vm.new_last_os_error()) } #[pyfunction] fn recv(socket: usize, size: usize, vm: &VirtualMachine) -> PyResult> { - let mut buf = vec![0u8; size]; - let n_read = - unsafe { WinSock::recv(socket as SOCKET, buf.as_mut_ptr() as *mut _, size as i32, 0) }; - if n_read < 0 { - Err(vm.new_last_os_error()) - } else { - buf.truncate(n_read as usize); - Ok(buf) - } + host_multiprocessing::recv_socket(socket as host_multiprocessing::RawSocket, size) + .map_err(|_| vm.new_last_os_error()) } #[pyfunction] fn send(socket: usize, buf: ArgBytesLike, vm: &VirtualMachine) -> PyResult { - let ret = buf.with_ref(|b| unsafe { - WinSock::send(socket as SOCKET, b.as_ptr() as *const _, b.len() as i32, 0) - }); - if ret < 0 { - Err(vm.new_last_os_error()) - } else { - Ok(ret) - } + buf.with_ref(|b| { + host_multiprocessing::send_socket(socket as host_multiprocessing::RawSocket, b) + }) + .map_err(|_| vm.new_last_os_error()) } } @@ -414,20 +350,23 @@ mod _multiprocessing { use crate::vm::{ Context, FromArgs, Py, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyDict, PyType, PyTypeRef}, + convert::ToPyException, function::{FuncArgs, KwArgs}, types::Constructor, }; - use alloc::ffi::CString; use core::sync::atomic::{AtomicI32, AtomicU64, Ordering}; + #[cfg(target_vendor = "apple")] use libc::sem_t; - use nix::errno::Errno; + use rustpython_host_env::multiprocessing::{ + self as host_multiprocessing, SemError, TryAcquireStatus, WaitStatus, + }; /// Error type for sem_timedwait operations #[cfg(target_vendor = "apple")] enum SemWaitError { Timeout, SignalException(PyBaseExceptionRef), - OsError(Errno), + OsError(SemError), } /// macOS fallback for sem_timedwait using select + sem_trywait polling @@ -441,60 +380,19 @@ mod _multiprocessing { let mut delay: u64 = 0; loop { - // poll: try to acquire - if unsafe { libc::sem_trywait(sem) } == 0 { - return Ok(()); - } - let err = Errno::last(); - if err != Errno::EAGAIN { - return Err(SemWaitError::OsError(err)); - } - - // get current time - let mut now = libc::timeval { - tv_sec: 0, - tv_usec: 0, - }; - if unsafe { libc::gettimeofday(&mut now, core::ptr::null_mut()) } < 0 { - return Err(SemWaitError::OsError(Errno::last())); - } - - // check for timeout - let deadline_usec = deadline.tv_sec * 1_000_000 + deadline.tv_nsec / 1000; - #[allow(clippy::unnecessary_cast)] - let now_usec = now.tv_sec as i64 * 1_000_000 + now.tv_usec as i64; - - if now_usec >= deadline_usec { - return Err(SemWaitError::Timeout); - } - - // calculate how much time is left - let difference = (deadline_usec - now_usec) as u64; - - // check delay not too long -- maximum is 20 msecs - delay += 1000; - if delay > 20000 { - delay = 20000; - } - if delay > difference { - delay = difference; + match vm.allow_threads(|| { + host_multiprocessing::sem_timedwait_poll_step(sem, deadline, delay) + }) { + Ok(host_multiprocessing::PollWaitStep::Acquired) => return Ok(()), + Ok(host_multiprocessing::PollWaitStep::Timeout) => { + return Err(SemWaitError::Timeout); + } + Ok(host_multiprocessing::PollWaitStep::Continue(next_delay)) => { + delay = next_delay; + } + Err(err) => return Err(SemWaitError::OsError(err)), } - // sleep using select - let mut tv_delay = libc::timeval { - tv_sec: (delay / 1_000_000) as _, - tv_usec: (delay % 1_000_000) as _, - }; - vm.allow_threads(|| unsafe { - libc::select( - 0, - core::ptr::null_mut(), - core::ptr::null_mut(), - core::ptr::null_mut(), - &mut tv_delay, - ) - }); - // check for signals - preserve the exception (e.g., KeyboardInterrupt) if let Err(exc) = vm.check_signals() { return Err(SemWaitError::SignalException(exc)); @@ -510,7 +408,8 @@ mod _multiprocessing { macro_rules! ismine { ($self:expr) => { $self.count.load(Ordering::Acquire) > 0 - && $self.last_tid.load(Ordering::Acquire) == current_thread_id() + && $self.last_tid.load(Ordering::Acquire) + == host_multiprocessing::current_thread_id() }; } @@ -540,70 +439,7 @@ mod _multiprocessing { count: AtomicI32, // int } - #[derive(Debug)] - struct SemHandle { - raw: *mut sem_t, - } - - unsafe impl Send for SemHandle {} - unsafe impl Sync for SemHandle {} - - impl SemHandle { - fn create( - name: &str, - value: u32, - unlink: bool, - vm: &VirtualMachine, - ) -> PyResult<(Self, Option)> { - let cname = semaphore_name(vm, name)?; - // SEM_CREATE(name, val, max) sem_open(name, O_CREAT | O_EXCL, 0600, val) - let raw = unsafe { - libc::sem_open(cname.as_ptr(), libc::O_CREAT | libc::O_EXCL, 0o600, value) - }; - if raw == libc::SEM_FAILED { - let err = Errno::last(); - return Err(os_error(vm, err)); - } - if unlink { - // SEM_UNLINK(name) sem_unlink(name) - unsafe { - libc::sem_unlink(cname.as_ptr()); - } - Ok((Self { raw }, None)) - } else { - Ok((Self { raw }, Some(name.to_owned()))) - } - } - - fn open_existing(name: &str, vm: &VirtualMachine) -> PyResult { - let cname = semaphore_name(vm, name)?; - let raw = unsafe { libc::sem_open(cname.as_ptr(), 0) }; - if raw == libc::SEM_FAILED { - let err = Errno::last(); - return Err(os_error(vm, err)); - } - Ok(Self { raw }) - } - - #[inline] - fn as_ptr(&self) -> *mut sem_t { - self.raw - } - } - - impl Drop for SemHandle { - fn drop(&mut self) { - // Guard against default/uninitialized state. - // Note: SEM_FAILED is (sem_t*)-1, not null, but valid handles are never null - // and SEM_FAILED is never stored (error is returned immediately on sem_open failure). - if !self.raw.is_null() { - // SEM_CLOSE(sem) sem_close(sem) - unsafe { - libc::sem_close(self.raw); - } - } - } - } + type SemHandle = host_multiprocessing::SemHandle; #[pyclass(with(Constructor), flags(BASETYPE))] impl SemLock { @@ -659,54 +495,26 @@ mod _multiprocessing { let timeout_obj = timeout_obj.unwrap(); // This accepts both int and float, converting to f64 let timeout: f64 = timeout_obj.try_float(vm)?.to_f64(); - let timeout = if timeout < 0.0 { 0.0 } else { timeout }; - - let mut tv = libc::timeval { - tv_sec: 0, - tv_usec: 0, - }; - let res = unsafe { libc::gettimeofday(&mut tv, core::ptr::null_mut()) }; - if res < 0 { - return Err(vm.new_os_error("gettimeofday failed".to_string())); - } - - // deadline calculation: - // long sec = (long) timeout; - // long nsec = (long) (1e9 * (timeout - sec) + 0.5); - // deadline.tv_sec = now.tv_sec + sec; - // deadline.tv_nsec = now.tv_usec * 1000 + nsec; - // deadline.tv_sec += (deadline.tv_nsec / 1000000000); - // deadline.tv_nsec %= 1000000000; - let sec = timeout as libc::c_long; - let nsec = (1e9 * (timeout - sec as f64) + 0.5) as libc::c_long; - let mut deadline = libc::timespec { - tv_sec: tv.tv_sec + sec as libc::time_t, - tv_nsec: (tv.tv_usec as libc::c_long * 1000 + nsec) as _, - }; - deadline.tv_sec += (deadline.tv_nsec / 1_000_000_000) as libc::time_t; - deadline.tv_nsec %= 1_000_000_000; - Some(deadline) + Some( + host_multiprocessing::deadline_from_timeout(timeout) + .map_err(|_| vm.new_os_error("gettimeofday failed".to_string()))?, + ) } else { None }; // Check whether we can acquire without releasing the GIL and blocking - let mut res; - loop { - res = unsafe { libc::sem_trywait(self.handle.as_ptr()) }; - if res >= 0 { - break; - } - let err = Errno::last(); - if err == Errno::EINTR { - vm.check_signals()?; - continue; + let try_status = loop { + match host_multiprocessing::sem_trywait_status(self.handle.as_ptr()) { + TryAcquireStatus::Interrupted => { + vm.check_signals()?; + } + status => break status, } - break; - } + }; // if (res < 0 && errno == EAGAIN && blocking) - if res < 0 && Errno::last() == Errno::EAGAIN && blocking { + if matches!(try_status, TryAcquireStatus::WouldBlock) && blocking { // Couldn't acquire immediately, need to block. // // Save errno inside the allow_threads closure, before @@ -715,49 +523,20 @@ mod _multiprocessing { #[cfg(not(target_vendor = "apple"))] { - let mut saved_errno; loop { let sem_ptr = self.handle.as_ptr(); // Py_BEGIN_ALLOW_THREADS / Py_END_ALLOW_THREADS - let (r, e) = if let Some(ref dl) = deadline { - vm.allow_threads(|| { - let r = unsafe { libc::sem_timedwait(sem_ptr, dl) }; - ( - r, - if r < 0 { - Errno::last() - } else { - Errno::from_raw(0) - }, - ) - }) - } else { - vm.allow_threads(|| { - let r = unsafe { libc::sem_wait(sem_ptr) }; - ( - r, - if r < 0 { - Errno::last() - } else { - Errno::from_raw(0) - }, - ) - }) - }; - res = r; - saved_errno = e; - - if res >= 0 { - break; - } - if saved_errno == Errno::EINTR { - vm.check_signals()?; - continue; + match vm.allow_threads(|| { + host_multiprocessing::sem_wait_status(sem_ptr, deadline.as_ref()) + }) { + WaitStatus::Acquired => break, + WaitStatus::Interrupted => { + vm.check_signals()?; + continue; + } + WaitStatus::TimedOut => return Ok(false), + WaitStatus::Error(err) => return Err(os_error(vm, err)), } - break; - } - if res < 0 { - return handle_wait_error(vm, saved_errno); } } #[cfg(target_vendor = "apple")] @@ -778,50 +557,35 @@ mod _multiprocessing { } } else { // No timeout: use sem_wait (available on macOS) - let mut saved_errno; loop { let sem_ptr = self.handle.as_ptr(); - let (r, e) = vm.allow_threads(|| { - let r = unsafe { libc::sem_wait(sem_ptr) }; - ( - r, - if r < 0 { - Errno::last() - } else { - Errno::from_raw(0) - }, - ) - }); - res = r; - saved_errno = e; - if res >= 0 { - break; + match vm.allow_threads(|| { + host_multiprocessing::sem_wait_status(sem_ptr, None) + }) { + WaitStatus::Acquired => break, + WaitStatus::Interrupted => { + vm.check_signals()?; + continue; + } + WaitStatus::TimedOut => return Ok(false), + WaitStatus::Error(err) => return Err(os_error(vm, err)), } - if saved_errno == Errno::EINTR { - vm.check_signals()?; - continue; - } - break; - } - if res < 0 { - return handle_wait_error(vm, saved_errno); } } } - } else if res < 0 { + } else if !matches!(try_status, TryAcquireStatus::Acquired) { // Non-blocking path failed, or blocking=false - let err = Errno::last(); - match err { - Errno::EAGAIN | Errno::ETIMEDOUT => return Ok(false), - Errno::EINTR => { - return vm.check_signals().map(|_| false); - } - _ => return Err(os_error(vm, err)), + match try_status { + TryAcquireStatus::WouldBlock => return Ok(false), + TryAcquireStatus::Interrupted => return vm.check_signals().map(|_| false), + TryAcquireStatus::Error(err) => return Err(os_error(vm, err)), + TryAcquireStatus::Acquired => unreachable!(), } } self.count.fetch_add(1, Ordering::Release); - self.last_tid.store(current_thread_id(), Ordering::Release); + self.last_tid + .store(host_multiprocessing::current_thread_id(), Ordering::Release); Ok(true) } @@ -849,11 +613,9 @@ mod _multiprocessing { #[cfg(not(target_vendor = "apple"))] { // Linux: use sem_getvalue - let mut sval: libc::c_int = 0; - let res = unsafe { libc::sem_getvalue(self.handle.as_ptr(), &mut sval) }; - if res < 0 { - return Err(os_error(vm, Errno::last())); - } + let sval = + unsafe { host_multiprocessing::get_semaphore_value(self.handle.as_ptr()) } + .map_err(|err| os_error(vm, err))?; if sval >= self.maxvalue { return Err(vm.new_value_error("semaphore or lock released too many times")); } @@ -864,27 +626,29 @@ mod _multiprocessing { // We will only check properly the maxvalue == 1 case if self.maxvalue == 1 { // make sure that already locked - if unsafe { libc::sem_trywait(self.handle.as_ptr()) } < 0 { - if Errno::last() != Errno::EAGAIN { - return Err(os_error(vm, Errno::last())); + match host_multiprocessing::sem_trywait_status(self.handle.as_ptr()) { + TryAcquireStatus::WouldBlock => {} + TryAcquireStatus::Acquired => { + if let Err(err) = + host_multiprocessing::sem_post(self.handle.as_ptr()) + { + return Err(os_error(vm, err)); + } + return Err( + vm.new_value_error("semaphore or lock released too many times") + ); } - // it is already locked as expected - } else { - // it was not locked so undo wait and raise - if unsafe { libc::sem_post(self.handle.as_ptr()) } < 0 { - return Err(os_error(vm, Errno::last())); + TryAcquireStatus::Interrupted => { + return Err(os_error(vm, SemError::Interrupted)); } - return Err( - vm.new_value_error("semaphore or lock released too many times") - ); + TryAcquireStatus::Error(err) => return Err(os_error(vm, err)), } } } } - let res = unsafe { libc::sem_post(self.handle.as_ptr()) }; - if res < 0 { - return Err(os_error(vm, Errno::last())); + if let Err(err) = host_multiprocessing::sem_post(self.handle.as_ptr()) { + return Err(os_error(vm, err)); } self.count.fetch_sub(1, Ordering::Release); @@ -926,7 +690,7 @@ mod _multiprocessing { let Some(ref name_str) = name else { return Err(vm.new_value_error("cannot rebuild SemLock without name")); }; - let handle = SemHandle::open_existing(name_str, vm)?; + let handle = SemHandle::open_existing(name_str).map_err(|err| os_error(vm, err))?; // return newsemlockobject(type, handle, kind, maxvalue, name_copy); let zelf = Self { handle, @@ -976,14 +740,8 @@ mod _multiprocessing { #[cfg(not(target_vendor = "apple"))] { // Linux: use sem_getvalue - let mut sval: libc::c_int = 0; - let res = unsafe { libc::sem_getvalue(self.handle.as_ptr(), &mut sval) }; - if res < 0 { - return Err(os_error(vm, Errno::last())); - } - // some posix implementations use negative numbers to indicate - // the number of waiting threads - Ok(if sval < 0 { 0 } else { sval }) + unsafe { host_multiprocessing::get_semaphore_value(self.handle.as_ptr()) } + .map_err(|err| os_error(vm, err)) } #[cfg(target_vendor = "apple")] { @@ -1004,15 +762,17 @@ mod _multiprocessing { { // macOS: HAVE_BROKEN_SEM_GETVALUE // Try to acquire - if EAGAIN, value is 0 - if unsafe { libc::sem_trywait(self.handle.as_ptr()) } < 0 { - if Errno::last() == Errno::EAGAIN { - return Ok(true); + match host_multiprocessing::sem_trywait_status(self.handle.as_ptr()) { + TryAcquireStatus::WouldBlock => return Ok(true), + TryAcquireStatus::Interrupted => { + return Err(os_error(vm, SemError::Interrupted)); } - return Err(os_error(vm, Errno::last())); + TryAcquireStatus::Error(err) => return Err(os_error(vm, err)), + TryAcquireStatus::Acquired => {} } // Successfully acquired - undo and return false - if unsafe { libc::sem_post(self.handle.as_ptr()) } < 0 { - return Err(os_error(vm, Errno::last())); + if let Err(err) = host_multiprocessing::sem_post(self.handle.as_ptr()) { + return Err(os_error(vm, err)); } Ok(false) } @@ -1027,14 +787,7 @@ mod _multiprocessing { class.set_attr(ctx.intern_str("SEMAPHORE"), ctx.new_int(SEMAPHORE).into()); // SEM_VALUE_MAX from system, or INT_MAX if negative // We use a reasonable default - let sem_value_max: i32 = unsafe { - let val = libc::sysconf(libc::_SC_SEM_VALUE_MAX); - if val < 0 || val > i32::MAX as libc::c_long { - i32::MAX - } else { - val as i32 - } - }; + let sem_value_max = host_multiprocessing::sem_value_max(); class.set_attr( ctx.intern_str("SEM_VALUE_MAX"), ctx.new_int(sem_value_max).into(), @@ -1057,7 +810,14 @@ mod _multiprocessing { } let value = args.value as u32; - let (handle, name) = SemHandle::create(&args.name, value, args.unlink, vm)?; + let (handle, name) = + SemHandle::create(&args.name, value, args.unlink).map_err(|err| { + if err == SemError::InvalidInput && args.name.contains('\0') { + vm.new_value_error("embedded null character") + } else { + os_error(vm, err) + } + })?; // return newsemlockobject(type, handle, kind, maxvalue, name_copy); Ok(Self { @@ -1075,12 +835,13 @@ mod _multiprocessing { // _PyMp_sem_unlink. #[pyfunction] fn sem_unlink(name: String, vm: &VirtualMachine) -> PyResult<()> { - let cname = semaphore_name(vm, &name)?; - let res = unsafe { libc::sem_unlink(cname.as_ptr()) }; - if res < 0 { - return Err(os_error(vm, Errno::last())); - } - Ok(()) + host_multiprocessing::sem_unlink(&name).map_err(|err| { + if err == SemError::InvalidInput && name.contains('\0') { + vm.new_value_error("embedded null character") + } else { + os_error(vm, err) + } + }) } /// Module-level flags dict. @@ -1111,39 +872,8 @@ mod _multiprocessing { flags } - fn semaphore_name(vm: &VirtualMachine, name: &str) -> PyResult { - // POSIX semaphore names must start with / - let mut full = String::with_capacity(name.len() + 1); - if !name.starts_with('/') { - full.push('/'); - } - full.push_str(name); - CString::new(full).map_err(|_| vm.new_value_error("embedded null character")) - } - - fn handle_wait_error(vm: &VirtualMachine, saved_errno: Errno) -> PyResult { - match saved_errno { - Errno::EAGAIN | Errno::ETIMEDOUT => Ok(false), - Errno::EINTR => vm.check_signals().map(|_| false), - _ => Err(os_error(vm, saved_errno)), - } - } - - fn os_error(vm: &VirtualMachine, err: Errno) -> PyBaseExceptionRef { - // _PyMp_SetError maps to PyErr_SetFromErrno - let exc_type = match err { - Errno::EEXIST => vm.ctx.exceptions.file_exists_error.to_owned(), - Errno::ENOENT => vm.ctx.exceptions.file_not_found_error.to_owned(), - _ => vm.ctx.exceptions.os_error.to_owned(), - }; - vm.new_os_subtype_error(exc_type, Some(err as i32), err.desc().to_owned()) - .upcast() - } - - /// Get current thread identifier. - /// PyThread_get_thread_ident on Unix (pthread_self). - fn current_thread_id() -> u64 { - unsafe { libc::pthread_self() as u64 } + fn os_error(vm: &VirtualMachine, err: SemError) -> PyBaseExceptionRef { + err.to_pyexception(vm) } } diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index 24cbe8a40a3..76309d3c21d 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -65,7 +65,7 @@ mod _ssl { LazyLock, PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }, - socket::{self, PySocket}, + socket::{self, PySocket, SockWaitKind, sock_wait}, vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ @@ -2292,47 +2292,52 @@ mod _ssl { self.0.get_timeout().map(|d| Instant::now() + d) } - fn select(&self, needs: SslNeeds, deadline: &SocketDeadline) -> SelectRet { - let sock = match self.0.sock_opt() { - Some(s) => s, - None => return SelectRet::Closed, + fn select( + &self, + needs: SslNeeds, + deadline: &SocketDeadline, + vm: &VirtualMachine, + ) -> PyResult { + let Some(sock) = self.0.sock_opt() else { + return Ok(SelectRet::Closed); }; // For blocking sockets without timeout, call sock_select with None timeout // to actually block waiting for data instead of busy-looping let timeout = match &deadline { Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { Some(d) => Some(d), - None => return SelectRet::TimedOut, + None => return Ok(SelectRet::TimedOut), }, Err(true) => None, // Blocking: no timeout, wait indefinitely - Err(false) => return SelectRet::Nonblocking, + Err(false) => return Ok(SelectRet::Nonblocking), }; - let res = socket::sock_select( - &sock, - match needs { - SslNeeds::Read => socket::SelectKind::Read, - SslNeeds::Write => socket::SelectKind::Write, - }, - timeout, - ); - match res { - Ok(true) => SelectRet::TimedOut, - _ => SelectRet::Ok, - } + let wait_kind = match needs { + SslNeeds::Read => SockWaitKind::Read, + SslNeeds::Write => SockWaitKind::Write, + }; + sock_wait(&*sock, wait_kind, timeout, vm).map(|timed_out| { + if timed_out { + SelectRet::TimedOut + } else { + SelectRet::Ok + } + }) } fn socket_needs( &self, err: &ssl::Error, deadline: &SocketDeadline, - ) -> (Option, SelectRet) { + vm: &VirtualMachine, + ) -> PyResult<(Option, SelectRet)> { let needs = match err.code() { ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read), ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write), _ => None, }; - let state = needs.map_or(SelectRet::Ok, |needs| self.select(needs, deadline)); - (needs, state) + let state = + needs.map_or(Ok(SelectRet::Ok), |needs| self.select(needs, deadline, vm))?; + Ok((needs, state)) } } @@ -2850,7 +2855,7 @@ mod _ssl { break; } // Wait briefly for peer's close_notify before retrying - match socket_stream.select(SslNeeds::Read, &deadline) { + match socket_stream.select(SslNeeds::Read, &deadline, vm)? { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( vm, @@ -2888,7 +2893,7 @@ mod _ssl { }; // Wait on the socket - match socket_stream.select(needs, &deadline) { + match socket_stream.select(needs, &deadline, vm)? { SelectRet::TimedOut => { let msg = if err == sys::SSL_ERROR_WANT_READ { "The read operation timed out" @@ -2984,7 +2989,7 @@ mod _ssl { let (needs, state) = stream .get_ref() .expect("handshake called in bio mode; should only be called in socket mode") - .socket_needs(&err, &timeout); + .socket_needs(&err, &timeout, vm)?; match state { SelectRet::TimedOut => { // Clean up SNI ex_data before returning error @@ -3038,7 +3043,7 @@ mod _ssl { .get_ref() .expect("write called in bio mode; should only be called in socket mode"); let timeout = socket_ref.timeout_deadline(); - let state = socket_ref.select(SslNeeds::Write, &timeout); + let state = socket_ref.select(SslNeeds::Write, &timeout, vm)?; match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -3058,7 +3063,7 @@ mod _ssl { let (needs, state) = stream .get_ref() .expect("write called in bio mode; should only be called in socket mode") - .socket_needs(&err, &timeout); + .socket_needs(&err, &timeout, vm)?; match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -3229,7 +3234,7 @@ mod _ssl { let (needs, state) = stream .get_ref() .expect("read called in bio mode; should only be called in socket mode") - .socket_needs(&err, &timeout); + .socket_needs(&err, &timeout, vm)?; match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -4102,35 +4107,30 @@ mod windows { #[pyfunction] fn enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult> { - use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore}; - use windows_sys::Win32::Security::Cryptography; - - // TODO: check every store for it, not just 2 of them: - // https://github.com/python/cpython/blob/3.8/Modules/_ssl.c#L5603-L5610 - let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; - let stores = open_fns - .iter() - .filter_map(|open| open(store_name.as_str()).ok()) - .collect::>(); - let certs = stores.iter().flat_map(|s| s.certs()).map(|c| { - let cert = vm.ctx.new_bytes(c.to_der().to_owned()); - let enc_type = unsafe { - let ptr = c.as_ptr() as *const Cryptography::CERT_CONTEXT; - (*ptr).dwCertEncodingType - }; - let enc_type = match enc_type { - Cryptography::X509_ASN_ENCODING => vm.new_pyobj(ascii!("x509_asn")), - Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), - other => vm.new_pyobj(other), + let certs = rustpython_host_env::cert_store::enum_certificates(store_name.as_str()); + let certs = certs.entries.into_iter().map(|c| { + let cert = vm.ctx.new_bytes(c.der); + let enc_type = match c.encoding { + rustpython_host_env::cert_store::EncodingType::X509Asn => { + vm.new_pyobj(ascii!("x509_asn")) + } + rustpython_host_env::cert_store::EncodingType::Pkcs7Asn => { + vm.new_pyobj(ascii!("pkcs_7_asn")) + } + rustpython_host_env::cert_store::EncodingType::Other(other) => vm.new_pyobj(other), }; - let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? { - ValidUses::All => vm.ctx.new_bool(true).into(), - ValidUses::Oids(oids) => PyFrozenSet::from_iter( - vm, - oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), - )? - .into_ref(&vm.ctx) - .into(), + let usage: PyObjectRef = match c.valid_uses.map_err(|e| e.to_pyexception(vm))? { + rustpython_host_env::cert_store::CertificateUses::All => { + vm.ctx.new_bool(true).into() + } + rustpython_host_env::cert_store::CertificateUses::Oids(oids) => { + PyFrozenSet::from_iter( + vm, + oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), + )? + .into_ref(&vm.ctx) + .into() + } }; Ok(vm.new_tuple((cert, enc_type, usage)).into()) }); @@ -4147,7 +4147,7 @@ mod bio { use openssl_sys as sys; use std::marker::PhantomData; - pub struct MemBioSlice<'a>(*mut sys::BIO, PhantomData<&'a [u8]>); + pub(super) struct MemBioSlice<'a>(*mut sys::BIO, PhantomData<&'a [u8]>); impl Drop for MemBioSlice<'_> { fn drop(&mut self) { @@ -4158,7 +4158,7 @@ mod bio { } impl<'a> MemBioSlice<'a> { - pub fn new(buf: &'a [u8]) -> Result, ErrorStack> { + pub(super) fn new(buf: &'a [u8]) -> Result, ErrorStack> { openssl::init(); assert!(buf.len() <= c_int::MAX as usize); @@ -4170,7 +4170,7 @@ mod bio { Ok(MemBioSlice(bio, PhantomData)) } - pub fn as_ptr(&self) -> *mut sys::BIO { + pub(super) fn as_ptr(&self) -> *mut sys::BIO { self.0 } } diff --git a/crates/stdlib/src/overlapped.rs b/crates/stdlib/src/overlapped.rs index bc1fb62341e..4ce3d3ba830 100644 --- a/crates/stdlib/src/overlapped.rs +++ b/crates/stdlib/src/overlapped.rs @@ -11,133 +11,51 @@ mod _overlapped { AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytesRef, PyModule, PyStrRef, PyTupleRef, PyType}, common::lock::PyMutex, - convert::ToPyObject, + convert::{ToPyException, ToPyObject}, function::OptionalArg, object::{Traverse, TraverseFn}, protocol::PyBuffer, types::{Constructor, Destructor}, }; - use windows_sys::Win32::{ - Foundation::{self, GetLastError, HANDLE}, - Networking::WinSock::{AF_INET, AF_INET6, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6}, - System::IO::OVERLAPPED, + use rustpython_host_env::{ + overlapped as host_overlapped, winapi as host_winapi, windows as host_windows, }; pub(crate) fn module_exec(vm: &VirtualMachine, module: &Py) -> PyResult<()> { let _ = vm.import("_socket", 0)?; - initialize_winsock_extensions(vm)?; + host_overlapped::initialize_winsock_extensions() + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm))?; __module_exec(vm, module); Ok(()) } #[pyattr] - use windows_sys::Win32::{ - Foundation::{ - ERROR_IO_PENDING, ERROR_NETNAME_DELETED, ERROR_OPERATION_ABORTED, ERROR_PIPE_BUSY, - ERROR_PORT_UNREACHABLE, ERROR_SEM_TIMEOUT, - }, - Networking::WinSock::{ - SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, TF_REUSE_SOCKET, - }, - System::Threading::INFINITE, - }; + const ERROR_IO_PENDING: u32 = host_winapi::ERROR_IO_PENDING; + #[pyattr] + const ERROR_NETNAME_DELETED: u32 = host_winapi::ERROR_NETNAME_DELETED; + #[pyattr] + const ERROR_OPERATION_ABORTED: u32 = host_winapi::ERROR_OPERATION_ABORTED; + #[pyattr] + const ERROR_PIPE_BUSY: u32 = host_winapi::ERROR_PIPE_BUSY; + #[pyattr] + const ERROR_PORT_UNREACHABLE: u32 = host_winapi::ERROR_PORT_UNREACHABLE; + #[pyattr] + const ERROR_SEM_TIMEOUT: u32 = host_winapi::ERROR_SEM_TIMEOUT; + #[pyattr] + const SO_UPDATE_ACCEPT_CONTEXT: i32 = host_overlapped::SO_UPDATE_ACCEPT_CONTEXT_VALUE; + #[pyattr] + const SO_UPDATE_CONNECT_CONTEXT: i32 = host_overlapped::SO_UPDATE_CONNECT_CONTEXT_VALUE; + #[pyattr] + const TF_REUSE_SOCKET: u32 = host_overlapped::TF_REUSE_SOCKET_FLAG; + #[pyattr] + const INFINITE: u32 = host_winapi::INFINITE_TIMEOUT; #[pyattr] - const INVALID_HANDLE_VALUE: isize = - unsafe { core::mem::transmute(windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE) }; + const INVALID_HANDLE_VALUE: isize = host_overlapped::INVALID_HANDLE_VALUE_ISIZE; #[pyattr] const NULL: isize = 0; - // Function pointers for Winsock extension functions - static ACCEPT_EX: std::sync::OnceLock = std::sync::OnceLock::new(); - static CONNECT_EX: std::sync::OnceLock = std::sync::OnceLock::new(); - static DISCONNECT_EX: std::sync::OnceLock = std::sync::OnceLock::new(); - static TRANSMIT_FILE: std::sync::OnceLock = std::sync::OnceLock::new(); - - fn initialize_winsock_extensions(vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Networking::WinSock::{ - INVALID_SOCKET, IPPROTO_TCP, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCK_STREAM, - SOCKET_ERROR, WSAGetLastError, WSAIoctl, closesocket, socket, - }; - - // GUIDs for extension functions - const WSAID_ACCEPTEX: windows_sys::core::GUID = windows_sys::core::GUID { - data1: 0xb5367df1, - data2: 0xcbac, - data3: 0x11cf, - data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], - }; - const WSAID_CONNECTEX: windows_sys::core::GUID = windows_sys::core::GUID { - data1: 0x25a207b9, - data2: 0xddf3, - data3: 0x4660, - data4: [0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e], - }; - const WSAID_DISCONNECTEX: windows_sys::core::GUID = windows_sys::core::GUID { - data1: 0x7fda2e11, - data2: 0x8630, - data3: 0x436f, - data4: [0xa0, 0x31, 0xf5, 0x36, 0xa6, 0xee, 0xc1, 0x57], - }; - const WSAID_TRANSMITFILE: windows_sys::core::GUID = windows_sys::core::GUID { - data1: 0xb5367df0, - data2: 0xcbac, - data3: 0x11cf, - data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], - }; - - // Check all four locks to prevent partial initialization - if ACCEPT_EX.get().is_some() - && CONNECT_EX.get().is_some() - && DISCONNECT_EX.get().is_some() - && TRANSMIT_FILE.get().is_some() - { - return Ok(()); - } - - let s = unsafe { socket(AF_INET as i32, SOCK_STREAM, IPPROTO_TCP) }; - if s == INVALID_SOCKET { - let err = unsafe { WSAGetLastError() } as u32; - return Err(set_from_windows_err(err, vm)); - } - - let mut dw_bytes: u32 = 0; - - macro_rules! get_extension { - ($guid:expr, $lock:expr) => {{ - let mut func_ptr: usize = 0; - let ret = unsafe { - WSAIoctl( - s, - SIO_GET_EXTENSION_FUNCTION_POINTER, - &$guid as *const _ as *const _, - core::mem::size_of_val(&$guid) as u32, - &mut func_ptr as *mut _ as *mut _, - core::mem::size_of::() as u32, - &mut dw_bytes, - core::ptr::null_mut(), - None, - ) - }; - if ret == SOCKET_ERROR { - let err = unsafe { WSAGetLastError() } as u32; - unsafe { closesocket(s) }; - return Err(set_from_windows_err(err, vm)); - } - let _ = $lock.set(func_ptr); - }}; - } - - get_extension!(WSAID_ACCEPTEX, ACCEPT_EX); - get_extension!(WSAID_CONNECTEX, CONNECT_EX); - get_extension!(WSAID_DISCONNECTEX, DISCONNECT_EX); - get_extension!(WSAID_TRANSMITFILE, TRANSMIT_FILE); - - unsafe { closesocket(s) }; - Ok(()) - } - #[pyattr] #[pyclass(name, traverse)] #[derive(PyPayload)] @@ -146,8 +64,8 @@ mod _overlapped { } struct OverlappedInner { - overlapped: OVERLAPPED, - handle: HANDLE, + overlapped: host_overlapped::OverlappedIo, + handle: host_overlapped::Handle, error: u32, data: OverlappedData, } @@ -223,7 +141,7 @@ mod _overlapped { result: Option, // The actual read buffer allocated_buffer: PyBytesRef, - address: SOCKADDR_IN6, + address: host_overlapped::SocketAddrV6, address_length: i32, } @@ -242,7 +160,7 @@ mod _overlapped { result: Option, /* Buffer passed by the user */ user_buffer: PyBuffer, - address: SOCKADDR_IN6, + address: host_overlapped::SocketAddrV6, address_length: i32, } @@ -270,16 +188,9 @@ mod _overlapped { } } - fn mark_as_completed(ov: &mut OVERLAPPED) { - ov.Internal = 0; - if !ov.hEvent.is_null() { - unsafe { windows_sys::Win32::System::Threading::SetEvent(ov.hEvent) }; - } - } - fn set_from_windows_err(err: u32, vm: &VirtualMachine) -> PyBaseExceptionRef { let err = if err == 0 { - unsafe { GetLastError() } + host_winapi::get_last_error() } else { err }; @@ -292,51 +203,16 @@ mod _overlapped { exc.upcast() } - fn HasOverlappedIoCompleted(overlapped: &OVERLAPPED) -> bool { - overlapped.Internal != (Foundation::STATUS_PENDING as usize) - } - /// Parse a Python address tuple to SOCKADDR fn parse_address(addr_obj: &PyTupleRef, vm: &VirtualMachine) -> PyResult<(Vec, i32)> { - use windows_sys::Win32::Networking::WinSock::{WSAGetLastError, WSAStringToAddressW}; - match addr_obj.len() { 2 => { // IPv4: (host, port) let host: PyStrRef = addr_obj[0].clone().try_into_value(vm)?; let port: u16 = addr_obj[1].clone().try_to_value(vm)?; - - let mut addr: SOCKADDR_IN = unsafe { core::mem::zeroed() }; - addr.sin_family = AF_INET; - let host_wide: Vec = host.as_wtf8().encode_wide().chain([0]).collect(); - let mut addr_len = core::mem::size_of::() as i32; - - let ret = unsafe { - WSAStringToAddressW( - host_wide.as_ptr(), - AF_INET as i32, - core::ptr::null(), - &mut addr as *mut _ as *mut SOCKADDR, - &mut addr_len, - ) - }; - - if ret < 0 { - let err = unsafe { WSAGetLastError() } as u32; - return Err(set_from_windows_err(err, vm)); - } - - // Restore port (WSAStringToAddressW overwrites it) - addr.sin_port = port.to_be(); - - let bytes = unsafe { - core::slice::from_raw_parts( - &addr as *const _ as *const u8, - core::mem::size_of::(), - ) - }; - Ok((bytes.to_vec(), addr_len)) + host_overlapped::parse_address_v4_wide(&host_wide, port) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } 4 => { // IPv6: (host, port, flowinfo, scope_id) @@ -344,71 +220,33 @@ mod _overlapped { let port: u16 = addr_obj[1].clone().try_to_value(vm)?; let flowinfo: u32 = addr_obj[2].clone().try_to_value(vm)?; let scope_id: u32 = addr_obj[3].clone().try_to_value(vm)?; - - let mut addr: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; - addr.sin6_family = AF_INET6; - let host_wide: Vec = host.as_wtf8().encode_wide().chain([0]).collect(); - let mut addr_len = core::mem::size_of::() as i32; - - let ret = unsafe { - WSAStringToAddressW( - host_wide.as_ptr(), - AF_INET6 as i32, - core::ptr::null(), - &mut addr as *mut _ as *mut SOCKADDR, - &mut addr_len, - ) - }; - - if ret < 0 { - let err = unsafe { WSAGetLastError() } as u32; - return Err(set_from_windows_err(err, vm)); - } - - // Restore fields that WSAStringToAddressW might overwrite - addr.sin6_port = port.to_be(); - addr.sin6_flowinfo = flowinfo; - addr.Anonymous.sin6_scope_id = scope_id; - - let bytes = unsafe { - core::slice::from_raw_parts( - &addr as *const _ as *const u8, - core::mem::size_of::(), - ) - }; - Ok((bytes.to_vec(), addr_len)) + host_overlapped::parse_address_v6_wide(&host_wide, port, flowinfo, scope_id) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } _ => Err(vm.new_value_error("illegal address_as_bytes argument")), } } /// Parse a SOCKADDR_IN6 (which can also hold IPv4 addresses) to a Python address tuple - fn unparse_address(addr: &SOCKADDR_IN6, _addr_len: i32, vm: &VirtualMachine) -> PyResult { - use core::net::{Ipv4Addr, Ipv6Addr}; - - unsafe { - let family = addr.sin6_family; - if family == AF_INET { - // IPv4 address stored in SOCKADDR_IN6 structure - let addr_in = &*(addr as *const SOCKADDR_IN6 as *const SOCKADDR_IN); - let ip_bytes = addr_in.sin_addr.S_un.S_un_b; - let ip_str = - Ipv4Addr::new(ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4) - .to_string(); - let port = u16::from_be(addr_in.sin_port); - Ok((ip_str, port).to_pyobject(vm)) - } else if family == AF_INET6 { - // IPv6 address - let ip_bytes = addr.sin6_addr.u.Byte; - let ip_str = Ipv6Addr::from(ip_bytes).to_string(); - let port = u16::from_be(addr.sin6_port); - let flowinfo = u32::from_be(addr.sin6_flowinfo); - let scope_id = addr.Anonymous.sin6_scope_id; - Ok((ip_str, port, flowinfo, scope_id).to_pyobject(vm)) - } else { - Err(vm.new_value_error("recvfrom returned unsupported address family")) - } + fn unparse_address( + addr: &host_overlapped::SocketAddrV6, + addr_len: i32, + vm: &VirtualMachine, + ) -> PyResult { + match host_overlapped::unparse_address( + addr as *const _ as *const host_overlapped::SocketAddrRaw, + addr_len, + ) + .map_err(|_| vm.new_value_error("recvfrom returned unsupported address family"))? + { + host_overlapped::SocketAddress::V4 { host, port } => Ok((host, port).to_pyobject(vm)), + host_overlapped::SocketAddress::V6 { + host, + port, + flowinfo, + scope_id, + } => Ok((host, port, flowinfo, scope_id).to_pyobject(vm)), } } @@ -423,7 +261,7 @@ mod _overlapped { #[pygetset] fn pending(&self, _vm: &VirtualMachine) -> bool { let inner = self.inner.lock(); - !HasOverlappedIoCompleted(&inner.overlapped) + !host_overlapped::has_overlapped_io_completed(&inner.overlapped) && !matches!(inner.data, OverlappedData::NotStarted) } @@ -448,25 +286,17 @@ mod _overlapped { ) { return Ok(()); } - let ret = if !HasOverlappedIoCompleted(&inner.overlapped) { - unsafe { - windows_sys::Win32::System::IO::CancelIoEx(inner.handle, &inner.overlapped) - } - } else { - 1 - }; - // CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between - if ret == 0 && unsafe { GetLastError() } != Foundation::ERROR_NOT_FOUND { - return Err(set_from_windows_err(0, vm)); + if !host_overlapped::has_overlapped_io_completed(&inner.overlapped) { + host_overlapped::cancel_overlapped(inner.handle, &inner.overlapped).map_err( + |err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm), + )?; } Ok(()) } #[pymethod] fn getresult(zelf: &Py, wait: OptionalArg, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::{ - ERROR_BROKEN_PIPE, ERROR_MORE_DATA, ERROR_SUCCESS, - }; + use host_winapi::{ERROR_BROKEN_PIPE, ERROR_MORE_DATA, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); let wait = wait.unwrap_or(false); @@ -479,22 +309,10 @@ mod _overlapped { return Err(vm.new_value_error("operation failed to start")); } - // Get the result - let mut transferred: u32 = 0; - let ret = unsafe { - windows_sys::Win32::System::IO::GetOverlappedResult( - inner.handle, - &inner.overlapped, - &mut transferred, - if wait { 1 } else { 0 }, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; + let result = + host_overlapped::get_overlapped_result(inner.handle, &inner.overlapped, wait); + let transferred = result.transferred; + let err = result.error; inner.error = err; // Handle errors @@ -566,10 +384,9 @@ mod _overlapped { // ReadFile #[pymethod] fn ReadFile(zelf: &Py, handle: isize, size: u32, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::{ + use host_winapi::{ ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, }; - use windows_sys::Win32::Storage::FileSystem::ReadFile; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -581,30 +398,20 @@ mod _overlapped { let buf = vec![0u8; core::cmp::max(size, 1) as usize]; let buf = vm.ctx.new_bytes(buf); - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; inner.data = OverlappedData::Read(buf.clone()); - let mut nread: u32 = 0; - let ret = unsafe { - ReadFile( - handle as HANDLE, - buf.as_bytes().as_ptr() as *mut _, - size, - &mut nread, - &mut inner.overlapped, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; + let err = host_overlapped::start_read_file( + handle as host_overlapped::Handle, + buf.as_bytes().as_ptr() as *mut u8, + size, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_BROKEN_PIPE => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), @@ -623,17 +430,16 @@ mod _overlapped { buf: PyBuffer, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ + use host_winapi::{ ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, }; - use windows_sys::Win32::Storage::FileSystem::ReadFile; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted")); } - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { return Err(vm.new_value_error("buffer too large")); @@ -641,33 +447,23 @@ mod _overlapped { // For async read, buffer must be contiguous - we can't use a temporary copy // because Windows writes data directly to the buffer after this call returns - let Some(contiguous) = buf.as_contiguous_mut() else { + let Some(mut contiguous) = buf.as_contiguous_mut() else { return Err(vm.new_buffer_error("buffer is not contiguous")); }; inner.data = OverlappedData::ReadInto(buf.clone()); - let mut nread: u32 = 0; - let ret = unsafe { - ReadFile( - handle as HANDLE, - contiguous.as_ptr() as *mut _, - buf_len as u32, - &mut nread, - &mut inner.overlapped, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; + let err = host_overlapped::start_read_file( + handle as host_overlapped::Handle, + contiguous.as_mut_ptr(), + buf_len as u32, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_BROKEN_PIPE => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), @@ -687,10 +483,9 @@ mod _overlapped { flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ + use host_winapi::{ ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, }; - use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSARecv}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -704,37 +499,21 @@ mod _overlapped { let buf = vec![0u8; core::cmp::max(size, 1) as usize]; let buf = vm.ctx.new_bytes(buf); - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; inner.data = OverlappedData::Read(buf.clone()); - let wsabuf = WSABUF { - buf: buf.as_bytes().as_ptr() as *mut _, - len: size, - }; - let mut nread: u32 = 0; - - let ret = unsafe { - WSARecv( - handle as _, - &wsabuf, - 1, - &mut nread, - &mut flags, - &mut inner.overlapped, - None, - ) - }; - - let err = if ret < 0 { - unsafe { WSAGetLastError() as u32 } - } else { - ERROR_SUCCESS - }; + let err = host_overlapped::start_wsa_recv( + handle as usize, + buf.as_bytes().as_ptr() as *mut u8, + size, + &mut flags, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_BROKEN_PIPE => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), @@ -754,10 +533,9 @@ mod _overlapped { flags: u32, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ + use host_winapi::{ ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, }; - use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSARecv}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -765,46 +543,30 @@ mod _overlapped { } let mut flags = flags; - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { return Err(vm.new_value_error("buffer too large")); } - let Some(contiguous) = buf.as_contiguous_mut() else { + let Some(mut contiguous) = buf.as_contiguous_mut() else { return Err(vm.new_buffer_error("buffer is not contiguous")); }; inner.data = OverlappedData::ReadInto(buf.clone()); - let wsabuf = WSABUF { - buf: contiguous.as_ptr() as *mut _, - len: buf_len as u32, - }; - let mut nread: u32 = 0; - - let ret = unsafe { - WSARecv( - handle as _, - &wsabuf, - 1, - &mut nread, - &mut flags, - &mut inner.overlapped, - None, - ) - }; - - let err = if ret < 0 { - unsafe { WSAGetLastError() as u32 } - } else { - ERROR_SUCCESS - }; + let err = host_overlapped::start_wsa_recv( + handle as usize, + contiguous.as_mut_ptr(), + buf_len as u32, + &mut flags, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_BROKEN_PIPE => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), @@ -823,15 +585,14 @@ mod _overlapped { buf: PyBuffer, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Storage::FileSystem::WriteFile; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted")); } - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { return Err(vm.new_value_error("buffer too large")); @@ -845,22 +606,12 @@ mod _overlapped { inner.data = OverlappedData::Write(buf.clone()); - let mut written: u32 = 0; - let ret = unsafe { - WriteFile( - handle as HANDLE, - contiguous.as_ptr() as *const _, - buf_len as u32, - &mut written, - &mut inner.overlapped, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; + let err = host_overlapped::start_write_file( + handle as host_overlapped::Handle, + contiguous.as_ptr(), + buf_len as u32, + &mut inner.overlapped, + ); inner.error = err; match err { @@ -881,15 +632,14 @@ mod _overlapped { flags: u32, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSASend}; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted")); } - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { return Err(vm.new_value_error("buffer too large")); @@ -901,29 +651,13 @@ mod _overlapped { inner.data = OverlappedData::Write(buf.clone()); - let wsabuf = WSABUF { - buf: contiguous.as_ptr() as *mut _, - len: buf_len as u32, - }; - let mut written: u32 = 0; - - let ret = unsafe { - WSASend( - handle as _, - &wsabuf, - 1, - &mut written, - flags, - &mut inner.overlapped, - None, - ) - }; - - let err = if ret < 0 { - unsafe { WSAGetLastError() as u32 } - } else { - ERROR_SUCCESS - }; + let err = host_overlapped::start_wsa_send( + handle as usize, + contiguous.as_ptr(), + buf_len as u32, + flags, + &mut inner.overlapped, + ); inner.error = err; match err { @@ -943,8 +677,7 @@ mod _overlapped { accept_socket: isize, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -952,46 +685,20 @@ mod _overlapped { } // Buffer size: local address + remote address - let size = core::mem::size_of::() + 16; + let size = core::mem::size_of::() + 16; let buf = vec![0u8; size * 2]; let buf = vm.ctx.new_bytes(buf); - inner.handle = listen_socket as HANDLE; + inner.handle = listen_socket as host_overlapped::Handle; inner.data = OverlappedData::Accept(buf.clone()); - let mut bytes_received: u32 = 0; - - type AcceptExFn = unsafe extern "system" fn( - sListenSocket: usize, - sAcceptSocket: usize, - lpOutputBuffer: *mut core::ffi::c_void, - dwReceiveDataLength: u32, - dwLocalAddressLength: u32, - dwRemoteAddressLength: u32, - lpdwBytesReceived: *mut u32, - lpOverlapped: *mut OVERLAPPED, - ) -> i32; - - let accept_ex: AcceptExFn = unsafe { core::mem::transmute(*ACCEPT_EX.get().unwrap()) }; - - let ret = unsafe { - accept_ex( - listen_socket as _, - accept_socket as _, - buf.as_bytes().as_ptr() as *mut _, - 0, - size as u32, - size as u32, - &mut bytes_received, - &mut inner.overlapped, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { WSAGetLastError() as u32 } - }; + let err = host_overlapped::start_accept_ex( + listen_socket as usize, + accept_socket as usize, + buf.as_bytes().as_ptr() as *mut u8, + size as u32, + &mut inner.overlapped, + ); inner.error = err; match err { @@ -1011,8 +718,7 @@ mod _overlapped { address: PyTupleRef, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -1021,46 +727,22 @@ mod _overlapped { let (addr_bytes, addr_len) = parse_address(&address, vm)?; - inner.handle = socket as HANDLE; + inner.handle = socket as host_overlapped::Handle; // Store addr_bytes in OverlappedData to keep it alive during async operation inner.data = OverlappedData::Connect(addr_bytes); - type ConnectExFn = unsafe extern "system" fn( - s: usize, - name: *const SOCKADDR, - namelen: i32, - lpSendBuffer: *const core::ffi::c_void, - dwSendDataLength: u32, - lpdwBytesSent: *mut u32, - lpOverlapped: *mut OVERLAPPED, - ) -> i32; - - let connect_ex: ConnectExFn = - unsafe { core::mem::transmute(*CONNECT_EX.get().unwrap()) }; - // Get pointer to the stored address data let addr_ptr = match &inner.data { OverlappedData::Connect(bytes) => bytes.as_ptr(), _ => unreachable!(), }; - let ret = unsafe { - connect_ex( - socket as _, - addr_ptr as *const SOCKADDR, - addr_len, - core::ptr::null(), - 0, - core::ptr::null_mut(), - &mut inner.overlapped, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { WSAGetLastError() as u32 } - }; + let err = host_overlapped::start_connect_ex( + socket as usize, + addr_ptr as *const host_overlapped::SocketAddrRaw, + addr_len, + &mut inner.overlapped, + ); inner.error = err; match err { @@ -1080,34 +762,18 @@ mod _overlapped { flags: u32, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted")); } - inner.handle = socket as HANDLE; + inner.handle = socket as host_overlapped::Handle; inner.data = OverlappedData::Disconnect; - type DisconnectExFn = unsafe extern "system" fn( - s: usize, - lpOverlapped: *mut OVERLAPPED, - dwFlags: u32, - dwReserved: u32, - ) -> i32; - - let disconnect_ex: DisconnectExFn = - unsafe { core::mem::transmute(*DISCONNECT_EX.get().unwrap()) }; - - let ret = unsafe { disconnect_ex(socket as _, &mut inner.overlapped, flags, 0) }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { WSAGetLastError() as u32 } - }; + let err = + host_overlapped::start_disconnect_ex(socket as usize, flags, &mut inner.overlapped); inner.error = err; match err { @@ -1136,49 +802,25 @@ mod _overlapped { flags: u32, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted")); } - inner.handle = socket as HANDLE; + inner.handle = socket as host_overlapped::Handle; inner.data = OverlappedData::TransmitFile; - inner.overlapped.Anonymous.Anonymous.Offset = offset; - inner.overlapped.Anonymous.Anonymous.OffsetHigh = offset_high; - - type TransmitFileFn = unsafe extern "system" fn( - hSocket: usize, - hFile: HANDLE, - nNumberOfBytesToWrite: u32, - nNumberOfBytesPerSend: u32, - lpOverlapped: *mut OVERLAPPED, - lpTransmitBuffers: *const core::ffi::c_void, - dwReserved: u32, - ) -> i32; - - let transmit_file: TransmitFileFn = - unsafe { core::mem::transmute(*TRANSMIT_FILE.get().unwrap()) }; - - let ret = unsafe { - transmit_file( - socket as _, - file as HANDLE, - count_to_write, - count_per_send, - &mut inner.overlapped, - core::ptr::null(), - flags, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { WSAGetLastError() as u32 } - }; + let err = host_overlapped::start_transmit_file( + socket as usize, + file as host_overlapped::Handle, + count_to_write, + count_per_send, + flags, + offset, + offset_high, + &mut inner.overlapped, + ); inner.error = err; match err { @@ -1193,31 +835,25 @@ mod _overlapped { // ConnectNamedPipe #[pymethod] fn ConnectNamedPipe(zelf: &Py, pipe: isize, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::{ - ERROR_IO_PENDING, ERROR_PIPE_CONNECTED, ERROR_SUCCESS, - }; - use windows_sys::Win32::System::Pipes::ConnectNamedPipe; + use host_winapi::{ERROR_IO_PENDING, ERROR_PIPE_CONNECTED, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted")); } - inner.handle = pipe as HANDLE; + inner.handle = pipe as host_overlapped::Handle; inner.data = OverlappedData::ConnectNamedPipe; - let ret = unsafe { ConnectNamedPipe(pipe as HANDLE, &mut inner.overlapped) }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; + let err = host_overlapped::start_connect_named_pipe( + pipe as host_overlapped::Handle, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_PIPE_CONNECTED => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Ok(true) } ERROR_SUCCESS | ERROR_IO_PENDING => Ok(false), @@ -1238,8 +874,7 @@ mod _overlapped { address: PyTupleRef, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; - use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSASendTo}; + use host_winapi::{ERROR_IO_PENDING, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -1248,7 +883,7 @@ mod _overlapped { let (addr_bytes, addr_len) = parse_address(&address, vm)?; - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { return Err(vm.new_value_error("buffer too large")); @@ -1264,37 +899,21 @@ mod _overlapped { address: addr_bytes, }); - let wsabuf = WSABUF { - buf: contiguous.as_ptr() as *mut _, - len: buf_len as u32, - }; - let mut written: u32 = 0; - // Get pointer to the stored address data let addr_ptr = match &inner.data { OverlappedData::WriteTo(wt) => wt.address.as_ptr(), _ => unreachable!(), }; - let ret = unsafe { - WSASendTo( - handle as _, - &wsabuf, - 1, - &mut written, - flags, - addr_ptr as *const SOCKADDR, - addr_len, - &mut inner.overlapped, - None, - ) - }; - - let err = if ret < 0 { - unsafe { WSAGetLastError() as u32 } - } else { - ERROR_SUCCESS - }; + let err = host_overlapped::start_wsa_send_to( + handle as usize, + contiguous.as_ptr(), + buf_len as u32, + flags, + addr_ptr as *const host_overlapped::SocketAddrRaw, + addr_len, + &mut inner.overlapped, + ); inner.error = err; match err { @@ -1315,10 +934,9 @@ mod _overlapped { flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ + use host_winapi::{ ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, }; - use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSARecvFrom}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -1332,10 +950,10 @@ mod _overlapped { let buf = vec![0u8; core::cmp::max(size, 1) as usize]; let buf = vm.ctx.new_bytes(buf); - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; - let address: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; - let address_length = core::mem::size_of::() as i32; + let address: host_overlapped::SocketAddrV6 = unsafe { core::mem::zeroed() }; + let address_length = core::mem::size_of::() as i32; inner.data = OverlappedData::ReadFrom(OverlappedReadFrom { result: None, @@ -1344,45 +962,29 @@ mod _overlapped { address_length, }); - let wsabuf = WSABUF { - buf: buf.as_bytes().as_ptr() as *mut _, - len: size, - }; - let mut nread: u32 = 0; - // Get mutable reference to address in inner.data let (addr_ptr, addr_len_ptr) = match &mut inner.data { OverlappedData::ReadFrom(rf) => ( - &mut rf.address as *mut SOCKADDR_IN6, + &mut rf.address as *mut host_overlapped::SocketAddrV6, &mut rf.address_length as *mut i32, ), _ => unreachable!(), }; - let ret = unsafe { - WSARecvFrom( - handle as _, - &wsabuf, - 1, - &mut nread, - &mut flags, - addr_ptr as *mut SOCKADDR, - addr_len_ptr, - &mut inner.overlapped, - None, - ) - }; - - let err = if ret < 0 { - unsafe { WSAGetLastError() as u32 } - } else { - ERROR_SUCCESS - }; + let err = host_overlapped::start_wsa_recv_from( + handle as usize, + buf.as_bytes().as_ptr() as *mut u8, + size, + &mut flags, + addr_ptr as *mut host_overlapped::SocketAddrRaw, + addr_len_ptr, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_BROKEN_PIPE => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), @@ -1403,10 +1005,9 @@ mod _overlapped { flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::{ + use host_winapi::{ ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, }; - use windows_sys::Win32::Networking::WinSock::{WSABUF, WSAGetLastError, WSARecvFrom}; let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { @@ -1414,9 +1015,9 @@ mod _overlapped { } let mut flags = flags.unwrap_or(0); - inner.handle = handle as HANDLE; + inner.handle = handle as host_overlapped::Handle; - let Some(contiguous) = buf.as_contiguous_mut() else { + let Some(mut contiguous) = buf.as_contiguous_mut() else { return Err(vm.new_buffer_error("buffer is not contiguous")); }; @@ -1425,8 +1026,8 @@ mod _overlapped { return Err(vm.new_value_error("buffer too large")); } - let address: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; - let address_length = core::mem::size_of::() as i32; + let address: host_overlapped::SocketAddrV6 = unsafe { core::mem::zeroed() }; + let address_length = core::mem::size_of::() as i32; inner.data = OverlappedData::ReadFromInto(OverlappedReadFromInto { result: None, @@ -1435,45 +1036,29 @@ mod _overlapped { address_length, }); - let wsabuf = WSABUF { - buf: contiguous.as_ptr() as *mut _, - len: size, - }; - let mut nread: u32 = 0; - // Get mutable reference to address in inner.data let (addr_ptr, addr_len_ptr) = match &mut inner.data { OverlappedData::ReadFromInto(rfi) => ( - &mut rfi.address as *mut SOCKADDR_IN6, + &mut rfi.address as *mut host_overlapped::SocketAddrV6, &mut rfi.address_length as *mut i32, ), _ => unreachable!(), }; - let ret = unsafe { - WSARecvFrom( - handle as _, - &wsabuf, - 1, - &mut nread, - &mut flags, - addr_ptr as *mut SOCKADDR, - addr_len_ptr, - &mut inner.overlapped, - None, - ) - }; - - let err = if ret < 0 { - unsafe { WSAGetLastError() as u32 } - } else { - ERROR_SUCCESS - }; + let err = host_overlapped::start_wsa_recv_from( + handle as usize, + contiguous.as_mut_ptr(), + size, + &mut flags, + addr_ptr as *mut host_overlapped::SocketAddrRaw, + addr_len_ptr, + &mut inner.overlapped, + ); inner.error = err; match err { ERROR_BROKEN_PIPE => { - mark_as_completed(&mut inner.overlapped); + host_overlapped::mark_as_completed(&mut inner.overlapped); Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), @@ -1492,26 +1077,20 @@ mod _overlapped { let mut event = event.unwrap_or(INVALID_HANDLE_VALUE); if event == INVALID_HANDLE_VALUE { - event = unsafe { - windows_sys::Win32::System::Threading::CreateEventW( - core::ptr::null(), - Foundation::TRUE, - Foundation::FALSE, - core::ptr::null(), - ) as isize - }; - if event == NULL { - return Err(set_from_windows_err(0, vm)); - } + event = host_winapi::create_event_w(true, false, None) + .map(|handle| handle as isize) + .map_err(|err| { + set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm) + })?; } - let mut overlapped: OVERLAPPED = unsafe { core::mem::zeroed() }; + let mut overlapped: host_overlapped::OverlappedIo = unsafe { core::mem::zeroed() }; if event != NULL { - overlapped.hEvent = event as HANDLE; + overlapped.hEvent = event as host_overlapped::Handle; } let inner = OverlappedInner { overlapped, - handle: NULL as HANDLE, + handle: NULL as host_overlapped::Handle, error: 0, data: OverlappedData::None, }; @@ -1523,35 +1102,18 @@ mod _overlapped { impl Destructor for Overlapped { fn del(zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Foundation::{ - ERROR_NOT_FOUND, ERROR_OPERATION_ABORTED, ERROR_SUCCESS, - }; - use windows_sys::Win32::System::IO::{CancelIoEx, GetOverlappedResult}; + use host_winapi::{ERROR_NOT_FOUND, ERROR_OPERATION_ABORTED, ERROR_SUCCESS}; let mut inner = zelf.inner.lock(); - let olderr = unsafe { GetLastError() }; + let olderr = host_winapi::get_last_error(); // Cancel pending I/O and wait for completion - if !HasOverlappedIoCompleted(&inner.overlapped) + if !host_overlapped::has_overlapped_io_completed(&inner.overlapped) && !matches!(inner.data, OverlappedData::NotStarted) { - let cancelled = unsafe { CancelIoEx(inner.handle, &inner.overlapped) } != 0; - let mut transferred: u32 = 0; - let ret = unsafe { - GetOverlappedResult( - inner.handle, - &inner.overlapped, - &mut transferred, - if cancelled { 1 } else { 0 }, - ) - }; - - let err = if ret != 0 { - ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; - match err { + match host_overlapped::cancel_overlapped_for_drop(inner.handle, &inner.overlapped) + .error + { ERROR_SUCCESS | ERROR_NOT_FOUND | ERROR_OPERATION_ABORTED => {} _ => { let msg = format!( @@ -1569,14 +1131,12 @@ mod _overlapped { // Close the event handle if !inner.overlapped.hEvent.is_null() { - unsafe { - Foundation::CloseHandle(inner.overlapped.hEvent); - } + let _ = host_winapi::close_handle(inner.overlapped.hEvent); inner.overlapped.hEvent = core::ptr::null_mut(); } // Restore last error - unsafe { Foundation::SetLastError(olderr) }; + host_windows::set_last_error(olderr); Ok(()) } @@ -1584,30 +1144,8 @@ mod _overlapped { #[pyfunction] fn ConnectPipe(address: String, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::{GENERIC_READ, GENERIC_WRITE}; - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_FLAG_OVERLAPPED, OPEN_EXISTING, - }; - - let address_wide: Vec = address.encode_utf16().chain(core::iter::once(0)).collect(); - - let handle = unsafe { - CreateFileW( - address_wide.as_ptr(), - GENERIC_READ | GENERIC_WRITE, - 0, - core::ptr::null(), - OPEN_EXISTING, - FILE_FLAG_OVERLAPPED, - core::ptr::null_mut(), - ) - }; - - if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { - return Err(set_from_windows_err(0, vm)); - } - - Ok(handle as isize) + host_overlapped::connect_pipe(&address) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] @@ -1618,53 +1156,26 @@ mod _overlapped { concurrency: u32, vm: &VirtualMachine, ) -> PyResult { - let r = unsafe { - windows_sys::Win32::System::IO::CreateIoCompletionPort( - handle as HANDLE, - port as HANDLE, - key, - concurrency, - ) as isize - }; - if r == 0 { - return Err(set_from_windows_err(0, vm)); - } - Ok(r) + host_overlapped::create_io_completion_port(handle, port, key, concurrency) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn GetQueuedCompletionStatus(port: isize, msecs: u32, vm: &VirtualMachine) -> PyResult { - let mut bytes_transferred = 0; - let mut completion_key = 0; - let mut overlapped: *mut OVERLAPPED = core::ptr::null_mut(); - let ret = unsafe { - windows_sys::Win32::System::IO::GetQueuedCompletionStatus( - port as HANDLE, - &mut bytes_transferred, - &mut completion_key, - &mut overlapped, - msecs, - ) - }; - let err = if ret != 0 { - Foundation::ERROR_SUCCESS - } else { - unsafe { GetLastError() } - }; - if overlapped.is_null() { - if err == Foundation::WAIT_TIMEOUT { - return Ok(vm.ctx.none()); - } - return Err(set_from_windows_err(err, vm)); + match host_overlapped::get_queued_completion_status(port, msecs) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm))? + { + host_overlapped::WaitResult::Timeout => Ok(vm.ctx.none()), + host_overlapped::WaitResult::Queued(status) => Ok(vm + .ctx + .new_tuple(vec![ + status.error.to_pyobject(vm), + status.bytes_transferred.to_pyobject(vm), + status.completion_key.to_pyobject(vm), + status.overlapped.to_pyobject(vm), + ]) + .into()), } - - let value = vm.ctx.new_tuple(vec![ - err.to_pyobject(vm), - bytes_transferred.to_pyobject(vm), - completion_key.to_pyobject(vm), - (overlapped as usize).to_pyobject(vm), - ]); - Ok(value.into()) } #[pyfunction] @@ -1675,64 +1186,8 @@ mod _overlapped { address: usize, vm: &VirtualMachine, ) -> PyResult<()> { - let ret = unsafe { - windows_sys::Win32::System::IO::PostQueuedCompletionStatus( - port as HANDLE, - bytes, - key, - address as *mut OVERLAPPED, - ) - }; - if ret == 0 { - return Err(set_from_windows_err(0, vm)); - } - Ok(()) - } - - // Registry to track callback data for proper cleanup - // Uses Arc for reference counting to prevent use-after-free when callback - // and UnregisterWait race - the data stays alive until both are done - static WAIT_CALLBACK_REGISTRY: std::sync::OnceLock< - std::sync::Mutex>>, - > = std::sync::OnceLock::new(); - - fn wait_callback_registry() -> &'static std::sync::Mutex< - std::collections::HashMap>, - > { - WAIT_CALLBACK_REGISTRY - .get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new())) - } - - // Callback data for RegisterWaitWithQueue - // Uses Arc to ensure the data stays alive while callback is executing - struct PostCallbackData { - completion_port: HANDLE, - overlapped: *mut OVERLAPPED, - } - - // SAFETY: The pointers are handles/addresses passed from Python and are - // only used to call Windows APIs. They are not dereferenced as Rust pointers. - unsafe impl Send for PostCallbackData {} - unsafe impl Sync for PostCallbackData {} - - unsafe extern "system" fn post_to_queue_callback( - parameter: *mut core::ffi::c_void, - timer_or_wait_fired: bool, - ) { - // Reconstruct Arc from raw pointer - this gives us ownership of one reference - // The Arc prevents use-after-free since we own a reference count - let data = unsafe { alloc::sync::Arc::from_raw(parameter as *const PostCallbackData) }; - - unsafe { - let _ = windows_sys::Win32::System::IO::PostQueuedCompletionStatus( - data.completion_port, - if timer_or_wait_fired { 1 } else { 0 }, - 0, - data.overlapped, - ); - } - // Arc is dropped here, decrementing refcount - // Memory is freed only when all references (callback + registry) are gone + host_overlapped::post_queued_completion_status(port, bytes, key, address) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] @@ -1743,193 +1198,45 @@ mod _overlapped { timeout: u32, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Threading::{ - RegisterWaitForSingleObject, WT_EXECUTEINWAITTHREAD, WT_EXECUTEONLYONCE, - }; - - let data = alloc::sync::Arc::new(PostCallbackData { - completion_port: completion_port as HANDLE, - overlapped: overlapped as *mut OVERLAPPED, - }); - - // Create raw pointer for the callback - this increments refcount - let data_ptr = alloc::sync::Arc::into_raw(data.clone()); - - let mut new_wait_object: HANDLE = core::ptr::null_mut(); - let ret = unsafe { - RegisterWaitForSingleObject( - &mut new_wait_object, - object as HANDLE, - Some(post_to_queue_callback), - data_ptr as *mut _, - timeout, - WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE, - ) - }; - - if ret == 0 { - // Registration failed - reconstruct Arc to drop the extra reference - unsafe { - let _ = alloc::sync::Arc::from_raw(data_ptr); - } - return Err(set_from_windows_err(0, vm)); - } - - // Store in registry for cleanup tracking - let wait_handle = new_wait_object as isize; - if let Ok(mut registry) = wait_callback_registry().lock() { - registry.insert(wait_handle, data); - } - - Ok(wait_handle) - } - - // Helper to cleanup callback data when unregistering - // Just removes from registry - Arc ensures memory stays alive if callback is running - fn cleanup_wait_callback_data(wait_handle: isize) { - if let Ok(mut registry) = wait_callback_registry().lock() { - // Removing from registry drops one Arc reference - // If callback already ran, this frees the memory - // If callback is still pending/running, it holds the other reference - registry.remove(&wait_handle); - } + host_overlapped::register_wait_with_queue(object, completion_port, overlapped, timeout) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn UnregisterWait(wait_handle: isize, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Threading::UnregisterWait; - - let ret = unsafe { UnregisterWait(wait_handle as HANDLE) }; - // Cleanup callback data regardless of UnregisterWait result - // (callback may have already fired, or may never fire) - cleanup_wait_callback_data(wait_handle); - if ret == 0 { - return Err(set_from_windows_err(0, vm)); - } - Ok(()) + host_overlapped::unregister_wait(wait_handle) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn UnregisterWaitEx(wait_handle: isize, event: isize, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Threading::UnregisterWaitEx; - - let ret = unsafe { UnregisterWaitEx(wait_handle as HANDLE, event as HANDLE) }; - // Cleanup callback data regardless of UnregisterWaitEx result - cleanup_wait_callback_data(wait_handle); - if ret == 0 { - return Err(set_from_windows_err(0, vm)); - } - Ok(()) + host_overlapped::unregister_wait_ex(wait_handle, event) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn BindLocal(socket: isize, family: i32, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Networking::WinSock::{ - INADDR_ANY, SOCKET_ERROR, WSAGetLastError, bind, - }; - - let ret = if family == AF_INET as i32 { - let mut addr: SOCKADDR_IN = unsafe { core::mem::zeroed() }; - addr.sin_family = AF_INET; - addr.sin_port = 0; - addr.sin_addr.S_un.S_addr = INADDR_ANY; - unsafe { - bind( - socket as _, - &addr as *const _ as *const SOCKADDR, - core::mem::size_of::() as i32, - ) - } - } else if family == AF_INET6 as i32 { - // in6addr_any is all zeros, which we have from zeroed() - let mut addr: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; - addr.sin6_family = AF_INET6; - addr.sin6_port = 0; - unsafe { - bind( - socket as _, - &addr as *const _ as *const SOCKADDR, - core::mem::size_of::() as i32, - ) - } - } else { + if family != host_overlapped::AF_INET_FAMILY && family != host_overlapped::AF_INET6_FAMILY { return Err(vm.new_value_error("expected tuple of length 2 or 4")); - }; - - if ret == SOCKET_ERROR { - let err = unsafe { WSAGetLastError() } as u32; - return Err(set_from_windows_err(err, vm)); } - Ok(()) + host_overlapped::bind_local(socket, family) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn FormatMessage(error_code: u32, _vm: &VirtualMachine) -> String { - use windows_sys::Win32::Foundation::LocalFree; - use windows_sys::Win32::System::Diagnostics::Debug::{ - FORMAT_MESSAGE_ALLOCATE_BUFFER, FORMAT_MESSAGE_FROM_SYSTEM, - FORMAT_MESSAGE_IGNORE_INSERTS, FormatMessageW, - }; - - // LANG_NEUTRAL = 0, SUBLANG_DEFAULT = 1 - const LANG_NEUTRAL: u32 = 0; - const SUBLANG_DEFAULT: u32 = 1; - - let mut buffer: *mut u16 = core::ptr::null_mut(); - - let len = unsafe { - FormatMessageW( - FORMAT_MESSAGE_ALLOCATE_BUFFER - | FORMAT_MESSAGE_FROM_SYSTEM - | FORMAT_MESSAGE_IGNORE_INSERTS, - core::ptr::null(), - error_code, - (SUBLANG_DEFAULT << 10) | LANG_NEUTRAL, - &mut buffer as *mut _ as *mut u16, - 0, - core::ptr::null(), - ) - }; - - if len == 0 || buffer.is_null() { - if !buffer.is_null() { - unsafe { LocalFree(buffer as *mut _) }; - } - return format!("unknown error code {error_code}"); - } - - // Convert to Rust string, trimming trailing whitespace - let slice = unsafe { core::slice::from_raw_parts(buffer, len as usize) }; - let msg = String::from_utf16_lossy(slice).trim_end().to_string(); - - unsafe { LocalFree(buffer as *mut _) }; - - msg + host_overlapped::format_message(error_code) } #[pyfunction] fn WSAConnect(socket: isize, address: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Networking::WinSock::{SOCKET_ERROR, WSAConnect, WSAGetLastError}; - let (addr_bytes, addr_len) = parse_address(&address, vm)?; - - let ret = unsafe { - WSAConnect( - socket as _, - addr_bytes.as_ptr() as *const SOCKADDR, - addr_len, - core::ptr::null(), - core::ptr::null_mut(), - core::ptr::null(), - core::ptr::null(), - ) - }; - - if ret == SOCKET_ERROR { - let err = unsafe { WSAGetLastError() } as u32; - return Err(set_from_windows_err(err, vm)); - } - Ok(()) + host_overlapped::wsa_connect( + socket, + addr_bytes.as_ptr() as *const host_overlapped::SocketAddrRaw, + addr_len, + ) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] @@ -1944,39 +1251,23 @@ mod _overlapped { return Err(vm.new_value_error("EventAttributes must be None")); } - let name_wide: Option> = - name.map(|n| n.encode_utf16().chain(core::iter::once(0)).collect()); - let name_ptr = name_wide.as_ref().map_or(core::ptr::null(), |n| n.as_ptr()); - - let event = unsafe { - windows_sys::Win32::System::Threading::CreateEventW( - core::ptr::null(), - if manual_reset { 1 } else { 0 }, - if initial_state { 1 } else { 0 }, - name_ptr, - ) as isize - }; - if event == NULL { - return Err(set_from_windows_err(0, vm)); - } - Ok(event) + let name_wide: Option = name + .map(|n| widestring::WideCString::from_str(&n).map_err(|err| err.to_pyexception(vm))) + .transpose()?; + host_winapi::create_event_w(manual_reset, initial_state, name_wide.as_deref()) + .map(|h| h as isize) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn SetEvent(handle: isize, vm: &VirtualMachine) -> PyResult<()> { - let ret = unsafe { windows_sys::Win32::System::Threading::SetEvent(handle as HANDLE) }; - if ret == 0 { - return Err(set_from_windows_err(0, vm)); - } - Ok(()) + host_winapi::set_event(handle as host_winapi::Handle) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } #[pyfunction] fn ResetEvent(handle: isize, vm: &VirtualMachine) -> PyResult<()> { - let ret = unsafe { windows_sys::Win32::System::Threading::ResetEvent(handle as HANDLE) }; - if ret == 0 { - return Err(set_from_windows_err(0, vm)); - } - Ok(()) + host_winapi::reset_event(handle as host_winapi::Handle) + .map_err(|err| set_from_windows_err(err.raw_os_error().unwrap_or(0) as u32, vm)) } } diff --git a/crates/stdlib/src/posixsubprocess.rs b/crates/stdlib/src/posixsubprocess.rs index 059e07e2d36..0371d12e3c2 100644 --- a/crates/stdlib/src/posixsubprocess.rs +++ b/crates/stdlib/src/posixsubprocess.rs @@ -4,19 +4,13 @@ use crate::vm::{ builtins::PyListRef, function::ArgSequence, ospath::OsPath, - stdlib::posix, {PyObjectRef, PyResult, TryFromObject, VirtualMachine}, }; -use itertools::Itertools; -use nix::{ - errno::Errno, - unistd::{self, Pid}, -}; +use rustpython_host_env::posix as host_posix; use std::{ io::prelude::*, os::fd::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd}, }; -use unistd::{Gid, Uid}; use alloc::ffi::CString; @@ -48,7 +42,8 @@ mod _posixsubprocess { let extra_groups = args .groups_list .as_ref() - .map(|l| Vec::::try_from_borrowed_object(vm, l.as_object())) + .map(|l| Vec::::try_from_borrowed_object(vm, l.as_object())) + .map(|res| res.map(|groups| groups.into_iter().map(|gid| gid.0).collect::>())) .transpose()?; let argv = args.args.iter().collect::>(); let envp = args.env_list.as_ref().map(CharPtrVec::from_iter); @@ -57,9 +52,9 @@ mod _posixsubprocess { envp: envp.as_deref(), extra_groups: extra_groups.as_deref(), }; - match unsafe { nix::unistd::fork() }.map_err(|err| err.into_pyexception(vm))? { - nix::unistd::ForkResult::Child => exec(&args, procargs, vm), - nix::unistd::ForkResult::Parent { child } => Ok(child.as_raw()), + match host_posix::fork().map_err(|err| err.into_pyexception(vm))? { + 0 => exec(&args, procargs, vm), + child => Ok(child), } } } @@ -143,7 +138,7 @@ impl TryFromObject for Fd { impl Write for Fd { fn write(&mut self, buf: &[u8]) -> std::io::Result { - Ok(unistd::write(self, buf)?) + host_posix::write_fd(self.as_fd(), buf) } fn flush(&mut self) -> std::io::Result<()> { @@ -201,6 +196,45 @@ impl AsRawFd for MaybeFd { } } +#[derive(Copy, Clone)] +struct RawUid(u32); + +#[derive(Copy, Clone)] +struct RawGid(u32); + +fn try_from_id(vm: &VirtualMachine, obj: PyObjectRef, typ_name: &str) -> PyResult { + use core::cmp::Ordering; + let i = obj + .try_to_ref::(vm) + .map_err(|_| { + vm.new_type_error(format!( + "an integer is required (got type {})", + obj.class().name() + )) + })? + .try_to_primitive::(vm)?; + + match i.cmp(&-1) { + Ordering::Greater => Ok(i + .try_into() + .map_err(|_| vm.new_overflow_error(format!("{typ_name} is larger than maximum")))?), + Ordering::Less => Err(vm.new_overflow_error(format!("{typ_name} is less than minimum"))), + Ordering::Equal => Ok(-1i32 as u32), + } +} + +impl TryFromObject for RawUid { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + try_from_id(vm, obj, "uid").map(Self) + } +} + +impl TryFromObject for RawGid { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + try_from_id(vm, obj, "gid").map(Self) + } +} + // impl gen_args! { @@ -221,9 +255,9 @@ gen_args! { restore_signals: bool, call_setsid: bool, pgid_to_set: libc::pid_t, - gid: Option, + gid: Option, groups_list: Option, - uid: Option, + uid: Option, child_umask: i32, preexec_fn: Option, } @@ -232,7 +266,7 @@ gen_args! { struct ProcArgs<'a> { argv: &'a CharPtrSlice<'a>, envp: Option<&'a CharPtrSlice<'a>>, - extra_groups: Option<&'a [Gid]>, + extra_groups: Option<&'a [u32]>, } fn exec(args: &ForkExecArgs<'_>, procargs: ProcArgs<'_>, vm: &VirtualMachine) -> ! { @@ -246,7 +280,8 @@ fn exec(args: &ForkExecArgs<'_>, procargs: ProcArgs<'_>, vm: &VirtualMachine) -> let _ = write!(pipe, "SubprocessError:0:{}", ctx.as_msg()); } else { // errno is written in hex format - let _ = write!(pipe, "OSError:{:x}:{}", e as i32, ctx.as_msg()); + let errno = e.raw_os_error().unwrap_or(0); + let _ = write!(pipe, "OSError:{errno:x}:{}", ctx.as_msg()); } rustpython_host_env::os::exit(255) } @@ -276,94 +311,34 @@ fn exec_inner( procargs: ProcArgs<'_>, ctx: &mut ExecErrorContext, vm: &VirtualMachine, -) -> nix::Result { - for &fd in args.fds_to_keep.as_slice() { - if fd.as_raw_fd() != args.errpipe_write.as_raw_fd() { - posix::set_inheritable(fd, true)? - } - } - - for &fd in &[args.p2cwrite, args.c2pread, args.errread] { - if let MaybeFd::Valid(fd) = fd { - unistd::close(fd)?; - } - } - unistd::close(args.errpipe_read)?; - - let c2pwrite = match args.c2pwrite { - MaybeFd::Valid(c2pwrite) if c2pwrite.as_raw_fd() == 0 => { - let fd = unistd::dup(c2pwrite)?; - posix::set_inheritable(fd.as_fd(), true)?; - MaybeFd::Valid(fd.into()) - } - fd => fd, - }; - - let mut errwrite = args.errwrite; - loop { - match errwrite { - MaybeFd::Valid(fd) if fd.as_raw_fd() == 0 || fd.as_raw_fd() == 1 => { - let fd = unistd::dup(fd)?; - posix::set_inheritable(fd.as_fd(), true)?; - errwrite = MaybeFd::Valid(fd.into()); - } - _ => break, - } - } - - fn dup_into_stdio(fd: MaybeFd, io_fd: i32, dup2_stdio: F) -> nix::Result<()> - where - F: Fn(Fd) -> nix::Result<()>, - { - match fd { - MaybeFd::Valid(fd) if fd.as_raw_fd() == io_fd => { - posix::set_inheritable(fd.as_fd(), true) - } - MaybeFd::Valid(fd) => dup2_stdio(fd), - MaybeFd::Invalid => Ok(()), - } - } - dup_into_stdio(args.p2cread, 0, unistd::dup2_stdin)?; - dup_into_stdio(c2pwrite, 1, unistd::dup2_stdout)?; - dup_into_stdio(errwrite, 2, unistd::dup2_stderr)?; +) -> std::io::Result { + host_posix::setup_child_fds( + args.fds_to_keep.as_slice(), + args.errpipe_write.as_fd(), + args.p2cread.as_raw_fd(), + args.p2cwrite.as_raw_fd(), + args.c2pread.as_raw_fd(), + args.c2pwrite.as_raw_fd(), + args.errread.as_raw_fd(), + args.errwrite.as_raw_fd(), + args.errpipe_read.as_raw_fd(), + )?; if let Some(ref cwd) = args.cwd { - unistd::chdir(cwd.s.as_c_str()).inspect_err(|_| *ctx = ExecErrorContext::ChDir)? + host_posix::chdir(cwd.s.as_c_str()).inspect_err(|_| *ctx = ExecErrorContext::ChDir)? } - if args.child_umask >= 0 { - unsafe { libc::umask(args.child_umask as libc::mode_t) }; - } + host_posix::set_umask(args.child_umask); if args.restore_signals { - unsafe { - libc::signal(libc::SIGPIPE, libc::SIG_DFL); - libc::signal(libc::SIGXFSZ, libc::SIG_DFL); - } + host_posix::restore_signals(); } - if args.call_setsid { - unistd::setsid()?; - } - - if args.pgid_to_set > -1 { - unistd::setpgid(Pid::from_raw(0), Pid::from_raw(args.pgid_to_set))?; - } - - if let Some(_groups) = procargs.extra_groups { - #[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "redox")))] - unistd::setgroups(_groups)?; - } - - if let Some(gid) = args.gid.filter(|x| x.as_raw() != u32::MAX) { - let ret = unsafe { libc::setregid(gid.as_raw(), gid.as_raw()) }; - nix::Error::result(ret)?; - } - - if let Some(uid) = args.uid.filter(|x| x.as_raw() != u32::MAX) { - let ret = unsafe { libc::setreuid(uid.as_raw(), uid.as_raw()) }; - nix::Error::result(ret)?; - } + host_posix::setsid_if_needed(args.call_setsid)?; + host_posix::setpgid_if_needed(args.pgid_to_set)?; + host_posix::setgroups_if_needed(procargs.extra_groups)?; + host_posix::setregid_if_needed(args.gid.map(|gid| gid.0))?; + host_posix::setreuid_if_needed(args.uid.map(|uid| uid.0))?; // Call preexec_fn after all process setup but before closing FDs if let Some(ref preexec_fn) = args.preexec_fn { @@ -372,7 +347,7 @@ fn exec_inner( Err(_e) => { // Cannot safely stringify exception after fork *ctx = ExecErrorContext::PreExec; - return Err(Errno::UnknownErrno); + return Err(std::io::Error::from_raw_os_error(0)); } } } @@ -380,158 +355,13 @@ fn exec_inner( *ctx = ExecErrorContext::Exec; if args.close_fds { - close_fds(KeepFds { - above: 2, - keep: &args.fds_to_keep, - }); + host_posix::close_fds(2, args.fds_to_keep.as_slice()); } - let mut first_err = None; - for exec in args.exec_list.as_slice() { - // not using nix's versions of these functions because those allocate the char-ptr array, - // and we can't allocate - if let Some(envp) = procargs.envp { - unsafe { libc::execve(exec.s.as_ptr(), procargs.argv.as_ptr(), envp.as_ptr()) }; - } else { - unsafe { libc::execv(exec.s.as_ptr(), procargs.argv.as_ptr()) }; - } - let e = Errno::last(); - if e != Errno::ENOENT && e != Errno::ENOTDIR && first_err.is_none() { - first_err = Some(e) - } - } - Err(first_err.unwrap_or_else(Errno::last)) -} - -#[derive(Copy, Clone)] -struct KeepFds<'a> { - above: i32, - keep: &'a [BorrowedFd<'a>], -} - -impl KeepFds<'_> { - fn should_keep(self, fd: i32) -> bool { - fd > self.above - && self - .keep - .binary_search_by_key(&fd, BorrowedFd::as_raw_fd) - .is_err() - } -} - -fn close_fds(keep: KeepFds<'_>) { - #[cfg(not(target_os = "redox"))] - if close_dir_fds(keep).is_ok() { - return; - } - #[cfg(target_os = "redox")] - if close_filetable_fds(keep).is_ok() { - return; - } - close_fds_brute_force(keep) -} - -#[cfg(not(target_os = "redox"))] -fn close_dir_fds(keep: KeepFds<'_>) -> nix::Result<()> { - use nix::{dir::Dir, fcntl::OFlag}; - - #[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "netbsd", - target_os = "openbsd", - target_vendor = "apple", - ))] - let fd_dir_name = c"/dev/fd"; - - #[cfg(any(target_os = "linux", target_os = "android"))] - let fd_dir_name = c"/proc/self/fd"; - - let mut dir = Dir::open( - fd_dir_name, - OFlag::O_RDONLY | OFlag::O_DIRECTORY, - nix::sys::stat::Mode::empty(), - )?; - let dirfd = dir.as_raw_fd(); - 'outer: for e in dir.iter() { - let e = e?; - let mut parser = IntParser::default(); - for &c in e.file_name().to_bytes() { - if parser.feed(c).is_err() { - continue 'outer; - } - } - let fd = parser.num; - if fd != dirfd && keep.should_keep(fd) { - let _ = unistd::close(fd); - } - } - Ok(()) -} - -#[cfg(target_os = "redox")] -fn close_filetable_fds(keep: KeepFds<'_>) -> nix::Result<()> { - use nix::fcntl; - use std::os::fd::{FromRawFd, OwnedFd}; - let filetable = fcntl::open( - c"/scheme/thisproc/current/filetable", - fcntl::OFlag::O_RDONLY, - nix::sys::stat::Mode::empty(), - )?; - let read_one = || -> nix::Result<_> { - let mut byte = 0; - let n = nix::unistd::read(&filetable, std::slice::from_mut(&mut byte))?; - Ok((n > 0).then_some(byte)) - }; - while let Some(c) = read_one()? { - let mut parser = IntParser::default(); - if parser.feed(c).is_err() { - continue; - } - let done = loop { - let Some(c) = read_one()? else { break true }; - if parser.feed(c).is_err() { - break false; - } - }; - - let fd = parser.num as i32; - if fd != filetable.as_raw_fd() && keep.should_keep(fd) { - let _ = unistd::close(fd); - } - if done { - break; - } - } - Ok(()) -} - -fn close_fds_brute_force(keep: KeepFds<'_>) { - let max_fd = nix::unistd::sysconf(nix::unistd::SysconfVar::OPEN_MAX) - .ok() - .flatten() - .unwrap_or(256) as i32; - let fds = itertools::chain![ - Some(keep.above), - keep.keep.iter().map(BorrowedFd::as_raw_fd), - Some(max_fd) - ]; - for fd in fds.tuple_windows().flat_map(|(start, end)| start + 1..end) { - unsafe { libc::close(fd) }; - } -} - -#[derive(Default)] -struct IntParser { - num: i32, -} - -struct NonDigit; -impl IntParser { - fn feed(&mut self, c: u8) -> Result<(), NonDigit> { - let digit = (c as char).to_digit(10).ok_or(NonDigit)?; - self.num *= 10; - self.num += digit as i32; - Ok(()) - } + let err = host_posix::exec_replace( + args.exec_list.as_slice(), + procargs.argv.as_ptr(), + procargs.envp.map(CharPtrSlice::as_ptr), + ); + Err(std::io::Error::from_raw_os_error(err as i32)) } diff --git a/crates/stdlib/src/resource.rs b/crates/stdlib/src/resource.rs index c984a294775..bac708435c9 100644 --- a/crates/stdlib/src/resource.rs +++ b/crates/stdlib/src/resource.rs @@ -10,7 +10,7 @@ mod resource { convert::{ToPyException, ToPyObject}, types::PyStructSequence, }; - use core::mem; + use rustpython_host_env::resource as host_resource; use std::io; #[cfg_attr(target_os = "android", expect(deprecated))] @@ -92,8 +92,8 @@ mod resource { #[pyclass(with(PyStructSequence))] impl PyRUsage {} - impl From for RUsageData { - fn from(rusage: libc::rusage) -> Self { + impl From for RUsageData { + fn from(rusage: host_resource::RUsage) -> Self { let tv = |tv: libc::timeval| tv.tv_sec as f64 + (tv.tv_usec as f64 / 1_000_000.0); Self { ru_utime: tv(rusage.ru_utime), @@ -118,14 +118,7 @@ mod resource { #[pyfunction] fn getrusage(who: i32, vm: &VirtualMachine) -> PyResult { - let res = unsafe { - let mut rusage = mem::MaybeUninit::::uninit(); - if libc::getrusage(who, rusage.as_mut_ptr()) == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(rusage.assume_init()) - } - }; + let res = host_resource::getrusage(who); res.map(RUsageData::from).map_err(|e| { if e.kind() == io::ErrorKind::InvalidInput { vm.new_value_error("invalid who parameter") @@ -175,13 +168,7 @@ mod resource { return Err(vm.new_value_error("invalid resource specified")); } - let rlimit = unsafe { - let mut rlimit = mem::MaybeUninit::::uninit(); - if libc::getrlimit(resource as _, rlimit.as_mut_ptr()) == -1 { - return Err(vm.new_last_errno_error()); - } - rlimit.assume_init() - }; + let rlimit = host_resource::getrlimit(resource).map_err(|_| vm.new_last_errno_error())?; Ok(Limits(rlimit)) } @@ -193,13 +180,7 @@ mod resource { return Err(vm.new_value_error("invalid resource specified")); } - let res = unsafe { - if libc::setrlimit(resource as _, &limits.0) == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } - }; + let res = host_resource::setrlimit(resource, limits.0); res.map_err(|e| match e.kind() { io::ErrorKind::InvalidInput => { diff --git a/crates/stdlib/src/select.rs b/crates/stdlib/src/select.rs index 0110e07339f..12e55db57f2 100644 --- a/crates/stdlib/src/select.rs +++ b/crates/stdlib/src/select.rs @@ -5,7 +5,7 @@ pub(crate) use decl::module_def; use crate::vm::{ PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::PyListRef, }; -use rustpython_host_env::select::{self as host_select, FdSet, RawFd}; +use rustpython_host_env::select::{self as host_select, FdSet, RawFd, platform::FD_SETSIZE}; use std::io; #[derive(Traverse)] @@ -43,7 +43,7 @@ mod decl { #[expect(clippy::unnecessary_wraps, reason = "Needs to comply with a signature")] pub(crate) fn module_exec(vm: &VirtualMachine, module: &Py) -> PyResult<()> { #[cfg(windows)] - crate::vm::windows::init_winsock(); + rustpython_host_env::windows::init_winsock(); #[cfg(unix)] { @@ -81,8 +81,22 @@ mod decl { let seq2set = |list: &PyObject| -> PyResult<(Vec, FdSet)> { let v: Vec = list.try_to_value(vm)?; + + let too_many_fds = cfg_select! { + windows => v.len() > FD_SETSIZE as usize, + _ => v.len() > FD_SETSIZE, + }; + if too_many_fds { + return Err(vm.new_value_error("too many file descriptors in select()")); + } + let mut fds = FdSet::new(); for fd in &v { + #[cfg(unix)] + if fd.fno as usize >= FD_SETSIZE { + return Err(vm.new_value_error("file descriptor out of range in select()")); + } + fds.insert(fd.fno); } Ok((v, fds)) @@ -97,11 +111,15 @@ mod decl { return Ok((empty.clone(), empty.clone(), empty)); } - let nfds: i32 = [&mut r, &mut w, &mut x] - .iter_mut() - .filter_map(|set| set.highest()) - .max() - .map_or(0, |n| n + 1) as _; + let nfds = cfg_select! { + windows => 0, // value is ignored on windows + + _ => [&mut r, &mut w, &mut x] + .iter_mut() + .filter_map(|set| set.highest()) + .max() + .map_or(0, |n| n + 1) as _, + }; loop { let mut tv = timeout.map(host_select::sec_to_timeval); @@ -166,7 +184,6 @@ mod decl { stdlib::_io::Fildes, }; use core::{convert::TryFrom, time::Duration}; - use libc::pollfd; use num_traits::{Signed, ToPrimitive}; use std::time::Instant; @@ -216,34 +233,7 @@ mod decl { #[derive(Default, Debug, PyPayload)] pub(crate) struct PyPoll { // keep sorted - fds: PyMutex>, - } - - #[inline] - fn search(fds: &[pollfd], fd: i32) -> Result { - fds.binary_search_by_key(&fd, |pfd| pfd.fd) - } - - fn insert_fd(fds: &mut Vec, fd: i32, events: i16) { - match search(fds, fd) { - Ok(i) => fds[i].events = events, - Err(i) => fds.insert( - i, - pollfd { - fd, - events, - revents: 0, - }, - ), - } - } - - fn get_fd_mut(fds: &mut [pollfd], fd: i32) -> Option<&mut pollfd> { - search(fds, fd).ok().map(move |i| &mut fds[i]) - } - - fn remove_fd(fds: &mut Vec, fd: i32) -> Option { - search(fds, fd).ok().map(|i| fds.remove(i)) + fds: PyMutex>, } // new EventMask type @@ -281,7 +271,7 @@ mod decl { OptionalArg::Present(event_mask) => event_mask.0, OptionalArg::Missing => DEFAULT_EVENTS, }; - insert_fd(&mut self.fds.lock(), fd, mask); + host_select::insert_poll_fd(&mut self.fds.lock(), fd, mask); } #[pymethod] @@ -293,7 +283,7 @@ mod decl { ) -> PyResult<()> { let mut fds = self.fds.lock(); // CPython raises KeyError if fd is not registered, match that behavior - let pfd = get_fd_mut(&mut fds, fd) + let pfd = host_select::get_poll_fd_mut(&mut fds, fd) .ok_or_else(|| vm.new_key_error(vm.ctx.new_int(fd).into()))?; pfd.events = eventmask.0; Ok(()) @@ -301,7 +291,7 @@ mod decl { #[pymethod] fn unregister(&self, Fildes(fd): Fildes, vm: &VirtualMachine) -> PyResult<()> { - let removed = remove_fd(&mut self.fds.lock(), fd); + let removed = host_select::remove_poll_fd(&mut self.fds.lock(), fd); removed .map(drop) .ok_or_else(|| vm.new_key_error(vm.ctx.new_int(fd).into())) @@ -323,13 +313,12 @@ mod decl { let deadline = timeout.map(|d| Instant::now() + d); let mut poll_timeout = timeout_ms; loop { - let res = vm.allow_threads(|| unsafe { - libc::poll(fds.as_mut_ptr(), fds.len() as _, poll_timeout) - }); - match nix::Error::result(res) { + match vm.allow_threads(|| host_select::poll_fds(&mut fds, poll_timeout)) { Ok(_) => break, - Err(nix::Error::EINTR) => vm.check_signals()?, - Err(e) => return Err(e.into_pyexception(vm)), + Err(err) if err.raw_os_error() == Some(libc::EINTR) => { + vm.check_signals()? + } + Err(err) => return Err(err.into_pyexception(vm)), } if let Some(d) = deadline { if let Some(remaining) = d.checked_duration_since(Instant::now()) { @@ -379,8 +368,7 @@ mod decl { types::Constructor, }; use core::ops::Deref; - use rustix::event::epoll::{self, EventData, EventFlags}; - use std::os::fd::{AsRawFd, IntoRawFd, OwnedFd}; + use std::os::fd::{AsRawFd, OwnedFd}; use std::time::Instant; #[pyclass(module = "select", name = "epoll")] @@ -422,7 +410,7 @@ mod decl { #[pyclass(with(Constructor))] impl PyEpoll { fn new() -> std::io::Result { - let epoll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)?; + let epoll_fd = host_select::epoll::create()?; let epoll_fd = Some(epoll_fd).into(); Ok(Self { epoll_fd }) } @@ -431,7 +419,7 @@ mod decl { fn close(&self) -> std::io::Result<()> { let fd = self.epoll_fd.write().take(); if let Some(fd) = fd { - nix::unistd::close(fd.into_raw_fd())?; + host_select::epoll::close(fd)?; } Ok(()) } @@ -468,26 +456,28 @@ mod decl { vm: &VirtualMachine, ) -> PyResult<()> { let events = match eventmask { - OptionalArg::Present(mask) => EventFlags::from_bits_retain(mask), - OptionalArg::Missing => EventFlags::IN | EventFlags::PRI | EventFlags::OUT, + OptionalArg::Present(mask) => mask, + OptionalArg::Missing => (host_select::epoll::EventFlags::IN + | host_select::epoll::EventFlags::PRI + | host_select::epoll::EventFlags::OUT) + .bits(), }; let epoll_fd = &*self.get_epoll(vm)?; - let data = EventData::new_u64(fd.as_raw_fd() as u64); - epoll::add(epoll_fd, fd, data, events).map_err(|e| e.into_pyexception(vm)) + host_select::epoll::add(epoll_fd, fd, fd.as_raw_fd() as u64, events) + .map_err(|e| e.into_pyexception(vm)) } #[pymethod] fn modify(&self, fd: Fildes, eventmask: u32, vm: &VirtualMachine) -> PyResult<()> { - let events = EventFlags::from_bits_retain(eventmask); let epoll_fd = &*self.get_epoll(vm)?; - let data = EventData::new_u64(fd.as_raw_fd() as u64); - epoll::modify(epoll_fd, fd, data, events).map_err(|e| e.into_pyexception(vm)) + host_select::epoll::modify(epoll_fd, fd, fd.as_raw_fd() as u64, eventmask) + .map_err(|e| e.into_pyexception(vm)) } #[pymethod] fn unregister(&self, fd: Fildes, vm: &VirtualMachine) -> PyResult<()> { let epoll_fd = &*self.get_epoll(vm)?; - epoll::delete(epoll_fd, fd).map_err(|e| e.into_pyexception(vm)) + host_select::epoll::delete(epoll_fd, fd).map_err(|e| e.into_pyexception(vm)) } #[pymethod] @@ -495,11 +485,10 @@ mod decl { let poll::TimeoutArg(timeout) = args.timeout; let maxevents = args.maxevents; - let mut poll_timeout = - timeout - .map(rustix::event::Timespec::try_from) - .transpose() - .map_err(|_| vm.new_overflow_error("timeout is too large"))?; + let mut poll_timeout = timeout + .map(host_select::epoll::Timespec::try_from) + .transpose() + .map_err(|_| vm.new_overflow_error("timeout is too large"))?; let deadline = timeout.map(|d| Instant::now() + d); let maxevents = match maxevents { @@ -512,22 +501,19 @@ mod decl { _ => maxevents as usize, }; - let mut events = Vec::::with_capacity(maxevents); + let mut events = Vec::::with_capacity(maxevents); let epoll = &*self.get_epoll(vm)?; loop { - events.clear(); match vm.allow_threads(|| { - epoll::wait( - epoll, - rustix::buffer::spare_capacity(&mut events), - poll_timeout.as_ref(), - ) + host_select::epoll::wait(epoll, &mut events, poll_timeout.as_ref()) }) { Ok(_) => break, - Err(rustix::io::Errno::INTR) => vm.check_signals()?, - Err(e) => return Err(e.into_pyexception(vm)), + Err(host_select::epoll::WaitError::Interrupted) => vm.check_signals()?, + Err(host_select::epoll::WaitError::Io(e)) => { + return Err(e.into_pyexception(vm)); + } } if let Some(deadline) = deadline { if let Some(new_timeout) = deadline.checked_duration_since(Instant::now()) { diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap index 04684cb6c52..1b0ca25c15d 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap @@ -1,5 +1,6 @@ --- source: crates/stdlib/src/_opcode.rs +assertion_line: 300 expression: "dis(r#\"\nasync def test():\n for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):\n with self.subTest(type=type(stop_exc)):\n try:\n async with egg():\n raise stop_exc\n except Exception as ex:\n self.assertIs(ex, stop_exc)\n else:\n self.fail(f'{stop_exc} was suppressed')\n\"#)" --- 0 RESUME 0 @@ -13,9 +14,9 @@ expression: "dis(r#\"\nasync def test():\n for stop_exc in (StopIteration('sp Disassembly of ", line 1>: 1 RETURN_GENERATOR POP_TOP - RESUME 0 + L1: RESUME 0 - 2 L1: LOAD_GLOBAL 1 (StopIteration + NULL) + 2 LOAD_GLOBAL 1 (StopIteration + NULL) LOAD_CONST 0 ('spam') CALL 1 LOAD_GLOBAL 3 (StopAsyncIteration + NULL) @@ -81,21 +82,23 @@ Disassembly of ", line 1>: L18: CLEANUP_THROW L19: END_SEND TO_BOOL - POP_JUMP_IF_TRUE 2 (to L20) - NOT_TAKEN - RERAISE 2 - L20: POP_TOP - L21: POP_EXCEPT + POP_JUMP_IF_TRUE 2 (to L22) + L20: NOT_TAKEN + L21: RERAISE 2 + L22: POP_TOP + L23: POP_EXCEPT POP_TOP POP_TOP POP_TOP - JUMP_FORWARD 3 (to L23) - L22: COPY 3 + JUMP_FORWARD 3 (to L25) + + -- L24: COPY 3 POP_EXCEPT RERAISE 1 - L23: NOP - 10 L24: LOAD_GLOBAL 4 (self) + 5 L25: NOP + + 10 L26: LOAD_GLOBAL 4 (self) LOAD_ATTR 13 (fail + NULL|self) LOAD_FAST 0 (stop_exc) FORMAT_SIMPLE @@ -103,81 +106,82 @@ Disassembly of ", line 1>: BUILD_STRING 2 CALL 1 POP_TOP - JUMP_FORWARD 45 (to L31) + JUMP_FORWARD 45 (to L33) - -- L25: PUSH_EXC_INFO + -- L27: PUSH_EXC_INFO 7 LOAD_GLOBAL 14 (Exception) CHECK_EXC_MATCH - POP_JUMP_IF_FALSE 32 (to L29) + POP_JUMP_IF_FALSE 32 (to L31) NOT_TAKEN STORE_FAST 1 (ex) - 8 L26: LOAD_GLOBAL 4 (self) + 8 L28: LOAD_GLOBAL 4 (self) LOAD_ATTR 17 (assertIs + NULL|self) LOAD_FAST_LOAD_FAST 16 (ex, stop_exc) CALL 2 POP_TOP - L27: POP_EXCEPT + L29: POP_EXCEPT LOAD_CONST 3 (None) STORE_FAST 1 (ex) DELETE_FAST 1 (ex) - JUMP_FORWARD 8 (to L31) + JUMP_FORWARD 8 (to L33) - -- L28: LOAD_CONST 3 (None) + -- L30: LOAD_CONST 3 (None) STORE_FAST 1 (ex) DELETE_FAST 1 (ex) RERAISE 1 - 7 L29: RERAISE 0 + 7 L31: RERAISE 0 - -- L30: COPY 3 + -- L32: COPY 3 POP_EXCEPT RERAISE 1 - 3 L31: LOAD_CONST 3 (None) + 3 L33: LOAD_CONST 3 (None) LOAD_CONST 3 (None) LOAD_CONST 3 (None) CALL 3 POP_TOP JUMP_BACKWARD 188 (to L2) - L32: PUSH_EXC_INFO + L34: PUSH_EXC_INFO WITH_EXCEPT_START TO_BOOL - POP_JUMP_IF_TRUE 2 (to L33) + POP_JUMP_IF_TRUE 2 (to L35) NOT_TAKEN RERAISE 2 - L33: POP_TOP - L34: POP_EXCEPT + L35: POP_TOP + L36: POP_EXCEPT POP_TOP POP_TOP POP_TOP JUMP_BACKWARD 205 (to L2) - L35: COPY 3 + + -- L37: COPY 3 POP_EXCEPT RERAISE 1 - - -- L36: CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + L38: CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) RERAISE 1 ExceptionTable: - L1 to L3 -> L36 [0] lasti - L3 to L4 -> L32 [3] lasti - L5 to L7 -> L25 [3] + L1 to L3 -> L38 [0] lasti + L3 to L4 -> L34 [3] lasti + L5 to L7 -> L27 [3] L7 to L8 -> L12 [7] - L8 to L10 -> L25 [3] + L8 to L10 -> L27 [3] L10 to L11 -> L14 [5] lasti - L11 to L12 -> L36 [0] lasti - L12 to L13 -> L25 [3] - L14 to L16 -> L22 [7] lasti + L11 to L12 -> L38 [0] lasti + L12 to L13 -> L27 [3] + L14 to L16 -> L24 [7] lasti L16 to L17 -> L18 [10] - L17 to L21 -> L22 [7] lasti - L21 to L23 -> L25 [3] - L24 to L25 -> L32 [3] lasti - L25 to L26 -> L30 [4] lasti - L26 to L27 -> L28 [4] lasti - L27 to L28 -> L32 [3] lasti - L28 to L30 -> L30 [4] lasti - L30 to L31 -> L32 [3] lasti - L31 to L32 -> L36 [0] lasti - L32 to L34 -> L35 [5] lasti - L34 to L36 -> L36 [0] lasti + L17 to L20 -> L24 [7] lasti + L21 to L23 -> L24 [7] lasti + L23 to L25 -> L27 [3] + L26 to L27 -> L34 [3] lasti + L27 to L28 -> L32 [4] lasti + L28 to L29 -> L30 [4] lasti + L29 to L30 -> L34 [3] lasti + L30 to L32 -> L32 [4] lasti + L32 to L33 -> L34 [3] lasti + L33 to L34 -> L38 [0] lasti + L34 to L36 -> L37 [5] lasti + L36 to L38 -> L38 [0] lasti diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index ecedb136f53..366de2ecc21 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -3,11 +3,13 @@ pub(crate) use _socket::module_def; #[cfg(feature = "ssl")] -pub(super) use _socket::{PySocket, SelectKind, sock_select, timeout_error_msg}; +pub(super) use _socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg}; #[pymodule] mod _socket { use crate::common::lock::{PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard}; + #[cfg(all(unix, not(target_os = "redox")))] + use crate::vm::convert::ToPyException; use crate::vm::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ @@ -23,11 +25,15 @@ mod _socket { utils::ToCString, }; use rustpython_host_env::os::ErrorExt; + #[cfg(any(unix, windows))] + use rustpython_host_env::socket as host_socket; + #[cfg(windows)] + use rustpython_host_env::windows as host_windows; #[expect(clippy::unnecessary_wraps, reason = "Needs to comply with a signature")] pub(crate) fn module_exec(vm: &VirtualMachine, module: &Py) -> PyResult<()> { #[cfg(windows)] - crate::vm::windows::init_winsock(); + host_windows::init_winsock(); __module_exec(vm, module); Ok(()) @@ -50,72 +56,36 @@ mod _socket { #[cfg(unix)] use libc as c; - #[cfg(windows)] mod c { - pub(super) use windows_sys::Win32::NetworkManagement::IpHelper::{ - if_indextoname, if_nametoindex, - }; - - pub(super) use windows_sys::Win32::Networking::WinSock::{ - INADDR_ANY, INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_NONE, - }; - - pub(super) use windows_sys::Win32::Networking::WinSock::{ - AF_APPLETALK, AF_DECnet, AF_IPX, AF_LINK, AI_ADDRCONFIG, AI_ALL, AI_CANONNAME, - AI_NUMERICSERV, AI_V4MAPPED, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_HDRINCL, - IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_OPTIONS, IP_RECVDSTADDR, - IP_TOS, IP_TTL, IPPORT_RESERVED, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, - IPPROTO_FRAGMENT, IPPROTO_GGP, IPPROTO_HOPOPTS, IPPROTO_ICMP, IPPROTO_ICMPV6, - IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IP, IPPROTO_IP as IPPROTO_IPIP, IPPROTO_IPV4, - IPPROTO_IPV6, IPPROTO_ND, IPPROTO_NONE, IPPROTO_PIM, IPPROTO_PUP, IPPROTO_RAW, - IPPROTO_ROUTING, IPPROTO_TCP, IPPROTO_UDP, IPV6_CHECKSUM, IPV6_DONTFRAG, IPV6_HOPLIMIT, - IPV6_HOPOPTS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, IPV6_MULTICAST_HOPS, - IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_PKTINFO, IPV6_RECVRTHDR, IPV6_RECVTCLASS, - IPV6_RTHDR, IPV6_TCLASS, IPV6_UNICAST_HOPS, IPV6_V6ONLY, MSG_BCAST, MSG_CTRUNC, - MSG_DONTROUTE, MSG_MCAST, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, - NI_MAXHOST, NI_MAXSERV, NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, - RCVALL_IPLEVEL, RCVALL_OFF, RCVALL_ON, RCVALL_SOCKETLEVELONLY, SD_BOTH as SHUT_RDWR, - SD_RECEIVE as SHUT_RD, SD_SEND as SHUT_WR, SIO_KEEPALIVE_VALS, SIO_LOOPBACK_FAST_PATH, - SIO_RCVALL, SO_BROADCAST, SO_ERROR, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_RCVBUF, - SO_REUSEADDR, SO_SNDBUF, SO_TYPE, SO_USELOOPBACK, SOCK_DGRAM, SOCK_RAW, SOCK_RDM, - SOCK_SEQPACKET, SOCK_STREAM, SOL_SOCKET, SOMAXCONN, TCP_NODELAY, WSAEBADF, - WSAECONNRESET, WSAENOTSOCK, WSAEWOULDBLOCK, + pub(super) use rustpython_host_env::socket::{ + AF_APPLETALK, AF_DECnet, AF_INET, AF_INET6, AF_IPX, AF_LINK, AF_UNSPEC, AI_ADDRCONFIG, + AI_ALL, AI_CANONNAME, AI_NUMERICHOST, AI_NUMERICSERV, AI_PASSIVE, AI_V4MAPPED, + EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NODATA, EAI_NONAME, + EAI_SERVICE, EAI_SOCKTYPE, INADDR_ANY, INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_NONE, + IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_HDRINCL, IP_MULTICAST_IF, IP_MULTICAST_LOOP, + IP_MULTICAST_TTL, IP_OPTIONS, IP_RECVDSTADDR, IP_TOS, IP_TTL, IPPORT_RESERVED, + IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, IPPROTO_FRAGMENT, IPPROTO_GGP, + IPPROTO_HOPOPTS, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IP, + IPPROTO_IP as IPPROTO_IPIP, IPPROTO_IPV4, IPPROTO_IPV6, IPPROTO_ND, IPPROTO_NONE, + IPPROTO_PIM, IPPROTO_PUP, IPPROTO_RAW, IPPROTO_ROUTING, IPPROTO_TCP, IPPROTO_UDP, + IPV6_CHECKSUM, IPV6_DONTFRAG, IPV6_HOPLIMIT, IPV6_HOPOPTS, IPV6_JOIN_GROUP, + IPV6_LEAVE_GROUP, IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, + IPV6_PKTINFO, IPV6_RECVRTHDR, IPV6_RECVTCLASS, IPV6_RTHDR, IPV6_TCLASS, + IPV6_UNICAST_HOPS, IPV6_V6ONLY, MSG_BCAST, MSG_CTRUNC, MSG_DONTROUTE, MSG_MCAST, + MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_MAXSERV, + NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, RCVALL_IPLEVEL, RCVALL_OFF, + RCVALL_ON, RCVALL_SOCKETLEVELONLY, SD_BOTH as SHUT_RDWR, SD_RECEIVE as SHUT_RD, + SD_SEND as SHUT_WR, SIO_KEEPALIVE_VALS, SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, + SO_BROADCAST, SO_ERROR, SO_EXCLUSIVEADDRUSE, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, + SO_RCVBUF, SO_REUSEADDR, SO_SNDBUF, SO_TYPE, SO_USELOOPBACK, SOCK_DGRAM, SOCK_RAW, + SOCK_RDM, SOCK_SEQPACKET, SOCK_STREAM, SOL_SOCKET, SOMAXCONN, TCP_NODELAY, WSAEBADF, + WSAENOTSOCK, WSAEWOULDBLOCK, getprotobyname, getservbyname, getservbyport, }; - - pub(super) use windows_sys::Win32::Networking::WinSock::{ - INVALID_SOCKET, SOCKET_ERROR, WSA_FLAG_OVERLAPPED, WSADuplicateSocketW, - WSAGetLastError, WSAIoctl, WSAPROTOCOL_INFOW, WSASocketW, - }; - - pub(super) use windows_sys::Win32::Networking::WinSock::{ - SO_REUSEADDR as SO_EXCLUSIVEADDRUSE, getprotobyname, getservbyname, getservbyport, - getsockopt, setsockopt, - }; - - pub(super) use windows_sys::Win32::Networking::WinSock::{ - WSA_NOT_ENOUGH_MEMORY as EAI_MEMORY, WSAEAFNOSUPPORT as EAI_FAMILY, - WSAEINVAL as EAI_BADFLAGS, WSAESOCKTNOSUPPORT as EAI_SOCKTYPE, - WSAHOST_NOT_FOUND as EAI_NODATA, WSAHOST_NOT_FOUND as EAI_NONAME, - WSANO_RECOVERY as EAI_FAIL, WSATRY_AGAIN as EAI_AGAIN, - WSATYPE_NOT_FOUND as EAI_SERVICE, - }; - - pub(super) const IF_NAMESIZE: usize = - windows_sys::Win32::NetworkManagement::Ndis::IF_MAX_STRING_SIZE as _; - pub(super) const AF_UNSPEC: i32 = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as _; - pub(super) const AF_INET: i32 = windows_sys::Win32::Networking::WinSock::AF_INET as _; - pub(super) const AF_INET6: i32 = windows_sys::Win32::Networking::WinSock::AF_INET6 as _; - pub(super) const AI_PASSIVE: i32 = windows_sys::Win32::Networking::WinSock::AI_PASSIVE as _; - pub(super) const AI_NUMERICHOST: i32 = - windows_sys::Win32::Networking::WinSock::AI_NUMERICHOST as _; - pub(super) const FROM_PROTOCOL_INFO: i32 = -1; } - // constants #[pyattr(name = "has_ipv6")] const HAS_IPV6: bool = true; - #[pyattr] // put IPPROTO_MAX later use c::{ @@ -842,7 +812,7 @@ mod _socket { #[cfg(windows)] #[pyattr] - use windows_sys::Win32::Networking::WinSock::{ + use host_socket::{ IPPROTO_CBT, IPPROTO_ICLFXBM, IPPROTO_IGP, IPPROTO_L2TP, IPPROTO_PGM, IPPROTO_RDP, IPPROTO_SCTP, IPPROTO_ST, }; @@ -913,9 +883,6 @@ mod _socket { }; } - #[cfg(windows)] - use windows_sys::Win32::NetworkManagement::IpHelper; - fn get_raw_sock(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { #[cfg(unix)] type CastFrom = libc::c_long; @@ -1093,20 +1060,20 @@ mod _socket { fn sock_op( &self, vm: &VirtualMachine, - select: SelectKind, + wait_kind: SockWaitKind, f: F, ) -> Result where F: FnMut() -> io::Result, { let timeout = self.get_timeout().ok(); - self.sock_op_timeout_err(vm, select, timeout, f) + self.sock_op_timeout_err(vm, wait_kind, timeout, f) } fn sock_op_timeout_err( &self, vm: &VirtualMachine, - select: SelectKind, + wait_kind: SockWaitKind, timeout: Option, mut f: F, ) -> Result @@ -1116,19 +1083,9 @@ mod _socket { let deadline = timeout.map(Deadline::new); loop { - if deadline.is_some() || matches!(select, SelectKind::Connect) { - let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; + if deadline.is_some() || matches!(wait_kind, SockWaitKind::Connect) { let sock = self.sock()?; - let res = vm.allow_threads(|| sock_select(&sock, select, interval)); - match res { - Ok(true) => return Err(IoOrPyException::Timeout), - Err(e) if e.kind() == io::ErrorKind::Interrupted => { - vm.check_signals()?; - continue; - } - Err(e) => return Err(e.into()), - Ok(false) => {} // no timeout, continue as normal - } + sock_wait_deadline(&sock, wait_kind, &deadline, vm)?; } let err = loop { @@ -1255,11 +1212,7 @@ mod _socket { } let cstr = alloc::ffi::CString::new(ifname) .map_err(|_| vm.new_os_error("invalid interface name".to_owned()))?; - let idx = unsafe { libc::if_nametoindex(cstr.as_ptr()) }; - if idx == 0 { - return Err(io::Error::last_os_error().into()); - } - idx as i32 + host_socket::if_nametoindex_checked(cstr.as_c_str())? as i32 }; // Create sockaddr_can @@ -1376,16 +1329,11 @@ mod _socket { }; if wait_connect { - // basically, connect() is async, and it registers an "error" on the socket when it's - // done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up - // from poll and the error is EISCONN then we know that the connect is done - self.sock_op(vm, SelectKind::Connect, || { + self.sock_op(vm, SockWaitKind::Connect, || { let sock = self.sock()?; let err = sock.take_error()?; match err { - Some(e) if e.posix_errno() == libc::EISCONN => Ok(()), Some(e) => Err(e), - // TODO: is this accurate? None => Ok(()), } }) @@ -1478,6 +1426,13 @@ mod _socket { let mut socket_kind = args.r#type.unwrap_or(-1); let mut proto = args.proto.unwrap_or(-1); + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call( + (vm.ctx.new_str("socket.__new__"), family, socket_kind, proto), + vm, + )?; + } + let fileno = args.fileno; let sock; @@ -1487,7 +1442,7 @@ mod _socket { use crate::vm::builtins::PyBytes; if let Ok(bytes) = fileno_obj.clone().downcast::() { let bytes_data = bytes.as_bytes(); - let expected_size = core::mem::size_of::(); + let expected_size = host_socket::protocol_info_size(); if bytes_data.len() != expected_size { return Err(vm @@ -1497,37 +1452,11 @@ mod _socket { .into()); } - let mut info: c::WSAPROTOCOL_INFOW = unsafe { core::mem::zeroed() }; - unsafe { - core::ptr::copy_nonoverlapping( - bytes_data.as_ptr(), - &mut info as *mut c::WSAPROTOCOL_INFOW as *mut u8, - expected_size, - ); - } - - let fd = unsafe { - c::WSASocketW( - c::FROM_PROTOCOL_INFO, - c::FROM_PROTOCOL_INFO, - c::FROM_PROTOCOL_INFO, - &info, - 0, - c::WSA_FLAG_OVERLAPPED, - ) - }; - - if fd == c::INVALID_SOCKET { - return Err(Self::wsa_error().into()); - } - - crate::vm::stdlib::nt::raw_set_handle_inheritable(fd as _, false)?; - - family = info.iAddressFamily; - socket_kind = info.iSocketType; - proto = info.iProtocol; - - sock = unsafe { sock_from_raw_unchecked(fd as RawSocket) }; + let shared = host_socket::socket_from_share_data(bytes_data)?; + family = shared.family; + socket_kind = shared.socket_type; + proto = shared.protocol; + sock = unsafe { sock_from_raw_unchecked(shared.raw as RawSocket) }; return Ok(zelf.init_inner(family, socket_kind, proto, sock)?); } @@ -1633,6 +1562,18 @@ mod _socket { #[pymethod] fn bind(&self, address: PyObjectRef, vm: &VirtualMachine) -> Result<(), IoOrPyException> { let sock_addr = self.extract_address(address, "bind", vm)?; + + if let Some(addr) = sock_addr.as_socket() + && let Ok(audit) = vm.sys_module.get_attr("audit", vm) + { + let (ip, port) = match addr { + SocketAddr::V4(addr) => (addr.ip().to_string(), addr.port()), + SocketAddr::V6(addr) => (addr.ip().to_string(), addr.port()), + }; + + audit.call((vm.ctx.new_str("socket.bind"), (ip, port)), vm)?; + } + Ok(self.sock()?.bind(&sock_addr)?) } @@ -1650,7 +1591,8 @@ mod _socket { ) -> Result<(RawSocket, PyObjectRef), IoOrPyException> { // Use accept_raw() instead of accept() to avoid socket2's set_common_flags() // which tries to set SO_NOSIGPIPE and fails with EINVAL on Unix domain sockets on macOS - let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock()?.accept_raw())?; + let (sock, addr) = + self.sock_op(vm, SockWaitKind::Read, || self.sock()?.accept_raw())?; let fd = into_sock_fileno(sock); Ok((fd, get_addr_tuple(&addr, vm))) } @@ -1665,7 +1607,7 @@ mod _socket { let flags = flags.unwrap_or(0); let mut buffer = Vec::with_capacity(bufsize); let sock = self.sock()?; - let n = self.sock_op(vm, SelectKind::Read, || { + let n = self.sock_op(vm, SockWaitKind::Read, || { sock.recv_with_flags(buffer.spare_capacity_mut(), flags) })?; unsafe { buffer.set_len(n) }; @@ -1696,7 +1638,7 @@ mod _socket { }; let buf = &mut buf[..read_len]; - self.sock_op(vm, SelectKind::Read, || { + self.sock_op(vm, SockWaitKind::Read, || { sock.recv_with_flags(unsafe { slice_as_uninit(buf) }, flags) }) } @@ -1713,7 +1655,7 @@ mod _socket { .to_usize() .ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom"))?; let mut buffer = Vec::with_capacity(bufsize); - let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || { self.sock()? .recv_from_with_flags(buffer.spare_capacity_mut(), flags) })?; @@ -1744,7 +1686,7 @@ mod _socket { }; let flags = flags.unwrap_or(0); let sock = self.sock()?; - let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || { sock.recv_from_with_flags(unsafe { slice_as_uninit(buf) }, flags) })?; Ok((n, get_addr_tuple(&addr, vm))) @@ -1760,7 +1702,7 @@ mod _socket { let flags = flags.unwrap_or(0); let buf = bytes.borrow_buf(); let buf = &*buf; - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { self.sock()?.send_with_flags(buf, flags) }) } @@ -1784,7 +1726,7 @@ mod _socket { // now we have like 3 layers of interrupt loop :) while buf_offset < buf.len() { let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; - self.sock_op_timeout_err(vm, SelectKind::Write, interval, || { + self.sock_op_timeout_err(vm, SockWaitKind::Write, interval, || { let subbuf = &buf[buf_offset..]; buf_offset += self.sock()?.send_with_flags(subbuf, flags)?; Ok(()) @@ -1817,7 +1759,7 @@ mod _socket { let addr = self.extract_address(address, "sendto", vm)?; let buf = bytes.borrow_buf(); let buf = &*buf; - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { self.sock()?.send_to_with_flags(buf, &addr, flags) }) } @@ -1875,7 +1817,7 @@ mod _socket { } } - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { let sock = self.sock()?; sock.sendmsg(&msg, flags) }) @@ -1888,6 +1830,8 @@ mod _socket { #[cfg(target_os = "linux")] #[pymethod] fn sendmsg_afalg(&self, args: SendmsgAfalgArgs, vm: &VirtualMachine) -> PyResult { + use std::os::fd::BorrowedFd; + let msg = args.msg; let op = args.op; let iv = args.iv; @@ -1902,100 +1846,17 @@ mod _socket { OptionalArg::Missing => None, }; - // Build control messages for AF_ALG - let mut control_buf = Vec::new(); - - // Add ALG_SET_OP control message - { - let op_bytes = op.to_ne_bytes(); - let space = - unsafe { libc::CMSG_SPACE(core::mem::size_of::() as u32) } as usize; - let old_len = control_buf.len(); - control_buf.resize(old_len + space, 0u8); - - let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; - unsafe { - (*cmsg).cmsg_len = libc::CMSG_LEN(core::mem::size_of::() as u32) as _; - (*cmsg).cmsg_level = libc::SOL_ALG; - (*cmsg).cmsg_type = libc::ALG_SET_OP; - let data = libc::CMSG_DATA(cmsg); - core::ptr::copy_nonoverlapping(op_bytes.as_ptr(), data, op_bytes.len()); - } - } - - // Add ALG_SET_IV control message if iv is provided - if let Some(iv_data) = iv { - let iv_bytes = iv_data.borrow_buf(); - // struct af_alg_iv { __u32 ivlen; __u8 iv[]; } - let iv_struct_size = 4 + iv_bytes.len(); - let space = unsafe { libc::CMSG_SPACE(iv_struct_size as u32) } as usize; - let old_len = control_buf.len(); - control_buf.resize(old_len + space, 0u8); - - let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; - unsafe { - (*cmsg).cmsg_len = libc::CMSG_LEN(iv_struct_size as u32) as _; - (*cmsg).cmsg_level = libc::SOL_ALG; - (*cmsg).cmsg_type = libc::ALG_SET_IV; - let data = libc::CMSG_DATA(cmsg); - // Write ivlen - let ivlen = (iv_bytes.len() as u32).to_ne_bytes(); - core::ptr::copy_nonoverlapping(ivlen.as_ptr(), data, 4); - // Write iv - core::ptr::copy_nonoverlapping(iv_bytes.as_ptr(), data.add(4), iv_bytes.len()); - } - } - - // Add ALG_SET_AEAD_ASSOCLEN control message if assoclen is provided - if let Some(assoclen_val) = assoclen { - let assoclen_bytes = assoclen_val.to_ne_bytes(); - let space = - unsafe { libc::CMSG_SPACE(core::mem::size_of::() as u32) } as usize; - let old_len = control_buf.len(); - control_buf.resize(old_len + space, 0u8); - - let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; - unsafe { - (*cmsg).cmsg_len = libc::CMSG_LEN(core::mem::size_of::() as u32) as _; - (*cmsg).cmsg_level = libc::SOL_ALG; - (*cmsg).cmsg_type = libc::ALG_SET_AEAD_ASSOCLEN; - let data = libc::CMSG_DATA(cmsg); - core::ptr::copy_nonoverlapping( - assoclen_bytes.as_ptr(), - data, - assoclen_bytes.len(), - ); - } - } - - // Build buffers let buffers = msg.iter().map(|buf| buf.borrow_buf()).collect::>(); - let iovecs: Vec = buffers + let buffers = buffers .iter() - .map(|buf| libc::iovec { - iov_base: buf.as_ptr() as *mut _, - iov_len: buf.len(), - }) - .collect(); - - // Set up msghdr - let mut msghdr: libc::msghdr = unsafe { core::mem::zeroed() }; - msghdr.msg_iov = iovecs.as_ptr() as *mut _; - msghdr.msg_iovlen = iovecs.len() as _; - if !control_buf.is_empty() { - msghdr.msg_control = control_buf.as_mut_ptr() as *mut _; - msghdr.msg_controllen = control_buf.len() as _; - } + .map(|buf| io::IoSlice::new(buf)) + .collect::>(); + let iv = iv.map(|iv| iv.borrow_buf().to_vec()); - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { let sock = self.sock()?; - let fd = sock_fileno(&sock); - let ret = unsafe { libc::sendmsg(fd as libc::c_int, &msghdr, flags) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret as usize) - } + let fd = unsafe { BorrowedFd::borrow_raw(sock_fileno(&sock)) }; + host_socket::sendmsg_afalg(fd, &buffers, op, iv.as_deref(), assoclen, flags) }) .map_err(|e| e.into_pyexception(vm)) } @@ -2012,8 +1873,6 @@ mod _socket { flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use core::mem::MaybeUninit; - if bufsize < 0 { return Err(vm.new_value_error("negative buffer size in recvmsg")); } @@ -2026,62 +1885,29 @@ mod _socket { let ancbufsize = ancbufsize as usize; let flags = flags.unwrap_or(0); - // Allocate buffers - let mut data_buf: Vec> = vec![MaybeUninit::uninit(); bufsize]; - let mut anc_buf: Vec> = vec![MaybeUninit::uninit(); ancbufsize]; - let mut addr_storage: libc::sockaddr_storage = unsafe { core::mem::zeroed() }; - - // Set up iovec - let mut iov = [libc::iovec { - iov_base: data_buf.as_mut_ptr().cast(), - iov_len: bufsize, - }]; - - // Set up msghdr - let mut msg: libc::msghdr = unsafe { core::mem::zeroed() }; - msg.msg_name = (&mut addr_storage as *mut libc::sockaddr_storage).cast(); - msg.msg_namelen = core::mem::size_of::() as libc::socklen_t; - msg.msg_iov = iov.as_mut_ptr(); - msg.msg_iovlen = 1; - if ancbufsize > 0 { - msg.msg_control = anc_buf.as_mut_ptr().cast(); - msg.msg_controllen = ancbufsize as _; - } - - let n = self - .sock_op(vm, SelectKind::Read, || { + let msg = self + .sock_op(vm, SockWaitKind::Read, || { let sock = self.sock()?; - let fd = sock_fileno(&sock); - let ret = unsafe { libc::recvmsg(fd as libc::c_int, &mut msg, flags) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret as usize) - } + let fd = unsafe { std::os::fd::BorrowedFd::borrow_raw(sock_fileno(&sock)) }; + host_socket::recvmsg(fd, bufsize, ancbufsize, flags) }) .map_err(|e| e.into_pyexception(vm))?; - // Build data bytes - let data = unsafe { - data_buf.set_len(n); - core::mem::transmute::>, Vec>(data_buf) - }; - // Build ancdata list - let ancdata = Self::parse_ancillary_data(&msg, vm); + let ancdata = Self::parse_ancillary_data(&msg.ancdata, vm); // Build address tuple - let address = if msg.msg_namelen > 0 { + let address = if let Some(address) = msg.address { let storage: socket2::SockAddrStorage = - unsafe { core::mem::transmute(addr_storage) }; - let addr = unsafe { socket2::SockAddr::new(storage, msg.msg_namelen) }; + unsafe { core::mem::transmute(address.storage) }; + let addr = unsafe { socket2::SockAddr::new(storage, address.len as _) }; get_addr_tuple(&addr, vm) } else { vm.ctx.none() }; Ok(vm.ctx.new_tuple(vec![ - vm.ctx.new_bytes(data).into(), + vm.ctx.new_bytes(msg.data).into(), ancdata, vm.ctx.new_int(msg.msg_flags).into(), address, @@ -2090,35 +1916,18 @@ mod _socket { /// Parse ancillary data from a received message header #[cfg(all(unix, not(target_os = "redox")))] - fn parse_ancillary_data(msg: &libc::msghdr, vm: &VirtualMachine) -> PyObjectRef { + fn parse_ancillary_data( + control: &[host_socket::AncillaryMessage], + vm: &VirtualMachine, + ) -> PyObjectRef { let mut result = Vec::new(); - - // Calculate buffer end for truncation handling - let ctrl_buf = msg.msg_control as *const u8; - let ctrl_end = unsafe { ctrl_buf.add(msg.msg_controllen as _) }; - - let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(msg) }; - while !cmsg.is_null() { - let cmsg_ref = unsafe { &*cmsg }; - let data_ptr = unsafe { libc::CMSG_DATA(cmsg) }; - - // Calculate data length, respecting buffer truncation - let data_len_from_cmsg = - cmsg_ref.cmsg_len as usize - (data_ptr as usize - cmsg as usize); - let available = ctrl_end as usize - data_ptr as usize; - let data_len = data_len_from_cmsg.min(available); - - let data = unsafe { core::slice::from_raw_parts(data_ptr, data_len) }; - + for cmsg in control { let tuple = vm.ctx.new_tuple(vec![ - vm.ctx.new_int(cmsg_ref.cmsg_level).into(), - vm.ctx.new_int(cmsg_ref.cmsg_type).into(), - vm.ctx.new_bytes(data.to_vec()).into(), + vm.ctx.new_int(cmsg.level).into(), + vm.ctx.new_int(cmsg.kind).into(), + vm.ctx.new_bytes(cmsg.data.clone()).into(), ]); - result.push(tuple.into()); - - cmsg = unsafe { libc::CMSG_NXTHDR(msg, cmsg) }; } vm.ctx.new_list(result).into() @@ -2130,53 +1939,22 @@ mod _socket { cmsgs: &[(i32, i32, ArgBytesLike)], vm: &VirtualMachine, ) -> PyResult> { - use core::{mem, ptr}; - if cmsgs.is_empty() { return Ok(vec![]); } - - let capacity = cmsgs + let data = cmsgs .iter() - .map(|(_, _, buf)| buf.len()) - .try_fold(0, |sum, len| { - let space = checked_cmsg_space(len).ok_or_else(|| { - vm.new_os_error("ancillary data item too large".to_owned()) - })?; - usize::checked_add(sum, space) - .ok_or_else(|| vm.new_os_error("too much ancillary data".to_owned())) - })?; - - let mut cmsg_buffer = vec![0u8; capacity]; - - // make a dummy msghdr so we can use the CMSG_* apis - let mut mhdr = unsafe { mem::zeroed::() }; - mhdr.msg_control = cmsg_buffer.as_mut_ptr().cast(); - mhdr.msg_controllen = capacity as _; - - let mut pmhdr: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(&mhdr) }; - for (lvl, typ, buf) in cmsgs { - if pmhdr.is_null() { - return Err(vm.new_runtime_error( - "unexpected NULL result from CMSG_FIRSTHDR/CMSG_NXTHDR", - )); - } - let data = &*buf.borrow_buf(); - assert_eq!(data.len(), buf.len()); - // Safe because we know that pmhdr is valid, and we initialized it with - // sufficient space - unsafe { - (*pmhdr).cmsg_level = *lvl; - (*pmhdr).cmsg_type = *typ; - (*pmhdr).cmsg_len = libc::CMSG_LEN(data.len() as _) as _; - ptr::copy_nonoverlapping(data.as_ptr(), libc::CMSG_DATA(pmhdr), data.len()); - } - - // Safe because mhdr is valid - pmhdr = unsafe { libc::CMSG_NXTHDR(&mhdr, pmhdr) }; - } + .map(|(lvl, typ, buf)| { + let data = buf.borrow_buf(); + (*lvl, *typ, data.to_vec()) + }) + .collect::>(); + let data_refs = data + .iter() + .map(|(lvl, typ, data)| (*lvl, *typ, data.as_slice())) + .collect::>(); - Ok(cmsg_buffer) + host_socket::pack_ancillary_messages(&data_refs).map_err(|err| err.to_pyexception(vm)) } #[pymethod] @@ -2270,20 +2048,7 @@ mod _socket { let fd = sock_fileno(&sock); let buflen = buflen.unwrap_or(0); if buflen == 0 { - let mut flag: libc::c_int = 0; - let mut flagsize = core::mem::size_of::() as _; - let ret = unsafe { - c::getsockopt( - fd as _, - level, - name, - &mut flag as *mut libc::c_int as *mut _, - &mut flagsize, - ) - }; - if ret < 0 { - return Err(rustpython_host_env::os::errno_io_error().into()); - } + let flag = host_socket::getsockopt_int(fd as _, level, name)?; Ok(vm.ctx.new_int(flag).into()) } else { if buflen <= 0 || buflen > 1024 { @@ -2291,21 +2056,7 @@ mod _socket { .new_os_error("getsockopt buflen out of range".to_owned()) .into()); } - let mut buf = vec![0u8; buflen as usize]; - let mut buflen = buflen as _; - let ret = unsafe { - c::getsockopt( - fd as _, - level, - name, - buf.as_mut_ptr() as *mut _, - &mut buflen, - ) - }; - if ret < 0 { - return Err(rustpython_host_env::os::errno_io_error().into()); - } - buf.truncate(buflen as usize); + let buf = host_socket::getsockopt_bytes(fd as _, level, name, buflen as usize)?; Ok(vm.ctx.new_bytes(buf).into()) } } @@ -2321,33 +2072,23 @@ mod _socket { ) -> Result<(), IoOrPyException> { let sock = self.sock()?; let fd = sock_fileno(&sock); - let ret = match (value, optlen) { - (Some(Either::A(b)), OptionalArg::Missing) => b.with_ref(|b| unsafe { - c::setsockopt(fd as _, level, name, b.as_ptr() as *const _, b.len() as _) - }), - (Some(Either::B(ref val)), OptionalArg::Missing) => unsafe { - c::setsockopt( - fd as _, - level, - name, - val as *const i32 as *const _, - core::mem::size_of::() as _, - ) - }, - (None, OptionalArg::Present(optlen)) => unsafe { - c::setsockopt(fd as _, level, name, core::ptr::null(), optlen as _) - }, + match (value, optlen) { + (Some(Either::A(b)), OptionalArg::Missing) => { + b.with_ref(|b| host_socket::setsockopt_bytes(fd as _, level, name, b))? + } + (Some(Either::B(val)), OptionalArg::Missing) => { + host_socket::setsockopt_int(fd as _, level, name, val)? + } + (None, OptionalArg::Present(optlen)) => { + host_socket::setsockopt_none(fd as _, level, name, optlen)? + } _ => { return Err(vm .new_type_error("expected the value arg xor the optlen arg") .into()); } - }; - if ret < 0 { - Err(rustpython_host_env::os::errno_io_error().into()) - } else { - Ok(()) } + Ok(()) } #[pymethod] @@ -2365,11 +2106,6 @@ mod _socket { Ok(self.sock()?.shutdown(how)?) } - #[cfg(windows)] - fn wsa_error() -> io::Error { - io::Error::from_raw_os_error(unsafe { c::WSAGetLastError() }) - } - #[cfg(windows)] #[pymethod] fn ioctl( @@ -2383,7 +2119,6 @@ mod _socket { let sock = self.sock()?; let fd = sock_fileno(&sock); - let mut recv: u32 = 0; // Convert cmd to u32, returning ValueError for invalid/negative values let cmd_int = cmd @@ -2403,23 +2138,7 @@ mod _socket { .into()); } let option_val: u32 = TryFromObject::try_from_object(vm, option)?; - let ret = unsafe { - c::WSAIoctl( - fd as _, - cmd, - &option_val as *const u32 as *const _, - core::mem::size_of::() as u32, - core::ptr::null_mut(), - 0, - &mut recv, - core::ptr::null_mut(), - None, - ) - }; - if ret == c::SOCKET_ERROR { - return Err(Self::wsa_error().into()); - } - Ok(recv) + host_socket::ioctl_u32(fd as _, cmd, option_val).map_err(Into::into) } c::SIO_KEEPALIVE_VALS => { let tuple: PyTupleRef = option @@ -2433,36 +2152,18 @@ mod _socket { .into()); } - #[repr(C)] - struct TcpKeepalive { - onoff: u32, - keepalivetime: u32, - keepaliveinterval: u32, - } - - let ka = TcpKeepalive { + let ka = host_socket::TcpKeepalive { onoff: TryFromObject::try_from_object(vm, tuple[0].clone())?, keepalivetime: TryFromObject::try_from_object(vm, tuple[1].clone())?, keepaliveinterval: TryFromObject::try_from_object(vm, tuple[2].clone())?, }; - let ret = unsafe { - c::WSAIoctl( - fd as _, - cmd, - &ka as *const TcpKeepalive as *const _, - core::mem::size_of::() as u32, - core::ptr::null_mut(), - 0, - &mut recv, - core::ptr::null_mut(), - None, - ) - }; - if ret == c::SOCKET_ERROR { - return Err(Self::wsa_error().into()); + if cmd != c::SIO_KEEPALIVE_VALS { + return Err(vm + .new_value_error(format!("invalid ioctl command {cmd}")) + .into()); } - Ok(recv) + host_socket::ioctl_keepalive(fd as _, ka).map_err(Into::into) } _ => Err(vm .new_value_error(format!("invalid ioctl command {cmd}")) @@ -2475,24 +2176,7 @@ mod _socket { fn share(&self, process_id: u32, _vm: &VirtualMachine) -> Result, IoOrPyException> { let sock = self.sock()?; let fd = sock_fileno(&sock); - - let mut info: MaybeUninit = MaybeUninit::uninit(); - - let ret = unsafe { c::WSADuplicateSocketW(fd as _, process_id, info.as_mut_ptr()) }; - - if ret == c::SOCKET_ERROR { - return Err(Self::wsa_error().into()); - } - - let info = unsafe { info.assume_init() }; - let bytes = unsafe { - core::slice::from_raw_parts( - &info as *const c::WSAPROTOCOL_INFOW as *const u8, - core::mem::size_of::(), - ) - }; - - Ok(bytes.to_vec()) + host_socket::share_socket(fd as _, process_id).map_err(Into::into) } #[pygetset(name = "type")] @@ -2607,19 +2291,7 @@ mod _socket { let ifname = if ifindex == 0 { String::new() } else { - let mut buf = [0u8; libc::IF_NAMESIZE]; - let ret = unsafe { - libc::if_indextoname( - ifindex as libc::c_uint, - buf.as_mut_ptr() as *mut libc::c_char, - ) - }; - if ret.is_null() { - String::new() - } else { - let nul_pos = memchr::memchr(b'\0', &buf).unwrap_or(buf.len()); - String::from_utf8_lossy(&buf[..nul_pos]).into_owned() - } + host_socket::if_indextoname_checked(ifindex as u32).unwrap_or_default() }; return vm.ctx.new_tuple(vec![vm.ctx.new_str(ifname).into()]).into(); } @@ -2647,6 +2319,10 @@ mod _socket { #[pyfunction] fn gethostname(vm: &VirtualMachine) -> PyResult { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("socket.gethostname"),), vm)?; + } + gethostname::gethostname() .into_string() .map(|hostname| vm.ctx.new_str(hostname)) @@ -2655,8 +2331,8 @@ mod _socket { #[cfg(all(unix, not(any(target_os = "redox", target_os = "android"))))] #[pyfunction] - fn sethostname(hostname: PyUtf8StrRef) -> nix::Result<()> { - nix::unistd::sethostname(hostname.as_str()) + fn sethostname(hostname: PyUtf8StrRef) -> std::io::Result<()> { + host_socket::sethostname(hostname.as_str()) } #[pyfunction] @@ -2769,68 +2445,135 @@ mod _socket { } #[derive(Copy, Clone)] - pub(crate) enum SelectKind { + pub(crate) enum SockWaitKind { Read, Write, Connect, } - /// returns true if timed out - pub(crate) fn sock_select( + /// returns Ok(true) on timeout + pub(crate) fn sock_wait( + sock: &Socket, + wait_kind: SockWaitKind, + timeout: Option, + vm: &VirtualMachine, + ) -> PyResult { + match sock_wait_deadline(sock, wait_kind, &timeout.map(Deadline::new), vm) { + Ok(()) => Ok(false), + Err(IoOrPyException::Timeout) => Ok(true), + Err(e) => Err(e.into_pyexception(vm)), + } + } + + /// returns Err(IoOrPyException::Timeout) on timeout + fn sock_wait_deadline( sock: &Socket, - kind: SelectKind, - interval: Option, - ) -> io::Result { + wait_kind: SockWaitKind, + deadline: &Option, + vm: &VirtualMachine, + ) -> Result<(), IoOrPyException> { #[cfg(unix)] { - use nix::poll::*; - use std::os::fd::AsFd; - let events = match kind { - SelectKind::Read => PollFlags::POLLIN, - SelectKind::Write => PollFlags::POLLOUT, - SelectKind::Connect => PollFlags::POLLOUT | PollFlags::POLLERR, - }; - let mut pollfd = [PollFd::new(sock.as_fd(), events)]; - let timeout = match interval { - Some(d) => d.try_into().unwrap_or(PollTimeout::MAX), - None => PollTimeout::NONE, - }; - let ret = poll(&mut pollfd, timeout)?; - Ok(ret == 0) + use rustpython_host_env::select::{PollFd, poll_fds}; + + let mut events = 0; + if matches!(wait_kind, SockWaitKind::Read) { + events |= libc::POLLIN | libc::POLLPRI; + } + if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) { + events |= libc::POLLOUT; + } + let mut fds = [PollFd { + fd: sock_fileno(sock), + events, + revents: 0, + }; 1]; + + loop { + let (timeout, is_capped) = deadline + .as_ref() + .map(|d| { + d.time_until().map(|t| { + let timeout_ms = t.as_millis(); + let is_capped = timeout_ms > i32::MAX as u128; + let timeout = if is_capped { + i32::MAX + } else { + timeout_ms as i32 + }; + (timeout, is_capped) + }) + }) + .transpose()? + .unwrap_or((-1, false)); + + match vm.allow_threads(|| poll_fds(&mut fds, timeout)) { + Ok(0) => { + if is_capped { + continue; + } + break Err(IoOrPyException::Timeout); + } + + Ok(_) => { + if fds[0].revents & libc::POLLNVAL != 0 { + break Err(io::Error::from_raw_os_error(libc::EBADF).into()); + } + break Ok(()); + } + + Err(e) => { + if e.kind() == io::ErrorKind::Interrupted { + vm.check_signals()?; + continue; + } + break Err(e.into()); + } + } + } } #[cfg(windows)] { - use rustpython_host_env::select as host_select; + use rustpython_host_env::select::{FdSet, select, timeval}; - let fd = sock_fileno(sock); + let fd = sock_fileno(sock) as usize; - let mut reads = host_select::FdSet::new(); - let mut writes = host_select::FdSet::new(); - let mut errs = host_select::FdSet::new(); + let mut reads = FdSet::new(); + let mut writes = FdSet::new(); + let mut errs = FdSet::new(); - let fd = fd as usize; - match kind { - SelectKind::Read => reads.insert(fd), - SelectKind::Write => writes.insert(fd), - SelectKind::Connect => { - writes.insert(fd); - errs.insert(fd); - } + if matches!(wait_kind, SockWaitKind::Read) { + reads.insert(fd); + errs.insert(fd); + } + if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) { + writes.insert(fd); + errs.insert(fd); } - let mut interval = interval.map(|dur| host_select::timeval { - tv_sec: dur.as_secs() as _, - tv_usec: dur.subsec_micros() as _, - }); - - host_select::select( - fd as i32 + 1, - &mut reads, - &mut writes, - &mut errs, - interval.as_mut(), - ) - .map(|ret| ret == 0) + let mut timeout = deadline + .as_ref() + .map(|d| { + d.time_until().map(|dur| timeval { + tv_sec: dur.as_secs() as _, + tv_usec: dur.subsec_micros() as _, + }) + }) + .transpose()?; + + match vm.allow_threads(|| { + select( + 0, // nfds is ignored on windows + &mut reads, + &mut writes, + &mut errs, + timeout.as_mut(), + ) + }) { + Ok(0) => Err(IoOrPyException::Timeout), + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } } } @@ -3095,29 +2838,44 @@ mod _socket { #[cfg(not(target_os = "redox"))] #[pyfunction] fn if_nametoindex(name: FsPath, vm: &VirtualMachine) -> PyResult { - let name = name.to_cstring(vm)?; - // in case 'if_nametoindex' does not set errno - rustpython_host_env::os::set_errno(libc::ENODEV); - let ret = unsafe { c::if_nametoindex(name.as_ptr() as _) }; - if ret == 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(ret) + #[cfg(windows)] + { + let name = name.to_cstring(vm)?; + host_socket::if_nametoindex_checked(&name).map_err(|_| vm.new_last_errno_error()) + } + #[cfg(not(windows))] + { + let name = name.to_cstring(vm)?; + // in case 'if_nametoindex' does not set errno + rustpython_host_env::os::set_errno(libc::ENODEV); + let ret = unsafe { c::if_nametoindex(name.as_ptr() as _) }; + if ret == 0 { + Err(vm.new_last_errno_error()) + } else { + Ok(ret) + } } } #[cfg(not(target_os = "redox"))] #[pyfunction] fn if_indextoname(index: IfIndex, vm: &VirtualMachine) -> PyResult { - let mut buf = [0; c::IF_NAMESIZE + 1]; - // in case 'if_indextoname' does not set errno - rustpython_host_env::os::set_errno(libc::ENXIO); - let ret = unsafe { c::if_indextoname(index, buf.as_mut_ptr()) }; - if ret.is_null() { - Err(vm.new_last_errno_error()) - } else { - let buf = unsafe { ffi::CStr::from_ptr(buf.as_ptr() as _) }; - Ok(buf.to_string_lossy().into_owned()) + #[cfg(windows)] + { + host_socket::if_indextoname_checked(index).map_err(|_| vm.new_last_errno_error()) + } + #[cfg(not(windows))] + { + let mut buf = [0; c::IF_NAMESIZE + 1]; + // in case 'if_indextoname' does not set errno + rustpython_host_env::os::set_errno(libc::ENXIO); + let ret = unsafe { c::if_indextoname(index, buf.as_mut_ptr()) }; + if ret.is_null() { + Err(vm.new_last_errno_error()) + } else { + let buf = unsafe { ffi::CStr::from_ptr(buf.as_ptr() as _) }; + Ok(buf.to_string_lossy().into_owned()) + } } } @@ -3136,75 +2894,22 @@ mod _socket { fn if_nameindex(vm: &VirtualMachine) -> PyResult> { #[cfg(not(windows))] { - let list = nix::net::if_::if_nameindex() + let list = host_socket::if_nameindex() .map_err(|err| err.into_pyexception(vm))? - .to_slice() - .iter() - .map(|iface| { - let tup: (u32, String) = - (iface.index(), iface.name().to_string_lossy().into_owned()); - tup.to_pyobject(vm) - }) + .into_iter() + .map(|tup| tup.to_pyobject(vm)) .collect(); Ok(list) } #[cfg(windows)] { - use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH; - - let table = MibTable::get_raw().map_err(|err| err.into_pyexception(vm))?; - let list = table.as_slice().iter().map(|entry| { - let name = - get_name(&entry.InterfaceLuid).map_err(|err| err.into_pyexception(vm))?; - let tup = (entry.InterfaceIndex, name.to_string_lossy()); - Ok(tup.to_pyobject(vm)) - }); - let list = list.collect::>()?; - return Ok(list); - - fn get_name(luid: &NET_LUID_LH) -> io::Result { - let mut buf = [0; c::IF_NAMESIZE + 1]; - let ret = unsafe { - IpHelper::ConvertInterfaceLuidToNameW(luid, buf.as_mut_ptr(), buf.len()) - }; - if ret == 0 { - Ok(widestring::WideCString::from_ustr_truncate( - widestring::WideStr::from_slice(&buf[..]), - )) - } else { - Err(io::Error::from_raw_os_error(ret as i32)) - } - } - struct MibTable { - ptr: core::ptr::NonNull, - } - impl MibTable { - fn get_raw() -> io::Result { - let mut ptr = core::ptr::null_mut(); - let ret = unsafe { IpHelper::GetIfTable2Ex(IpHelper::MibIfTableRaw, &mut ptr) }; - if ret == 0 { - let ptr = unsafe { core::ptr::NonNull::new_unchecked(ptr) }; - Ok(Self { ptr }) - } else { - Err(io::Error::from_raw_os_error(ret as i32)) - } - } - } - impl MibTable { - fn as_slice(&self) -> &[IpHelper::MIB_IF_ROW2] { - unsafe { - let p = self.ptr.as_ptr(); - let ptr = &raw const (*p).Table as *const IpHelper::MIB_IF_ROW2; - core::slice::from_raw_parts(ptr, (*p).NumEntries as usize) - } - } - } - impl Drop for MibTable { - fn drop(&mut self) { - unsafe { IpHelper::FreeMibTable(self.ptr.as_ptr() as *mut _) }; - } - } + let list = host_socket::if_nameindex() + .map_err(|err| err.into_pyexception(vm))? + .into_iter() + .map(|tup| tup.to_pyobject(vm)) + .collect(); + Ok(list) } } @@ -3326,7 +3031,7 @@ mod _socket { } #[cfg(windows)] { - windows_sys::Win32::Networking::WinSock::INVALID_SOCKET as RawSocket + host_socket::INVALID_RAW_SOCKET as RawSocket } }; @@ -3341,19 +3046,14 @@ mod _socket { let strerr = { #[cfg(unix)] { - let s = match err_kind { - SocketError::GaiError => unsafe { - ffi::CStr::from_ptr(libc::gai_strerror(err.error_num())) - }, - SocketError::HError => unsafe { - ffi::CStr::from_ptr(libc::hstrerror(err.error_num())) - }, - }; - s.to_str().unwrap() + match err_kind { + SocketError::GaiError => host_socket::gai_error_string(err.error_num()), + SocketError::HError => host_socket::h_error_string(err.error_num()), + } } #[cfg(windows)] { - "getaddrinfo failed" + "getaddrinfo failed".to_owned() } }; let exception_cls = match err_kind { @@ -3433,7 +3133,7 @@ mod _socket { let newsock = sock.try_clone()?; let fd = into_sock_fileno(newsock); #[cfg(windows)] - crate::vm::stdlib::nt::raw_set_handle_inheritable(fd as _, false)?; + host_socket::set_socket_inheritable(fd as _, false)?; Ok(fd) } @@ -3443,18 +3143,7 @@ mod _socket { } fn close_inner(x: RawSocket) -> io::Result<()> { - #[cfg(unix)] - use libc::close; - #[cfg(windows)] - use windows_sys::Win32::Networking::WinSock::closesocket as close; - let ret = unsafe { close(x as _) }; - if ret < 0 { - let err = std::io::Error::last_os_error(); - if err.raw_os_error() != Some(errcode!(ECONNRESET)) { - return Err(err); - } - } - Ok(()) + host_socket::close_socket_ignore_connreset(x as _) } enum SocketError { @@ -3462,45 +3151,17 @@ mod _socket { GaiError, } - #[cfg(all(unix, not(target_os = "redox")))] - fn checked_cmsg_len(len: usize) -> Option { - // SAFETY: CMSG_LEN is always safe - let cmsg_len = |length| unsafe { libc::CMSG_LEN(length) }; - if len as u64 > (i32::MAX as u64 - cmsg_len(0) as u64) { - return None; - } - let res = cmsg_len(len as _) as usize; - if res > i32::MAX as usize || res < len { - return None; - } - Some(res) - } - - #[cfg(all(unix, not(target_os = "redox")))] - fn checked_cmsg_space(len: usize) -> Option { - // SAFETY: CMSG_SPACE is always safe - let cmsg_space = |length| unsafe { libc::CMSG_SPACE(length) }; - if len as u64 > (i32::MAX as u64 - cmsg_space(1) as u64) { - return None; - } - let res = cmsg_space(len as _) as usize; - if res > i32::MAX as usize || res < len { - return None; - } - Some(res) - } - #[cfg(all(unix, not(target_os = "redox")))] #[pyfunction(name = "CMSG_LEN")] fn cmsg_len(length: usize, vm: &VirtualMachine) -> PyResult { - checked_cmsg_len(length) + host_socket::checked_cmsg_len(length) .ok_or_else(|| vm.new_overflow_error("CMSG_LEN() argument out of range")) } #[cfg(all(unix, not(target_os = "redox")))] #[pyfunction(name = "CMSG_SPACE")] fn cmsg_space(length: usize, vm: &VirtualMachine) -> PyResult { - checked_cmsg_space(length) + host_socket::checked_cmsg_space(length) .ok_or_else(|| vm.new_overflow_error("CMSG_SPACE() argument out of range")) } } diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index a7aacf2d1a6..d36481a0062 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -25,6 +25,9 @@ mod compat; // SSL exception types (shared with openssl backend) mod error; +// Utilities for setting a Rustls cryptography provider. +pub mod providers; + pub(crate) use _ssl::module_def; #[allow(non_snake_case)] @@ -36,13 +39,13 @@ mod _ssl { hash::PyHash, lock::{PyMutex, PyRwLock}, }, - socket::{PySocket, SelectKind, sock_select, timeout_error_msg}, + socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg}, vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef, - PyUtf8StrRef, + PyBaseExceptionRef, PyByteArray, PyBytesRef, PyListRef, PyStrRef, PyType, + PyTypeRef, PyUtf8StrRef, }, convert::IntoPyException, function::{ @@ -64,7 +67,6 @@ mod _ssl { sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; - use rustls::crypto::aws_lc_rs::ALL_CIPHER_SUITES; use std::{ collections::{HashMap, hash_map::DefaultHasher}, io::BufRead, @@ -75,7 +77,7 @@ mod _ssl { use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock}; use pem_rfc7468::{LineEnding, encode_string}; use rustls::{ - ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection, + ClientConnection, Connection, HandshakeKind, RootCertStore, ServerConfig, ServerConnection, client::{ClientSessionMemoryCache, ClientSessionStore}, crypto::SupportedKxGroup, pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName}, @@ -94,11 +96,12 @@ mod _ssl { // Import compat module (OpenSSL compatibility layer) use super::compat::{ ClientConfigOptions, MultiCertResolver, ProtocolSettings, ServerConfigOptions, SslError, - TlsConnection, create_client_config, create_server_config, curve_name_to_kx_group, - extract_cipher_info, get_cipher_encryption_desc, is_blocking_io_error, - normalize_cipher_name, ssl_do_handshake, + create_client_config, create_server_config, curve_name_to_kx_group, extract_cipher_info, + get_cipher_encryption_desc, is_blocking_io_error, normalize_cipher_name, ssl_do_handshake, }; + use super::providers::CryptoExt; + // Type aliases for better readability // Additional type alias for certificate/key pairs (SessionCache and SniCertName defined below) @@ -154,8 +157,28 @@ mod _ssl { // Buffer sizes and limits (OpenSSL/CPython compatibility) const PEM_BUFSIZE: usize = 1024; + // OpenSSL: ssl/ssl_local.h + const SSL3_RT_HEADER_LENGTH: usize = 5; + // This is the maximum MAC (digest) size used by the SSL library. Currently + // maximum of 20 is used by SHA1, but we reserve for future extension for + // 512-bit hashes. + const SSL3_RT_MAX_MD_SIZE: usize = 64; + // Maximum plaintext length: defined by SSL/TLS standards const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; + // Maximum compression overhead: defined by SSL/TLS standards + const SSL3_RT_MAX_COMPRESSED_OVERHEAD: usize = 1024; + // The standards give a maximum encryption overhead of 1024 bytes. In + // practice the value is lower than this. The overhead is the maximum number + // of padding bytes (256) plus the mac size. + const SSL3_RT_MAX_ENCRYPTED_OVERHEAD: usize = 256 + SSL3_RT_MAX_MD_SIZE; + const SSL3_RT_MAX_COMPRESSED_LENGTH: usize = + SSL3_RT_MAX_PLAIN_LENGTH + SSL3_RT_MAX_COMPRESSED_OVERHEAD; + const SSL3_RT_MAX_ENCRYPTED_LENGTH: usize = + SSL3_RT_MAX_ENCRYPTED_OVERHEAD + SSL3_RT_MAX_COMPRESSED_LENGTH; + pub(crate) const SSL3_RT_MAX_PACKET_SIZE: usize = + SSL3_RT_MAX_ENCRYPTED_LENGTH + SSL3_RT_HEADER_LENGTH; + // SSL session cache size (common practice, similar to OpenSSL defaults) const SSL_SESSION_CACHE_SIZE: usize = 256; @@ -167,21 +190,26 @@ mod _ssl { #[pyattr] const CERT_REQUIRED: i32 = 2; - // Certificate requirements + // SSL Verification Flags / Certificate requirements #[pyattr] const VERIFY_DEFAULT: i32 = 0; #[pyattr] const VERIFY_CRL_CHECK_LEAF: i32 = 4; #[pyattr] const VERIFY_CRL_CHECK_CHAIN: i32 = 12; + /// VERIFY_X509_STRICT flag for RFC 5280 strict compliance + /// When set, performs additional validation including AKI extension checks #[pyattr] - const VERIFY_X509_STRICT: i32 = 32; + pub(crate) const VERIFY_X509_STRICT: i32 = 32; #[pyattr] const VERIFY_ALLOW_PROXY_CERTS: i32 = 64; #[pyattr] const VERIFY_X509_TRUSTED_FIRST: i32 = 32768; + /// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation + /// When set, accept certificates if any certificate in the chain is in the trust store + /// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. #[pyattr] - const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; + pub(crate) const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; // Options (OpenSSL-compatible flags, mostly no-op in rustls) #[pyattr] @@ -398,8 +426,7 @@ mod _ssl { // Session data structure for tracking TLS sessions #[derive(Debug, Clone)] struct SessionData { - #[allow(dead_code)] - server_name: String, + _server_name: String, session_id: Vec, creation_time: SystemTime, lifetime: u64, @@ -477,7 +504,7 @@ mod _ssl { let creation_time = SystemTime::now(); let server_name_str = server_name.to_str(); let session_data = SessionData { - server_name: server_name_str.as_ref().to_string(), + _server_name: server_name_str.as_ref().to_string(), session_id: generate_session_id_from_metadata( server_name_str.as_ref(), &creation_time, @@ -521,7 +548,7 @@ mod _ssl { let creation_time = SystemTime::now(); let server_name_str = server_name.to_str(); let session_data = SessionData { - server_name: server_name_str.to_string(), + _server_name: server_name_str.to_string(), session_id: generate_session_id_from_metadata( server_name_str.as_ref(), &creation_time, @@ -612,7 +639,7 @@ mod _ssl { return Err("No cipher can be selected".to_string()); } - let all_suites = ALL_CIPHER_SUITES; + let all_suites = CryptoExt::get_ext().all_ciphers_or_default(); let mut selected = Vec::new(); for part in cipher_str.split(':') { @@ -720,10 +747,6 @@ mod _ssl { #[pytraverse(skip)] verify_flags: PyRwLock, // Rustls configuration (built lazily) - #[allow(dead_code)] - #[pytraverse(skip)] - client_config: PyRwLock>>, - #[allow(dead_code)] #[pytraverse(skip)] server_config: PyRwLock>>, // Certificate store @@ -747,19 +770,11 @@ mod _ssl { #[pytraverse(skip)] cert_keys: PyRwLock>, // Options - #[allow(dead_code)] #[pytraverse(skip)] options: PyRwLock, // ALPN protocols - #[allow(dead_code)] #[pytraverse(skip)] alpn_protocols: PyRwLock>>, - // ALPN strict matching flag - // When false (default), mimics OpenSSL behavior: no ALPN negotiation failure - // When true, requires ALPN match (Rustls default behavior) - #[allow(dead_code)] - #[pytraverse(skip)] - require_alpn_match: PyRwLock, // TLS 1.3 features #[pytraverse(skip)] post_handshake_auth: PyRwLock, @@ -1047,6 +1062,8 @@ mod _ssl { #[pymethod] fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> { + let crypto_ext = CryptoExt::get_ext(); + // Parse certfile argument (str or bytes) to path let cert_path = Self::parse_path_arg(&args.certfile, vm)?; @@ -1189,15 +1206,14 @@ mod _ssl { } // Additional validation: Create CertifiedKey to ensure rustls accepts it - let signing_key = - rustls::crypto::aws_lc_rs::sign::any_supported_type(&key).map_err(|_| { - vm.new_os_subtype_error( - PySSLError::class(&vm.ctx).to_owned(), - None, - "[SSL: KEY_VALUES_MISMATCH] key values mismatch", - ) - .upcast() - })?; + let signing_key = crypto_ext.any_supported_key(&key).map_err(|_| { + vm.new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "[SSL: KEY_VALUES_MISMATCH] key values mismatch", + ) + .upcast() + })?; let certified_key = CertifiedKey::new(full_chain.clone(), signing_key); if certified_key.keys_match().is_err() { @@ -1372,26 +1388,17 @@ mod _ssl { ) -> PyResult<()> { #[cfg(windows)] { - // Windows: Use schannel to load from both ROOT and CA stores - use schannel::cert_store::CertStore; - let store_names = ["ROOT", "CA"]; - let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; for store_name in store_names { - for open_fn in &open_fns { - if let Ok(cert_store) = open_fn(store_name) { - for cert_ctx in cert_store.certs() { - let der_bytes = cert_ctx.to_der(); - let cert = - rustls::pki_types::CertificateDer::from(der_bytes.to_vec()); - let is_ca = cert::is_ca_certificate(cert.as_ref()); - if store.add(cert).is_ok() { - *self.x509_cert_count.write() += 1; - if is_ca { - *self.ca_cert_count.write() += 1; - } - } + let certs = rustpython_host_env::cert_store::enum_certificates(store_name); + for cert_ctx in certs.entries { + let cert = rustls::pki_types::CertificateDer::from(cert_ctx.der.to_vec()); + let is_ca = cert::is_ca_certificate(cert.as_ref()); + if store.add(cert).is_ok() { + *self.x509_cert_count.write() += 1; + if is_ca { + *self.ca_cert_count.write() += 1; } } } @@ -1521,7 +1528,8 @@ mod _ssl { // Dynamically generate cipher list from rustls ALL_CIPHER_SUITES // This automatically includes all cipher suites supported by the current rustls version - let cipher_list = ALL_CIPHER_SUITES + let cipher_list = CryptoExt::get_ext() + .all_ciphers_or_default() .iter() .map(|suite| { // Extract cipher information using unified helper @@ -1834,6 +1842,9 @@ mod _ssl { args: WrapSocketArgs, vm: &VirtualMachine, ) -> PyResult> { + let socket_mod = vm.import("socket", 0)?; + let socket_class = socket_mod.get_attr("socket", vm)?; + // Convert server_hostname to Option // Handle both missing argument and None value let hostname = match args.server_hostname.into_option().flatten() { @@ -1886,6 +1897,12 @@ mod _ssl { // Create _SSLSocket instance let ssl_socket = PySSLSocket { sock: args.sock.clone(), + sock_send_method: socket_class.get_attr("send", vm)?, + sock_recv_method: socket_class.get_attr("recv", vm)?, + tls_record_header_buf: vm + .ctx + .new_bytearray(Vec::with_capacity(TLS_RECORD_HEADER_SIZE)) + .into(), context: PyRwLock::new(zelf), server_side: args.server_side, server_hostname: PyRwLock::new(hostname), @@ -1895,12 +1912,12 @@ mod _ssl { owner: PyRwLock::new(args.owner.into_option()), // Filter out Python None objects - only store actual SSLSession objects session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), - verified_chain: PyRwLock::new(None), incoming_bio: None, outgoing_bio: None, sni_state: PyRwLock::new(None), pending_context: PyRwLock::new(None), client_hello_buffer: PyMutex::new(None), + sni_callback_processed: PyMutex::new(false), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), write_buffered_len: PyMutex::new(0), @@ -1957,7 +1974,12 @@ mod _ssl { // Create _SSLSocket instance with BIO mode let ssl_socket = PySSLSocket { - sock: vm.ctx.none(), // No socket in BIO mode + // No socket in BIO mode + sock: vm.ctx.none(), + sock_send_method: vm.ctx.none(), + sock_recv_method: vm.ctx.none(), + + tls_record_header_buf: vm.ctx.none(), context: PyRwLock::new(zelf), server_side, server_hostname: PyRwLock::new(hostname), @@ -1967,12 +1989,12 @@ mod _ssl { owner: PyRwLock::new(args.owner.into_option()), // Filter out Python None objects - only store actual SSLSession objects session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), - verified_chain: PyRwLock::new(None), incoming_bio: Some(args.incoming), outgoing_bio: Some(args.outgoing), sni_state: PyRwLock::new(None), pending_context: PyRwLock::new(None), client_hello_buffer: PyMutex::new(None), + sni_callback_processed: PyMutex::new(false), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), write_buffered_len: PyMutex::new(0), @@ -2201,6 +2223,8 @@ mod _ssl { (protocol,): Self::Args, vm: &VirtualMachine, ) -> PyResult { + let crypto_ext = CryptoExt::get_ext(); + // Validate protocol match protocol { PROTOCOL_TLS | PROTOCOL_TLS_CLIENT | PROTOCOL_TLS_SERVER | PROTOCOL_TLSv1_2 @@ -2270,7 +2294,6 @@ mod _ssl { check_hostname: PyRwLock::new(protocol == PROTOCOL_TLS_CLIENT), verify_mode: PyRwLock::new(default_verify_mode), verify_flags: PyRwLock::new(default_verify_flags), - client_config: PyRwLock::new(None), server_config: PyRwLock::new(None), root_certs: PyRwLock::new(RootCertStore::empty()), ca_certs_der: PyRwLock::new(Vec::new()), @@ -2279,7 +2302,6 @@ mod _ssl { cert_keys: PyRwLock::new(Vec::new()), options: PyRwLock::new(default_options), alpn_protocols: PyRwLock::new(Vec::new()), - require_alpn_match: PyRwLock::new(false), post_handshake_auth: PyRwLock::new(false), num_tickets: PyRwLock::new(2), // TLS 1.3 default minimum_version: PyRwLock::new(min_version), @@ -2295,7 +2317,7 @@ mod _ssl { rustls_server_session_store: rustls::server::ServerSessionMemoryCache::new( SSL_SESSION_CACHE_SIZE, ), - server_ticketer: rustls::crypto::aws_lc_rs::Ticketer::new() + server_ticketer: (crypto_ext.ticketer)() .expect("Failed to create shared ticketer for TLS 1.2 session resumption"), accept_count: AtomicUsize::new(0), session_hits: AtomicUsize::new(0), @@ -2311,6 +2333,15 @@ mod _ssl { pub(crate) struct PySSLSocket { // Underlying socket sock: PyObjectRef, + // Cached socket.socket.send + #[pytraverse(skip)] + sock_send_method: PyObjectRef, + // Cached socket.socket.recv + #[pytraverse(skip)] + sock_recv_method: PyObjectRef, + // Header of currently read TLS record. + #[pytraverse(skip)] + tls_record_header_buf: PyObjectRef, // SSL context context: PyRwLock>, // Server-side or client-side @@ -2321,7 +2352,7 @@ mod _ssl { server_hostname: PyRwLock>, // TLS connection state #[pytraverse(skip)] - connection: PyMutex>, + connection: PyMutex>, // Handshake completed flag #[pytraverse(skip)] handshake_done: PyMutex, @@ -2332,10 +2363,6 @@ mod _ssl { owner: PyRwLock>, // Session for resumption session: PyRwLock>, - // Verified certificate chain (built during verification) - #[allow(dead_code)] - #[pytraverse(skip)] - verified_chain: PyRwLock>>>, // MemoryBIO mode (optional) incoming_bio: Option>, outgoing_bio: Option>, @@ -2347,6 +2374,9 @@ mod _ssl { // Buffer to store ClientHello for connection recreation #[pytraverse(skip)] client_hello_buffer: PyMutex>>, + // Whether the Python SNI callback has already been run for this handshake + #[pytraverse(skip)] + sni_callback_processed: PyMutex, // Shutdown state for tracking close-notify exchange #[pytraverse(skip)] shutdown_state: PyMutex, @@ -2376,6 +2406,9 @@ mod _ssl { Completed, // unwrap() completed successfully } + /// TLS record header size (content_type + version + length). + const TLS_RECORD_HEADER_SIZE: usize = 5; + #[pyclass(with(Constructor, Representable), flags(BASETYPE))] impl PySSLSocket { // Check if this is BIO mode @@ -2556,7 +2589,7 @@ mod _ssl { .connection .lock() .as_ref() - .is_some_and(|conn| conn.is_session_resumed()); + .is_some_and(|conn| conn.handshake_kind() == Some(HandshakeKind::Resumed)); *self.session_was_reused.lock() = was_resumed; @@ -2584,7 +2617,7 @@ mod _ssl { // Internal implementation with timeout control pub(crate) fn sock_wait_for_io_impl( &self, - kind: SelectKind, + wait_kind: SockWaitKind, vm: &VirtualMachine, ) -> PyResult { if self.is_bio_mode() { @@ -2609,16 +2642,13 @@ mod _ssl { .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, kind, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; - - Ok(timed_out) + sock_wait(&socket, wait_kind, timeout, vm) } // Internal implementation with explicit timeout override pub(crate) fn sock_wait_for_io_with_timeout( &self, - kind: SelectKind, + wait_kind: SockWaitKind, timeout: Option, vm: &VirtualMachine, ) -> PyResult { @@ -2639,19 +2669,16 @@ mod _ssl { .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, kind, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; - - Ok(timed_out) + sock_wait(&socket, wait_kind, timeout, vm).map_err(|e| e.into_pyexception(vm)) } // SNI (Server Name Indication) Helper Methods: // These methods support the server-side handshake SNI callback mechanism /// Check if this is the first read during handshake (for SNI callback) - /// Returns true if we haven't processed ClientHello yet, regardless of SNI presence + /// Returns true until the SNI callback has been processed. pub(crate) fn is_first_sni_read(&self) -> bool { - self.client_hello_buffer.lock().is_none() + !*self.sni_callback_processed.lock() } /// Check if SNI callback is configured @@ -2660,9 +2687,13 @@ mod _ssl { self.context.read().sni_callback.read().is_some() } - /// Save ClientHello data from PyObjectRef for potential connection recreation + /// Save ClientHello data for potential connection recreation. pub(crate) fn save_client_hello_from_bytes(&self, bytes_data: &[u8]) { - *self.client_hello_buffer.lock() = Some(bytes_data.to_vec()); + let mut buffer = self.client_hello_buffer.lock(); + match buffer.as_mut() { + Some(existing) => existing.extend_from_slice(bytes_data), + None => *buffer = Some(bytes_data.to_vec()), + } } /// Get the extracted SNI name from resolver @@ -2780,23 +2811,85 @@ mod _ssl { return read_method.call((vm.ctx.new_int(size),), vm); } - // Normal socket mode - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - - // Call socket.socket.recv(self.sock, size, flags) - let recv_method = socket_class.get_attr("recv", vm)?; - recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm) + self.sock_recv_method + .call((self.sock.clone(), vm.ctx.new_int(size)), vm) } - /// Peek at socket data without consuming it (MSG_PEEK). - /// Used during TLS shutdown to avoid consuming post-TLS cleartext data. - pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult { - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - let recv_method = socket_class.get_attr("recv", vm)?; - let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?; - recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm) + // Helper to receive data for at most one TLS record. + // May return incomplete data but never returns more when completes a + // previously incomplete TLS record. + pub(crate) fn sock_recv_at_most_one_tls_record( + &self, + vm: &VirtualMachine, + ) -> PyResult { + let obj_to_bytes = |bytes_obj| { + PyBytesRef::try_from_object(vm, bytes_obj) + .map_err(|_| vm.new_os_error("Expected bytes from recv".to_string())) + }; + + let tls_record_header_buf = self + .tls_record_header_buf + .clone() + .downcast::() + .expect("BUG: tls_record_header_buf is not PyByteArray"); + + let buf_len = tls_record_header_buf.borrow_buf().len(); + + let (mut with_header, mut remaining_record_body_len) = + if buf_len < TLS_RECORD_HEADER_SIZE { + // We do not have a full TLS record header, start receiving one. + let bytes_obj = self.sock_recv(TLS_RECORD_HEADER_SIZE - buf_len, vm)?; + let bytes = obj_to_bytes(bytes_obj)?; + + let mut buf = tls_record_header_buf.borrow_buf_mut(); + buf.extend_from_slice(bytes.as_bytes()); + + if buf.len() < TLS_RECORD_HEADER_SIZE { + return Ok(bytes); + } + + // Parse the remaining length. + let record_body_len = u16::from_be_bytes([buf[3], buf[4]]); + // Validity of length value will be checked by rustls. + + // Zero-length TLS record. + if record_body_len == 0 { + buf.clear(); + return Ok(bytes); + } + + let mut bytes_vec = bytes.as_bytes().to_vec(); + bytes_vec.reserve(record_body_len as usize); + (Some(bytes_vec), record_body_len) + } else { + let buf = tls_record_header_buf.borrow_buf(); + let remaining_record_body_len = u16::from_be_bytes([buf[3], buf[4]]); + (None, remaining_record_body_len) + }; + + // We have full record header and are in a process of receiving a record. + let bytes_obj = self.sock_recv(remaining_record_body_len as usize, vm)?; + let bytes = obj_to_bytes(bytes_obj)?; + + if let Some(with_header) = with_header.as_mut() { + with_header.extend_from_slice(bytes.as_bytes()); + } + + let mut buf = tls_record_header_buf.borrow_buf_mut(); + remaining_record_body_len -= bytes.len() as u16; + if remaining_record_body_len == 0 { + // Record received completely, need to start a new one beginning with its header. + buf.clear(); + } else { + // Update remaining length in the header. + buf.as_mut_slice()[3..5].copy_from_slice(&remaining_record_body_len.to_be_bytes()); + } + + if let Some(with_header) = with_header { + Ok(vm.ctx.new_bytes(with_header)) + } else { + Ok(bytes) + } } /// Socket send - just sends data, caller must handle pending flush @@ -2809,13 +2902,8 @@ mod _ssl { return write_method.call((vm.ctx.new_bytes(data.to_vec()),), vm); } - // Normal socket mode - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - - // Call socket.socket.send(self.sock, data) - let send_method = socket_class.get_attr("send", vm)?; - send_method.call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm) + self.sock_send_method + .call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm) } /// Flush any pending TLS output data to the socket @@ -2851,13 +2939,12 @@ mod _ssl { socket_timeout }; - // Use sock_select directly with calculated timeout + // Use sock_wait directly with calculated timeout let py_socket: PyRef = self.sock.clone().try_into_value(vm)?; let socket = py_socket .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, SelectKind::Write, timeout_to_use) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + let timed_out = sock_wait(&socket, SockWaitKind::Write, timeout_to_use, vm)?; if timed_out { // Keep unsent data in pending buffer @@ -2918,7 +3005,7 @@ mod _ssl { let mut sent_total = 0; while sent_total < buf.len() { - let timed_out = self.sock_wait_for_io_impl(SelectKind::Write, vm)?; + let timed_out = self.sock_wait_for_io_impl(SockWaitKind::Write, vm)?; if timed_out { // Save unsent data to pending buffer self.pending_tls_output @@ -2990,8 +3077,7 @@ mod _ssl { let socket = py_socket .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, SelectKind::Write, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + let timed_out = sock_wait(&socket, SockWaitKind::Write, timeout, vm)?; if timed_out { return Err( @@ -3007,7 +3093,7 @@ mod _ssl { let mut pending = self.pending_tls_output.lock(); pending.drain(..sent); } - // If sent == 0, loop will retry with sock_select + // If sent == 0, loop will retry with sock_wait } Err(e) => { if is_blocking_io_error(&e, vm) { @@ -3121,7 +3207,7 @@ mod _ssl { /// Returns the configured ServerConnection. fn initialize_server_connection( &self, - conn_guard: &mut Option, + conn_guard: &mut Option, vm: &VirtualMachine, ) -> PyResult<()> { let ctx = self.context.read(); @@ -3291,11 +3377,11 @@ mod _ssl { vm.new_value_error(format!("Failed to create server connection: {e}")) })?; - *conn_guard = Some(TlsConnection::Server(conn)); + *conn_guard = Some(Connection::Server(conn)); // If ClientHello buffer exists (from SNI callback), re-inject it if let Some(ref hello_data) = *self.client_hello_buffer.lock() - && let Some(TlsConnection::Server(ref mut server)) = *conn_guard + && let Some(Connection::Server(ref mut server)) = *conn_guard { let mut cursor = std::io::Cursor::new(hello_data.as_slice()); let _ = server.read_tls(&mut cursor); @@ -3418,14 +3504,14 @@ mod _ssl { vm.new_value_error(format!("Failed to create client connection: {e}")) })?; - *conn_guard = Some(TlsConnection::Client(conn)); + *conn_guard = Some(Connection::Client(conn)); } } // Perform the actual handshake by exchanging data with the socket/BIO let conn = conn_guard.as_mut().expect("unreachable"); - let is_client = matches!(conn, TlsConnection::Client(_)); + let is_client = matches!(conn, Connection::Client(_)); let handshake_result = ssl_do_handshake(conn, self, vm); drop(conn_guard); @@ -3451,6 +3537,7 @@ mod _ssl { // Now safe to call Python callback (no locks held) self.invoke_sni_callback(sni_name.as_deref(), vm)?; + *self.sni_callback_processed.lock() = true; // Clear connection to trigger recreation *self.connection.lock() = None; @@ -4180,7 +4267,7 @@ mod _ssl { // Wait for socket to be readable let timed_out = self.sock_wait_for_io_with_timeout( - SelectKind::Read, + SockWaitKind::Read, remaining_timeout, vm, )?; @@ -4251,7 +4338,7 @@ mod _ssl { } // Helper: Write all pending TLS data (including close_notify) to outgoing buffer/BIO - fn write_pending_tls(&self, conn: &mut TlsConnection, vm: &VirtualMachine) -> PyResult<()> { + fn write_pending_tls(&self, conn: &mut Connection, vm: &VirtualMachine) -> PyResult<()> { // First, flush any previously pending TLS output // Must succeed before sending new data to maintain order self.flush_pending_tls_output(vm, None)?; @@ -4261,7 +4348,7 @@ mod _ssl { break; } - let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let mut buf = vec![0u8; SSL3_RT_MAX_PACKET_SIZE]; let written = conn .write_tls(&mut buf.as_mut_slice()) .map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?; @@ -4281,7 +4368,7 @@ mod _ssl { // Returns true if peer closed connection (with or without close_notify) fn try_read_close_notify( &self, - conn: &mut TlsConnection, + conn: &mut Connection, vm: &VirtualMachine, ) -> PyResult { // In socket mode, peek first to avoid consuming post-TLS cleartext @@ -4289,11 +4376,11 @@ mod _ssl { // transitions to cleartext. Without peeking, sock_recv may consume // cleartext data meant for the application after unwrap(). if self.incoming_bio.is_none() { - return self.try_read_close_notify_socket(conn, vm); + return Ok(self.try_read_close_notify_socket(conn, vm)); } // BIO mode: read from incoming BIO - match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + match self.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(bytes_obj) => { let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; let data = bytes.borrow_buf(); @@ -4332,81 +4419,31 @@ mod _ssl { /// /// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no /// such knob, so we enforce record-level reads manually via peek. - fn try_read_close_notify_socket( - &self, - conn: &mut TlsConnection, - vm: &VirtualMachine, - ) -> PyResult { - // Peek at the first 5 bytes (TLS record header size) - let peeked_obj = match self.sock_peek(5, vm) { - Ok(obj) => obj, - Err(e) => { - if is_blocking_io_error(&e, vm) { - return Ok(false); - } - return Ok(true); - } - }; - - let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?; - let peek_data = peeked.borrow_buf(); - - if peek_data.is_empty() { - return Ok(true); // EOF - } - - // TLS record content types: ChangeCipherSpec(20), Alert(21), - // Handshake(22), ApplicationData(23) - let content_type = peek_data[0]; - if !(20..=23).contains(&content_type) { - // Not a TLS record - post-TLS cleartext data. - // Peer has completed TLS shutdown; don't consume this data. - return Ok(true); - } - - // Determine how many bytes to read for exactly one TLS record - let recv_size = if peek_data.len() >= 5 { - let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize; - 5 + record_length - } else { - // Partial header available - read just these bytes for now - peek_data.len() - }; - - drop(peek_data); - drop(peeked); - - // Now consume exactly one TLS record from the socket - match self.sock_recv(recv_size, vm) { - Ok(bytes_obj) => { - let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; - let data = bytes.borrow_buf(); - + fn try_read_close_notify_socket(&self, conn: &mut Connection, vm: &VirtualMachine) -> bool { + // Consume at most one TLS record from the socket + match self.sock_recv_at_most_one_tls_record(vm) { + Ok(data) => { if data.is_empty() { - return Ok(true); + return true; } let data_slice: &[u8] = data.as_ref(); let mut cursor = std::io::Cursor::new(data_slice); let _ = conn.read_tls(&mut cursor); let _ = conn.process_new_packets(); - Ok(false) + false } Err(e) => { if is_blocking_io_error(&e, vm) { - return Ok(false); + return false; } - Ok(true) + true } } } // Helper: Check if peer has sent close_notify - fn check_peer_closed( - &self, - conn: &mut TlsConnection, - vm: &VirtualMachine, - ) -> PyResult { + fn check_peer_closed(&self, conn: &mut Connection, vm: &VirtualMachine) -> PyResult { // Process any remaining packets and check peer_has_closed let io_state = conn .process_new_packets() @@ -4468,12 +4505,12 @@ mod _ssl { let conn_guard = self.connection.lock(); if let Some(conn) = conn_guard.as_ref() { let version = match conn { - TlsConnection::Client(_) => { + Connection::Client(_) => { return Err(vm.new_value_error( "Post-handshake authentication requires server socket", )); } - TlsConnection::Server(server) => server.protocol_version(), + Connection::Server(server) => server.protocol_version(), }; // Post-handshake auth is only available in TLS 1.3 @@ -4835,35 +4872,34 @@ mod _ssl { #[pyfunction] fn RAND_status() -> i32 { - 1 // Always have good randomness with aws-lc-rs + 1 // The configured rustls provider supplies cryptographic randomness. } #[pyfunction] fn RAND_add(_string: PyObjectRef, _entropy: f64) { - // No-op: aws-lc-rs handles its own entropy + // No-op: the configured rustls provider handles its own entropy. // Accept any type (str, bytes, bytearray) } #[pyfunction] fn RAND_bytes(n: i64, vm: &VirtualMachine) -> PyResult { - use aws_lc_rs::rand::{SecureRandom, SystemRandom}; - // Validate n is not negative if n < 0 { return Err(vm.new_value_error("num must be positive")); } let n_usize = n as usize; - let rng = SystemRandom::new(); let mut buf = vec![0u8; n_usize]; - rng.fill(&mut buf) + CryptoExt::get_provider() + .secure_random + .fill(&mut buf) .map_err(|_| vm.new_os_error("Failed to generate random bytes"))?; Ok(PyBytesRef::from(vm.ctx.new_bytes(buf))) } #[pyfunction] fn RAND_pseudo_bytes(n: i64, vm: &VirtualMachine) -> PyResult<(PyBytesRef, bool)> { - // In rustls/aws-lc-rs, all random bytes are cryptographically strong + // Rustls providers expose cryptographically strong random bytes. let bytes = RAND_bytes(n, vm)?; Ok((bytes, true)) } @@ -4932,39 +4968,28 @@ mod _ssl { store_name: PyUtf8StrRef, vm: &VirtualMachine, ) -> PyResult> { - use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore}; - use windows_sys::Win32::Security::Cryptography; - let store_name_str = store_name.as_str(); - - // Try both Current User and Local Machine stores - let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; - let stores = open_fns - .iter() - .filter_map(|open| open(store_name_str).ok()) - .collect::>(); - - // If no stores could be opened, raise OSError - if stores.is_empty() { + let certs = rustpython_host_env::cert_store::enum_certificates(store_name_str); + if !certs.had_open_store { return Err(vm.new_os_error(format!( "failed to open certificate store {store_name_str:?}" ))); } - let certs = stores.iter().flat_map(|s| s.certs()).map(|c| { - let cert = vm.ctx.new_bytes(c.to_der().to_owned()); - let enc_type = unsafe { - let ptr = c.as_ptr() as *const Cryptography::CERT_CONTEXT; - (*ptr).dwCertEncodingType - }; - let enc_type = match enc_type { - Cryptography::X509_ASN_ENCODING => vm.new_pyobj("x509_asn"), - Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj("pkcs_7_asn"), - other => vm.new_pyobj(other), + let certs = certs.entries.into_iter().map(|c| { + let cert = vm.ctx.new_bytes(c.der); + let enc_type = match c.encoding { + rustpython_host_env::cert_store::EncodingType::X509Asn => vm.new_pyobj("x509_asn"), + rustpython_host_env::cert_store::EncodingType::Pkcs7Asn => { + vm.new_pyobj("pkcs_7_asn") + } + rustpython_host_env::cert_store::EncodingType::Other(other) => vm.new_pyobj(other), }; - let usage: PyObjectRef = match c.valid_uses() { - Ok(ValidUses::All) => vm.ctx.new_bool(true).into(), - Ok(ValidUses::Oids(oids)) => { + let usage: PyObjectRef = match c.valid_uses { + Ok(rustpython_host_env::cert_store::CertificateUses::All) => { + vm.ctx.new_bool(true).into() + } + Ok(rustpython_host_env::cert_store::CertificateUses::Oids(oids)) => { match crate::builtins::PyFrozenSet::from_iter( vm, oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), @@ -4983,54 +5008,30 @@ mod _ssl { #[cfg(windows)] #[pyfunction] fn enum_crls(store_name: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult> { - use windows_sys::Win32::Security::Cryptography::{ - CRL_CONTEXT, CertCloseStore, CertEnumCRLsInStore, CertOpenSystemStoreW, - X509_ASN_ENCODING, - }; - let store_name_str = store_name.as_str(); - let store_name_wide: Vec = store_name_str - .encode_utf16() - .chain(core::iter::once(0)) - .collect(); - - // Open system store - let store = unsafe { CertOpenSystemStoreW(0, store_name_wide.as_ptr()) }; - - if store.is_null() { - return Err(vm.new_os_error(format!( + let crls = rustpython_host_env::cert_store::enum_crls(store_name_str).map_err(|_| { + vm.new_os_error(format!( "failed to open certificate store {store_name_str:?}" - ))); - } - - let mut result = Vec::new(); - - let mut crl_context: *const CRL_CONTEXT = core::ptr::null(); - loop { - crl_context = unsafe { CertEnumCRLsInStore(store, crl_context) }; - if crl_context.is_null() { - break; - } - - let crl = unsafe { &*crl_context }; - let crl_bytes = - unsafe { core::slice::from_raw_parts(crl.pbCrlEncoded, crl.cbCrlEncoded as usize) }; - - let enc_type = if crl.dwCertEncodingType == X509_ASN_ENCODING { - vm.new_pyobj("x509_asn") - } else { - vm.new_pyobj(crl.dwCertEncodingType) - }; - - result.push( - vm.new_tuple((vm.ctx.new_bytes(crl_bytes.to_vec()), enc_type)) - .into(), - ); - } - - unsafe { CertCloseStore(store, 0) }; + )) + })?; - Ok(result) + Ok(crls + .into_iter() + .map(|crl| { + let enc_type = match crl.encoding { + rustpython_host_env::cert_store::EncodingType::X509Asn => { + vm.new_pyobj("x509_asn") + } + rustpython_host_env::cert_store::EncodingType::Pkcs7Asn => { + vm.new_pyobj("pkcs_7_asn") + } + rustpython_host_env::cert_store::EncodingType::Other(other) => { + vm.new_pyobj(other) + } + }; + vm.new_tuple((vm.ctx.new_bytes(crl.der), enc_type)).into() + }) + .collect()) } // Certificate type for SSL module (pure Rust implementation) diff --git a/crates/stdlib/src/ssl/cert.rs b/crates/stdlib/src/ssl/cert.rs index 835e3f37c6b..e304781b644 100644 --- a/crates/stdlib/src/ssl/cert.rs +++ b/crates/stdlib/src/ssl/cert.rs @@ -22,7 +22,10 @@ use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; use std::collections::HashSet; use x509_parser::prelude::*; -use super::compat::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; +use super::{ + _ssl::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}, + providers::CryptoExt, +}; // Certificate Verification Constants @@ -1151,7 +1154,7 @@ pub(super) fn load_cert_chain_from_file( let private_key = if let Some(pwd) = password { // Try to parse as encrypted PKCS#8 use der::SecretDocument; - use pkcs8::EncryptedPrivateKeyInfo; + use pkcs8::EncryptedPrivateKeyInfoRef; use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; let pem_str = String::from_utf8_lossy(&key_contents); @@ -1177,7 +1180,7 @@ pub(super) fn load_cert_chain_from_file( Ok((label, doc)) => { if label == "ENCRYPTED PRIVATE KEY" { // Parse encrypted key info from DER - match EncryptedPrivateKeyInfo::try_from(doc.as_bytes()) { + match EncryptedPrivateKeyInfoRef::try_from(doc.as_bytes()) { Ok(encrypted_key) => { // Decrypt with password match encrypted_key.decrypt(pwd.as_bytes()) { @@ -1268,9 +1271,7 @@ pub(super) fn validate_cert_key_match( // For rustls, the actual validation happens when creating CertifiedKey // We can attempt to create a signing key to verify the key is valid - use rustls::crypto::aws_lc_rs::sign::any_supported_type; - - match any_supported_type(private_key) { + match CryptoExt::get_ext().any_supported_key(private_key) { Ok(_signing_key) => { // If we can create a signing key, the private key is valid // Rustls will validate the cert-key match when building config diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index ed3880940b9..7ed65ce8f4c 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -15,28 +15,29 @@ #[path = "../openssl/ssl_data_31.rs"] mod ssl_data; -use crate::socket::{SelectKind, timeout_error_msg}; +use crate::socket::{SockWaitKind, timeout_error_msg}; use crate::vm::VirtualMachine; use alloc::sync::Arc; use parking_lot::RwLock as ParkingRwLock; -use rustls::RootCertStore; +use rustls::Connection; use rustls::client::ClientConfig; -use rustls::client::ClientConnection; -use rustls::crypto::SupportedKxGroup; +use rustls::crypto::{CryptoProvider, SupportedKxGroup}; use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer}; -use rustls::server::ResolvesServerCert; -use rustls::server::ServerConfig; -use rustls::server::ServerConnection; +use rustls::server::{ProducesTickets, ResolvesServerCert, ServerConfig, WebPkiClientVerifier}; use rustls::sign::CertifiedKey; +use rustls::{RootCertStore, SupportedCipherSuite}; use rustpython_vm::builtins::{PyBaseException, PyBaseExceptionRef}; use rustpython_vm::convert::IntoPyException; use rustpython_vm::function::ArgBytesLike; use rustpython_vm::{AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject}; use std::io::Read; -use std::sync::Once; + +use super::providers::CryptoExt; // Import PySSLSocket from parent module -use super::_ssl::PySSLSocket; +use super::_ssl::{ + PySSLSocket, SSL3_RT_MAX_PACKET_SIZE, VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT, +}; // Import error types and helper functions from error module use super::error::{ @@ -44,39 +45,8 @@ use super::error::{ create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, }; -// SSL Verification Flags -/// VERIFY_X509_STRICT flag for RFC 5280 strict compliance -/// When set, performs additional validation including AKI extension checks -pub(super) const VERIFY_X509_STRICT: i32 = 0x20; - -/// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation -/// When set, accept certificates if any certificate in the chain is in the trust store -/// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. -pub(super) const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; - -// CryptoProvider Initialization: - -/// Ensure the default CryptoProvider is installed (thread-safe, runs once) -/// -/// This is necessary because rustls 0.23+ requires a process-level CryptoProvider -/// to be installed before using default_provider(). We use Once to ensure this -/// happens exactly once, even if called from multiple threads. -static INIT_PROVIDER: Once = Once::new(); - -fn ensure_default_provider() { - INIT_PROVIDER.call_once(|| { - let _ = rustls::crypto::CryptoProvider::install_default( - rustls::crypto::aws_lc_rs::default_provider(), - ); - }); -} - // OpenSSL Constants: -// OpenSSL TLS record maximum plaintext size (ssl/ssl_local.h) -// #define SSL3_RT_MAX_PLAIN_LENGTH 16384 -const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; - // OpenSSL error library codes (include/openssl/err.h) // #define ERR_LIB_SSL 20 const ERR_LIB_SSL: i32 = 20; @@ -95,74 +65,15 @@ const X509_V_FLAG_CRL_CHECK: i32 = 4; // verification. They are used to map rustls certificate errors to OpenSSL // error codes for compatibility. -pub(super) use x509::{ - X509_V_ERR_CERT_HAS_EXPIRED, X509_V_ERR_CERT_NOT_YET_VALID, X509_V_ERR_CERT_REVOKED, - X509_V_ERR_HOSTNAME_MISMATCH, X509_V_ERR_INVALID_PURPOSE, X509_V_ERR_IP_ADDRESS_MISMATCH, - X509_V_ERR_UNABLE_TO_GET_CRL, X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY, - X509_V_ERR_UNSPECIFIED, -}; - -#[allow(dead_code)] -mod x509 { - pub(super) const X509_V_OK: i32 = 0; - pub(crate) const X509_V_ERR_UNSPECIFIED: i32 = 1; - pub(super) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: i32 = 2; - pub(crate) const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; - pub(super) const X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: i32 = 4; - pub(super) const X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: i32 = 5; - pub(super) const X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: i32 = 6; - pub(super) const X509_V_ERR_CERT_SIGNATURE_FAILURE: i32 = 7; - pub(super) const X509_V_ERR_CRL_SIGNATURE_FAILURE: i32 = 8; - pub(crate) const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; - pub(crate) const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; - pub(super) const X509_V_ERR_CRL_NOT_YET_VALID: i32 = 11; - pub(super) const X509_V_ERR_CRL_HAS_EXPIRED: i32 = 12; - pub(super) const X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: i32 = 13; - pub(super) const X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: i32 = 14; - pub(super) const X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: i32 = 15; - pub(super) const X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: i32 = 16; - pub(super) const X509_V_ERR_OUT_OF_MEM: i32 = 17; - pub(super) const X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: i32 = 18; - pub(super) const X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: i32 = 19; - pub(crate) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; - pub(super) const X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: i32 = 21; - pub(super) const X509_V_ERR_CERT_CHAIN_TOO_LONG: i32 = 22; - pub(crate) const X509_V_ERR_CERT_REVOKED: i32 = 23; - pub(super) const X509_V_ERR_INVALID_CA: i32 = 24; - pub(super) const X509_V_ERR_PATH_LENGTH_EXCEEDED: i32 = 25; - pub(crate) const X509_V_ERR_INVALID_PURPOSE: i32 = 26; - pub(super) const X509_V_ERR_CERT_UNTRUSTED: i32 = 27; - pub(super) const X509_V_ERR_CERT_REJECTED: i32 = 28; - pub(super) const X509_V_ERR_SUBJECT_ISSUER_MISMATCH: i32 = 29; - pub(super) const X509_V_ERR_AKID_SKID_MISMATCH: i32 = 30; - pub(super) const X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: i32 = 31; - pub(super) const X509_V_ERR_KEYUSAGE_NO_CERTSIGN: i32 = 32; - pub(super) const X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER: i32 = 33; - pub(super) const X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION: i32 = 34; - pub(super) const X509_V_ERR_KEYUSAGE_NO_CRL_SIGN: i32 = 35; - pub(super) const X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION: i32 = 36; - pub(super) const X509_V_ERR_INVALID_NON_CA: i32 = 37; - pub(super) const X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED: i32 = 38; - pub(super) const X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE: i32 = 39; - pub(super) const X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED: i32 = 40; - pub(super) const X509_V_ERR_INVALID_EXTENSION: i32 = 41; - pub(super) const X509_V_ERR_INVALID_POLICY_EXTENSION: i32 = 42; - pub(super) const X509_V_ERR_NO_EXPLICIT_POLICY: i32 = 43; - pub(super) const X509_V_ERR_DIFFERENT_CRL_SCOPE: i32 = 44; - pub(super) const X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE: i32 = 45; - pub(super) const X509_V_ERR_UNNESTED_RESOURCE: i32 = 46; - pub(super) const X509_V_ERR_PERMITTED_VIOLATION: i32 = 47; - pub(super) const X509_V_ERR_EXCLUDED_VIOLATION: i32 = 48; - pub(super) const X509_V_ERR_SUBTREE_MINMAX: i32 = 49; - pub(super) const X509_V_ERR_APPLICATION_VERIFICATION: i32 = 50; - pub(super) const X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE: i32 = 51; - pub(super) const X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX: i32 = 52; - pub(super) const X509_V_ERR_UNSUPPORTED_NAME_SYNTAX: i32 = 53; - pub(super) const X509_V_ERR_CRL_PATH_VALIDATION_ERROR: i32 = 54; - pub(crate) const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; - pub(super) const X509_V_ERR_EMAIL_MISMATCH: i32 = 63; - pub(crate) const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; -} +pub(super) const X509_V_ERR_UNSPECIFIED: i32 = 1; +pub(super) const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; +pub(super) const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; +pub(super) const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; +pub(super) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; +pub(super) const X509_V_ERR_CERT_REVOKED: i32 = 23; +pub(super) const X509_V_ERR_INVALID_PURPOSE: i32 = 26; +pub(super) const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; +pub(super) const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; // Certificate Error Conversion Functions: @@ -263,126 +174,6 @@ pub(super) fn create_ssl_cert_verification_error( Ok(exc.upcast()) } -/// Unified TLS connection type (client or server) -#[derive(Debug)] -pub(super) enum TlsConnection { - Client(ClientConnection), - Server(ServerConnection), -} - -impl TlsConnection { - /// Check if handshake is in progress - pub(super) fn is_handshaking(&self) -> bool { - match self { - Self::Client(conn) => conn.is_handshaking(), - Self::Server(conn) => conn.is_handshaking(), - } - } - - /// Check if connection wants to read data - pub(super) fn wants_read(&self) -> bool { - match self { - Self::Client(conn) => conn.wants_read(), - Self::Server(conn) => conn.wants_read(), - } - } - - /// Check if connection wants to write data - pub(super) fn wants_write(&self) -> bool { - match self { - Self::Client(conn) => conn.wants_write(), - Self::Server(conn) => conn.wants_write(), - } - } - - /// Read TLS data from socket - pub(super) fn read_tls(&mut self, reader: &mut dyn std::io::Read) -> std::io::Result { - match self { - Self::Client(conn) => conn.read_tls(reader), - Self::Server(conn) => conn.read_tls(reader), - } - } - - /// Write TLS data to socket - pub(super) fn write_tls(&mut self, writer: &mut dyn std::io::Write) -> std::io::Result { - match self { - Self::Client(conn) => conn.write_tls(writer), - Self::Server(conn) => conn.write_tls(writer), - } - } - - /// Process new TLS packets - pub(super) fn process_new_packets(&mut self) -> Result { - match self { - Self::Client(conn) => conn.process_new_packets(), - Self::Server(conn) => conn.process_new_packets(), - } - } - - /// Get reader for plaintext data (rustls native type) - pub(super) fn reader(&mut self) -> rustls::Reader<'_> { - match self { - Self::Client(conn) => conn.reader(), - Self::Server(conn) => conn.reader(), - } - } - - /// Get writer for plaintext data (rustls native type) - pub(super) fn writer(&mut self) -> rustls::Writer<'_> { - match self { - Self::Client(conn) => conn.writer(), - Self::Server(conn) => conn.writer(), - } - } - - /// Check if session was resumed - pub(super) fn is_session_resumed(&self) -> bool { - use rustls::HandshakeKind; - match self { - Self::Client(conn) => { - matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) - } - Self::Server(conn) => { - matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) - } - } - } - - /// Send close_notify alert - pub(super) fn send_close_notify(&mut self) { - match self { - Self::Client(conn) => conn.send_close_notify(), - Self::Server(conn) => conn.send_close_notify(), - } - } - - /// Get negotiated ALPN protocol - pub(super) fn alpn_protocol(&self) -> Option<&[u8]> { - match self { - Self::Client(conn) => conn.alpn_protocol(), - Self::Server(conn) => conn.alpn_protocol(), - } - } - - /// Get negotiated cipher suite - pub(super) fn negotiated_cipher_suite(&self) -> Option { - match self { - Self::Client(conn) => conn.negotiated_cipher_suite(), - Self::Server(conn) => conn.negotiated_cipher_suite(), - } - } - - /// Get peer certificates - pub(super) fn peer_certificates( - &self, - ) -> Option<&[rustls::pki_types::CertificateDer<'static>]> { - match self { - Self::Client(conn) => conn.peer_certificates(), - Self::Server(conn) => conn.peer_certificates(), - } - } -} - /// Error types matching OpenSSL error codes #[derive(Debug)] pub(super) enum SslError { @@ -584,6 +375,16 @@ impl SslError { // Use the proper cert verification error creator create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen") } + Self::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + create_ssl_eof_error(vm).upcast() + } + Self::Io(err) if err.raw_os_error().is_none() => vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + format!("SSL error: {err}"), + ) + .upcast(), Self::Io(err) => err.into_pyexception(vm), Self::SniCallbackRestart => { // This should be handled at PySSLSocket level @@ -633,7 +434,7 @@ pub(super) struct ServerConfigOptions { /// Session storage for server-side session resumption pub session_storage: Option>, /// Shared ticketer for TLS 1.2 session tickets (stateless resumption) - pub ticketer: Option>, + pub ticketer: Option>, } /// Options for creating a client TLS configuration @@ -666,15 +467,14 @@ pub(super) struct ClientConfigOptions { /// This helper function consolidates the duplicated CryptoProvider creation logic /// for both server and client configurations. fn create_custom_crypto_provider( - cipher_suites: Option>, - kx_groups: Option>, -) -> Arc { - use rustls::crypto::aws_lc_rs::{ALL_CIPHER_SUITES, ALL_KX_GROUPS}; - let default_provider = rustls::crypto::aws_lc_rs::default_provider(); - - Arc::new(rustls::crypto::CryptoProvider { - cipher_suites: cipher_suites.unwrap_or_else(|| ALL_CIPHER_SUITES.to_vec()), - kx_groups: kx_groups.unwrap_or_else(|| ALL_KX_GROUPS.to_vec()), + cipher_suites: Option>, + kx_groups: Option>, +) -> Arc { + let default_provider = CryptoExt::get_provider(); + + Arc::new(CryptoProvider { + cipher_suites: cipher_suites.unwrap_or_else(|| default_provider.cipher_suites.clone()), + kx_groups: kx_groups.unwrap_or_else(|| default_provider.kx_groups.clone()), signature_verification_algorithms: default_provider.signature_verification_algorithms, secure_random: default_provider.secure_random, key_provider: default_provider.key_provider, @@ -686,11 +486,6 @@ fn create_custom_crypto_provider( /// This abstracts the complex rustls ServerConfig building logic, /// matching SSL_CTX initialization for server sockets. pub(super) fn create_server_config(options: ServerConfigOptions) -> Result { - use rustls::server::WebPkiClientVerifier; - - // Ensure default CryptoProvider is installed - ensure_default_provider(); - // Create custom crypto provider using helper function let custom_provider = create_custom_crypto_provider( options.protocol_settings.cipher_suites.clone(), @@ -868,9 +663,6 @@ fn apply_alpn_with_fallback(config_alpn: &mut Vec>, alpn_protocols: &[Ve /// This abstracts the complex rustls ClientConfig building logic, /// matching SSL_CTX initialization for client sockets. pub(super) fn create_client_config(options: ClientConfigOptions) -> Result { - // Ensure default CryptoProvider is installed - ensure_default_provider(); - // Create custom crypto provider using helper function let custom_provider = create_custom_crypto_provider( options.protocol_settings.cipher_suites.clone(), @@ -938,7 +730,6 @@ pub(super) fn create_client_config(options: ClientConfigOptions) -> Result SslResult { - // Peek at what is available without consuming it. - let peeked_obj = match socket.sock_peek(SSL3_RT_MAX_PLAIN_LENGTH, vm) { - Ok(d) => d, - Err(e) => { - if is_blocking_io_error(&e, vm) { - return Err(SslError::WantRead); - } - return Err(SslError::Py(e)); - } - }; - - let peeked = ArgBytesLike::try_from_object(vm, peeked_obj) - .map_err(|_| SslError::Syscall("Expected bytes-like object from peek".to_string()))?; - let peeked_bytes = peeked.borrow_buf(); - - if peeked_bytes.is_empty() { - // Empty peek means the peer has closed the TCP connection (FIN). - // Non-blocking sockets would have returned EAGAIN/EWOULDBLOCK - // (caught above as WantRead), so empty bytes here always means EOF. - return Err(SslError::Eof); - } - - if peeked_bytes.len() < TLS_RECORD_HEADER_SIZE { - // Not enough data for a TLS record header yet. - // Read all available bytes so rustls can buffer the partial header; - // this avoids busy-waiting because the kernel buffer is now empty - // and select() will only wake us when new data arrives. - return socket.sock_recv(peeked_bytes.len(), vm).map_err(|e| { - if is_blocking_io_error(&e, vm) { - SslError::WantRead - } else { - SslError::Py(e) - } - }); - } - - // Parse the TLS record length from the header. - let record_body_len = u16::from_be_bytes([peeked_bytes[3], peeked_bytes[4]]) as usize; - let total_record_size = TLS_RECORD_HEADER_SIZE + record_body_len; - - let recv_size = if peeked_bytes.len() >= total_record_size { - // Complete record available — consume exactly one record. - total_record_size - } else { - // Incomplete record — consume everything so the kernel buffer is - // drained and select() will block until more data arrives. - peeked_bytes.len() - }; - - // Must drop the borrow before calling sock_recv (which re-enters Python). - drop(peeked_bytes); - drop(peeked); - - socket.sock_recv(recv_size, vm).map_err(|e| { +fn recv_at_most_one_tls_record( + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + let bytes = socket.sock_recv_at_most_one_tls_record(vm).map_err(|e| { if is_blocking_io_error(&e, vm) { SslError::WantRead } else { SslError::Py(e) } - }) + })?; + if bytes.is_empty() { + Err(SslError::Eof) + } else { + Ok(bytes.into()) + } } -/// Read a single TLS record for post-handshake I/O while preserving the +/// Read up to a single TLS record for post-handshake I/O while preserving the /// SSL-vs-socket error precedence from the old sock_recv() path. -fn recv_one_tls_record_for_data( - conn: &mut TlsConnection, +fn recv_at_most_one_tls_record_for_data( + conn: &mut Connection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { - match recv_one_tls_record(socket, vm) { + match recv_at_most_one_tls_record(socket, vm) { Ok(data) => Ok(data), Err(SslError::Eof) => { if let Err(rustls_err) = conn.process_new_packets() { @@ -1275,7 +1017,7 @@ fn recv_one_tls_record_for_data( } fn handshake_read_data( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, is_bio: bool, is_server: bool, @@ -1285,14 +1027,15 @@ fn handshake_read_data( return Ok((false, false)); } - // SERVER-SPECIFIC: Check if this is the first read (for SNI callback) - // Must check BEFORE reading data, so we can detect first time - let is_first_sni_read = is_server && socket.is_first_sni_read(); + // SERVER-SPECIFIC: Check if this is before the SNI callback. + // sock_recv() may return only part of a TLS record, so keep capturing + // ClientHello fragments until process_new_packets() has produced a response. + let is_first_sni_read = is_server && socket.has_sni_callback() && socket.is_first_sni_read(); // Wait for data in socket mode if !is_bio { let timed_out = socket - .sock_wait_for_io_impl(SelectKind::Read, vm) + .sock_wait_for_io_impl(SockWaitKind::Read, vm) .map_err(SslError::Py)?; if timed_out { @@ -1308,9 +1051,9 @@ fn handshake_read_data( // record. This matches OpenSSL's default (no read-ahead) behaviour // and keeps remaining data in the kernel buffer where select() can // detect it. - recv_one_tls_record(socket, vm)? + recv_at_most_one_tls_record(socket, vm)? } else { - match socket.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + match socket.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(d) => d, Err(e) => { if is_blocking_io_error(&e, vm) { @@ -1324,7 +1067,7 @@ fn handshake_read_data( } }; - // SERVER-SPECIFIC: Save ClientHello on first read for potential connection recreation + // SERVER-SPECIFIC: Save ClientHello fragments for potential connection recreation. if is_first_sni_read { // Extract bytes from PyObjectRef use rustpython_vm::builtins::PyBytes; @@ -1344,7 +1087,7 @@ fn handshake_read_data( /// Tries to send NewSessionTicket in non-blocking mode to avoid deadlocks. /// Returns true if handshake is complete and we should exit. fn handle_handshake_complete( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, _is_server: bool, vm: &VirtualMachine, @@ -1419,7 +1162,7 @@ fn handle_handshake_complete( /// /// Returns Ok(Some(n)) if n bytes were read, Ok(None) if would block, /// or Err on real errors. -fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult> { +fn try_read_plaintext(conn: &mut Connection, buf: &mut [u8]) -> SslResult> { let mut reader = conn.reader(); match reader.read(buf) { Ok(0) => { @@ -1448,7 +1191,7 @@ fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult SslResult<()> { @@ -1458,12 +1201,9 @@ pub(super) fn ssl_do_handshake( } let is_bio = socket.is_bio_mode(); - let is_server = matches!(conn, TlsConnection::Server(_)); + let is_server = matches!(conn, Connection::Server(_)); let mut first_iteration = true; // Track if this is the first loop iteration - let mut iteration_count = 0; - loop { - iteration_count += 1; let mut made_progress = false; // IMPORTANT: In BIO mode, force initial write even if wants_write() is false @@ -1506,10 +1246,10 @@ pub(super) fn ssl_do_handshake( return Err(SslError::from_rustls(e)); } - // SERVER-SPECIFIC: Check SNI callback after processing packets - // SNI name is extracted during process_new_packets() - // Invoke callback on FIRST read if callback is configured, regardless of SNI presence - if is_server && is_first_sni_read && socket.has_sni_callback() { + // SERVER-SPECIFIC: Check SNI callback after processing packets. + // A partial TLS record can be read without producing any handshake + // response. Wait until rustls has processed a complete ClientHello. + if is_server && is_first_sni_read && socket.has_sni_callback() && conn.wants_write() { // IMPORTANT: Do NOT call the callback here! // The connection lock is still held, which would cause deadlock. // Return SniCallbackRestart to signal do_handshake to: @@ -1533,7 +1273,7 @@ pub(super) fn ssl_do_handshake( if conn.wants_write() { // Write all pending TLS data to outgoing BIO loop { - let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let mut buf = vec![0u8; SSL3_RT_MAX_PACKET_SIZE]; let n = match conn.write_tls(&mut buf.as_mut_slice()) { Ok(n) => n, Err(_) => break, @@ -1581,11 +1321,6 @@ pub(super) fn ssl_do_handshake( if !should_continue { break; } - - // Safety check: prevent truly infinite loops (should never happen) - if iteration_count > 1000 { - break; - } } // If we exit the loop without completing handshake, return appropriate error @@ -1599,9 +1334,9 @@ pub(super) fn ssl_do_handshake( return Err(SslError::WantRead); } // Neither wants_read nor wants_write - this is a real error - Err(SslError::Syscall(format!( - "SSL handshake failed: incomplete after {iteration_count} iterations", - ))) + Err(SslError::Syscall( + "SSL handshake failed: incomplete handshake".to_string(), + )) } else { // Handshake completed successfully (shouldn't reach here normally) Ok(()) @@ -1615,7 +1350,7 @@ pub(super) fn ssl_do_handshake( /// /// = SSL_read_ex() pub(super) fn ssl_read( - conn: &mut TlsConnection, + conn: &mut Connection, buf: &mut [u8], socket: &PySSLSocket, vm: &VirtualMachine, @@ -1753,7 +1488,7 @@ pub(super) fn ssl_read( // Blocking socket or socket with timeout: try to read more data from socket. // Even though rustls says it doesn't want to read, more TLS records may arrive. // Use single-record reading to avoid consuming close_notify alongside data. - let data = recv_one_tls_record_for_data(conn, socket, vm)?; + let data = recv_at_most_one_tls_record_for_data(conn, socket, vm)?; let bytes_read = data .clone() @@ -1788,11 +1523,6 @@ pub(super) fn ssl_read( // Successfully read and processed TLS data // Continue loop to try reading plaintext } - Err(SslError::Io(ref io_err)) if io_err.to_string().contains("message buffer full") => { - // This case should be rare now that ssl_read_tls_records handles buffer full - // Just continue loop to try again - continue; - } Err(e) => { // Other errors - check for buffered plaintext before propagating match try_read_plaintext(conn, buf)? { @@ -1817,7 +1547,7 @@ pub(super) fn ssl_read( /// /// = SSL_write_ex() pub(super) fn ssl_write( - conn: &mut TlsConnection, + conn: &mut Connection, data: &[u8], socket: &PySSLSocket, vm: &VirtualMachine, @@ -1944,7 +1674,7 @@ pub(super) fn ssl_write( return Err(SslError::WantRead); } // For socket mode, try to read TLS data - let recv_result = socket.sock_recv(4096, vm).map_err(SslError::Py)?; + let recv_result = recv_at_most_one_tls_record_for_data(conn, socket, vm)?; ssl_read_tls_records(conn, recv_result, false, vm)?; conn.process_new_packets().map_err(SslError::from_rustls)?; // Continue loop @@ -1994,7 +1724,7 @@ pub(super) fn ssl_write( // Helper functions (private-ish, used by public SSL functions) /// Write TLS records from rustls to socket -fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { +fn ssl_write_tls_records(conn: &mut Connection) -> SslResult> { let mut buf = Vec::new(); let n = conn .write_tls(&mut buf as &mut dyn std::io::Write) @@ -2005,7 +1735,7 @@ fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { /// Read TLS records from socket to rustls fn ssl_read_tls_records( - conn: &mut TlsConnection, + conn: &mut Connection, data: PyObjectRef, is_bio: bool, vm: &VirtualMachine, @@ -2068,6 +1798,9 @@ fn ssl_read_tls_records( } Ok(n) => { offset += n; + if offset < bytes_data.len() { + conn.process_new_packets().map_err(SslError::from_rustls)?; + } } Err(e) => { return Err(SslError::Io(e)); @@ -2075,14 +1808,12 @@ fn ssl_read_tls_records( } } else { offset += read_bytes; + if offset < bytes_data.len() { + conn.process_new_packets().map_err(SslError::from_rustls)?; + } } } Err(e) => { - // Check if it's a buffer full error (unlikely but handle it) - if e.to_string().contains("buffer full") { - conn.process_new_packets().map_err(SslError::from_rustls)?; - continue; - } // Real error - propagate it return Err(SslError::Io(e)); } @@ -2118,7 +1849,7 @@ fn is_connection_closed_error(exc: &Py, vm: &VirtualMachine) -> /// Ensure TLS data is available for reading /// Returns the number of bytes read from the socket fn ssl_ensure_data_available( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { @@ -2140,7 +1871,7 @@ fn ssl_ensure_data_available( { // Socket has timeout - use select to enforce it let timed_out = socket - .sock_wait_for_io_impl(SelectKind::Read, vm) + .sock_wait_for_io_impl(SockWaitKind::Read, vm) .map_err(SslError::Py)?; if timed_out { // Socket not ready within timeout - raise socket.timeout @@ -2157,9 +1888,9 @@ fn ssl_ensure_data_available( // consuming a close_notify that arrives alongside application data, // keeping it in the kernel buffer where select() can detect it. let data = if !is_bio { - recv_one_tls_record_for_data(conn, socket, vm)? + recv_at_most_one_tls_record_for_data(conn, socket, vm)? } else { - match socket.sock_recv(2048, vm) { + match socket.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(data) => data, Err(e) => { if is_blocking_io_error(&e, vm) { @@ -2367,8 +2098,7 @@ pub(super) fn curve_name_to_kx_group( curve: &str, ) -> Result, String> { // Get the default crypto provider's key exchange groups - let provider = rustls::crypto::aws_lc_rs::default_provider(); - let all_groups = &provider.kx_groups; + let all_groups = CryptoExt::get_ext().all_kx_or_default(); match curve { // P-256 (also known as secp256r1 or prime256v1) @@ -2393,14 +2123,12 @@ pub(super) fn curve_name_to_kx_group( .map(|g| vec![*g]) .ok_or_else(|| "X25519 not supported by crypto provider".to_owned()), // P-521 (also known as secp521r1 or prime521v1) - // Now supported with aws-lc-rs crypto provider "prime521v1" | "secp521r1" => all_groups .iter() .find(|g| g.name() == rustls::NamedGroup::secp521r1) .map(|g| vec![*g]) .ok_or_else(|| "secp521r1 not supported by crypto provider".to_owned()), // X448 - // Now supported with aws-lc-rs crypto provider "X448" | "x448" => all_groups .iter() .find(|g| g.name() == rustls::NamedGroup::X448) diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index d12cd834d1b..4e5def82bd5 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -125,7 +125,10 @@ pub(crate) mod ssl_error { ) } - #[allow(dead_code, reason = "This seems like a false positive")] + #[cfg_attr( + all(feature = "ssl-openssl", not(feature = "ssl-rustls")), + expect(dead_code) + )] pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { vm.new_os_subtype_error( PySSLZeroReturnError::class(&vm.ctx).to_owned(), @@ -134,7 +137,10 @@ pub(crate) mod ssl_error { ) } - #[allow(dead_code, reason = "This seems like a false positive")] + #[cfg_attr( + all(feature = "ssl-openssl", not(feature = "ssl-rustls")), + expect(dead_code) + )] pub(crate) fn create_ssl_syscall_error( vm: &VirtualMachine, msg: impl Into, diff --git a/crates/stdlib/src/ssl/oid.rs b/crates/stdlib/src/ssl/oid.rs index 175951628d1..ca059ff5000 100644 --- a/crates/stdlib/src/ssl/oid.rs +++ b/crates/stdlib/src/ssl/oid.rs @@ -394,7 +394,7 @@ mod tests { use super::*; #[test] - fn test_find_by_nid() { + fn find_by_nid_ok() { let entry = find_by_nid(13).unwrap(); assert_eq!(entry.short_name, "CN"); assert_eq!(entry.long_name, "commonName"); @@ -402,48 +402,48 @@ mod tests { } #[test] - fn test_find_by_oid_string() { + fn find_by_oid_string_ok() { let entry = find_by_oid_string("2.5.4.3").unwrap(); assert_eq!(entry.nid, 13); assert_eq!(entry.short_name, "CN"); } #[test] - fn test_find_by_name_short() { + fn find_by_name_short() { let entry = find_by_name("CN").unwrap(); assert_eq!(entry.nid, 13); assert_eq!(entry.oid_string(), "2.5.4.3"); } #[test] - fn test_find_by_name_long() { + fn find_by_name_long() { let entry = find_by_name("commonName").unwrap(); assert_eq!(entry.nid, 13); assert_eq!(entry.short_name, "CN"); } #[test] - fn test_find_by_name_case_insensitive() { + fn find_by_name_case_insensitive() { let entry = find_by_name("COMMONNAME").unwrap(); assert_eq!(entry.nid, 13); } #[test] - fn test_subject_alt_name() { + fn subject_alt_name() { let entry = find_by_nid(85).unwrap(); assert_eq!(entry.short_name, "subjectAltName"); assert_eq!(entry.oid_string(), "2.5.29.17"); } #[test] - fn test_server_auth_eku() { + fn server_auth_eku() { let entry = find_by_nid(129).unwrap(); assert_eq!(entry.short_name, "serverAuth"); assert_eq!(entry.oid_string(), "1.3.6.1.5.5.7.3.1"); } #[test] - fn test_no_duplicate_nids() { + fn no_duplicate_nids() { let table = &*OID_TABLE; assert_eq!( table.entries.len(), @@ -453,7 +453,7 @@ mod tests { } #[test] - fn test_oid_count() { + fn oid_count() { let table = &*OID_TABLE; // We should have 50+ OIDs defined assert!( diff --git a/crates/stdlib/src/ssl/providers.rs b/crates/stdlib/src/ssl/providers.rs new file mode 100644 index 00000000000..478d02ff933 --- /dev/null +++ b/crates/stdlib/src/ssl/providers.rs @@ -0,0 +1,132 @@ +//! Utilities for user-settable cryptography providers. +//! +//! This has two main moving parts: [`CryptoProvider`] and [`CryptoExt`]. [`CryptoProvider`] +//! is always implemented by the cryptography crate because it's a trait from Rustls. RustPython +//! needs some extra data such as all of the cipher suites supported by an implementation. +//! The [`CryptoExt`] table stores that extra data if it exists and provides convenience methods +//! as a fallback. +//! +//! Both the [`CryptoProvider`] and [`CryptoExt`] are process-level structs that need to be +//! set before any TLS operations. [`CryptoExt::set_provider`] is thread-safe and runs once. +//! It sets both once per process. + +use alloc::sync::Arc; +use std::sync::OnceLock; + +use rustls::{ + Error, SignatureScheme, SupportedCipherSuite, + crypto::{CryptoProvider, SupportedKxGroup}, + pki_types::PrivateKeyDer, + server::ProducesTickets, + sign::SigningKey, +}; + +static CRYPTO_EXT: OnceLock = OnceLock::new(); + +#[derive(Clone, Copy)] +pub struct CryptoExt { + pub all_cipher_suites: Option<&'static [SupportedCipherSuite]>, + pub all_kx_groups: Option<&'static [&'static dyn SupportedKxGroup]>, + #[allow(clippy::type_complexity)] + pub any_supported_key: Option) -> Result, Error>>, + pub ticketer: fn() -> Result, Error>, +} + +impl CryptoExt { + #[inline] + #[must_use] + pub fn get_ext() -> &'static Self { + CRYPTO_EXT + .get() + .expect("A CryptoProvider must be set before TLS") + } + + #[inline] + #[must_use] + pub fn get_provider() -> &'static CryptoProvider { + CryptoProvider::get_default().expect("A CryptoProvider must be set before TLS") + } + + /// Returns all [`SupportedCipherSuite`] or the provider's defaults. + /// + /// # Panics + /// Panics if a [`CryptoProvider`] hasn't been set. + #[must_use] + pub fn all_ciphers_or_default(&self) -> &'static [SupportedCipherSuite] { + self.all_cipher_suites.unwrap_or_else(|| { + CryptoProvider::get_default() + .expect("A CryptoProvider has been set if CryptoExt is set") + .cipher_suites + .as_slice() + }) + } + + /// Returns all [`SupportedKxGroup`] or the provider's defaults. + /// + /// # Panics + /// Panics if a [`CryptoProvider`] hasn't been set. + #[must_use] + pub fn all_kx_or_default(&self) -> &'static [&'static dyn SupportedKxGroup] { + self.all_kx_groups.unwrap_or_else(|| { + CryptoProvider::get_default() + .expect("A CryptoProvider has been set if CryptoExt is set") + .kx_groups + .as_slice() + }) + } + + /// Return the first supported [`SigningKey`] for a [`PrivateKeyDer`]. + /// + /// Ideally, this function should be provided by the backend implementation or + /// the user. This fallback filters out insecure algorithms then picks the first available key + /// if it exists. + /// + /// # Panics + /// Panics if a [`CryptoProvider`] hasn't been set. + pub fn any_supported_key(&self, der: &PrivateKeyDer<'_>) -> Result, Error> { + self.any_supported_key.map_or_else( + || { + let provider = CryptoProvider::get_default() + .expect("A CryptoProvider has been set if CryptoExt is set"); + let key = provider.key_provider.load_private_key(der.clone_key())?; + + for scheme in provider + .signature_verification_algorithms + .mapping + .iter() + .filter_map(|(scheme, _)| { + (!matches!( + scheme, + SignatureScheme::RSA_PKCS1_SHA1 + | SignatureScheme::ECDSA_SHA1_Legacy + | SignatureScheme::Unknown(_), + )) + .then_some(*scheme) + }) + { + if key.choose_scheme(&[scheme]).is_some() { + return Ok(key); + } + } + + Err(Error::General( + "failed to parse private key as RSA, ECDSA, or EdDSA".into(), + )) + }, + |f| f(der), + ) + } + + /// Set a process-level [`CryptoProvider`] and [`CryptoExt`]. + /// + /// A provider must be set before any cryptographic operations. All crypto ops panic if a provider + /// is unset. + pub fn set_provider(provider: CryptoProvider, extension: Self) -> Result<(), Error> { + provider + .install_default() + .map_err(|_| Error::General("A default CryptoProvider is already set".into()))?; + CRYPTO_EXT + .set(extension) + .map_err(|_| Error::General("A CryptoExt is already set".into())) + } +} diff --git a/crates/stdlib/src/syslog.rs b/crates/stdlib/src/syslog.rs index fab82f14f56..52424972a0a 100644 --- a/crates/stdlib/src/syslog.rs +++ b/crates/stdlib/src/syslog.rs @@ -53,12 +53,29 @@ mod syslog { fn openlog(args: OpenLogArgs, vm: &VirtualMachine) -> PyResult<()> { let logoption = args.logoption.unwrap_or(0); let facility = args.facility.unwrap_or(LOG_USER); - let ident = match args.ident.flatten() { + let ident = match args.ident.clone().flatten() { Some(args) => Some(args.to_cstring(vm)?), None => get_argv(vm).map(|argv| argv.to_cstring(vm)).transpose()?, } .map(|ident| ident.into_boxed_c_str()); + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + let audit_ident: PyObjectRef = args.ident.flatten().map_or_else( + || get_argv(vm).map_or_else(|| vm.ctx.none(), Into::into), + Into::into, + ); + + audit.call( + ( + vm.ctx.new_str("syslog.openlog"), + audit_ident, + logoption, + facility, + ), + vm, + )?; + } + host_syslog::openlog(ident, logoption, facility); Ok(()) } @@ -78,6 +95,10 @@ mod syslog { None => (LOG_INFO, args.priority.try_into_value(vm)?), }; + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("syslog.syslog"), priority, msg.clone()), vm)?; + } + if !host_syslog::is_open() { openlog(OpenLogArgs::default(), vm)?; } @@ -88,13 +109,22 @@ mod syslog { } #[pyfunction] - fn closelog() { + fn closelog(vm: &VirtualMachine) -> PyResult<()> { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("syslog.closelog"),), vm)?; + } + host_syslog::closelog(); + Ok(()) } #[pyfunction] - fn setlogmask(maskpri: i32) -> i32 { - host_syslog::setlogmask(maskpri) + fn setlogmask(maskpri: i32, vm: &VirtualMachine) -> PyResult { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("syslog.setlogmask"), maskpri), vm)?; + } + + Ok(host_syslog::setlogmask(maskpri)) } #[inline] diff --git a/crates/stdlib/src/termios.rs b/crates/stdlib/src/termios.rs index 919b4ff702a..9731f33b39f 100644 --- a/crates/stdlib/src/termios.rs +++ b/crates/stdlib/src/termios.rs @@ -31,56 +31,6 @@ mod termios { // TCSBRKP, TIOCGICOUNT, TIOCGLCKTRMIOS, TIOCSERCONFIG, TIOCSERGETLSR, TIOCSERGETMULTI, // TIOCSERGSTRUCT, TIOCSERGWILD, TIOCSERSETMULTI, TIOCSERSWILD, TIOCSER_TEMT, // TIOCSLCKTRMIOS, TIOCSSERIAL, TIOCTTYGSTRUCT - #[cfg(any(target_os = "illumos", target_os = "solaris"))] - #[pyattr] - use libc::{CSTART, CSTOP, CSWTCH}; - #[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd" - ))] - #[pyattr] - use libc::{FIOASYNC, TIOCGETD, TIOCSETD}; - #[pyattr] - use libc::{FIOCLEX, FIONBIO, TIOCGWINSZ, TIOCSWINSZ}; - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd" - ))] - #[pyattr] - use libc::{ - FIONCLEX, FIONREAD, TIOCEXCL, TIOCM_CAR, TIOCM_CD, TIOCM_CTS, TIOCM_DSR, TIOCM_DTR, - TIOCM_LE, TIOCM_RI, TIOCM_RNG, TIOCM_RTS, TIOCM_SR, TIOCM_ST, TIOCMBIC, TIOCMBIS, TIOCMGET, - TIOCMSET, TIOCNXCL, TIOCSCTTY, - }; - #[cfg(any(target_os = "android", target_os = "linux"))] - #[pyattr] - use libc::{ - IBSHIFT, TCFLSH, TCGETA, TCGETS, TCSBRK, TCSETA, TCSETAF, TCSETAW, TCSETS, TCSETSF, - TCSETSW, TCXONC, TIOCGSERIAL, TIOCGSOFTCAR, TIOCINQ, TIOCLINUX, TIOCSSOFTCAR, XTABS, - }; - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "linux", - target_os = "macos" - ))] - #[pyattr] - use libc::{TIOCCONS, TIOCGPGRP, TIOCOUTQ, TIOCSPGRP, TIOCSTI}; - #[cfg(any(target_os = "dragonfly", target_os = "freebsd", target_os = "macos"))] - #[pyattr] - use libc::{ - TIOCNOTTY, TIOCPKT, TIOCPKT_DATA, TIOCPKT_DOSTOP, TIOCPKT_FLUSHREAD, TIOCPKT_FLUSHWRITE, - TIOCPKT_NOSTOP, TIOCPKT_START, TIOCPKT_STOP, - }; #[cfg(any( target_os = "android", target_os = "freebsd", @@ -91,7 +41,7 @@ mod termios { target_os = "solaris" ))] #[pyattr] - use termios::os::target::TAB3; + use host_termios::TAB3; #[cfg(any( target_os = "dragonfly", target_os = "freebsd", @@ -100,7 +50,18 @@ mod termios { target_os = "openbsd" ))] #[pyattr] - use termios::os::target::TCSASOFT; + use host_termios::TCSASOFT; + #[pyattr] + use host_termios::{ + B0, B50, B75, B110, B134, B150, B200, B300, B600, B1200, B1800, B2400, B4800, B9600, + B19200, B38400, B57600, B115200, B230400, BRKINT, CLOCAL, CREAD, CRTSCTS, CS5, CS6, CS7, + CS8, CSIZE, CSTOPB, ECHO, ECHOCTL, ECHOE, ECHOK, ECHOKE, ECHONL, ECHOPRT, EXTA, EXTB, + FLUSHO, HUPCL, ICANON, ICRNL, IEXTEN, IGNBRK, IGNCR, IGNPAR, IMAXBEL, INLCR, INPCK, ISIG, + ISTRIP, IXANY, IXOFF, IXON, NCCS, NOFLSH, OCRNL, ONLCR, ONLRET, ONOCR, OPOST, PARENB, + PARMRK, PARODD, PENDIN, TCIFLUSH, TCIOFF, TCIOFLUSH, TCION, TCOFLUSH, TCOOFF, TCOON, + TCSADRAIN, TCSAFLUSH, TCSANOW, TOSTOP, VDISCARD, VEOF, VEOL, VEOL2, VERASE, VINTR, VKILL, + VLNEXT, VMIN, VQUIT, VREPRINT, VSTART, VSTOP, VSUSP, VTIME, VWERASE, + }; #[cfg(any( target_os = "android", target_os = "freebsd", @@ -110,10 +71,10 @@ mod termios { target_os = "solaris" ))] #[pyattr] - use termios::os::target::{B460800, B921600}; + use host_termios::{B460800, B921600}; #[cfg(any(target_os = "android", target_os = "linux"))] #[pyattr] - use termios::os::target::{ + use host_termios::{ B500000, B576000, B1000000, B1152000, B1500000, B2000000, B2500000, B3000000, B3500000, B4000000, CBAUDEX, }; @@ -125,7 +86,7 @@ mod termios { target_os = "solaris" ))] #[pyattr] - use termios::os::target::{ + use host_termios::{ BS0, BS1, BSDLY, CR0, CR1, CR2, CR3, CRDLY, FF0, FF1, FFDLY, NL0, NL1, NLDLY, OFDEL, OFILL, TAB1, TAB2, VT0, VT1, VTDLY, }; @@ -136,7 +97,7 @@ mod termios { target_os = "solaris" ))] #[pyattr] - use termios::os::target::{CBAUD, CIBAUD, IUCLC, OLCUC, XCASE}; + use host_termios::{CBAUD, CIBAUD, IUCLC, OLCUC, XCASE}; #[cfg(any( target_os = "android", target_os = "freebsd", @@ -146,38 +107,74 @@ mod termios { target_os = "solaris" ))] #[pyattr] - use termios::os::target::{TAB0, TABDLY}; + use host_termios::{TAB0, TABDLY}; #[cfg(any(target_os = "android", target_os = "linux"))] #[pyattr] - use termios::os::target::{VSWTC, VSWTC as VSWTCH}; + use host_termios::{VSWTC, VSWTC as VSWTCH}; #[cfg(any(target_os = "illumos", target_os = "solaris"))] #[pyattr] - use termios::os::target::{VSWTCH, VSWTCH as VSWTC}; + use host_termios::{VSWTCH, VSWTCH as VSWTC}; + #[cfg(any(target_os = "illumos", target_os = "solaris"))] #[pyattr] - use termios::{ - B0, B50, B75, B110, B134, B150, B200, B300, B600, B1200, B1800, B2400, B4800, B9600, - B19200, B38400, BRKINT, CLOCAL, CREAD, CS5, CS6, CS7, CS8, CSIZE, CSTOPB, ECHO, ECHOE, - ECHOK, ECHONL, HUPCL, ICANON, ICRNL, IEXTEN, IGNBRK, IGNCR, IGNPAR, INLCR, INPCK, ISIG, - ISTRIP, IXANY, IXOFF, IXON, NOFLSH, OCRNL, ONLCR, ONLRET, ONOCR, OPOST, PARENB, PARMRK, - PARODD, TCIFLUSH, TCIOFF, TCIOFLUSH, TCION, TCOFLUSH, TCOOFF, TCOON, TCSADRAIN, TCSAFLUSH, - TCSANOW, TOSTOP, VEOF, VEOL, VERASE, VINTR, VKILL, VMIN, VQUIT, VSTART, VSTOP, VSUSP, - VTIME, - os::target::{ - B57600, B115200, B230400, CRTSCTS, ECHOCTL, ECHOKE, ECHOPRT, EXTA, EXTB, FLUSHO, - IMAXBEL, NCCS, PENDIN, VDISCARD, VEOL2, VLNEXT, VREPRINT, VWERASE, - }, + use libc::{CSTART, CSTOP, CSWTCH}; + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + ))] + #[pyattr] + use libc::{FIOASYNC, TIOCGETD, TIOCSETD}; + #[pyattr] + use libc::{FIOCLEX, FIONBIO, TIOCGWINSZ, TIOCSWINSZ}; + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + ))] + #[pyattr] + use libc::{ + FIONCLEX, FIONREAD, TIOCEXCL, TIOCM_CAR, TIOCM_CD, TIOCM_CTS, TIOCM_DSR, TIOCM_DTR, + TIOCM_LE, TIOCM_RI, TIOCM_RNG, TIOCM_RTS, TIOCM_SR, TIOCM_ST, TIOCMBIC, TIOCMBIS, TIOCMGET, + TIOCMSET, TIOCNXCL, TIOCSCTTY, + }; + #[cfg(any(target_os = "android", target_os = "linux"))] + #[pyattr] + use libc::{ + IBSHIFT, TCFLSH, TCGETA, TCGETS, TCSBRK, TCSETA, TCSETAF, TCSETAW, TCSETS, TCSETSF, + TCSETSW, TCXONC, TIOCGSERIAL, TIOCGSOFTCAR, TIOCINQ, TIOCLINUX, TIOCSSOFTCAR, XTABS, + }; + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "linux", + target_os = "macos" + ))] + #[pyattr] + use libc::{TIOCCONS, TIOCGPGRP, TIOCOUTQ, TIOCSPGRP, TIOCSTI}; + #[cfg(any(target_os = "dragonfly", target_os = "freebsd", target_os = "macos"))] + #[pyattr] + use libc::{ + TIOCNOTTY, TIOCPKT, TIOCPKT_DATA, TIOCPKT_DOSTOP, TIOCPKT_FLUSHREAD, TIOCPKT_FLUSHWRITE, + TIOCPKT_NOSTOP, TIOCPKT_START, TIOCPKT_STOP, }; #[pyfunction] fn tcgetattr(fd: i32, vm: &VirtualMachine) -> PyResult> { let termios = host_termios::tcgetattr(fd).map_err(|e| termios_error(e, vm))?; - let noncanon = (termios.c_lflag & termios::ICANON) == 0; + let noncanon = (termios.c_lflag & host_termios::ICANON) == 0; let cc = termios .c_cc .iter() .enumerate() .map(|(i, &c)| match i { - termios::VMIN | termios::VTIME if noncanon => vm.ctx.new_int(c).into(), + host_termios::VMIN | host_termios::VTIME if noncanon => vm.ctx.new_int(c).into(), _ => vm.ctx.new_bytes(vec![c as _]).into(), }) .collect::>(); @@ -186,8 +183,8 @@ mod termios { termios.c_oflag.to_pyobject(vm), termios.c_cflag.to_pyobject(vm), termios.c_lflag.to_pyobject(vm), - termios::cfgetispeed(&termios).to_pyobject(vm), - termios::cfgetospeed(&termios).to_pyobject(vm), + host_termios::cfgetispeed(&termios).to_pyobject(vm), + host_termios::cfgetospeed(&termios).to_pyobject(vm), vm.ctx.new_list(cc).into(), ]; Ok(out) @@ -204,9 +201,9 @@ mod termios { termios.c_oflag = oflag.try_into_value(vm)?; termios.c_cflag = cflag.try_into_value(vm)?; termios.c_lflag = lflag.try_into_value(vm)?; - termios::cfsetispeed(&mut termios, ispeed.try_into_value(vm)?) + host_termios::cfsetispeed(&mut termios, ispeed.try_into_value(vm)?) .map_err(|e| termios_error(e, vm))?; - termios::cfsetospeed(&mut termios, ospeed.try_into_value(vm)?) + host_termios::cfsetospeed(&mut termios, ospeed.try_into_value(vm)?) .map_err(|e| termios_error(e, vm))?; let cc = PyListRef::try_from_object(vm, cc)?; let cc = cc.borrow_vec(); diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 869f3418f11..83e41fa1f5f 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -44,14 +44,12 @@ rustpython-literal = { workspace = true } rustpython-sre_engine = { workspace = true } ascii = { workspace = true } -ahash = { workspace = true } bitflags = { workspace = true } bstr = { workspace = true } crossbeam-utils = { workspace = true } chrono = { workspace = true } constant_time_eq = { workspace = true } flame = { workspace = true, optional = true } -getrandom = { workspace = true } hex = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } @@ -65,6 +63,7 @@ num-traits = { workspace = true } num_enum = { workspace = true } parking_lot = { workspace = true } paste = { workspace = true } +rapidhash = { workspace = true } scopeguard = { workspace = true } serde = { workspace = true, optional = true } static_assertions = { workspace = true } @@ -73,7 +72,6 @@ strum_macros = { workspace = true } thiserror = { workspace = true } memchr = { workspace = true } -caseless = { workspace = true } flamer = { workspace = true, optional = true } half = { workspace = true } psm = { workspace = true } @@ -88,61 +86,18 @@ icu_properties = { workspace = true } writeable = { workspace = true } [target.'cfg(unix)'.dependencies] -rustix = { workspace = true } -nix = { workspace = true } exitcode = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] rustyline = { workspace = true } which = { workspace = true } -errno = { workspace = true } widestring = { workspace = true } -[target.'cfg(all(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "android"), not(any(target_env = "musl", target_env = "sgx"))))'.dependencies] -libffi = { workspace = true, features = ["system"] } -libloading = { workspace = true } - -[target.'cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))'.dependencies] -num_cpus = { workspace = true } - -[target.'cfg(windows)'.dependencies] -junction = { workspace = true } - -[target.'cfg(windows)'.dependencies.windows-sys] -workspace = true -features = [ - "Win32_Foundation", - "Win32_Globalization", - "Win32_Media_Audio", - "Win32_Networking_WinSock", - "Win32_Security", - "Win32_Security_Authorization", - "Win32_Storage_FileSystem", - "Win32_System_Console", - "Win32_System_Diagnostics_Debug", - "Win32_System_Environment", - "Win32_System_IO", - "Win32_System_Ioctl", - "Win32_System_JobObjects", - "Win32_System_Kernel", - "Win32_System_LibraryLoader", - "Win32_System_Memory", - "Win32_System_Performance", - "Win32_System_Pipes", - "Win32_System_Registry", - "Win32_System_SystemInformation", - "Win32_System_SystemServices", - "Win32_System_Threading", - "Win32_System_Time", - "Win32_System_WindowsProgramming", - "Win32_UI_Shell", - "Win32_UI_WindowsAndMessaging", -] - [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies] wasm-bindgen = { workspace = true, optional = true } [build-dependencies] +chrono = { workspace = true } glob = { workspace = true } itertools = { workspace = true } diff --git a/crates/vm/build.rs b/crates/vm/build.rs index 6c65aa8633b..36e7a5d9d27 100644 --- a/crates/vm/build.rs +++ b/crates/vm/build.rs @@ -3,6 +3,8 @@ reason = "build scripts cannot use rustpython-host_env" )] +use chrono::{Local, prelude::DateTime}; +use core::time::Duration; use itertools::Itertools; use std::{ env, @@ -24,6 +26,9 @@ fn main() { } println!("cargo:rerun-if-changed=../../Lib/importlib/_bootstrap.py"); + // = 3.14.0alpha + python_version(3, 14, 0, "alpha", 0); + println!("cargo:rustc-env=RUSTPYTHON_GIT_HASH={}", git_hash()); println!( "cargo:rustc-env=RUSTPYTHON_GIT_TIMESTAMP={}", @@ -31,10 +36,17 @@ fn main() { ); println!("cargo:rustc-env=RUSTPYTHON_GIT_TAG={}", git_tag()); println!("cargo:rustc-env=RUSTPYTHON_GIT_BRANCH={}", git_branch()); + println!( + "cargo:rustc-env=RUSTPYTHON_GIT_IDENTIFIER={}", + git_identifier() + ); + println!("cargo:rustc-env=RUSTPYTHON_BUILD_INFO={}", get_build_info()); println!("cargo:rustc-env=RUSTC_VERSION={}", rustc_version()); let release_level = option_env!("RUSTPYTHON_RELEASE_LEVEL").unwrap_or("alpha"); println!("cargo:rustc-env=RUSTPYTHON_RELEASE_LEVEL={release_level}"); + let release_level_n = release_to_n(release_level); + println!("cargo:rustc-env=RUSTPYTHON_RELEASE_LEVEL_N={release_level_n}"); let release_serial = option_env!("RUSTPYTHON_RELEASE_SERIAL").unwrap_or("0"); println!("cargo:rustc-env=RUSTPYTHON_RELEASE_SERIAL={release_serial}"); @@ -81,11 +93,119 @@ fn git(args: &[&str]) -> io::Result { command("git", args) } +#[must_use] +fn get_build_info() -> String { + // See: https://reproducible-builds.org/docs/timestamps/ + let revision = git_hash(); + let separator = if revision.is_empty() { "" } else { ":" }; + let identifier = git_identifier(); + + format!( + "{id}{sep}{revision}, {date:.20}, {time:.9}", + id = if identifier.is_empty() { + "default" + } else { + &identifier + }, + sep = separator, + revision = revision, + date = get_git_date(), + time = get_git_time(), + ) +} + +fn git_identifier() -> String { + let tag = git_tag(); + if tag.is_empty() || tag.eq_ignore_ascii_case("undefined") { + git_branch() + } else { + tag + } +} + +fn get_git_timestamp_datetime() -> DateTime { + let timestamp = git_timestamp().parse::().unwrap_or_default(); + let datetime = UNIX_EPOCH + Duration::from_secs(timestamp); + datetime.into() +} + +#[must_use] +fn get_git_date() -> String { + let datetime = get_git_timestamp_datetime(); + + datetime.format("%b %e %Y").to_string() +} + +#[must_use] +fn get_git_time() -> String { + let datetime = get_git_timestamp_datetime(); + + datetime.format("%H:%M:%S").to_string() +} + fn rustc_version() -> String { let rustc = env::var_os("RUSTC").unwrap_or_else(|| "rustc".into()); command(rustc, &["-V"]).unwrap_or_else(|_| "rustc [unknown]".into()) } +fn python_version(major: usize, minor: usize, micro: usize, release: &str, serial: usize) { + println!("cargo:rustc-env=MAJOR_CPY={major}"); + println!("cargo:rustc-env=MINOR_CPY={minor}"); + println!("cargo:rustc-env=MICRO_CPY={micro}"); + println!("cargo:rustc-env=RELEASE_LEVEL_CPY={release}"); + println!( + "cargo:rustc-env=RELEASE_LEVEL_N_CPY={}", + release_to_n(release) + ); + println!("cargo:rustc-env=SERIAL_CPY={serial}"); + + println!("cargo:rustc-env=WINVER_CPY={major}.{minor}",); + + let cpy_version = format!("{major}.{minor}.{micro}.{release}"); + + let (left, right) = get_version(&cpy_version); + println!("cargo:rustc-env=RUSTPYTHON_VERSION_LEFT={left}"); + println!("cargo:rustc-env=RUSTPYTHON_VERSION_RIGHT={right}"); +} + +#[must_use] +fn get_version(cpy_version: &str) -> (String, String) { + // Windows: include MSC v. for compatibility with ctypes.util.find_library + // MSC v.1929 = VS 2019, version 14+ makes find_msvcrt() return None + let msc_info = cfg_select! { + windows => {{ + // Include both RustPython identifier and MSC v. for compatibility + if cfg!(target_pointer_width = "64") { + " MSC v.1929 64 bit (AMD64)" + } else { + " MSC v.1929 32 bit (Intel)" + } + }}, + _ => "", + }; + + // `left` and `right` are split by \n like PyPy. Passing a string with a newline to rustc + // truncates everything from the newline onward, so we have to manually combine them later. + let left = format!("{:.80} ({:.80})", cpy_version, get_build_info()); + let right = format!( + "[RustPython {} with {:.80}{}]", + env!("CARGO_PKG_VERSION"), + rustc_version(), + msc_info, + ); + (left, right) +} + +fn release_to_n(release: &str) -> usize { + match release { + "alpha" => 0xA, + "beta" => 0xB, + "candidate" => 0xC, + "final" => 0xD, + _ => panic!("`release` must be one of: 'alpha', 'beta', 'candidate', 'final'"), + } +} + fn command(cmd: impl AsRef, args: &[&str]) -> io::Result { Command::new(&cmd).args(args).output().and_then(|output| { // TODO: Switch to exit_ok()? when stable. diff --git a/crates/vm/src/buffer.rs b/crates/vm/src/buffer.rs index 9a15598eeef..c3ad10e89a7 100644 --- a/crates/vm/src/buffer.rs +++ b/crates/vm/src/buffer.rs @@ -18,7 +18,7 @@ type UnpackFunc = fn(&VirtualMachine, &[u8]) -> PyObjectRef; static OVERFLOW_MSG: &str = "total struct size too long"; // not a const to reduce code size -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum Endianness { Native, Little, @@ -40,7 +40,12 @@ impl Endianness { Some(b'>' | b'!') => Self::Big, _ => return Self::Native, }; - chars.next().unwrap(); + + // SAFETY: + // We just ensured with `chars.peek()` that this is safe + unsafe { + let _ = chars.next().unwrap_unchecked(); + } e } } @@ -48,13 +53,17 @@ impl Endianness { trait ByteOrder { fn convert(i: I) -> I; } + enum BigEndian {} + impl ByteOrder for BigEndian { fn convert(i: I) -> I { i.to_be() } } + enum LittleEndian {} + impl ByteOrder for LittleEndian { fn convert(i: I) -> I { i.to_le() @@ -66,7 +75,7 @@ type NativeEndian = cfg_select! { target_endian = "little" => LittleEndian, }; -#[derive(Copy, Clone, num_enum::TryFromPrimitive)] +#[derive(Copy, Clone, num_enum::TryFromPrimitive, Eq, PartialEq)] #[repr(u8)] pub(crate) enum FormatType { Pad = b'x', @@ -105,6 +114,7 @@ impl FormatType { fn info(self, e: Endianness) -> &'static FormatInfo { use FormatType::*; use mem::{align_of, size_of}; + macro_rules! native_info { ($t:ty) => {{ &FormatInfo { @@ -115,6 +125,7 @@ impl FormatType { } }}; } + macro_rules! nonnative_info { ($t:ty, $end:ty) => {{ &FormatInfo { @@ -125,6 +136,7 @@ impl FormatType { } }}; } + macro_rules! match_nonnative { ($zelf:expr, $end:ty) => {{ match $zelf { @@ -158,6 +170,7 @@ impl FormatType { } }}; } + match e { Endianness::Native => match self { Pad | Str | Pascal => &FormatInfo { @@ -381,6 +394,7 @@ pub(crate) struct FormatInfo { pub pack: Option, pub unpack: Option, } + impl fmt::Debug for FormatInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FormatInfo") diff --git a/crates/vm/src/builtins/bool.rs b/crates/vm/src/builtins/bool.rs index a2bab0acf08..4bb980d71a2 100644 --- a/crates/vm/src/builtins/bool.rs +++ b/crates/vm/src/builtins/bool.rs @@ -37,8 +37,7 @@ impl PyObjectRef { pub fn try_to_bool(self, vm: &VirtualMachine) -> PyResult { if self.is(&vm.ctx.true_value) { return Ok(true); - } - if self.is(&vm.ctx.false_value) { + } else if self.is(&vm.ctx.false_value) { return Ok(false); } @@ -83,10 +82,9 @@ impl Constructor for PyBool { fn slot_new(zelf: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { let x: Self::Args = args.bind(vm)?; if !zelf.fast_isinstance(vm.ctx.types.type_type) { - let actual_class = zelf.class(); - let actual_type = &actual_class.name(); return Err(vm.new_type_error(format!( - "requires a 'type' object but received a '{actual_type}'" + "requires a 'type' object but received a '{}'", + zelf.class().name() ))); } let val = x.map_or(Ok(false), |val| val.try_to_bool(vm))?; diff --git a/crates/vm/src/builtins/capsule.rs b/crates/vm/src/builtins/capsule.rs index 33680fb1973..43efa0fb214 100644 --- a/crates/vm/src/builtins/capsule.rs +++ b/crates/vm/src/builtins/capsule.rs @@ -4,7 +4,7 @@ use crate::{ class::PyClassImpl, types::{Destructor, Representable}, }; -use core::ffi::c_void; +use core::ffi::{CStr, c_void}; use core::sync::atomic::AtomicPtr; /// PyCapsule - a container for C pointers. @@ -13,6 +13,8 @@ use core::sync::atomic::AtomicPtr; #[derive(Debug)] pub struct PyCapsule { ptr: AtomicPtr, + context: AtomicPtr, + name: Option<&'static CStr>, destructor: Option, } @@ -27,10 +29,13 @@ impl PyPayload for PyCapsule { impl PyCapsule { pub fn new( ptr: *mut c_void, + name: Option<&'static CStr>, destructor: Option, ) -> Self { Self { ptr: ptr.into(), + context: core::ptr::null_mut::().into(), + name, destructor, } } @@ -39,6 +44,24 @@ impl PyCapsule { self.ptr.load(core::sync::atomic::Ordering::Relaxed) } + pub fn set_pointer(&self, pointer: *mut c_void) { + self.ptr + .store(pointer, core::sync::atomic::Ordering::Relaxed); + } + + pub fn context(&self) -> *mut c_void { + self.context.load(core::sync::atomic::Ordering::Relaxed) + } + + pub fn set_context(&self, context: *mut c_void) { + self.context + .store(context, core::sync::atomic::Ordering::Relaxed); + } + + pub fn name(&self) -> Option<&CStr> { + self.name + } + fn destructor(&self) -> Option { self.destructor } diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index 1593c250f86..4b2b7c7541e 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -205,7 +205,7 @@ impl PyDict { /// Set item variant which can be called with multiple /// key types, such as str to name a notable one. - pub(crate) fn inner_setitem( + pub fn inner_setitem( &self, key: &K, value: PyObjectRef, @@ -248,10 +248,21 @@ impl PyDict { pub fn size(&self) -> dict_inner::DictSize { self.entries.size() } + + pub fn next_entry(&self, position: usize) -> Option<(usize, PyObjectRef, PyObjectRef)> { + self.entries.next_entry(position) + } + + pub fn inner_getitem_opt( + &self, + key: &K, + vm: &VirtualMachine, + ) -> PyResult> { + self.entries.get(vm, key) + } } // Python dict methods: -#[allow(clippy::len_without_is_empty)] #[pyclass( with( Py, @@ -929,7 +940,6 @@ macro_rules! dict_view { } fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef { - #[allow(clippy::redundant_closure_call)] $result_fn(vm, key, value) } @@ -1005,7 +1015,6 @@ macro_rules! dict_view { self.internal.lock().length_hint(|_| self.size.entries_size) } - #[allow(clippy::redundant_closure_call)] #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef { let iter = builtins_iter(vm); @@ -1024,7 +1033,6 @@ macro_rules! dict_view { impl SelfIter for $iter_name {} impl IterNext for $iter_name { - #[allow(clippy::redundant_closure_call)] fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); let next = if let IterStatus::Active(dict) = &internal.status { @@ -1076,7 +1084,6 @@ macro_rules! dict_view { } } - #[allow(clippy::redundant_closure_call)] #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef { let iter = builtins_reversed(vm); @@ -1103,7 +1110,6 @@ macro_rules! dict_view { impl SelfIter for $reverse_iter_name {} impl IterNext for $reverse_iter_name { - #[allow(clippy::redundant_closure_call)] fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); let next = if let IterStatus::Active(dict) = &internal.status { diff --git a/crates/vm/src/builtins/frame.rs b/crates/vm/src/builtins/frame.rs index 62af017663b..0cb16b2359f 100644 --- a/crates/vm/src/builtins/frame.rs +++ b/crates/vm/src/builtins/frame.rs @@ -43,19 +43,19 @@ pub(crate) mod stack_analysis { } impl Kind { - fn from_i64(v: i64) -> Option { - match v { - 1 => Some(Self::Iterator), - 2 => Some(Self::Except), - 3 => Some(Self::Object), - 4 => Some(Self::Null), - 5 => Some(Self::Lasti), - _ => None, - } + const fn from_i64(v: i64) -> Option { + Some(match v { + 1 => Self::Iterator, + 2 => Self::Except, + 3 => Self::Object, + 4 => Self::Null, + 5 => Self::Lasti, + _ => return None, + }) } } - pub(crate) fn push_value(stack: i64, kind: i64) -> i64 { + pub(crate) const fn push_value(stack: i64, kind: i64) -> i64 { if (stack as u64) >= WILL_OVERFLOW { OVERFLOWED } else { @@ -63,20 +63,20 @@ pub(crate) mod stack_analysis { } } - pub(crate) fn pop_value(stack: i64) -> i64 { + pub(crate) const fn pop_value(stack: i64) -> i64 { stack >> BITS_PER_BLOCK } - pub(crate) fn top_of_stack(stack: i64) -> i64 { + pub(crate) const fn top_of_stack(stack: i64) -> i64 { stack & MASK } - fn peek(stack: i64, n: u32) -> i64 { + const fn peek(stack: i64, n: u32) -> i64 { debug_assert!(n >= 1); (stack >> (BITS_PER_BLOCK * (n - 1))) & MASK } - fn stack_swap(stack: i64, n: u32) -> i64 { + const fn stack_swap(stack: i64, n: u32) -> i64 { debug_assert!(n >= 1); let to_swap = peek(stack, n); let top = top_of_stack(stack); @@ -85,7 +85,7 @@ pub(crate) mod stack_analysis { (replaced_low & !MASK) | to_swap } - fn pop_to_level(mut stack: i64, level: u32) -> i64 { + const fn pop_to_level(mut stack: i64, level: u32) -> i64 { if level == 0 { return EMPTY_STACK; } @@ -97,20 +97,21 @@ pub(crate) mod stack_analysis { stack } - fn compatible_kind(from: i64, to: i64) -> bool { + #[must_use] + const fn compatible_kind(from: i64, to: i64) -> bool { if to == 0 { - return false; - } - if to == Kind::Object as i64 { - return from != Kind::Null as i64; - } - if to == Kind::Null as i64 { - return true; + false + } else if to == Kind::Object as i64 { + from != Kind::Null as i64 + } else if to == Kind::Null as i64 { + true + } else { + from == to } - from == to } - pub(crate) fn compatible_stack(from_stack: i64, to_stack: i64) -> bool { + #[must_use] + pub(crate) const fn compatible_stack(from_stack: i64, to_stack: i64) -> bool { if from_stack < 0 || to_stack < 0 { return false; } @@ -131,14 +132,17 @@ pub(crate) mod stack_analysis { to == 0 } - pub(crate) fn explain_incompatible_stack(to_stack: i64) -> &'static str { + pub(crate) const fn explain_incompatible_stack(to_stack: i64) -> &'static str { debug_assert!(to_stack != 0); + if to_stack == OVERFLOWED { return "stack is too deep to analyze"; } + if to_stack == UNINITIALIZED { return "can't jump into an exception handler, or code may be unreachable"; } + match Kind::from_i64(top_of_stack(to_stack)) { Some(Kind::Except) => "can't jump into an 'except' block as there's no exception", Some(Kind::Lasti) => "can't jump into a re-raising block as there's no location", diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 5b12401f5ac..6052e4fe256 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -33,11 +33,13 @@ fn format_missing_args( missing: &mut Vec, ) -> String { let count = missing.len(); + let last = if missing.len() > 1 { missing.pop() } else { None }; + let (and, right): (&str, String) = if let Some(last) = last { ( if missing.len() == 1 { @@ -45,11 +47,12 @@ fn format_missing_args( } else { "', and '" }, - format!("{last}"), + last.to_string(), ) } else { ("", String::new()) }; + format!( "{qualname}() missing {count} required {kind} argument{}: '{}{}{right}'", if count == 1 { "" } else { "s" }, @@ -561,7 +564,8 @@ impl Py { let is_gen = code.flags.contains(bytecode::CodeFlags::GENERATOR); let is_coro = code.flags.contains(bytecode::CodeFlags::COROUTINE); - let use_datastack = !(is_gen || is_coro); + let is_async_gen = code.flags.contains(bytecode::CodeFlags::ASYNC_GENERATOR); + let use_datastack = !(is_gen || is_coro || is_async_gen); // Construct frame: let frame = Frame::new( @@ -576,35 +580,30 @@ impl Py { .into_ref(&vm.ctx); self.fill_locals_from_args(&frame, func_args, vm)?; - match (is_gen, is_coro) { - (true, false) => { - let obj = PyGenerator::new(frame.clone(), self.__name__(), self.__qualname__()) - .into_pyobject(vm); - frame.set_generator(&obj); - Ok(obj) - } - (false, true) => { - let obj = PyCoroutine::new(frame.clone(), self.__name__(), self.__qualname__()) - .into_pyobject(vm); - frame.set_generator(&obj); - Ok(obj) - } - (true, true) => { - let obj = PyAsyncGen::new(frame.clone(), self.__name__(), self.__qualname__()) - .into_pyobject(vm); - frame.set_generator(&obj); - Ok(obj) - } - (false, false) => { - let result = vm.run_frame(frame.clone()); - // Release data stack memory after frame execution completes. - unsafe { - if let Some(base) = frame.materialize_localsplus() { - vm.datastack_pop(base); - } + if is_async_gen { + let obj = PyAsyncGen::new(frame.clone(), self.__name__(), self.__qualname__()) + .into_pyobject(vm); + frame.set_generator(&obj); + Ok(obj) + } else if is_gen { + let obj = PyGenerator::new(frame.clone(), self.__name__(), self.__qualname__()) + .into_pyobject(vm); + frame.set_generator(&obj); + Ok(obj) + } else if is_coro { + let obj = PyCoroutine::new(frame.clone(), self.__name__(), self.__qualname__()) + .into_pyobject(vm); + frame.set_generator(&obj); + Ok(obj) + } else { + let result = vm.run_frame(frame.clone()); + // Release data stack memory after frame execution completes. + unsafe { + if let Some(base) = frame.materialize_localsplus() { + vm.datastack_pop(base); } - result } + result } } @@ -686,11 +685,11 @@ impl Py { .intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS) ); debug_assert_eq!(code.kwonlyarg_count, 0); - debug_assert!( - !code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - ); + debug_assert!(!code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + )); let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) { None @@ -738,10 +737,11 @@ impl Py { // Generator/coroutine code objects are SIMPLE_FUNCTION in call // specialization classification, but their call path must still // go through invoke() to produce generator/coroutine objects. - if code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - { + if code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ) { return self.invoke(FuncArgs::from(args), vm); } let frame = self.prepare_exact_args_frame(args, vm); @@ -757,10 +757,11 @@ impl Py { } pub(crate) fn datastack_frame_size_bytes_for_code(code: &Py) -> Option { - if code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - { + if code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ) { return None; } let nlocalsplus = code.localspluskinds.len(); @@ -1268,12 +1269,12 @@ impl PyBoundMethod { } #[inline] - pub(crate) fn function_obj(&self) -> &PyObjectRef { + pub(crate) const fn function_obj(&self) -> &PyObjectRef { &self.function } #[inline] - pub(crate) fn self_obj(&self) -> &PyObjectRef { + pub(crate) const fn self_obj(&self) -> &PyObjectRef { &self.object } @@ -1398,6 +1399,7 @@ impl Representable for PyBoundMethod { pub(crate) struct PyCell { contents: PyMutex>, } + pub(crate) type PyCellRef = PyRef; impl PyPayload for PyCell { @@ -1426,6 +1428,7 @@ impl PyCell { pub(crate) fn get(&self) -> Option { self.contents.lock().clone() } + pub(crate) fn set(&self, x: Option) { *self.contents.lock() = x; } @@ -1435,6 +1438,7 @@ impl PyCell { self.get() .ok_or_else(|| vm.new_value_error("Cell is empty")) } + #[pygetset(setter)] fn set_cell_contents(&self, x: PySetterValue) { match x { @@ -1462,9 +1466,11 @@ pub(crate) fn vectorcall_function( && !code.flags.contains(bytecode::CodeFlags::VARARGS) && !code.flags.contains(bytecode::CodeFlags::VARKEYWORDS) && code.kwonlyarg_count == 0 - && !code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE); + && !code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ); if is_simple && nargs == code.arg_count as usize { // FAST PATH: simple positional-only call, exact arg count. @@ -1488,6 +1494,7 @@ pub(crate) fn vectorcall_function( args.truncate(nargs); FuncArgs::from(args) }; + zelf.invoke(func_args, vm) } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index a59d9367f03..f4af27ea492 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -245,7 +245,6 @@ impl PyList { Self::from(self.borrow_vec().to_vec()).into_ref(&vm.ctx) } - #[allow(clippy::len_without_is_empty)] pub fn __len__(&self) -> usize { self.borrow_vec().len() } diff --git a/crates/vm/src/builtins/module.rs b/crates/vm/src/builtins/module.rs index f9d3df8df97..51d90e9b32e 100644 --- a/crates/vm/src/builtins/module.rs +++ b/crates/vm/src/builtins/module.rs @@ -82,10 +82,6 @@ impl PyModuleDef { } } -#[allow( - clippy::new_without_default, - reason = "avoid a misleading Default implementation" -)] #[pyclass(module = false, name = "module")] #[derive(Debug)] pub struct PyModule { @@ -112,7 +108,10 @@ pub struct ModuleInitArgs { } impl PyModule { - #[allow(clippy::new_without_default)] + #[expect( + clippy::new_without_default, + reason = "avoid a misleading Default implementation" + )] #[must_use] pub const fn new() -> Self { Self { diff --git a/crates/vm/src/builtins/namespace.rs b/crates/vm/src/builtins/namespace.rs index 1768cf7d985..bd574fd97be 100644 --- a/crates/vm/src/builtins/namespace.rs +++ b/crates/vm/src/builtins/namespace.rs @@ -54,6 +54,21 @@ impl PyNamespace { let cls: PyObjectRef = zelf.class().to_owned().into(); let result = cls.call((), vm)?; + if !zelf.class().is(result.class()) { + return Err(vm.new_type_error(format!( + "expect {} type, but {}() returned '{}' object", + Self::class(&vm.ctx).slot_name(), + zelf.class() + .__qualname__(vm) + .downcast_ref::() + .map_or_else( + || zelf.class().name().to_string(), + |n| n.as_wtf8().to_string() + ), + result.class().name(), + ))); + } + // Copy the current namespace dict to the new instance let src_dict = zelf.dict().unwrap(); let dst_dict = result.dict().unwrap(); diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 960c3581301..212b14c7487 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -45,10 +45,10 @@ use rustpython_common::{ hash, lock::PyMutex, str::DeduceStrKind, - wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk, Wtf8Concat}, + wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Concat}, }; -use icu_casemap::TitlecaseMapper; +use icu_casemap::{CaseMapper, TitlecaseMapper}; use icu_locale::LanguageIdentifier; use icu_properties::props::{ BidiClass, BinaryProperty, CaseIgnorable, Cased, EnumeratedProperty, GeneralCategory, @@ -743,20 +743,31 @@ impl PyStr { } } - // casefold is much more aggressive than lower + // Case folding is a Unicode standard operation to erase case differences. + // + // Lower, upper, and title case are special properties. Case folding erases those + // differences. For ASCII, case folding is the same as lower case but other scripts have + // their own, well-defined mappings. #[pymethod] fn casefold(&self) -> Self { match self.as_str_kind() { - PyKindStr::Ascii(s) => caseless::default_case_fold_str(s.as_str()).into(), - PyKindStr::Utf8(s) => caseless::default_case_fold_str(s).into(), - PyKindStr::Wtf8(w) => w - .chunks() - .map(|c| match c { - Wtf8Chunk::Utf8(s) => Wtf8Buf::from_string(caseless::default_case_fold_str(s)), - Wtf8Chunk::Surrogate(c) => Wtf8Buf::from(c), - }) - .collect::() - .into(), + PyKindStr::Ascii(s) => s.to_ascii_lowercase().into(), + PyKindStr::Utf8(s) => CaseMapper::new().fold_string(s).to_string().into(), + PyKindStr::Wtf8(w) => { + let mut out = VecFmtWriter(Vec::with_capacity(w.len())); + let mapper = CaseMapper::new(); + for chunk in w.as_bytes().utf8_chunks() { + mapper + .fold(chunk.valid()) + .write_to(&mut out) + .expect("Writing to an in-memory buffer cannot fail."); + out.0.extend(chunk.invalid()); + } + // SAFETY: + // * CaseMapper only produces valid UTF-8 + // * Surrogates are appended as-is + unsafe { Wtf8Buf::from_bytes_unchecked(out.0) }.into() + } } } diff --git a/crates/vm/src/builtins/super.rs b/crates/vm/src/builtins/super.rs index f75b9b36327..c44b61d71e9 100644 --- a/crates/vm/src/builtins/super.rs +++ b/crates/vm/src/builtins/super.rs @@ -86,6 +86,7 @@ impl Initializer for PySuper { if frame.code.arg_count == 0 { return Err(vm.new_runtime_error("super(): no arguments")); } + // SAFETY: Frame is current and not concurrently mutated. use rustpython_compiler_core::bytecode::CO_FAST_CELL; let obj = unsafe { frame.fastlocals() }[0] @@ -165,9 +166,9 @@ impl GetAttr for PySuper { Some(o) => o.clone(), None => return skip(zelf, name), }; + // We want __class__ to return the class of the super object // (i.e. super, or a subclass), not the class of su->obj. - if name.as_bytes() == b"__class__" { return skip(zelf, name); } @@ -280,21 +281,23 @@ pub(crate) fn init(context: &'static Context) { let super_type = &context.types.super_type; PySuper::extend_class(context, super_type); - let super_doc = "super() -> same as super(__class__, )\n\ - super(type) -> unbound super object\n\ - super(type, obj) -> bound super object; requires isinstance(obj, type)\n\ - super(type, type2) -> bound super object; requires issubclass(type2, type)\n\ - Typical use to call a cooperative superclass method:\n\ - class C(B):\n \ - def meth(self, arg):\n \ - super().meth(arg)\n\ - This works for class methods too:\n\ - class C(B):\n \ - @classmethod\n \ - def cmeth(cls, arg):\n \ - super().cmeth(arg)\n"; + const SUPER_DOC: &str = "\ +super() -> same as super(__class__, ) +super(type) -> unbound super object +super(type, obj) -> bound super object; requires isinstance(obj, type) +super(type, type2) -> bound super object; requires issubclass(type2, type) +Typical use to call a cooperative superclass method: +class C(B): + def meth(self, arg): + super().meth(arg) +This works for class methods too: +class C(B): + @classmethod + def cmeth(cls, arg): + super().cmeth(arg) +"; extend_class!(context, super_type, { - "__doc__" => context.new_str(super_doc), + "__doc__" => context.new_str(SUPER_DOC), }); } diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index 81e7ce6f9f7..4606509fd19 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -1,3 +1,5 @@ +// cspell:ignore pyhash + use super::{ PositionIterInternal, PyGenericAlias, PyStrRef, PyType, PyTypeRef, iter::builtins_iter, }; @@ -296,13 +298,13 @@ impl PyTuple { #[inline] #[must_use] - pub fn len(&self) -> usize { + pub const fn len(&self) -> usize { self.elements.len() } #[inline] #[must_use] - pub fn is_empty(&self) -> bool { + pub const fn is_empty(&self) -> bool { self.elements.is_empty() } @@ -725,23 +727,29 @@ pub(crate) fn init(context: &'static Context) { } pub(super) fn tuple_hash(elements: &[PyObjectRef], vm: &VirtualMachine) -> PyResult { - #[cfg(target_pointer_width = "64")] - const PRIME1: PyUHash = 11400714785074694791; - #[cfg(target_pointer_width = "64")] - const PRIME2: PyUHash = 14029467366897019727; - #[cfg(target_pointer_width = "64")] - const PRIME5: PyUHash = 2870177450012600261; - #[cfg(target_pointer_width = "64")] - const ROTATE: u32 = 31; - - #[cfg(target_pointer_width = "32")] - const PRIME1: PyUHash = 2654435761; - #[cfg(target_pointer_width = "32")] - const PRIME2: PyUHash = 2246822519; - #[cfg(target_pointer_width = "32")] - const PRIME5: PyUHash = 374761393; - #[cfg(target_pointer_width = "32")] - const ROTATE: u32 = 13; + const PRIME1: PyUHash = cfg_select! { + target_pointer_width = "64" => 11400714785074694791, + target_pointer_width = "32" => 2654435761, + _ => unreachable!(), + }; + + const PRIME2: PyUHash = cfg_select! { + target_pointer_width = "64" => 14029467366897019727, + target_pointer_width = "32" => 2246822519, + _ => unreachable!(), + }; + + const PRIME5: PyUHash = cfg_select! { + target_pointer_width = "64" => 2870177450012600261, + target_pointer_width = "32" => 374761393, + _ => unreachable!(), + }; + + const ROTATE: u32 = cfg_select! { + target_pointer_width = "64" => 31, + target_pointer_width = "32" => 13, + _ => unreachable!(), + }; let mut acc = PRIME5; let len = elements.len() as PyUHash; @@ -755,8 +763,10 @@ pub(super) fn tuple_hash(elements: &[PyObjectRef], vm: &VirtualMachine) -> PyRes acc = acc.wrapping_add(len ^ (PRIME5 ^ 3527539)); - if acc as PyHash == -1 { + let acc_pyhash = acc as PyHash; + if acc_pyhash == -1 { return Ok(1546275796); } - Ok(acc as PyHash) + + Ok(acc_pyhash) } diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 8b920c2fee1..c097a18a516 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -408,7 +408,8 @@ cfg_select! { /// For attributes we do not use a dict, but an IndexMap, which is an Hash Table /// that maintains order and is compatible with the standard HashMap This is probably /// faster and only supports strings as keys. -pub(crate) type PyAttributes = IndexMap<&'static PyStrInterned, PyObjectRef, ahash::RandomState>; +pub(crate) type PyAttributes = + IndexMap<&'static PyStrInterned, PyObjectRef, rapidhash::quality::RandomState>; unsafe impl Traverse for PyAttributes { fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { @@ -2961,7 +2962,7 @@ mod tests { } #[test] - fn test_linearise() { + fn linearise() { let context = Context::genesis(); let object = context.types.object_type.to_owned(); let type_type = context.types.type_type.to_owned(); diff --git a/crates/vm/src/builtins/weakproxy.rs b/crates/vm/src/builtins/weakproxy.rs index 8cdc206db88..077e8c34963 100644 --- a/crates/vm/src/builtins/weakproxy.rs +++ b/crates/vm/src/builtins/weakproxy.rs @@ -136,7 +136,7 @@ impl IterNext for PyWeakProxy { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let obj = zelf.try_upgrade(vm)?; if obj.class().slots.iternext.load().is_none() { - return Err(vm.new_type_error("Weakref proxy referenced a non-iterator".to_owned())); + return Err(vm.new_type_error("Weakref proxy referenced a non-iterator")); } PyIter::new(obj).next(vm) } diff --git a/crates/vm/src/builtins/weakref.rs b/crates/vm/src/builtins/weakref.rs index c7c4466a1d2..dbf2941c2fc 100644 --- a/crates/vm/src/builtins/weakref.rs +++ b/crates/vm/src/builtins/weakref.rs @@ -32,6 +32,7 @@ impl PyPayload for PyWeak { impl Callable for PyWeak { type Args = (); + #[inline] fn call(zelf: &Py, _: Self::Args, vm: &VirtualMachine) -> PyResult { Ok(vm.unwrap_or_none(zelf.upgrade())) @@ -50,10 +51,10 @@ impl Constructor for PyWeak { .ok_or_else(|| vm.new_type_error("__new__ expected at least 1 argument, got 0"))?; let callback = positional.next(); if let Some(_extra) = positional.next() { - return Err(vm.new_type_error(format!( - "__new__ expected at most 2 arguments, got {}", - 3 + positional.count() - ))); + let got = positional.count() + 3; + return Err( + vm.new_type_error(format!("__new__ expected at most 2 arguments, got {got}")) + ); } let weak = referent.downgrade_with_typ(callback, cls, vm)?; Ok(weak.into()) @@ -151,7 +152,7 @@ impl Representable for PyWeak { #[inline] fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { let id = zelf.get_id(); - let repr = if let Some(o) = zelf.upgrade() { + Ok(if let Some(o) = zelf.upgrade() { format!( "", id, @@ -160,8 +161,7 @@ impl Representable for PyWeak { ) } else { format!("") - }; - Ok(repr) + }) } } diff --git a/crates/vm/src/dict_inner.rs b/crates/vm/src/dict_inner.rs index 6cf1d52e436..8112bbe252b 100644 --- a/crates/vm/src/dict_inner.rs +++ b/crates/vm/src/dict_inner.rs @@ -1245,7 +1245,7 @@ mod tests { use crate::{Interpreter, common::ascii}; #[test] - fn test_insert() { + fn insert_basic() { Interpreter::without_stdlib(Default::default()).enter(|vm| { let dict = Dict::default(); assert_eq!(0, dict.len()); @@ -1290,8 +1290,8 @@ mod tests { } hash_tests! { - test_abc: "abc", - test_x: "x", + abc: "abc", + x: "x", } fn check_hash_equivalence(text: &str) { diff --git a/crates/vm/src/eval.rs b/crates/vm/src/eval.rs index be09b3e4cc9..b69a4e05bfb 100644 --- a/crates/vm/src/eval.rs +++ b/crates/vm/src/eval.rs @@ -16,7 +16,7 @@ mod tests { use crate::Interpreter; #[test] - fn test_print_42() { + fn print_42() { Interpreter::without_stdlib(Default::default()).enter(|vm| { let source = String::from("print('Hello world')"); let vars = vm.new_scope_with_builtins(); diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 51ce99de82b..b91e567e731 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -1385,28 +1385,18 @@ impl ToOSErrorBuilder for std::io::Error { // Use C runtime's strerror for POSIX errno values. // For Windows-specific error codes, fall back to FormatMessage. const MAX_POSIX_ERRNO: i32 = 127; - if errno > 0 && errno <= MAX_POSIX_ERRNO { - let ptr = unsafe { libc::strerror(errno) }; - if !ptr.is_null() { - let s = unsafe { core::ffi::CStr::from_ptr(ptr) }.to_string_lossy(); - if !s.starts_with("Unknown error") { - break 'msg s.into_owned(); - } - } + if errno > 0 + && errno <= MAX_POSIX_ERRNO + && let Some(s) = crate::host_env::errno::strerror_string(errno) + && !s.starts_with("Unknown error") + { + break 'msg s; } self.to_string() }; #[cfg(unix)] - let msg = { - let ptr = unsafe { libc::strerror(errno) }; - if !ptr.is_null() { - unsafe { core::ffi::CStr::from_ptr(ptr) } - .to_string_lossy() - .into_owned() - } else { - self.to_string() - } - }; + let msg = + crate::host_env::errno::strerror_string(errno).unwrap_or_else(|| self.to_string()); #[cfg(not(any(windows, unix)))] let msg = self.to_string(); @@ -1433,17 +1423,160 @@ impl IntoPyException for std::io::Error { } } +#[cfg(all(unix, not(target_os = "redox")))] +impl ToPyException for rustpython_host_env::fcntl::LockfError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::InvalidCmd => vm.new_value_error("unrecognized lockf argument"), + Self::Overflow(e) => vm.new_overflow_error(e.clone()), + Self::Io(err) => err.to_pyexception(vm), + } + } +} + #[cfg(unix)] -impl IntoPyException for nix::Error { - fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { - std::io::Error::from(self).into_pyexception(vm) +impl ToPyException for rustpython_host_env::posix::AccessError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::InvalidMode => vm.new_value_error( + "One of the flags is wrong, there are only 4 possibilities F_OK, R_OK, W_OK and X_OK", + ), + Self::Os(errno) => std::io::Error::from_raw_os_error(*errno).to_pyexception(vm), + } + } +} + +#[cfg(all(unix, not(target_os = "redox")))] +impl ToPyException for rustpython_host_env::socket::AncillaryPackError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::ItemTooLarge => vm.new_os_error("ancillary data item too large"), + Self::TooMuchData => vm.new_os_error("too much ancillary data"), + Self::UnexpectedNullHeader => { + vm.new_runtime_error("unexpected NULL result from CMSG_FIRSTHDR/CMSG_NXTHDR") + } + } + } +} + +#[cfg(any(unix, windows))] +impl ToPyException for rustpython_host_env::time::CheckedTmError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::YearOutOfRange => vm.new_overflow_error("year out of range"), + Self::MonthOutOfRange => vm.new_value_error("month out of range"), + Self::DayOfMonthOutOfRange => vm.new_value_error("day of month out of range"), + Self::HourOutOfRange => vm.new_value_error("hour out of range"), + Self::MinuteOutOfRange => vm.new_value_error("minute out of range"), + Self::SecondsOutOfRange => vm.new_value_error("seconds out of range"), + Self::DayOfWeekOutOfRange => vm.new_value_error("day of week out of range"), + Self::DayOfYearOutOfRange => vm.new_value_error("day of year out of range"), + Self::EmbeddedNul => vm.new_value_error("embedded null character"), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::winapi::BuildEnvironmentBlockError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::ContainsNul => vm.new_value_error("embedded null character"), + Self::IllegalName => vm.new_value_error("illegal environment variable name"), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::winapi::BatchedWaitError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::Timeout => vm + .new_os_subtype_error( + vm.ctx.exceptions.timeout_error.to_owned(), + None, + "timed out", + ) + .upcast(), + Self::Interrupted => vm + .new_errno_error(libc::EINTR, "Interrupted system call") + .upcast(), + Self::Os(err) => vm.new_os_error(*err as i32), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::nt::ReadlinkError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::Io(err) => err.to_pyexception(vm), + Self::NotSymbolicLink => { + vm.new_os_error("The file or directory is not a reparse point") + } + Self::InvalidReparseData => vm.new_os_error("Invalid reparse data"), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::nt::ReadConsoleError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::Io(err) => err.to_pyexception(vm), + Self::BufferTooSmall { + available, + required, + } => vm.new_system_error(format!( + "Buffer had room for {available} bytes but {required} bytes required", + )), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::winreg::ExpandEnvironmentStringsError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::Os => vm.new_os_error("ExpandEnvironmentStringsW failed"), + Self::Utf16(e) => vm.new_value_error(format!("UTF16 error: {e}")), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::winreg::QueryStringError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::Code(err) => std::io::Error::from_raw_os_error(*err as i32).to_pyexception(vm), + Self::Utf16(e) => vm.new_value_error(format!("UTF16 error: {e}")), + } + } +} + +#[cfg(windows)] +impl ToPyException for rustpython_host_env::wmi::ExecQueryError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::MoreData => vm.new_os_error(format!( + "Query returns more than {} characters", + rustpython_host_env::wmi::BUFFER_SIZE + )), + Self::Code(err) => std::io::Error::from_raw_os_error(*err as i32).to_pyexception(vm), + } } } #[cfg(unix)] -impl IntoPyException for rustix::io::Errno { - fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { - std::io::Error::from(self).into_pyexception(vm) +impl ToPyException for rustpython_host_env::multiprocessing::SemError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + let excs = &vm.ctx.exceptions; + let exc_type = match self { + Self::AlreadyExists => excs.file_exists_error.to_owned(), + Self::NotFound => excs.file_not_found_error.to_owned(), + _ => excs.os_error.to_owned(), + }; + vm.new_os_subtype_error(exc_type, Some(self.raw_os_error()), self.description()) + .upcast() } } diff --git a/crates/vm/src/format.rs b/crates/vm/src/format.rs index 77322d0629a..7e18dd75d2f 100644 --- a/crates/vm/src/format.rs +++ b/crates/vm/src/format.rs @@ -10,34 +10,29 @@ use crate::common::format::*; use crate::common::wtf8::{Wtf8, Wtf8Buf}; /// Get locale information from C `localeconv()` for the 'n' format specifier. -#[cfg(unix)] +#[cfg(any(unix, windows))] pub(crate) fn get_locale_info() -> LocaleInfo { - use core::ffi::CStr; - unsafe { - let lc = libc::localeconv(); - if lc.is_null() { - return LocaleInfo { - thousands_sep: String::new(), - decimal_point: ".".to_string(), - grouping: vec![], - }; - } - let thousands_sep = CStr::from_ptr((*lc).thousands_sep) - .to_string_lossy() - .into_owned(); - let decimal_point = CStr::from_ptr((*lc).decimal_point) - .to_string_lossy() - .into_owned(); - let grouping = parse_grouping((*lc).grouping); - LocaleInfo { - thousands_sep, - decimal_point, - grouping, - } + let lc = crate::host_env::locale::localeconv_data(); + #[allow( + clippy::unnecessary_cast, + reason = "libc::c_char is not u8 on all platforms" + )] + let mut grouping = lc.grouping.iter().map(|&c| c as u8).collect::>(); + if !grouping.is_empty() { + grouping.push(0); + } + LocaleInfo { + thousands_sep: String::from_utf8_lossy(&lc.thousands_sep).into_owned(), + decimal_point: if lc.decimal_point.is_empty() { + ".".to_string() + } else { + String::from_utf8_lossy(&lc.decimal_point).into_owned() + }, + grouping, } } -#[cfg(not(unix))] +#[cfg(not(any(unix, windows)))] pub(crate) fn get_locale_info() -> LocaleInfo { LocaleInfo { thousands_sep: String::new(), @@ -46,30 +41,6 @@ pub(crate) fn get_locale_info() -> LocaleInfo { } } -/// Parse C `lconv.grouping` into a `Vec`. -/// Reads bytes until 0 or CHAR_MAX, then appends 0 (meaning "repeat last group"). -#[cfg(unix)] -unsafe fn parse_grouping(grouping: *const libc::c_char) -> Vec { - let mut result = Vec::new(); - if grouping.is_null() { - return result; - } - - unsafe { - let mut ptr = grouping; - while ![0, libc::c_char::MAX].contains(&*ptr) { - result.push(*ptr as _); - ptr = ptr.add(1); - } - } - - if !result.is_empty() { - result.push(0); - } - - result -} - impl IntoPyException for FormatSpecError { fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { match self { diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 1bafe7f26a4..f1ed31d7189 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -710,10 +710,11 @@ impl Frame { // For generators/coroutines, initialize prev_line to the def line // so that preamble instructions (RETURN_GENERATOR, POP_TOP) don't // fire spurious LINE events. - let prev_line = if code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - { + let prev_line = if code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ) { code.first_line_number.map_or(0, |line| line.get() as u32) } else { 0 @@ -9523,9 +9524,7 @@ impl ExecutingFrame<'_> { // Returns the exception object; RERAISE will re-raise it if arg.fast_isinstance(vm.ctx.exceptions.stop_iteration) { let flags = &self.code.flags; - let msg = if flags - .contains(bytecode::CodeFlags::COROUTINE | bytecode::CodeFlags::GENERATOR) - { + let msg = if flags.contains(bytecode::CodeFlags::ASYNC_GENERATOR) { "async generator raised StopIteration" } else if flags.contains(bytecode::CodeFlags::COROUTINE) { "coroutine raised StopIteration" diff --git a/crates/vm/src/function/builtin.rs b/crates/vm/src/function/builtin.rs index 444df64a8ef..4fed5e4cf23 100644 --- a/crates/vm/src/function/builtin.rs +++ b/crates/vm/src/function/builtin.rs @@ -221,7 +221,7 @@ mod tests { use core::mem::size_of_val; #[test] - fn test_into_native_fn_noalloc() { + fn into_native_fn_noalloc() { fn py_func(_b: bool, _vm: &crate::VirtualMachine) -> i32 { 1 } diff --git a/crates/vm/src/function/method.rs b/crates/vm/src/function/method.rs index 6c59172b633..ee483e361e6 100644 --- a/crates/vm/src/function/method.rs +++ b/crates/vm/src/function/method.rs @@ -299,9 +299,14 @@ impl Py { unsafe { &*(&self.method as *const _) } } - pub fn build_function(&self, vm: &VirtualMachine) -> PyRef { + pub fn build_function( + &self, + vm: &VirtualMachine, + zelf: Option, + ) -> PyRef { let mut function = unsafe { self.method() }.to_function(); function._method_def_owner = Some(self.to_owned().into()); + function.zelf = zelf; PyRef::new_ref( function, vm.ctx.types.builtin_function_or_method_type.to_owned(), diff --git a/crates/vm/src/gc_state.rs b/crates/vm/src/gc_state.rs index d29b9bb6c99..ecb34fcc869 100644 --- a/crates/vm/src/gc_state.rs +++ b/crates/vm/src/gc_state.rs @@ -874,7 +874,7 @@ mod tests { use super::*; #[test] - fn test_gc_state_default() { + fn gc_state_default() { let state = GcState::new(); assert!(state.is_enabled()); assert_eq!(state.get_debug(), GcDebugFlags::empty()); @@ -883,7 +883,7 @@ mod tests { } #[test] - fn test_gc_enable_disable() { + fn gc_enable_disable() { let state = GcState::new(); assert!(state.is_enabled()); state.disable(); @@ -893,14 +893,14 @@ mod tests { } #[test] - fn test_gc_threshold() { + fn gc_threshold() { let state = GcState::new(); state.set_threshold(100, Some(20), Some(30)); assert_eq!(state.get_threshold(), (100, 20, 30)); } #[test] - fn test_gc_debug_flags() { + fn gc_debug_flags() { let state = GcState::new(); state.set_debug(GcDebugFlags::STATS | GcDebugFlags::COLLECTABLE); assert_eq!( diff --git a/crates/vm/src/getpath.rs b/crates/vm/src/getpath.rs index 51f85046ea7..63f922a6365 100644 --- a/crates/vm/src/getpath.rs +++ b/crates/vm/src/getpath.rs @@ -369,9 +369,11 @@ fn get_executable_path() -> Option { } /// Parse pyvenv.cfg and extract the 'home' key value -#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] fn parse_pyvenv_home(pyvenv_cfg: &Path) -> Option { + #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] let content = crate::host_env::fs::read_to_string(pyvenv_cfg).ok()?; + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + let content = std::fs::read_to_string(pyvenv_cfg).ok()?; for line in content.lines() { if let Some((key, value)) = line.split_once('=') @@ -384,17 +386,12 @@ fn parse_pyvenv_home(pyvenv_cfg: &Path) -> Option { None } -#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] -fn parse_pyvenv_home(_pyvenv_cfg: &Path) -> Option { - None -} - #[cfg(test)] mod tests { use super::*; #[test] - fn test_init_path_config() { + fn init_path_config_basic() { let settings = Settings::default(); let paths = init_path_config(&settings); // Just verify it doesn't panic and returns valid paths @@ -402,7 +399,7 @@ mod tests { } #[test] - fn test_search_up() { + fn search_up() { // Test with a path that doesn't have any landmarks let result = search_up_file( crate::host_env::os::temp_dir(), @@ -412,7 +409,7 @@ mod tests { } #[test] - fn test_default_prefix() { + fn default_prefix_basic() { let prefix = default_prefix(); assert!(!prefix.is_empty()); } diff --git a/crates/vm/src/import.rs b/crates/vm/src/import.rs index f33290ff83b..798bc258b7f 100644 --- a/crates/vm/src/import.rs +++ b/crates/vm/src/import.rs @@ -9,7 +9,7 @@ use crate::{ }; pub(crate) fn check_pyc_magic_number_bytes(buf: &[u8]) -> bool { - buf.starts_with(&crate::version::PYC_MAGIC_NUMBER_BYTES[..2]) + buf.starts_with(&crate::version::PYC_MAGIC_NUMBER_BYTES) } pub(crate) fn init_importlib_base(vm: &mut VirtualMachine) -> PyResult { diff --git a/crates/vm/src/intern.rs b/crates/vm/src/intern.rs index da1d63f8791..981732f2dd2 100644 --- a/crates/vm/src/intern.rs +++ b/crates/vm/src/intern.rs @@ -11,7 +11,7 @@ use core::{borrow::Borrow, ops::Deref}; #[derive(Debug)] pub(crate) struct StringPool { - inner: PyRwLock>, + inner: PyRwLock>, } impl Default for StringPool { diff --git a/crates/vm/src/lib.rs b/crates/vm/src/lib.rs index 5e2e88a074a..cca5b43457c 100644 --- a/crates/vm/src/lib.rs +++ b/crates/vm/src/lib.rs @@ -6,13 +6,19 @@ //! //! Some stdlib modules are implemented here, but most of them are in the `rustpython-stdlib` module. The -// to allow `mod foo {}` in foo.rs; clippy thinks this is a mistake/misunderstanding of -// how `mod` works, but we want this sometimes for pymodule declarations #![deny(clippy::disallowed_methods)] -#![allow(clippy::module_inception)] -// we want to mirror python naming conventions when defining python structs, so that does mean -// uppercase acronyms, e.g. TextIOWrapper instead of TextIoWrapper -#![allow(clippy::upper_case_acronyms)] +#![allow( + clippy::module_inception, + reason = " + to allow `mod foo {}` in foo.rs; clippy thinks this is a mistake/misunderstanding of + how `mod` works, but we want this sometimes for pymodule declarations" +)] +#![allow( + clippy::upper_case_acronyms, + reason = " +we want to mirror python naming conventions when defining python structs, so that does mean +uppercase acronyms, e.g. TextIOWrapper instead of TextIoWrapper" +)] #![doc(html_logo_url = "https://raw.githubusercontent.com/RustPython/RustPython/main/logo.png")] #![doc(html_root_url = "https://docs.rs/rustpython-vm/")] diff --git a/crates/vm/src/readline.rs b/crates/vm/src/readline.rs index b154bd1365c..bd0ecd73912 100644 --- a/crates/vm/src/readline.rs +++ b/crates/vm/src/readline.rs @@ -14,7 +14,7 @@ pub enum ReadlineResult { Interrupt, Io(std::io::Error), #[cfg(unix)] - OsError(nix::Error), + OsError(String), Other(OtherError), } @@ -163,7 +163,7 @@ pub mod rustyline_readline { Err(ReadlineError::Io(e)) => ReadlineResult::Io(e), Err(ReadlineError::Signal(_)) => continue, #[cfg(unix)] - Err(ReadlineError::Errno(num)) => ReadlineResult::OsError(num), + Err(ReadlineError::Errno(num)) => ReadlineResult::OsError(num.to_string()), Err(e) => ReadlineResult::Other(e.into()), }; } diff --git a/crates/vm/src/signal.rs b/crates/vm/src/signal.rs index dbeeaeb4bf8..6c8e9cb6d35 100644 --- a/crates/vm/src/signal.rs +++ b/crates/vm/src/signal.rs @@ -92,6 +92,7 @@ pub(crate) fn set_triggered() { } #[inline(always)] +#[cfg(not(target_arch = "wasm32"))] pub(crate) fn is_triggered() -> bool { ANY_TRIGGERED.load(Ordering::Relaxed) } diff --git a/crates/vm/src/sliceable.rs b/crates/vm/src/sliceable.rs index 21fe6057e60..dd4dd50979d 100644 --- a/crates/vm/src/sliceable.rs +++ b/crates/vm/src/sliceable.rs @@ -160,7 +160,7 @@ impl SliceableSequenceMutOp for Vec { } } -#[allow(clippy::len_without_is_empty)] +#[expect(clippy::len_without_is_empty, reason = "Doesn't match CPython code")] pub trait SliceableSequenceOp { type Item; type Sliced; diff --git a/crates/vm/src/stdlib/_ast/pyast.rs b/crates/vm/src/stdlib/_ast/pyast.rs index c9017006f1d..7eae6f00986 100644 --- a/crates/vm/src/stdlib/_ast/pyast.rs +++ b/crates/vm/src/stdlib/_ast/pyast.rs @@ -1,5 +1,3 @@ -#![allow(clippy::all)] - use super::*; use crate::builtins::{PyGenericAlias, PyTuple, PyTupleRef, PyTypeRef, make_union}; use crate::common::ascii; diff --git a/crates/vm/src/stdlib/_ast/statement.rs b/crates/vm/src/stdlib/_ast/statement.rs index 4c78ad4b73e..43d1162a402 100644 --- a/crates/vm/src/stdlib/_ast/statement.rs +++ b/crates/vm/src/stdlib/_ast/statement.rs @@ -36,111 +36,99 @@ impl Node for ast::Stmt { } } - #[allow(clippy::if_same_then_else)] + #[expect(clippy::if_same_then_else, reason = "Looks better here")] fn ast_from_object( - _vm: &VirtualMachine, + vm: &VirtualMachine, source_file: &SourceFile, - _object: PyObjectRef, + object: PyObjectRef, ) -> PyResult { - let _cls = _object.class(); - Ok(if _cls.is(pyast::NodeStmtFunctionDef::static_type()) { + let cls = object.class(); + Ok(if cls.is(pyast::NodeStmtFunctionDef::static_type()) { Self::FunctionDef(ast::StmtFunctionDef::ast_from_object( - _vm, + vm, source_file, - _object, + object, )?) - } else if _cls.is(pyast::NodeStmtAsyncFunctionDef::static_type()) { + } else if cls.is(pyast::NodeStmtAsyncFunctionDef::static_type()) { Self::FunctionDef(ast::StmtFunctionDef::ast_from_object( - _vm, - source_file, - _object, - )?) - } else if _cls.is(pyast::NodeStmtClassDef::static_type()) { - Self::ClassDef(ast::StmtClassDef::ast_from_object( - _vm, + vm, source_file, - _object, + object, )?) - } else if _cls.is(pyast::NodeStmtReturn::static_type()) { - Self::Return(ast::StmtReturn::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtDelete::static_type()) { - Self::Delete(ast::StmtDelete::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtAssign::static_type()) { - Self::Assign(ast::StmtAssign::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtTypeAlias::static_type()) { + } else if cls.is(pyast::NodeStmtClassDef::static_type()) { + Self::ClassDef(ast::StmtClassDef::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtReturn::static_type()) { + Self::Return(ast::StmtReturn::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtDelete::static_type()) { + Self::Delete(ast::StmtDelete::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtAssign::static_type()) { + Self::Assign(ast::StmtAssign::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtTypeAlias::static_type()) { Self::TypeAlias(ast::StmtTypeAlias::ast_from_object( - _vm, + vm, source_file, - _object, + object, )?) - } else if _cls.is(pyast::NodeStmtAugAssign::static_type()) { + } else if cls.is(pyast::NodeStmtAugAssign::static_type()) { Self::AugAssign(ast::StmtAugAssign::ast_from_object( - _vm, + vm, source_file, - _object, + object, )?) - } else if _cls.is(pyast::NodeStmtAnnAssign::static_type()) { + } else if cls.is(pyast::NodeStmtAnnAssign::static_type()) { Self::AnnAssign(ast::StmtAnnAssign::ast_from_object( - _vm, + vm, source_file, - _object, + object, )?) - } else if _cls.is(pyast::NodeStmtFor::static_type()) { - Self::For(ast::StmtFor::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtAsyncFor::static_type()) { - Self::For(ast::StmtFor::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtWhile::static_type()) { - Self::While(ast::StmtWhile::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtIf::static_type()) { - Self::If(ast::StmtIf::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtWith::static_type()) { - Self::With(ast::StmtWith::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtAsyncWith::static_type()) { - Self::With(ast::StmtWith::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtMatch::static_type()) { - Self::Match(ast::StmtMatch::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtRaise::static_type()) { - Self::Raise(ast::StmtRaise::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtTry::static_type()) { - Self::Try(ast::StmtTry::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtTryStar::static_type()) { - Self::Try(ast::StmtTry::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtAssert::static_type()) { - Self::Assert(ast::StmtAssert::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtImport::static_type()) { - Self::Import(ast::StmtImport::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtImportFrom::static_type()) { + } else if cls.is(pyast::NodeStmtFor::static_type()) { + Self::For(ast::StmtFor::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtAsyncFor::static_type()) { + Self::For(ast::StmtFor::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtWhile::static_type()) { + Self::While(ast::StmtWhile::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtIf::static_type()) { + Self::If(ast::StmtIf::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtWith::static_type()) { + Self::With(ast::StmtWith::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtAsyncWith::static_type()) { + Self::With(ast::StmtWith::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtMatch::static_type()) { + Self::Match(ast::StmtMatch::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtRaise::static_type()) { + Self::Raise(ast::StmtRaise::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtTry::static_type()) { + Self::Try(ast::StmtTry::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtTryStar::static_type()) { + Self::Try(ast::StmtTry::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtAssert::static_type()) { + Self::Assert(ast::StmtAssert::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtImport::static_type()) { + Self::Import(ast::StmtImport::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtImportFrom::static_type()) { Self::ImportFrom(ast::StmtImportFrom::ast_from_object( - _vm, - source_file, - _object, - )?) - } else if _cls.is(pyast::NodeStmtGlobal::static_type()) { - Self::Global(ast::StmtGlobal::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtNonlocal::static_type()) { - Self::Nonlocal(ast::StmtNonlocal::ast_from_object( - _vm, - source_file, - _object, - )?) - } else if _cls.is(pyast::NodeStmtExpr::static_type()) { - Self::Expr(ast::StmtExpr::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtPass::static_type()) { - Self::Pass(ast::StmtPass::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtBreak::static_type()) { - Self::Break(ast::StmtBreak::ast_from_object(_vm, source_file, _object)?) - } else if _cls.is(pyast::NodeStmtContinue::static_type()) { - Self::Continue(ast::StmtContinue::ast_from_object( - _vm, + vm, source_file, - _object, + object, )?) - } else if _vm.is_none(&_object) { - return Err(_vm.new_value_error("None disallowed in statement list")); + } else if cls.is(pyast::NodeStmtGlobal::static_type()) { + Self::Global(ast::StmtGlobal::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtNonlocal::static_type()) { + Self::Nonlocal(ast::StmtNonlocal::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtExpr::static_type()) { + Self::Expr(ast::StmtExpr::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtPass::static_type()) { + Self::Pass(ast::StmtPass::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtBreak::static_type()) { + Self::Break(ast::StmtBreak::ast_from_object(vm, source_file, object)?) + } else if cls.is(pyast::NodeStmtContinue::static_type()) { + Self::Continue(ast::StmtContinue::ast_from_object(vm, source_file, object)?) + } else if vm.is_none(&object) { + return Err(vm.new_value_error("None disallowed in statement list")); } else { - return Err(_vm.new_type_error(format!( + return Err(vm.new_type_error(format!( "expected some sort of stmt, but got {}", - _object.repr(_vm)? + object.repr(vm)? ))); }) } @@ -201,6 +189,7 @@ impl Node for ast::StmtFunctionDef { node_add_location(&dict, range, vm, source_file); node.into() } + fn ast_from_object( _vm: &VirtualMachine, source_file: &SourceFile, diff --git a/crates/vm/src/stdlib/_codecs.rs b/crates/vm/src/stdlib/_codecs.rs index 0fc9c792b61..a9402edc3a2 100644 --- a/crates/vm/src/stdlib/_codecs.rs +++ b/crates/vm/src/stdlib/_codecs.rs @@ -375,6 +375,7 @@ fn delegate_pycodecs( mod _codecs_windows { use crate::{PyResult, VirtualMachine}; use crate::{builtins::PyStrRef, builtins::PyUtf8StrRef, function::ArgBytesLike}; + use rustpython_host_env::windows as host_windows; #[derive(FromArgs)] struct MbcsEncodeArgs { @@ -387,9 +388,6 @@ mod _codecs_windows { #[pyfunction] fn mbcs_encode(args: MbcsEncodeArgs, vm: &VirtualMachine) -> PyResult<(Vec, usize)> { use crate::host_env::windows::ToWideString; - use windows_sys::Win32::Globalization::{ - CP_ACP, WC_NO_BEST_FIT_CHARS, WideCharToMultiByte, - }; let errors = args.errors.as_ref().map_or("strict", |s| s.as_str()); let s = match args.s.to_str() { @@ -411,56 +409,31 @@ mod _codecs_windows { let wide: Vec = std::ffi::OsStr::new(s).to_wide(); // Get the required buffer size - let size = unsafe { - WideCharToMultiByte( - CP_ACP, - WC_NO_BEST_FIT_CHARS, - wide.as_ptr(), - wide.len() as i32, - core::ptr::null_mut(), - 0, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - - if size == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("mbcs_encode failed: {err}"))); - } - - let mut buffer = vec![0u8; size as usize]; - let mut used_default_char: i32 = 0; - - let result = unsafe { - WideCharToMultiByte( - CP_ACP, - WC_NO_BEST_FIT_CHARS, - wide.as_ptr(), - wide.len() as i32, - buffer.as_mut_ptr().cast(), - size, - core::ptr::null(), - if errors == "strict" { - &mut used_default_char - } else { - core::ptr::null_mut() - }, - ) - }; - - if result == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("mbcs_encode failed: {err}"))); - } - - if errors == "strict" && used_default_char != 0 { + let (size, _) = host_windows::wide_char_to_multi_byte_len( + host_windows::CP_ACP, + host_windows::WC_NO_BEST_FIT_CHARS, + &wide, + false, + ) + .map_err(|err| vm.new_os_error(format!("mbcs_encode failed: {err}")))?; + + let mut buffer = vec![0u8; size]; + let (result, used_default_char) = host_windows::wide_char_to_multi_byte( + host_windows::CP_ACP, + host_windows::WC_NO_BEST_FIT_CHARS, + &wide, + &mut buffer, + errors == "strict", + ) + .map_err(|err| vm.new_os_error(format!("mbcs_encode failed: {err}")))?; + + if errors == "strict" && used_default_char { return Err(vm.new_unicode_encode_error( "'mbcs' codec can't encode characters: invalid character", )); } - buffer.truncate(result as usize); + buffer.truncate(result); Ok((buffer, char_len)) } @@ -477,10 +450,6 @@ mod _codecs_windows { #[pyfunction] fn mbcs_decode(args: MbcsDecodeArgs, vm: &VirtualMachine) -> PyResult<(String, usize)> { - use windows_sys::Win32::Globalization::{ - CP_ACP, MB_ERR_INVALID_CHARS, MultiByteToWideChar, - }; - let _errors = args.errors.as_ref().map_or("strict", |s| s.as_str()); let data = args.data.borrow_buf(); let len = data.len(); @@ -490,72 +459,42 @@ mod _codecs_windows { } // Get the required buffer size for UTF-16 - let size = unsafe { - MultiByteToWideChar( - CP_ACP, - MB_ERR_INVALID_CHARS, - data.as_ptr().cast(), - len as i32, - core::ptr::null_mut(), - 0, - ) - }; + let size = host_windows::multi_byte_to_wide_len( + host_windows::CP_ACP, + host_windows::MB_ERR_INVALID_CHARS, + data.as_ref(), + ); - if size == 0 { + if size.is_err() { // Try without MB_ERR_INVALID_CHARS for non-strict mode (replacement behavior) - let size = unsafe { - MultiByteToWideChar( - CP_ACP, - 0, - data.as_ptr().cast(), - len as i32, - core::ptr::null_mut(), - 0, - ) - }; - if size == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("mbcs_decode failed: {err}"))); - } + let size = host_windows::multi_byte_to_wide_len(host_windows::CP_ACP, 0, data.as_ref()) + .map_err(|err| vm.new_os_error(format!("mbcs_decode failed: {err}")))?; - let mut buffer = vec![0u16; size as usize]; - let result = unsafe { - MultiByteToWideChar( - CP_ACP, - 0, - data.as_ptr().cast(), - len as i32, - buffer.as_mut_ptr(), - size, - ) - }; - if result == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("mbcs_decode failed: {err}"))); - } - buffer.truncate(result as usize); + let mut buffer = vec![0u16; size]; + let result = host_windows::multi_byte_to_wide( + host_windows::CP_ACP, + 0, + data.as_ref(), + &mut buffer, + ) + .map_err(|err| vm.new_os_error(format!("mbcs_decode failed: {err}")))?; + buffer.truncate(result); let s = String::from_utf16(&buffer) .map_err(|e| vm.new_unicode_decode_error(format!("mbcs_decode failed: {e}")))?; return Ok((s, len)); } // Strict mode succeeded - no invalid characters - let mut buffer = vec![0u16; size as usize]; - let result = unsafe { - MultiByteToWideChar( - CP_ACP, - MB_ERR_INVALID_CHARS, - data.as_ptr().cast(), - len as i32, - buffer.as_mut_ptr(), - size, - ) - }; - if result == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("mbcs_decode failed: {err}"))); - } - buffer.truncate(result as usize); + let size = size.unwrap(); + let mut buffer = vec![0u16; size]; + let result = host_windows::multi_byte_to_wide( + host_windows::CP_ACP, + host_windows::MB_ERR_INVALID_CHARS, + data.as_ref(), + &mut buffer, + ) + .map_err(|err| vm.new_os_error(format!("mbcs_decode failed: {err}")))?; + buffer.truncate(result); let s = String::from_utf16(&buffer) .map_err(|e| vm.new_unicode_decode_error(format!("mbcs_decode failed: {e}")))?; @@ -573,9 +512,6 @@ mod _codecs_windows { #[pyfunction] fn oem_encode(args: OemEncodeArgs, vm: &VirtualMachine) -> PyResult<(Vec, usize)> { use crate::host_env::windows::ToWideString; - use windows_sys::Win32::Globalization::{ - CP_OEMCP, WC_NO_BEST_FIT_CHARS, WideCharToMultiByte, - }; let errors = args.errors.as_ref().map_or("strict", |s| s.as_str()); let s = match args.s.to_str() { @@ -597,56 +533,31 @@ mod _codecs_windows { let wide: Vec = std::ffi::OsStr::new(s).to_wide(); // Get the required buffer size - let size = unsafe { - WideCharToMultiByte( - CP_OEMCP, - WC_NO_BEST_FIT_CHARS, - wide.as_ptr(), - wide.len() as i32, - core::ptr::null_mut(), - 0, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - - if size == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("oem_encode failed: {err}"))); - } - - let mut buffer = vec![0u8; size as usize]; - let mut used_default_char: i32 = 0; - - let result = unsafe { - WideCharToMultiByte( - CP_OEMCP, - WC_NO_BEST_FIT_CHARS, - wide.as_ptr(), - wide.len() as i32, - buffer.as_mut_ptr().cast(), - size, - core::ptr::null(), - if errors == "strict" { - &mut used_default_char - } else { - core::ptr::null_mut() - }, - ) - }; - - if result == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("oem_encode failed: {err}"))); - } - - if errors == "strict" && used_default_char != 0 { + let (size, _) = host_windows::wide_char_to_multi_byte_len( + host_windows::CP_OEMCP, + host_windows::WC_NO_BEST_FIT_CHARS, + &wide, + false, + ) + .map_err(|err| vm.new_os_error(format!("oem_encode failed: {err}")))?; + + let mut buffer = vec![0u8; size]; + let (result, used_default_char) = host_windows::wide_char_to_multi_byte( + host_windows::CP_OEMCP, + host_windows::WC_NO_BEST_FIT_CHARS, + &wide, + &mut buffer, + errors == "strict", + ) + .map_err(|err| vm.new_os_error(format!("oem_encode failed: {err}")))?; + + if errors == "strict" && used_default_char { return Err(vm.new_unicode_encode_error( "'oem' codec can't encode characters: invalid character", )); } - buffer.truncate(result as usize); + buffer.truncate(result); Ok((buffer, char_len)) } @@ -663,10 +574,6 @@ mod _codecs_windows { #[pyfunction] fn oem_decode(args: OemDecodeArgs, vm: &VirtualMachine) -> PyResult<(String, usize)> { - use windows_sys::Win32::Globalization::{ - CP_OEMCP, MB_ERR_INVALID_CHARS, MultiByteToWideChar, - }; - let _errors = args.errors.as_ref().map_or("strict", |s| s.as_str()); let data = args.data.borrow_buf(); let len = data.len(); @@ -676,72 +583,43 @@ mod _codecs_windows { } // Get the required buffer size for UTF-16 - let size = unsafe { - MultiByteToWideChar( - CP_OEMCP, - MB_ERR_INVALID_CHARS, - data.as_ptr().cast(), - len as i32, - core::ptr::null_mut(), - 0, - ) - }; + let size = host_windows::multi_byte_to_wide_len( + host_windows::CP_OEMCP, + host_windows::MB_ERR_INVALID_CHARS, + data.as_ref(), + ); - if size == 0 { + if size.is_err() { // Try without MB_ERR_INVALID_CHARS for non-strict mode (replacement behavior) - let size = unsafe { - MultiByteToWideChar( - CP_OEMCP, - 0, - data.as_ptr().cast(), - len as i32, - core::ptr::null_mut(), - 0, - ) - }; - if size == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("oem_decode failed: {err}"))); - } + let size = + host_windows::multi_byte_to_wide_len(host_windows::CP_OEMCP, 0, data.as_ref()) + .map_err(|err| vm.new_os_error(format!("oem_decode failed: {err}")))?; - let mut buffer = vec![0u16; size as usize]; - let result = unsafe { - MultiByteToWideChar( - CP_OEMCP, - 0, - data.as_ptr().cast(), - len as i32, - buffer.as_mut_ptr(), - size, - ) - }; - if result == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("oem_decode failed: {err}"))); - } - buffer.truncate(result as usize); + let mut buffer = vec![0u16; size]; + let result = host_windows::multi_byte_to_wide( + host_windows::CP_OEMCP, + 0, + data.as_ref(), + &mut buffer, + ) + .map_err(|err| vm.new_os_error(format!("oem_decode failed: {err}")))?; + buffer.truncate(result); let s = String::from_utf16(&buffer) .map_err(|e| vm.new_unicode_decode_error(format!("oem_decode failed: {e}")))?; return Ok((s, len)); } // Strict mode succeeded - no invalid characters - let mut buffer = vec![0u16; size as usize]; - let result = unsafe { - MultiByteToWideChar( - CP_OEMCP, - MB_ERR_INVALID_CHARS, - data.as_ptr().cast(), - len as i32, - buffer.as_mut_ptr(), - size, - ) - }; - if result == 0 { - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("oem_decode failed: {err}"))); - } - buffer.truncate(result as usize); + let size = size.unwrap(); + let mut buffer = vec![0u16; size]; + let result = host_windows::multi_byte_to_wide( + host_windows::CP_OEMCP, + host_windows::MB_ERR_INVALID_CHARS, + data.as_ref(), + &mut buffer, + ) + .map_err(|err| vm.new_os_error(format!("oem_decode failed: {err}")))?; + buffer.truncate(result); let s = String::from_utf16(&buffer) .map_err(|e| vm.new_unicode_decode_error(format!("oem_decode failed: {e}")))?; @@ -768,17 +646,12 @@ mod _codecs_windows { /// Get WideCharToMultiByte flags for encoding. /// Matches encode_code_page_flags() in CPython. fn encode_code_page_flags(code_page: u32, errors: &str) -> u32 { - use windows_sys::Win32::Globalization::{WC_ERR_INVALID_CHARS, WC_NO_BEST_FIT_CHARS}; - if code_page == 65001 { - // CP_UTF8 - WC_ERR_INVALID_CHARS - } else if code_page == 65000 { - // CP_UTF7 only supports flags=0 - 0 - } else if errors == "replace" { + if code_page == host_windows::CP_UTF8 { + host_windows::WC_ERR_INVALID_CHARS + } else if code_page == host_windows::CP_UTF7 || errors == "replace" { 0 } else { - WC_NO_BEST_FIT_CHARS + host_windows::WC_NO_BEST_FIT_CHARS } } @@ -790,80 +663,56 @@ mod _codecs_windows { wide: &[u16], vm: &VirtualMachine, ) -> PyResult>> { - use windows_sys::Win32::Globalization::WideCharToMultiByte; - let flags = encode_code_page_flags(code_page, "strict"); - let use_default_char = code_page != 65001 && code_page != 65000; - let mut used_default_char: i32 = 0; - let pused = if use_default_char { - &mut used_default_char as *mut i32 - } else { - core::ptr::null_mut() - }; - - let size = unsafe { - WideCharToMultiByte( - code_page, - flags, - wide.as_ptr(), - wide.len() as i32, - core::ptr::null_mut(), - 0, - core::ptr::null(), - pused, - ) - }; - - if size <= 0 { - let err_code = std::io::Error::last_os_error().raw_os_error().unwrap_or(0); - if err_code == 1113 { - // ERROR_NO_UNICODE_TRANSLATION - return Ok(None); + let use_default_char = + code_page != host_windows::CP_UTF8 && code_page != host_windows::CP_UTF7; + + let size = match host_windows::wide_char_to_multi_byte_len( + code_page, + flags, + wide, + use_default_char, + ) { + Ok((size, used_default_char)) => { + if use_default_char && used_default_char { + return Ok(None); + } + size + } + Err(err) => { + let err_code = err.raw_os_error().unwrap_or(0); + if err_code == host_windows::ERROR_NO_UNICODE_TRANSLATION_I32 { + return Ok(None); + } + return Err(vm.new_os_error(format!("code_page_encode: {err}"))); } - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("code_page_encode: {err}"))); - } - - if use_default_char && used_default_char != 0 { - return Ok(None); - } - - let mut buffer = vec![0u8; size as usize]; - used_default_char = 0; - let pused = if use_default_char { - &mut used_default_char as *mut i32 - } else { - core::ptr::null_mut() - }; - - let result = unsafe { - WideCharToMultiByte( - code_page, - flags, - wide.as_ptr(), - wide.len() as i32, - buffer.as_mut_ptr().cast(), - size, - core::ptr::null(), - pused, - ) }; - if result <= 0 { - let err_code = std::io::Error::last_os_error().raw_os_error().unwrap_or(0); - if err_code == 1113 { - return Ok(None); + let mut buffer = vec![0u8; size]; + let result = match host_windows::wide_char_to_multi_byte( + code_page, + flags, + wide, + &mut buffer, + use_default_char, + ) { + Ok((result, used_default_char)) => { + if use_default_char && used_default_char { + return Ok(None); + } + result } - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("code_page_encode: {err}"))); - } - - if use_default_char && used_default_char != 0 { - return Ok(None); - } + Err(err) => { + let err_code = err.raw_os_error().unwrap_or(0); + if err_code == host_windows::ERROR_NO_UNICODE_TRANSLATION_I32 { + return Ok(None); + } + return Err(vm.new_os_error(format!("code_page_encode: {err}"))); + } + }; - buffer.truncate(result as usize); + buffer.truncate(result); Ok(Some(buffer)) } @@ -876,11 +725,11 @@ mod _codecs_windows { vm: &VirtualMachine, ) -> PyResult<(Vec, usize)> { use crate::builtins::{PyBytes, PyStr, PyTuple}; - use windows_sys::Win32::Globalization::WideCharToMultiByte; let char_len = s.char_len(); let flags = encode_code_page_flags(code_page, errors); - let use_default_char = code_page != 65001 && code_page != 65000; + let use_default_char = + code_page != host_windows::CP_UTF8 && code_page != host_windows::CP_UTF7; let encoding_str = vm.ctx.new_str(encoding_name); let reason_str = vm.ctx.new_str("invalid character"); @@ -902,28 +751,19 @@ mod _codecs_windows { wchars[1] = ((ch - 0x10000) & 0x3FF) as u16 + 0xDC00; 2 }; - let mut used_default_char: i32 = 0; - let pused = if use_default_char { - &mut used_default_char as *mut i32 - } else { - core::ptr::null_mut() - }; - let outsize = unsafe { - WideCharToMultiByte( - code_page, - flags, - wchars.as_ptr(), - wchar_len, - core::ptr::null_mut(), - 0, - core::ptr::null(), - pused, - ) - }; - if outsize <= 0 || (use_default_char && used_default_char != 0) { - break; + match host_windows::wide_char_to_multi_byte_len( + code_page, + flags, + &wchars[..wchar_len], + use_default_char, + ) { + Ok((_outsize, used_default_char)) + if !use_default_char || !used_default_char => + { + fail_pos += 1; + } + _ => break, } - fail_pos += 1; } return Err(vm.new_unicode_encode_error_real( encoding_str, @@ -961,29 +801,16 @@ mod _codecs_windows { } if !is_surrogate { - let mut used_default_char: i32 = 0; - let pused = if use_default_char { - &mut used_default_char as *mut i32 - } else { - core::ptr::null_mut() - }; - let mut buf = [0u8; 8]; - let outsize = unsafe { - WideCharToMultiByte( - code_page, - flags, - wchars.as_ptr(), - wchar_len, - buf.as_mut_ptr().cast(), - buf.len() as i32, - core::ptr::null(), - pused, - ) - }; - - if outsize > 0 && (!use_default_char || used_default_char == 0) { - output.extend_from_slice(&buf[..outsize as usize]); + if let Ok((outsize, used_default_char)) = host_windows::wide_char_to_multi_byte( + code_page, + flags, + &wchars[..wchar_len], + &mut buf, + use_default_char, + ) && (!use_default_char || !used_default_char) + { + output.extend_from_slice(&buf[..outsize]); pos += 1; continue; } @@ -1096,51 +923,41 @@ mod _codecs_windows { data: &[u8], vm: &VirtualMachine, ) -> PyResult>> { - use windows_sys::Win32::Globalization::{MB_ERR_INVALID_CHARS, MultiByteToWideChar}; - - let mut flags = MB_ERR_INVALID_CHARS; + let mut flags = host_windows::MB_ERR_INVALID_CHARS; loop { - let size = unsafe { - MultiByteToWideChar( - code_page, - flags, - data.as_ptr().cast(), - data.len() as i32, - core::ptr::null_mut(), - 0, - ) + let size = match host_windows::multi_byte_to_wide_len(code_page, flags, data) { + Ok(size) => size, + Err(err) => { + let err_code = err.raw_os_error().unwrap_or(0); + if flags != 0 && err_code == host_windows::ERROR_INVALID_FLAGS_I32 { + flags = 0; + continue; + } + if err_code == host_windows::ERROR_NO_UNICODE_TRANSLATION_I32 { + return Ok(None); + } + return Err(vm.new_os_error(format!("code_page_decode: {err}"))); + } }; - if size > 0 { - let mut buffer = vec![0u16; size as usize]; - let result = unsafe { - MultiByteToWideChar( - code_page, - flags, - data.as_ptr().cast(), - data.len() as i32, - buffer.as_mut_ptr(), - size, - ) - }; - if result > 0 { - buffer.truncate(result as usize); + let mut buffer = vec![0u16; size]; + match host_windows::multi_byte_to_wide(code_page, flags, data, &mut buffer) { + Ok(result) => { + buffer.truncate(result); return Ok(Some(buffer)); } + Err(err) => { + let err_code = err.raw_os_error().unwrap_or(0); + if flags != 0 && err_code == host_windows::ERROR_INVALID_FLAGS_I32 { + flags = 0; + continue; + } + if err_code == host_windows::ERROR_NO_UNICODE_TRANSLATION_I32 { + return Ok(None); + } + return Err(vm.new_os_error(format!("code_page_decode: {err}"))); + } } - - let err_code = std::io::Error::last_os_error().raw_os_error().unwrap_or(0); - // ERROR_INVALID_FLAGS = 1004 - if flags != 0 && err_code == 1004 { - flags = 0; - continue; - } - // ERROR_NO_UNICODE_TRANSLATION = 1113 - if err_code == 1113 { - return Ok(None); - } - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("code_page_decode: {err}"))); } } @@ -1155,7 +972,6 @@ mod _codecs_windows { ) -> PyResult<(PyStrRef, usize)> { use crate::builtins::PyTuple; use crate::common::wtf8::Wtf8Buf; - use windows_sys::Win32::Globalization::{MB_ERR_INVALID_CHARS, MultiByteToWideChar}; let len = data.len(); let encoding_str = vm.ctx.new_str(encoding_name); @@ -1167,33 +983,37 @@ mod _codecs_windows { if errors == "strict" && is_final { // Find the exact failing byte position by trying byte by byte let mut fail_pos = 0; - let mut flags_s: u32 = MB_ERR_INVALID_CHARS; + let mut flags_s: u32 = host_windows::MB_ERR_INVALID_CHARS; let mut buf = [0u16; 2]; while fail_pos < len { let mut in_size = 1; let mut found = false; while in_size <= 4 && fail_pos + in_size <= len { - let outsize = unsafe { - MultiByteToWideChar( - code_page, - flags_s, - data[fail_pos..].as_ptr().cast(), - in_size as i32, - buf.as_mut_ptr(), - 2, - ) - }; - if outsize > 0 { - fail_pos += in_size; - found = true; - break; - } - let err_code = std::io::Error::last_os_error().raw_os_error().unwrap_or(0); - if err_code == 1004 && flags_s != 0 { - flags_s = 0; - continue; + match host_windows::multi_byte_to_wide( + code_page, + flags_s, + &data[fail_pos..fail_pos + in_size], + &mut buf, + ) { + Ok(_outsize) => { + fail_pos += in_size; + found = true; + break; + } + Err(err) => { + let err_code = err.raw_os_error().unwrap_or(0); + if err_code == host_windows::ERROR_INVALID_FLAGS_I32 && flags_s != 0 { + flags_s = 0; + continue; + } + in_size += 1; + if err_code != host_windows::ERROR_NO_UNICODE_TRANSLATION_I32 + && err_code != host_windows::ERROR_INSUFFICIENT_BUFFER_I32 + { + break; + } + } } - in_size += 1; } if !found { break; @@ -1222,46 +1042,46 @@ mod _codecs_windows { let mut wide_buf: Vec = Vec::new(); let mut pos = 0usize; - let mut flags: u32 = MB_ERR_INVALID_CHARS; + let mut flags: u32 = host_windows::MB_ERR_INVALID_CHARS; while pos < len { // Try to decode with increasing byte counts (1, 2, 3, 4) let mut in_size = 1; - let mut outsize; + let outsize; let mut buffer = [0u16; 2]; loop { - outsize = unsafe { - MultiByteToWideChar( - code_page, - flags, - data[pos..].as_ptr().cast(), - in_size as i32, - buffer.as_mut_ptr(), - 2, - ) - }; - if outsize > 0 { - break; - } - let err_code = std::io::Error::last_os_error().raw_os_error().unwrap_or(0); - if err_code == 1004 && flags != 0 { - // ERROR_INVALID_FLAGS - retry with flags=0 - flags = 0; - continue; - } - if err_code != 1113 && err_code != 122 { - // Not ERROR_NO_UNICODE_TRANSLATION and not ERROR_INSUFFICIENT_BUFFER - let err = std::io::Error::last_os_error(); - return Err(vm.new_os_error(format!("code_page_decode: {err}"))); - } - in_size += 1; - if in_size > 4 || pos + in_size > len { - break; + match host_windows::multi_byte_to_wide( + code_page, + flags, + &data[pos..pos + in_size], + &mut buffer, + ) { + Ok(size) => { + outsize = size; + break; + } + Err(err) => { + let err_code = err.raw_os_error().unwrap_or(0); + if err_code == host_windows::ERROR_INVALID_FLAGS_I32 && flags != 0 { + flags = 0; + continue; + } + if err_code != host_windows::ERROR_NO_UNICODE_TRANSLATION_I32 + && err_code != host_windows::ERROR_INSUFFICIENT_BUFFER_I32 + { + return Err(vm.new_os_error(format!("code_page_decode: {err}"))); + } + in_size += 1; + if in_size > 4 || pos + in_size > len { + outsize = 0; + break; + } + } } } - if outsize <= 0 { + if outsize == 0 { // Can't decode this byte sequence if pos + in_size >= len && !is_final { // Incomplete sequence at end, not final - stop here @@ -1348,7 +1168,7 @@ mod _codecs_windows { } } else { // Successfully decoded - wide_buf.extend_from_slice(&buffer[..outsize as usize]); + wide_buf.extend_from_slice(&buffer[..outsize]); pos += in_size; } } @@ -1374,23 +1194,24 @@ mod _codecs_windows { let is_final = args.r#final; if data.is_empty() { - return Ok((vm.ctx.empty_str.to_owned(), 0)); + return Ok((vm.ctx.new_str(""), 0)); } let encoding_name = code_page_encoding_name(code_page); - // Fast path: try to decode the whole buffer with strict flags - match try_decode_code_page_strict(code_page, &data, vm)? { - Some(wide) => { - let s = Wtf8Buf::from_wide(&wide); - return Ok((vm.ctx.new_str(s), data.len())); - } - None => { - // Decode error - fall through to slow path - } + // Fast path: try decoding the whole buffer at once + if let Some(wide) = try_decode_code_page_strict(code_page, data.as_ref(), vm)? { + let s = Wtf8Buf::from_wide(&wide); + return Ok((vm.ctx.new_str(s), data.len())); } - // Slow path: byte by byte with error handling - decode_code_page_errors(code_page, &data, errors, is_final, &encoding_name, vm) + decode_code_page_errors( + code_page, + data.as_ref(), + errors, + is_final, + &encoding_name, + vm, + ) } } diff --git a/crates/vm/src/stdlib/_ctypes.rs b/crates/vm/src/stdlib/_ctypes.rs index 1a3f454524a..cc87ebde572 100644 --- a/crates/vm/src/stdlib/_ctypes.rs +++ b/crates/vm/src/stdlib/_ctypes.rs @@ -3,7 +3,6 @@ mod array; mod base; mod function; -mod library; mod pointer; mod simple; mod structure; @@ -15,12 +14,6 @@ use crate::{ class::PyClassImpl, types::TypeDataRef, }; -use core::ffi::{ - c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, - c_ulonglong, c_ushort, -}; -use core::mem; -use widestring::WideChar; pub(super) use array::PyCArray; pub(super) use base::{FfiArgValue, PyCData, PyCField, StgInfo, StgInfoFlags}; @@ -97,145 +90,8 @@ pub(crate) use _ctypes::module_def; // These check if an object's type's metaclass is a subclass of a specific metaclass -/// Size of long double - platform dependent -/// x86_64 macOS/Linux: 16 bytes (80-bit extended + padding) -/// ARM64: 16 bytes (128-bit) -/// Windows: 8 bytes (same as double) -#[cfg(all( - any(target_arch = "x86_64", target_arch = "aarch64"), - not(target_os = "windows") -))] -const LONG_DOUBLE_SIZE: usize = 16; - -#[cfg(target_os = "windows")] -const LONG_DOUBLE_SIZE: usize = mem::size_of::(); - -#[cfg(not(any( - all( - any(target_arch = "x86_64", target_arch = "aarch64"), - not(target_os = "windows") - ), - target_os = "windows" -)))] -const LONG_DOUBLE_SIZE: usize = mem::size_of::(); - -/// Type information for ctypes simple types -struct TypeInfo { - pub size: usize, - pub ffi_type_fn: fn() -> libffi::middle::Type, -} - -/// Get type information (size and ffi_type) for a ctypes type code -fn type_info(ty: &str) -> Option { - use libffi::middle::Type; - match ty { - "c" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::u8, - }), - "u" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: if mem::size_of::() == 2 { - Type::u16 - } else { - Type::u32 - }, - }), - "b" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::i8, - }), - "B" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::u8, - }), - "h" | "v" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::i16, - }), - "H" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::u16, - }), - "i" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::i32, - }), - "I" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::u32, - }), - "l" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: if mem::size_of::() == 8 { - Type::i64 - } else { - Type::i32 - }, - }), - "L" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: if mem::size_of::() == 8 { - Type::u64 - } else { - Type::u32 - }, - }), - "q" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::i64, - }), - "Q" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::u64, - }), - "f" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::f32, - }), - "d" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::f64, - }), - "g" => Some(TypeInfo { - // long double - platform dependent size - // x86_64 macOS/Linux: 16 bytes (80-bit extended + padding) - // ARM64: 16 bytes (128-bit) - // Windows: 8 bytes (same as double) - // Note: Use f64 as FFI type since Rust doesn't support long double natively - size: LONG_DOUBLE_SIZE, - ffi_type_fn: Type::f64, - }), - "?" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::u8, - }), - "z" | "Z" | "P" | "X" | "O" => Some(TypeInfo { - size: mem::size_of::(), - ffi_type_fn: Type::pointer, - }), - "void" => Some(TypeInfo { - size: 0, - ffi_type_fn: Type::void, - }), - _ => None, - } -} - -/// Get size for a ctypes type code -fn get_size(ty: &str) -> usize { - type_info(ty).map(|t| t.size).expect("invalid type code") -} - -/// Get alignment for simple type codes from type_info(). -/// For primitive C types (c_int, c_long, etc.), alignment equals size. -fn get_align(ty: &str) -> usize { - get_size(ty) -} - #[pymodule] pub(crate) mod _ctypes { - use super::library; use super::{PyCArray, PyCData, PyCPointer, PyCSimple, PyCStructure, PyCUnion}; use crate::builtins::{PyType, PyTypeRef}; use crate::class::StaticType; @@ -280,10 +136,16 @@ pub(crate) mod _ctypes { b'b' | b'h' | b'i' | b'l' | b'q' => { // Signed integers let n = match zelf.value { - FfiArgValue::I8(v) => v as i64, - FfiArgValue::I16(v) => v as i64, - FfiArgValue::I32(v) => v as i64, - FfiArgValue::I64(v) => v, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::I8(v)) => { + v as i64 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::I16(v)) => { + v as i64 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::I32(v)) => { + v as i64 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::I64(v)) => v, _ => 0, }; Ok(format!("")) @@ -291,25 +153,35 @@ pub(crate) mod _ctypes { b'B' | b'H' | b'I' | b'L' | b'Q' => { // Unsigned integers let n = match zelf.value { - FfiArgValue::U8(v) => v as u64, - FfiArgValue::U16(v) => v as u64, - FfiArgValue::U32(v) => v as u64, - FfiArgValue::U64(v) => v, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::U8(v)) => { + v as u64 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::U16(v)) => { + v as u64 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::U32(v)) => { + v as u64 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::U64(v)) => v, _ => 0, }; Ok(format!("")) } b'f' => { let v = match zelf.value { - FfiArgValue::F32(v) => v as f64, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::F32(v)) => { + v as f64 + } _ => 0.0, }; Ok(format!("")) } b'd' | b'g' => { let v = match zelf.value { - FfiArgValue::F64(v) => v, - FfiArgValue::F32(v) => v as f64, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::F64(v)) => v, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::F32(v)) => { + v as f64 + } _ => 0.0, }; Ok(format!("")) @@ -317,8 +189,10 @@ pub(crate) mod _ctypes { b'c' => { // c_char - single byte let byte = match zelf.value { - FfiArgValue::I8(v) => v as u8, - FfiArgValue::U8(v) => v, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::I8(v)) => { + v as u8 + } + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::U8(v)) => v, _ => 0, }; if is_literal_char(byte) { @@ -330,7 +204,8 @@ pub(crate) mod _ctypes { b'z' | b'Z' | b'P' | b'V' => { // Pointer types let ptr = match zelf.value { - FfiArgValue::Pointer(v) => v, + FfiArgValue::Scalar(rustpython_host_env::ctypes::FfiValue::Pointer(v)) => v, + FfiArgValue::OwnedPointer(v, _) => v, _ => 0, }; if ptr == 0 { @@ -365,14 +240,14 @@ pub(crate) mod _ctypes { // TODO: get properly #[pyattr] - const RTLD_LOCAL: i32 = 0; + const RTLD_LOCAL: i32 = rustpython_host_env::ctypes::RTLD_LOCAL; // TODO: get properly #[pyattr] - const RTLD_GLOBAL: i32 = 0; + const RTLD_GLOBAL: i32 = rustpython_host_env::ctypes::RTLD_GLOBAL; #[pyattr] - const SIZEOF_TIME_T: usize = core::mem::size_of::(); + const SIZEOF_TIME_T: usize = rustpython_host_env::ctypes::SIZEOF_TIME_T; #[pyattr] const CTYPES_MAX_ARGCOUNT: usize = 1024; @@ -518,13 +393,16 @@ pub(crate) mod _ctypes { if let Ok(type_attr) = type_obj.as_object().get_attr("_type_", vm) && let Ok(type_str) = type_attr.str(vm) { - return Ok(super::get_size(type_str.as_ref())); + return Ok( + rustpython_host_env::ctypes::simple_type_size(type_str.as_ref()) + .expect("invalid ctypes simple type"), + ); } - return Ok(core::mem::size_of::()); + return Ok(rustpython_host_env::ctypes::pointer_size()); } // Pointer types if type_obj.fast_issubclass(PyCPointer::static_type()) { - return Ok(core::mem::size_of::()); + return Ok(rustpython_host_env::ctypes::pointer_size()); } return Err(vm.new_type_error("this type has no size")); } @@ -535,7 +413,7 @@ pub(crate) mod _ctypes { return Ok(cdata.size()); } if obj.fast_isinstance(PyCPointer::static_type()) { - return Ok(core::mem::size_of::()); + return Ok(rustpython_host_env::ctypes::pointer_size()); } Err(vm.new_type_error("this type has no size")) @@ -546,14 +424,11 @@ pub(crate) mod _ctypes { fn load_library_windows( name: String, _load_flags: OptionalArg, - vm: &VirtualMachine, + _vm: &VirtualMachine, ) -> usize { // TODO: audit functions first // TODO: load_flags - let cache = library::libcache(); - let mut cache_write = cache.write(); - let (id, _) = cache_write.get_or_insert_lib(&name, vm).unwrap(); - id + rustpython_host_env::ctypes::open_library(&name).unwrap() } #[cfg(not(windows))] @@ -563,58 +438,38 @@ pub(crate) mod _ctypes { load_flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - // Default mode: RTLD_NOW | RTLD_LOCAL, always force RTLD_NOW - let mode = load_flags.unwrap_or(libc::RTLD_NOW | libc::RTLD_LOCAL) | libc::RTLD_NOW; + let mode = rustpython_host_env::ctypes::dlopen_mode(load_flags.into_option()); match name { Some(name) => { - let cache = library::libcache(); - let mut cache_write = cache.write(); let os_str = name.as_os_str(vm)?; - let (id, _) = cache_write - .get_or_insert_lib_with_mode(&*os_str, mode, vm) - .map_err(|e| { - let name_str = os_str.to_string_lossy(); - vm.new_os_error(format!("{name_str}: {e}")) - })?; - Ok(id) + rustpython_host_env::ctypes::open_library_with_mode(&*os_str, mode).map_err(|e| { + let name_str = os_str.to_string_lossy(); + vm.new_os_error(format!("{name_str}: {e}")) + }) } None => { // dlopen(NULL, mode) to get the current process handle (for pythonapi) - let handle = unsafe { libc::dlopen(core::ptr::null(), mode) }; - if handle.is_null() { - let err = unsafe { libc::dlerror() }; - let msg = if err.is_null() { - "dlopen() error" - } else { - unsafe { &core::ffi::CStr::from_ptr(err).to_string_lossy() } - }; - return Err(vm.new_os_error(msg)); - } + let handle = rustpython_host_env::ctypes::dlopen_self(mode) + .map_err(|msg| vm.new_os_error(msg))?; // Add to library cache so symbol lookup works - let cache = library::libcache(); - let mut cache_write = cache.write(); - let id = cache_write.insert_raw_handle(handle); - Ok(id) + Ok(rustpython_host_env::ctypes::insert_raw_library_handle( + handle, + )) } } } #[pyfunction(name = "FreeLibrary")] fn free_library(handle: usize) { - let cache = library::libcache(); - let mut cache_write = cache.write(); - cache_write.drop_lib(handle); + rustpython_host_env::ctypes::drop_library(handle); } #[cfg(not(windows))] #[pyfunction] fn dlclose(handle: usize, _vm: &VirtualMachine) { - // Remove from cache, which triggers SharedLibrary drop. - // libloading::Library calls dlclose automatically on Drop. - let cache = library::libcache(); - let mut cache_write = cache.write(); - cache_write.drop_lib(handle); + // Remove from the host_env cache. The underlying library is closed on Drop. + rustpython_host_env::ctypes::drop_library(handle); } #[cfg(not(windows))] @@ -626,29 +481,8 @@ pub(crate) mod _ctypes { ) -> PyResult { let symbol_name = alloc::ffi::CString::new(name.as_str()) .map_err(|_| vm.new_value_error("symbol name contains null byte"))?; - - // Clear previous error - unsafe { libc::dlerror() }; - - let ptr = unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) }; - - // Check for error via dlerror first - let err = unsafe { libc::dlerror() }; - if !err.is_null() { - let msg = unsafe { - core::ffi::CStr::from_ptr(err) - .to_string_lossy() - .into_owned() - }; - return Err(vm.new_os_error(msg)); - } - - // Treat NULL symbol address as error - // This handles cases like GNU IFUNCs that resolve to NULL - if ptr.is_null() { - return Err(vm.new_os_error(format!("symbol '{}' not found", name.as_str()))); - } - + let ptr = rustpython_host_env::ctypes::dlsym_checked(handle, symbol_name.as_c_str()) + .map_err(|msg| vm.new_os_error(msg))?; Ok(ptr as usize) } @@ -784,10 +618,10 @@ pub(crate) mod _ctypes { // Get buffer address: (char *)((CDataObject *)obj)->b_ptr + offset let ptr_val = if let Some(simple) = obj.downcast_ref::() { let buffer = simple.0.buffer.read(); - (buffer.as_ptr() as isize + offset_val) as usize + rustpython_host_env::ctypes::offset_address(buffer.as_ptr() as usize, offset_val) } else if let Some(cdata) = obj.downcast_ref::() { let buffer = cdata.buffer.read(); - (buffer.as_ptr() as isize + offset_val) as usize + rustpython_host_env::ctypes::offset_address(buffer.as_ptr() as usize, offset_val) } else { 0 }; @@ -795,7 +629,7 @@ pub(crate) mod _ctypes { // Create CArgObject to hold the reference Ok(CArgObject { tag: b'P', - value: FfiArgValue::Pointer(ptr_val), + value: FfiArgValue::pointer(ptr_val), obj, size: 0, offset: offset_val, @@ -866,7 +700,8 @@ pub(crate) mod _ctypes { if let Ok(s) = type_attr.str(vm) { let ty = s.to_string(); if ty.len() == 1 && super::simple::SIMPLE_TYPE_CHARS.contains(ty.as_str()) { - return Ok(super::get_align(&ty)); + return Ok(rustpython_host_env::ctypes::simple_type_align(&ty) + .expect("invalid ctypes simple type")); } } } @@ -937,46 +772,43 @@ pub(crate) mod _ctypes { let new_size = size as usize; let mut buffer = cdata.buffer.write(); let old_data = buffer.to_vec(); - let mut new_data = vec![0u8; new_size]; - let copy_len = old_data.len().min(new_size); - new_data[..copy_len].copy_from_slice(&old_data[..copy_len]); - *buffer = Cow::Owned(new_data); + *buffer = Cow::Owned(rustpython_host_env::ctypes::resize_owned_bytes( + &old_data, new_size, + )); Ok(()) } #[pyfunction] fn get_errno() -> i32 { - super::function::get_errno_value() + rustpython_host_env::ctypes::get_errno() } #[pyfunction] fn set_errno(value: i32) -> i32 { - super::function::set_errno_value(value) + rustpython_host_env::ctypes::set_errno(value) } #[cfg(windows)] #[pyfunction] fn get_last_error() -> u32 { - super::function::get_last_error_value() + rustpython_host_env::ctypes::get_last_error() } #[cfg(windows)] #[pyfunction] fn set_last_error(value: u32) -> u32 { - super::function::set_last_error_value(value) + rustpython_host_env::ctypes::set_last_error(value) } #[pyattr] fn _memmove_addr(_vm: &VirtualMachine) -> usize { - let f = libc::memmove; - f as *const () as usize + rustpython_host_env::ctypes::memmove_addr() } #[pyattr] fn _memset_addr(_vm: &VirtualMachine) -> usize { - let f = libc::memset; - f as *const () as usize + rustpython_host_env::ctypes::memset_addr() } #[pyattr] @@ -1092,32 +924,24 @@ pub(crate) mod _ctypes { _flags: u32, vm: &VirtualMachine, ) -> PyResult { - use libffi::middle::{Arg, Cif, CodePtr, Type}; - if func_addr == 0 { return Err(vm.new_value_error("NULL function pointer")); } - let mut ffi_args: Vec> = Vec::with_capacity(args.len()); - let mut arg_values: Vec = Vec::with_capacity(args.len()); - let mut arg_types: Vec = Vec::with_capacity(args.len()); + let mut call_args = Vec::with_capacity(args.len()); for arg in args.iter() { if vm.is_none(arg) { - arg_values.push(0); - arg_types.push(Type::pointer()); + call_args.push(rustpython_host_env::ctypes::CdeclArgValue::Pointer(0)); } else if let Ok(int_val) = arg.try_int(vm) { let val = int_val.as_bigint().to_i64().unwrap_or(0) as isize; - arg_values.push(val); - arg_types.push(Type::isize()); + call_args.push(rustpython_host_env::ctypes::CdeclArgValue::Int(val)); } else if let Some(bytes) = arg.downcast_ref::() { let ptr = bytes.as_bytes().as_ptr() as isize; - arg_values.push(ptr); - arg_types.push(Type::pointer()); + call_args.push(rustpython_host_env::ctypes::CdeclArgValue::Pointer(ptr)); } else if let Some(s) = arg.downcast_ref::() { let ptr = s.as_bytes().as_ptr() as isize; - arg_values.push(ptr); - arg_types.push(Type::pointer()); + call_args.push(rustpython_host_env::ctypes::CdeclArgValue::Pointer(ptr)); } else { return Err(vm.new_type_error(format!( "Don't know how to convert parameter of type '{}'", @@ -1126,13 +950,7 @@ pub(crate) mod _ctypes { } } - for val in &arg_values { - ffi_args.push(Arg::new(val)); - } - - let cif = Cif::new(arg_types, Type::c_int()); - let code_ptr = CodePtr::from_ptr(func_addr as *const _); - let result: libc::c_int = unsafe { cif.call(code_ptr, &ffi_args) }; + let result = rustpython_host_env::ctypes::call_cdecl_i32_values(func_addr, &call_args); Ok(vm.ctx.new_int(result).into()) } @@ -1168,83 +986,41 @@ pub(crate) mod _ctypes { path: Option, vm: &VirtualMachine, ) -> PyResult { - use alloc::ffi::CString; - let path = match path { Some(p) if !vm.is_none(&p) => p, _ => return Ok(false), }; let path_str = path.str(vm)?.to_string(); - let c_path = - CString::new(path_str).map_err(|_| vm.new_value_error("path contains null byte"))?; - - unsafe extern "C" { - fn _dyld_shared_cache_contains_path(path: *const libc::c_char) -> bool; - } - - let result = unsafe { _dyld_shared_cache_contains_path(c_path.as_ptr()) }; - Ok(result) + rustpython_host_env::ctypes::dyld_shared_cache_contains_path(&path_str) + .map_err(|_| vm.new_value_error("path contains null byte")) } #[cfg(windows)] #[pyfunction(name = "FormatError")] fn format_error_func(code: OptionalArg, _vm: &VirtualMachine) -> String { - use windows_sys::Win32::Foundation::{GetLastError, LocalFree}; - use windows_sys::Win32::System::Diagnostics::Debug::{ - FORMAT_MESSAGE_ALLOCATE_BUFFER, FORMAT_MESSAGE_FROM_SYSTEM, - FORMAT_MESSAGE_IGNORE_INSERTS, FormatMessageW, - }; - - let error_code = code.unwrap_or_else(|| unsafe { GetLastError() }); - - let mut buffer: *mut u16 = core::ptr::null_mut(); - let len = unsafe { - FormatMessageW( - FORMAT_MESSAGE_ALLOCATE_BUFFER - | FORMAT_MESSAGE_FROM_SYSTEM - | FORMAT_MESSAGE_IGNORE_INSERTS, - core::ptr::null(), - error_code, - 0, - &mut buffer as *mut *mut u16 as *mut u16, - 0, - core::ptr::null(), - ) - }; - - if len == 0 || buffer.is_null() { - return "".to_string(); - } - - unsafe { - let slice = core::slice::from_raw_parts(buffer, len as usize); - let msg = String::from_utf16_lossy(slice).trim_end().to_string(); - LocalFree(buffer as *mut _); - msg - } + rustpython_host_env::ctypes::format_error_message(code.into_option()) + .unwrap_or_else(|| "".to_string()) } #[cfg(windows)] #[pyfunction(name = "CopyComPointer")] fn copy_com_pointer(src: PyObjectRef, dst: PyObjectRef, vm: &VirtualMachine) -> i32 { - use windows_sys::Win32::Foundation::{E_POINTER, S_OK}; - // 1. Extract pointer-to-pointer address from dst (byref() result) let pdst: usize = if let Some(carg) = dst.downcast_ref::() { // byref() result: object buffer address + offset let base = if let Some(cdata) = carg.obj.downcast_ref::() { cdata.buffer.read().as_ptr() as usize } else { - return E_POINTER; + return rustpython_host_env::ctypes::HRESULT_E_POINTER; }; (base as isize + carg.offset) as usize } else { - return E_POINTER; + return rustpython_host_env::ctypes::HRESULT_E_POINTER; }; if pdst == 0 { - return E_POINTER; + return rustpython_host_env::ctypes::HRESULT_E_POINTER; } // 2. Extract COM pointer value from src @@ -1253,38 +1029,12 @@ pub(crate) mod _ctypes { } else if let Some(cdata) = src.downcast_ref::() { // c_void_p etc: read pointer value from buffer let buffer = cdata.buffer.read(); - if buffer.len() >= core::mem::size_of::() { - usize::from_ne_bytes( - buffer[..core::mem::size_of::()] - .try_into() - .unwrap_or([0; core::mem::size_of::()]), - ) - } else { - 0 - } + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) } else { - return E_POINTER; + return rustpython_host_env::ctypes::HRESULT_E_POINTER; }; - // 3. Call IUnknown::AddRef if src is non-NULL - if src_ptr != 0 { - unsafe { - // IUnknown vtable: [QueryInterface, AddRef, Release, ...] - let iunknown = src_ptr as *mut *const usize; - let vtable = *iunknown; - debug_assert!(!vtable.is_null(), "IUnknown vtable is null"); - let addref_fn: extern "system" fn(*mut core::ffi::c_void) -> u32 = - core::mem::transmute(*vtable.add(1)); // AddRef is index 1 - addref_fn(src_ptr as *mut core::ffi::c_void); - } - } - - // 4. Copy pointer: *pdst = src - unsafe { - *(pdst as *mut usize) = src_ptr; - } - - S_OK + rustpython_host_env::ctypes::copy_com_pointer(src_ptr, pdst) } #[expect(clippy::unnecessary_wraps, reason = "Needs to comply with a signature")] diff --git a/crates/vm/src/stdlib/_ctypes/array.rs b/crates/vm/src/stdlib/_ctypes/array.rs index 4c7f8708679..f7abc834564 100644 --- a/crates/vm/src/stdlib/_ctypes/array.rs +++ b/crates/vm/src/stdlib/_ctypes/array.rs @@ -1,6 +1,5 @@ use super::StgInfo; use super::base::{CDATA_BUFFER_METHODS, PyCData}; -use super::type_info; use crate::common::lock::LazyLock; use crate::sliceable::SaturatedSliceIter; use crate::{ @@ -17,6 +16,13 @@ use crate::{ use alloc::borrow::Cow; use num_traits::{Signed, ToPrimitive}; +use rustpython_host_env::ctypes::{ + ArrayElementWriteValue, DecodedValue, WCHAR_SIZE, WCharArrayWriteError, char_array_field_value, + int_to_sized_bytes, read_array_element, simple_type_size, uint_to_sized_bytes, + wchar_from_bytes, write_array_element, write_char_array_raw, write_char_array_value, + write_wchar_array_value, wstring_from_bytes, zeroed_bytes, +}; + /// Get itemsize from a PEP 3118 format string /// Extracts the type code (last char after endianness prefix) and returns its size fn get_size_from_format(fmt: &str) -> usize { @@ -26,7 +32,7 @@ fn get_size_from_format(fmt: &str) -> usize { .chars() .next() .map(|c| c.to_string()); - code.map_or(1, |c| type_info(&c).map_or(1, |t| t.size)) + code.map_or(1, |c| simple_type_size(&c).unwrap_or(1)) } /// Creates array type for (element_type, length) @@ -520,72 +526,19 @@ impl PyCArray { vec![i.to_i8().unwrap_or(0) as u8] } } - 2 => { - if let Some(v) = i.to_u16() { - v.to_ne_bytes().to_vec() - } else { - i.to_i16().unwrap_or(0).to_ne_bytes().to_vec() - } - } - 4 => { - if let Some(v) = i.to_u32() { - v.to_ne_bytes().to_vec() - } else { - i.to_i32().unwrap_or(0).to_ne_bytes().to_vec() - } - } - 8 => { - if let Some(v) = i.to_u64() { - v.to_ne_bytes().to_vec() - } else { - i.to_i64().unwrap_or(0).to_ne_bytes().to_vec() - } - } - _ => vec![0u8; size], - } - } - - fn bytes_to_int( - bytes: &[u8], - size: usize, - type_code: Option<&str>, - vm: &VirtualMachine, - ) -> PyObjectRef { - // Unsigned type codes: B (uchar), H (ushort), I (uint), L (ulong), Q (ulonglong) - let is_unsigned = matches!(type_code, Some("B" | "H" | "I" | "L" | "Q")); - - match (size, is_unsigned) { - (1, false) => vm.ctx.new_int(bytes[0] as i8).into(), - (1, true) => vm.ctx.new_int(bytes[0]).into(), - (2, false) => { - let val = i16::from_ne_bytes([bytes[0], bytes[1]]); - vm.ctx.new_int(val).into() - } - (2, true) => { - let val = u16::from_ne_bytes([bytes[0], bytes[1]]); - vm.ctx.new_int(val).into() - } - (4, false) => { - let val = i32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); - vm.ctx.new_int(val).into() - } - (4, true) => { - let val = u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); - vm.ctx.new_int(val).into() - } - (8, false) => { - let val = i64::from_ne_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ]); - vm.ctx.new_int(val).into() - } - (8, true) => { - let val = u64::from_ne_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ]); - vm.ctx.new_int(val).into() - } - _ => vm.ctx.new_int(0).into(), + 2 => i.to_u16().map_or_else( + || int_to_sized_bytes(i.to_i16().unwrap_or(0).into(), 2), + |v| uint_to_sized_bytes(v.into(), 2), + ), + 4 => i.to_u32().map_or_else( + || int_to_sized_bytes(i.to_i32().unwrap_or(0).into(), 4), + |v| uint_to_sized_bytes(v.into(), 4), + ), + 8 => i.to_u64().map_or_else( + || int_to_sized_bytes(i.to_i64().unwrap_or(0), 8), + |v| uint_to_sized_bytes(v, 8), + ), + _ => zeroed_bytes(size), } } @@ -630,120 +583,15 @@ impl PyCArray { type_code: Option<&str>, vm: &VirtualMachine, ) -> PyObjectRef { - match type_code { - Some("c") => { - // Return single byte as bytes - if offset < buffer.len() { - vm.ctx.new_bytes(vec![buffer[offset]]).into() - } else { - vm.ctx.new_bytes(vec![0]).into() - } - } - Some("u") => { - // Return single wchar as str - if let Some(code) = wchar_from_bytes(&buffer[offset..]) { - let s = char::from_u32(code) - .map(|c| c.to_string()) - .unwrap_or_default(); - vm.ctx.new_str(s).into() - } else { - vm.ctx.new_str("").into() - } - } - Some("z") => { - // c_char_p: pointer to bytes - dereference to get string - if offset + element_size > buffer.len() { - return vm.ctx.none(); - } - - let ptr_bytes = &buffer[offset..offset + element_size]; - let ptr_val = usize::from_ne_bytes( - ptr_bytes - .try_into() - .unwrap_or([0; core::mem::size_of::()]), - ); - - if ptr_val == 0 { - return vm.ctx.none(); - } - - // Read null-terminated string from pointer address - unsafe { - let ptr = ptr_val as *const u8; - let mut len = 0; - while *ptr.add(len) != 0 { - len += 1; - } - let bytes = core::slice::from_raw_parts(ptr, len); - vm.ctx.new_bytes(bytes.to_vec()).into() - } - } - Some("Z") => { - // c_wchar_p: pointer to wchar_t - dereference to get string - if offset + element_size > buffer.len() { - return vm.ctx.none(); - } - - let ptr_bytes = &buffer[offset..offset + element_size]; - let ptr_val = usize::from_ne_bytes( - ptr_bytes - .try_into() - .unwrap_or([0; core::mem::size_of::()]), - ); - - if ptr_val == 0 { - return vm.ctx.none(); - } - - // Read null-terminated wide string using WCHAR_SIZE - unsafe { - let ptr = ptr_val as *const u8; - let mut chars = Vec::new(); - let mut pos = 0usize; - loop { - let code = if WCHAR_SIZE == 2 { - let bytes = core::slice::from_raw_parts(ptr.add(pos), 2); - u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 - } else { - let bytes = core::slice::from_raw_parts(ptr.add(pos), 4); - u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) - }; - if code == 0 { - break; - } - if let Some(ch) = char::from_u32(code) { - chars.push(ch); - } - pos += WCHAR_SIZE; - } - - let s: String = chars.into_iter().collect(); - vm.ctx.new_str(s).into() - } - } - Some("f") => { - // c_float - let val = buffer[offset..] - .first_chunk::<4>() - .copied() - .map_or(0.0, f32::from_ne_bytes); - vm.ctx.new_float(val as f64).into() - } - Some("d" | "g") => { - // c_double / c_longdouble - read f64 from first 8 bytes - let val = buffer[offset..] - .first_chunk::<8>() - .copied() - .map_or(0.0, f64::from_ne_bytes); - vm.ctx.new_float(val).into() - } - _ => { - if let Some(bytes) = buffer[offset..].get(..element_size) { - Self::bytes_to_int(bytes, element_size, type_code, vm) - } else { - vm.ctx.new_int(0).into() - } - } + match read_array_element(buffer, offset, element_size, type_code) { + DecodedValue::Bytes(bytes) => vm.ctx.new_bytes(bytes).into(), + DecodedValue::String(value) => vm.ctx.new_str(value).into(), + DecodedValue::Float(value) => vm.ctx.new_float(value).into(), + DecodedValue::Signed(value) => vm.ctx.new_int(value).into(), + DecodedValue::Unsigned(value) => vm.ctx.new_int(value).into(), + DecodedValue::None => vm.ctx.none(), + DecodedValue::Pointer(value) => vm.ctx.new_int(value).into(), + DecodedValue::Bool(value) => vm.ctx.new_bool(value).into(), } } @@ -763,13 +611,17 @@ impl PyCArray { match type_code { Some("c") => { if let Some(b) = value.downcast_ref::() { - if offset < buffer.len() { - buffer[offset] = b.as_bytes().first().copied().unwrap_or(0); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Byte(b.as_bytes().first().copied().unwrap_or(0)), + ); } else if let Ok(int_val) = value.try_int(vm) { - if offset < buffer.len() { - buffer[offset] = int_val.as_bigint().to_u8().unwrap_or(0); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Byte(int_val.as_bigint().to_u8().unwrap_or(0)), + ); } else { return Err(vm.new_type_error("an integer or bytes of length 1 is required")); } @@ -777,9 +629,7 @@ impl PyCArray { Some("u") => { if let Some(s) = value.downcast_ref::() { let code = s.as_wtf8().code_points().next().map_or(0, |c| c.to_u32()); - if offset + WCHAR_SIZE <= buffer.len() { - wchar_to_bytes(code, &mut buffer[offset..]); - } + write_array_element(buffer, offset, ArrayElementWriteValue::Wchar(code)); } else { return Err(vm.new_type_error("unicode string expected")); } @@ -799,9 +649,14 @@ impl PyCArray { value.class().name() ))); }; - if offset + element_size <= buffer.len() { - buffer[offset..offset + element_size].copy_from_slice(&ptr_val.to_ne_bytes()); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Pointer { + value: ptr_val, + size: element_size, + }, + ); if let Some(c) = converted { return zelf.0.keep_ref(index, c, vm); } @@ -817,9 +672,14 @@ impl PyCArray { } else { return Err(vm.new_type_error("unicode string or integer address expected")); }; - if offset + element_size <= buffer.len() { - buffer[offset..offset + element_size].copy_from_slice(&ptr_val.to_ne_bytes()); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Pointer { + value: ptr_val, + size: element_size, + }, + ); if let Some(c) = converted { return zelf.0.keep_ref(index, c, vm); } @@ -833,9 +693,14 @@ impl PyCArray { } else { return Err(vm.new_type_error("a float is required")); }; - if offset + 4 <= buffer.len() { - buffer[offset..offset + 4].copy_from_slice(&f32_val.to_ne_bytes()); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Float { + value: f32_val.into(), + size: 4, + }, + ); } Some("d" | "g") => { // c_double / c_longdouble: convert int/float to f64 bytes @@ -846,25 +711,39 @@ impl PyCArray { } else { return Err(vm.new_type_error("a float is required")); }; - if offset + 8 <= buffer.len() { - buffer[offset..offset + 8].copy_from_slice(&f64_val.to_ne_bytes()); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Float { + value: f64_val, + size: 8, + }, + ); // For "g" type, remaining bytes stay zero } _ => { // Handle ctypes instances (copy their buffer) if let Some(cdata) = value.downcast_ref::() { let src_buffer = cdata.buffer.read(); - let copy_len = src_buffer.len().min(element_size); - if offset + copy_len <= buffer.len() { - buffer[offset..offset + copy_len].copy_from_slice(&src_buffer[..copy_len]); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Bytes { + bytes: &src_buffer, + size: element_size, + }, + ); // Other types: use int_to_bytes } else if let Ok(int_val) = value.try_int(vm) { let bytes = Self::int_to_bytes(int_val.as_bigint(), element_size); - if offset + element_size <= buffer.len() { - buffer[offset..offset + element_size].copy_from_slice(&bytes); - } + write_array_element( + buffer, + offset, + ArrayElementWriteValue::Bytes { + bytes: &bytes, + size: element_size, + }, + ); } else { return Err(vm.new_type_error(format!( "expected {} instance, not {}", @@ -920,9 +799,8 @@ impl PyCArray { Cow::Borrowed(slice) => { // SAFETY: For from_buffer, the slice points to writable shared memory. // Python's from_buffer requires writable buffer, so this is safe. - let ptr = slice.as_ptr() as *mut u8; - let len = slice.len(); - let owned_slice = unsafe { core::slice::from_raw_parts_mut(ptr, len) }; + let owned_slice = + unsafe { rustpython_host_env::ctypes::borrowed_slice_as_mut(slice) }; Self::write_element_to_buffer( owned_slice, final_offset, @@ -1179,8 +1057,9 @@ impl AsBuffer for PyCArray { fn char_array_get_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { let zelf = obj.downcast_ref::().unwrap(); let buffer = zelf.0.buffer.read(); - let len = buffer.iter().position(|&b| b == 0).unwrap_or(buffer.len()); - vm.ctx.new_bytes(buffer[..len].to_vec()).into() + vm.ctx + .new_bytes(char_array_field_value(&buffer).to_vec()) + .into() } // CharArray_set_value @@ -1196,10 +1075,7 @@ fn char_array_set_value(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachin return Err(vm.new_value_error("byte string too long")); } - buffer.to_mut()[..src.len()].copy_from_slice(src); - if src.len() < buffer.len() { - buffer.to_mut()[src.len()] = 0; - } + write_char_array_value(buffer.to_mut(), src); Ok(()) } @@ -1224,7 +1100,7 @@ fn char_array_set_raw( if src.len() > buffer.len() { return Err(vm.new_value_error("byte string too long")); } - buffer.to_mut()[..src.len()].copy_from_slice(&src); + write_char_array_raw(buffer.to_mut(), &src); Ok(()) } @@ -1246,22 +1122,9 @@ fn wchar_array_set_value( .downcast_ref::() .ok_or_else(|| vm.new_type_error("unicode string expected"))?; let mut buffer = zelf.0.buffer.write(); - let wchar_count = buffer.len() / WCHAR_SIZE; - let char_count = s.as_wtf8().code_points().count(); - - if char_count > wchar_count { - return Err(vm.new_value_error("string too long")); - } - - for (i, ch) in s.as_wtf8().code_points().enumerate() { - let offset = i * WCHAR_SIZE; - wchar_to_bytes(ch.to_u32(), &mut buffer.to_mut()[offset..]); - } - - let terminator_offset = char_count * WCHAR_SIZE; - if terminator_offset + WCHAR_SIZE <= buffer.len() { - wchar_to_bytes(0, &mut buffer.to_mut()[terminator_offset..]); - } + write_wchar_array_value(buffer.to_mut(), s.as_wtf8()).map_err(|err| match err { + WCharArrayWriteError::TooLong => vm.new_value_error("string too long"), + })?; Ok(()) } @@ -1309,57 +1172,3 @@ fn add_wchar_array_getsets(array_type: &Py, vm: &VirtualMachine) { .write() .insert(vm.ctx.intern_str("value"), value_getset.into()); } - -// wchar_t helpers - Platform-independent wide character handling -// Windows: sizeof(wchar_t) == 2 (UTF-16) -// Linux/macOS: sizeof(wchar_t) == 4 (UTF-32) - -/// Size of wchar_t on this platform -pub(super) const WCHAR_SIZE: usize = core::mem::size_of::(); - -/// Read a single wchar_t from bytes (platform-endian) -#[inline] -pub(super) fn wchar_from_bytes(bytes: &[u8]) -> Option { - if bytes.len() < WCHAR_SIZE { - return None; - } - Some(if WCHAR_SIZE == 2 { - u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 - } else { - u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) - }) -} - -/// Write a single wchar_t to bytes (platform-endian) -#[inline] -pub(super) fn wchar_to_bytes(ch: u32, buffer: &mut [u8]) { - if WCHAR_SIZE == 2 { - if buffer.len() >= 2 { - buffer[..2].copy_from_slice(&(ch as u16).to_ne_bytes()); - } - } else if buffer.len() >= 4 { - buffer[..4].copy_from_slice(&ch.to_ne_bytes()); - } -} - -/// Read a null-terminated wchar_t string from bytes, returns String -fn wstring_from_bytes(buffer: &[u8]) -> String { - let mut chars = Vec::new(); - for chunk in buffer.chunks(WCHAR_SIZE) { - if chunk.len() < WCHAR_SIZE { - break; - } - let code = if WCHAR_SIZE == 2 { - u16::from_ne_bytes([chunk[0], chunk[1]]) as u32 - } else { - u32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) - }; - if code == 0 { - break; // null terminator - } - if let Some(ch) = char::from_u32(code) { - chars.push(ch); - } - } - chars.into_iter().collect() -} diff --git a/crates/vm/src/stdlib/_ctypes/base.rs b/crates/vm/src/stdlib/_ctypes/base.rs index d70178be4e8..765bacd19a0 100644 --- a/crates/vm/src/stdlib/_ctypes/base.rs +++ b/crates/vm/src/stdlib/_ctypes/base.rs @@ -1,4 +1,5 @@ -use super::array::{WCHAR_SIZE, wchar_from_bytes, wchar_to_bytes}; +#![allow(unreachable_pub)] + use crate::builtins::{ PyBytes, PyDict, PyList, PyMemoryView, PyStr, PyTuple, PyType, PyTypeRef, PyUtf8Str, }; @@ -11,16 +12,15 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; use alloc::borrow::Cow; -use core::ffi::{ - c_double, c_float, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort, -}; use core::fmt::Debug; -use core::mem; use crossbeam_utils::atomic::AtomicCell; use num_traits::{Signed, ToPrimitive}; use rustpython_common::lock::PyRwLock; use rustpython_common::wtf8::Wtf8; -use widestring::WideChar; +use rustpython_host_env::ctypes::{ + CTypeParamKind, FfiArg, FfiType, FfiValue, char_array_assignment_bytes, char_array_field_value, + ffi_arg_from_value, ffi_type_for_layout, wchar_array_field_value, write_cow_bytes_at_offset, +}; // StgInfo - Storage information for ctypes types // Stored in TypeDataSlot of heap types (PyType::init_type_data/get_type_data) @@ -77,7 +77,7 @@ pub(super) enum ParamFunc { } #[derive(Clone)] -pub(crate) struct StgInfo { +pub struct StgInfo { pub initialized: bool, pub size: usize, // number of bytes pub align: usize, // alignment requirements @@ -100,7 +100,7 @@ pub(crate) struct StgInfo { pub big_endian: bool, // true if big endian, false if little endian // FFI field types for structure/union passing (inherited from base class) - pub ffi_field_types: Vec, + pub ffi_field_types: Vec, // Cached pointer type (non-inheritable via descriptor) pub pointer_type: Option, @@ -178,7 +178,7 @@ impl StgInfo { /// item_shape: the element's shape (will be prepended with length) /// item_flags: the element type's flags (for HASPOINTER inheritance) #[allow(clippy::too_many_arguments)] - pub(crate) fn new_array( + pub fn new_array( size: usize, align: usize, length: usize, @@ -227,78 +227,30 @@ impl StgInfo { /// Get libffi type for this StgInfo /// Note: For very large types, returns pointer type to avoid overflow - pub(crate) fn to_ffi_type(&self) -> libffi::middle::Type { - // Limit to avoid overflow in libffi (MAX_STRUCT_SIZE is platform-dependent) - const MAX_FFI_STRUCT_SIZE: usize = 1024 * 1024; // 1MB limit for safety - - match self.paramfunc { - ParamFunc::Structure | ParamFunc::Union => { - if !self.ffi_field_types.is_empty() { - libffi::middle::Type::structure(self.ffi_field_types.iter().cloned()) - } else if self.size <= MAX_FFI_STRUCT_SIZE { - // Small struct without field types: use bytes array - libffi::middle::Type::structure(core::iter::repeat_n( - libffi::middle::Type::u8(), - self.size, - )) - } else { - // Large struct: treat as pointer (passed by reference) - libffi::middle::Type::pointer() - } - } - ParamFunc::Array => { - if self.size > MAX_FFI_STRUCT_SIZE || self.length > MAX_FFI_STRUCT_SIZE { - // Large array: treat as pointer - libffi::middle::Type::pointer() - } else if let Some(ref fmt) = self.format { - let elem_type = Self::format_to_ffi_type(fmt); - libffi::middle::Type::structure(core::iter::repeat_n(elem_type, self.length)) - } else { - libffi::middle::Type::structure(core::iter::repeat_n( - libffi::middle::Type::u8(), - self.size, - )) - } - } - ParamFunc::Pointer => libffi::middle::Type::pointer(), - _ => { - // Simple type: derive from format - if let Some(ref fmt) = self.format { - Self::format_to_ffi_type(fmt) - } else { - libffi::middle::Type::u8() - } - } - } - } - - /// Convert format string to libffi type - fn format_to_ffi_type(fmt: &str) -> libffi::middle::Type { - // Strip endian prefix if present - let code = fmt.trim_start_matches(['<', '>', '!', '@', '=']); - match code { - "b" => libffi::middle::Type::i8(), - "B" => libffi::middle::Type::u8(), - "h" => libffi::middle::Type::i16(), - "H" => libffi::middle::Type::u16(), - "i" | "l" => libffi::middle::Type::i32(), - "I" | "L" => libffi::middle::Type::u32(), - "q" => libffi::middle::Type::i64(), - "Q" => libffi::middle::Type::u64(), - "f" => libffi::middle::Type::f32(), - "d" => libffi::middle::Type::f64(), - "P" | "z" | "Z" | "O" => libffi::middle::Type::pointer(), - _ => libffi::middle::Type::u8(), // default - } + pub fn to_ffi_type(&self) -> FfiType { + let kind = match self.paramfunc { + ParamFunc::Structure => CTypeParamKind::Structure, + ParamFunc::Union => CTypeParamKind::Union, + ParamFunc::Array => CTypeParamKind::Array, + ParamFunc::Pointer => CTypeParamKind::Pointer, + _ => CTypeParamKind::Simple, + }; + ffi_type_for_layout( + kind, + &self.ffi_field_types, + self.size, + self.length, + self.format.as_deref(), + ) } /// Check if this type is finalized (cannot set _fields_ again) - pub(crate) fn is_final(&self) -> bool { + pub fn is_final(&self) -> bool { self.flags.contains(StgInfoFlags::DICTFLAG_FINAL) } /// Get proto type reference (for Pointer/Array types) - pub(crate) fn proto(&self) -> &Py { + pub fn proto(&self) -> &Py { self.proto.as_deref().expect("type has proto") } } @@ -408,26 +360,13 @@ pub(super) static CDATA_BUFFER_METHODS: BufferMethods = BufferMethods { retain: |_| {}, }; -/// Convert Vec to Vec by reinterpreting the memory (same allocation). -fn vec_to_bytes(vec: Vec) -> Vec { - let len = vec.len() * core::mem::size_of::(); - let cap = vec.capacity() * core::mem::size_of::(); - let ptr = vec.as_ptr() as *mut u8; - core::mem::forget(vec); - unsafe { Vec::from_raw_parts(ptr, len, cap) } -} - /// Ensure PyBytes data is null-terminated. Returns (kept_alive_obj, pointer). /// The caller must keep the returned object alive to keep the pointer valid. pub(super) fn ensure_z_null_terminated( bytes: &PyBytes, vm: &VirtualMachine, ) -> (PyObjectRef, usize) { - let data = bytes.as_bytes(); - let mut buffer = data.to_vec(); - if !buffer.ends_with(&[0]) { - buffer.push(0); - } + let buffer = rustpython_host_env::ctypes::null_terminated_bytes(bytes.as_bytes()); let ptr = buffer.as_ptr() as usize; let kept_alive: PyObjectRef = vm.ctx.new_bytes(buffer).into(); (kept_alive, ptr) @@ -435,13 +374,8 @@ pub(super) fn ensure_z_null_terminated( /// Convert str to null-terminated wchar_t buffer. Returns (PyBytes holder, pointer). pub(super) fn str_to_wchar_bytes(s: &Wtf8, vm: &VirtualMachine) -> (PyObjectRef, usize) { - let wchars: Vec = s - .code_points() - .map(|cp| cp.to_u32() as libc::wchar_t) - .chain(core::iter::once(0)) - .collect(); - let ptr = wchars.as_ptr() as usize; - let bytes = vec_to_bytes(wchars); + let bytes = rustpython_host_env::ctypes::wchar_null_terminated_bytes(s); + let ptr = bytes.as_ptr() as usize; let holder: PyObjectRef = vm.ctx.new_bytes(bytes).into(); (holder, ptr) } @@ -449,7 +383,7 @@ pub(super) fn str_to_wchar_bytes(s: &Wtf8, vm: &VirtualMachine) -> (PyObjectRef, /// PyCData - base type for all ctypes data types #[pyclass(name = "_CData", module = "_ctypes")] #[derive(Debug, PyPayload)] -pub(crate) struct PyCData { +pub struct PyCData { /// Memory buffer - Owned (self-owned) or Borrowed (external reference) /// /// SAFETY: Borrowed variant's 'static lifetime is not actually static. @@ -501,7 +435,7 @@ impl PyCData { } /// Create from bytes with specified length (for arrays) - pub(crate) fn from_bytes_with_length( + pub fn from_bytes_with_length( data: Vec, objects: Option, length: usize, @@ -523,10 +457,10 @@ impl PyCData { /// The returned slice's 'static lifetime is a lie. /// Actually only valid for the lifetime of the memory pointed to by ptr. /// PyCData_AtAddress - pub(crate) unsafe fn at_address(ptr: *const u8, size: usize) -> Self { + pub unsafe fn at_address(ptr: *const u8, size: usize) -> Self { // = PyCData_AtAddress // SAFETY: Caller must ensure ptr is valid for the lifetime of returned PyCData - let slice: &'static [u8] = unsafe { core::slice::from_raw_parts(ptr, size) }; + let slice = unsafe { rustpython_host_env::ctypes::borrow_memory(ptr, size) }; Self { buffer: PyRwLock::new(Cow::Borrowed(slice)), base: PyRwLock::new(None), @@ -543,7 +477,7 @@ impl PyCData { /// Similar to from_base_with_offset, but also stores a copy of the data. /// This is used for arrays where we need our own buffer for the buffer protocol, /// but still maintain the base reference for KeepRef and tracking. - pub(crate) fn from_base_with_data( + pub fn from_base_with_data( base_obj: PyObjectRef, offset: usize, idx: usize, @@ -568,7 +502,7 @@ impl PyCData { /// /// # Safety /// ptr must point into base_obj's buffer and remain valid as long as base_obj is alive. - pub(crate) unsafe fn from_base_obj( + pub unsafe fn from_base_obj( ptr: *mut u8, size: usize, base_obj: PyObjectRef, @@ -576,7 +510,7 @@ impl PyCData { ) -> Self { // = PyCData_FromBaseObj // SAFETY: ptr points into base_obj's buffer, kept alive via base reference - let slice: &'static [u8] = unsafe { core::slice::from_raw_parts(ptr, size) }; + let slice = unsafe { rustpython_host_env::ctypes::borrow_memory(ptr, size) }; Self { buffer: PyRwLock::new(Cow::Borrowed(slice)), base: PyRwLock::new(Some(base_obj)), @@ -596,7 +530,7 @@ impl PyCData { /// /// # Safety /// ptr must point to valid memory that remains valid as long as source is alive. - pub(crate) unsafe fn from_buffer_shared( + pub unsafe fn from_buffer_shared( ptr: *const u8, size: usize, length: usize, @@ -604,7 +538,7 @@ impl PyCData { vm: &VirtualMachine, ) -> Self { // SAFETY: Caller must ensure ptr is valid for the lifetime of source - let slice: &'static [u8] = unsafe { core::slice::from_raw_parts(ptr, size) }; + let slice = unsafe { rustpython_host_env::ctypes::borrow_memory(ptr, size) }; // Python stores the reference in a dict with key "-1" (unique_key pattern) let objects_dict = vm.ctx.new_dict(); @@ -627,7 +561,7 @@ impl PyCData { /// Validates buffer, creates memoryview, and returns PyCData sharing memory with source. /// /// CDataType_from_buffer_impl - pub(crate) fn from_buffer_impl( + pub fn from_buffer_impl( cls: &Py, source: PyObjectRef, offset: isize, @@ -687,7 +621,7 @@ impl PyCData { /// Copies data from buffer and creates new independent instance. /// /// CDataType_from_buffer_copy_impl - pub(crate) fn from_buffer_copy_impl( + pub fn from_buffer_copy_impl( cls: &Py, source: &[u8], offset: isize, @@ -721,13 +655,13 @@ impl PyCData { } #[inline] - pub(crate) fn size(&self) -> usize { + pub fn size(&self) -> usize { self.buffer.read().len() } /// Check if this buffer is borrowed (external memory reference) #[inline] - pub(crate) fn is_borrowed(&self) -> bool { + pub fn is_borrowed(&self) -> bool { matches!(&*self.buffer.read(), Cow::Borrowed(_)) } @@ -738,35 +672,15 @@ impl PyCData { /// /// # Safety /// For borrowed buffers, caller must ensure the memory is writable. - pub(crate) fn write_bytes_at_offset(&self, offset: usize, bytes: &[u8]) { - let buffer = self.buffer.read(); - if offset + bytes.len() > buffer.len() { - return; // Out of bounds - } - - match &*buffer { - Cow::Borrowed(slice) => { - // For borrowed memory, write directly - // SAFETY: We assume the caller knows this memory is writable - // (e.g., from from_address pointing to a ctypes buffer) - unsafe { - let ptr = slice.as_ptr() as *mut u8; - core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr.add(offset), bytes.len()); - } - } - Cow::Owned(_) => { - // For owned memory, use to_mut() through write lock - drop(buffer); - let mut buffer = self.buffer.write(); - buffer.to_mut()[offset..offset + bytes.len()].copy_from_slice(bytes); - } - } + pub fn write_bytes_at_offset(&self, offset: usize, bytes: &[u8]) { + let mut buffer = self.buffer.write(); + write_cow_bytes_at_offset(&mut buffer, offset, bytes); } /// Generate unique key for nested references (unique_key) /// Creates a hierarchical key by walking up the b_base chain. /// Format: "index:parent_index:grandparent_index:..." - pub(crate) fn unique_key(&self, index: usize) -> String { + pub fn unique_key(&self, index: usize) -> String { let mut key = format!("{index:x}"); // Walk up the base chain to build hierarchical key if self.base.read().is_some() { @@ -785,12 +699,7 @@ impl PyCData { /// /// If this object has a base (is embedded in another structure/union/array), /// the reference is stored in the root object's b_objects with a hierarchical key. - pub(crate) fn keep_ref( - &self, - index: usize, - keep: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<()> { + pub fn keep_ref(&self, index: usize, keep: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { // Optimization: no need to store None if vm.is_none(&keep) { return Ok(()); @@ -850,7 +759,7 @@ impl PyCData { /// Walks up to root object (same as keep_ref) so the reference /// lives as long as the owning ctypes object. /// Uses unique_key (hierarchical) so nested fields don't collide. - pub(crate) fn keep_alive(&self, index: usize, obj: PyObjectRef) { + pub fn keep_alive(&self, index: usize, obj: PyObjectRef) { let key = self.unique_key(index); if let Some(base_obj) = self.base.read().clone() { let root = Self::find_root_object(&base_obj); @@ -923,7 +832,7 @@ impl PyCData { /// Get kept objects from a CData instance /// Returns the _objects of the CData, or an empty dict if None. - pub(crate) fn get_kept_objects(value: &PyObject, vm: &VirtualMachine) -> PyObjectRef { + pub fn get_kept_objects(value: &PyObject, vm: &VirtualMachine) -> PyObjectRef { value .downcast_ref::() .and_then(|cdata| cdata.objects.read().clone()) @@ -939,7 +848,7 @@ impl PyCData { /// PyCData_set /// Sets a field value at the given offset, handling type conversion and KeepRef #[allow(clippy::too_many_arguments)] - pub(crate) fn set_field( + pub fn set_field( &self, proto: &PyObject, value: PyObjectRef, @@ -957,7 +866,7 @@ impl PyCData { if is_char_array { if let Some(bytes_val) = value.downcast_ref::() { let src = bytes_val.as_bytes(); - let to_copy = PyCField::bytes_for_char_array(src); + let to_copy = char_array_assignment_bytes(src); let copy_len = core::cmp::min(to_copy.len(), size); self.write_bytes_at_offset(offset, &to_copy[..copy_len]); self.keep_ref(index, value, vm)?; @@ -969,17 +878,10 @@ impl PyCData { // For c_wchar arrays with str input, convert to wchar_t if is_wchar_array { if let Some(str_val) = value.downcast_ref::() { - // Convert str to wchar_t bytes (platform-dependent size) - let mut wchar_bytes = Vec::with_capacity(size); - for cp in str_val.as_wtf8().code_points().take(size / WCHAR_SIZE) { - let mut bytes = [0u8; 4]; - wchar_to_bytes(cp.to_u32(), &mut bytes); - wchar_bytes.extend_from_slice(&bytes[..WCHAR_SIZE]); - } - // Pad with nulls to fill the array - while wchar_bytes.len() < size { - wchar_bytes.push(0); - } + let wchar_bytes = rustpython_host_env::ctypes::encode_wtf8_to_wchar_padded( + str_val.as_wtf8(), + size, + ); self.write_bytes_at_offset(offset, &wchar_bytes); self.keep_ref(index, value, vm)?; return Ok(()); @@ -999,9 +901,8 @@ impl PyCData { let array_buffer = array.0.buffer.read(); array_buffer.as_ptr() as usize }; - let addr_bytes = buffer_addr.to_ne_bytes(); - let len = core::cmp::min(addr_bytes.len(), size); - self.write_bytes_at_offset(offset, &addr_bytes[..len]); + let addr_bytes = rustpython_host_env::ctypes::pointer_to_sized_bytes(buffer_addr, size); + self.write_bytes_at_offset(offset, &addr_bytes); self.keep_ref(index, value, vm)?; return Ok(()); } @@ -1055,13 +956,8 @@ impl PyCData { && let Some(bytes_val) = value.downcast_ref::() { let (kept_alive, ptr) = ensure_z_null_terminated(bytes_val, vm); - let mut result = vec![0u8; size]; - let addr_bytes = ptr.to_ne_bytes(); - let len = core::cmp::min(addr_bytes.len(), size); - result[..len].copy_from_slice(&addr_bytes[..len]); - if needs_swap { - result.reverse(); - } + let result = + rustpython_host_env::ctypes::pointer_to_sized_bytes_endian(ptr, size, needs_swap); self.write_bytes_at_offset(offset, &result); self.keep_ref(index, value, vm)?; self.keep_alive(index, kept_alive); @@ -1094,7 +990,7 @@ impl PyCData { /// PyCData_get /// Gets a field value at the given offset - pub(crate) fn get_field( + pub fn get_field( &self, proto: &PyObject, index: usize, @@ -1117,24 +1013,16 @@ impl PyCData { // c_char array → return bytes if PyCField::is_char_array(proto, vm) { let data = &buffer[offset..offset + size]; - // Find first null terminator (or use full length) - let end = data.iter().position(|&b| b == 0).unwrap_or(data.len()); - return Ok(vm.ctx.new_bytes(data[..end].to_vec()).into()); + return Ok(vm + .ctx + .new_bytes(char_array_field_value(data).to_vec()) + .into()); } // c_wchar array → return str if PyCField::is_wchar_array(proto, vm) { let data = &buffer[offset..offset + size]; - // wchar_t → char conversion, skip null - let chars: String = data - .chunks(WCHAR_SIZE) - .filter_map(|chunk| { - wchar_from_bytes(chunk) - .filter(|&wchar| wchar != 0) - .and_then(char::from_u32) - }) - .collect(); - return Ok(vm.ctx.new_str(chars).into()); + return Ok(vm.ctx.new_str(wchar_array_field_value(data)).into()); } // Other array types - create array with a copy of data from the base's buffer @@ -1186,7 +1074,7 @@ impl PyCData { buffer_data }; - return bytes_to_pyobject(&proto_type, &data, vm); + return Ok(bytes_to_pyobject(&proto_type, &data, vm)); } // Complex types: create ctypes instance via PyCData_FromBaseObj @@ -1299,24 +1187,21 @@ impl PyCData { .ok_or_else(|| vm.new_value_error("Invalid library handle"))? }; - // Look up the library in the cache and use lib.get() for symbol lookup - let library_cache = super::library::libcache().read(); - let library = library_cache - .get_lib(handle) - .ok_or_else(|| vm.new_value_error("Library not found"))?; - let inner_lib = library.lib.lock(); - let symbol_name_with_nul = format!("{}\0", name.as_wtf8()); - let ptr: *const u8 = if let Some(lib) = &*inner_lib { - unsafe { - lib.get::<*const u8>(symbol_name_with_nul.as_bytes()) - .map(|sym| *sym) - .map_err(|_| { - vm.new_value_error(format!("symbol '{}' not found", name.as_wtf8())) - })? + let ptr = match rustpython_host_env::ctypes::lookup_data_symbol_addr( + handle, + symbol_name_with_nul.as_bytes(), + ) { + Ok(ptr) => ptr as *const u8, + Err(rustpython_host_env::ctypes::LookupSymbolError::LibraryNotFound) => { + return Err(vm.new_value_error("Library not found")); + } + Err(rustpython_host_env::ctypes::LookupSymbolError::LibraryClosed) => { + return Err(vm.new_value_error("Library closed")); + } + Err(rustpython_host_env::ctypes::LookupSymbolError::Load(_)) => { + return Err(vm.new_value_error(format!("symbol '{}' not found", name.as_wtf8()))); } - } else { - return Err(vm.new_value_error("Library closed")); }; // dlsym can return NULL for symbols that resolve to NULL (e.g., GNU IFUNC) @@ -1336,7 +1221,7 @@ impl PyCData { /// CField descriptor for Structure/Union field access #[pyclass(name = "CField", module = "_ctypes")] #[derive(Debug, PyPayload)] -pub(crate) struct PyCField { +pub struct PyCField { /// Field name pub(crate) name: String, /// Byte offset of the field within the structure/union @@ -1357,7 +1242,7 @@ pub(crate) struct PyCField { impl PyCField { /// Create a new CField descriptor (non-bitfield) - pub(crate) fn new( + pub fn new( name: String, proto: PyTypeRef, offset: isize, @@ -1377,7 +1262,7 @@ impl PyCField { } /// Create a new CField descriptor for a bitfield - pub(crate) fn new_bitfield( + pub fn new_bitfield( name: String, proto: PyTypeRef, offset: isize, @@ -1399,7 +1284,7 @@ impl PyCField { } /// Get the byte size of the field's underlying type - pub(crate) fn get_byte_size(&self) -> usize { + pub fn get_byte_size(&self) -> usize { self.byte_size_val as usize } @@ -1419,7 +1304,7 @@ impl PyCField { } /// Set anonymous flag - pub(crate) fn set_anonymous(&mut self, anonymous: bool) { + pub fn set_anonymous(&mut self, anonymous: bool) { self.anonymous = anonymous; } } @@ -1584,29 +1469,19 @@ impl PyCField { fn value_to_bytes(value: &PyObject, size: usize, vm: &VirtualMachine) -> Vec { // 1. Handle bytes objects if let Some(bytes) = value.downcast_ref::() { - let src = bytes.as_bytes(); - let mut result = vec![0u8; size]; - let len = core::cmp::min(src.len(), size); - result[..len].copy_from_slice(&src[..len]); - result + rustpython_host_env::ctypes::copy_to_sized_bytes(bytes.as_bytes(), size) } // 2. Handle ctypes array instances (copy their buffer) else if let Some(cdata) = value.downcast_ref::() { let buffer = cdata.buffer.read(); - let mut result = vec![0u8; size]; - let len = core::cmp::min(buffer.len(), size); - result[..len].copy_from_slice(&buffer[..len]); - result + rustpython_host_env::ctypes::copy_to_sized_bytes(&buffer, size) } // 4. Handle float values (check before int, since float.try_int would truncate) else if let Some(float_val) = value.downcast_ref::() { let f = float_val.to_f64(); match size { - 4 => { - let val = f as f32; - val.to_ne_bytes().to_vec() - } - 8 => f.to_ne_bytes().to_vec(), + 4 | 8 => rustpython_host_env::ctypes::float_to_sized_bytes(f, size) + .expect("float size checked"), _ => unreachable!("wrong payload size"), } } @@ -1614,26 +1489,23 @@ impl PyCField { else if let Ok(int_val) = value.try_int(vm) { let i = int_val.as_bigint(); match size { - 1 => { - let val = i.to_i8().unwrap_or(0); - val.to_ne_bytes().to_vec() - } - 2 => { - let val = i.to_i16().unwrap_or(0); - val.to_ne_bytes().to_vec() - } - 4 => { - let val = i.to_i32().unwrap_or(0); - val.to_ne_bytes().to_vec() - } - 8 => { - let val = i.to_i64().unwrap_or(0); - val.to_ne_bytes().to_vec() - } - _ => vec![0u8; size], + 1 => rustpython_host_env::ctypes::int_to_sized_bytes( + i.to_i8().unwrap_or(0).into(), + size, + ), + 2 => rustpython_host_env::ctypes::int_to_sized_bytes( + i.to_i16().unwrap_or(0).into(), + size, + ), + 4 => rustpython_host_env::ctypes::int_to_sized_bytes( + i.to_i32().unwrap_or(0).into(), + size, + ), + 8 => rustpython_host_env::ctypes::int_to_sized_bytes(i.to_i64().unwrap_or(0), size), + _ => rustpython_host_env::ctypes::zeroed_bytes(size), } } else { - vec![0u8; size] + rustpython_host_env::ctypes::zeroed_bytes(size) } } @@ -1658,8 +1530,11 @@ impl PyCField { value.class().name() ))); }; - let val = f as f32; - Ok((val.to_ne_bytes().to_vec(), None)) + Ok(( + rustpython_host_env::ctypes::float_to_sized_bytes(f, 4) + .expect("c_float size is fixed"), + None, + )) } // c_double: always convert to float first (d_set) "d" => { @@ -1673,7 +1548,11 @@ impl PyCField { value.class().name() ))); }; - Ok((f.to_ne_bytes().to_vec(), None)) + Ok(( + rustpython_host_env::ctypes::float_to_sized_bytes(f, 8) + .expect("c_double size is fixed"), + None, + )) } // c_longdouble: convert to float (treated as f64 in RustPython) "g" => { @@ -1687,7 +1566,11 @@ impl PyCField { value.class().name() ))); }; - Ok((f.to_ne_bytes().to_vec(), None)) + Ok(( + rustpython_host_env::ctypes::float_to_sized_bytes(f, 8) + .expect("c_longdouble bytes are stored as f64"), + None, + )) } "z" => { // c_char_p with bytes is handled in set_field before this call. @@ -1695,15 +1578,14 @@ impl PyCField { // Integer address if let Ok(int_val) = value.try_index(vm) { let v = int_val.as_bigint().to_usize().unwrap_or(0); - let mut result = vec![0u8; size]; - let bytes = v.to_ne_bytes(); - let len = core::cmp::min(bytes.len(), size); - result[..len].copy_from_slice(&bytes[..len]); - return Ok((result, None)); + return Ok(( + rustpython_host_env::ctypes::pointer_to_sized_bytes(v, size), + None, + )); } // None -> NULL pointer if vm.is_none(value) { - return Ok((vec![0u8; size], None)); + return Ok((rustpython_host_env::ctypes::zeroed_bytes(size), None)); } Ok((Self::value_to_bytes(value, size, vm), None)) } @@ -1711,24 +1593,22 @@ impl PyCField { // c_wchar_p: store pointer to null-terminated wchar_t buffer if let Some(s) = value.downcast_ref::() { let (holder, ptr) = str_to_wchar_bytes(s.as_wtf8(), vm); - let mut result = vec![0u8; size]; - let addr_bytes = ptr.to_ne_bytes(); - let len = core::cmp::min(addr_bytes.len(), size); - result[..len].copy_from_slice(&addr_bytes[..len]); - return Ok((result, Some(holder))); + return Ok(( + rustpython_host_env::ctypes::pointer_to_sized_bytes(ptr, size), + Some(holder), + )); } // Integer address if let Ok(int_val) = value.try_index(vm) { let v = int_val.as_bigint().to_usize().unwrap_or(0); - let mut result = vec![0u8; size]; - let bytes = v.to_ne_bytes(); - let len = core::cmp::min(bytes.len(), size); - result[..len].copy_from_slice(&bytes[..len]); - return Ok((result, None)); + return Ok(( + rustpython_host_env::ctypes::pointer_to_sized_bytes(v, size), + None, + )); } // None -> NULL pointer if vm.is_none(value) { - return Ok((vec![0u8; size], None)); + return Ok((rustpython_host_env::ctypes::zeroed_bytes(size), None)); } Ok((Self::value_to_bytes(value, size, vm), None)) } @@ -1736,15 +1616,14 @@ impl PyCField { // c_void_p: store integer as pointer if let Ok(int_val) = value.try_index(vm) { let v = int_val.as_bigint().to_usize().unwrap_or(0); - let mut result = vec![0u8; size]; - let bytes = v.to_ne_bytes(); - let len = core::cmp::min(bytes.len(), size); - result[..len].copy_from_slice(&bytes[..len]); - return Ok((result, None)); + return Ok(( + rustpython_host_env::ctypes::pointer_to_sized_bytes(v, size), + None, + )); } // None -> NULL pointer if vm.is_none(value) { - return Ok((vec![0u8; size], None)); + return Ok((rustpython_host_env::ctypes::zeroed_bytes(size), None)); } Ok((Self::value_to_bytes(value, size, vm), None)) } @@ -1785,17 +1664,6 @@ impl PyCField { } false } - - /// Convert bytes for c_char array assignment (stops at first null terminator) - /// Returns (bytes_to_copy, copy_len) - fn bytes_for_char_array(src: &[u8]) -> &[u8] { - // Find first null terminator and include it - if let Some(null_pos) = src.iter().position(|&b| b == 0) { - &src[..=null_pos] - } else { - src - } - } } #[pyclass(flags(IMMUTABLETYPE), with(Representable, GetDescriptor, Constructor))] @@ -1987,7 +1855,7 @@ fn array_paramfunc(obj: &PyObject, vm: &VirtualMachine) -> PyResult Ok(CArgObject { tag: b'P', - value: FfiArgValue::Pointer(ptr_val), + value: FfiArgValue::pointer(ptr_val), obj: obj.to_owned(), size: 0, offset: 0, @@ -2007,7 +1875,7 @@ fn pointer_paramfunc(obj: &PyObject, vm: &VirtualMachine) -> PyResult Self { + Self::Scalar(FfiValue::Pointer(value)) + } + /// Create an Arg reference to this owned value - pub(crate) fn as_arg(&self) -> libffi::middle::Arg<'_> { + pub fn as_arg(&self) -> FfiArg<'_> { match self { - Self::U8(v) => libffi::middle::Arg::new(v), - Self::I8(v) => libffi::middle::Arg::new(v), - Self::U16(v) => libffi::middle::Arg::new(v), - Self::I16(v) => libffi::middle::Arg::new(v), - Self::U32(v) => libffi::middle::Arg::new(v), - Self::I32(v) => libffi::middle::Arg::new(v), - Self::U64(v) => libffi::middle::Arg::new(v), - Self::I64(v) => libffi::middle::Arg::new(v), - Self::F32(v) => libffi::middle::Arg::new(v), - Self::F64(v) => libffi::middle::Arg::new(v), - Self::Pointer(v) => libffi::middle::Arg::new(v), - Self::OwnedPointer(v, _) => libffi::middle::Arg::new(v), + Self::Scalar(value) => ffi_arg_from_value(value), + Self::OwnedPointer(v, _) => rustpython_host_env::ctypes::ffi_arg( + rustpython_host_env::ctypes::FfiArgRef::Pointer(v), + ), } } } /// Convert buffer bytes to FfiArgValue based on type code pub(super) fn buffer_to_ffi_value(type_code: &str, buffer: &[u8]) -> FfiArgValue { - match type_code { - "c" | "b" => { - let v = buffer.first().map_or(0, |&b| b as i8); - FfiArgValue::I8(v) - } - "B" => { - let v = buffer.first().copied().unwrap_or(0); - FfiArgValue::U8(v) - } - "h" => { - let v = buffer.first_chunk().copied().map_or(0, i16::from_ne_bytes); - FfiArgValue::I16(v) - } - "H" => { - let v = buffer.first_chunk().copied().map_or(0, u16::from_ne_bytes); - FfiArgValue::U16(v) - } - "i" => { - let v = buffer.first_chunk().copied().map_or(0, i32::from_ne_bytes); - FfiArgValue::I32(v) - } - "I" => { - let v = buffer.first_chunk().copied().map_or(0, u32::from_ne_bytes); - FfiArgValue::U32(v) - } - "l" | "q" => { - let v = if let Some(&bytes) = buffer.first_chunk::<8>() { - i64::from_ne_bytes(bytes) - } else if let Some(&bytes) = buffer.first_chunk::<4>() { - i32::from_ne_bytes(bytes).into() - } else { - 0 - }; - FfiArgValue::I64(v) - } - "L" | "Q" => { - let v = if let Some(&bytes) = buffer.first_chunk::<8>() { - u64::from_ne_bytes(bytes) - } else if let Some(&bytes) = buffer.first_chunk::<4>() { - u32::from_ne_bytes(bytes).into() - } else { - 0 - }; - FfiArgValue::U64(v) - } - "f" => { - let v = buffer - .first_chunk::<4>() - .copied() - .map_or(0.0, f32::from_ne_bytes); - FfiArgValue::F32(v) - } - "d" | "g" => { - let v = buffer - .first_chunk::<8>() - .copied() - .map_or(0.0, f64::from_ne_bytes); - FfiArgValue::F64(v) - } - "z" | "Z" | "P" | "O" => FfiArgValue::Pointer(read_ptr_from_buffer(buffer)), - "?" => { - let v = buffer.first().is_some_and(|&b| b != 0); - FfiArgValue::U8(if v { 1 } else { 0 }) - } - "u" => { - // wchar_t - 4 bytes on most platforms - let v = buffer.first_chunk().copied().map_or(0, u32::from_ne_bytes); - FfiArgValue::U32(v) - } - _ => FfiArgValue::Pointer(0), - } + FfiArgValue::Scalar(rustpython_host_env::ctypes::ffi_value_from_type_code( + type_code, buffer, + )) } /// Convert bytes to appropriate Python object based on ctypes type @@ -2163,204 +1949,41 @@ pub(super) fn bytes_to_pyobject( cls: &Py, bytes: &[u8], vm: &VirtualMachine, -) -> PyResult { +) -> PyObjectRef { // Try to get _type_ attribute if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) && let Ok(s) = type_attr.str(vm) { let ty = s.to_string(); - return match ty.as_str() { - "c" => Ok(vm.ctx.new_bytes(bytes.to_vec()).into()), - "b" => { - let val = if !bytes.is_empty() { bytes[0] as i8 } else { 0 }; - Ok(vm.ctx.new_int(val).into()) - } - "B" => { - let val = if !bytes.is_empty() { bytes[0] } else { 0 }; - Ok(vm.ctx.new_int(val).into()) - } - "h" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "H" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_ushort::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "i" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_int::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "I" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_uint::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) + return match rustpython_host_env::ctypes::decode_type_code(ty.as_str(), bytes) { + rustpython_host_env::ctypes::DecodedValue::Bytes(value) => { + vm.ctx.new_bytes(value).into() } - "l" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_long::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "L" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_ulong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) + rustpython_host_env::ctypes::DecodedValue::Signed(value) => { + vm.ctx.new_int(value).into() } - "q" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_longlong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) + rustpython_host_env::ctypes::DecodedValue::Unsigned(value) => { + vm.ctx.new_int(value).into() } - "Q" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_ulonglong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) + rustpython_host_env::ctypes::DecodedValue::Float(value) => { + vm.ctx.new_float(value).into() } - "f" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_float::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + rustpython_host_env::ctypes::DecodedValue::Bool(value) => vm.ctx.new_bool(value).into(), + rustpython_host_env::ctypes::DecodedValue::Pointer(value) => { + if value == 0 { + vm.ctx.none() } else { - 0.0 - }; - Ok(vm.ctx.new_float(val as f64).into()) - } - "d" => { - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_double::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0.0 - }; - Ok(vm.ctx.new_float(val).into()) - } - "g" => { - // long double - read as f64 for now since Rust doesn't have native long double - // This may lose precision on platforms where long double > 64 bits - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_double::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0.0 - }; - Ok(vm.ctx.new_float(val).into()) - } - "?" => { - let val = !bytes.is_empty() && bytes[0] != 0; - Ok(vm.ctx.new_bool(val).into()) - } - "v" => { - // VARIANT_BOOL: non-zero = True, zero = False - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_bool(val != 0).into()) - } - "z" => { - // c_char_p: read NULL-terminated string from pointer - let ptr = read_ptr_from_buffer(bytes); - if ptr == 0 { - return Ok(vm.ctx.none()); + vm.ctx.new_int(value).into() } - let c_str = unsafe { core::ffi::CStr::from_ptr(ptr as _) }; - Ok(vm.ctx.new_bytes(c_str.to_bytes().to_vec()).into()) } - "Z" => { - // c_wchar_p: read NULL-terminated wide string from pointer - let ptr = read_ptr_from_buffer(bytes); - if ptr == 0 { - return Ok(vm.ctx.none()); - } - let len = unsafe { libc::wcslen(ptr as *const libc::wchar_t) }; - let wchars = - unsafe { core::slice::from_raw_parts(ptr as *const libc::wchar_t, len) }; - // wchar_t is i32 on some platforms and u32 on others - #[allow( - clippy::unnecessary_cast, - reason = "wchar_t is i32 on some platforms and u32 on others" - )] - let s: String = wchars - .iter() - .filter_map(|&c| char::from_u32(c as u32)) - .collect(); - Ok(vm.ctx.new_str(s).into()) - } - "P" => { - // c_void_p: return pointer value as integer - let val = read_ptr_from_buffer(bytes); - if val == 0 { - return Ok(vm.ctx.none()); - } - Ok(vm.ctx.new_int(val).into()) - } - "O" => { - // py_object: return Python object from pointer - let ptr = read_ptr_from_buffer(bytes); - if ptr == 0 { - return Err(vm.new_value_error("PyObject is NULL")); - } - unsafe { - let obj = - PyObjectRef::from_raw(core::ptr::NonNull::new_unchecked(ptr as *mut _)); - Ok(obj) - } + rustpython_host_env::ctypes::DecodedValue::String(value) => { + vm.ctx.new_str(value).into() } - "u" => { - let val = if bytes.len() >= mem::size_of::() { - let wc = if mem::size_of::() == 2 { - u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 - } else { - u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) - }; - char::from_u32(wc).unwrap_or('\0') - } else { - '\0' - }; - Ok(vm.ctx.new_str(val).into()) - } - _ => Ok(vm.ctx.none()), + rustpython_host_env::ctypes::DecodedValue::None => vm.ctx.none(), }; } // Default: return bytes as-is - Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) + vm.ctx.new_bytes(bytes.to_vec()).into() } // Shared functions for Structure and Union types @@ -2385,16 +2008,6 @@ pub(super) fn get_usize_attr( Ok(val.to_usize().unwrap_or(default)) } -/// Read a pointer value from buffer -#[inline] -pub(super) fn read_ptr_from_buffer(buffer: &[u8]) -> usize { - const PTR_SIZE: usize = core::mem::size_of::(); - buffer - .first_chunk::() - .copied() - .map_or(0, usize::from_ne_bytes) -} - /// Check if a type is a "simple instance" (direct subclass of a simple type) /// Returns TRUE for c_int, c_void_p, etc. (simple types with _type_ attribute) /// Returns FALSE for Structure, Array, POINTER(T), etc. @@ -2469,8 +2082,9 @@ pub(super) fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> usiz .and_then(|type_attr| type_attr.str(vm).ok()) .and_then(|type_str| { let s = type_str.to_string(); - (s.len() == 1).then(|| super::get_size(&s)) + (s.len() == 1).then(|| rustpython_host_env::ctypes::simple_type_size(&s)) }) + .flatten() { return size; } @@ -2485,7 +2099,7 @@ pub(super) fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> usiz return s; } - core::mem::size_of::() + rustpython_host_env::ctypes::pointer_size() } /// Get the alignment of a ctypes field type @@ -2503,8 +2117,9 @@ pub(super) fn get_field_align(field_type: &PyObject, vm: &VirtualMachine) -> usi .and_then(|type_attr| type_attr.str(vm).ok()) .and_then(|type_str| { let s = type_str.to_string(); - (s.len() == 1).then(|| super::get_size(&s)) + (s.len() == 1).then(|| rustpython_host_env::ctypes::simple_type_align(&s)) }) + .flatten() { return align; } diff --git a/crates/vm/src/stdlib/_ctypes/function.rs b/crates/vm/src/stdlib/_ctypes/function.rs index e188a67d8c3..25cbcdcd9a1 100644 --- a/crates/vm/src/stdlib/_ctypes/function.rs +++ b/crates/vm/src/stdlib/_ctypes/function.rs @@ -1,11 +1,11 @@ // spell-checker:disable +#![allow(unreachable_pub)] use super::{ _ctypes::CArgObject, PyCArray, PyCData, PyCPointer, PyCStructure, StgInfo, base::{CDATA_BUFFER_METHODS, FfiArgValue, ParamFunc, StgInfoFlags}, simple::PyCSimple, - type_info, }; use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, @@ -19,13 +19,17 @@ use crate::{ use alloc::borrow::Cow; use core::ffi::c_void; use core::fmt::Debug; -use libffi::{ - low, - middle::{Arg, Cif, Closure, CodePtr, Type}, -}; -use libloading::Symbol; use num_traits::{Signed, ToPrimitive}; use rustpython_common::lock::PyRwLock; +#[cfg(windows)] +use rustpython_host_env::ctypes::ComMethodError; +use rustpython_host_env::ctypes::{ + CallResult as RawResult, FfiCif, FfiCodePtr, FfiType, FfiValue, RawMemoryView, + RawMemoryViewError, StringAtError, ffi_f64_type, ffi_i32_type, ffi_pointer_type, + ffi_type_for_return_size, ffi_type_from_code, ffi_type_from_tag, ffi_void_type, + has_pointer_width, null_code_ptr, offset_address, pointer_bytes, pointer_format, pointer_size, + write_pointer_to_buffer_at, write_prefix_limited, +}; // Internal function addresses for special ctypes functions pub(super) const INTERNAL_CAST_ADDR: usize = 1; @@ -33,138 +37,8 @@ pub(super) const INTERNAL_STRING_AT_ADDR: usize = 2; pub(super) const INTERNAL_WSTRING_AT_ADDR: usize = 3; pub(super) const INTERNAL_MEMORYVIEW_AT_ADDR: usize = 4; -// Thread-local errno storage for ctypes -std::thread_local! { - /// Thread-local storage for ctypes errno - /// This is separate from the system errno - ctypes swaps them during FFI calls - /// when use_errno=True is specified. - static CTYPES_LOCAL_ERRNO: core::cell::Cell = const { core::cell::Cell::new(0) }; -} - -/// Get ctypes thread-local errno value -pub(super) fn get_errno_value() -> i32 { - CTYPES_LOCAL_ERRNO.with(|e| e.get()) -} - -/// Set ctypes thread-local errno value, returns old value -pub(super) fn set_errno_value(value: i32) -> i32 { - CTYPES_LOCAL_ERRNO.with(|e| { - let old = e.get(); - e.set(value); - old - }) -} - -/// Save and restore errno around FFI call (called when use_errno=True) -/// Before: restore thread-local errno to system -/// After: save system errno to thread-local -#[cfg(not(windows))] -fn swap_errno(f: F) -> R -where - F: FnOnce() -> R, -{ - // Before call: restore thread-local errno to system - let saved = CTYPES_LOCAL_ERRNO.with(|e| e.get()); - errno::set_errno(errno::Errno(saved)); - - // Call the function - let result = f(); - - // After call: save system errno to thread-local - let new_error = errno::errno().0; - CTYPES_LOCAL_ERRNO.with(|e| e.set(new_error)); - - result -} - -#[cfg(windows)] -std::thread_local! { - /// Thread-local storage for ctypes last_error (Windows only) - static CTYPES_LOCAL_LAST_ERROR: core::cell::Cell = const { core::cell::Cell::new(0) }; -} - -#[cfg(windows)] -pub(super) fn get_last_error_value() -> u32 { - CTYPES_LOCAL_LAST_ERROR.with(|e| e.get()) -} - -#[cfg(windows)] -pub(super) fn set_last_error_value(value: u32) -> u32 { - CTYPES_LOCAL_LAST_ERROR.with(|e| { - let old = e.get(); - e.set(value); - old - }) -} - -/// Save and restore last_error around FFI call (called when use_last_error=True) -#[cfg(windows)] -fn save_and_restore_last_error(f: F) -> R -where - F: FnOnce() -> R, -{ - // Before call: restore thread-local last_error to Windows - let saved = CTYPES_LOCAL_LAST_ERROR.with(|e| e.get()); - unsafe { windows_sys::Win32::Foundation::SetLastError(saved) }; - - // Call the function - let result = f(); - - // After call: save Windows last_error to thread-local - let new_error = unsafe { windows_sys::Win32::Foundation::GetLastError() }; - CTYPES_LOCAL_LAST_ERROR.with(|e| e.set(new_error)); - - result -} - -type FP = unsafe extern "C" fn(); - -/// Get FFI type for a ctypes type code -fn get_ffi_type(ty: &str) -> Option { - type_info(ty).map(|t| (t.ffi_type_fn)()) -} - // PyCFuncPtr - Function pointer implementation -/// Get FFI type from CArgObject tag character -fn ffi_type_from_tag(tag: u8) -> Type { - match tag { - b'c' | b'b' => Type::i8(), - b'B' => Type::u8(), - b'h' => Type::i16(), - b'H' => Type::u16(), - b'i' => Type::i32(), - b'I' => Type::u32(), - b'l' => { - if core::mem::size_of::() == 8 { - Type::i64() - } else { - Type::i32() - } - } - b'L' => { - if core::mem::size_of::() == 8 { - Type::u64() - } else { - Type::u32() - } - } - b'q' => Type::i64(), - b'Q' => Type::u64(), - b'f' => Type::f32(), - b'd' | b'g' => Type::f64(), - b'?' => Type::u8(), - b'u' => { - if core::mem::size_of::() == 2 { - Type::u16() - } else { - Type::u32() - } - } - _ => Type::pointer(), // 'P', 'V', 'z', 'Z', 'O', etc. - } -} - /// Convert any object to a pointer value for c_void_p arguments /// Follows ConvParam logic for pointer types fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult { @@ -180,44 +54,44 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult NULL if value.is(&vm.ctx.none) { - return Ok(FfiArgValue::Pointer(0)); + return Ok(FfiArgValue::pointer(0)); } // 2. PyCArray -> buffer address (PyCArrayType_paramfunc) if let Some(array) = value.downcast_ref::() { let addr = array.0.buffer.read().as_ptr() as usize; - return Ok(FfiArgValue::Pointer(addr)); + return Ok(FfiArgValue::pointer(addr)); } // 3. PyCPointer -> stored pointer value if let Some(ptr) = value.downcast_ref::() { - return Ok(FfiArgValue::Pointer(ptr.get_ptr_value())); + return Ok(FfiArgValue::pointer(ptr.get_ptr_value())); } // 4. PyCStructure -> buffer address if let Some(struct_obj) = value.downcast_ref::() { let addr = struct_obj.0.buffer.read().as_ptr() as usize; - return Ok(FfiArgValue::Pointer(addr)); + return Ok(FfiArgValue::pointer(addr)); } // 5. PyCSimple (c_void_p, c_char_p, etc.) -> value from buffer if let Some(simple) = value.downcast_ref::() { let buffer = simple.0.buffer.read(); - if buffer.len() >= core::mem::size_of::() { - let addr = super::base::read_ptr_from_buffer(&buffer); - return Ok(FfiArgValue::Pointer(addr)); + if has_pointer_width(&buffer) { + let addr = rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer); + return Ok(FfiArgValue::pointer(addr)); } } // 6. bytes -> buffer address (PyBytes_AsString) if let Some(bytes) = value.downcast_ref::() { let addr = bytes.as_bytes().as_ptr() as usize; - return Ok(FfiArgValue::Pointer(addr)); + return Ok(FfiArgValue::pointer(addr)); } // 7. Integer -> direct value (PyLong_AsVoidPtr behavior) @@ -226,10 +100,10 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult PyResult { // 2. None -> NULL pointer if value.is(&vm.ctx.none) { return Ok(Argument { - ffi_type: Type::pointer(), + ffi_type: ffi_pointer_type(), keep: None, - value: FfiArgValue::Pointer(0), + value: FfiArgValue::pointer(0), }); } @@ -280,33 +154,26 @@ fn conv_param(value: &PyObject, vm: &VirtualMachine) -> PyResult { // 4. Python str -> wide string pointer (like PyUnicode_AsWideCharString) if let Some(s) = value.downcast_ref::() { - // Convert to null-terminated UTF-16, preserving lone surrogates - let wide: Vec = s - .as_wtf8() - .encode_wide() - .chain(core::iter::once(0)) - .collect(); - let wide_bytes: Vec = wide.iter().flat_map(|&x| x.to_ne_bytes()).collect(); + let wide_bytes = rustpython_host_env::ctypes::utf16z_bytes(s.as_wtf8()); let keep = vm.ctx.new_bytes(wide_bytes); let addr = keep.as_bytes().as_ptr() as usize; return Ok(Argument { - ffi_type: Type::pointer(), + ffi_type: ffi_pointer_type(), keep: Some(keep.into()), - value: FfiArgValue::Pointer(addr), + value: FfiArgValue::pointer(addr), }); } // 9. Python bytes -> null-terminated buffer pointer // Need to ensure null termination like c_char_p if let Some(bytes) = value.downcast_ref::() { - let mut buffer = bytes.as_bytes().to_vec(); - buffer.push(0); // Add null terminator + let buffer = rustpython_host_env::ctypes::null_terminated_bytes(bytes.as_bytes()); let keep = vm.ctx.new_bytes(buffer); let addr = keep.as_bytes().as_ptr() as usize; return Ok(Argument { - ffi_type: Type::pointer(), + ffi_type: ffi_pointer_type(), keep: Some(keep.into()), - value: FfiArgValue::Pointer(addr), + value: FfiArgValue::pointer(addr), }); } @@ -314,18 +181,18 @@ fn conv_param(value: &PyObject, vm: &VirtualMachine) -> PyResult { if let Ok(int_val) = value.try_int(vm) { let val = int_val.as_bigint().to_i32().unwrap_or(0); return Ok(Argument { - ffi_type: Type::i32(), + ffi_type: ffi_i32_type(), keep: None, - value: FfiArgValue::I32(val), + value: FfiArgValue::Scalar(FfiValue::I32(val)), }); } // 11. Python float -> f64 if let Ok(float_val) = value.try_float(vm) { return Ok(Argument { - ffi_type: Type::f64(), + ffi_type: ffi_f64_type(), keep: None, - value: FfiArgValue::F64(float_val.to_f64()), + value: FfiArgValue::Scalar(FfiValue::F64(float_val.to_f64())), }); } @@ -341,29 +208,29 @@ fn conv_param(value: &PyObject, vm: &VirtualMachine) -> PyResult { } trait ArgumentType { - fn to_ffi_type(&self, vm: &VirtualMachine) -> PyResult; + fn to_ffi_type(&self, vm: &VirtualMachine) -> PyResult; fn convert_object(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult; } impl ArgumentType for PyTypeRef { - fn to_ffi_type(&self, vm: &VirtualMachine) -> PyResult { + fn to_ffi_type(&self, vm: &VirtualMachine) -> PyResult { use super::pointer::PyCPointer; use super::structure::PyCStructure; // CArgObject (from byref()) should be treated as pointer if self.fast_issubclass(CArgObject::static_type()) { - return Ok(Type::pointer()); + return Ok(ffi_pointer_type()); } // Pointer types (POINTER(T)) are always pointer FFI type // Check if type is a subclass of _Pointer (PyCPointer) if self.fast_issubclass(PyCPointer::static_type()) { - return Ok(Type::pointer()); + return Ok(ffi_pointer_type()); } // Structure types are passed as pointers if self.fast_issubclass(PyCStructure::static_type()) { - return Ok(Type::pointer()); + return Ok(ffi_pointer_type()); } // Use get_attr to traverse MRO (for subclasses like MyInt(c_int)) @@ -377,7 +244,7 @@ impl ArgumentType for PyTypeRef { .ok_or_else(|| vm.new_type_error("Unsupported argument type"))?; let typ = typ.to_string(); let typ = typ.as_str(); - get_ffi_type(typ) + ffi_type_from_code(typ) .ok_or_else(|| vm.new_type_error(format!("Unsupported argument type: {typ}"))) } @@ -398,13 +265,13 @@ impl ArgumentType for PyTypeRef { // None -> NULL pointer if vm.is_none(&converted) { - return Ok(FfiArgValue::Pointer(0)); + return Ok(FfiArgValue::pointer(0)); } // For pointer types (POINTER(T)), we need to pass the pointer VALUE stored in buffer if self.fast_issubclass(PyCPointer::static_type()) { if let Some(pointer) = converted.downcast_ref::() { - return Ok(FfiArgValue::Pointer(pointer.get_ptr_value())); + return Ok(FfiArgValue::pointer(pointer.get_ptr_value())); } return convert_to_pointer(&converted, vm); } @@ -440,15 +307,15 @@ impl ArgumentType for PyTypeRef { } trait ReturnType { - fn to_ffi_type(&self, vm: &VirtualMachine) -> Option; + fn to_ffi_type(&self, vm: &VirtualMachine) -> Option; } impl ReturnType for PyTypeRef { - fn to_ffi_type(&self, vm: &VirtualMachine) -> Option { + fn to_ffi_type(&self, vm: &VirtualMachine) -> Option { // Try to get _type_ attribute first (for ctypes types like c_void_p) if let Ok(type_attr) = self.as_object().get_attr(vm.ctx.intern_str("_type_"), vm) && let Some(s) = type_attr.downcast_ref::() - && let Some(ffi_type) = s.to_str().and_then(get_ffi_type) + && let Some(ffi_type) = s.to_str().and_then(ffi_type_from_code) { return Some(ffi_type); } @@ -459,25 +326,17 @@ impl ReturnType for PyTypeRef { let size = stg_info.size; // Small structs can be returned in registers // Match can_return_struct_as_int/can_return_struct_as_sint64 - return Some(if size <= 4 { - Type::i32() - } else if size <= 8 { - Type::i64() - } else { - // Large structs: use pointer-sized return - // (ABI typically returns via hidden pointer parameter) - Type::pointer() - }); + return Some(ffi_type_for_return_size(size)); } // Fallback to class name - get_ffi_type(self.name().to_string().as_str()) + ffi_type_from_code(self.name().to_string().as_str()) } } impl ReturnType for PyNone { - fn to_ffi_type(&self, _vm: &VirtualMachine) -> Option { - get_ffi_type("void") + fn to_ffi_type(&self, _vm: &VirtualMachine) -> Option { + ffi_type_from_code("void") } } @@ -500,7 +359,7 @@ impl Initializer for PyCFuncPtrType { new_type.check_not_initialized(vm)?; - let ptr_size = core::mem::size_of::(); + let ptr_size = pointer_size(); let mut stg_info = StgInfo::new(ptr_size, ptr_size); stg_info.format = Some("X{}".to_string()); stg_info.length = 1; @@ -577,8 +436,8 @@ fn extract_ptr_from_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult if carg.offset != 0 && let Some(cdata) = carg.obj.downcast_ref::() { - let base = cdata.buffer.read().as_ptr() as isize; - return Ok((base + carg.offset) as usize); + let base = cdata.buffer.read().as_ptr() as usize; + return Ok(offset_address(base, carg.offset)); } return extract_ptr_from_arg(&carg.obj, vm); } @@ -588,8 +447,10 @@ fn extract_ptr_from_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult } if let Some(simple) = arg.downcast_ref::() { let buffer = simple.0.buffer.read(); - if let Some(&bytes) = buffer.first_chunk::<{ size_of::() }>() { - return Ok(usize::from_ne_bytes(bytes)); + if buffer.first_chunk::<{ size_of::() }>().is_some() { + return Ok(rustpython_host_env::ctypes::read_pointer_from_buffer( + &buffer, + )); } } if let Some(cdata) = arg.downcast_ref::() { @@ -616,63 +477,19 @@ fn extract_ptr_from_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult /// string_at implementation - read bytes from memory at ptr fn string_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { - if ptr == 0 { - return Err(vm.new_value_error("NULL pointer access")); + match rustpython_host_env::ctypes::string_at(ptr, size) { + Ok(bytes) => Ok(vm.ctx.new_bytes(bytes).into()), + Err(StringAtError::NullPointer) => Err(vm.new_value_error("NULL pointer access")), + Err(StringAtError::TooLong) => Err(vm.new_overflow_error("string too long")), } - let ptr = ptr as *const u8; - let len = if size < 0 { - // size == -1 means use strlen - unsafe { libc::strlen(ptr as _) } - } else { - // Overflow check for huge size values - let size_usize = size as usize; - if size_usize > isize::MAX as usize / 2 { - return Err(vm.new_overflow_error("string too long")); - } - size_usize - }; - let bytes = unsafe { core::slice::from_raw_parts(ptr, len) }; - Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) } /// wstring_at implementation - read wide string from memory at ptr fn wstring_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { - if ptr == 0 { - return Err(vm.new_value_error("NULL pointer access")); - } - let w_ptr = ptr as *const libc::wchar_t; - let len = if size < 0 { - unsafe { libc::wcslen(w_ptr) } - } else { - // Overflow check for huge size values - let size_usize = size as usize; - if size_usize > isize::MAX as usize / core::mem::size_of::() { - return Err(vm.new_overflow_error("string too long")); - } - size_usize - }; - let wchars = unsafe { core::slice::from_raw_parts(w_ptr, len) }; - - // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide - // macOS/Linux: wchar_t = i32 (UTF-32) -> convert via char::from_u32 - cfg_select! { - windows => { - use rustpython_common::wtf8::Wtf8Buf; - let wide: Vec = wchars.to_vec(); - let wtf8 = Wtf8Buf::from_wide(&wide); - Ok(vm.ctx.new_str(wtf8).into()) - } - _ => { - #[allow( - clippy::useless_conversion, - reason = "wchar_t is i32 on some platforms and u32 on others" - )] - let s: String = wchars - .iter() - .filter_map(|&c| u32::try_from(c).ok().and_then(char::from_u32)) - .collect(); - Ok(vm.ctx.new_str(s).into()) - } + match rustpython_host_env::ctypes::wstring_at(ptr, size) { + Ok(text) => Ok(vm.ctx.new_str(text).into()), + Err(StringAtError::NullPointer) => Err(vm.new_value_error("NULL pointer access")), + Err(StringAtError::TooLong) => Err(vm.new_overflow_error("string too long")), } } @@ -680,24 +497,18 @@ fn wstring_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { #[pyclass(name = "_RawMemoryBuffer", module = "_ctypes")] #[derive(Debug, PyPayload)] pub(super) struct RawMemoryBuffer { - ptr: *const u8, - size: usize, - readonly: bool, + memory: RawMemoryView, } -// SAFETY: The caller ensures the pointer remains valid -unsafe impl Send for RawMemoryBuffer {} -unsafe impl Sync for RawMemoryBuffer {} - static RAW_MEMORY_BUFFER_METHODS: crate::protocol::BufferMethods = crate::protocol::BufferMethods { obj_bytes: |buffer| { let raw = buffer.obj_as::(); - let slice = unsafe { core::slice::from_raw_parts(raw.ptr, raw.size) }; + let slice = unsafe { raw.memory.bytes() }; rustpython_common::borrow::BorrowedValue::Ref(slice) }, obj_bytes_mut: |buffer| { let raw = buffer.obj_as::(); - let slice = unsafe { core::slice::from_raw_parts_mut(raw.ptr as *mut u8, raw.size) }; + let slice = unsafe { raw.memory.bytes_mut() }; rustpython_common::borrow::BorrowedValueMut::RefMut(slice) }, release: |_| {}, @@ -711,7 +522,7 @@ impl AsBuffer for RawMemoryBuffer { fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { Ok(PyBuffer::new( zelf.to_owned().into(), - BufferDescriptor::simple(zelf.size, zelf.readonly), + BufferDescriptor::simple(zelf.memory.size(), zelf.memory.readonly()), &RAW_MEMORY_BUFFER_METHODS, )) } @@ -721,19 +532,11 @@ impl AsBuffer for RawMemoryBuffer { fn memoryview_at_impl(ptr: usize, size: isize, readonly: bool, vm: &VirtualMachine) -> PyResult { use crate::builtins::PyMemoryView; - if ptr == 0 { - return Err(vm.new_value_error("NULL pointer access")); - } - if size < 0 { - return Err(vm.new_value_error("negative size")); - } - let len = size as usize; - let raw_buf = RawMemoryBuffer { - ptr: ptr as *const u8, - size: len, - readonly, - } - .into_pyobject(vm); + let memory = RawMemoryView::new(ptr, size, readonly).map_err(|err| match err { + RawMemoryViewError::NullPointer => vm.new_value_error("NULL pointer access"), + RawMemoryViewError::NegativeSize => vm.new_value_error("negative size"), + })?; + let raw_buf = RawMemoryBuffer { memory }.into_pyobject(vm); let mv = PyMemoryView::from_object(&raw_buf, vm)?; Ok(mv.into_pyobject(vm)) } @@ -799,7 +602,7 @@ pub(super) fn cast_impl( } else if let Some(simple) = obj.downcast_ref::() { // Simple type (c_void_p, c_char_p, etc.) → value from buffer let buffer = simple.0.buffer.read(); - super::base::read_ptr_from_buffer(&buffer) + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) } else if let Some(cdata) = obj.downcast_ref::() { // Array, Structure, Union → buffer address (b_ptr) cdata.buffer.read().as_ptr() as usize @@ -858,12 +661,8 @@ pub(super) fn cast_impl( if let Some(ptr) = result.downcast_ref::() { ptr.set_ptr_value(ptr_value); } else if let Some(cdata) = result.downcast_ref::() { - let bytes = ptr_value.to_ne_bytes(); let mut buffer = cdata.buffer.write(); - let buf = buffer.to_mut(); - if buf.len() >= bytes.len() { - buf[..bytes.len()].copy_from_slice(&bytes); - } + write_pointer_to_buffer_at(buffer.to_mut(), 0, pointer_size(), ptr_value); } Ok(result) @@ -873,22 +672,18 @@ impl PyCFuncPtr { /// Get function pointer address from buffer fn get_func_ptr(&self) -> usize { let buffer = self._base.buffer.read(); - super::base::read_ptr_from_buffer(&buffer) + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) } /// Get CodePtr from buffer for FFI calls - fn get_code_ptr(&self) -> Option { + fn get_code_ptr(&self) -> Option { let addr = self.get_func_ptr(); - if addr != 0 { - Some(CodePtr(addr as *mut _)) - } else { - None - } + rustpython_host_env::ctypes::code_ptr_from_addr(addr) } /// Create buffer with function pointer address fn make_ptr_buffer(addr: usize) -> Vec { - addr.to_ne_bytes().to_vec() + pointer_bytes(addr) } } @@ -902,7 +697,7 @@ impl Constructor for PyCFuncPtr { // 3. Tuple argument: (name, dll) form // 4. Callable: callback creation - let ptr_size = core::mem::size_of::(); + let ptr_size = pointer_size(); if args.args.is_empty() { return Self { @@ -1017,32 +812,26 @@ impl Constructor for PyCFuncPtr { .as_bigint() .clone(), }; - let library_cache = super::library::libcache().read(); - let library = library_cache - .get_lib( - handle - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid handle"))?, - ) - .ok_or_else(|| vm.new_value_error("Library not found"))?; - let inner_lib = library.lib.lock(); - let terminated = format!("{}\0", &name); - let ptr_val = if let Some(lib) = &*inner_lib { - let pointer: Symbol<'_, FP> = unsafe { - lib.get(terminated.as_bytes()) - .map_err(|err| err.to_string()) - .map_err(|err| vm.new_attribute_error(err))? - }; - let addr = *pointer as usize; - // dlsym can return NULL for symbols that resolve to NULL (e.g., GNU IFUNC) - // Treat NULL addresses as errors - if addr == 0 { - return Err(vm.new_attribute_error(format!("function '{name}' not found"))); + let ptr_val = match rustpython_host_env::ctypes::lookup_function_symbol_addr( + handle + .to_usize() + .ok_or_else(|| vm.new_value_error("Invalid handle"))?, + terminated.as_bytes(), + ) { + Ok(addr) => { + if addr == 0 { + return Err(vm.new_attribute_error(format!("function '{name}' not found"))); + } + addr + } + Err(rustpython_host_env::ctypes::LookupSymbolError::LibraryNotFound) => { + return Err(vm.new_value_error("Library not found")); + } + Err(rustpython_host_env::ctypes::LookupSymbolError::LibraryClosed) => 0, + Err(rustpython_host_env::ctypes::LookupSymbolError::Load(err)) => { + return Err(vm.new_attribute_error(err)); } - addr - } else { - 0 }; return Self { @@ -1176,7 +965,7 @@ struct CallInfo { explicit_arg_types: Option>, restype_obj: Option, restype_is_none: bool, - ffi_return_type: Type, + ffi_return_type: FfiType, is_pointer_return: bool, } @@ -1223,13 +1012,13 @@ fn extract_call_info(zelf: &Py, vm: &VirtualMachine) -> PyResult().ok()) .and_then(|t| ReturnType::to_ffi_type(&t, vm)) - .unwrap_or_else(Type::i32) + .unwrap_or_else(ffi_i32_type) }; // Check if return type is a pointer type via TYPEFLAG_ISPOINTER @@ -1300,7 +1089,7 @@ fn resolve_com_method( zelf: &Py, args: &FuncArgs, vm: &VirtualMachine, -) -> PyResult<(Option, bool)> { +) -> PyResult<(Option, bool)> { let com_index = zelf.index.read(); let Some(idx) = *com_index else { return Ok((None, false)); @@ -1315,8 +1104,8 @@ fn resolve_com_method( let self_arg = &args.args[0]; let com_ptr = if let Some(simple) = self_arg.downcast_ref::() { let buffer = simple.0.buffer.read(); - if buffer.len() >= core::mem::size_of::() { - super::base::read_ptr_from_buffer(&buffer) + if has_pointer_width(&buffer) { + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) } else { 0 } @@ -1326,33 +1115,26 @@ fn resolve_com_method( return Err(vm.new_type_error("COM method first argument must be a COM pointer")); }; - if com_ptr == 0 { - return Err(vm.new_value_error("NULL COM pointer access")); - } - - // Read vtable pointer from COM object: vtable = *(void**)com_ptr - let vtable_ptr = unsafe { *(com_ptr as *const usize) }; - if vtable_ptr == 0 { - return Err(vm.new_value_error("NULL vtable pointer")); - } - - // Read function pointer from vtable: func = vtable[index] - let fptr = unsafe { - let vtable = vtable_ptr as *const usize; - *vtable.add(idx) + let code_ptr = match rustpython_host_env::ctypes::resolve_com_vtable_entry(com_ptr, idx) { + Ok(code_ptr) => code_ptr, + Err(ComMethodError::NullComPointer) => { + return Err(vm.new_value_error("NULL COM pointer access")); + } + Err(ComMethodError::NullVtablePointer) => { + return Err(vm.new_value_error("NULL vtable pointer")); + } + Err(ComMethodError::NullFunctionPointer) => { + return Err(vm.new_value_error("NULL function pointer in vtable")); + } }; - if fptr == 0 { - return Err(vm.new_value_error("NULL function pointer in vtable")); - } - - Ok((Some(CodePtr(fptr as *mut _)), true)) + Ok((Some(code_ptr), true)) } /// Single argument for FFI call // struct argument struct Argument { - ffi_type: Type, + ffi_type: FfiType, value: FfiArgValue, #[allow(dead_code)] keep: Option, // Object to keep alive during call @@ -1468,7 +1250,7 @@ fn build_callargs_with_paramflags( arguments.push(Argument { ffi_type, keep: None, - value: FfiArgValue::Pointer(addr), + value: FfiArgValue::pointer(addr), }); out_buffers.push((param_idx, buffer)); } else { @@ -1539,29 +1321,22 @@ fn build_callargs( } } -/// Raw result from FFI call -enum RawResult { - Void, - Pointer(usize), - Value(libffi::low::ffi_arg), -} - /// Execute FFI call -fn ctypes_callproc(code_ptr: CodePtr, arguments: &[Argument], call_info: &CallInfo) -> RawResult { - let ffi_arg_types: Vec = arguments.iter().map(|a| a.ffi_type.clone()).collect(); - let cif = Cif::new(ffi_arg_types, call_info.ffi_return_type.clone()); - let ffi_args: Vec> = arguments.iter().map(|a| a.value.as_arg()).collect(); - - if call_info.restype_is_none { - unsafe { cif.call::<()>(code_ptr, &ffi_args) }; - RawResult::Void - } else if call_info.is_pointer_return { - let result = unsafe { cif.call::(code_ptr, &ffi_args) }; - RawResult::Pointer(result) - } else { - let result = unsafe { cif.call::(code_ptr, &ffi_args) }; - RawResult::Value(result) - } +fn ctypes_callproc( + code_ptr: FfiCodePtr, + arguments: &[Argument], + call_info: &CallInfo, +) -> RawResult { + let ffi_arg_types: Vec = arguments.iter().map(|a| a.ffi_type.clone()).collect(); + let ffi_args: Vec<_> = arguments.iter().map(|a| a.value.as_arg()).collect(); + rustpython_host_env::ctypes::callproc( + code_ptr, + ffi_arg_types, + call_info.ffi_return_type.clone(), + &ffi_args, + call_info.restype_is_none, + call_info.is_pointer_return, + ) } /// Check and handle HRESULT errors (Windows) @@ -1607,17 +1382,7 @@ fn convert_raw_result( vm: &VirtualMachine, ) -> Option { // Get result as bytes for type conversion - let (result_bytes, result_size) = match raw_result { - RawResult::Void => return None, - RawResult::Pointer(ptr) => { - let bytes = ptr.to_ne_bytes(); - (bytes.to_vec(), core::mem::size_of::()) - } - RawResult::Value(val) => { - let bytes = val.to_ne_bytes(); - (bytes.to_vec(), core::mem::size_of::()) - } - }; + let (result_bytes, result_size) = rustpython_host_env::ctypes::call_result_bytes(raw_result)?; // 1. No restype → return as int let restype = match &call_info.restype_obj { @@ -1670,7 +1435,11 @@ fn convert_raw_result( // 5. Simple type with getfunc → use bytes_to_pyobject (info->getfunc) // is_simple_instance returns TRUE for c_int, c_void_p, etc. if super::base::is_simple_instance(&restype_type) { - return super::base::bytes_to_pyobject(&restype_type, &result_bytes, vm).ok(); + return Some(super::base::bytes_to_pyobject( + &restype_type, + &result_bytes, + vm, + )); } // 6. Complex type → create ctypes instance (PyCData_FromBaseObj) @@ -1705,10 +1474,7 @@ fn pycdata_from_ffi_result( // Copy result data into instance buffer if let Some(cdata) = instance.downcast_ref::() { let mut buffer = cdata.buffer.write(); - let copy_size = size.min(buffer.len()).min(result_bytes.len()); - if copy_size > 0 { - buffer.to_mut()[..copy_size].copy_from_slice(&result_bytes[..copy_size]); - } + write_prefix_limited(buffer.to_mut(), result_bytes, size); } Ok(instance) @@ -1782,7 +1548,7 @@ impl Callable for PyCFuncPtr { #[cfg(windows)] let (func_ptr, is_com_method) = resolve_com_method(zelf, &args, vm)?; #[cfg(not(windows))] - let (func_ptr, is_com_method) = (None::, false); + let (func_ptr, is_com_method) = (None::, false); // 3. Extract call info (argtypes, restype) let call_info = extract_call_info(zelf, vm)?; @@ -1800,7 +1566,7 @@ impl Callable for PyCFuncPtr { None => { debug_assert!(false, "NULL function pointer"); // In release mode, this will crash - CodePtr(core::ptr::null_mut()) + null_code_ptr() } }; @@ -1811,7 +1577,9 @@ impl Callable for PyCFuncPtr { #[cfg(not(windows))] let raw_result = { if flags & super::base::StgInfoFlags::FUNCFLAG_USE_ERRNO.bits() != 0 { - swap_errno(|| ctypes_callproc(code_ptr, &arguments, &call_info)) + rustpython_host_env::ctypes::with_swapped_errno(|| { + ctypes_callproc(code_ptr, &arguments, &call_info) + }) } else { ctypes_callproc(code_ptr, &arguments, &call_info) } @@ -1820,7 +1588,9 @@ impl Callable for PyCFuncPtr { #[cfg(windows)] let raw_result = { if flags & super::base::StgInfoFlags::FUNCFLAG_USE_LASTERROR.bits() != 0 { - save_and_restore_last_error(|| ctypes_callproc(code_ptr, &arguments, &call_info)) + rustpython_host_env::ctypes::with_swapped_last_error(|| { + ctypes_callproc(code_ptr, &arguments, &call_info) + }) } else { ctypes_callproc(code_ptr, &arguments, &call_info) } @@ -1855,7 +1625,7 @@ impl AsBuffer for PyCFuncPtr { stg_info.size, ) } else { - (Cow::Borrowed("X{}"), core::mem::size_of::()) + (Cow::Borrowed(pointer_format()), pointer_size()) }; let desc = BufferDescriptor { len: itemsize, @@ -1984,71 +1754,24 @@ fn is_simple_subclass(ty: &Py, vm: &VirtualMachine) -> bool { } /// Convert a C value to a Python object based on the type code. -fn ffi_to_python(ty: &Py, ptr: *const c_void, vm: &VirtualMachine) -> PyObjectRef { +fn ffi_to_python( + ty: &Py, + args: *const *const c_void, + index: usize, + vm: &VirtualMachine, +) -> PyObjectRef { let type_code = ty.type_code(vm); - let raw_value: PyObjectRef = unsafe { - match type_code.as_deref() { - Some("b") => vm.ctx.new_int(*(ptr as *const i8) as i32).into(), - Some("B") => vm.ctx.new_int(*(ptr as *const u8) as i32).into(), - Some("c") => vm.ctx.new_bytes(vec![*(ptr as *const u8)]).into(), - Some("h") => vm.ctx.new_int(*(ptr as *const i16) as i32).into(), - Some("H") => vm.ctx.new_int(*(ptr as *const u16) as i32).into(), - Some("i") => vm.ctx.new_int(*(ptr as *const i32)).into(), - Some("I") => vm.ctx.new_int(*(ptr as *const u32)).into(), - Some("l") => vm.ctx.new_int(*(ptr as *const libc::c_long)).into(), - Some("L") => vm.ctx.new_int(*(ptr as *const libc::c_ulong)).into(), - Some("q") => vm.ctx.new_int(*(ptr as *const libc::c_longlong)).into(), - Some("Q") => vm.ctx.new_int(*(ptr as *const libc::c_ulonglong)).into(), - Some("f") => vm.ctx.new_float(*(ptr as *const f32) as f64).into(), - Some("d") => vm.ctx.new_float(*(ptr as *const f64)).into(), - Some("z") => { - // c_char_p: C string pointer → Python bytes - let cstr_ptr = *(ptr as *const *const libc::c_char); - if cstr_ptr.is_null() { - vm.ctx.none() - } else { - let cstr = core::ffi::CStr::from_ptr(cstr_ptr); - vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into() - } - } - Some("Z") => { - // c_wchar_p: wchar_t* → Python str - let wstr_ptr = *(ptr as *const *const libc::wchar_t); - if wstr_ptr.is_null() { - vm.ctx.none() - } else { - let mut len = 0; - while *wstr_ptr.add(len) != 0 { - len += 1; - } - let slice = core::slice::from_raw_parts(wstr_ptr, len); - // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide - // Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32 - cfg_select! { - windows => { - use rustpython_common::wtf8::Wtf8Buf; - let wide: Vec = slice.to_vec(); - let wtf8 = Wtf8Buf::from_wide(&wide); - vm.ctx.new_str(wtf8).into() - } - _ => { - #[allow( - clippy::useless_conversion, - reason = "wchar_t is i32 on some platforms and u32 on others" - )] - let s: String = slice - .iter() - .filter_map(|&c| u32::try_from(c).ok().and_then(char::from_u32)) - .collect(); - vm.ctx.new_str(s).into() - } - } - } - } - Some("P") => vm.ctx.new_int(*(ptr as *const usize)).into(), - Some("?") => vm.ctx.new_bool(*(ptr as *const u8) != 0).into(), - _ => return vm.ctx.none(), - } + let raw_value: PyObjectRef = match unsafe { + rustpython_host_env::ctypes::callback_arg_value_at(type_code.as_deref(), args, index) + } { + rustpython_host_env::ctypes::DecodedValue::Bytes(value) => vm.ctx.new_bytes(value).into(), + rustpython_host_env::ctypes::DecodedValue::Signed(value) => vm.ctx.new_int(value).into(), + rustpython_host_env::ctypes::DecodedValue::Unsigned(value) => vm.ctx.new_int(value).into(), + rustpython_host_env::ctypes::DecodedValue::Float(value) => vm.ctx.new_float(value).into(), + rustpython_host_env::ctypes::DecodedValue::Bool(value) => vm.ctx.new_bool(value).into(), + rustpython_host_env::ctypes::DecodedValue::Pointer(value) => vm.ctx.new_int(value).into(), + rustpython_host_env::ctypes::DecodedValue::String(value) => vm.ctx.new_str(value).into(), + rustpython_host_env::ctypes::DecodedValue::None => vm.ctx.none(), }; if !is_simple_subclass(ty, vm) { @@ -2064,117 +1787,92 @@ fn python_to_ffi(obj: PyResult, ty: &Py, result: *mut c_void, vm: &Virtu let Ok(obj) = obj else { return }; let type_code = ty.type_code(vm); - unsafe { - match type_code.as_deref() { - Some("b") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i8) = i.as_bigint().to_i8().unwrap_or(0); - } - } - Some("B") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u8) = i.as_bigint().to_u8().unwrap_or(0); - } - } - Some("c") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u8) = i.as_bigint().to_u8().unwrap_or(0); - } - } - Some("h") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i16) = i.as_bigint().to_i16().unwrap_or(0); - } - } - Some("H") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u16) = i.as_bigint().to_u16().unwrap_or(0); - } - } - Some("i") => { - if let Ok(i) = obj.try_int(vm) { - let val = i.as_bigint().to_i32().unwrap_or(0); - *(result as *mut libffi::low::ffi_arg) = val as libffi::low::ffi_arg; + match type_code.as_deref() { + Some("b" | "h" | "i" | "l" | "q") => { + if let Ok(i) = obj.try_int(vm) { + unsafe { + rustpython_host_env::ctypes::write_callback_result( + type_code.as_deref(), + result, + rustpython_host_env::ctypes::CallbackResultValue::Signed( + i.as_bigint().to_i64().unwrap_or(0), + ), + ); } } - Some("I") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u32) = i.as_bigint().to_u32().unwrap_or(0); - } - } - Some("l" | "q") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i64) = i.as_bigint().to_i64().unwrap_or(0); - } - } - Some("L" | "Q") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u64) = i.as_bigint().to_u64().unwrap_or(0); - } - } - Some("f") => { - if let Ok(f) = obj.try_float(vm) { - *(result as *mut f32) = f.to_f64() as f32; + } + Some("B" | "c" | "H" | "I" | "L" | "Q") => { + if let Ok(i) = obj.try_int(vm) { + unsafe { + rustpython_host_env::ctypes::write_callback_result( + type_code.as_deref(), + result, + rustpython_host_env::ctypes::CallbackResultValue::Unsigned( + i.as_bigint().to_u64().unwrap_or(0), + ), + ); } } - Some("d") => { - if let Ok(f) = obj.try_float(vm) { - *(result as *mut f64) = f.to_f64(); + } + Some("f" | "d") => { + if let Ok(f) = obj.try_float(vm) { + unsafe { + rustpython_host_env::ctypes::write_callback_result( + type_code.as_deref(), + result, + rustpython_host_env::ctypes::CallbackResultValue::Float(f.to_f64()), + ); } } - Some("P" | "z" | "Z") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut usize) = i.as_bigint().to_usize().unwrap_or(0); + } + Some("P" | "z" | "Z") => { + if let Ok(i) = obj.try_int(vm) { + unsafe { + rustpython_host_env::ctypes::write_callback_result( + type_code.as_deref(), + result, + rustpython_host_env::ctypes::CallbackResultValue::Pointer( + i.as_bigint().to_usize().unwrap_or(0), + ), + ); } } - Some("?") => { - if let Ok(b) = obj.is_true(vm) { - *(result as *mut u8) = u8::from(b); + } + Some("?") => { + if let Ok(b) = obj.is_true(vm) { + unsafe { + rustpython_host_env::ctypes::write_callback_result( + type_code.as_deref(), + result, + rustpython_host_env::ctypes::CallbackResultValue::Bool(b), + ); } } - _ => {} } + _ => {} } } /// The callback function that libffi calls when the closure is invoked. unsafe extern "C" fn thunk_callback( - _cif: &low::ffi_cif, + _cif: &FfiCif, result: &mut c_void, args: *const *const c_void, userdata: &ThunkUserData, ) { with_current_vm(|vm| { - // Swap errno before call if FUNCFLAG_USE_ERRNO is set let use_errno = userdata.flags & StgInfoFlags::FUNCFLAG_USE_ERRNO.bits() != 0; - let saved_errno = if use_errno { - let current = rustpython_host_env::os::get_errno(); - // TODO: swap with ctypes stored errno (thread-local) - Some(current) - } else { - None - }; - - let py_args: Vec = userdata - .arg_types - .iter() - .enumerate() - .map(|(i, ty)| { - let arg_ptr = unsafe { *args.add(i) }; - ffi_to_python(ty, arg_ptr, vm) - }) - .collect(); - - let py_result = userdata.callable.call(py_args, vm); - - // Swap errno back after call - if use_errno { - let _current = rustpython_host_env::os::get_errno(); - // TODO: store current errno to ctypes storage - if let Some(saved) = saved_errno { - rustpython_host_env::os::set_errno(saved); - } - } + let py_result = + rustpython_host_env::ctypes::with_callback_errno_preserved(use_errno, || { + let py_args: Vec = userdata + .arg_types + .iter() + .enumerate() + .map(|(i, ty)| ffi_to_python(ty, args, i, vm)) + .collect(); + + userdata.callable.call(py_args, vm) + }); // Call unraisable hook if exception occurred if let Err(exc) = &py_result { @@ -2192,29 +1890,14 @@ unsafe extern "C" fn thunk_callback( }); } -/// Holds the closure and userdata together to ensure proper lifetime. -struct ThunkData { - #[allow(dead_code)] - closure: Closure<'static>, - userdata_ptr: *mut ThunkUserData, -} - -impl Drop for ThunkData { - fn drop(&mut self) { - unsafe { - drop(Box::from_raw(self.userdata_ptr)); - } - } -} - /// CThunkObject wraps a Python callable to make it callable from C code. #[pyclass(name = "CThunkObject", module = "_ctypes")] #[derive(PyPayload)] pub(super) struct PyCThunk { callable: PyObjectRef, #[allow(dead_code)] - thunk_data: PyRwLock>, - code_ptr: CodePtr, + thunk_data: PyRwLock>>, + code_ptr: FfiCodePtr, } impl Debug for PyCThunk { @@ -2226,7 +1909,7 @@ impl Debug for PyCThunk { } impl PyCThunk { - pub(super) fn new( + pub fn new( callable: PyObjectRef, arg_types: Option, res_type: Option, @@ -2254,39 +1937,33 @@ impl PyCThunk { _ => None, }; - let ffi_arg_types: Vec = arg_type_vec + let ffi_arg_types: Vec = arg_type_vec .iter() .map(|ty| { ty.type_code(vm) - .and_then(|code| get_ffi_type(&code)) - .unwrap_or_else(Type::pointer) + .and_then(|code| ffi_type_from_code(&code)) + .unwrap_or_else(ffi_pointer_type) }) .collect(); let ffi_res_type = res_type_ref .as_ref() .and_then(|ty| ty.type_code(vm)) - .and_then(|code| get_ffi_type(&code)) - .unwrap_or_else(Type::void); - - let cif = Cif::new(ffi_arg_types, ffi_res_type); - - let userdata = Box::new(ThunkUserData { - callable: callable.clone(), - arg_types: arg_type_vec, - res_type: res_type_ref, - flags, - }); - let userdata_ptr = Box::into_raw(userdata); - let userdata_ref: &'static ThunkUserData = unsafe { &*userdata_ptr }; - - let closure = Closure::new(cif, thunk_callback, userdata_ref); - let code_ptr = CodePtr(*closure.code_ptr() as *mut _); - - let thunk_data = ThunkData { - closure, - userdata_ptr, - }; + .and_then(|code| ffi_type_from_code(&code)) + .unwrap_or_else(ffi_void_type); + + let thunk_data = rustpython_host_env::ctypes::CallbackThunk::new( + ffi_arg_types, + ffi_res_type, + Box::new(ThunkUserData { + callable: callable.clone(), + arg_types: arg_type_vec, + res_type: res_type_ref, + flags, + }), + thunk_callback, + ); + let code_ptr = thunk_data.code_ptr(); Ok(Self { callable, @@ -2295,7 +1972,7 @@ impl PyCThunk { }) } - pub(super) fn code_ptr(&self) -> CodePtr { + pub fn code_ptr(&self) -> FfiCodePtr { self.code_ptr } } diff --git a/crates/vm/src/stdlib/_ctypes/library.rs b/crates/vm/src/stdlib/_ctypes/library.rs deleted file mode 100644 index ac9059864d6..00000000000 --- a/crates/vm/src/stdlib/_ctypes/library.rs +++ /dev/null @@ -1,150 +0,0 @@ -use crate::VirtualMachine; -use alloc::fmt; -use libloading::Library; -use rustpython_common::lock::{PyMutex, PyRwLock}; -use std::collections::HashMap; -use std::ffi::OsStr; - -#[cfg(unix)] -use libloading::os::unix::Library as UnixLibrary; - -pub(super) struct SharedLibrary { - pub(crate) lib: PyMutex>, -} - -impl fmt::Debug for SharedLibrary { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SharedLibrary") - } -} - -impl SharedLibrary { - #[cfg(windows)] - pub(super) fn new(name: impl AsRef) -> Result { - Ok(Self { - lib: PyMutex::new(unsafe { Some(Library::new(name.as_ref())?) }), - }) - } - - #[cfg(unix)] - pub(super) fn new_with_mode( - name: impl AsRef, - mode: i32, - ) -> Result { - Ok(Self { - lib: PyMutex::new(Some(unsafe { - UnixLibrary::open(Some(name.as_ref()), mode)?.into() - })), - }) - } - - /// Create a SharedLibrary from a raw dlopen handle (for pythonapi / dlopen(NULL)) - #[cfg(unix)] - pub(super) fn from_raw_handle(handle: *mut libc::c_void) -> Self { - Self { - lib: PyMutex::new(Some(unsafe { UnixLibrary::from_raw(handle).into() })), - } - } - - /// Get the underlying OS handle (HMODULE on Windows, dlopen handle on Unix) - pub(super) fn get_pointer(&self) -> usize { - let lib_lock = self.lib.lock(); - if let Some(l) = &*lib_lock { - // libloading::Library internally stores the OS handle directly - // On Windows: HMODULE (*mut c_void) - // On Unix: *mut c_void from dlopen - // We use transmute_copy to read the handle without consuming the Library - unsafe { core::mem::transmute_copy::(l) } - } else { - 0 - } - } - - fn is_closed(&self) -> bool { - let lib_lock = self.lib.lock(); - lib_lock.is_none() - } -} - -pub(super) struct ExternalLibs { - libraries: HashMap, -} - -impl ExternalLibs { - fn new() -> Self { - Self { - libraries: HashMap::new(), - } - } - - pub(super) fn get_lib(&self, key: usize) -> Option<&SharedLibrary> { - self.libraries.get(&key) - } - - #[cfg(windows)] - pub(super) fn get_or_insert_lib( - &mut self, - library_path: impl AsRef, - _vm: &VirtualMachine, - ) -> Result<(usize, &SharedLibrary), libloading::Error> { - let new_lib = SharedLibrary::new(library_path)?; - let key = new_lib.get_pointer(); - - // Check if library already exists and is not closed - let should_use_cached = self.libraries.get(&key).is_some_and(|l| !l.is_closed()); - - if should_use_cached { - // new_lib will be dropped, calling FreeLibrary (decrements refcount) - // But library stays loaded because cached version maintains refcount - drop(new_lib); - return Ok((key, self.libraries.get(&key).expect("just checked"))); - } - - self.libraries.insert(key, new_lib); - Ok((key, self.libraries.get(&key).expect("just inserted"))) - } - - #[cfg(unix)] - pub(super) fn get_or_insert_lib_with_mode( - &mut self, - library_path: impl AsRef, - mode: i32, - _vm: &VirtualMachine, - ) -> Result<(usize, &SharedLibrary), libloading::Error> { - let new_lib = SharedLibrary::new_with_mode(library_path, mode)?; - let key = new_lib.get_pointer(); - - // Check if library already exists and is not closed - let should_use_cached = self.libraries.get(&key).is_some_and(|l| !l.is_closed()); - - if should_use_cached { - // new_lib will be dropped, calling dlclose (decrements refcount) - // But library stays loaded because cached version maintains refcount - drop(new_lib); - return Ok((key, self.libraries.get(&key).expect("just checked"))); - } - - self.libraries.insert(key, new_lib); - Ok((key, self.libraries.get(&key).expect("just inserted"))) - } - - /// Insert a raw dlopen handle into the cache (for pythonapi / dlopen(NULL)) - #[cfg(unix)] - pub(super) fn insert_raw_handle(&mut self, handle: *mut libc::c_void) -> usize { - let shared_lib = SharedLibrary::from_raw_handle(handle); - let key = handle as usize; - self.libraries.insert(key, shared_lib); - key - } - - pub(super) fn drop_lib(&mut self, key: usize) { - self.libraries.remove(&key); - } -} - -pub(super) fn libcache() -> &'static PyRwLock { - rustpython_common::static_cell! { - static LIBCACHE: PyRwLock; - } - LIBCACHE.get_or_init(|| PyRwLock::new(ExternalLibs::new())) -} diff --git a/crates/vm/src/stdlib/_ctypes/pointer.rs b/crates/vm/src/stdlib/_ctypes/pointer.rs index 71b8455b83e..be165ee9f8b 100644 --- a/crates/vm/src/stdlib/_ctypes/pointer.rs +++ b/crates/vm/src/stdlib/_ctypes/pointer.rs @@ -11,6 +11,10 @@ use crate::{ }; use alloc::borrow::Cow; use num_traits::ToPrimitive; +use rustpython_host_env::ctypes::{ + AddressValue, AddressWriteValue, IntegerValue, pointer_item_address, read_pointer_char_slice, + read_pointer_wchar_slice, +}; #[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")] #[derive(Debug)] @@ -46,7 +50,7 @@ impl Initializer for PyCPointerType { } // Initialize StgInfo for pointer type - let pointer_size = core::mem::size_of::(); + let pointer_size = rustpython_host_env::ctypes::pointer_size(); let mut stg_info = StgInfo::new(pointer_size, pointer_size); stg_info.proto = proto; stg_info.paramfunc = super::base::ParamFunc::Pointer; @@ -263,7 +267,7 @@ impl Constructor for PyCPointer { // Create a new PyCPointer instance with NULL pointer (all zeros) // Initial contents is set via __init__ if provided - let cdata = PyCData::from_bytes(vec![0u8; core::mem::size_of::()], None); + let cdata = PyCData::from_bytes(rustpython_host_env::ctypes::null_pointer_bytes(), None); // pointer instance has b_length set to 2 (for index 0 and 1) cdata.length.store(2); Self(cdata).into_ref_with_type(vm, cls).map(Into::into) @@ -296,16 +300,18 @@ impl PyCPointer { /// Get the pointer value stored in buffer as usize pub(crate) fn get_ptr_value(&self) -> usize { let buffer = self.0.buffer.read(); - super::base::read_ptr_from_buffer(&buffer) + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) } /// Set the pointer value in buffer pub(crate) fn set_ptr_value(&self, value: usize) { let mut buffer = self.0.buffer.write(); - let bytes = value.to_ne_bytes(); - if buffer.len() >= bytes.len() { - buffer.to_mut()[..bytes.len()].copy_from_slice(&bytes); - } + rustpython_host_env::ctypes::write_pointer_to_buffer_at( + buffer.to_mut(), + 0, + rustpython_host_env::ctypes::pointer_size(), + value, + ); } /// contents getter - reads address from b_ptr and creates an instance of the pointed-to type @@ -322,7 +328,7 @@ impl PyCPointer { let proto_type = stg_info.proto(); let element_size = proto_type .stg_info_opt() - .map_or(core::mem::size_of::(), |info| info.size); + .map_or_else(rustpython_host_env::ctypes::pointer_size, |info| info.size); // Create instance that references the memory directly // PyCData.into_ref_with_type works for all ctypes (simple, structure, union, array, pointer) @@ -405,11 +411,10 @@ impl PyCPointer { let proto_type = stg_info.proto(); let element_size = proto_type .stg_info_opt() - .map_or(core::mem::size_of::(), |info| info.size); + .map_or_else(rustpython_host_env::ctypes::pointer_size, |info| info.size); // offset = index * iteminfo->size - let offset = index * element_size as isize; - let addr = (ptr_value as isize + offset) as usize; + let addr = pointer_item_address(ptr_value, index, element_size); // Check if it's a simple type (has _type_ attribute) if let Ok(type_attr) = proto_type.as_object().get_attr("_type_", vm) @@ -495,7 +500,7 @@ impl PyCPointer { let element_size = if let Some(ref proto_type) = stg_info.proto { proto_type.stg_info_opt().expect("proto has StgInfo").size } else { - core::mem::size_of::() + rustpython_host_env::ctypes::pointer_size() }; let type_code = stg_info .proto @@ -511,23 +516,8 @@ impl PyCPointer { if len == 0 { return Ok(vm.ctx.new_bytes(vec![]).into()); } - let mut result = Vec::with_capacity(len); - if step == 1 { - // Optimized contiguous copy - let start_addr = (ptr_value as isize + start * element_size as isize) as *const u8; - unsafe { - result.extend_from_slice(core::slice::from_raw_parts(start_addr, len)); - } - } else { - let mut cur = start; - for _ in 0..len { - let addr = (ptr_value as isize + cur * element_size as isize) as *const u8; - unsafe { - result.push(*addr); - } - cur += step; - } - } + let result = + unsafe { read_pointer_char_slice(ptr_value, start, len, step, element_size) }; return Ok(vm.ctx.new_bytes(result).into()); } @@ -536,23 +526,10 @@ impl PyCPointer { if len == 0 { return Ok(vm.ctx.new_str("").into()); } - let mut result = String::with_capacity(len); - let wchar_size = core::mem::size_of::(); - let mut cur = start; - for _ in 0..len { - let addr = (ptr_value as isize + cur * wchar_size as isize) as *const libc::wchar_t; - unsafe { - #[allow( - clippy::unnecessary_cast, - reason = "wchar_t is i32 on some platforms and u32 on others" - )] - if let Some(c) = char::from_u32(*addr as u32) { - result.push(c); - } - } - cur += step; - } - return Ok(vm.ctx.new_str(result).into()); + return Ok(vm + .ctx + .new_str(unsafe { read_pointer_wchar_slice(ptr_value, start, len, step) }) + .into()); } // other types → list with Pointer_item for each @@ -608,11 +585,10 @@ impl PyCPointer { let element_size = proto_type .stg_info_opt() - .map_or(core::mem::size_of::(), |info| info.size); + .map_or_else(rustpython_host_env::ctypes::pointer_size, |info| info.size); // Calculate address - let offset = index * element_size as isize; - let addr = (ptr_value as isize + offset) as usize; + let addr = pointer_item_address(ptr_value, index, element_size); // Write value at address // Handle Structure/Array types by copying their buffer @@ -622,10 +598,8 @@ impl PyCPointer { || cdata.fast_isinstance(PyCSimple::static_type())) { let src_buffer = cdata.buffer.read(); - let copy_len = src_buffer.len().min(element_size); unsafe { - let dest_ptr = addr as *mut u8; - core::ptr::copy_nonoverlapping(src_buffer.as_ptr(), dest_ptr, copy_len); + rustpython_host_env::ctypes::copy_bytes_to_address(addr, &src_buffer, element_size); } } else { // Handle z/Z specially to store converted value @@ -634,7 +608,11 @@ impl PyCPointer { { let (kept_alive, ptr_val) = super::base::ensure_z_null_terminated(bytes, vm); unsafe { - *(addr as *mut usize) = ptr_val; + rustpython_host_env::ctypes::write_value_to_address( + addr, + element_size, + AddressWriteValue::Pointer(ptr_val), + ); } zelf.0.keep_alive(index as usize, kept_alive); return zelf.0.keep_ref(index as usize, value.clone(), vm); @@ -643,7 +621,11 @@ impl PyCPointer { { let (holder, ptr_val) = super::base::str_to_wchar_bytes(s.as_wtf8(), vm); unsafe { - *(addr as *mut usize) = ptr_val; + rustpython_host_env::ctypes::write_value_to_address( + addr, + element_size, + AddressWriteValue::Pointer(ptr_val), + ); } return zelf.0.keep_ref(index as usize, holder, vm); } @@ -661,56 +643,13 @@ impl PyCPointer { type_code: Option<&str>, vm: &VirtualMachine, ) -> PyObjectRef { - unsafe { - let ptr = addr as *const u8; - match type_code { - // Single-byte types don't need read_unaligned - Some("c") => vm.ctx.new_bytes(vec![*ptr]).into(), - Some("b") => vm.ctx.new_int(*ptr as i8 as i32).into(), - Some("B") => vm.ctx.new_int(*ptr as i32).into(), - // Multi-byte types need read_unaligned for safety on strict-alignment architectures - Some("h") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const i16) as i32) - .into(), - Some("H") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const u16) as i32) - .into(), - Some("i" | "l") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const i32)) - .into(), - Some("I" | "L") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const u32)) - .into(), - Some("q") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const i64)) - .into(), - Some("Q") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const u64)) - .into(), - Some("f") => vm - .ctx - .new_float(core::ptr::read_unaligned(ptr as *const f32) as f64) - .into(), - Some("d" | "g") => vm - .ctx - .new_float(core::ptr::read_unaligned(ptr as *const f64)) - .into(), - Some("P" | "z" | "Z") => vm - .ctx - .new_int(core::ptr::read_unaligned(ptr as *const usize)) - .into(), - _ => { - // Default: read as bytes - let bytes = core::slice::from_raw_parts(ptr, size).to_vec(); - vm.ctx.new_bytes(bytes).into() - } - } + match unsafe { rustpython_host_env::ctypes::read_value_at_address(addr, size, type_code) } { + AddressValue::ByteString(byte) => vm.ctx.new_bytes(vec![byte]).into(), + AddressValue::Integer(IntegerValue::Signed(value)) => vm.ctx.new_int(value).into(), + AddressValue::Integer(IntegerValue::Unsigned(value)) => vm.ctx.new_int(value).into(), + AddressValue::Float(value) => vm.ctx.new_float(value).into(), + AddressValue::Pointer(value) => vm.ctx.new_int(value).into(), + AddressValue::Bytes(bytes) => vm.ctx.new_bytes(bytes).into(), } } @@ -723,8 +662,6 @@ impl PyCPointer { vm: &VirtualMachine, ) -> PyResult<()> { unsafe { - let ptr = addr as *mut u8; - // Handle c_char_p (z) and c_wchar_p (Z) - store pointer address // Note: PyBytes/PyStr cases are handled by caller (setitem_by_index) if let Some("z" | "Z") = type_code { @@ -735,7 +672,11 @@ impl PyCPointer { } else { return Err(vm.new_type_error("bytes/string or integer address expected")); }; - core::ptr::write_unaligned(ptr as *mut usize, ptr_val); + rustpython_host_env::ctypes::write_value_to_address( + addr, + size, + AddressWriteValue::Pointer(ptr_val), + ); return Ok(()); } @@ -743,56 +684,39 @@ impl PyCPointer { // Use write_unaligned for safety on strict-alignment architectures if let Ok(int_val) = value.try_int(vm) { let i = int_val.as_bigint(); - match size { - 1 => { - *ptr = i.to_u8().expect("int too large"); - } - 2 => { - core::ptr::write_unaligned( - ptr as *mut i16, - i.to_i16().expect("int too large"), - ); - } - 4 => { - core::ptr::write_unaligned( - ptr as *mut i32, - i.to_i32().expect("int too large"), - ); - } - 8 => { - core::ptr::write_unaligned( - ptr as *mut i64, - i.to_i64().expect("int too large"), - ); - } + let bytes; + let write_value = match size { + 1 => AddressWriteValue::U8(i.to_u8().expect("int too large")), + 2 => AddressWriteValue::I16(i.to_i16().expect("int too large")), + 4 => AddressWriteValue::I32(i.to_i32().expect("int too large")), + 8 => AddressWriteValue::I64(i.to_i64().expect("int too large")), _ => { - let bytes = i.to_signed_bytes_le(); - let copy_len = bytes.len().min(size); - core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); + bytes = i.to_signed_bytes_le(); + AddressWriteValue::Bytes(&bytes) } - } + }; + rustpython_host_env::ctypes::write_value_to_address(addr, size, write_value); return Ok(()); } // Try to get value as float if let Ok(float_val) = value.try_float(vm) { let f = float_val.to_f64(); - match size { - 4 => { - core::ptr::write_unaligned(ptr as *mut f32, f as f32); - } - 8 => { - core::ptr::write_unaligned(ptr as *mut f64, f); - } - _ => {} - } + rustpython_host_env::ctypes::write_value_to_address( + addr, + size, + AddressWriteValue::Float(f), + ); return Ok(()); } // Try bytes if let Ok(bytes) = value.try_bytes_like(vm, |b| b.to_vec()) { - let copy_len = bytes.len().min(size); - core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); + rustpython_host_env::ctypes::write_value_to_address( + addr, + size, + AddressWriteValue::Bytes(&bytes), + ); return Ok(()); } diff --git a/crates/vm/src/stdlib/_ctypes/simple.rs b/crates/vm/src/stdlib/_ctypes/simple.rs index e50985c2ab9..d51a61130f3 100644 --- a/crates/vm/src/stdlib/_ctypes/simple.rs +++ b/crates/vm/src/stdlib/_ctypes/simple.rs @@ -1,11 +1,10 @@ use super::_ctypes::CArgObject; -use super::array::{PyCArray, WCHAR_SIZE, wchar_to_bytes}; +use super::array::PyCArray; use super::base::{ CDATA_BUFFER_METHODS, FfiArgValue, PyCData, StgInfo, StgInfoFlags, buffer_to_ffi_value, bytes_to_pyobject, }; use super::function::PyCFuncPtr; -use super::get_size; use super::pointer::PyCPointer; use crate::builtins::{PyByteArray, PyBytes, PyInt, PyNone, PyStr, PyType, PyTypeRef}; use crate::convert::ToPyObject; @@ -16,6 +15,10 @@ use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, Vir use alloc::borrow::Cow; use core::fmt::Debug; use num_traits::ToPrimitive; +use rustpython_host_env::ctypes::{ + SimpleStorageValue, simple_storage_value_to_bytes_endian, simple_type_align, + simple_type_pep3118_code, simple_type_size, write_simple_storage_buffer, zeroed_bytes, +}; /// Valid type codes for ctypes simple types pub(super) const SIMPLE_TYPE_CHARS: &str = cfg_select! { @@ -25,37 +28,10 @@ pub(super) const SIMPLE_TYPE_CHARS: &str = cfg_select! { _ => "cbBhHiIlLdfuzZqQPOv?g", }; -/// Convert ctypes type code to PEP 3118 format code. -/// Some ctypes codes need to be mapped to standard-size codes based on platform. -/// _ctypes_alloc_format_string_for_type -fn ctypes_code_to_pep3118(code: char) -> char { - match code { - // c_int: map based on sizeof(int) - 'i' if core::mem::size_of::() == 2 => 'h', - 'i' if core::mem::size_of::() == 4 => 'i', - 'i' if core::mem::size_of::() == 8 => 'q', - 'I' if core::mem::size_of::() == 2 => 'H', - 'I' if core::mem::size_of::() == 4 => 'I', - 'I' if core::mem::size_of::() == 8 => 'Q', - // c_long: map based on sizeof(long) - 'l' if core::mem::size_of::() == 4 => 'l', - 'l' if core::mem::size_of::() == 8 => 'q', - 'L' if core::mem::size_of::() == 4 => 'L', - 'L' if core::mem::size_of::() == 8 => 'Q', - // c_bool: map based on sizeof(bool) - typically 1 byte on all platforms - '?' if core::mem::size_of::() == 1 => '?', - '?' if core::mem::size_of::() == 2 => 'H', - '?' if core::mem::size_of::() == 4 => 'L', - '?' if core::mem::size_of::() == 8 => 'Q', - // Default: use the same code - _ => code, - } -} - /// _ctypes_alloc_format_string_for_type fn alloc_format_string_for_type(code: char, big_endian: bool) -> String { let prefix = if big_endian { ">" } else { "<" }; - let pep_code = ctypes_code_to_pep3118(code); + let pep_code = simple_type_pep3118_code(code); format!("{prefix}{pep_code}") } @@ -91,8 +67,8 @@ fn new_simple_type( ))); } - let size = get_size(&tp_str); - Ok(PyCSimple(PyCData::from_bytes(vec![0u8; size], None))) + let size = simple_type_size(&tp_str).expect("invalid ctypes simple type"); + Ok(PyCSimple(PyCData::from_bytes(zeroed_bytes(size), None))) } fn set_primitive(_type_: &str, value: &PyObject, vm: &VirtualMachine) -> PyResult { @@ -430,14 +406,11 @@ impl PyCSimpleType { if let Some(funcptr) = value.downcast_ref::() { let ptr_val = { let buffer = funcptr._base.buffer.read(); - buffer - .first_chunk::<{ size_of::() }>() - .copied() - .map_or(0, usize::from_ne_bytes) + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) }; return Ok(CArgObject { tag: b'P', - value: FfiArgValue::Pointer(ptr_val), + value: FfiArgValue::pointer(ptr_val), obj: value.clone(), size: 0, offset: 0, @@ -450,14 +423,11 @@ impl PyCSimpleType { if matches!(value_type_code.as_deref(), Some("z" | "Z")) { let ptr_val = { let buffer = simple.0.buffer.read(); - buffer - .first_chunk::<{ size_of::() }>() - .copied() - .map_or(0, usize::from_ne_bytes) + rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer) }; return Ok(CArgObject { tag: b'Z', - value: FfiArgValue::Pointer(ptr_val), + value: FfiArgValue::pointer(ptr_val), obj: value.clone(), size: 0, offset: 0, @@ -470,7 +440,7 @@ impl PyCSimpleType { Some("O") => { return Ok(CArgObject { tag: b'O', - value: FfiArgValue::Pointer(value.get_id()), + value: FfiArgValue::pointer(value.get_id()), obj: value, size: 0, offset: 0, @@ -579,8 +549,8 @@ impl Initializer for PyCSimpleType { } // Initialize StgInfo - let size = super::get_size(&type_str); - let align = super::get_align(&type_str); + let size = simple_type_size(&type_str).expect("invalid ctypes simple type"); + let align = simple_type_align(&type_str).expect("invalid ctypes simple type"); let mut stg_info = StgInfo::new(size, align); // Set format for PEP 3118 buffer protocol @@ -739,210 +709,173 @@ fn value_to_bytes_endian( swapped: bool, vm: &VirtualMachine, ) -> Vec { - // Helper macro for endian conversion - macro_rules! to_bytes { - ($val:expr) => { - if swapped { - // Use opposite endianness - #[cfg(target_endian = "little")] - { - $val.to_be_bytes().to_vec() - } - #[cfg(target_endian = "big")] - { - $val.to_le_bytes().to_vec() - } - } else { - $val.to_ne_bytes().to_vec() - } - }; - } - - match _type_ { + let storage_value = match _type_ { "c" => { // c_char - single byte (bytes, bytearray, or int 0-255) if let Some(bytes) = value.downcast_ref::() && !bytes.is_empty() { - return vec![bytes.as_bytes()[0]]; - } - if let Some(bytearray) = value.downcast_ref::() { + SimpleStorageValue::Byte(bytes.as_bytes()[0]) + } else if let Some(bytearray) = value.downcast_ref::() { let buf = bytearray.borrow_buf(); if !buf.is_empty() { - return vec![buf[0]]; + SimpleStorageValue::Byte(buf[0]) + } else { + SimpleStorageValue::Zero } - } - if let Ok(int_val) = value.try_int(vm) + } else if let Ok(int_val) = value.try_int(vm) && let Some(v) = int_val.as_bigint().to_u8() { - return vec![v]; + SimpleStorageValue::Byte(v) + } else { + SimpleStorageValue::Zero } - vec![0] } "u" => { // c_wchar - platform-dependent size (2 on Windows, 4 on Unix) if let Some(s) = value.downcast_ref::() { let mut cps = s.as_wtf8().code_points(); if let (Some(c), None) = (cps.next(), cps.next()) { - let mut buffer = vec![0u8; WCHAR_SIZE]; - wchar_to_bytes(c.to_u32(), &mut buffer); - if swapped { - buffer.reverse(); - } - return buffer; + SimpleStorageValue::Wchar(c.to_u32()) + } else { + SimpleStorageValue::Zero } + } else { + SimpleStorageValue::Zero } - vec![0; WCHAR_SIZE] } "b" => { // c_byte - signed char (1 byte) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as i8; - return vec![v as u8]; + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0] } "B" => { // c_ubyte - unsigned char (1 byte) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as u8; - return vec![v]; + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0] } "h" => { // c_short (2 bytes) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as i16; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0; 2] } "H" => { // c_ushort (2 bytes) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as u16; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0; 2] } "i" => { // c_int (4 bytes) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as i32; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0; 4] } "I" => { // c_uint (4 bytes) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as u32; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0; 4] } "l" => { // c_long (platform dependent) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as libc::c_long; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - const SIZE: usize = core::mem::size_of::(); - vec![0; SIZE] } "L" => { // c_ulong (platform dependent) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as libc::c_ulong; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - const SIZE: usize = core::mem::size_of::(); - vec![0; SIZE] } "q" => { // c_longlong (8 bytes) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as i64; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0; 8] } "Q" => { // c_ulonglong (8 bytes) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().expect("int too large") as u64; - return to_bytes!(v); + SimpleStorageValue::Signed(int_val.as_bigint().to_i128().expect("int too large")) + } else { + SimpleStorageValue::Zero } - vec![0; 8] } "f" => { // c_float (4 bytes) - also accepts int if let Ok(float_val) = value.try_float(vm) { - return to_bytes!(float_val.to_f64() as f32); - } - if let Ok(int_val) = value.try_int(vm) + SimpleStorageValue::Float(float_val.to_f64()) + } else if let Ok(int_val) = value.try_int(vm) && let Some(v) = int_val.as_bigint().to_f64() { - return to_bytes!(v as f32); + SimpleStorageValue::Float(v) + } else { + SimpleStorageValue::Zero } - vec![0; 4] } "d" => { // c_double (8 bytes) - also accepts int if let Ok(float_val) = value.try_float(vm) { - return to_bytes!(float_val.to_f64()); - } - if let Ok(int_val) = value.try_int(vm) + SimpleStorageValue::Float(float_val.to_f64()) + } else if let Ok(int_val) = value.try_int(vm) && let Some(v) = int_val.as_bigint().to_f64() { - return to_bytes!(v); + SimpleStorageValue::Float(v) + } else { + SimpleStorageValue::Zero } - vec![0; 8] } "g" => { // long double - platform dependent size // Store as f64, zero-pad to platform long double size // Note: This may lose precision on platforms where long double > 64 bits - let f64_val = if let Ok(float_val) = value.try_float(vm) { + let value = if let Ok(float_val) = value.try_float(vm) { float_val.to_f64() } else if let Ok(int_val) = value.try_int(vm) { int_val.as_bigint().to_f64().unwrap_or(0.0) } else { 0.0 }; - let f64_bytes = if swapped { - #[cfg(target_endian = "little")] - { - f64_val.to_be_bytes().to_vec() - } - #[cfg(target_endian = "big")] - { - f64_val.to_le_bytes().to_vec() - } - } else { - f64_val.to_ne_bytes().to_vec() - }; - // Pad to long double size - let long_double_size = super::get_size("g"); - let mut result = f64_bytes; - result.resize(long_double_size, 0); - result + SimpleStorageValue::Float(value) } "?" => { // c_bool (1 byte) if let Ok(b) = value.to_owned().try_to_bool(vm) { - return vec![if b { 1 } else { 0 }]; + SimpleStorageValue::Bool(b) + } else { + SimpleStorageValue::Zero } - vec![0] } "v" => { // VARIANT_BOOL: True = 0xFFFF (-1 as i16), False = 0x0000 if let Ok(b) = value.to_owned().try_to_bool(vm) { - let val: i16 = if b { -1 } else { 0 }; - return to_bytes!(val); + SimpleStorageValue::Bool(b) + } else { + SimpleStorageValue::Zero } - vec![0; 2] } "P" => { // c_void_p - pointer type (platform pointer size) @@ -951,9 +884,10 @@ fn value_to_bytes_endian( .as_bigint() .to_usize() .expect("int too large for pointer"); - return to_bytes!(v); + SimpleStorageValue::Pointer(v) + } else { + SimpleStorageValue::Zero } - vec![0; core::mem::size_of::()] } "z" => { // c_char_p - pointer to char (stores pointer value from int) @@ -963,9 +897,10 @@ fn value_to_bytes_endian( .as_bigint() .to_usize() .expect("int too large for pointer"); - return to_bytes!(v); + SimpleStorageValue::Pointer(v) + } else { + SimpleStorageValue::Zero } - vec![0; core::mem::size_of::()] } "Z" => { // c_wchar_p - pointer to wchar_t (stores pointer value from int) @@ -975,19 +910,20 @@ fn value_to_bytes_endian( .as_bigint() .to_usize() .expect("int too large for pointer"); - return to_bytes!(v); + SimpleStorageValue::Pointer(v) + } else { + SimpleStorageValue::Zero } - vec![0; core::mem::size_of::()] } "O" => { // py_object - store object id as non-zero marker // The actual object is stored in _objects // Use object's id as a non-zero placeholder (indicates non-NULL) - let id = value.get_id(); - to_bytes!(id) + SimpleStorageValue::ObjectId(value.get_id()) } - _ => vec![0], - } + _ => SimpleStorageValue::Zero, + }; + simple_storage_value_to_bytes_endian(_type_, storage_value, swapped) } /// Check if value is a c_char array or pointer(c_char) @@ -1049,7 +985,7 @@ impl Constructor for PyCSimple { if _type_ == "z" { if let Some(bytes) = v.downcast_ref::() { let (kept_alive, ptr) = super::base::ensure_z_null_terminated(bytes, vm); - let buffer = ptr.to_ne_bytes().to_vec(); + let buffer = rustpython_host_env::ctypes::pointer_bytes(ptr); let cdata = PyCData::from_bytes(buffer, Some(v.clone())); *cdata.base.write() = Some(kept_alive); return Self(cdata).into_ref_with_type(vm, cls).map(Into::into); @@ -1058,7 +994,7 @@ impl Constructor for PyCSimple { && let Some(s) = v.downcast_ref::() { let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_wtf8(), vm); - let buffer = ptr.to_ne_bytes().to_vec(); + let buffer = rustpython_host_env::ctypes::pointer_bytes(ptr); let cdata = PyCData::from_bytes(buffer, Some(holder)); return Self(cdata).into_ref_with_type(vm, cls).map(Into::into); } @@ -1166,58 +1102,30 @@ impl PyCSimple { // Special handling for c_char_p (z) and c_wchar_p (Z) // z_get, Z_get - dereference pointer to get string if type_code == "z" { - // c_char_p: read pointer from buffer, dereference to get bytes string let buffer = zelf.0.buffer.read(); - let ptr = super::base::read_ptr_from_buffer(&buffer); - if ptr == 0 { - return Ok(vm.ctx.none()); - } - // Read null-terminated string at the address - unsafe { - let cstr = core::ffi::CStr::from_ptr(ptr as _); - return Ok(vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into()); - } + return match rustpython_host_env::ctypes::decode_type_code(&type_code, &buffer) { + rustpython_host_env::ctypes::DecodedValue::Bytes(value) => { + Ok(vm.ctx.new_bytes(value).into()) + } + rustpython_host_env::ctypes::DecodedValue::None => Ok(vm.ctx.none()), + _ => unreachable!("decode_type_code('z') only returns bytes or None"), + }; } if type_code == "Z" { - // c_wchar_p: read pointer from buffer, dereference to get wide string let buffer = zelf.0.buffer.read(); - let ptr = super::base::read_ptr_from_buffer(&buffer); - if ptr == 0 { - return Ok(vm.ctx.none()); - } - // Read null-terminated wide string at the address - // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide for surrogate pairs - // Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32 - unsafe { - let w_ptr = ptr as *const libc::wchar_t; - let len = libc::wcslen(w_ptr); - let wchars = core::slice::from_raw_parts(w_ptr, len); - #[cfg(windows)] - { - use rustpython_common::wtf8::Wtf8Buf; - let wide: Vec = wchars.to_vec(); - let wtf8 = Wtf8Buf::from_wide(&wide); - return Ok(vm.ctx.new_str(wtf8).into()); - } - #[cfg(not(windows))] - { - #[allow( - clippy::useless_conversion, - reason = "wchar_t is i32 on some platforms and u32 on others" - )] - let s: String = wchars - .iter() - .filter_map(|&c| u32::try_from(c).ok().and_then(char::from_u32)) - .collect(); - return Ok(vm.ctx.new_str(s).into()); + return match rustpython_host_env::ctypes::decode_type_code(&type_code, &buffer) { + rustpython_host_env::ctypes::DecodedValue::String(value) => { + Ok(vm.ctx.new_str(value).into()) } - } + rustpython_host_env::ctypes::DecodedValue::None => Ok(vm.ctx.none()), + _ => unreachable!("decode_type_code('Z') only returns string or None"), + }; } // O_get: py_object - read PyObject pointer from buffer if type_code == "O" { let buffer = zelf.0.buffer.read(); - let ptr = super::base::read_ptr_from_buffer(&buffer); + let ptr = rustpython_host_env::ctypes::read_pointer_from_buffer(&buffer); if ptr == 0 { return Err(vm.new_value_error("PyObject is NULL")); } @@ -1243,23 +1151,7 @@ impl PyCSimple { }; let cls_ref = cls.to_owned(); - bytes_to_pyobject(&cls_ref, &buffer_data, vm).or_else(|_| { - // Fallback: return bytes as integer based on type - match type_code.as_str() { - "c" => { - if !buffer.is_empty() { - Ok(vm.ctx.new_bytes(vec![buffer[0]]).into()) - } else { - Ok(vm.ctx.new_bytes(vec![0]).into()) - } - } - "?" => { - let val = buffer.first().copied().unwrap_or(0); - Ok(vm.ctx.new_bool(val != 0).into()) - } - _ => Ok(vm.ctx.new_int(0).into()), - } - }) + Ok(bytes_to_pyobject(&cls_ref, &buffer_data, vm)) } #[pygetset(setter)] @@ -1281,7 +1173,8 @@ impl PyCSimple { if type_code == "z" { if let Some(bytes) = value.downcast_ref::() { let (kept_alive, ptr) = super::base::ensure_z_null_terminated(bytes, vm); - *zelf.0.buffer.write() = alloc::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); + *zelf.0.buffer.write() = + alloc::borrow::Cow::Owned(rustpython_host_env::ctypes::pointer_bytes(ptr)); *zelf.0.objects.write() = Some(value); *zelf.0.base.write() = Some(kept_alive); return Ok(()); @@ -1290,7 +1183,8 @@ impl PyCSimple { && let Some(s) = value.downcast_ref::() { let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_wtf8(), vm); - *zelf.0.buffer.write() = alloc::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); + *zelf.0.buffer.write() = + alloc::borrow::Cow::Owned(rustpython_host_env::ctypes::pointer_bytes(ptr)); *zelf.0.objects.write() = Some(holder); return Ok(()); } @@ -1310,20 +1204,7 @@ impl PyCSimple { // If the buffer is borrowed (from shared memory), write in-place // Otherwise replace with new owned buffer let mut buffer = zelf.0.buffer.write(); - match &mut *buffer { - Cow::Borrowed(slice) => { - // SAFETY: For from_buffer, the slice points to writable shared memory. - // Python's from_buffer requires writable buffer, so this is safe. - let ptr = slice.as_ptr() as *mut u8; - let len = slice.len().min(buffer_bytes.len()); - unsafe { - core::ptr::copy_nonoverlapping(buffer_bytes.as_ptr(), ptr, len); - } - } - Cow::Owned(vec) => { - vec.copy_from_slice(&buffer_bytes); - } - } + write_simple_storage_buffer(&mut buffer, &buffer_bytes); // For c_char_p (type "z"), c_wchar_p (type "Z"), and py_object (type "O"), // keep the reference in _objects @@ -1375,53 +1256,13 @@ impl PyCSimple { /// The value must be kept alive until after the FFI call completes. pub(crate) fn to_ffi_value( &self, - ty: libffi::middle::Type, + ty: rustpython_host_env::ctypes::FfiType, _vm: &VirtualMachine, ) -> Option { let buffer = self.0.buffer.read(); - let bytes: &[u8] = &buffer; - - let ret = if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u8().as_raw_ptr()) { - let byte = *bytes.first()?; - FfiArgValue::U8(byte) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i8().as_raw_ptr()) { - let byte = *bytes.first()?; - FfiArgValue::I8(byte as i8) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u16().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<2>()?; - FfiArgValue::U16(u16::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i16().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<2>()?; - FfiArgValue::I16(i16::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u32().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<4>()?; - FfiArgValue::U32(u32::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i32().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<4>()?; - FfiArgValue::I32(i32::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u64().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<8>()?; - FfiArgValue::U64(u64::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i64().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<8>()?; - FfiArgValue::I64(i64::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f32().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<4>()?; - FfiArgValue::F32(f32::from_ne_bytes(bytes)) - } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f64().as_raw_ptr()) { - let bytes = *bytes.first_chunk::<8>()?; - FfiArgValue::F64(f64::from_ne_bytes(bytes)) - } else if core::ptr::eq( - ty.as_raw_ptr(), - libffi::middle::Type::pointer().as_raw_ptr(), - ) { - let bytes = *buffer.first_chunk::<{ size_of::() }>()?; - let val = usize::from_ne_bytes(bytes); - FfiArgValue::Pointer(val) - } else { - return None; - }; - Some(ret) + Some(FfiArgValue::Scalar( + rustpython_host_env::ctypes::ffi_value_from_type(&buffer, ty)?, + )) } } diff --git a/crates/vm/src/stdlib/_io.rs b/crates/vm/src/stdlib/_io.rs index a0056523746..4d16c104694 100644 --- a/crates/vm/src/stdlib/_io.rs +++ b/crates/vm/src/stdlib/_io.rs @@ -23,6 +23,7 @@ use crate::{ AsObject, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::PyModule, }; pub use _io::{OpenArgs, io_open as open}; +use rustpython_host_env::io as host_io; fn file_closed(file: &PyObject, vm: &VirtualMachine) -> PyResult { file.get_attr("closed", vm)?.try_to_bool(vm) @@ -146,15 +147,8 @@ mod _io { use num_traits::ToPrimitive; use std::io::{self, Cursor, SeekFrom, prelude::*}; - #[allow(clippy::let_and_return)] fn validate_whence(whence: i32) -> bool { - let x = (0..=2).contains(&whence); - cfg_select! { - any(target_os = "dragonfly", target_os = "freebsd", target_os = "linux") => { - x || matches!(whence, libc::SEEK_DATA | libc::SEEK_HOLE) - } - _ => x, - } + host_io::validate_whence(whence) } fn ensure_unclosed(file: &PyObject, msg: &str, vm: &VirtualMachine) -> PyResult<()> { @@ -178,7 +172,7 @@ mod _io { if exc.fast_isinstance(vm.ctx.exceptions.os_error) && let Ok(errno_attr) = exc.as_object().get_attr("errno", vm) && let Ok(errno_val) = i32::try_from_object(vm, errno_attr) - && errno_val == libc::EINTR + && host_io::is_interrupted_errno(errno_val) { vm.check_signals()?; return Ok(None); @@ -209,7 +203,7 @@ mod _io { )) } - #[derive(FromArgs)] + #[derive(Clone, Copy, FromArgs)] pub(super) struct OptionalSize { // In a few functions, the default value is -1 rather than None. // Make sure the default value doesn't affect compatibility. @@ -218,7 +212,6 @@ mod _io { } impl OptionalSize { - #[allow(clippy::wrong_self_convention)] pub(super) fn to_usize(self) -> Option { self.size?.to_usize() } @@ -2912,7 +2905,7 @@ mod _io { Ok(()) } - #[allow(clippy::type_complexity)] + #[expect(clippy::type_complexity, reason = "ignore warning for now")] fn find_coder( buffer: &PyObject, encoding: &str, @@ -5081,7 +5074,8 @@ mod _io { // check file descriptor validity #[cfg(all(unix, feature = "host_env"))] if let Ok(crate::ospath::OsPathOrFd::Fd(fd)) = file.clone().try_into_value(vm) { - nix::fcntl::fcntl(fd, nix::fcntl::F_GETFD).map_err(|_| vm.new_last_errno_error())?; + rustpython_host_env::fcntl::validate_fd(fd.as_raw()) + .map_err(|_| vm.new_last_errno_error())?; } // Construct a RawIO (subclass of RawIOBase) @@ -5267,7 +5261,7 @@ mod _io { use super::*; #[test] - fn test_buffered_read() { + fn buffered_read() { let data = vec![1, 2, 3, 4]; let bytes = None; let mut buffered = BufferedIO { @@ -5278,7 +5272,7 @@ mod _io { } #[test] - fn test_buffered_seek() { + fn buffered_seek() { let data = vec![1, 2, 3, 4]; let count: u64 = 2; let mut buffered = BufferedIO { @@ -5290,7 +5284,7 @@ mod _io { } #[test] - fn test_buffered_value() { + fn buffered_value() { let data = vec![1, 2, 3, 4]; let buffered = BufferedIO { cursor: Cursor::new(data.clone()), @@ -5343,108 +5337,7 @@ mod fileio { types::{Constructor, DefaultConstructor, Destructor, Initializer, Representable}, }; use crossbeam_utils::atomic::AtomicCell; - use std::io::Read; - - bitflags::bitflags! { - #[derive(Copy, Clone, Debug, PartialEq)] - struct Mode: u8 { - const CREATED = 0b0001; - const READABLE = 0b0010; - const WRITABLE = 0b0100; - const APPENDING = 0b1000; - } - } - - enum ModeError { - Invalid, - BadRwa, - } - - impl ModeError { - fn error_msg(&self, mode_str: &str) -> String { - match self { - Self::Invalid => format!("invalid mode: {mode_str}"), - Self::BadRwa => { - "Must have exactly one of create/read/write/append mode and at most one plus" - .to_owned() - } - } - } - } - - fn compute_mode(mode_str: &str) -> Result<(Mode, i32), ModeError> { - let mut flags = 0; - let mut plus = false; - let mut rwa = false; - let mut mode = Mode::empty(); - for c in mode_str.bytes() { - match c { - b'x' => { - if rwa { - return Err(ModeError::BadRwa); - } - rwa = true; - mode.insert(Mode::WRITABLE | Mode::CREATED); - flags |= libc::O_EXCL | libc::O_CREAT; - } - b'r' => { - if rwa { - return Err(ModeError::BadRwa); - } - rwa = true; - mode.insert(Mode::READABLE); - } - b'w' => { - if rwa { - return Err(ModeError::BadRwa); - } - rwa = true; - mode.insert(Mode::WRITABLE); - flags |= libc::O_CREAT | libc::O_TRUNC; - } - b'a' => { - if rwa { - return Err(ModeError::BadRwa); - } - rwa = true; - mode.insert(Mode::WRITABLE | Mode::APPENDING); - flags |= libc::O_APPEND | libc::O_CREAT; - } - b'+' => { - if plus { - return Err(ModeError::BadRwa); - } - plus = true; - mode.insert(Mode::READABLE | Mode::WRITABLE); - } - b'b' => {} - _ => return Err(ModeError::Invalid), - } - } - - if !rwa { - return Err(ModeError::BadRwa); - } - - if mode.contains(Mode::READABLE | Mode::WRITABLE) { - flags |= libc::O_RDWR - } else if mode.contains(Mode::READABLE) { - flags |= libc::O_RDONLY - } else { - flags |= libc::O_WRONLY - } - - #[cfg(windows)] - { - flags |= libc::O_BINARY | libc::O_NOINHERIT; - } - #[cfg(unix)] - { - flags |= libc::O_CLOEXEC - } - - Ok((mode, flags as _)) - } + use rustpython_host_env::io as host_io; #[pyattr] #[pyclass(module = "_io", name, base = _RawIOBase)] @@ -5453,7 +5346,7 @@ mod fileio { _base: _RawIOBase, fd: AtomicCell, closefd: AtomicCell, - mode: AtomicCell, + mode: AtomicCell, seekable: AtomicCell>, blksize: AtomicCell, finalizing: AtomicCell, @@ -5477,7 +5370,7 @@ mod fileio { _base: Default::default(), fd: AtomicCell::new(-1), closefd: AtomicCell::new(true), - mode: AtomicCell::new(Mode::empty()), + mode: AtomicCell::new(host_io::FileMode::empty()), seekable: AtomicCell::new(None), blksize: AtomicCell::new(super::DEFAULT_BUFFER_SIZE as _), finalizing: AtomicCell::new(false), @@ -5516,8 +5409,10 @@ mod fileio { .mode .unwrap_or_else(|| PyUtf8Str::from("rb").into_ref(&vm.ctx)); let mode_str = mode_obj.as_str(); - let (mode, flags) = - compute_mode(mode_str).map_err(|e| vm.new_value_error(e.error_msg(mode_str)))?; + let parsed = host_io::parse_fileio_mode(mode_str) + .map_err(|e| vm.new_value_error(e.error_msg(mode_str)))?; + let mode = parsed.mode; + let flags = parsed.flags; zelf.mode.store(mode); let (fd, filename) = if let Some(fd) = arg_fd { @@ -5542,9 +5437,9 @@ mod fileio { } else { let path = OsPath::try_from_fspath(name.clone(), vm)?; #[cfg(any(unix, target_os = "wasi"))] - let fd = crt_fd::open(&path.clone().into_cstring(vm)?, flags, 0o666); + let fd = host_io::open_path(&path.clone().into_cstring(vm)?, flags, 0o666); #[cfg(windows)] - let fd = crt_fd::wopen(&path.to_wide_cstring(vm)?, flags, 0o666); + let fd = host_io::open_path(&path.to_wide_cstring(vm)?, flags, 0o666); let filename = OsPathOrFd::Path(path); match fd { Ok(fd) => (fd.into_raw(), Some(filename)), @@ -5561,48 +5456,17 @@ mod fileio { // TODO: _Py_set_inheritable - let fd_fstat = rustpython_host_env::fileutils::fstat(fd); - - #[cfg(windows)] - { - if let Err(err) = fd_fstat { - // If the fd is invalid, prevent destructor from trying to close it - if err.raw_os_error() - == Some(windows_sys::Win32::Foundation::ERROR_INVALID_HANDLE as i32) - { - zelf.fd.store(-1); + match host_io::inspect_file_target(fd) { + Ok(info) => { + if let Some(blksize) = info.blksize { + zelf.blksize.store(blksize); } - return Err(OSErrorBuilder::with_filename(&err, filename, vm)); } - } - #[cfg(any(unix, target_os = "wasi"))] - { - match fd_fstat { - Ok(status) => { - if (status.st_mode & libc::S_IFMT) == libc::S_IFDIR { - // If fd was passed by user, don't close it on error - if !fd_is_own { - zelf.fd.store(-1); - } - let err = std::io::Error::from_raw_os_error(libc::EISDIR); - return Err(OSErrorBuilder::with_filename(&err, filename, vm)); - } - // Store st_blksize for _blksize property - if status.st_blksize > 1 { - #[allow( - clippy::useless_conversion, - reason = "needed for 32-bit platforms" - )] - zelf.blksize.store(i64::from(status.st_blksize)); - } - } - Err(err) => { - if err.raw_os_error() == Some(libc::EBADF) { - // fd is invalid, prevent destructor from trying to close it - zelf.fd.store(-1); - return Err(OSErrorBuilder::with_filename(&err, filename, vm)); - } + Err(err) => { + if host_io::should_forget_fd_after_inspect_error(&err, fd_is_own) { + zelf.fd.store(-1); } + return Err(OSErrorBuilder::with_filename(&err, filename, vm)); } } @@ -5616,8 +5480,8 @@ mod fileio { return Err(e); } - if mode.contains(Mode::APPENDING) { - let _ = os::lseek(fd, 0, libc::SEEK_END, vm); + if mode.contains(host_io::FileMode::APPENDING) { + let _ = host_io::seek_to_end(fd); } Ok(()) @@ -5698,7 +5562,7 @@ mod fileio { if self.fd.load() < 0 { return Err(io_closed_error(vm)); } - Ok(self.mode.load().contains(Mode::READABLE)) + Ok(self.mode.load().contains(host_io::FileMode::READABLE)) } #[pymethod] @@ -5706,33 +5570,12 @@ mod fileio { if self.fd.load() < 0 { return Err(io_closed_error(vm)); } - Ok(self.mode.load().contains(Mode::WRITABLE)) + Ok(self.mode.load().contains(host_io::FileMode::WRITABLE)) } #[pygetset] fn mode(&self) -> &'static str { - let mode = self.mode.load(); - if mode.contains(Mode::CREATED) { - if mode.contains(Mode::READABLE) { - "xb+" - } else { - "xb" - } - } else if mode.contains(Mode::APPENDING) { - if mode.contains(Mode::READABLE) { - "ab+" - } else { - "ab" - } - } else if mode.contains(Mode::READABLE) { - if mode.contains(Mode::WRITABLE) { - "rb+" - } else { - "rb" - } - } else { - "wb" - } + self.mode.load().raw_mode() } #[pymethod] @@ -5741,7 +5584,7 @@ mod fileio { read_byte: OptionalSize, vm: &VirtualMachine, ) -> PyResult>> { - if !zelf.mode.load().contains(Mode::READABLE) { + if !zelf.mode.load().contains(host_io::FileMode::READABLE) { return Err(new_unsupported_operation( vm, "File or stream is not readable".to_owned(), @@ -5752,14 +5595,14 @@ mod fileio { let mut bytes = vec![0; read_byte]; // Loop on EINTR (PEP 475) let n = loop { - match vm.allow_threads(|| crt_fd::read(handle, &mut bytes)) { + match vm.allow_threads(|| host_io::read_once(handle, &mut bytes)) { Ok(n) => break n, - Err(e) if e.raw_os_error() == Some(libc::EINTR) => { + Err(e) if host_io::is_interrupted_error(&e) => { vm.check_signals()?; continue; } // Non-blocking mode: return None if EAGAIN - Err(e) if e.raw_os_error() == Some(libc::EAGAIN) => { + Err(e) if host_io::is_would_block_error(&e) => { return Ok(None); } Err(e) => return Err(Self::io_error(zelf, e, vm)), @@ -5771,17 +5614,14 @@ mod fileio { let mut bytes = vec![]; // Loop on EINTR (PEP 475) loop { - match vm.allow_threads(|| { - let mut h = handle; - h.read_to_end(&mut bytes) - }) { - Ok(_) => break, - Err(e) if e.raw_os_error() == Some(libc::EINTR) => { + match vm.allow_threads(|| host_io::read_all(handle, &mut bytes)) { + Ok(()) => break, + Err(e) if host_io::is_interrupted_error(&e) => { vm.check_signals()?; continue; } // Non-blocking mode: return None if EAGAIN (only if no data read yet) - Err(e) if e.raw_os_error() == Some(libc::EAGAIN) => { + Err(e) if host_io::is_would_block_error(&e) => { if bytes.is_empty() { return Ok(None); } @@ -5802,7 +5642,7 @@ mod fileio { obj: ArgMemoryBuffer, vm: &VirtualMachine, ) -> PyResult> { - if !zelf.mode.load().contains(Mode::READABLE) { + if !zelf.mode.load().contains(host_io::FileMode::READABLE) { return Err(new_unsupported_operation( vm, "File or stream is not readable".to_owned(), @@ -5814,14 +5654,14 @@ mod fileio { let mut buf = obj.borrow_buf_mut(); // Loop on EINTR (PEP 475) let ret = loop { - match vm.allow_threads(|| crt_fd::read(handle, &mut buf)) { + match vm.allow_threads(|| host_io::read_once(handle, &mut buf)) { Ok(n) => break n, - Err(e) if e.raw_os_error() == Some(libc::EINTR) => { + Err(e) if host_io::is_interrupted_error(&e) => { vm.check_signals()?; continue; } // Non-blocking mode: return None if EAGAIN - Err(e) if e.raw_os_error() == Some(libc::EAGAIN) => { + Err(e) if host_io::is_would_block_error(&e) => { return Ok(None); } Err(e) => return Err(Self::io_error(zelf, e, vm)), @@ -5837,7 +5677,7 @@ mod fileio { obj: ArgBytesLike, vm: &VirtualMachine, ) -> PyResult> { - if !zelf.mode.load().contains(Mode::WRITABLE) { + if !zelf.mode.load().contains(host_io::FileMode::WRITABLE) { return Err(new_unsupported_operation( vm, "File or stream is not writable".to_owned(), @@ -5848,14 +5688,14 @@ mod fileio { // Loop on EINTR (PEP 475) let len = loop { - match obj.with_ref(|b| vm.allow_threads(|| crt_fd::write(handle, b))) { + match obj.with_ref(|b| vm.allow_threads(|| host_io::write_once(handle, b))) { Ok(n) => break n, - Err(e) if e.raw_os_error() == Some(libc::EINTR) => { + Err(e) if host_io::is_interrupted_error(&e) => { vm.check_signals()?; continue; } // Non-blocking mode: return None if EAGAIN - Err(e) if e.raw_os_error() == Some(libc::EAGAIN) => return Ok(None), + Err(e) if host_io::is_would_block_error(&e) => return Ok(None), Err(e) => return Err(Self::io_error(zelf, e, vm)), } }; @@ -5877,7 +5717,7 @@ mod fileio { } let fd = zelf.fd.swap(-1); let close_err = if fd >= 0 { - crt_fd::close(unsafe { crt_fd::Owned::from_raw(fd) }) + host_io::close_owned_fd(unsafe { crt_fd::Owned::from_raw(fd) }) .map_err(|err| Self::io_error(zelf, err, vm)) .err() } else { @@ -5897,7 +5737,7 @@ mod fileio { fn seekable(&self, vm: &VirtualMachine) -> PyResult { let fd = self.get_fd(vm)?; Ok(self.seekable.load().unwrap_or_else(|| { - let seekable = os::lseek(fd, 0, libc::SEEK_CUR, vm).is_ok(); + let seekable = host_io::is_seekable(fd); self.seekable.store(Some(seekable)); seekable })) @@ -5914,13 +5754,13 @@ mod fileio { let fd = self.get_fd(vm)?; let offset = get_offset(offset, vm)?; - os::lseek(fd, offset, how, vm) + host_io::seek(fd, offset, how).map_err(|e| e.into_pyexception(vm)) } #[pymethod] fn tell(&self, vm: &VirtualMachine) -> PyResult { let fd = self.get_fd(vm)?; - os::lseek(fd, 0, libc::SEEK_CUR, vm) + host_io::tell(fd).map_err(|e| e.into_pyexception(vm)) } #[pymethod] @@ -5928,7 +5768,7 @@ mod fileio { let fd = self.get_fd(vm)?; let len = match len.flatten() { Some(l) => get_offset(l, vm)?, - None => os::lseek(fd, 0, libc::SEEK_CUR, vm)?, + None => host_io::tell(fd).map_err(|e| e.into_pyexception(vm))?, }; os::ftruncate(fd, len).map_err(|e| e.into_pyexception(vm))?; Ok(len) @@ -5937,7 +5777,7 @@ mod fileio { #[pymethod] fn isatty(&self, vm: &VirtualMachine) -> PyResult { let fd = self.fileno(vm)?; - Ok(os::isatty(fd)) + Ok(host_io::isatty(fd)) } #[pymethod] @@ -6001,45 +5841,20 @@ mod winconsoleio { types::{Constructor, DefaultConstructor, Destructor, Initializer, Representable}, }; use crossbeam_utils::atomic::AtomicCell; - use windows_sys::Win32::{ - Foundation::{self, GENERIC_READ, GENERIC_WRITE, INVALID_HANDLE_VALUE}, - Globalization::{CP_UTF8, MultiByteToWideChar, WideCharToMultiByte}, - Storage::FileSystem::{ - CreateFileW, FILE_SHARE_READ, FILE_SHARE_WRITE, GetFullPathNameW, OPEN_EXISTING, - }, - System::Console::{ - GetConsoleMode, GetNumberOfConsoleInputEvents, ReadConsoleW, WriteConsoleW, - }, - }; - - type HANDLE = Foundation::HANDLE; + use rustpython_host_env::io as host_io; + use rustpython_host_env::nt as host_nt; + use rustpython_host_env::windows::ToWideString; + type HANDLE = host_nt::Handle; const SMALLBUF: usize = 4; const BUFMAX: usize = 32 * 1024 * 1024; fn handle_from_fd(fd: i32) -> HANDLE { - unsafe { rustpython_host_env::suppress_iph!(libc::get_osfhandle(fd)) as HANDLE } + host_nt::handle_from_fd(fd) } fn is_invalid_handle(handle: HANDLE) -> bool { - handle == INVALID_HANDLE_VALUE || handle.is_null() - } - - /// Check if a HANDLE is a console and what type ('r', 'w', or '\0'). - fn get_console_type(handle: HANDLE) -> char { - if is_invalid_handle(handle) { - return '\0'; - } - let mut mode: u32 = 0; - if unsafe { GetConsoleMode(handle, &mut mode) } == 0 { - return '\0'; - } - let mut peek_count: u32 = 0; - if unsafe { GetNumberOfConsoleInputEvents(handle, &mut peek_count) } != 0 { - 'r' - } else { - 'w' - } + host_nt::is_invalid_handle(handle) } /// Check if a Python object (fd or path string) refers to a console. @@ -6047,11 +5862,7 @@ mod winconsoleio { pub(super) fn pyio_get_console_type(path_or_fd: &PyObject, vm: &VirtualMachine) -> char { // Try as integer fd first if let Ok(fd) = i32::try_from_object(vm, path_or_fd.to_owned()) { - if fd >= 0 { - let handle = handle_from_fd(fd); - return get_console_type(handle); - } - return '\0'; + return host_nt::console_type_from_fd(fd); } // Try as string path @@ -6062,80 +5873,7 @@ mod winconsoleio { // Surrogate strings can't be console device names return '\0'; }; - - if name_str.eq_ignore_ascii_case("CONIN$") { - return 'r'; - } - if name_str.eq_ignore_ascii_case("CONOUT$") { - return 'w'; - } - if name_str.eq_ignore_ascii_case("CON") { - return 'x'; - } - - // Resolve full path and check for console device names - let wide: Vec = name_str.encode_utf16().chain(core::iter::once(0)).collect(); - let mut buf = [0u16; 260]; // MAX_PATH - let length = unsafe { - GetFullPathNameW( - wide.as_ptr(), - buf.len() as u32, - buf.as_mut_ptr(), - core::ptr::null_mut(), - ) - }; - if length == 0 || length as usize > buf.len() { - return '\0'; - } - let full_path = &buf[..length as usize]; - // Skip \\?\ or \\.\ prefix - let path_part = if full_path.len() >= 4 - && full_path[0] == b'\\' as u16 - && full_path[1] == b'\\' as u16 - && (full_path[2] == b'.' as u16 || full_path[2] == b'?' as u16) - && full_path[3] == b'\\' as u16 - { - &full_path[4..] - } else { - full_path - }; - - let path_str = String::from_utf16_lossy(path_part); - if path_str.eq_ignore_ascii_case("CONIN$") { - 'r' - } else if path_str.eq_ignore_ascii_case("CONOUT$") { - 'w' - } else if path_str.eq_ignore_ascii_case("CON") { - 'x' - } else { - '\0' - } - } - - /// Find the last valid UTF-8 boundary in a byte slice. - fn find_last_utf8_boundary(buf: &[u8], len: usize) -> usize { - let len = len.min(buf.len()); - for count in 1..=4.min(len) { - let c = buf[len - count]; - if c < 0x80 { - return len; - } - if c >= 0xc0 { - let expected = if c < 0xe0 { - 2 - } else if c < 0xf0 { - 3 - } else { - 4 - }; - if count < expected { - // Incomplete multibyte sequence - return len - count; - } - return len; - } - } - len + host_nt::console_type_from_name(name_str) } #[pyattr] @@ -6276,61 +6014,10 @@ mod winconsoleio { } let name_str = nameobj.str(vm)?; - let wide = name_str - .as_wtf8() - .encode_wide() - .chain(core::iter::once(0)) - .collect::>(); - - let access = if writable { - GENERIC_WRITE - } else { - GENERIC_READ - }; - - // Try read/write first, fall back to specific access - let mut handle: HANDLE = unsafe { - CreateFileW( - wide.as_ptr(), - GENERIC_READ | GENERIC_WRITE, - FILE_SHARE_READ | FILE_SHARE_WRITE, - core::ptr::null(), - OPEN_EXISTING, - 0, - core::ptr::null_mut(), - ) - }; - if is_invalid_handle(handle) { - handle = unsafe { - CreateFileW( - wide.as_ptr(), - access, - FILE_SHARE_READ | FILE_SHARE_WRITE, - core::ptr::null(), - OPEN_EXISTING, - 0, - core::ptr::null_mut(), - ) - }; - } - - if is_invalid_handle(handle) { - return Err(std::io::Error::last_os_error().to_pyexception(vm)); - } + let wide = name_str.as_wtf8().to_wide_cstring(); - let osf_flags = if writable { - libc::O_WRONLY | libc::O_BINARY | 0x80 /* O_NOINHERIT */ - } else { - libc::O_RDONLY | libc::O_BINARY | 0x80 /* O_NOINHERIT */ - }; - - fd = unsafe { libc::open_osfhandle(handle as isize, osf_flags) }; - if fd < 0 { - unsafe { - Foundation::CloseHandle(handle); - } - return Err(std::io::Error::last_os_error().to_pyexception(vm)); - } + fd = host_nt::open_console_path_fd(&wide, writable) + .map_err(|err| err.to_pyexception(vm))?; } else { // When opened by fd, never close the fd (user owns it) zelf.closefd.store(false); @@ -6341,7 +6028,7 @@ mod winconsoleio { // Validate console type if console_type == '\0' { let handle = handle_from_fd(fd); - console_type = get_console_type(handle); + console_type = host_nt::console_type(handle); } if console_type == '\0' { @@ -6371,9 +6058,8 @@ mod winconsoleio { fn internal_close(zelf: &WindowsConsoleIO) { let fd = zelf.fd.swap(-1); if fd >= 0 && zelf.closefd.load() { - unsafe { - libc::close(fd); - } + let _ = + host_io::close_owned_fd(unsafe { crate::host_env::crt_fd::Owned::from_raw(fd) }); } } @@ -6482,12 +6168,9 @@ mod winconsoleio { } let fd = zelf.fd.swap(-1); let close_err: Option = if fd >= 0 { - let result = unsafe { libc::close(fd) }; - if result < 0 { - Some(std::io::Error::last_os_error().into_pyexception(vm)) - } else { - None - } + host_io::close_owned_fd(unsafe { crate::host_env::crt_fd::Owned::from_raw(fd) }) + .err() + .map(|e| e.into_pyexception(vm)) } else { None }; @@ -6554,116 +6237,10 @@ mod winconsoleio { return Err(std::io::Error::last_os_error().to_pyexception(vm)); } - // Each character may take up to 4 bytes in UTF-8. - let mut wlen = (len / 4) as u32; - if wlen == 0 { - wlen = 1; - } - let dest = &mut *buf_ref; - - // Copy from internal buffer first - let mut read_len = { - let mut buf = self.buf.lock(); - Self::copy_from_buf(&mut buf, dest) - }; - if read_len > 0 { - wlen = wlen.saturating_sub(1); - } - if read_len >= len || wlen == 0 { - return Ok(read_len); - } - - // Read from console - let mut wbuf = vec![0u16; wlen as usize]; - let mut nread: u32 = 0; - let res = unsafe { - ReadConsoleW( - handle, - wbuf.as_mut_ptr() as _, - wlen, - &mut nread, - core::ptr::null(), - ) - }; - if res == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - if nread == 0 { - return Ok(read_len); - } - - // Check for Ctrl+Z (EOF) - if nread > 0 && wbuf[0] == 0x1A { - return Ok(read_len); - } - - // Convert wchar to UTF-8 - let remaining = len - read_len; - let u8n; - if remaining < 4 { - // Buffer the result in the internal small buffer - let mut buf = self.buf.lock(); - let converted = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - buf.as_mut_ptr() as _, - SMALLBUF as i32, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - if converted > 0 { - u8n = Self::copy_from_buf(&mut buf, &mut dest[read_len..]) as i32; - } else { - u8n = 0; - } - } else { - u8n = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - dest[read_len..].as_mut_ptr() as _, - remaining as i32, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - } - - if u8n > 0 { - read_len += u8n as usize; - } else { - let err = std::io::Error::last_os_error(); - if err.raw_os_error() == Some(122) { - // ERROR_INSUFFICIENT_BUFFER - let needed = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - core::ptr::null_mut(), - 0, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - if needed > 0 { - return Err(vm.new_system_error(format!( - "Buffer had room for {remaining} bytes but {needed} bytes required", - ))); - } - } - return Err(err.into_pyexception(vm)); - } - - Ok(read_len) + let mut smallbuf = self.buf.lock(); + host_nt::read_console_into(handle, dest, &mut smallbuf) + .map_err(|err| err.to_pyexception(vm)) } #[pymethod] @@ -6677,77 +6254,9 @@ mod winconsoleio { return Err(std::io::Error::last_os_error().to_pyexception(vm)); } - let mut result = Vec::new(); - - // Copy any buffered bytes first - { - let mut buf = self.buf.lock(); - let mut tmp = [0u8; SMALLBUF]; - let n = Self::copy_from_buf(&mut buf, &mut tmp); - result.extend_from_slice(&tmp[..n]); - } - - let mut wbuf = vec![0u16; 8192]; - loop { - let mut nread: u32 = 0; - let res = unsafe { - ReadConsoleW( - handle, - wbuf.as_mut_ptr() as _, - wbuf.len() as u32, - &mut nread, - core::ptr::null(), - ) - }; - if res == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - if nread == 0 { - break; - } - // Ctrl+Z at start -> EOF - if wbuf[0] == 0x1A { - break; - } - // Convert to UTF-8 - let needed = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - core::ptr::null_mut(), - 0, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - if needed == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - let offset = result.len(); - result.resize(offset + needed as usize, 0); - let written = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - result[offset..].as_mut_ptr() as _, - needed, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - if written == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - // If we didn't fill the buffer, no more data - if nread < wbuf.len() as u32 { - break; - } - } - + let mut smallbuf = self.buf.lock(); + let result = host_nt::read_console_all(handle, &mut smallbuf) + .map_err(|err| err.into_pyexception(vm))?; Ok(vm.ctx.new_bytes(result).into()) } @@ -6775,105 +6284,19 @@ mod winconsoleio { return Err(std::io::Error::last_os_error().to_pyexception(vm)); } - let len = size as usize; - - let mut wlen = (len / 4) as u32; - if wlen == 0 { - wlen = 1; - } - let mut read_len = { let mut ibuf = self.buf.lock(); Self::copy_from_buf(&mut ibuf, &mut buf) }; - if read_len > 0 { - wlen = wlen.saturating_sub(1); - } - if read_len >= len || wlen == 0 { - buf.truncate(read_len); - return Ok(vm.ctx.new_bytes(buf).into()); - } - - let mut wbuf = vec![0u16; wlen as usize]; - let mut nread: u32 = 0; - let res = unsafe { - ReadConsoleW( - handle, - wbuf.as_mut_ptr() as _, - wlen, - &mut nread, - core::ptr::null(), - ) - }; - if res == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - if nread == 0 || wbuf[0] == 0x1A { + if read_len >= size as usize { buf.truncate(read_len); return Ok(vm.ctx.new_bytes(buf).into()); } - - let remaining = len - read_len; - let u8n; - if remaining < 4 { + { let mut ibuf = self.buf.lock(); - let converted = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - ibuf.as_mut_ptr() as _, - SMALLBUF as i32, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - if converted > 0 { - u8n = Self::copy_from_buf(&mut ibuf, &mut buf[read_len..]) as i32; - } else { - u8n = 0; - } - } else { - u8n = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - buf[read_len..].as_mut_ptr() as _, - remaining as i32, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - } - - if u8n > 0 { - read_len += u8n as usize; - } else { - let err = std::io::Error::last_os_error(); - if err.raw_os_error() == Some(122) { - // ERROR_INSUFFICIENT_BUFFER - let needed = unsafe { - WideCharToMultiByte( - CP_UTF8, - 0, - wbuf.as_ptr(), - nread as i32, - core::ptr::null_mut(), - 0, - core::ptr::null(), - core::ptr::null_mut(), - ) - }; - if needed > 0 { - return Err(vm.new_system_error(format!( - "Buffer had room for {remaining} bytes but {needed} bytes required", - ))); - } - } - return Err(err.into_pyexception(vm)); + let n = host_nt::read_console_into(handle, &mut buf[read_len..], &mut ibuf) + .map_err(|err| err.to_pyexception(vm))?; + read_len += n; } buf.truncate(read_len); @@ -6903,72 +6326,8 @@ mod winconsoleio { return Ok(0); } - let mut len = data.len().min(BUFMAX); - - // Cap at 32766/2 wchars * 3 bytes (UTF-8 to wchar ratio is at most 3:1) - let max_wlen: u32 = 32766 / 2; - len = len.min(max_wlen as usize * 3); - - // Reduce len until wlen fits within max_wlen - let wlen; - loop { - len = find_last_utf8_boundary(data, len); - let w = unsafe { - MultiByteToWideChar( - CP_UTF8, - 0, - data.as_ptr(), - len as i32, - core::ptr::null_mut(), - 0, - ) - }; - if w as u32 <= max_wlen { - wlen = w; - break; - } - len /= 2; - } - if wlen == 0 { - return Ok(0); - } - - let mut wbuf = vec![0u16; wlen as usize]; - let wlen = unsafe { - MultiByteToWideChar( - CP_UTF8, - 0, - data.as_ptr(), - len as i32, - wbuf.as_mut_ptr(), - wlen, - ) - }; - if wlen == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - - let mut n_written: u32 = 0; - let res = unsafe { - WriteConsoleW( - handle, - wbuf.as_ptr() as _, - wlen as u32, - &mut n_written, - core::ptr::null(), - ) - }; - if res == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - - // If we wrote fewer wchars than expected, recalculate bytes consumed - if n_written < wlen as u32 { - // Binary search to find how many input bytes correspond to n_written wchars - len = wchar_to_utf8_count(data, len, n_written); - } - - Ok(len) + host_nt::write_console_utf8(handle, data, BUFMAX) + .map_err(|err| err.into_pyexception(vm)) } #[pymethod(name = "__reduce__")] @@ -6977,43 +6336,6 @@ mod winconsoleio { } } - /// Find how many UTF-8 bytes correspond to n wide chars. - fn wchar_to_utf8_count(data: &[u8], mut len: usize, mut n: u32) -> usize { - let mut start: usize = 0; - loop { - let mut mid = 0; - for i in (len / 2)..=len { - mid = find_last_utf8_boundary(data, i); - if mid != 0 { - break; - } - } - if mid == len { - return start + len; - } - if mid == 0 { - mid = if len > 1 { len - 1 } else { 1 }; - } - let wlen = unsafe { - MultiByteToWideChar( - CP_UTF8, - 0, - data[start..].as_ptr(), - mid as i32, - core::ptr::null_mut(), - 0, - ) - } as u32; - if wlen <= n { - start += mid; - len -= mid; - n -= wlen; - } else { - len = mid; - } - } - } - impl Destructor for WindowsConsoleIO { fn slot_del(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<()> { if let Some(cio) = zelf.downcast_ref::() { diff --git a/crates/vm/src/stdlib/_signal.rs b/crates/vm/src/stdlib/_signal.rs index 3c05f867d6f..b9c41d6a4ce 100644 --- a/crates/vm/src/stdlib/_signal.rs +++ b/crates/vm/src/stdlib/_signal.rs @@ -4,6 +4,8 @@ pub(crate) use _signal::module_def; #[pymodule] pub(crate) mod _signal { + #![allow(unreachable_pub)] + #[cfg(any(unix, windows))] use crate::convert::{IntoPyException, TryFromBorrowedObject}; use crate::{Py, PyObjectRef, PyResult, VirtualMachine, signal}; @@ -13,8 +15,12 @@ pub(crate) mod _signal { function::{ArgIntoFloat, OptionalArg}, }; use core::sync::atomic::{self, Ordering}; + #[cfg(any(unix, windows))] + use rustpython_host_env::signal as host_signal; #[cfg(unix)] use rustpython_host_env::signal::{double_to_timeval, itimerval_to_tuple}; + #[cfg(unix)] + use std::os::fd::AsFd; #[allow(non_camel_case_types)] type sighandler_t = cfg_select! { @@ -26,7 +32,7 @@ pub(crate) mod _signal { windows => { type WakeupFdRaw = libc::SOCKET; struct WakeupFd(WakeupFdRaw); - const INVALID_WAKEUP: libc::SOCKET = windows_sys::Win32::Networking::WinSock::INVALID_SOCKET; + const INVALID_WAKEUP: libc::SOCKET = host_signal::INVALID_SOCKET; static WAKEUP: atomic::AtomicUsize = atomic::AtomicUsize::new(INVALID_WAKEUP); // windows doesn't use the same fds for files and sockets like windows does, so we need // this to know whether to send() or write() @@ -57,14 +63,12 @@ pub(crate) mod _signal { } #[cfg(unix)] - pub(crate) use libc::SIG_ERR; - - #[cfg(unix)] - pub(crate) use nix::unistd::alarm as sig_alarm; + #[allow(unused_imports)] + pub use libc::SIG_ERR; #[cfg(unix)] #[pyattr] - pub(crate) use libc::{SIG_DFL, SIG_IGN}; + pub use libc::{SIG_DFL, SIG_IGN}; // pthread_sigmask 'how' constants #[cfg(unix)] @@ -73,54 +77,32 @@ pub(crate) mod _signal { #[cfg(not(unix))] #[pyattr] - pub(crate) const SIG_DFL: sighandler_t = 0; - + pub const SIG_DFL: sighandler_t = 0; #[cfg(not(unix))] #[pyattr] - pub(crate) const SIG_IGN: sighandler_t = 1; - + pub const SIG_IGN: sighandler_t = 1; #[cfg(not(unix))] #[allow(dead_code)] - pub(crate) const SIG_ERR: sighandler_t = -1 as _; - - #[cfg(all(unix, not(target_os = "redox")))] - unsafe extern "C" { - fn siginterrupt(sig: i32, flag: i32) -> i32; - } - - #[cfg(any(target_os = "linux", target_os = "android"))] - mod ffi { - unsafe extern "C" { - pub(super) fn getitimer( - which: libc::c_int, - curr_value: *mut libc::itimerval, - ) -> libc::c_int; - pub(super) fn setitimer( - which: libc::c_int, - new_value: *const libc::itimerval, - old_value: *mut libc::itimerval, - ) -> libc::c_int; - } - } + pub const SIG_ERR: sighandler_t = -1 as _; #[pyattr] use crate::signal::NSIG; #[cfg(any(unix, windows))] #[pyattr] - pub(crate) use libc::{SIGABRT, SIGFPE, SIGILL, SIGINT, SIGSEGV, SIGTERM}; + pub use libc::{SIGABRT, SIGFPE, SIGILL, SIGINT, SIGSEGV, SIGTERM}; #[cfg(windows)] #[pyattr] - const SIGBREAK: i32 = 21; // _SIGBREAK + const SIGBREAK: i32 = host_signal::SIGBREAK; // Windows-specific control events for GenerateConsoleCtrlEvent #[cfg(windows)] #[pyattr] - const CTRL_C_EVENT: u32 = 0; + const CTRL_C_EVENT: u32 = host_signal::CTRL_C_EVENT; #[cfg(windows)] #[pyattr] - const CTRL_BREAK_EVENT: u32 = 1; + const CTRL_BREAK_EVENT: u32 = host_signal::CTRL_BREAK_EVENT; #[cfg(unix)] #[pyattr] @@ -175,10 +157,9 @@ pub(crate) mod _signal { let sig_ign = vm.new_pyobj(SIG_IGN as u8); for signum in 1..NSIG { - let handler = unsafe { libc::signal(signum as i32, SIG_IGN) }; - if handler != SIG_ERR { - unsafe { libc::signal(signum as i32, handler) }; - } + let Some(handler) = (unsafe { host_signal::probe_handler(signum as i32) }) else { + continue; + }; let py_handler = if handler == SIG_DFL { Some(sig_dfl.clone()) } else if handler == SIG_IGN { @@ -210,7 +191,7 @@ pub(crate) mod _signal { #[cfg(any(unix, windows))] #[pyfunction] - pub(crate) fn signal( + pub fn signal( signalnum: i32, handler: PyObjectRef, vm: &VirtualMachine, @@ -218,16 +199,7 @@ pub(crate) mod _signal { signal::assert_in_range(signalnum, vm)?; #[cfg(windows)] { - const VALID_SIGNALS: &[i32] = &[ - libc::SIGINT, - libc::SIGILL, - libc::SIGFPE, - libc::SIGSEGV, - libc::SIGTERM, - SIGBREAK, - libc::SIGABRT, - ]; - if !VALID_SIGNALS.contains(&signalnum) { + if !host_signal::is_valid_signal(signalnum) { return Err(vm.new_value_error(format!("signal number {signalnum} out of range"))); } } @@ -246,14 +218,13 @@ pub(crate) mod _signal { }; signal::check_signals(vm)?; - let old = unsafe { libc::signal(signalnum, sig_handler) }; - if old == SIG_ERR { - return Err(vm.new_os_error("Failed to set signal".to_owned())); - } - #[cfg(all(unix, not(target_os = "redox")))] - unsafe { - siginterrupt(signalnum, 1); - } + let old = unsafe { host_signal::install_handler(signalnum, sig_handler) }; + let _old = match old { + Ok(old) => old, + Err(_) => { + return Err(vm.new_os_error("Failed to set signal".to_owned())); + } + }; let signal_handlers = vm.signal_handlers.get_or_init(signal::new_signal_handlers); let old_handler = signal_handlers.borrow_mut()[signalnum as usize].replace(handler); @@ -273,18 +244,13 @@ pub(crate) mod _signal { #[cfg(unix)] #[pyfunction] fn alarm(time: u32) -> u32 { - let prev_time = if time == 0 { - sig_alarm::cancel() - } else { - sig_alarm::set(time) - }; - prev_time.unwrap_or(0) + rustpython_host_env::signal::alarm(time) } #[cfg(unix)] #[pyfunction] fn pause(vm: &VirtualMachine) -> PyResult<()> { - unsafe { libc::pause() }; + host_signal::pause(); signal::check_signals(vm)?; Ok(()) } @@ -303,35 +269,25 @@ pub(crate) mod _signal { it_value: double_to_timeval(seconds), it_interval: double_to_timeval(interval), }; - let mut old = core::mem::MaybeUninit::::uninit(); - #[cfg(any(target_os = "linux", target_os = "android"))] - let ret = unsafe { ffi::setitimer(which, &new, old.as_mut_ptr()) }; - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let ret = unsafe { libc::setitimer(which, &new, old.as_mut_ptr()) }; - if ret != 0 { - let err = std::io::Error::last_os_error(); - let itimer_error = itimer_error(vm); - return Err(vm.new_exception_msg(itimer_error, err.to_string().into())); + match host_signal::setitimer(which, &new) { + Ok(old) => Ok(itimerval_to_tuple(&old)), + Err(err) => { + let itimer_error = itimer_error(vm); + Err(vm.new_exception_msg(itimer_error, err.to_string().into())) + } } - let old = unsafe { old.assume_init() }; - Ok(itimerval_to_tuple(&old)) } #[cfg(unix)] #[pyfunction] fn getitimer(which: i32, vm: &VirtualMachine) -> PyResult<(f64, f64)> { - let mut old = core::mem::MaybeUninit::::uninit(); - #[cfg(any(target_os = "linux", target_os = "android"))] - let ret = unsafe { ffi::getitimer(which, old.as_mut_ptr()) }; - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let ret = unsafe { libc::getitimer(which, old.as_mut_ptr()) }; - if ret != 0 { - let err = std::io::Error::last_os_error(); - let itimer_error = itimer_error(vm); - return Err(vm.new_exception_msg(itimer_error, err.to_string().into())); + match host_signal::getitimer(which) { + Ok(old) => Ok(itimerval_to_tuple(&old)), + Err(err) => { + let itimer_error = itimer_error(vm); + Err(vm.new_exception_msg(itimer_error, err.to_string().into())) + } } - let old = unsafe { old.assume_init() }; - Ok(itimerval_to_tuple(&old)) } #[pyfunction] @@ -365,54 +321,25 @@ pub(crate) mod _signal { #[cfg(windows)] let is_socket = if fd != INVALID_WAKEUP { - use windows_sys::Win32::Networking::WinSock; - - crate::windows::init_winsock(); - let mut res = 0i32; - let mut res_size = core::mem::size_of::() as i32; - let res = unsafe { - WinSock::getsockopt( - fd, - WinSock::SOL_SOCKET, - WinSock::SO_ERROR, - &mut res as *mut i32 as *mut _, - &mut res_size, - ) - }; - // if getsockopt succeeded, fd is for sure a socket - let is_socket = res == 0; - if !is_socket { - let err = std::io::Error::last_os_error(); - // if getsockopt failed for some other reason, throw - if err.raw_os_error() != Some(WinSock::WSAENOTSOCK) { - return Err(err.into_pyexception(vm)); + host_signal::wakeup_fd_is_socket(fd).map_err(|err| { + if err.kind() == std::io::ErrorKind::InvalidInput { + vm.new_value_error("invalid fd") + } else { + err.into_pyexception(vm) } - // Validate that fd is a valid file descriptor using fstat - // First check if SOCKET can be safely cast to i32 (file descriptor) - let fd_i32 = i32::try_from(fd).map_err(|_| vm.new_value_error("invalid fd"))?; - // Verify the fd is valid by trying to fstat it - let borrowed_fd = - unsafe { rustpython_host_env::crt_fd::Borrowed::try_borrow_raw(fd_i32) } - .map_err(|e| e.into_pyexception(vm))?; - rustpython_host_env::fileutils::fstat(borrowed_fd) - .map_err(|e| e.into_pyexception(vm))?; - } - is_socket + })? } else { false }; #[cfg(unix)] - if let Ok(fd) = unsafe { rustpython_host_env::crt_fd::Borrowed::try_borrow_raw(fd) } { - use nix::fcntl; - let oflags = fcntl::fcntl(fd, fcntl::F_GETFL).map_err(|e| e.into_pyexception(vm))?; - let nonblock = - fcntl::OFlag::from_bits_truncate(oflags).contains(fcntl::OFlag::O_NONBLOCK); - if !nonblock { - return Err(vm.new_value_error(format!( - "the fd {} must be in non-blocking mode", - fd.as_raw() - ))); - } + if let Ok(fd) = unsafe { rustpython_host_env::crt_fd::Borrowed::try_borrow_raw(fd) } + && rustpython_host_env::fcntl::get_blocking(fd.as_fd()) + .map_err(|e| e.into_pyexception(vm))? + { + return Err(vm.new_value_error(format!( + "the fd {} must be in non-blocking mode", + fd.as_raw() + ))); } let old_fd = WAKEUP.swap(fd, Ordering::Relaxed); @@ -450,33 +377,14 @@ pub(crate) mod _signal { } let flags = flags.unwrap_or(0); - let ret = unsafe { - libc::syscall( - libc::SYS_pidfd_send_signal, - pidfd, - sig, - core::ptr::null::(), - flags, - ) as libc::c_long - }; - - if ret == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + host_signal::pidfd_send_signal(pidfd, sig, flags).map_err(|_| vm.new_last_errno_error()) } #[cfg(all(unix, not(target_os = "redox")))] #[pyfunction(name = "siginterrupt")] fn py_siginterrupt(signum: i32, flag: i32, vm: &VirtualMachine) -> PyResult<()> { signal::assert_in_range(signum, vm)?; - let res = unsafe { siginterrupt(signum, flag) }; - if res < 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + host_signal::siginterrupt(signum, flag).map_err(|_| vm.new_last_errno_error()) } /// CPython: signal_raise_signal (signalmodule.c) @@ -488,25 +396,14 @@ pub(crate) mod _signal { // On Windows, only certain signals are supported #[cfg(windows)] { - // Windows supports: SIGINT(2), SIGILL(4), SIGFPE(8), SIGSEGV(11), SIGTERM(15), SIGBREAK(21), SIGABRT(22) - const VALID_SIGNALS: &[i32] = &[ - libc::SIGINT, - libc::SIGILL, - libc::SIGFPE, - libc::SIGSEGV, - libc::SIGTERM, - SIGBREAK, - libc::SIGABRT, - ]; - if !VALID_SIGNALS.contains(&signalnum) { + if !host_signal::is_valid_signal(signalnum) { return Err(vm .new_errno_error(libc::EINVAL, "Invalid argument") .upcast()); } } - let res = unsafe { libc::raise(signalnum) }; - if res != 0 { + if host_signal::raise_signal(signalnum).is_err() { return Err(vm.new_os_error(format!("raise_signal failed for signal {signalnum}"))); } @@ -523,13 +420,7 @@ pub(crate) mod _signal { if signalnum < 1 || signalnum >= signal::NSIG as i32 { return Err(vm.new_value_error(format!("signal number {signalnum} out of range"))); } - let s = unsafe { libc::strsignal(signalnum) }; - if s.is_null() { - Ok(None) - } else { - let cstr = unsafe { core::ffi::CStr::from_ptr(s) }; - Ok(Some(cstr.to_string_lossy().into_owned())) - } + Ok(host_signal::strsignal(signalnum)) } #[cfg(windows)] @@ -538,18 +429,7 @@ pub(crate) mod _signal { if signalnum < 1 || signalnum >= signal::NSIG as i32 { return Err(vm.new_value_error(format!("signal number {signalnum} out of range"))); } - // Windows doesn't have strsignal(), provide our own mapping - let name = match signalnum { - libc::SIGINT => "Interrupt", - libc::SIGILL => "Illegal instruction", - libc::SIGFPE => "Floating-point exception", - libc::SIGSEGV => "Segmentation fault", - libc::SIGTERM => "Terminated", - SIGBREAK => "Break", - libc::SIGABRT => "Aborted", - _ => return Ok(None), - }; - Ok(Some(name.to_owned())) + Ok(host_signal::strsignal(signalnum)) } /// CPython: signal_valid_signals (signalmodule.c) @@ -558,39 +438,16 @@ pub(crate) mod _signal { use crate::PyPayload; use crate::builtins::PySet; let set = PySet::default().into_ref(&vm.ctx); - cfg_select! { - unix => { - // Use sigfillset to get all valid signals - let mut mask: libc::sigset_t = unsafe { core::mem::zeroed() }; - // SAFETY: mask is a valid pointer - if unsafe { libc::sigfillset(&mut mask) } != 0 { - return Err(vm.new_os_error("sigfillset failed".to_owned())); - } - // Convert the filled mask to a Python set - for signum in 1..signal::NSIG { - if unsafe { libc::sigismember(&mask, signum as i32) } == 1 { - set.add(vm.ctx.new_int(signum as i32).into(), vm)?; - } - } - } - windows => { - // Windows only supports a limited set of signals - for &signum in &[ - libc::SIGINT, - libc::SIGILL, - libc::SIGFPE, - libc::SIGSEGV, - libc::SIGTERM, - SIGBREAK, - libc::SIGABRT, - ] { - set.add(vm.ctx.new_int(signum).into(), vm)?; - } - } - _ => { - // Empty set for platforms without signal support (e.g., WASM) - let _ = &set; - } + #[cfg(any(unix, windows))] + for signum in host_signal::valid_signals(signal::NSIG) + .map_err(|_| vm.new_os_error("sigfillset failed".to_owned()))? + { + set.add(vm.ctx.new_int(signum).into(), vm)?; + } + #[cfg(not(any(unix, windows)))] + { + // Empty set for platforms without signal support (e.g., WASM) + let _ = &set; } Ok(set.into()) } @@ -601,8 +458,7 @@ pub(crate) mod _signal { use crate::builtins::PySet; let set = PySet::default().into_ref(&vm.ctx); for signum in 1..signal::NSIG { - // SAFETY: mask is a valid sigset_t - if unsafe { libc::sigismember(mask, signum as i32) } == 1 { + if host_signal::sigset_contains(mask, signum as i32) { set.add(vm.ctx.new_int(signum as i32).into(), vm)?; } } @@ -619,11 +475,7 @@ pub(crate) mod _signal { use crate::convert::IntoPyException; // Initialize sigset - let mut sigset: libc::sigset_t = unsafe { core::mem::zeroed() }; - // SAFETY: sigset is a valid pointer - if unsafe { libc::sigemptyset(&mut sigset) } != 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } + let mut sigset = host_signal::sigemptyset().map_err(|e| e.into_pyexception(vm))?; // Add signals to the set for sig in mask.iter(vm)? { @@ -643,19 +495,11 @@ pub(crate) mod _signal { signal::NSIG - 1 ))); } - // SAFETY: sigset is a valid pointer and signum is validated - if unsafe { libc::sigaddset(&mut sigset, signum) } != 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } + host_signal::sigaddset(&mut sigset, signum).map_err(|e| e.into_pyexception(vm))?; } - // Call pthread_sigmask - let mut old_mask: libc::sigset_t = unsafe { core::mem::zeroed() }; - // SAFETY: all pointers are valid - let err = unsafe { libc::pthread_sigmask(how, &sigset, &mut old_mask) }; - if err != 0 { - return Err(std::io::Error::from_raw_os_error(err).into_pyexception(vm)); - } + let old_mask = + host_signal::pthread_sigmask(how, &sigset).map_err(|e| e.into_pyexception(vm))?; // Check for pending signals signal::check_signals(vm)?; @@ -665,35 +509,18 @@ pub(crate) mod _signal { } #[cfg(any(unix, windows))] - pub(crate) extern "C" fn run_signal(signum: i32) { + pub extern "C" fn run_signal(signum: i32) { signal::TRIGGERS[signum as usize].store(true, Ordering::Relaxed); signal::set_triggered(); #[cfg(windows)] - if signum == libc::SIGINT - && let Some(handle) = signal::get_sigint_event() - { - unsafe { - windows_sys::Win32::System::Threading::SetEvent(handle as _); - } - } - let wakeup_fd = WAKEUP.load(Ordering::Relaxed); - if wakeup_fd != INVALID_WAKEUP { - let sigbyte = signum as u8; - #[cfg(windows)] - if WAKEUP_IS_SOCKET.load(Ordering::Relaxed) { - let _res = unsafe { - windows_sys::Win32::Networking::WinSock::send( - wakeup_fd, - &sigbyte as *const u8 as *const _, - 1, - 0, - ) - }; - return; - } - let _res = unsafe { libc::write(wakeup_fd as _, &sigbyte as *const u8 as *const _, 1) }; - // TODO: handle _res < 1, support warn_on_full_buffer - } + host_signal::notify_signal( + signum, + WAKEUP.load(Ordering::Relaxed), + WAKEUP_IS_SOCKET.load(Ordering::Relaxed), + signal::get_sigint_event(), + ); + #[cfg(unix)] + host_signal::notify_signal(signum, WAKEUP.load(Ordering::Relaxed)); } /// Reset wakeup fd after fork in child process. diff --git a/crates/vm/src/stdlib/_stat.rs b/crates/vm/src/stdlib/_stat.rs index 8809c1d0bf8..f461221daff 100644 --- a/crates/vm/src/stdlib/_stat.rs +++ b/crates/vm/src/stdlib/_stat.rs @@ -2,6 +2,11 @@ pub(crate) use _stat::module_def; #[pymodule] mod _stat { + #![allow(unreachable_pub)] + + #[cfg(windows)] + use rustpython_host_env::nt as host_nt; + // Use libc::mode_t for Mode to match the system's definition #[cfg(unix)] type Mode = libc::mode_t; @@ -25,65 +30,65 @@ mod _stat { } #[pyattr] - pub(super) const S_IFDIR: Mode = libc_const!( + pub const S_IFDIR: Mode = libc_const!( #[cfg(unix)] S_IFDIR, 0o040000 ); #[pyattr] - pub(super) const S_IFCHR: Mode = libc_const!( + pub const S_IFCHR: Mode = libc_const!( #[cfg(unix)] S_IFCHR, 0o020000 ); #[pyattr] - pub(super) const S_IFBLK: Mode = libc_const!( + pub const S_IFBLK: Mode = libc_const!( #[cfg(unix)] S_IFBLK, 0o060000 ); #[pyattr] - pub(super) const S_IFREG: Mode = libc_const!( + pub const S_IFREG: Mode = libc_const!( #[cfg(unix)] S_IFREG, 0o100000 ); #[pyattr] - pub(super) const S_IFIFO: Mode = libc_const!( + pub const S_IFIFO: Mode = libc_const!( #[cfg(unix)] S_IFIFO, 0o010000 ); #[pyattr] - pub(super) const S_IFLNK: Mode = libc_const!( + pub const S_IFLNK: Mode = libc_const!( #[cfg(unix)] S_IFLNK, 0o120000 ); #[pyattr] - pub(super) const S_IFSOCK: Mode = libc_const!( + pub const S_IFSOCK: Mode = libc_const!( #[cfg(unix)] S_IFSOCK, 0o140000 ); #[pyattr] - pub(super) const S_IFDOOR: Mode = 0; // TODO: RUSTPYTHON Support Solaris + pub const S_IFDOOR: Mode = 0; // TODO: RUSTPYTHON Support Solaris #[pyattr] - pub(super) const S_IFPORT: Mode = 0; // TODO: RUSTPYTHON Support Solaris + pub const S_IFPORT: Mode = 0; // TODO: RUSTPYTHON Support Solaris // TODO: RUSTPYTHON Support BSD // https://man.freebsd.org/cgi/man.cgi?stat(2) #[pyattr] - pub(super) const S_IFWHT: Mode = if cfg!(target_os = "macos") { + pub const S_IFWHT: Mode = if cfg!(target_os = "macos") { 0o160000 } else { 0 @@ -92,133 +97,133 @@ mod _stat { // Permission bits #[pyattr] - pub(super) const S_ISUID: Mode = libc_const!( + pub const S_ISUID: Mode = libc_const!( #[cfg(unix)] S_ISUID, 0o4000 ); #[pyattr] - pub(super) const S_ISGID: Mode = libc_const!( + pub const S_ISGID: Mode = libc_const!( #[cfg(unix)] S_ISGID, 0o2000 ); #[pyattr] - pub(super) const S_ENFMT: Mode = libc_const!( + pub const S_ENFMT: Mode = libc_const!( #[cfg(unix)] S_ISGID, 0o2000 ); #[pyattr] - pub(super) const S_ISVTX: Mode = libc_const!( + pub const S_ISVTX: Mode = libc_const!( #[cfg(unix)] S_ISVTX, 0o1000 ); #[pyattr] - pub(super) const S_IRWXU: Mode = libc_const!( + pub const S_IRWXU: Mode = libc_const!( #[cfg(unix)] S_IRWXU, 0o0700 ); #[pyattr] - pub(super) const S_IRUSR: Mode = libc_const!( + pub const S_IRUSR: Mode = libc_const!( #[cfg(unix)] S_IRUSR, 0o0400 ); #[pyattr] - pub(super) const S_IREAD: Mode = libc_const!( + pub const S_IREAD: Mode = libc_const!( #[cfg(unix)] S_IRUSR, 0o0400 ); #[pyattr] - pub(super) const S_IWUSR: Mode = libc_const!( + pub const S_IWUSR: Mode = libc_const!( #[cfg(unix)] S_IWUSR, 0o0200 ); #[pyattr] - pub(super) const S_IXUSR: Mode = libc_const!( + pub const S_IXUSR: Mode = libc_const!( #[cfg(unix)] S_IXUSR, 0o0100 ); #[pyattr] - pub(super) const S_IRWXG: Mode = libc_const!( + pub const S_IRWXG: Mode = libc_const!( #[cfg(unix)] S_IRWXG, 0o0070 ); #[pyattr] - pub(super) const S_IRGRP: Mode = libc_const!( + pub const S_IRGRP: Mode = libc_const!( #[cfg(unix)] S_IRGRP, 0o0040 ); #[pyattr] - pub(super) const S_IWGRP: Mode = libc_const!( + pub const S_IWGRP: Mode = libc_const!( #[cfg(unix)] S_IWGRP, 0o0020 ); #[pyattr] - pub(super) const S_IXGRP: Mode = libc_const!( + pub const S_IXGRP: Mode = libc_const!( #[cfg(unix)] S_IXGRP, 0o0010 ); #[pyattr] - pub(super) const S_IRWXO: Mode = libc_const!( + pub const S_IRWXO: Mode = libc_const!( #[cfg(unix)] S_IRWXO, 0o0007 ); #[pyattr] - pub(super) const S_IROTH: Mode = libc_const!( + pub const S_IROTH: Mode = libc_const!( #[cfg(unix)] S_IROTH, 0o0004 ); #[pyattr] - pub(super) const S_IWOTH: Mode = libc_const!( + pub const S_IWOTH: Mode = libc_const!( #[cfg(unix)] S_IWOTH, 0o0002 ); #[pyattr] - pub(super) const S_IXOTH: Mode = libc_const!( + pub const S_IXOTH: Mode = libc_const!( #[cfg(unix)] S_IXOTH, 0o0001 ); #[pyattr] - pub(super) const S_IWRITE: Mode = libc_const!( + pub const S_IWRITE: Mode = libc_const!( #[cfg(all(unix, not(target_os = "android"), not(target_os = "redox")))] S_IWRITE, 0o0200 ); #[pyattr] - pub(super) const S_IEXEC: Mode = libc_const!( + pub const S_IEXEC: Mode = libc_const!( #[cfg(all(unix, not(target_os = "android"), not(target_os = "redox")))] S_IEXEC, 0o0100 @@ -228,7 +233,7 @@ mod _stat { #[cfg(windows)] #[pyattr] - pub(super) use windows_sys::Win32::Storage::FileSystem::{ + pub use host_nt::{ FILE_ATTRIBUTE_ARCHIVE, FILE_ATTRIBUTE_COMPRESSED, FILE_ATTRIBUTE_DEVICE, FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_ENCRYPTED, FILE_ATTRIBUTE_HIDDEN, FILE_ATTRIBUTE_INTEGRITY_STREAM, FILE_ATTRIBUTE_NO_SCRUB_DATA, FILE_ATTRIBUTE_NORMAL, @@ -240,144 +245,142 @@ mod _stat { // Windows reparse point tags #[cfg(windows)] #[pyattr] - pub(super) const IO_REPARSE_TAG_SYMLINK: u32 = 0xA000000C; - + pub const IO_REPARSE_TAG_SYMLINK: u32 = 0xA000000C; #[cfg(windows)] #[pyattr] - pub(super) const IO_REPARSE_TAG_MOUNT_POINT: u32 = 0xA0000003; - + pub const IO_REPARSE_TAG_MOUNT_POINT: u32 = 0xA0000003; #[cfg(windows)] #[pyattr] - pub(super) const IO_REPARSE_TAG_APPEXECLINK: u32 = 0x8000001B; + pub const IO_REPARSE_TAG_APPEXECLINK: u32 = 0x8000001B; // Unix file flags (if on Unix) #[pyattr] - pub(super) const UF_NODUMP: u32 = libc_const!( + pub const UF_NODUMP: u32 = libc_const!( #[cfg(target_os = "macos")] UF_NODUMP, 0x00000001 ); #[pyattr] - pub(super) const UF_IMMUTABLE: u32 = libc_const!( + pub const UF_IMMUTABLE: u32 = libc_const!( #[cfg(target_os = "macos")] UF_IMMUTABLE, 0x00000002 ); #[pyattr] - pub(super) const UF_APPEND: u32 = libc_const!( + pub const UF_APPEND: u32 = libc_const!( #[cfg(target_os = "macos")] UF_APPEND, 0x00000004 ); #[pyattr] - pub(super) const UF_OPAQUE: u32 = libc_const!( + pub const UF_OPAQUE: u32 = libc_const!( #[cfg(target_os = "macos")] UF_OPAQUE, 0x00000008 ); #[pyattr] - pub(super) const UF_COMPRESSED: u32 = libc_const!( + pub const UF_COMPRESSED: u32 = libc_const!( #[cfg(target_os = "macos")] UF_COMPRESSED, 0x00000020 ); #[pyattr] - pub(super) const UF_HIDDEN: u32 = libc_const!( + pub const UF_HIDDEN: u32 = libc_const!( #[cfg(target_os = "macos")] UF_HIDDEN, 0x00008000 ); #[pyattr] - pub(super) const SF_ARCHIVED: u32 = libc_const!( + pub const SF_ARCHIVED: u32 = libc_const!( #[cfg(target_os = "macos")] SF_ARCHIVED, 0x00010000 ); #[pyattr] - pub(super) const SF_IMMUTABLE: u32 = libc_const!( + pub const SF_IMMUTABLE: u32 = libc_const!( #[cfg(target_os = "macos")] SF_IMMUTABLE, 0x00020000 ); #[pyattr] - pub(super) const SF_APPEND: u32 = libc_const!( + pub const SF_APPEND: u32 = libc_const!( #[cfg(target_os = "macos")] SF_APPEND, 0x00040000 ); #[pyattr] - pub(super) const SF_SETTABLE: u32 = if cfg!(target_os = "macos") { + pub const SF_SETTABLE: u32 = if cfg!(target_os = "macos") { 0x3fff0000 } else { 0xffff0000 }; #[pyattr] - pub(super) const UF_NOUNLINK: u32 = 0x00000010; + pub const UF_NOUNLINK: u32 = 0x00000010; #[pyattr] - pub(super) const SF_NOUNLINK: u32 = 0x00100000; + pub const SF_NOUNLINK: u32 = 0x00100000; #[pyattr] - pub(super) const SF_SNAPSHOT: u32 = 0x00200000; + pub const SF_SNAPSHOT: u32 = 0x00200000; #[pyattr] - pub(super) const SF_FIRMLINK: u32 = 0x00800000; + pub const SF_FIRMLINK: u32 = 0x00800000; #[pyattr] - pub(super) const SF_DATALESS: u32 = 0x40000000; + pub const SF_DATALESS: u32 = 0x40000000; // MacOS specific #[cfg(target_os = "macos")] #[pyattr] - pub(super) const SF_SUPPORTED: u32 = 0x009f0000; + pub const SF_SUPPORTED: u32 = 0x009f0000; #[cfg(target_os = "macos")] #[pyattr] - pub(super) const SF_SYNTHETIC: u32 = 0xc0000000; + pub const SF_SYNTHETIC: u32 = 0xc0000000; // Stat result indices #[pyattr] - pub(super) const ST_MODE: u32 = 0; + pub const ST_MODE: u32 = 0; #[pyattr] - pub(super) const ST_INO: u32 = 1; + pub const ST_INO: u32 = 1; #[pyattr] - pub(super) const ST_DEV: u32 = 2; + pub const ST_DEV: u32 = 2; #[pyattr] - pub(super) const ST_NLINK: u32 = 3; + pub const ST_NLINK: u32 = 3; #[pyattr] - pub(super) const ST_UID: u32 = 4; + pub const ST_UID: u32 = 4; #[pyattr] - pub(super) const ST_GID: u32 = 5; + pub const ST_GID: u32 = 5; #[pyattr] - pub(super) const ST_SIZE: u32 = 6; + pub const ST_SIZE: u32 = 6; #[pyattr] - pub(super) const ST_ATIME: u32 = 7; + pub const ST_ATIME: u32 = 7; #[pyattr] - pub(super) const ST_MTIME: u32 = 8; + pub const ST_MTIME: u32 = 8; #[pyattr] - pub(super) const ST_CTIME: u32 = 9; + pub const ST_CTIME: u32 = 9; const S_IFMT: Mode = 0o170000; diff --git a/crates/vm/src/stdlib/_thread.rs b/crates/vm/src/stdlib/_thread.rs index c9589c1ec52..0af8d7add38 100644 --- a/crates/vm/src/stdlib/_thread.rs +++ b/crates/vm/src/stdlib/_thread.rs @@ -31,6 +31,8 @@ pub(crate) mod _thread { lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, }; use rustpython_common::str::levenshtein::{MOVE_COST, levenshtein_distance}; + #[cfg(any(unix, windows))] + use rustpython_host_env::thread as host_thread; use std::thread; // PYTHREAD_NAME: show current thread name @@ -381,50 +383,17 @@ pub(crate) mod _thread { /// Set the name of the current thread #[pyfunction] fn set_name(name: PyUtf8StrRef) { - #[cfg(target_os = "linux")] - { - use alloc::ffi::CString; - if let Ok(c_name) = CString::new(name.as_str()) { - // pthread_setname_np on Linux has a 16-byte limit including null terminator - // TODO: Potential UTF-8 boundary issue when truncating thread name on Linux. - // https://github.com/RustPython/RustPython/pull/6726/changes#r2689379171 - let truncated = if c_name.as_bytes().len() > 15 { - CString::new(&c_name.as_bytes()[..15]).unwrap_or(c_name) - } else { - c_name - }; - unsafe { - libc::pthread_setname_np(libc::pthread_self(), truncated.as_ptr()); - } - } - } - #[cfg(target_os = "macos")] - { - use alloc::ffi::CString; - if let Ok(c_name) = CString::new(name.as_str()) { - unsafe { - libc::pthread_setname_np(c_name.as_ptr()); - } - } - } - #[cfg(windows)] - { - // Windows doesn't have a simple pthread_setname_np equivalent - // SetThreadDescription requires Windows 10+ - let _ = name; - } - #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))] - { - let _ = name; - } + #[cfg(any(unix, windows))] + host_thread::set_current_thread_name(name.as_str()); + #[cfg(not(any(unix, windows)))] + let _ = name; } /// Get OS-level thread ID (pthread_self on Unix) /// This is important for fork compatibility - the ID must remain stable after fork #[cfg(unix)] fn current_thread_id() -> u64 { - // pthread_self() for fork compatibility - unsafe { libc::pthread_self() as u64 } + host_thread::current_thread_id() } #[cfg(not(unix))] @@ -561,10 +530,16 @@ pub(crate) mod _thread { // Increment thread count when thread actually starts executing vm.state.thread_count.fetch_add(1); - match func.invoke(args, vm) { - Ok(_obj) => {} - Err(e) if e.fast_isinstance(vm.ctx.exceptions.system_exit) => {} - Err(exc) => { + // Inner scope: drop `func` (and its Python refs) before the thread + // slot is torn down below. Otherwise the parameter `func` would drop + // at end-of-function, after cleanup_current_thread_frames has cleared + // CURRENT_THREAD_SLOT, and a weakref callback fired during that drop + // would panic in push_thread_frame. + { + let func = func; + if let Err(exc) = func.invoke(args, vm) + && !exc.fast_isinstance(vm.ctx.exceptions.system_exit) + { vm.run_unraisable( exc, Some("Exception ignored in thread started by".to_owned()), @@ -1694,11 +1669,18 @@ pub(crate) mod _thread { // Increment thread count when thread actually starts executing vm_state.thread_count.fetch_add(1); - // Run the function - match func.invoke((), vm) { - Ok(_) => {} - Err(e) if e.fast_isinstance(vm.ctx.exceptions.system_exit) => {} - Err(exc) => { + // Inner scope: drop `func` (and its Python refs) before the + // outer scopeguard::defer tears down the thread slot. As a + // `move` closure capture, `func` would otherwise drop after + // all locals (including the scopeguard `_guard`), and a + // weakref callback fired during that drop would panic in + // push_thread_frame. + { + let func = func; + // Run the function + if let Err(exc) = func.invoke((), vm) + && !exc.fast_isinstance(vm.ctx.exceptions.system_exit) + { vm.run_unraisable( exc, Some("Exception ignored in thread started by".to_owned()), diff --git a/crates/vm/src/stdlib/_winapi.rs b/crates/vm/src/stdlib/_winapi.rs index 113d5bd2de4..f6faf6a8a95 100644 --- a/crates/vm/src/stdlib/_winapi.rs +++ b/crates/vm/src/stdlib/_winapi.rs @@ -9,75 +9,54 @@ mod _winapi { Py, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, builtins::PyStrRef, common::lock::PyMutex, - convert::{ToPyException, ToPyResult}, + convert::ToPyException, function::{ArgMapping, ArgSequence, OptionalArg}, types::Constructor, windows::{WinHandle, WindowsSysResult}, }; - use core::ptr::{null, null_mut}; + use core::ptr::null_mut; use rustpython_common::wtf8::Wtf8Buf; + use rustpython_host_env::overlapped as host_overlapped; use rustpython_host_env::winapi as host_winapi; use rustpython_host_env::windows::ToWideString; - use windows_sys::Win32::Foundation::{HANDLE, MAX_PATH}; #[pyattr] - use windows_sys::Win32::{ - Foundation::{ - DUPLICATE_CLOSE_SOURCE, DUPLICATE_SAME_ACCESS, ERROR_ACCESS_DENIED, - ERROR_ALREADY_EXISTS, ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, - ERROR_NETNAME_DELETED, ERROR_NO_DATA, ERROR_NO_SYSTEM_RESOURCES, - ERROR_OPERATION_ABORTED, ERROR_PIPE_BUSY, ERROR_PIPE_CONNECTED, - ERROR_PRIVILEGE_NOT_HELD, ERROR_SEM_TIMEOUT, GENERIC_READ, GENERIC_WRITE, STILL_ACTIVE, - WAIT_ABANDONED_0, WAIT_OBJECT_0, WAIT_TIMEOUT, - }, - Globalization::{ - LCMAP_FULLWIDTH, LCMAP_HALFWIDTH, LCMAP_HIRAGANA, LCMAP_KATAKANA, - LCMAP_LINGUISTIC_CASING, LCMAP_LOWERCASE, LCMAP_SIMPLIFIED_CHINESE, LCMAP_TITLECASE, - LCMAP_TRADITIONAL_CHINESE, LCMAP_UPPERCASE, - }, - Storage::FileSystem::{ - COPY_FILE_ALLOW_DECRYPTED_DESTINATION, COPY_FILE_COPY_SYMLINK, - COPY_FILE_FAIL_IF_EXISTS, COPY_FILE_NO_BUFFERING, COPY_FILE_NO_OFFLOAD, - COPY_FILE_OPEN_SOURCE_FOR_WRITE, COPY_FILE_REQUEST_COMPRESSED_TRAFFIC, - COPY_FILE_REQUEST_SECURITY_PRIVILEGES, COPY_FILE_RESTARTABLE, - COPY_FILE_RESUME_FROM_PAUSE, COPYFILE2_CALLBACK_CHUNK_FINISHED, - COPYFILE2_CALLBACK_CHUNK_STARTED, COPYFILE2_CALLBACK_ERROR, - COPYFILE2_CALLBACK_POLL_CONTINUE, COPYFILE2_CALLBACK_STREAM_FINISHED, - COPYFILE2_CALLBACK_STREAM_STARTED, COPYFILE2_PROGRESS_CANCEL, - COPYFILE2_PROGRESS_CONTINUE, COPYFILE2_PROGRESS_PAUSE, COPYFILE2_PROGRESS_QUIET, - COPYFILE2_PROGRESS_STOP, FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, - FILE_GENERIC_READ, FILE_GENERIC_WRITE, FILE_TYPE_CHAR, FILE_TYPE_DISK, FILE_TYPE_PIPE, - FILE_TYPE_REMOTE, FILE_TYPE_UNKNOWN, OPEN_EXISTING, PIPE_ACCESS_DUPLEX, - PIPE_ACCESS_INBOUND, SYNCHRONIZE, - }, - System::{ - Console::{STD_ERROR_HANDLE, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE}, - Memory::{ - FILE_MAP_ALL_ACCESS, FILE_MAP_COPY, FILE_MAP_EXECUTE, FILE_MAP_READ, - FILE_MAP_WRITE, MEM_COMMIT, MEM_FREE, MEM_IMAGE, MEM_MAPPED, MEM_PRIVATE, - MEM_RESERVE, PAGE_EXECUTE, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, - PAGE_EXECUTE_WRITECOPY, PAGE_GUARD, PAGE_NOACCESS, PAGE_NOCACHE, PAGE_READONLY, - PAGE_READWRITE, PAGE_WRITECOMBINE, PAGE_WRITECOPY, SEC_COMMIT, SEC_IMAGE, - SEC_LARGE_PAGES, SEC_NOCACHE, SEC_RESERVE, SEC_WRITECOMBINE, - }, - Pipes::{ - NMPWAIT_WAIT_FOREVER, PIPE_READMODE_MESSAGE, PIPE_TYPE_MESSAGE, - PIPE_UNLIMITED_INSTANCES, PIPE_WAIT, - }, - SystemServices::LOCALE_NAME_MAX_LENGTH, - Threading::{ - ABOVE_NORMAL_PRIORITY_CLASS, BELOW_NORMAL_PRIORITY_CLASS, - CREATE_BREAKAWAY_FROM_JOB, CREATE_DEFAULT_ERROR_MODE, CREATE_NEW_CONSOLE, - CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW, DETACHED_PROCESS, HIGH_PRIORITY_CLASS, - IDLE_PRIORITY_CLASS, INFINITE, NORMAL_PRIORITY_CLASS, PROCESS_ALL_ACCESS, - PROCESS_DUP_HANDLE, REALTIME_PRIORITY_CLASS, STARTF_FORCEOFFFEEDBACK, - STARTF_FORCEONFEEDBACK, STARTF_PREVENTPINNING, STARTF_RUNFULLSCREEN, - STARTF_TITLEISAPPID, STARTF_TITLEISLINKNAME, STARTF_UNTRUSTEDSOURCE, - STARTF_USECOUNTCHARS, STARTF_USEFILLATTRIBUTE, STARTF_USEHOTKEY, - STARTF_USEPOSITION, STARTF_USESHOWWINDOW, STARTF_USESIZE, STARTF_USESTDHANDLES, - }, - }, - UI::WindowsAndMessaging::SW_HIDE, + use host_winapi::{ + ABOVE_NORMAL_PRIORITY_CLASS, BELOW_NORMAL_PRIORITY_CLASS, + COPY_FILE_ALLOW_DECRYPTED_DESTINATION, COPY_FILE_COPY_SYMLINK, COPY_FILE_FAIL_IF_EXISTS, + COPY_FILE_NO_BUFFERING, COPY_FILE_NO_OFFLOAD, COPY_FILE_OPEN_SOURCE_FOR_WRITE, + COPY_FILE_REQUEST_COMPRESSED_TRAFFIC, COPY_FILE_REQUEST_SECURITY_PRIVILEGES, + COPY_FILE_RESTARTABLE, COPY_FILE_RESUME_FROM_PAUSE, COPYFILE2_CALLBACK_CHUNK_FINISHED, + COPYFILE2_CALLBACK_CHUNK_STARTED, COPYFILE2_CALLBACK_ERROR, + COPYFILE2_CALLBACK_POLL_CONTINUE, COPYFILE2_CALLBACK_STREAM_FINISHED, + COPYFILE2_CALLBACK_STREAM_STARTED, COPYFILE2_PROGRESS_CANCEL, COPYFILE2_PROGRESS_CONTINUE, + COPYFILE2_PROGRESS_PAUSE, COPYFILE2_PROGRESS_QUIET, COPYFILE2_PROGRESS_STOP, + CREATE_BREAKAWAY_FROM_JOB, CREATE_DEFAULT_ERROR_MODE, CREATE_NEW_CONSOLE, + CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW, DETACHED_PROCESS, DUPLICATE_CLOSE_SOURCE, + DUPLICATE_SAME_ACCESS, ERROR_ACCESS_DENIED, ERROR_ALREADY_EXISTS, ERROR_BROKEN_PIPE, + ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_NETNAME_DELETED, ERROR_NO_DATA, + ERROR_NO_SYSTEM_RESOURCES, ERROR_OPERATION_ABORTED, ERROR_PIPE_BUSY, ERROR_PIPE_CONNECTED, + ERROR_PRIVILEGE_NOT_HELD, ERROR_SEM_TIMEOUT, FILE_FLAG_FIRST_PIPE_INSTANCE, + FILE_FLAG_OVERLAPPED, FILE_GENERIC_READ, FILE_GENERIC_WRITE, FILE_MAP_ALL_ACCESS, + FILE_MAP_COPY, FILE_MAP_EXECUTE, FILE_MAP_READ, FILE_MAP_WRITE, FILE_TYPE_CHAR, + FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE, GENERIC_READ, GENERIC_WRITE, + HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, LCMAP_FULLWIDTH, LCMAP_HALFWIDTH, LCMAP_HIRAGANA, + LCMAP_KATAKANA, LCMAP_LINGUISTIC_CASING, LCMAP_LOWERCASE, LCMAP_SIMPLIFIED_CHINESE, + LCMAP_TITLECASE, LCMAP_TRADITIONAL_CHINESE, LCMAP_UPPERCASE, LOCALE_NAME_MAX_LENGTH, + MEM_COMMIT, MEM_FREE, MEM_IMAGE, MEM_MAPPED, MEM_PRIVATE, MEM_RESERVE, + NMPWAIT_WAIT_FOREVER, NORMAL_PRIORITY_CLASS, OPEN_EXISTING, PAGE_EXECUTE, + PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, PAGE_EXECUTE_WRITECOPY, PAGE_GUARD, + PAGE_NOACCESS, PAGE_NOCACHE, PAGE_READONLY, PAGE_READWRITE, PAGE_WRITECOMBINE, + PAGE_WRITECOPY, PIPE_ACCESS_DUPLEX, PIPE_ACCESS_INBOUND, PIPE_READMODE_MESSAGE, + PIPE_TYPE_MESSAGE, PIPE_UNLIMITED_INSTANCES, PIPE_WAIT, PROCESS_ALL_ACCESS, + PROCESS_DUP_HANDLE, REALTIME_PRIORITY_CLASS, SEC_COMMIT, SEC_IMAGE, SEC_LARGE_PAGES, + SEC_NOCACHE, SEC_RESERVE, SEC_WRITECOMBINE, STARTF_FORCEOFFFEEDBACK, + STARTF_FORCEONFEEDBACK, STARTF_PREVENTPINNING, STARTF_RUNFULLSCREEN, STARTF_TITLEISAPPID, + STARTF_TITLEISLINKNAME, STARTF_UNTRUSTEDSOURCE, STARTF_USECOUNTCHARS, + STARTF_USEFILLATTRIBUTE, STARTF_USEHOTKEY, STARTF_USEPOSITION, STARTF_USESHOWWINDOW, + STARTF_USESIZE, STARTF_USESTDHANDLES, STD_ERROR_HANDLE, STD_INPUT_HANDLE, + STD_OUTPUT_HANDLE, STILL_ACTIVE, SW_HIDE, SYNCHRONIZE, WAIT_ABANDONED_0, WAIT_OBJECT_0, + WAIT_TIMEOUT, }; #[pyattr] @@ -86,20 +65,23 @@ mod _winapi { #[pyattr] const INVALID_HANDLE_VALUE: isize = -1; + #[pyattr] + const INFINITE: u32 = host_winapi::INFINITE_TIMEOUT; + #[pyattr] const COPY_FILE_DIRECTORY: u32 = 0x00000080; #[pyfunction] fn CloseHandle(handle: WinHandle) -> WindowsSysResult { - WindowsSysResult(unsafe { windows_sys::Win32::Foundation::CloseHandle(handle.0) }) + WindowsSysResult(host_winapi::close_handle(handle.0)) } /// CreateFile - Create or open a file or I/O device. - #[pyfunction] - #[allow( + #[expect( clippy::too_many_arguments, reason = "matches Win32 CreateFile parameter structure" )] + #[pyfunction] fn CreateFile( file_name: PyStrRef, desired_access: u32, @@ -110,44 +92,26 @@ mod _winapi { _template_file: PyObjectRef, // Always NULL (0) vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Storage::FileSystem::CreateFileW; - - let file_name_wide = file_name.as_wtf8().to_wide_with_nul(); - - let handle = unsafe { - CreateFileW( - file_name_wide.as_ptr(), - desired_access, - share_mode, - null(), - creation_disposition, - flags_and_attributes, - null_mut(), - ) - }; - - if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); - } - - Ok(WinHandle(handle)) + let file_name_wide = file_name.as_wtf8().to_wide_cstring(); + host_winapi::create_file_w( + &file_name_wide, + desired_access, + share_mode, + creation_disposition, + flags_and_attributes, + ) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn GetStdHandle( - std_handle: windows_sys::Win32::System::Console::STD_HANDLE, + std_handle: host_winapi::StdHandle, vm: &VirtualMachine, ) -> PyResult> { - let handle = unsafe { windows_sys::Win32::System::Console::GetStdHandle(std_handle) }; - if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); - } - Ok(if handle.is_null() { - // NULL handle - return None - None - } else { - Some(WinHandle(handle)) - }) + host_winapi::get_std_handle(std_handle) + .map(|handle| handle.map(WinHandle)) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] @@ -156,20 +120,9 @@ mod _winapi { size: u32, vm: &VirtualMachine, ) -> PyResult<(WinHandle, WinHandle)> { - use windows_sys::Win32::Foundation::HANDLE; - let (read, write) = unsafe { - let mut read = core::mem::MaybeUninit::::uninit(); - let mut write = core::mem::MaybeUninit::::uninit(); - WindowsSysResult(windows_sys::Win32::System::Pipes::CreatePipe( - read.as_mut_ptr(), - write.as_mut_ptr(), - core::ptr::null(), - size, - )) - .to_pyresult(vm)?; - (read.assume_init(), write.assume_init()) - }; - Ok((WinHandle(read), WinHandle(write))) + host_winapi::create_pipe(size) + .map(|(read, write)| (WinHandle(read), WinHandle(write))) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] @@ -182,22 +135,16 @@ mod _winapi { options: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::HANDLE; - let target = unsafe { - let mut target = core::mem::MaybeUninit::::uninit(); - WindowsSysResult(windows_sys::Win32::Foundation::DuplicateHandle( - src_process.0, - src.0, - target_process.0, - target.as_mut_ptr(), - access, - inherit, - options.unwrap_or(0), - )) - .to_pyresult(vm)?; - target.assume_init() - }; - Ok(WinHandle(target)) + host_winapi::duplicate_handle( + src_process.0, + src.0, + target_process.0, + access, + inherit, + options.unwrap_or(0), + ) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] @@ -211,16 +158,8 @@ mod _winapi { } #[pyfunction] - fn GetFileType( - h: WinHandle, - vm: &VirtualMachine, - ) -> PyResult { - let file_type = unsafe { windows_sys::Win32::Storage::FileSystem::GetFileType(h.0) }; - if file_type == 0 && unsafe { windows_sys::Win32::Foundation::GetLastError() } != 0 { - Err(vm.new_last_os_error()) - } else { - Ok(file_type) - } + fn GetFileType(h: WinHandle, vm: &VirtualMachine) -> PyResult { + host_winapi::get_file_type(h.0).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] @@ -260,51 +199,39 @@ mod _winapi { args: CreateProcessArgs, vm: &VirtualMachine, ) -> PyResult<(WinHandle, WinHandle, u32, u32)> { - let mut si: windows_sys::Win32::System::Threading::STARTUPINFOEXW = - unsafe { core::mem::zeroed() }; - si.StartupInfo.cb = core::mem::size_of_val(&si) as _; - macro_rules! si_attr { ($attr:ident, $t:ty) => {{ - si.StartupInfo.$attr = >::try_from_object( + >::try_from_object( vm, args.startup_info.get_attr(stringify!($attr), vm)?, )? .unwrap_or(0) as _ }}; ($attr:ident) => {{ - si.StartupInfo.$attr = >::try_from_object( + >::try_from_object( vm, args.startup_info.get_attr(stringify!($attr), vm)?, )? .unwrap_or(0) }}; } - si_attr!(dwFlags); - si_attr!(wShowWindow); - si_attr!(hStdInput, isize); - si_attr!(hStdOutput, isize); - si_attr!(hStdError, isize); + let startup_info = host_winapi::StartupInfoData { + flags: si_attr!(dwFlags), + show_window: si_attr!(wShowWindow), + std_input: si_attr!(hStdInput, isize), + std_output: si_attr!(hStdOutput, isize), + std_error: si_attr!(hStdError, isize), + }; - let mut env = args + let env = args .env_mapping .map(|m| getenvironment(m, vm)) .transpose()?; - let env = env.as_mut().map_or_else(null_mut, |v| v.as_mut_ptr()); - - let mut attrlist = - getattributelist(args.startup_info.get_attr("lpAttributeList", vm)?, vm)?; - si.lpAttributeList = attrlist - .as_mut() - .map_or_else(null_mut, |l| l.attrlist.as_mut_ptr() as _); - - let wstr = |s: PyStrRef| { - let ws = widestring::WideCString::from_str(s.expect_str()) - .map_err(|err| err.to_pyexception(vm))?; - Ok(ws.into_vec_with_nul()) - }; + + let handle_list = get_handle_list(args.startup_info.get_attr("lpAttributeList", vm)?, vm)?; // Validate no embedded null bytes in command name and command line + // before handing the strings off; to_wide_cstring truncates at NUL. if let Some(ref name) = args.name && name.as_bytes().contains(&0) { @@ -316,44 +243,31 @@ mod _winapi { return Err(crate::exceptions::cstring_error(vm)); } - let app_name = args.name.map(wstr).transpose()?; - let app_name = app_name.as_ref().map_or_else(null, |w| w.as_ptr()); - - let mut command_line = args.command_line.map(wstr).transpose()?; - let command_line = command_line - .as_mut() - .map_or_else(null_mut, |w| w.as_mut_ptr()); - - let mut current_dir = args.current_dir.map(wstr).transpose()?; - let current_dir = current_dir - .as_mut() - .map_or_else(null_mut, |w| w.as_mut_ptr()); - - let procinfo = unsafe { - let mut procinfo = core::mem::MaybeUninit::uninit(); - WindowsSysResult(windows_sys::Win32::System::Threading::CreateProcessW( - app_name, - command_line, - core::ptr::null(), - core::ptr::null(), - args.inherit_handles, - args.creation_flags - | windows_sys::Win32::System::Threading::EXTENDED_STARTUPINFO_PRESENT - | windows_sys::Win32::System::Threading::CREATE_UNICODE_ENVIRONMENT, - env as _, - current_dir, - &mut si as *mut _ as *mut _, - procinfo.as_mut_ptr(), - )) - .into_pyresult(vm)?; - procinfo.assume_init() - }; + let wcstring = |s: PyStrRef| s.as_wtf8().to_wide_cstring(); + let app_name = args.name.as_ref().map(|s| wcstring(s.clone())); + let current_dir = args.current_dir.as_ref().map(|s| wcstring(s.clone())); + let mut command_line = args + .command_line + .as_ref() + .map(|s| wcstring(s.clone()).into_vec_with_nul()); + + let procinfo = host_winapi::create_process( + app_name.as_deref(), + command_line.as_deref_mut(), + args.inherit_handles, + args.creation_flags, + env.as_deref(), + current_dir.as_deref(), + startup_info, + handle_list, + ) + .map_err(|e| e.to_pyexception(vm))?; Ok(( - WinHandle(procinfo.hProcess), - WinHandle(procinfo.hThread), - procinfo.dwProcessId, - procinfo.dwThreadId, + WinHandle(procinfo.process), + WinHandle(procinfo.thread), + procinfo.process_id, + procinfo.thread_id, )) } @@ -364,33 +278,20 @@ mod _winapi { process_id: u32, vm: &VirtualMachine, ) -> PyResult { - let handle = unsafe { - windows_sys::Win32::System::Threading::OpenProcess( - desired_access, - i32::from(inherit_handle), - process_id, - ) - }; - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + host_winapi::open_process(desired_access, inherit_handle, process_id) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn ExitProcess(exit_code: u32) { - unsafe { windows_sys::Win32::System::Threading::ExitProcess(exit_code) } + host_winapi::exit_process(exit_code) } #[pyfunction] fn NeedCurrentDirectoryForExePath(exe_name: PyStrRef) -> bool { - let exe_name = exe_name.as_wtf8().to_wide_with_nul(); - let return_value = unsafe { - windows_sys::Win32::System::Environment::NeedCurrentDirectoryForExePathW( - exe_name.as_ptr(), - ) - }; - return_value != 0 + let exe_name = exe_name.as_wtf8().to_wide_cstring(); + host_winapi::need_current_directory_for_exe_path_w(&exe_name) } #[pyfunction] @@ -401,8 +302,7 @@ mod _winapi { ) -> PyResult<()> { let src_path = std::path::Path::new(src_path.expect_str()); let dest_path = std::path::Path::new(dest_path.expect_str()); - - junction::create(src_path, dest_path).map_err(|e| e.to_pyexception(vm)) + host_winapi::create_junction(src_path, dest_path).map_err(|e| e.to_pyexception(vm)) } fn getenvironment(env: ArgMapping, vm: &VirtualMachine) -> PyResult> { @@ -416,125 +316,33 @@ mod _winapi { return Err(vm.new_runtime_error("environment changed size during iteration")); } - // Deduplicate case-insensitive keys, keeping the last value - use std::collections::HashMap; - let mut last_entry: HashMap = HashMap::new(); + let mut entries = Vec::with_capacity(keys.len()); for (k, v) in keys.into_iter().zip(values) { let k = PyStrRef::try_from_object(vm, k)?; - let k = k.expect_str(); + let k = k.expect_str().to_owned(); let v = PyStrRef::try_from_object(vm, v)?; - let v = v.expect_str(); - if k.contains('\0') || v.contains('\0') { - return Err(crate::exceptions::cstring_error(vm)); - } - if k.is_empty() || k[1..].contains('=') { - return Err(vm.new_value_error("illegal environment variable name")); - } - let key_upper = k.to_uppercase(); - let mut entry = widestring::WideString::new(); - entry.push_str(k); - entry.push_str("="); - entry.push_str(v); - entry.push_str("\0"); - last_entry.insert(key_upper, entry); - } - - // Sort by uppercase key for case-insensitive ordering - let mut entries: Vec<(String, widestring::WideString)> = last_entry.into_iter().collect(); - entries.sort_by(|a, b| a.0.cmp(&b.0)); - - let mut out = widestring::WideString::new(); - for (_, entry) in entries { - out.push(entry); - } - // Each entry ends with \0, so one more \0 terminates the block. - // For empty env, we need \0\0 as a valid empty environment block. - if out.is_empty() { - out.push_str("\0"); + let v = v.expect_str().to_owned(); + entries.push((k, v)); } - out.push_str("\0"); - Ok(out.into_vec()) - } - struct AttrList { - handlelist: Option>, - attrlist: Vec, - } - impl Drop for AttrList { - fn drop(&mut self) { - unsafe { - windows_sys::Win32::System::Threading::DeleteProcThreadAttributeList( - self.attrlist.as_mut_ptr() as *mut _, - ) - }; - } + host_winapi::build_environment_block(entries).map_err(|err| err.to_pyexception(vm)) } - fn getattributelist(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - >::try_from_object(vm, obj)? - .map(|mapping| { - let handlelist = mapping - .as_ref() - .get_item("handle_list", vm) - .ok() - .and_then(|obj| { - >>::try_from_object(vm, obj) - .map(|s| match s { - Some(s) if !s.is_empty() => Some(s.into_vec()), - _ => None, - }) - .transpose() - }) - .transpose()?; - - let attr_count = handlelist.is_some() as u32; - let (result, mut size) = unsafe { - let mut size = core::mem::MaybeUninit::uninit(); - let result = WindowsSysResult( - windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList( - core::ptr::null_mut(), - attr_count, - 0, - size.as_mut_ptr(), - ), - ); - (result, size.assume_init()) - }; - if !result.is_err() - || unsafe { windows_sys::Win32::Foundation::GetLastError() } - != windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER - { - return Err(vm.new_last_os_error()); - } - let mut attrlist = vec![0u8; size]; - WindowsSysResult(unsafe { - windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList( - attrlist.as_mut_ptr() as *mut _, - attr_count, - 0, - &mut size, - ) - }) - .into_pyresult(vm)?; - let mut attrs = AttrList { - handlelist, - attrlist, - }; - if let Some(ref mut handlelist) = attrs.handlelist { - WindowsSysResult(unsafe { - windows_sys::Win32::System::Threading::UpdateProcThreadAttribute( - attrs.attrlist.as_mut_ptr() as _, - 0, - (2 & 0xffff) | 0x20000, // PROC_THREAD_ATTRIBUTE_HANDLE_LIST - handlelist.as_mut_ptr() as _, - (handlelist.len() * core::mem::size_of::()) as _, - core::ptr::null_mut(), - core::ptr::null(), - ) + fn get_handle_list(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult>> { + let Some(mapping) = >::try_from_object(vm, obj)? else { + return Ok(None); + }; + mapping + .as_ref() + .get_item("handle_list", vm) + .ok() + .and_then(|obj| { + >>::try_from_object(vm, obj) + .map(|s| match s { + Some(s) if !s.is_empty() => Some(s.into_vec()), + _ => None, }) - .into_pyresult(vm)?; - } - Ok(attrs) + .transpose() }) .transpose() } @@ -543,18 +351,13 @@ mod _winapi { fn WaitForSingleObject(h: WinHandle, ms: i64, vm: &VirtualMachine) -> PyResult { // Negative values (e.g., -1) map to INFINITE (0xFFFFFFFF) let ms = if ms < 0 { - windows_sys::Win32::System::Threading::INFINITE + host_winapi::INFINITE_TIMEOUT } else if ms > u32::MAX as i64 { return Err(vm.new_overflow_error("timeout value is too large")); } else { ms as u32 }; - let ret = unsafe { windows_sys::Win32::System::Threading::WaitForSingleObject(h.0, ms) }; - if ret == windows_sys::Win32::Foundation::WAIT_FAILED { - Err(vm.new_last_os_error()) - } else { - Ok(ret) - } + host_winapi::wait_for_single_object(h.0, ms).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] @@ -564,13 +367,10 @@ mod _winapi { milliseconds: u32, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Foundation::WAIT_FAILED; - use windows_sys::Win32::System::Threading::WaitForMultipleObjects as WinWaitForMultipleObjects; - - let handles: Vec = handle_seq + let handles: Vec = handle_seq .into_vec() .into_iter() - .map(|h| h as HANDLE) + .map(|h| h as host_winapi::Handle) .collect(); if handles.is_empty() { @@ -581,40 +381,18 @@ mod _winapi { return Err(vm.new_value_error("WaitForMultipleObjects supports at most 64 handles")); } - let ret = unsafe { - WinWaitForMultipleObjects( - handles.len() as u32, - handles.as_ptr(), - if wait_all { 1 } else { 0 }, - milliseconds, - ) - }; - - if ret == WAIT_FAILED { - Err(vm.new_last_os_error()) - } else { - Ok(ret) - } + host_winapi::wait_for_multiple_objects(&handles, wait_all, milliseconds) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn GetExitCodeProcess(h: WinHandle, vm: &VirtualMachine) -> PyResult { - unsafe { - let mut ec = core::mem::MaybeUninit::uninit(); - WindowsSysResult(windows_sys::Win32::System::Threading::GetExitCodeProcess( - h.0, - ec.as_mut_ptr(), - )) - .to_pyresult(vm)?; - Ok(ec.assume_init()) - } + host_winapi::get_exit_code_process(h.0).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn TerminateProcess(h: WinHandle, exit_code: u32) -> WindowsSysResult { - WindowsSysResult(unsafe { - windows_sys::Win32::System::Threading::TerminateProcess(h.0, exit_code) - }) + WindowsSysResult(host_winapi::terminate_process(h.0, exit_code)) } #[pyfunction] @@ -623,22 +401,10 @@ mod _winapi { name: OptionalArg>, vm: &VirtualMachine, ) -> PyResult { - let handle = unsafe { - match name.flatten() { - Some(name) => { - let name_wide = name.as_wtf8().to_wide_with_nul(); - windows_sys::Win32::System::JobObjects::CreateJobObjectW( - null(), - name_wide.as_ptr(), - ) - } - None => windows_sys::Win32::System::JobObjects::CreateJobObjectW(null(), null()), - } - }; - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + let name = name.flatten().map(|name| name.as_wtf8().to_wide_cstring()); + host_winapi::create_job_object_w(name.as_deref()) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] @@ -647,59 +413,25 @@ mod _winapi { process: WinHandle, vm: &VirtualMachine, ) -> PyResult<()> { - let ret = unsafe { - windows_sys::Win32::System::JobObjects::AssignProcessToJobObject(job.0, process.0) - }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) + host_winapi::assign_process_to_job_object(job.0, process.0) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn TerminateJobObject(job: WinHandle, exit_code: u32, vm: &VirtualMachine) -> PyResult<()> { - let ret = - unsafe { windows_sys::Win32::System::JobObjects::TerminateJobObject(job.0, exit_code) }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) + host_winapi::terminate_job_object(job.0, exit_code).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn SetJobObjectKillOnClose(job: WinHandle, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::JobObjects::{ - JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation, - SetInformationJobObject, - }; - let mut info: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = unsafe { core::mem::zeroed() }; - info.BasicLimitInformation.LimitFlags = - windows_sys::Win32::System::JobObjects::JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; - let ret = unsafe { - SetInformationJobObject( - job.0, - JobObjectExtendedLimitInformation, - &info as *const _ as *const core::ffi::c_void, - core::mem::size_of::() as u32, - ) - }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) + host_winapi::set_job_object_kill_on_close(job.0).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn GetModuleFileName(handle: isize, vm: &VirtualMachine) -> PyResult { - let mut path: Vec = vec![0; MAX_PATH as usize]; - - let length = unsafe { - windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW( - handle as windows_sys::Win32::Foundation::HMODULE, - path.as_mut_ptr(), - path.len() as u32, - ) - }; + let mut path: Vec = vec![0; host_winapi::MAX_PATH_USIZE]; + + let length = host_winapi::get_module_file_name(handle as _, &mut path); if length == 0 { return Err(vm.new_runtime_error("GetModuleFileName failed")); } @@ -715,23 +447,15 @@ mod _winapi { name: PyStrRef, vm: &VirtualMachine, ) -> PyResult { - let name_wide = name.as_wtf8().to_wide_with_nul(); - let handle = unsafe { - windows_sys::Win32::System::Threading::OpenMutexW( - desired_access, - i32::from(inherit_handle), - name_wide.as_ptr(), - ) - }; - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + let name_wide = name.as_wtf8().to_wide_cstring(); + host_winapi::open_mutex_w(desired_access, inherit_handle, &name_wide) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn ReleaseMutex(handle: WinHandle) -> WindowsSysResult { - WindowsSysResult(unsafe { windows_sys::Win32::System::Threading::ReleaseMutex(handle.0) }) + WindowsSysResult(host_winapi::release_mutex(handle.0)) } // LOCALE_NAME_INVARIANT is an empty string in Windows API @@ -755,64 +479,27 @@ mod _winapi { src: PyStrRef, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::Globalization::{ - LCMAP_BYTEREV, LCMAP_HASH, LCMAP_SORTHANDLE, LCMAP_SORTKEY, - LCMapStringEx as WinLCMapStringEx, - }; - // Reject unsupported flags - if flags & (LCMAP_SORTHANDLE | LCMAP_HASH | LCMAP_BYTEREV | LCMAP_SORTKEY) != 0 { + if flags + & (host_winapi::LCMAP_SORTHANDLE_FLAG + | host_winapi::LCMAP_HASH_FLAG + | host_winapi::LCMAP_BYTEREV_FLAG + | host_winapi::LCMAP_SORTKEY_FLAG) + != 0 + { return Err(vm.new_value_error("unsupported flags")); } // Use ToWideString which properly handles WTF-8 (including surrogates) - let locale_wide = locale.as_wtf8().to_wide_with_nul(); + let locale_wide = locale.as_wtf8().to_wide_cstring(); let src_wide = src.as_wtf8().to_wide(); if src_wide.len() > i32::MAX as usize { return Err(vm.new_overflow_error("input string is too long")); } - // First call to get required buffer size - let dest_size = unsafe { - WinLCMapStringEx( - locale_wide.as_ptr(), - flags, - src_wide.as_ptr(), - src_wide.len() as i32, - null_mut(), - 0, - null(), - null(), - 0, - ) - }; - - if dest_size <= 0 { - return Err(vm.new_last_os_error()); - } - - // Second call to perform the mapping - let mut dest = vec![0u16; dest_size as usize]; - let nmapped = unsafe { - WinLCMapStringEx( - locale_wide.as_ptr(), - flags, - src_wide.as_ptr(), - src_wide.len() as i32, - dest.as_mut_ptr(), - dest_size, - null(), - null(), - 0, - ) - }; - - if nmapped <= 0 { - return Err(vm.new_last_os_error()); - } - - dest.truncate(nmapped as usize); + let dest = host_winapi::lc_map_string_ex(&locale_wide, flags, &src_wide) + .map_err(|e| e.to_pyexception(vm))?; // Convert UTF-16 back to WTF-8 (handles surrogates properly) let result = Wtf8Buf::from_wide(&dest); @@ -842,28 +529,18 @@ mod _winapi { /// CreateNamedPipe - Create a named pipe #[pyfunction] fn CreateNamedPipe(args: CreateNamedPipeArgs, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::System::Pipes::CreateNamedPipeW; - - let name_wide = args.name.as_wtf8().to_wide_with_nul(); - - let handle = unsafe { - CreateNamedPipeW( - name_wide.as_ptr(), - args.open_mode, - args.pipe_mode, - args.max_instances, - args.out_buffer_size, - args.in_buffer_size, - args.default_timeout, - null(), // security_attributes - NULL for now - ) - }; - - if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); - } - - Ok(WinHandle(handle)) + let name_wide = args.name.as_wtf8().to_wide_cstring(); + host_winapi::create_named_pipe_w( + &name_wide, + args.open_mode, + args.pipe_mode, + args.max_instances, + args.out_buffer_size, + args.in_buffer_size, + args.default_timeout, + ) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } // ==================== Overlapped class ==================== @@ -873,144 +550,51 @@ mod _winapi { #[pyclass(name = "Overlapped", module = "_winapi")] #[derive(Debug, PyPayload)] struct Overlapped { - inner: PyMutex, - } - - struct OverlappedInner { - // Box ensures the OVERLAPPED struct stays at a stable heap address - // even when the containing Overlapped Python object is moved during - // into_pyobject(). The OS holds a pointer to this struct for pending - // I/O operations, so it must not be relocated. - overlapped: Box, - handle: HANDLE, - pending: bool, - completed: bool, - read_buffer: Option>, - write_buffer: Option>, + inner: PyMutex, } - impl core::fmt::Debug for OverlappedInner { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("OverlappedInner") - .field("handle", &self.handle) - .field("pending", &self.pending) - .field("completed", &self.completed) - .finish() - } - } - - unsafe impl Sync for OverlappedInner {} - unsafe impl Send for OverlappedInner {} - #[pyclass(with(Constructor))] impl Overlapped { - fn new_with_handle(handle: HANDLE) -> Self { - use windows_sys::Win32::System::Threading::CreateEventW; - - let event = unsafe { CreateEventW(null(), 1, 0, null()) }; - let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = - unsafe { core::mem::zeroed() }; - overlapped.hEvent = event; - - Self { - inner: PyMutex::new(OverlappedInner { - overlapped: Box::new(overlapped), - handle, - pending: false, - completed: false, - read_buffer: None, - write_buffer: None, - }), - } + fn new_with_handle(handle: host_winapi::Handle, vm: &VirtualMachine) -> PyResult { + host_overlapped::Operation::new(handle) + .map(|inner| Self { + inner: PyMutex::new(inner), + }) + .map_err(|e| e.to_pyexception(vm)) } #[pymethod] fn GetOverlappedResult(&self, wait: bool, vm: &VirtualMachine) -> PyResult<(u32, u32)> { - use windows_sys::Win32::Foundation::{ - ERROR_IO_INCOMPLETE, ERROR_MORE_DATA, ERROR_OPERATION_ABORTED, ERROR_SUCCESS, - GetLastError, - }; - use windows_sys::Win32::System::IO::GetOverlappedResult; - let mut inner = self.inner.lock(); - - let mut transferred: u32 = 0; - - let ret = unsafe { - GetOverlappedResult( - inner.handle, - &*inner.overlapped, - &mut transferred, - if wait { 1 } else { 0 }, - ) - }; - - let err = if ret == 0 { - unsafe { GetLastError() } - } else { - ERROR_SUCCESS - }; - - match err { - ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_OPERATION_ABORTED => { - inner.completed = true; - inner.pending = false; - } - ERROR_IO_INCOMPLETE => {} - _ => { - inner.pending = false; - return Err(std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm)); - } - } - - if inner.completed - && let Some(read_buffer) = &mut inner.read_buffer - && transferred != read_buffer.len() as u32 - { - read_buffer.truncate(transferred as usize); - } - - Ok((transferred, err)) + inner + .get_result(wait) + .map(|result| (result.transferred, result.error)) + .map_err(|e| e.to_pyexception(vm)) } #[pymethod] fn getbuffer(&self, vm: &VirtualMachine) -> PyResult> { let inner = self.inner.lock(); - if !inner.completed { + if !inner.is_completed() { return Err(vm.new_value_error( "can't get read buffer before GetOverlappedResult() signals the operation completed", )); } Ok(inner - .read_buffer - .as_ref() - .map(|buf| vm.ctx.new_bytes(buf.clone()).into())) + .read_buffer() + .map(|buf| vm.ctx.new_bytes(buf.to_vec()).into())) } #[pymethod] fn cancel(&self, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::IO::CancelIoEx; - let mut inner = self.inner.lock(); - let ret = if inner.pending { - unsafe { CancelIoEx(inner.handle, &*inner.overlapped) } - } else { - 1 - }; - if ret == 0 { - let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; - if err != windows_sys::Win32::Foundation::ERROR_NOT_FOUND { - return Err(std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm)); - } - } - inner.pending = false; - Ok(()) + inner.cancel().map_err(|e| e.to_pyexception(vm)) } #[pygetset] fn event(&self) -> isize { let inner = self.inner.lock(); - inner.overlapped.hEvent as isize + inner.event() as isize } } @@ -1020,18 +604,9 @@ mod _winapi { fn py_new( _cls: &Py, _args: Self::Args, - _vm: &VirtualMachine, + vm: &VirtualMachine, ) -> PyResult { - Ok(Self::new_with_handle(null_mut())) - } - } - - impl Drop for OverlappedInner { - fn drop(&mut self) { - use windows_sys::Win32::Foundation::CloseHandle; - if !self.overlapped.hEvent.is_null() { - unsafe { CloseHandle(self.overlapped.hEvent) }; - } + Self::new_with_handle(null_mut(), vm) } } @@ -1046,122 +621,54 @@ mod _winapi { #[pyfunction] fn ConnectNamedPipe(args: ConnectNamedPipeArgs, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::{ - ERROR_IO_PENDING, ERROR_PIPE_CONNECTED, GetLastError, - }; - let handle = args.handle; let use_overlapped = args.overlapped.unwrap_or(false); if use_overlapped { - // Overlapped (async) mode - let ov = Overlapped::new_with_handle(handle.0); - - let _ret = { + let ov = Overlapped::new_with_handle(handle.0, vm)?; + { let mut inner = ov.inner.lock(); - unsafe { - windows_sys::Win32::System::Pipes::ConnectNamedPipe( - handle.0, - &mut *inner.overlapped, - ) - } - }; - - let err = unsafe { GetLastError() }; - match err { - ERROR_IO_PENDING => { - let mut inner = ov.inner.lock(); - inner.pending = true; - } - ERROR_PIPE_CONNECTED => { - let inner = ov.inner.lock(); - unsafe { - windows_sys::Win32::System::Threading::SetEvent(inner.overlapped.hEvent); - } - } - _ => { - return Err(std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm)); - } + inner + .connect_named_pipe() + .map_err(|e| e.to_pyexception(vm))?; } - Ok(ov.into_pyobject(vm)) } else { - // Synchronous mode - let ret = unsafe { - windows_sys::Win32::System::Pipes::ConnectNamedPipe(handle.0, null_mut()) - }; - - if ret == 0 { - let err = unsafe { GetLastError() }; - if err != ERROR_PIPE_CONNECTED { - return Err(std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm)); - } - } - + host_winapi::connect_named_pipe(handle.0).map_err(|e| e.to_pyexception(vm))?; Ok(vm.ctx.none()) } } /// Helper for GetShortPathName and GetLongPathName - fn get_path_name_impl( - path: &PyStrRef, - api_fn: unsafe extern "system" fn(*const u16, *mut u16, u32) -> u32, - vm: &VirtualMachine, - ) -> PyResult { - let path_wide = path.as_wtf8().to_wide_with_nul(); - - // First call to get required buffer size - let size = unsafe { api_fn(path_wide.as_ptr(), null_mut(), 0) }; - - if size == 0 { - return Err(vm.new_last_os_error()); - } - - // Second call to get the actual path - let mut buffer: Vec = vec![0; size as usize]; - let result = - unsafe { api_fn(path_wide.as_ptr(), buffer.as_mut_ptr(), buffer.len() as u32) }; - - if result == 0 { - return Err(vm.new_last_os_error()); - } - - // Truncate to actual length (excluding null terminator) - buffer.truncate(result as usize); - + fn path_name_result_to_pystr(wide: Vec, vm: &VirtualMachine) -> PyStrRef { // Convert UTF-16 back to WTF-8 (handles surrogates properly) - let result_str = Wtf8Buf::from_wide(&buffer); - Ok(vm.ctx.new_str(result_str)) + let result_str = Wtf8Buf::from_wide(&wide); + vm.ctx.new_str(result_str) } /// GetShortPathName - Return the short version of the provided path. #[pyfunction] fn GetShortPathName(path: PyStrRef, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Storage::FileSystem::GetShortPathNameW; - get_path_name_impl(&path, GetShortPathNameW, vm) + let path_wide = path.as_wtf8().to_wide_cstring(); + let wide = + host_winapi::get_short_path_name_w(&path_wide).map_err(|e| e.to_pyexception(vm))?; + Ok(path_name_result_to_pystr(wide, vm)) } /// GetLongPathName - Return the long version of the provided path. #[pyfunction] fn GetLongPathName(path: PyStrRef, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Storage::FileSystem::GetLongPathNameW; - get_path_name_impl(&path, GetLongPathNameW, vm) + let path_wide = path.as_wtf8().to_wide_cstring(); + let wide = + host_winapi::get_long_path_name_w(&path_wide).map_err(|e| e.to_pyexception(vm))?; + Ok(path_name_result_to_pystr(wide, vm)) } /// WaitNamedPipe - Wait for an instance of a named pipe to become available. #[pyfunction] fn WaitNamedPipe(name: PyStrRef, timeout: u32, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Pipes::WaitNamedPipeW; - - let name_wide = name.as_wtf8().to_wide_with_nul(); - - let success = unsafe { WaitNamedPipeW(name_wide.as_ptr(), timeout) }; - - if success == 0 { - return Err(vm.new_last_os_error()); - } - - Ok(()) + let name_wide = name.as_wtf8().to_wide_cstring(); + host_winapi::wait_named_pipe_w(&name_wide, timeout).map_err(|e| e.to_pyexception(vm)) } /// PeekNamedPipe - Peek at data in a named pipe without removing it. @@ -1171,60 +678,33 @@ mod _winapi { size: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Pipes::PeekNamedPipe as WinPeekNamedPipe; - let size = size.unwrap_or(0); if size < 0 { return Err(vm.new_value_error("negative size")); } - let mut navail: u32 = 0; - let mut nleft: u32 = 0; - if size > 0 { - let mut buf = vec![0u8; size as usize]; - let mut nread: u32 = 0; - - let ret = unsafe { - WinPeekNamedPipe( - handle.0, - buf.as_mut_ptr() as *mut _, - size as u32, - &mut nread, - &mut navail, - &mut nleft, - ) - }; - - if ret == 0 { - return Err(vm.new_last_os_error()); - } - - buf.truncate(nread as usize); + let result = host_winapi::peek_named_pipe(handle.0, Some(size as u32)) + .map_err(|e| e.to_pyexception(vm))?; + let buf = result.data.unwrap_or_default(); let bytes: PyObjectRef = vm.ctx.new_bytes(buf).into(); Ok(vm .ctx .new_tuple(vec![ bytes, - vm.ctx.new_int(navail).into(), - vm.ctx.new_int(nleft).into(), + vm.ctx.new_int(result.available).into(), + vm.ctx.new_int(result.left_this_message).into(), ]) .into()) } else { - let ret = unsafe { - WinPeekNamedPipe(handle.0, null_mut(), 0, null_mut(), &mut navail, &mut nleft) - }; - - if ret == 0 { - return Err(vm.new_last_os_error()); - } - + let result = + host_winapi::peek_named_pipe(handle.0, None).map_err(|e| e.to_pyexception(vm))?; Ok(vm .ctx .new_tuple(vec![ - vm.ctx.new_int(navail).into(), - vm.ctx.new_int(nleft).into(), + vm.ctx.new_int(result.available).into(), + vm.ctx.new_int(result.left_this_message).into(), ]) .into()) } @@ -1239,41 +719,18 @@ mod _winapi { name: Option, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Threading::CreateEventW as WinCreateEventW; - let _ = security_attributes; // Ignored, always NULL - let name_wide = name.map(|n| n.as_wtf8().to_wide_with_nul()); - let name_ptr = name_wide.as_ref().map_or(null(), |n| n.as_ptr()); - - let handle = unsafe { - WinCreateEventW( - null(), - i32::from(manual_reset), - i32::from(initial_state), - name_ptr, - ) - }; - - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - - Ok(WinHandle(handle)) + let name_wide = name.map(|n| n.as_wtf8().to_wide_cstring()); + host_winapi::create_event_w(manual_reset, initial_state, name_wide.as_deref()) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } /// SetEvent - Set the specified event object to the signaled state. #[pyfunction] fn SetEvent(event: WinHandle, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Threading::SetEvent as WinSetEvent; - - let ret = unsafe { WinSetEvent(event.0) }; - - if ret == 0 { - return Err(vm.new_last_os_error()); - } - - Ok(()) + host_winapi::set_event(event.0).map_err(|e| e.to_pyexception(vm)) } #[derive(FromArgs)] @@ -1289,96 +746,29 @@ mod _winapi { /// WriteFile - Write data to a file or I/O device. #[pyfunction] fn WriteFile(args: WriteFileArgs, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Storage::FileSystem::WriteFile as WinWriteFile; - let handle = args.handle; let use_overlapped = args.overlapped; let buf = args.buffer.borrow_buf(); - let len = core::cmp::min(buf.len(), u32::MAX as usize) as u32; if use_overlapped { - use windows_sys::Win32::Foundation::ERROR_IO_PENDING; - - let ov = Overlapped::new_with_handle(handle.0); + let ov = Overlapped::new_with_handle(handle.0, vm)?; let err = { let mut inner = ov.inner.lock(); - inner.write_buffer = Some(buf.to_vec()); - let write_buf = inner.write_buffer.as_ref().unwrap(); - let mut written: u32 = 0; - let ret = unsafe { - WinWriteFile( - handle.0, - write_buf.as_ptr() as *const _, - len, - &mut written, - &mut *inner.overlapped, - ) - }; - - let err = if ret == 0 { - unsafe { windows_sys::Win32::Foundation::GetLastError() } - } else { - 0 - }; - - if ret == 0 && err != ERROR_IO_PENDING { - return Err(vm.new_last_os_error()); - } - if ret == 0 && err == ERROR_IO_PENDING { - inner.pending = true; - } - - err + inner.write(&buf).map_err(|e| e.to_pyexception(vm))? }; - // Without GIL, the Python-level PipeConnection._send_bytes has a - // race on _send_ov when the caller (SimpleQueue) skips locking on - // Windows. Wait for completion here so the caller never sees - // ERROR_IO_PENDING and never blocks in WaitForMultipleObjects, - // keeping the _send_ov window negligibly small. - if err == ERROR_IO_PENDING { - let event = ov.inner.lock().overlapped.hEvent; - vm.allow_threads(|| unsafe { - windows_sys::Win32::System::Threading::WaitForSingleObject( - event, - windows_sys::Win32::System::Threading::INFINITE, - ); - }); - let result = vm - .ctx - .new_tuple(vec![ov.into_pyobject(vm), vm.ctx.new_int(0u32).into()]); - return Ok(result.into()); - } - let result = vm .ctx .new_tuple(vec![ov.into_pyobject(vm), vm.ctx.new_int(err).into()]); return Ok(result.into()); } - let mut written: u32 = 0; - let ret = unsafe { - WinWriteFile( - handle.0, - buf.as_ptr() as *const _, - len, - &mut written, - null_mut(), - ) - }; - let err = if ret == 0 { - unsafe { windows_sys::Win32::Foundation::GetLastError() } - } else { - 0 - }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } + let result = host_winapi::write_file(handle.0, &buf).map_err(|e| e.to_pyexception(vm))?; Ok(vm .ctx .new_tuple(vec![ - vm.ctx.new_int(written).into(), - vm.ctx.new_int(err).into(), + vm.ctx.new_int(result.written).into(), + vm.ctx.new_int(result.error).into(), ]) .into()) } @@ -1396,45 +786,15 @@ mod _winapi { /// ReadFile - Read data from a file or I/O device. #[pyfunction] fn ReadFile(args: ReadFileArgs, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Storage::FileSystem::ReadFile as WinReadFile; - let handle = args.handle; let size = args.size; let use_overlapped = args.overlapped; if use_overlapped { - use windows_sys::Win32::Foundation::ERROR_IO_PENDING; - - let ov = Overlapped::new_with_handle(handle.0); + let ov = Overlapped::new_with_handle(handle.0, vm)?; let err = { let mut inner = ov.inner.lock(); - inner.read_buffer = Some(vec![0u8; size as usize]); - let read_buf = inner.read_buffer.as_mut().unwrap(); - let mut nread: u32 = 0; - let ret = unsafe { - WinReadFile( - handle.0, - read_buf.as_mut_ptr() as *mut _, - size, - &mut nread, - &mut *inner.overlapped, - ) - }; - - let err = if ret == 0 { - unsafe { windows_sys::Win32::Foundation::GetLastError() } - } else { - 0 - }; - - if ret == 0 && err != ERROR_IO_PENDING && err != ERROR_MORE_DATA { - return Err(vm.new_last_os_error()); - } - if ret == 0 && err == ERROR_IO_PENDING { - inner.pending = true; - } - - err + inner.read(size).map_err(|e| e.to_pyexception(vm))? }; let result = vm .ctx @@ -1442,31 +802,12 @@ mod _winapi { return Ok(result.into()); } - let mut buf = vec![0u8; size as usize]; - let mut nread: u32 = 0; - let ret = unsafe { - WinReadFile( - handle.0, - buf.as_mut_ptr() as *mut _, - size, - &mut nread, - null_mut(), - ) - }; - let err = if ret == 0 { - unsafe { windows_sys::Win32::Foundation::GetLastError() } - } else { - 0 - }; - if ret == 0 && err != ERROR_MORE_DATA { - return Err(vm.new_last_os_error()); - } - buf.truncate(nread as usize); + let result = host_winapi::read_file(handle.0, size).map_err(|e| e.to_pyexception(vm))?; Ok(vm .ctx .new_tuple(vec![ - vm.ctx.new_bytes(buf).into(), - vm.ctx.new_int(err).into(), + vm.ctx.new_bytes(result.data).into(), + vm.ctx.new_int(result.error).into(), ]) .into()) } @@ -1480,38 +821,21 @@ mod _winapi { collect_data_timeout: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - use windows_sys::Win32::System::Pipes::SetNamedPipeHandleState as WinSetNamedPipeHandleState; - - let mut dw_args: [u32; 3] = [0; 3]; - let mut p_args: [*mut u32; 3] = [null_mut(); 3]; - let objs = [&mode, &max_collection_count, &collect_data_timeout]; - for (i, obj) in objs.iter().enumerate() { + let mut values = [None; 3]; + for (index, obj) in objs.iter().enumerate() { if !vm.is_none(obj) { - dw_args[i] = u32::try_from_object(vm, (*obj).clone())?; - p_args[i] = &mut dw_args[i]; + values[index] = Some(u32::try_from_object(vm, (*obj).clone())?); } } - - let ret = - unsafe { WinSetNamedPipeHandleState(named_pipe.0, p_args[0], p_args[1], p_args[2]) }; - - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) + host_winapi::set_named_pipe_handle_state(named_pipe.0, values[0], values[1], values[2]) + .map_err(|e| e.to_pyexception(vm)) } /// ResetEvent - Reset the specified event object to the nonsignaled state. #[pyfunction] fn ResetEvent(event: WinHandle, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Threading::ResetEvent as WinResetEvent; - - let ret = unsafe { WinResetEvent(event.0) }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) + host_winapi::reset_event(event.0).map_err(|e| e.to_pyexception(vm)) } /// CreateMutexW - Create or open a named or unnamed mutex object. @@ -1522,18 +846,11 @@ mod _winapi { name: Option, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Threading::CreateMutexW as WinCreateMutexW; - let _ = security_attributes; - let name_wide = name.map(|n| n.as_wtf8().to_wide_with_nul()); - let name_ptr = name_wide.as_ref().map_or(null(), |n| n.as_ptr()); - - let handle = unsafe { WinCreateMutexW(null(), i32::from(initial_owner), name_ptr) }; - - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + let name_wide = name.map(|n| n.as_wtf8().to_wide_cstring()); + host_winapi::create_mutex_w(initial_owner, name_wide.as_deref()) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } /// OpenEventW - Open an existing named event object. @@ -1544,21 +861,10 @@ mod _winapi { name: PyStrRef, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Threading::OpenEventW as WinOpenEventW; - - let name_wide = name.as_wtf8().to_wide_with_nul(); - let handle = unsafe { - WinOpenEventW( - desired_access, - i32::from(inherit_handle), - name_wide.as_ptr(), - ) - }; - - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + let name_wide = name.as_wtf8().to_wide_cstring(); + host_winapi::open_event_w(desired_access, inherit_handle, &name_wide) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } const MAXIMUM_WAIT_OBJECTS: usize = 64; @@ -1571,21 +877,15 @@ mod _winapi { milliseconds: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use alloc::sync::Arc; - use core::sync::atomic::{AtomicU32, Ordering}; - use windows_sys::Win32::Foundation::{CloseHandle, WAIT_FAILED, WAIT_OBJECT_0}; - use windows_sys::Win32::System::SystemInformation::GetTickCount64; - use windows_sys::Win32::System::Threading::{ - CreateEventW as WinCreateEventW, CreateThread, GetExitCodeThread, - INFINITE as WIN_INFINITE, ResumeThread, SetEvent as WinSetEvent, TerminateThread, - WaitForMultipleObjects, - }; - - let milliseconds = milliseconds.unwrap_or(WIN_INFINITE); + let milliseconds = milliseconds.unwrap_or(host_winapi::INFINITE_TIMEOUT); // Get handles from sequence let seq = ArgSequence::::try_from_object(vm, handle_seq)?; - let handles: Vec = seq.into_vec(); + let handles: Vec = seq + .into_vec() + .into_iter() + .map(|handle| handle as _) + .collect(); let nhandles = handles.len(); if nhandles == 0 { @@ -1603,299 +903,46 @@ mod _winapi { ))); } - // Create batches of handles - let batch_size = MAXIMUM_WAIT_OBJECTS - 1; // Leave room for cancel_event - let mut batches: Vec> = Vec::new(); - let mut i = 0; - while i < nhandles { - let end = core::cmp::min(i + batch_size, nhandles); - batches.push(handles[i..end].to_vec()); - i = end; - } - #[cfg(feature = "threading")] let sigint_event = { let is_main = crate::stdlib::_thread::get_ident() == vm.state.main_thread_ident.load(); if is_main { - let handle = crate::signal::get_sigint_event().unwrap_or_else(|| { - let handle = unsafe { WinCreateEventW(null(), 1, 0, null()) }; - if !handle.is_null() { - crate::signal::set_sigint_event(handle as isize); - } - handle as isize - }); - if handle == 0 { None } else { Some(handle) } + let handle = crate::signal::get_sigint_event().map_or_else( + || { + let handle = host_winapi::create_event_w(true, false, None) + .unwrap_or(core::ptr::null_mut()); + if !handle.is_null() { + crate::signal::set_sigint_event(handle as isize); + } + handle + }, + |handle| handle as host_winapi::Handle, + ); + if handle.is_null() { None } else { Some(handle) } } else { None } }; #[cfg(not(feature = "threading"))] - let sigint_event: Option = None; - - if wait_all { - // For wait_all, we wait sequentially for each batch - let mut err: Option = None; - let deadline = if milliseconds != WIN_INFINITE { - Some(unsafe { GetTickCount64() } + milliseconds as u64) - } else { - None - }; - - for batch in &batches { - let timeout = if let Some(deadline) = deadline { - let now = unsafe { GetTickCount64() }; - if now >= deadline { - err = Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT); - break; - } - (deadline - now) as u32 - } else { - WIN_INFINITE - }; - - let batch_handles: Vec<_> = batch.iter().map(|&h| h as _).collect(); - let result = unsafe { - WaitForMultipleObjects( - batch_handles.len() as u32, - batch_handles.as_ptr(), - 1, // wait_all = TRUE - timeout, - ) - }; - - if result == WAIT_FAILED { - err = Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }); - break; - } - if result == windows_sys::Win32::Foundation::WAIT_TIMEOUT { - err = Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT); - break; - } - - if let Some(sigint_event) = sigint_event { - let sig_result = unsafe { - windows_sys::Win32::System::Threading::WaitForSingleObject( - sigint_event as _, - 0, - ) - }; - if sig_result == WAIT_OBJECT_0 { - err = Some(windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT); - break; - } - if sig_result == WAIT_FAILED { - err = Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }); - break; - } - } - } - - if let Some(err) = err { - if err == windows_sys::Win32::Foundation::WAIT_TIMEOUT { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.timeout_error.to_owned(), - None, - "timed out", - ) - .upcast()); - } - if err == windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT { - return Err(vm - .new_errno_error(libc::EINTR, "Interrupted system call") - .upcast()); - } - return Err(vm.new_os_error(err as i32)); - } - - Ok(vm.ctx.none()) - } else { - // For wait_any, we use threads to wait on each batch in parallel - let cancel_event = unsafe { WinCreateEventW(null(), 1, 0, null()) }; // Manual reset, not signaled - if cancel_event.is_null() { - return Err(vm.new_last_os_error()); - } - - struct BatchData { - handles: Vec, - cancel_event: isize, - handle_base: usize, - result: AtomicU32, - thread: core::cell::UnsafeCell, - } - - unsafe impl Send for BatchData {} - unsafe impl Sync for BatchData {} - - let batch_data: Vec> = batches - .iter() - .enumerate() - .map(|(idx, batch)| { - let base = idx * batch_size; - let mut handles_with_cancel = batch.clone(); - handles_with_cancel.push(cancel_event as isize); - Arc::new(BatchData { - handles: handles_with_cancel, - cancel_event: cancel_event as isize, - handle_base: base, - result: AtomicU32::new(WAIT_FAILED), - thread: core::cell::UnsafeCell::new(0), - }) - }) - .collect(); - - // Thread function - extern "system" fn batch_wait_thread(param: *mut core::ffi::c_void) -> u32 { - let data = unsafe { &*(param as *const BatchData) }; - let handles: Vec<_> = data.handles.iter().map(|&h| h as _).collect(); - let result = unsafe { - WaitForMultipleObjects( - handles.len() as u32, - handles.as_ptr(), - 0, // wait_any - WIN_INFINITE, - ) - }; - data.result.store(result, Ordering::SeqCst); - - if result == WAIT_FAILED { - let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; - unsafe { WinSetEvent(data.cancel_event as _) }; - return err; - } else if result >= windows_sys::Win32::Foundation::WAIT_ABANDONED_0 - && result - < windows_sys::Win32::Foundation::WAIT_ABANDONED_0 - + MAXIMUM_WAIT_OBJECTS as u32 - { - data.result.store(WAIT_FAILED, Ordering::SeqCst); - unsafe { WinSetEvent(data.cancel_event as _) }; - return windows_sys::Win32::Foundation::ERROR_ABANDONED_WAIT_0; - } - 0 - } - - // Create threads - let mut thread_handles: Vec = Vec::new(); - for data in &batch_data { - let thread = unsafe { - CreateThread( - null(), - 1, // Smallest stack - Some(batch_wait_thread), - Arc::as_ptr(data) as *const _ as *mut _, - 4, // CREATE_SUSPENDED - null_mut(), - ) - }; - if thread.is_null() { - // Cleanup on error - for h in &thread_handles { - unsafe { TerminateThread(*h as _, 0) }; - unsafe { CloseHandle(*h as _) }; - } - unsafe { CloseHandle(cancel_event) }; - return Err(vm.new_last_os_error()); - } - unsafe { *data.thread.get() = thread as isize }; - thread_handles.push(thread as isize); - } - - // Resume all threads - for &thread in &thread_handles { - unsafe { ResumeThread(thread as _) }; - } - - // Wait for any thread to complete - let mut thread_handles_raw: Vec<_> = thread_handles.iter().map(|&h| h as _).collect(); - if let Some(sigint_event) = sigint_event { - thread_handles_raw.push(sigint_event as _); - } - let result = unsafe { - WaitForMultipleObjects( - thread_handles_raw.len() as u32, - thread_handles_raw.as_ptr(), - 0, // wait_any - milliseconds, - ) - }; - - let err = if result == WAIT_FAILED { - Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }) - } else if result == windows_sys::Win32::Foundation::WAIT_TIMEOUT { - Some(windows_sys::Win32::Foundation::WAIT_TIMEOUT) - } else if sigint_event.is_some() - && result == WAIT_OBJECT_0 + thread_handles_raw.len() as u32 - { - Some(windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT) - } else { - None - }; - - // Signal cancel event to stop other threads - unsafe { WinSetEvent(cancel_event) }; - - // Wait for all threads to finish - let thread_handles_only: Vec<_> = thread_handles.iter().map(|&h| h as _).collect(); - unsafe { - WaitForMultipleObjects( - thread_handles_only.len() as u32, - thread_handles_only.as_ptr(), - 1, // wait_all - WIN_INFINITE, + let sigint_event: Option = None; + + match host_winapi::batched_wait_for_multiple_objects( + &handles, + wait_all, + milliseconds, + sigint_event, + ) { + Ok(host_winapi::BatchedWaitResult::All) => Ok(vm.ctx.none()), + Ok(host_winapi::BatchedWaitResult::Indices(indices)) => Ok(vm + .ctx + .new_list( + indices + .into_iter() + .map(|index| vm.ctx.new_int(index).into()) + .collect(), ) - }; - - // Check for errors from threads - let mut thread_err = err; - for data in &batch_data { - if thread_err.is_none() && data.result.load(Ordering::SeqCst) == WAIT_FAILED { - let mut exit_code: u32 = 0; - let thread = unsafe { *data.thread.get() }; - if unsafe { GetExitCodeThread(thread as _, &mut exit_code) } == 0 { - thread_err = - Some(unsafe { windows_sys::Win32::Foundation::GetLastError() }); - } else if exit_code != 0 { - thread_err = Some(exit_code); - } - } - let thread = unsafe { *data.thread.get() }; - unsafe { CloseHandle(thread as _) }; - } - - unsafe { CloseHandle(cancel_event) }; - - // Return result - if let Some(e) = thread_err { - if e == windows_sys::Win32::Foundation::WAIT_TIMEOUT { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.timeout_error.to_owned(), - None, - "timed out", - ) - .upcast()); - } - if e == windows_sys::Win32::Foundation::ERROR_CONTROL_C_EXIT { - return Err(vm - .new_errno_error(libc::EINTR, "Interrupted system call") - .upcast()); - } - return Err(vm.new_os_error(e as i32)); - } - - // Collect triggered indices - let mut triggered_indices: Vec = Vec::new(); - for data in &batch_data { - let result = data.result.load(Ordering::SeqCst); - let triggered = result as i32 - WAIT_OBJECT_0 as i32; - // Check if it's a valid handle index (not the cancel_event which is last) - if triggered >= 0 && (triggered as usize) < data.handles.len() - 1 { - let index = data.handle_base + triggered as usize; - triggered_indices.push(vm.ctx.new_int(index).into()); - } - } - - Ok(vm.ctx.new_list(triggered_indices).into()) + .into()), + Err(err) => Err(err.to_pyexception(vm)), } } @@ -1910,8 +957,6 @@ mod _winapi { name: Option, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Memory::CreateFileMappingW; - if let Some(ref n) = name && n.as_bytes().contains(&0) { @@ -1919,24 +964,16 @@ mod _winapi { vm.new_value_error("CreateFileMapping: name must not contain null characters") ); } - let name_wide = name.as_ref().map(|n| n.as_wtf8().to_wide_with_nul()); - let name_ptr = name_wide.as_ref().map_or(null(), |n| n.as_ptr()); - - let handle = unsafe { - CreateFileMappingW( - file_handle.0, - null(), - protect, - max_size_high, - max_size_low, - name_ptr, - ) - }; - - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + let name_wide = name.as_ref().map(|n| n.as_wtf8().to_wide_cstring()); + host_winapi::create_file_mapping_w( + file_handle.0, + protect, + max_size_high, + max_size_low, + name_wide.as_deref(), + ) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } /// OpenFileMapping - Open a named file mapping object. @@ -1947,26 +984,15 @@ mod _winapi { name: PyStrRef, vm: &VirtualMachine, ) -> PyResult { - use windows_sys::Win32::System::Memory::OpenFileMappingW; - if name.as_bytes().contains(&0) { return Err( vm.new_value_error("OpenFileMapping: name must not contain null characters") ); } - let name_wide = name.as_wtf8().to_wide_with_nul(); - let handle = unsafe { - OpenFileMappingW( - desired_access, - i32::from(inherit_handle), - name_wide.as_ptr(), - ) - }; - - if handle.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(WinHandle(handle)) + let name_wide = name.as_wtf8().to_wide_cstring(); + host_winapi::open_file_mapping_w(desired_access, inherit_handle, &name_wide) + .map(WinHandle) + .map_err(|e| e.to_pyexception(vm)) } /// MapViewOfFile - Map a view of a file mapping into the address space. @@ -1979,57 +1005,26 @@ mod _winapi { number_bytes: usize, vm: &VirtualMachine, ) -> PyResult { - let address = unsafe { - windows_sys::Win32::System::Memory::MapViewOfFile( - file_map.0, - desired_access, - file_offset_high, - file_offset_low, - number_bytes, - ) - }; - - let ptr = address.Value; - if ptr.is_null() { - return Err(vm.new_last_os_error()); - } - Ok(ptr as isize) + host_winapi::map_view_of_file( + file_map.0, + desired_access, + file_offset_high, + file_offset_low, + number_bytes, + ) + .map_err(|e| e.to_pyexception(vm)) } /// UnmapViewOfFile - Unmap a mapped view of a file. #[pyfunction] fn UnmapViewOfFile(address: isize, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::System::Memory::MEMORY_MAPPED_VIEW_ADDRESS; - - let view = MEMORY_MAPPED_VIEW_ADDRESS { - Value: address as *mut core::ffi::c_void, - }; - let ret = unsafe { windows_sys::Win32::System::Memory::UnmapViewOfFile(view) }; - - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) + host_winapi::unmap_view_of_file(address).map_err(|e| e.to_pyexception(vm)) } /// VirtualQuerySize - Return the size of a memory region. #[pyfunction] fn VirtualQuerySize(address: isize, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::System::Memory::{MEMORY_BASIC_INFORMATION, VirtualQuery}; - - let mut mbi: MEMORY_BASIC_INFORMATION = unsafe { core::mem::zeroed() }; - let ret = unsafe { - VirtualQuery( - address as *const core::ffi::c_void, - &mut mbi, - core::mem::size_of::(), - ) - }; - - if ret == 0 { - return Err(vm.new_last_os_error()); - } - Ok(mbi.RegionSize) + host_winapi::virtual_query_size(address).map_err(|e| e.to_pyexception(vm)) } /// CopyFile2 - Copy a file with extended parameters. @@ -2041,29 +1036,9 @@ mod _winapi { _progress_routine: OptionalArg, vm: &VirtualMachine, ) -> PyResult<()> { - use windows_sys::Win32::Storage::FileSystem::{ - COPYFILE2_EXTENDED_PARAMETERS, CopyFile2 as WinCopyFile2, - }; - - let src_wide = existing_file_name.as_wtf8().to_wide_with_nul(); - let dst_wide = new_file_name.as_wtf8().to_wide_with_nul(); - - let mut params: COPYFILE2_EXTENDED_PARAMETERS = unsafe { core::mem::zeroed() }; - params.dwSize = core::mem::size_of::() as u32; - params.dwCopyFlags = flags; - - let hr = unsafe { WinCopyFile2(src_wide.as_ptr(), dst_wide.as_ptr(), ¶ms) }; - - if hr < 0 { - // HRESULT failure - convert to Windows error code - let err = if (hr as u32 >> 16) == 0x8007 { - (hr as u32) & 0xFFFF - } else { - hr as u32 - }; - return Err(std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm)); - } - Ok(()) + let src_wide = existing_file_name.as_wtf8().to_wide_cstring(); + let dst_wide = new_file_name.as_wtf8().to_wide_cstring(); + host_winapi::copy_file2(&src_wide, &dst_wide, flags).map_err(|e| e.to_pyexception(vm)) } /// _mimetypes_read_windows_registry - Read MIME type associations from registry. @@ -2072,110 +1047,15 @@ mod _winapi { on_type_read: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - use windows_sys::Win32::System::Registry::{ - HKEY, HKEY_CLASSES_ROOT, KEY_READ, REG_SZ, RegCloseKey, RegEnumKeyExW, RegOpenKeyExW, - RegQueryValueExW, - }; - - let mut hkcr: HKEY = null_mut() as HKEY; - let err = unsafe { RegOpenKeyExW(HKEY_CLASSES_ROOT, null(), 0, KEY_READ, &mut hkcr) }; - if err != 0 { - return Err(vm.new_os_error(err as i32)); - } - scopeguard::defer! { unsafe { RegCloseKey(hkcr) }; } - - let mut i: u32 = 0; - let mut entries: Vec<(String, String)> = Vec::new(); - - loop { - let mut ext_buf = [0u16; 128]; - let mut cch_ext: u32 = ext_buf.len() as u32; - - let err = unsafe { - RegEnumKeyExW( - hkcr, - i, - ext_buf.as_mut_ptr(), - &mut cch_ext, - null_mut(), - null_mut(), - null_mut(), - null_mut(), - ) - }; - i += 1; - - if err == windows_sys::Win32::Foundation::ERROR_NO_MORE_ITEMS { - break; - } - if err != 0 && err != windows_sys::Win32::Foundation::ERROR_MORE_DATA { - return Err(vm.new_os_error(err as i32)); - } - - // Only process keys starting with '.' - if cch_ext == 0 || ext_buf[0] != b'.' as u16 { - continue; - } - - let ext_wide = &ext_buf[..cch_ext as usize]; - - // Open subkey to read Content Type - let mut subkey: HKEY = null_mut() as HKEY; - let err = unsafe { RegOpenKeyExW(hkcr, ext_buf.as_ptr(), 0, KEY_READ, &mut subkey) }; - if err == windows_sys::Win32::Foundation::ERROR_FILE_NOT_FOUND - || err == windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED - { - continue; - } - if err != 0 { - return Err(vm.new_os_error(err as i32)); + host_winapi::read_windows_mimetype_registry_in_batches(|entries| { + for (mime_type, ext) in entries.drain(..) { + on_type_read.call((vm.ctx.new_str(mime_type), vm.ctx.new_str(ext)), vm)?; } - - let content_type_key: Vec = "Content Type\0".encode_utf16().collect(); - let mut type_buf = [0u16; 256]; - let mut cb_type: u32 = (type_buf.len() * 2) as u32; - let mut reg_type: u32 = 0; - - let err = unsafe { - RegQueryValueExW( - subkey, - content_type_key.as_ptr(), - null_mut(), - &mut reg_type, - type_buf.as_mut_ptr() as *mut u8, - &mut cb_type, - ) - }; - unsafe { RegCloseKey(subkey) }; - - if err != 0 || reg_type != REG_SZ || cb_type == 0 { - continue; - } - - // Convert wide strings to Rust strings - let type_len = (cb_type as usize / 2).saturating_sub(1); // exclude null terminator - let type_str = String::from_utf16_lossy(&type_buf[..type_len]); - let ext_str = String::from_utf16_lossy(ext_wide); - - if type_str.is_empty() { - continue; - } - - entries.push((type_str, ext_str)); - - // Flush buffer periodically to call Python callback - if entries.len() >= 64 { - for (mime_type, ext) in entries.drain(..) { - on_type_read.call((vm.ctx.new_str(mime_type), vm.ctx.new_str(ext)), vm)?; - } - } - } - - // Process remaining entries - for (mime_type, ext) in entries { - on_type_read.call((vm.ctx.new_str(mime_type), vm.ctx.new_str(ext)), vm)?; - } - - Ok(()) + Ok(()) + }) + .map_err(|err| match err { + host_winapi::MimeRegistryReadError::Os(err) => vm.new_os_error(err as i32), + host_winapi::MimeRegistryReadError::Callback(err) => err, + }) } } diff --git a/crates/vm/src/stdlib/_wmi.rs b/crates/vm/src/stdlib/_wmi.rs index 7236c74809e..e1549ef12d0 100644 --- a/crates/vm/src/stdlib/_wmi.rs +++ b/crates/vm/src/stdlib/_wmi.rs @@ -3,570 +3,13 @@ pub(crate) use _wmi::module_def; -// COM/WMI FFI declarations (not inside pymodule to avoid macro issues) -mod wmi_ffi { - #![allow(unsafe_op_in_unsafe_fn)] - use core::ffi::c_void; - - pub(super) type HRESULT = i32; - - #[repr(C)] - pub(super) struct GUID { - pub(super) data1: u32, - pub(super) data2: u16, - pub(super) data3: u16, - pub(super) data4: [u8; 8], - } - - // Opaque VARIANT type (24 bytes covers both 32-bit and 64-bit) - #[repr(C, align(8))] - pub(super) struct VARIANT([u64; 3]); - - impl VARIANT { - pub(super) fn zeroed() -> Self { - Self([0u64; 3]) - } - } - - // CLSID_WbemLocator = {4590F811-1D3A-11D0-891F-00AA004B2E24} - pub(super) const CLSID_WBEM_LOCATOR: GUID = GUID { - data1: 0x4590F811, - data2: 0x1D3A, - data3: 0x11D0, - data4: [0x89, 0x1F, 0x00, 0xAA, 0x00, 0x4B, 0x2E, 0x24], - }; - - // IID_IWbemLocator = {DC12A687-737F-11CF-884D-00AA004B2E24} - pub(super) const IID_IWBEM_LOCATOR: GUID = GUID { - data1: 0xDC12A687, - data2: 0x737F, - data3: 0x11CF, - data4: [0x88, 0x4D, 0x00, 0xAA, 0x00, 0x4B, 0x2E, 0x24], - }; - - // COM constants - pub(super) const COINIT_APARTMENTTHREADED: u32 = 0x2; - pub(super) const CLSCTX_INPROC_SERVER: u32 = 0x1; - pub(super) const RPC_C_AUTHN_LEVEL_DEFAULT: u32 = 0; - pub(super) const RPC_C_IMP_LEVEL_IMPERSONATE: u32 = 3; - pub(super) const RPC_C_AUTHN_LEVEL_CALL: u32 = 3; - pub(super) const RPC_C_AUTHN_WINNT: u32 = 10; - pub(super) const RPC_C_AUTHZ_NONE: u32 = 0; - pub(super) const EOAC_NONE: u32 = 0; - pub(super) const RPC_E_TOO_LATE: HRESULT = 0x80010119_u32 as i32; - - // WMI constants - pub(super) const WBEM_FLAG_FORWARD_ONLY: i32 = 0x20; - pub(super) const WBEM_FLAG_RETURN_IMMEDIATELY: i32 = 0x10; - pub(super) const WBEM_S_FALSE: HRESULT = 1; - pub(super) const WBEM_S_NO_MORE_DATA: HRESULT = 0x40005; - pub(super) const WBEM_INFINITE: i32 = -1; - pub(super) const WBEM_FLAVOR_MASK_ORIGIN: i32 = 0x60; - pub(super) const WBEM_FLAVOR_ORIGIN_SYSTEM: i32 = 0x40; - - #[link(name = "ole32")] - unsafe extern "system" { - pub(super) fn CoInitializeEx(pvReserved: *mut c_void, dwCoInit: u32) -> HRESULT; - - pub(super) fn CoUninitialize(); - - pub(super) fn CoInitializeSecurity( - pSecDesc: *const c_void, - cAuthSvc: i32, - asAuthSvc: *const c_void, - pReserved1: *const c_void, - dwAuthnLevel: u32, - dwImpLevel: u32, - pAuthList: *const c_void, - dwCapabilities: u32, - pReserved3: *const c_void, - ) -> HRESULT; - - pub(super) fn CoCreateInstance( - rclsid: *const GUID, - pUnkOuter: *mut c_void, - dwClsContext: u32, - riid: *const GUID, - ppv: *mut *mut c_void, - ) -> HRESULT; - - pub(super) fn CoSetProxyBlanket( - pProxy: *mut c_void, - dwAuthnSvc: u32, - dwAuthzSvc: u32, - pServerPrincName: *const u16, - dwAuthnLevel: u32, - dwImpLevel: u32, - pAuthInfo: *const c_void, - dwCapabilities: u32, - ) -> HRESULT; - } - - #[link(name = "oleaut32")] - unsafe extern "system" { - pub(super) fn SysAllocString(psz: *const u16) -> *mut u16; - pub(super) fn SysFreeString(bstrString: *mut u16); - pub(super) fn VariantClear(pvarg: *mut VARIANT) -> HRESULT; - } - - #[link(name = "propsys")] - unsafe extern "system" { - pub(super) fn VariantToString( - varIn: *const VARIANT, - pszBuf: *mut u16, - cchBuf: u32, - ) -> HRESULT; - } - - /// Release a COM object (IUnknown::Release, vtable index 2) - pub(super) unsafe fn com_release(this: *mut c_void) { - if !this.is_null() { - let vtable = *(this as *const *const usize); - let release: unsafe extern "system" fn(*mut c_void) -> u32 = - core::mem::transmute(*vtable.add(2)); - release(this); - } - } - - /// IWbemLocator::ConnectServer (vtable index 3) - #[allow(clippy::too_many_arguments)] - pub(super) unsafe fn locator_connect_server( - this: *mut c_void, - network_resource: *const u16, - user: *const u16, - password: *const u16, - locale: *const u16, - security_flags: i32, - authority: *const u16, - ctx: *mut c_void, - services: *mut *mut c_void, - ) -> HRESULT { - let vtable = *(this as *const *const usize); - let method: unsafe extern "system" fn( - *mut c_void, - *const u16, - *const u16, - *const u16, - *const u16, - i32, - *const u16, - *mut c_void, - *mut *mut c_void, - ) -> HRESULT = core::mem::transmute(*vtable.add(3)); - method( - this, - network_resource, - user, - password, - locale, - security_flags, - authority, - ctx, - services, - ) - } - - /// IWbemServices::ExecQuery (vtable index 20) - pub(super) unsafe fn services_exec_query( - this: *mut c_void, - query_language: *const u16, - query: *const u16, - flags: i32, - ctx: *mut c_void, - enumerator: *mut *mut c_void, - ) -> HRESULT { - let vtable = *(this as *const *const usize); - let method: unsafe extern "system" fn( - *mut c_void, - *const u16, - *const u16, - i32, - *mut c_void, - *mut *mut c_void, - ) -> HRESULT = core::mem::transmute(*vtable.add(20)); - method(this, query_language, query, flags, ctx, enumerator) - } - - /// IEnumWbemClassObject::Next (vtable index 4) - pub(super) unsafe fn enum_next( - this: *mut c_void, - timeout: i32, - count: u32, - objects: *mut *mut c_void, - returned: *mut u32, - ) -> HRESULT { - let vtable = *(this as *const *const usize); - let method: unsafe extern "system" fn( - *mut c_void, - i32, - u32, - *mut *mut c_void, - *mut u32, - ) -> HRESULT = core::mem::transmute(*vtable.add(4)); - method(this, timeout, count, objects, returned) - } - - /// IWbemClassObject::BeginEnumeration (vtable index 8) - pub(super) unsafe fn object_begin_enumeration(this: *mut c_void, enum_flags: i32) -> HRESULT { - let vtable = *(this as *const *const usize); - let method: unsafe extern "system" fn(*mut c_void, i32) -> HRESULT = - core::mem::transmute(*vtable.add(8)); - method(this, enum_flags) - } - - /// IWbemClassObject::Next (vtable index 9) - pub(super) unsafe fn object_next( - this: *mut c_void, - flags: i32, - name: *mut *mut u16, - val: *mut VARIANT, - cim_type: *mut i32, - flavor: *mut i32, - ) -> HRESULT { - let vtable = *(this as *const *const usize); - let method: unsafe extern "system" fn( - *mut c_void, - i32, - *mut *mut u16, - *mut VARIANT, - *mut i32, - *mut i32, - ) -> HRESULT = core::mem::transmute(*vtable.add(9)); - method(this, flags, name, val, cim_type, flavor) - } - - /// IWbemClassObject::EndEnumeration (vtable index 10) - pub(super) unsafe fn object_end_enumeration(this: *mut c_void) -> HRESULT { - let vtable = *(this as *const *const usize); - let method: unsafe extern "system" fn(*mut c_void) -> HRESULT = - core::mem::transmute(*vtable.add(10)); - method(this) - } -} - #[pymodule] mod _wmi { - use super::wmi_ffi::*; use crate::builtins::PyStrRef; use crate::convert::ToPyException; use crate::{PyResult, VirtualMachine}; - use core::ffi::c_void; - use core::ptr::{null, null_mut}; - use windows_sys::Win32::Foundation::{ - CloseHandle, ERROR_BROKEN_PIPE, ERROR_MORE_DATA, ERROR_NOT_ENOUGH_MEMORY, GetLastError, - HANDLE, WAIT_OBJECT_0, WAIT_TIMEOUT, - }; - use windows_sys::Win32::Storage::FileSystem::{ReadFile, WriteFile}; - use windows_sys::Win32::System::Pipes::CreatePipe; - use windows_sys::Win32::System::Threading::{ - CreateEventW, CreateThread, GetExitCodeThread, SetEvent, WaitForSingleObject, - }; - - const BUFFER_SIZE: usize = 8192; - - fn hresult_from_win32(err: u32) -> HRESULT { - if err == 0 { - 0 - } else { - ((err & 0xFFFF) | 0x80070000) as HRESULT - } - } - - fn succeeded(hr: HRESULT) -> bool { - hr >= 0 - } + use rustpython_host_env::wmi as host_wmi; - fn failed(hr: HRESULT) -> bool { - hr < 0 - } - - fn wide_str(s: &str) -> Vec { - s.encode_utf16().chain(core::iter::once(0)).collect() - } - - unsafe fn wcslen(s: *const u16) -> usize { - unsafe { - let mut len = 0; - while *s.add(len) != 0 { - len += 1; - } - len - } - } - - unsafe fn wait_event(event: HANDLE, timeout: u32) -> u32 { - unsafe { - match WaitForSingleObject(event, timeout) { - WAIT_OBJECT_0 => 0, - WAIT_TIMEOUT => WAIT_TIMEOUT, - _ => GetLastError(), - } - } - } - - struct QueryThreadData { - query: Vec, - write_pipe: HANDLE, - init_event: HANDLE, - connect_event: HANDLE, - } - - // SAFETY: QueryThreadData contains HANDLEs (isize) which are safe to send across threads - unsafe impl Send for QueryThreadData {} - - unsafe extern "system" fn query_thread(param: *mut c_void) -> u32 { - unsafe { query_thread_impl(param) } - } - - unsafe fn query_thread_impl(param: *mut c_void) -> u32 { - unsafe { - let data = Box::from_raw(param as *mut QueryThreadData); - let write_pipe = data.write_pipe; - let init_event = data.init_event; - let connect_event = data.connect_event; - - let mut locator: *mut c_void = null_mut(); - let mut services: *mut c_void = null_mut(); - let mut enumerator: *mut c_void = null_mut(); - let mut hr: HRESULT = 0; - - // gh-125315: Copy the query string first - let bstr_query = SysAllocString(data.query.as_ptr()); - if bstr_query.is_null() { - hr = hresult_from_win32(ERROR_NOT_ENOUGH_MEMORY); - } - - drop(data); - - if succeeded(hr) { - hr = CoInitializeEx(null_mut(), COINIT_APARTMENTTHREADED); - } - - if failed(hr) { - CloseHandle(write_pipe); - if !bstr_query.is_null() { - SysFreeString(bstr_query); - } - return hr as u32; - } - - hr = CoInitializeSecurity( - null(), - -1, - null(), - null(), - RPC_C_AUTHN_LEVEL_DEFAULT, - RPC_C_IMP_LEVEL_IMPERSONATE, - null(), - EOAC_NONE, - null(), - ); - // gh-96684: CoInitializeSecurity will fail if another part of the app has - // already called it. - if hr == RPC_E_TOO_LATE { - hr = 0; - } - - if succeeded(hr) { - hr = CoCreateInstance( - &CLSID_WBEM_LOCATOR, - null_mut(), - CLSCTX_INPROC_SERVER, - &IID_IWBEM_LOCATOR, - &mut locator, - ); - } - if succeeded(hr) && SetEvent(init_event) == 0 { - hr = hresult_from_win32(GetLastError()); - } - - if succeeded(hr) { - let root_cimv2 = wide_str("ROOT\\CIMV2"); - let bstr_root = SysAllocString(root_cimv2.as_ptr()); - hr = locator_connect_server( - locator, - bstr_root, - null(), - null(), - null(), - 0, - null(), - null_mut(), - &mut services, - ); - if !bstr_root.is_null() { - SysFreeString(bstr_root); - } - } - if succeeded(hr) && SetEvent(connect_event) == 0 { - hr = hresult_from_win32(GetLastError()); - } - - if succeeded(hr) { - hr = CoSetProxyBlanket( - services, - RPC_C_AUTHN_WINNT, - RPC_C_AUTHZ_NONE, - null(), - RPC_C_AUTHN_LEVEL_CALL, - RPC_C_IMP_LEVEL_IMPERSONATE, - null(), - EOAC_NONE, - ); - } - if succeeded(hr) { - let wql = wide_str("WQL"); - let bstr_wql = SysAllocString(wql.as_ptr()); - hr = services_exec_query( - services, - bstr_wql, - bstr_query, - WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, - null_mut(), - &mut enumerator, - ); - if !bstr_wql.is_null() { - SysFreeString(bstr_wql); - } - } - - // Enumerate results and write to pipe - let mut value: *mut c_void; - let mut start_of_enum = true; - let null_sep: u16 = 0; - let eq_sign: u16 = b'=' as u16; - - while succeeded(hr) { - let mut got: u32 = 0; - let mut written: u32 = 0; - value = null_mut(); - hr = enum_next(enumerator, WBEM_INFINITE, 1, &mut value, &mut got); - - if hr == WBEM_S_FALSE { - hr = 0; - break; - } - if failed(hr) || got != 1 || value.is_null() { - continue; - } - - if !start_of_enum - && WriteFile( - write_pipe, - &null_sep as *const u16 as *const _, - 2, - &mut written, - null_mut(), - ) == 0 - { - hr = hresult_from_win32(GetLastError()); - com_release(value); - break; - } - start_of_enum = false; - - hr = object_begin_enumeration(value, 0); - if failed(hr) { - com_release(value); - break; - } - - while succeeded(hr) { - let mut prop_name: *mut u16 = null_mut(); - let mut prop_value = VARIANT::zeroed(); - let mut flavor: i32 = 0; - - hr = object_next( - value, - 0, - &mut prop_name, - &mut prop_value, - null_mut(), - &mut flavor, - ); - - if hr == WBEM_S_NO_MORE_DATA { - hr = 0; - break; - } - - if succeeded(hr) - && (flavor & WBEM_FLAVOR_MASK_ORIGIN) != WBEM_FLAVOR_ORIGIN_SYSTEM - { - let mut prop_str = [0u16; BUFFER_SIZE]; - hr = - VariantToString(&prop_value, prop_str.as_mut_ptr(), BUFFER_SIZE as u32); - - if succeeded(hr) { - let cb_str1 = (wcslen(prop_name) * 2) as u32; - let cb_str2 = (wcslen(prop_str.as_ptr()) * 2) as u32; - - if WriteFile( - write_pipe, - prop_name as *const _, - cb_str1, - &mut written, - null_mut(), - ) == 0 - || WriteFile( - write_pipe, - &eq_sign as *const u16 as *const _, - 2, - &mut written, - null_mut(), - ) == 0 - || WriteFile( - write_pipe, - prop_str.as_ptr() as *const _, - cb_str2, - &mut written, - null_mut(), - ) == 0 - || WriteFile( - write_pipe, - &null_sep as *const u16 as *const _, - 2, - &mut written, - null_mut(), - ) == 0 - { - hr = hresult_from_win32(GetLastError()); - } - } - - VariantClear(&mut prop_value); - SysFreeString(prop_name); - } - } - - object_end_enumeration(value); - com_release(value); - } - - // Cleanup - if !bstr_query.is_null() { - SysFreeString(bstr_query); - } - if !enumerator.is_null() { - com_release(enumerator); - } - if !services.is_null() { - com_release(services); - } - if !locator.is_null() { - com_release(locator); - } - CoUninitialize(); - CloseHandle(write_pipe); - - hr as u32 - } - } - - /// Runs a WMI query against the local machine. - /// - /// This returns a single string with 'name=value' pairs in a flat array separated - /// by null characters. #[pyfunction] fn exec_query(query: PyStrRef, vm: &VirtualMachine) -> PyResult { let query_str = query.expect_str(); @@ -578,128 +21,6 @@ mod _wmi { return Err(vm.new_value_error("only SELECT queries are supported")); } - let query_wide = wide_str(query_str); - - let mut h_thread: HANDLE = null_mut(); - let mut err: u32 = 0; - let mut buffer = [0u16; BUFFER_SIZE]; - let mut offset: u32 = 0; - let mut bytes_read: u32 = 0; - - let mut read_pipe: HANDLE = null_mut(); - let mut write_pipe: HANDLE = null_mut(); - - unsafe { - let init_event = CreateEventW(null(), 1, 0, null()); - let connect_event = CreateEventW(null(), 1, 0, null()); - - if init_event.is_null() - || connect_event.is_null() - || CreatePipe(&mut read_pipe, &mut write_pipe, null(), 0) == 0 - { - err = GetLastError(); - } else { - let thread_data = Box::new(QueryThreadData { - query: query_wide, - write_pipe, - init_event, - connect_event, - }); - let thread_data_ptr = Box::into_raw(thread_data); - - h_thread = CreateThread( - null(), - 0, - Some(query_thread), - thread_data_ptr as *const _ as *mut _, - 0, - null_mut(), - ); - - if h_thread.is_null() { - err = GetLastError(); - // Thread didn't start, so recover data and close write pipe - let data = Box::from_raw(thread_data_ptr); - CloseHandle(data.write_pipe); - } - } - - // gh-112278: Timeout for COM init and WMI connection - if err == 0 { - err = wait_event(init_event, 1000); - if err == 0 { - err = wait_event(connect_event, 100); - } - } - - // Read results from pipe - while err == 0 { - let buf_ptr = (buffer.as_mut_ptr() as *mut u8).add(offset as usize); - let buf_remaining = (BUFFER_SIZE * 2) as u32 - offset; - - if ReadFile( - read_pipe, - buf_ptr as *mut _, - buf_remaining, - &mut bytes_read, - null_mut(), - ) != 0 - { - offset += bytes_read; - if offset >= (BUFFER_SIZE * 2) as u32 { - err = ERROR_MORE_DATA; - } - } else { - err = GetLastError(); - } - } - - if !read_pipe.is_null() { - CloseHandle(read_pipe); - } - - if !h_thread.is_null() { - let thread_err: u32; - match WaitForSingleObject(h_thread, 100) { - WAIT_OBJECT_0 => { - let mut exit_code: u32 = 0; - if GetExitCodeThread(h_thread, &mut exit_code) == 0 { - thread_err = GetLastError(); - } else { - thread_err = exit_code; - } - } - WAIT_TIMEOUT => { - thread_err = WAIT_TIMEOUT; - } - _ => { - thread_err = GetLastError(); - } - } - if err == 0 || err == ERROR_BROKEN_PIPE { - err = thread_err; - } - - CloseHandle(h_thread); - } - - CloseHandle(init_event); - CloseHandle(connect_event); - } - - if err == ERROR_MORE_DATA { - return Err( - vm.new_os_error(format!("Query returns more than {BUFFER_SIZE} characters")) - ); - } else if err != 0 { - return Err(std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm)); - } - - if offset == 0 { - return Ok(String::new()); - } - - let char_count = (offset as usize) / 2 - 1; - Ok(String::from_utf16_lossy(&buffer[..char_count])) + host_wmi::exec_query(query_str).map_err(|err| err.to_pyexception(vm)) } } diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index dba40d78f38..403e8c12e62 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -160,7 +160,11 @@ mod builtins { .map(|&b| b as char) .collect(); - if name.is_empty() { None } else { Some(name) } + if name.is_empty() { + None + } else { + Some(normalize_source_encoding(&name)) + } } // Split into lines (first two only) @@ -186,15 +190,39 @@ mod builtins { lines.next().and_then(find_encoding_in_line) } + /// Match CPython's Parser/tokenizer/helpers.c:get_normal_name(). + #[cfg(feature = "parser")] + fn normalize_source_encoding(name: &str) -> String { + let mut normalized = String::with_capacity(name.len().min(12)); + for ch in name.chars().take(12) { + if ch == '_' { + normalized.push('-'); + } else { + normalized.push(ch.to_ascii_lowercase()); + } + } + + if normalized == "utf-8" || normalized.starts_with("utf-8-") { + "utf-8".to_owned() + } else if normalized == "latin-1" + || normalized == "iso-8859-1" + || normalized == "iso-latin-1" + || normalized.starts_with("latin-1-") + || normalized.starts_with("iso-8859-1-") + || normalized.starts_with("iso-latin-1-") + { + "iso-8859-1".to_owned() + } else { + name.to_owned() + } + } + /// Decode source bytes to a string, handling PEP 263 encoding declarations /// and BOM. Raises SyntaxError for invalid UTF-8 without an encoding /// declaration. - /// Check if an encoding name is a UTF-8 variant after normalization. - /// Matches: utf-8, utf_8, utf8, UTF-8, etc. #[cfg(feature = "parser")] fn is_utf8_encoding(name: &str) -> bool { - let normalized: String = name.chars().filter(|&c| c != '-' && c != '_').collect(); - normalized.eq_ignore_ascii_case("utf8") + name == "utf-8" } #[cfg(feature = "parser")] @@ -206,9 +234,10 @@ mod builtins { // Validate BOM + encoding combination if has_bom && !is_utf8 { + let enc = encoding.as_deref().unwrap_or("utf-8"); return Err(vm.new_exception_msg( vm.ctx.exceptions.syntax_error.to_owned(), - format!("encoding problem for '{filename}': utf-8").into(), + format!("encoding problem: {enc} with BOM").into(), )); } @@ -737,7 +766,7 @@ mod builtins { } ReadlineResult::Io(e) => Err(vm.new_os_error(e.to_string())), #[cfg(unix)] - ReadlineResult::OsError(num) => Err(vm.new_os_error(num.to_string())), + ReadlineResult::OsError(num) => Err(vm.new_os_error(num)), ReadlineResult::Other(e) => Err(vm.new_runtime_error(e.to_string())), } } else { @@ -754,9 +783,7 @@ mod builtins { /// In this case, rustyline may hang because it uses raw mode. #[cfg(unix)] fn is_pty_child() -> bool { - use nix::unistd::{getpid, getsid}; - // If this process is a session leader, we're likely in a PTY child - getsid(None) == Ok(getpid()) + crate::host_env::posix::is_session_leader() } #[cfg(not(unix))] @@ -1447,6 +1474,7 @@ pub fn init_module(vm: &VirtualMachine, module: &Py) { "TimeoutError" => ctx.exceptions.timeout_error.to_owned(), "ReferenceError" => ctx.exceptions.reference_error.to_owned(), "RuntimeError" => ctx.exceptions.runtime_error.to_owned(), + "PythonFinalizationError" => ctx.exceptions.python_finalization_error.to_owned(), "NotImplementedError" => ctx.exceptions.not_implemented_error.to_owned(), "RecursionError" => ctx.exceptions.recursion_error.to_owned(), "SyntaxError" => ctx.exceptions.syntax_error.to_owned(), diff --git a/crates/vm/src/stdlib/errno.rs b/crates/vm/src/stdlib/errno.rs index d7a0a222a76..5b0d666984b 100644 --- a/crates/vm/src/stdlib/errno.rs +++ b/crates/vm/src/stdlib/errno.rs @@ -24,46 +24,7 @@ mod errno_mod { } #[cfg(any(unix, windows, target_os = "wasi"))] -pub mod errors { - pub use libc::*; - #[cfg(windows)] - pub use windows_sys::Win32::{ - Foundation::*, - Networking::WinSock::{ - WSABASEERR, WSADESCRIPTION_LEN, WSAEACCES, WSAEADDRINUSE, WSAEADDRNOTAVAIL, - WSAEAFNOSUPPORT, WSAEALREADY, WSAEBADF, WSAECANCELLED, WSAECONNABORTED, - WSAECONNREFUSED, WSAECONNRESET, WSAEDESTADDRREQ, WSAEDISCON, WSAEDQUOT, WSAEFAULT, - WSAEHOSTDOWN, WSAEHOSTUNREACH, WSAEINPROGRESS, WSAEINTR, WSAEINVAL, - WSAEINVALIDPROCTABLE, WSAEINVALIDPROVIDER, WSAEISCONN, WSAELOOP, WSAEMFILE, - WSAEMSGSIZE, WSAENAMETOOLONG, WSAENETDOWN, WSAENETRESET, WSAENETUNREACH, WSAENOBUFS, - WSAENOMORE, WSAENOPROTOOPT, WSAENOTCONN, WSAENOTEMPTY, WSAENOTSOCK, WSAEOPNOTSUPP, - WSAEPFNOSUPPORT, WSAEPROCLIM, WSAEPROTONOSUPPORT, WSAEPROTOTYPE, - WSAEPROVIDERFAILEDINIT, WSAEREFUSED, WSAEREMOTE, WSAESHUTDOWN, WSAESOCKTNOSUPPORT, - WSAESTALE, WSAETIMEDOUT, WSAETOOMANYREFS, WSAEUSERS, WSAEWOULDBLOCK, WSAID_ACCEPTEX, - WSAID_CONNECTEX, WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_TRANSMITFILE, - WSAID_TRANSMITPACKETS, WSAID_WSAPOLL, WSAID_WSARECVMSG, WSANO_DATA, WSANO_RECOVERY, - WSANOTINITIALISED, WSAPROTOCOL_LEN, WSASERVICE_NOT_FOUND, WSASYS_STATUS_LEN, - WSASYSCALLFAILURE, WSASYSNOTREADY, WSATRY_AGAIN, WSATYPE_NOT_FOUND, WSAVERNOTSUPPORTED, - }, - }; - #[cfg(windows)] - macro_rules! reexport_wsa { - ($($errname:ident),*$(,)?) => { - paste::paste! { - $(pub const $errname: i32 = windows_sys::Win32::Networking::WinSock:: [] as i32;)* - } - } - } - #[cfg(windows)] - reexport_wsa! { - EADDRINUSE, EADDRNOTAVAIL, EAFNOSUPPORT, EALREADY, ECONNABORTED, ECONNREFUSED, ECONNRESET, - EDESTADDRREQ, EDQUOT, EHOSTDOWN, EHOSTUNREACH, EINPROGRESS, EISCONN, ELOOP, EMSGSIZE, - ENETDOWN, ENETRESET, ENETUNREACH, ENOBUFS, ENOPROTOOPT, ENOTCONN, ENOTSOCK, EOPNOTSUPP, - EPFNOSUPPORT, EPROTONOSUPPORT, EPROTOTYPE, EREMOTE, ESHUTDOWN, ESOCKTNOSUPPORT, ESTALE, - ETIMEDOUT, ETOOMANYREFS, EUSERS, EWOULDBLOCK, - // TODO: EBADF should be here once winerrs are translated to errnos but it messes up some things atm - } -} +pub use rustpython_host_env::errno::errors; #[cfg(any(unix, windows, target_os = "wasi"))] macro_rules! e { diff --git a/crates/vm/src/stdlib/marshal.rs b/crates/vm/src/stdlib/marshal.rs index 60ecb1792f0..6e0fc4e7f5d 100644 --- a/crates/vm/src/stdlib/marshal.rs +++ b/crates/vm/src/stdlib/marshal.rs @@ -107,6 +107,14 @@ mod decl { _version, } = args; let version = _version.unwrap_or(marshal::FORMAT_VERSION as i32); + + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call( + (vm.ctx.new_str("marshal.dumps"), value.clone(), version), + vm, + )?; + } + if !allow_code { check_no_code(&value, vm)?; } diff --git a/crates/vm/src/stdlib/msvcrt.rs b/crates/vm/src/stdlib/msvcrt.rs index e3aa7432b71..774b3cf087d 100644 --- a/crates/vm/src/stdlib/msvcrt.rs +++ b/crates/vm/src/stdlib/msvcrt.rs @@ -13,10 +13,9 @@ mod msvcrt { use itertools::Itertools; use rustpython_host_env::msvcrt as host_msvcrt; use std::os::windows::io::AsRawHandle; - use windows_sys::Win32::System::Diagnostics::Debug; #[pyattr] - use windows_sys::Win32::System::Diagnostics::Debug::{ + use host_msvcrt::{ SEM_FAILCRITICALERRORS, SEM_NOALIGNMENTFAULTEXCEPT, SEM_NOGPFAULTERRORBOX, SEM_NOOPENFILEERRORBOX, }; @@ -139,7 +138,7 @@ mod msvcrt { #[allow(non_snake_case)] #[pyfunction] - fn SetErrorMode(mode: Debug::THREAD_ERROR_MODE, _: &VirtualMachine) -> u32 { + fn SetErrorMode(mode: host_msvcrt::ErrorMode, _: &VirtualMachine) -> u32 { host_msvcrt::set_error_mode(mode) } } diff --git a/crates/vm/src/stdlib/nt.rs b/crates/vm/src/stdlib/nt.rs index 7b3dc5c8b4d..3303b1c67e2 100644 --- a/crates/vm/src/stdlib/nt.rs +++ b/crates/vm/src/stdlib/nt.rs @@ -7,27 +7,19 @@ pub use module::raw_set_handle_inheritable; pub(crate) mod module { use crate::{ Py, PyResult, TryFromObject, VirtualMachine, - builtins::{ - PyBaseExceptionRef, PyBytes, PyDictRef, PyListRef, PyStr, PyStrRef, PyTupleRef, - }, + builtins::{PyBytes, PyDictRef, PyListRef, PyStr, PyStrRef, PyTupleRef}, convert::ToPyException, exceptions::OSErrorBuilder, function::{ArgMapping, Either, OptionalArg}, - host_env::{crt_fd, suppress_iph, windows::ToWideString}, + host_env::{crt_fd, windows::ToWideString}, ospath::{OsPath, OsPathOrFd}, stdlib::os::{_os, DirFd, SupportFunc, TargetIsDirectory}, }; - use core::mem::MaybeUninit; use libc::intptr_t; use rustpython_common::wtf8::Wtf8Buf; use rustpython_host_env::nt as host_nt; + use std::os::windows::ffi::OsStringExt; use std::os::windows::io::AsRawHandle; - use std::{io, os::windows::ffi::OsStringExt}; - use windows_sys::Win32::{ - Foundation::{self, INVALID_HANDLE_VALUE}, - Storage::FileSystem, - System::{Console, Threading}, - }; #[pyattr] use libc::{O_BINARY, O_NOINHERIT, O_RANDOM, O_SEQUENTIAL, O_TEMPORARY, O_TEXT}; @@ -57,7 +49,7 @@ pub(crate) mod module { const TMP_MAX: i32 = i32::MAX; #[pyattr] - use windows_sys::Win32::System::LibraryLoader::{ + use host_nt::{ LOAD_LIBRARY_SEARCH_APPLICATION_DIR as _LOAD_LIBRARY_SEARCH_APPLICATION_DIR, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS as _LOAD_LIBRARY_SEARCH_DEFAULT_DIRS, LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR as _LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR, @@ -67,11 +59,8 @@ pub(crate) mod module { #[pyfunction] pub(super) fn access(path: OsPath, mode: u8, vm: &VirtualMachine) -> PyResult { - let attr = unsafe { FileSystem::GetFileAttributesW(path.to_wide_cstring(vm)?.as_ptr()) }; - Ok(attr != FileSystem::INVALID_FILE_ATTRIBUTES - && (mode & 2 == 0 - || attr & FileSystem::FILE_ATTRIBUTE_READONLY == 0 - || attr & FileSystem::FILE_ATTRIBUTE_DIRECTORY != 0)) + let _ = path.to_wide_cstring(vm)?; + Ok(host_nt::access(path.as_ref(), mode)) } #[pyfunction] @@ -81,59 +70,14 @@ pub(crate) mod module { dir_fd: DirFd<'static, 0>, vm: &VirtualMachine, ) -> PyResult<()> { - // On Windows, use DeleteFileW directly. - // Rust's std::fs::remove_file may have different behavior for read-only files. - // See Py_DeleteFileW. - use windows_sys::Win32::Storage::FileSystem::{ - DeleteFileW, FindClose, FindFirstFileW, RemoveDirectoryW, WIN32_FIND_DATAW, - }; - use windows_sys::Win32::System::SystemServices::{ - IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, - }; - let [] = dir_fd.0; - let wide_path = path.to_wide_cstring(vm)?; - let attrs = unsafe { FileSystem::GetFileAttributesW(wide_path.as_ptr()) }; - - let mut is_directory = false; - let mut is_link = false; - - if attrs != FileSystem::INVALID_FILE_ATTRIBUTES { - is_directory = (attrs & FileSystem::FILE_ATTRIBUTE_DIRECTORY) != 0; - - // Check if it's a symlink or junction point - if is_directory && (attrs & FileSystem::FILE_ATTRIBUTE_REPARSE_POINT) != 0 { - let mut find_data: WIN32_FIND_DATAW = unsafe { core::mem::zeroed() }; - let handle = unsafe { FindFirstFileW(wide_path.as_ptr(), &mut find_data) }; - if handle != INVALID_HANDLE_VALUE { - is_link = find_data.dwReserved0 == IO_REPARSE_TAG_SYMLINK - || find_data.dwReserved0 == IO_REPARSE_TAG_MOUNT_POINT; - unsafe { FindClose(handle) }; - } - } - } - - let result = if is_directory && is_link { - unsafe { RemoveDirectoryW(wide_path.as_ptr()) } - } else { - unsafe { DeleteFileW(wide_path.as_ptr()) } - }; - - if result == 0 { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - Ok(()) + let _ = path.to_wide_cstring(vm)?; + host_nt::remove(path.as_ref()).map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } #[pyfunction] pub(super) fn _supports_virtual_terminal() -> bool { - let mut mode = 0; - let handle = unsafe { Console::GetStdHandle(Console::STD_ERROR_HANDLE) }; - if unsafe { Console::GetConsoleMode(handle, &mut mode) } == 0 { - return false; - } - mode & Console::ENABLE_VIRTUAL_TERMINAL_PROCESSING != 0 + host_nt::supports_virtual_terminal() } #[derive(FromArgs)] @@ -149,69 +93,15 @@ pub(crate) mod module { #[pyfunction] pub(super) fn symlink(args: SymlinkArgs<'_>, vm: &VirtualMachine) -> PyResult<()> { use crate::exceptions::ToOSErrorBuilder; - use core::sync::atomic::{AtomicBool, Ordering}; - use windows_sys::Win32::Storage::FileSystem::WIN32_FILE_ATTRIBUTE_DATA; - use windows_sys::Win32::Storage::FileSystem::{ - CreateSymbolicLinkW, FILE_ATTRIBUTE_DIRECTORY, GetFileAttributesExW, - SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE, SYMBOLIC_LINK_FLAG_DIRECTORY, - }; - - static HAS_UNPRIVILEGED_FLAG: AtomicBool = AtomicBool::new(true); - - fn check_dir(src: &OsPath, dst: &OsPath) -> bool { - use windows_sys::Win32::Storage::FileSystem::GetFileExInfoStandard; - - let dst_parent = dst.as_path().parent(); - let Some(dst_parent) = dst_parent else { - return false; - }; - let resolved = if src.as_path().is_absolute() { - src.as_path().to_path_buf() - } else { - dst_parent.join(src.as_path()) - }; - let wide = match widestring::WideCString::from_os_str(&resolved) { - Ok(wide) => wide, - Err(_) => return false, - }; - let mut info: WIN32_FILE_ATTRIBUTE_DATA = unsafe { core::mem::zeroed() }; - let ok = unsafe { - GetFileAttributesExW( - wide.as_ptr(), - GetFileExInfoStandard, - &mut info as *mut _ as *mut _, - ) - }; - ok != 0 && (info.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) != 0 - } - - let mut flags = 0u32; - if HAS_UNPRIVILEGED_FLAG.load(Ordering::Relaxed) { - flags |= SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE; - } - if args.target_is_directory.target_is_directory || check_dir(&args.src, &args.dst) { - flags |= SYMBOLIC_LINK_FLAG_DIRECTORY; - } - let src = args.src.to_wide_cstring(vm)?; let dst = args.dst.to_wide_cstring(vm)?; - - let mut result = unsafe { CreateSymbolicLinkW(dst.as_ptr(), src.as_ptr(), flags) }; - if !result - && HAS_UNPRIVILEGED_FLAG.load(Ordering::Relaxed) - && unsafe { Foundation::GetLastError() } == Foundation::ERROR_INVALID_PARAMETER - { - let flags = flags & !SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE; - result = unsafe { CreateSymbolicLinkW(dst.as_ptr(), src.as_ptr(), flags) }; - if result - || unsafe { Foundation::GetLastError() } != Foundation::ERROR_INVALID_PARAMETER - { - HAS_UNPRIVILEGED_FLAG.store(false, Ordering::Relaxed); - } - } - - if !result { - let err = io::Error::last_os_error(); + if let Err(err) = host_nt::symlink( + args.src.as_ref(), + args.dst.as_ref(), + &src, + &dst, + args.target_is_directory.target_is_directory, + ) { let builder = err.to_os_error_builder(vm); let builder = builder .filename(args.src.filename(vm)) @@ -236,10 +126,7 @@ pub(crate) mod module { fn environ(vm: &VirtualMachine) -> PyDictRef { let environ = vm.ctx.new_dict(); - for (key, value) in crate::host_env::os::vars() { - // Skip hidden Windows environment variables (e.g., =C:, =D:, =ExitCode) - // These are internal cmd.exe bookkeeping variables that store per-drive - // current directories and cannot be reliably modified via _wputenv(). + for (key, value) in host_nt::visible_env_vars() { if key.starts_with('=') { continue; } @@ -251,10 +138,7 @@ pub(crate) mod module { #[pyfunction] fn _create_environ(vm: &VirtualMachine) -> PyDictRef { let environ = vm.ctx.new_dict(); - for (key, value) in crate::host_env::os::vars() { - if key.starts_with('=') { - continue; - } + for (key, value) in host_nt::visible_env_vars() { environ.set_item(&key, vm.new_pyobj(value), vm).unwrap(); } environ @@ -274,10 +158,6 @@ pub(crate) mod module { const S_IWRITE: u32 = 128; - fn win32_hchmod(handle: Foundation::HANDLE, mode: u32, vm: &VirtualMachine) -> PyResult<()> { - host_nt::win32_hchmod(handle, mode, S_IWRITE).map_err(|e| e.to_pyexception(vm)) - } - fn fchmod_impl(fd: i32, mode: u32, vm: &VirtualMachine) -> PyResult<()> { host_nt::fchmod(fd, mode, S_IWRITE).map_err(|e| e.to_pyexception(vm)) } @@ -319,30 +199,9 @@ pub(crate) mod module { let follow_symlinks = follow_symlinks.into_option().unwrap_or(false); if follow_symlinks { - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_FLAG_BACKUP_SEMANTICS, FILE_READ_ATTRIBUTES, FILE_SHARE_DELETE, - FILE_SHARE_READ, FILE_SHARE_WRITE, FILE_WRITE_ATTRIBUTES, OPEN_EXISTING, - }; - let wide = path.to_wide_cstring(vm)?; - let handle = unsafe { - CreateFileW( - wide.as_ptr(), - FILE_READ_ATTRIBUTES | FILE_WRITE_ATTRIBUTES, - FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - core::ptr::null(), - OPEN_EXISTING, - FILE_FLAG_BACKUP_SEMANTICS, - core::ptr::null_mut(), - ) - }; - if handle == INVALID_HANDLE_VALUE { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - let result = win32_hchmod(handle, mode, vm); - unsafe { Foundation::CloseHandle(handle) }; - result + host_nt::chmod_follow(&wide, mode, S_IWRITE) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } else { win32_lchmod(&path, mode, vm) } @@ -352,31 +211,8 @@ pub(crate) mod module { /// Uses FindFirstFileW to get the name as stored on the filesystem. #[pyfunction] fn _findfirstfile(path: OsPath, vm: &VirtualMachine) -> PyResult { - use crate::host_env::windows::ToWideString; - use std::os::windows::ffi::OsStringExt; - use windows_sys::Win32::Storage::FileSystem::{ - FindClose, FindFirstFileW, WIN32_FIND_DATAW, - }; - - let wide_path = path.as_ref().to_wide_with_nul(); - let mut find_data: WIN32_FIND_DATAW = unsafe { core::mem::zeroed() }; - - let handle = unsafe { FindFirstFileW(wide_path.as_ptr(), &mut find_data) }; - if handle == INVALID_HANDLE_VALUE { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - - unsafe { FindClose(handle) }; - - // Convert the filename from the find data to a Rust string - // cFileName is a null-terminated wide string - let len = find_data - .cFileName - .iter() - .position(|&c| c == 0) - .unwrap_or(find_data.cFileName.len()); - let filename = std::ffi::OsString::from_wide(&find_data.cFileName[..len]); + let filename = host_nt::find_first_file_name(path.as_ref()) + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; let filename_str = filename .to_str() .ok_or_else(|| vm.new_unicode_decode_error("filename contains invalid UTF-8"))?; @@ -406,296 +242,53 @@ pub(crate) mod module { /// _testInfo - determine file type based on attributes and reparse tag fn _test_info(attributes: u32, reparse_tag: u32, disk_device: bool, tested_type: u32) -> bool { - use windows_sys::Win32::Storage::FileSystem::{ - FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_REPARSE_POINT, + let tested_type = match tested_type { + PY_IFREG => host_nt::TestType::RegularFile, + PY_IFDIR => host_nt::TestType::Directory, + PY_IFLNK => host_nt::TestType::Symlink, + PY_IFMNT => host_nt::TestType::Junction, + PY_IFLRP => host_nt::TestType::LinkReparsePoint, + PY_IFRRP => host_nt::TestType::RegularReparsePoint, + _ => return false, }; - use windows_sys::Win32::System::SystemServices::{ - IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, - }; - - match tested_type { - PY_IFREG => { - // diskDevice && attributes && !(attributes & FILE_ATTRIBUTE_DIRECTORY) - disk_device && attributes != 0 && (attributes & FILE_ATTRIBUTE_DIRECTORY) == 0 - } - PY_IFDIR => (attributes & FILE_ATTRIBUTE_DIRECTORY) != 0, - PY_IFLNK => { - (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 - && reparse_tag == IO_REPARSE_TAG_SYMLINK - } - PY_IFMNT => { - (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 - && reparse_tag == IO_REPARSE_TAG_MOUNT_POINT - } - PY_IFLRP => { - (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 - && is_reparse_tag_name_surrogate(reparse_tag) - } - PY_IFRRP => { - (attributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 - && reparse_tag != 0 - && !is_reparse_tag_name_surrogate(reparse_tag) - } - _ => false, - } - } - - fn is_reparse_tag_name_surrogate(tag: u32) -> bool { - (tag & 0x20000000) != 0 - } - - fn file_info_error_is_trustworthy(error: u32) -> bool { - use windows_sys::Win32::Foundation; - matches!( - error, - Foundation::ERROR_FILE_NOT_FOUND - | Foundation::ERROR_PATH_NOT_FOUND - | Foundation::ERROR_NOT_READY - | Foundation::ERROR_BAD_NET_NAME - | Foundation::ERROR_BAD_NETPATH - | Foundation::ERROR_BAD_PATHNAME - | Foundation::ERROR_INVALID_NAME - | Foundation::ERROR_FILENAME_EXCED_RANGE - ) + host_nt::test_info(attributes, reparse_tag, disk_device, tested_type) } /// _testFileTypeByHandle - test file type using an open handle fn _test_file_type_by_handle( - handle: windows_sys::Win32::Foundation::HANDLE, + handle: host_nt::Handle, tested_type: u32, disk_only: bool, ) -> bool { - use windows_sys::Win32::Storage::FileSystem::{ - FILE_ATTRIBUTE_TAG_INFO, FILE_BASIC_INFO, FILE_TYPE_DISK, - FileAttributeTagInfo as FileAttributeTagInfoClass, FileBasicInfo, - GetFileInformationByHandleEx, GetFileType, + let tested_type = match tested_type { + PY_IFREG => host_nt::TestType::RegularFile, + PY_IFDIR => host_nt::TestType::Directory, + PY_IFLNK => host_nt::TestType::Symlink, + PY_IFMNT => host_nt::TestType::Junction, + PY_IFLRP => host_nt::TestType::LinkReparsePoint, + PY_IFRRP => host_nt::TestType::RegularReparsePoint, + _ => return false, }; - - let disk_device = unsafe { GetFileType(handle) } == FILE_TYPE_DISK; - if disk_only && !disk_device { - return false; - } - - if tested_type != PY_IFREG && tested_type != PY_IFDIR { - // For symlinks/junctions, need FileAttributeTagInfo to get reparse tag - let mut info: FILE_ATTRIBUTE_TAG_INFO = unsafe { core::mem::zeroed() }; - let ret = unsafe { - GetFileInformationByHandleEx( - handle, - FileAttributeTagInfoClass, - &mut info as *mut _ as *mut _, - core::mem::size_of::() as u32, - ) - }; - if ret == 0 { - return false; - } - _test_info( - info.FileAttributes, - info.ReparseTag, - disk_device, - tested_type, - ) - } else { - // For regular files/directories, FileBasicInfo is sufficient - let mut info: FILE_BASIC_INFO = unsafe { core::mem::zeroed() }; - let ret = unsafe { - GetFileInformationByHandleEx( - handle, - FileBasicInfo, - &mut info as *mut _ as *mut _, - core::mem::size_of::() as u32, - ) - }; - if ret == 0 { - return false; - } - _test_info(info.FileAttributes, 0, disk_device, tested_type) - } + host_nt::test_file_type_by_handle(handle, tested_type, disk_only) } /// _testFileTypeByName - test file type by path name fn _test_file_type_by_name(path: &std::path::Path, tested_type: u32) -> bool { - use crate::host_env::fileutils::windows::{ - FILE_INFO_BY_NAME_CLASS, get_file_information_by_name, - }; - use crate::host_env::windows::ToWideString; - use windows_sys::Win32::Foundation::{CloseHandle, INVALID_HANDLE_VALUE}; - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_ATTRIBUTE_REPARSE_POINT, FILE_FLAG_BACKUP_SEMANTICS, - FILE_FLAG_OPEN_REPARSE_POINT, FILE_READ_ATTRIBUTES, OPEN_EXISTING, + let tested_type = match tested_type { + PY_IFREG => host_nt::TestType::RegularFile, + PY_IFDIR => host_nt::TestType::Directory, + PY_IFLNK => host_nt::TestType::Symlink, + PY_IFMNT => host_nt::TestType::Junction, + PY_IFLRP => host_nt::TestType::LinkReparsePoint, + PY_IFRRP => host_nt::TestType::RegularReparsePoint, + _ => return false, }; - use windows_sys::Win32::Storage::FileSystem::{FILE_DEVICE_CD_ROM, FILE_DEVICE_DISK}; - use windows_sys::Win32::System::Ioctl::FILE_DEVICE_VIRTUAL_DISK; - - match get_file_information_by_name( - path.as_os_str(), - FILE_INFO_BY_NAME_CLASS::FileStatBasicByNameInfo, - ) { - Ok(info) => { - let disk_device = matches!( - info.DeviceType, - FILE_DEVICE_DISK | FILE_DEVICE_VIRTUAL_DISK | FILE_DEVICE_CD_ROM - ); - let result = _test_info( - info.FileAttributes, - info.ReparseTag, - disk_device, - tested_type, - ); - if !result - || (tested_type != PY_IFREG && tested_type != PY_IFDIR) - || (info.FileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) == 0 - { - return result; - } - } - Err(err) => { - if let Some(code) = err.raw_os_error() - && file_info_error_is_trustworthy(code as u32) - { - return false; - } - } - } - - let wide_path = path.to_wide_with_nul(); - - let mut flags = FILE_FLAG_BACKUP_SEMANTICS; - if tested_type != PY_IFREG && tested_type != PY_IFDIR { - flags |= FILE_FLAG_OPEN_REPARSE_POINT; - } - let handle = unsafe { - CreateFileW( - wide_path.as_ptr(), - FILE_READ_ATTRIBUTES, - 0, - core::ptr::null(), - OPEN_EXISTING, - flags, - core::ptr::null_mut(), - ) - }; - - if handle != INVALID_HANDLE_VALUE { - let result = _test_file_type_by_handle(handle, tested_type, false); - unsafe { CloseHandle(handle) }; - return result; - } - - match unsafe { windows_sys::Win32::Foundation::GetLastError() } { - windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED - | windows_sys::Win32::Foundation::ERROR_SHARING_VIOLATION - | windows_sys::Win32::Foundation::ERROR_CANT_ACCESS_FILE - | windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER => { - let stat = if tested_type == PY_IFREG || tested_type == PY_IFDIR { - crate::windows::win32_xstat(path.as_os_str(), true) - } else { - crate::windows::win32_xstat(path.as_os_str(), false) - }; - if let Ok(st) = stat { - let disk_device = (st.st_mode & libc::S_IFREG as u16) != 0; - return _test_info( - st.st_file_attributes, - st.st_reparse_tag, - disk_device, - tested_type, - ); - } - } - _ => {} - } - - false + host_nt::test_file_type_by_name(path, tested_type) } /// _testFileExistsByName - test if path exists fn _test_file_exists_by_name(path: &std::path::Path, follow_links: bool) -> bool { - use crate::host_env::fileutils::windows::{ - FILE_INFO_BY_NAME_CLASS, get_file_information_by_name, - }; - use crate::host_env::windows::ToWideString; - use windows_sys::Win32::Foundation::{CloseHandle, INVALID_HANDLE_VALUE}; - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_ATTRIBUTE_REPARSE_POINT, FILE_FLAG_BACKUP_SEMANTICS, - FILE_FLAG_OPEN_REPARSE_POINT, FILE_READ_ATTRIBUTES, OPEN_EXISTING, - }; - - match get_file_information_by_name( - path.as_os_str(), - FILE_INFO_BY_NAME_CLASS::FileStatBasicByNameInfo, - ) { - Ok(info) => { - if (info.FileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) == 0 - || (!follow_links && is_reparse_tag_name_surrogate(info.ReparseTag)) - { - return true; - } - } - Err(err) => { - if let Some(code) = err.raw_os_error() - && file_info_error_is_trustworthy(code as u32) - { - return false; - } - } - } - - let wide_path = path.to_wide_with_nul(); - let mut flags = FILE_FLAG_BACKUP_SEMANTICS; - if !follow_links { - flags |= FILE_FLAG_OPEN_REPARSE_POINT; - } - let handle = unsafe { - CreateFileW( - wide_path.as_ptr(), - FILE_READ_ATTRIBUTES, - 0, - core::ptr::null(), - OPEN_EXISTING, - flags, - core::ptr::null_mut(), - ) - }; - if handle != INVALID_HANDLE_VALUE { - if follow_links { - unsafe { CloseHandle(handle) }; - return true; - } - let is_regular_reparse_point = _test_file_type_by_handle(handle, PY_IFRRP, false); - unsafe { CloseHandle(handle) }; - if !is_regular_reparse_point { - return true; - } - let handle = unsafe { - CreateFileW( - wide_path.as_ptr(), - FILE_READ_ATTRIBUTES, - 0, - core::ptr::null(), - OPEN_EXISTING, - FILE_FLAG_BACKUP_SEMANTICS, - core::ptr::null_mut(), - ) - }; - if handle != INVALID_HANDLE_VALUE { - unsafe { CloseHandle(handle) }; - return true; - } - } - - match unsafe { windows_sys::Win32::Foundation::GetLastError() } { - windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED - | windows_sys::Win32::Foundation::ERROR_SHARING_VIOLATION - | windows_sys::Win32::Foundation::ERROR_CANT_ACCESS_FILE - | windows_sys::Win32::Foundation::ERROR_INVALID_PARAMETER => { - let stat = crate::windows::win32_xstat(path.as_os_str(), follow_links); - return stat.is_ok(); - } - _ => {} - } - - false + host_nt::test_file_exists_by_name(path, follow_links) } /// _testFileType wrapper - handles both fd and path @@ -715,23 +308,8 @@ pub(crate) mod module { /// _testFileExists wrapper - handles both fd and path fn _test_file_exists(path_or_fd: &OsPathOrFd<'_>, follow_links: bool) -> bool { - use windows_sys::Win32::Storage::FileSystem::{FILE_TYPE_UNKNOWN, GetFileType}; - match path_or_fd { - OsPathOrFd::Fd(fd) => { - if let Ok(handle) = crate::host_env::crt_fd::as_handle(*fd) { - use std::os::windows::io::AsRawHandle; - let file_type = unsafe { GetFileType(handle.as_raw_handle() as _) }; - // GetFileType(hfile) != FILE_TYPE_UNKNOWN || !GetLastError() - if file_type != FILE_TYPE_UNKNOWN { - return true; - } - // Check if GetLastError is 0 (no error means valid handle) - unsafe { windows_sys::Win32::Foundation::GetLastError() == 0 } - } else { - false - } - } + OsPathOrFd::Fd(fd) => host_nt::fd_exists(*fd), OsPathOrFd::Path(path) => _test_file_exists_by_name(path.as_ref(), follow_links), } } @@ -787,113 +365,18 @@ pub(crate) mod module { /// Check if a path is on a Windows Dev Drive. #[pyfunction] fn _path_isdevdrive(path: OsPath, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::CloseHandle; - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_FLAG_BACKUP_SEMANTICS, FILE_READ_ATTRIBUTES, FILE_SHARE_READ, - FILE_SHARE_WRITE, GetDriveTypeW, GetVolumePathNameW, OPEN_EXISTING, - }; - use windows_sys::Win32::System::IO::DeviceIoControl; - use windows_sys::Win32::System::Ioctl::FSCTL_QUERY_PERSISTENT_VOLUME_STATE; - use windows_sys::Win32::System::WindowsProgramming::DRIVE_FIXED; - - // PERSISTENT_VOLUME_STATE_DEV_VOLUME flag - not yet in windows-sys - const PERSISTENT_VOLUME_STATE_DEV_VOLUME: u32 = 0x00002000; - - // FILE_FS_PERSISTENT_VOLUME_INFORMATION structure - #[repr(C)] - struct FileFsPersistentVolumeInformation { - volume_flags: u32, - flag_mask: u32, - version: u32, - reserved: u32, - } - - let wide_path = path.to_wide_cstring(vm)?; - let mut volume = [0u16; Foundation::MAX_PATH as usize]; - - // Get volume path - let ret = unsafe { - GetVolumePathNameW(wide_path.as_ptr(), volume.as_mut_ptr(), volume.len() as _) - }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } - - // Check if it's a fixed drive - if unsafe { GetDriveTypeW(volume.as_ptr()) } != DRIVE_FIXED { - return Ok(false); - } - - // Open the volume - let handle = unsafe { - CreateFileW( - volume.as_ptr(), - FILE_READ_ATTRIBUTES, - FILE_SHARE_READ | FILE_SHARE_WRITE, - core::ptr::null(), - OPEN_EXISTING, - FILE_FLAG_BACKUP_SEMANTICS, - core::ptr::null_mut(), - ) - }; - if handle == INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); - } - - // Query persistent volume state - let mut volume_state = FileFsPersistentVolumeInformation { - volume_flags: 0, - flag_mask: PERSISTENT_VOLUME_STATE_DEV_VOLUME, - version: 1, - reserved: 0, - }; - - let ret = unsafe { - DeviceIoControl( - handle, - FSCTL_QUERY_PERSISTENT_VOLUME_STATE, - &volume_state as *const _ as *const core::ffi::c_void, - core::mem::size_of::() as u32, - &mut volume_state as *mut _ as *mut core::ffi::c_void, - core::mem::size_of::() as u32, - core::ptr::null_mut(), - core::ptr::null_mut(), - ) - }; - - unsafe { CloseHandle(handle) }; - - if ret == 0 { - let err = io::Error::last_os_error(); - // ERROR_INVALID_PARAMETER means not supported on this platform - if err.raw_os_error() == Some(Foundation::ERROR_INVALID_PARAMETER as i32) { - return Ok(false); - } - return Err(err.to_pyexception(vm)); - } - - Ok((volume_state.volume_flags & PERSISTENT_VOLUME_STATE_DEV_VOLUME) != 0) - } - - // cwait is available on MSVC only - #[cfg(target_env = "msvc")] - unsafe extern "C" { - fn _cwait(termstat: *mut i32, procHandle: intptr_t, action: i32) -> intptr_t; + let _ = path.to_wide_cstring(vm)?; + host_nt::path_isdevdrive(path.as_ref()).map_err(|err| err.to_pyexception(vm)) } #[cfg(target_env = "msvc")] #[pyfunction] fn waitpid(pid: intptr_t, opt: i32, vm: &VirtualMachine) -> PyResult<(intptr_t, u64)> { - let mut status: i32 = 0; - let pid = unsafe { suppress_iph!(_cwait(&mut status, pid, opt)) }; - if pid == -1 { - Err(vm.new_last_errno_error()) - } else { - // Cast to unsigned to handle large exit codes (like 0xC000013A) - // then shift left by 8 to match POSIX waitpid format - let ustatus = (status as u32) as u64; - Ok((pid, ustatus << 8)) - } + let (pid, status) = host_nt::cwait(pid, opt).map_err(|_| vm.new_last_errno_error())?; + // Cast to unsigned to handle large exit codes (like 0xC000013A) + // then shift left by 8 to match POSIX waitpid format + let ustatus = (status as u32) as u64; + Ok((pid, ustatus << 8)) } #[cfg(target_env = "msvc")] @@ -904,31 +387,7 @@ pub(crate) mod module { #[pyfunction] fn kill(pid: i32, sig: isize, vm: &VirtualMachine) -> PyResult<()> { - let sig = sig as u32; - let pid = pid as u32; - - if sig == Console::CTRL_C_EVENT || sig == Console::CTRL_BREAK_EVENT { - let ret = unsafe { Console::GenerateConsoleCtrlEvent(sig, pid) }; - let res = if ret == 0 { - Err(vm.new_last_os_error()) - } else { - Ok(()) - }; - return res; - } - - let h = unsafe { Threading::OpenProcess(Threading::PROCESS_ALL_ACCESS, 0, pid) }; - if h.is_null() { - return Err(vm.new_last_os_error()); - } - let ret = unsafe { Threading::TerminateProcess(h, sig) }; - let res = if ret == 0 { - Err(vm.new_last_os_error()) - } else { - Ok(()) - }; - unsafe { Foundation::CloseHandle(h) }; - res + host_nt::kill(pid as u32, sig as u32).map_err(|err| err.to_pyexception(vm)) } #[pyfunction] @@ -937,68 +396,13 @@ pub(crate) mod module { vm: &VirtualMachine, ) -> PyResult<_os::TerminalSizeData> { let fd = fd.unwrap_or(1); // default to stdout - - // Use _get_osfhandle for all fds let borrowed = unsafe { crt_fd::Borrowed::borrow_raw(fd) }; let handle = crt_fd::as_handle(borrowed).map_err(|e| e.to_pyexception(vm))?; - let h = handle.as_raw_handle() as Foundation::HANDLE; - - let mut csbi = MaybeUninit::uninit(); - let ret = unsafe { Console::GetConsoleScreenBufferInfo(h, csbi.as_mut_ptr()) }; - if ret == 0 { - // Check if error is due to lack of read access on a console handle - // ERROR_ACCESS_DENIED (5) means it's a console but without read permission - // In that case, try opening CONOUT$ directly with read access - let err = unsafe { Foundation::GetLastError() }; - if err != Foundation::ERROR_ACCESS_DENIED { - return Err(vm.new_last_os_error()); - } - let conout: Vec = "CONOUT$\0".encode_utf16().collect(); - let console_handle = unsafe { - FileSystem::CreateFileW( - conout.as_ptr(), - Foundation::GENERIC_READ | Foundation::GENERIC_WRITE, - FileSystem::FILE_SHARE_READ | FileSystem::FILE_SHARE_WRITE, - core::ptr::null(), - FileSystem::OPEN_EXISTING, - 0, - core::ptr::null_mut(), - ) - }; - if console_handle == INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); - } - let ret = - unsafe { Console::GetConsoleScreenBufferInfo(console_handle, csbi.as_mut_ptr()) }; - unsafe { Foundation::CloseHandle(console_handle) }; - if ret == 0 { - return Err(vm.new_last_os_error()); - } - } - let csbi = unsafe { csbi.assume_init() }; - let w = csbi.srWindow; - let columns = (w.Right - w.Left + 1) as usize; - let lines = (w.Bottom - w.Top + 1) as usize; + let (columns, lines) = host_nt::get_terminal_size_handle(handle.as_raw_handle() as _) + .map_err(|_| vm.new_last_os_error())?; Ok(_os::TerminalSizeData { columns, lines }) } - #[cfg(target_env = "msvc")] - unsafe extern "C" { - fn _wexecv(cmdname: *const u16, argv: *const *const u16) -> intptr_t; - fn _wexecve( - cmdname: *const u16, - argv: *const *const u16, - envp: *const *const u16, - ) -> intptr_t; - fn _wspawnv(mode: i32, cmdname: *const u16, argv: *const *const u16) -> intptr_t; - fn _wspawnve( - mode: i32, - cmdname: *const u16, - argv: *const *const u16, - envp: *const *const u16, - ) -> intptr_t; - } - #[cfg(target_env = "msvc")] #[pyfunction] fn spawnv( @@ -1008,7 +412,6 @@ pub(crate) mod module { vm: &VirtualMachine, ) -> PyResult { use crate::function::FsPath; - use core::iter::once; let path = path.to_wide_cstring(vm)?; @@ -1025,18 +428,8 @@ pub(crate) mod module { return Err(vm.new_value_error("spawnv() arg 3 first element cannot be empty")); } - let argv_spawn: Vec<*const u16> = argv - .iter() - .map(|v| v.as_ptr()) - .chain(once(core::ptr::null())) - .collect(); - - let result = unsafe { suppress_iph!(_wspawnv(mode, path.as_ptr(), argv_spawn.as_ptr())) }; - if result == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(result) - } + let argv_refs: Vec<&widestring::WideCStr> = argv.iter().map(|s| s.as_ref()).collect(); + host_nt::spawnv(mode, &path, &argv_refs).map_err(|_| vm.new_last_errno_error()) } #[cfg(target_env = "msvc")] @@ -1049,7 +442,6 @@ pub(crate) mod module { vm: &VirtualMachine, ) -> PyResult { use crate::function::FsPath; - use core::iter::once; let path = path.to_wide_cstring(vm)?; @@ -1066,12 +458,6 @@ pub(crate) mod module { return Err(vm.new_value_error("spawnve() arg 2 first element cannot be empty")); } - let argv_spawn: Vec<*const u16> = argv - .iter() - .map(|v| v.as_ptr()) - .chain(once(core::ptr::null())) - .collect(); - // Build environment strings as "KEY=VALUE\0" wide strings let mut env_strings: Vec = Vec::new(); for (key, value) in env { @@ -1094,25 +480,10 @@ pub(crate) mod module { ); } - let envp: Vec<*const u16> = env_strings - .iter() - .map(|s| s.as_ptr()) - .chain(once(core::ptr::null())) - .collect(); - - let result = unsafe { - suppress_iph!(_wspawnve( - mode, - path.as_ptr(), - argv_spawn.as_ptr(), - envp.as_ptr() - )) - }; - if result == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(result) - } + let argv_refs: Vec<&widestring::WideCStr> = argv.iter().map(|s| s.as_ref()).collect(); + let envp_refs: Vec<&widestring::WideCStr> = + env_strings.iter().map(|s| s.as_ref()).collect(); + host_nt::spawnve(mode, &path, &argv_refs, &envp_refs).map_err(|_| vm.new_last_errno_error()) } #[cfg(target_env = "msvc")] @@ -1122,8 +493,6 @@ pub(crate) mod module { argv: Either, vm: &VirtualMachine, ) -> PyResult<()> { - use core::iter::once; - let make_widestring = |s: &str| widestring::WideCString::from_os_str(s).map_err(|err| err.to_pyexception(vm)); @@ -1142,17 +511,8 @@ pub(crate) mod module { return Err(vm.new_value_error("execv() arg 2 first element cannot be empty")); } - let argv_execv: Vec<*const u16> = argv - .iter() - .map(|v| v.as_ptr()) - .chain(once(core::ptr::null())) - .collect(); - - if (unsafe { suppress_iph!(_wexecv(path.as_ptr(), argv_execv.as_ptr())) } == -1) { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + let argv_refs: Vec<&widestring::WideCStr> = argv.iter().map(|s| s.as_ref()).collect(); + host_nt::execv(&path, &argv_refs).map_err(|_| vm.new_last_errno_error()) } #[cfg(target_env = "msvc")] @@ -1163,8 +523,6 @@ pub(crate) mod module { env: ArgMapping, vm: &VirtualMachine, ) -> PyResult<()> { - use core::iter::once; - let make_widestring = |s: &str| widestring::WideCString::from_os_str(s).map_err(|err| err.to_pyexception(vm)); @@ -1183,12 +541,6 @@ pub(crate) mod module { return Err(vm.new_value_error("execve: argv first element cannot be empty")); } - let argv_execve: Vec<*const u16> = argv - .iter() - .map(|v| v.as_ptr()) - .chain(once(core::ptr::null())) - .collect(); - let env = crate::stdlib::os::envobj_to_dict(env, vm)?; // Build environment strings as "KEY=VALUE\0" wide strings let mut env_strings: Vec = Vec::new(); @@ -1213,123 +565,38 @@ pub(crate) mod module { env_strings.push(make_widestring(&env_str)?); } - let envp: Vec<*const u16> = env_strings - .iter() - .map(|s| s.as_ptr()) - .chain(once(core::ptr::null())) - .collect(); - - if (unsafe { suppress_iph!(_wexecve(path.as_ptr(), argv_execve.as_ptr(), envp.as_ptr())) } - == -1) - { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + let argv_refs: Vec<&widestring::WideCStr> = argv.iter().map(|s| s.as_ref()).collect(); + let envp_refs: Vec<&widestring::WideCStr> = + env_strings.iter().map(|s| s.as_ref()).collect(); + host_nt::execve(&path, &argv_refs, &envp_refs).map_err(|_| vm.new_last_errno_error()) } #[pyfunction] fn _getfinalpathname(path: OsPath, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_FLAG_BACKUP_SEMANTICS, GetFinalPathNameByHandleW, OPEN_EXISTING, - VOLUME_NAME_DOS, - }; - - let wide = path.to_wide_cstring(vm)?; - let handle = unsafe { - CreateFileW( - wide.as_ptr(), - 0, - 0, - core::ptr::null(), - OPEN_EXISTING, - FILE_FLAG_BACKUP_SEMANTICS, - core::ptr::null_mut(), - ) - }; - if handle == INVALID_HANDLE_VALUE { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - - let mut buffer: Vec = vec![0; Foundation::MAX_PATH as usize]; - let result = loop { - let ret = unsafe { - GetFinalPathNameByHandleW( - handle, - buffer.as_mut_ptr(), - buffer.len() as u32, - VOLUME_NAME_DOS, - ) - }; - if ret == 0 { - let err = io::Error::last_os_error(); - let _ = unsafe { Foundation::CloseHandle(handle) }; - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - if (ret as usize) < buffer.len() { - let final_path = std::ffi::OsString::from_wide(&buffer[..ret as usize]); - break Ok(path.mode().process_path(final_path, vm)); - } - buffer.resize(ret as usize, 0); - }; - - unsafe { Foundation::CloseHandle(handle) }; - result + let _ = path.to_wide_cstring(vm)?; + let final_path = host_nt::getfinalpathname(path.as_ref()) + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; + Ok(path.mode().process_path(final_path, vm)) } #[pyfunction] fn _getfullpathname(path: OsPath, vm: &VirtualMachine) -> PyResult { - let wpath = path.to_wide_cstring(vm)?; - let mut buffer = vec![0u16; Foundation::MAX_PATH as usize]; - let ret = unsafe { - FileSystem::GetFullPathNameW( - wpath.as_ptr(), - buffer.len() as _, - buffer.as_mut_ptr(), - core::ptr::null_mut(), - ) - }; - if ret == 0 { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - if ret as usize > buffer.len() { - buffer.resize(ret as usize, 0); - let ret = unsafe { - FileSystem::GetFullPathNameW( - wpath.as_ptr(), - buffer.len() as _, - buffer.as_mut_ptr(), - core::ptr::null_mut(), - ) - }; - if ret == 0 { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - } - let buffer = widestring::WideCString::from_vec_truncate(buffer); - Ok(path.mode().process_path(buffer.to_os_string(), vm)) + let _ = path.to_wide_cstring(vm)?; + let buffer = host_nt::getfullpathname(path.as_ref()) + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; + Ok(path.mode().process_path(buffer, vm)) } #[pyfunction] fn _getvolumepathname(path: OsPath, vm: &VirtualMachine) -> PyResult { let wide = path.to_wide_cstring(vm)?; - let buflen = core::cmp::max(wide.len(), Foundation::MAX_PATH as usize); + let buflen = core::cmp::max(wide.len(), host_nt::MAX_PATH_USIZE); if buflen > u32::MAX as usize { return Err(vm.new_overflow_error("path too long")); } - let mut buffer = vec![0u16; buflen]; - let ret = unsafe { - FileSystem::GetVolumePathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buflen as _) - }; - if ret == 0 { - let err = io::Error::last_os_error(); - return Err(OSErrorBuilder::with_filename(&err, path, vm)); - } - let buffer = widestring::WideCString::from_vec_truncate(buffer); - Ok(path.mode().process_path(buffer.to_os_string(), vm)) + let buffer = host_nt::getvolumepathname(path.as_ref()) + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; + Ok(path.mode().process_path(buffer, vm)) } /// Implements _Py_skiproot logic for Windows paths @@ -1488,15 +755,9 @@ pub(crate) mod module { .chain(core::iter::once(0)) // null-terminated .collect(); - let mut end: *const u16 = core::ptr::null(); - let hr = unsafe { - windows_sys::Win32::UI::Shell::PathCchSkipRoot(backslashed.as_ptr(), &mut end) - }; - if hr >= 0 { - assert!(!end.is_null()); - let len: usize = unsafe { end.offset_from(backslashed.as_ptr()) } - .try_into() - .expect("len must be non-negative"); + let backslashed_wide = widestring::WideCStr::from_slice_truncate(&backslashed) + .expect("backslashed is null-terminated"); + if let Some(len) = host_nt::path_skip_root(backslashed_wide) { assert!( len < backslashed.len(), // backslashed is null-terminated "path: {:?} {} < {}", @@ -1684,43 +945,13 @@ pub(crate) mod module { #[pyfunction] fn _getdiskusage(path: OsPath, vm: &VirtualMachine) -> PyResult<(u64, u64)> { - use FileSystem::GetDiskFreeSpaceExW; - - let wpath = path.to_wide_cstring(vm)?; - let mut _free_to_me: u64 = 0; - let mut total: u64 = 0; - let mut free: u64 = 0; - let ret = - unsafe { GetDiskFreeSpaceExW(wpath.as_ptr(), &mut _free_to_me, &mut total, &mut free) }; - if ret != 0 { - return Ok((total, free)); - } - let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(Foundation::ERROR_DIRECTORY as i32) - && let Some(parent) = path.as_ref().parent() - { - let parent = widestring::WideCString::from_os_str(parent).unwrap(); - - let ret = unsafe { - GetDiskFreeSpaceExW(parent.as_ptr(), &mut _free_to_me, &mut total, &mut free) - }; - - return if ret == 0 { - Err(err.to_pyexception(vm)) - } else { - Ok((total, free)) - }; - } - Err(err.to_pyexception(vm)) + let _ = path.to_wide_cstring(vm)?; + host_nt::getdiskusage(path.as_ref()).map_err(|err| err.to_pyexception(vm)) } #[pyfunction] fn get_handle_inheritable(handle: intptr_t, vm: &VirtualMachine) -> PyResult { - let mut flags = 0; - if unsafe { Foundation::GetHandleInformation(handle as _, &mut flags) } == 0 { - return Err(vm.new_last_os_error()); - } - Ok(flags & Foundation::HANDLE_FLAG_INHERIT != 0) + host_nt::get_handle_inheritable(handle).map_err(|err| err.to_pyexception(vm)) } #[pyfunction] @@ -1732,139 +963,41 @@ pub(crate) mod module { #[pyfunction] fn getlogin(vm: &VirtualMachine) -> PyResult { - let mut buffer = [0u16; 257]; - let mut size = buffer.len() as u32; - - let success = unsafe { - windows_sys::Win32::System::WindowsProgramming::GetUserNameW( - buffer.as_mut_ptr(), - &mut size, - ) - }; - - if success != 0 { - // Convert the buffer (which is UTF-16) to a Rust String - let username = std::ffi::OsString::from_wide(&buffer[..(size - 1) as usize]); - Ok(username.to_str().unwrap().to_string()) - } else { - Err(vm.new_os_error(format!("Error code: {success}"))) - } + host_nt::getlogin().map_err(|_| vm.new_os_error("Error code: 0".to_owned())) } pub fn raw_set_handle_inheritable(handle: intptr_t, inheritable: bool) -> std::io::Result<()> { - let flags = if inheritable { - Foundation::HANDLE_FLAG_INHERIT - } else { - 0 - }; - let res = unsafe { - Foundation::SetHandleInformation(handle as _, Foundation::HANDLE_FLAG_INHERIT, flags) - }; - if res == 0 { - Err(std::io::Error::last_os_error()) - } else { - Ok(()) - } + host_nt::set_handle_inheritable(handle, inheritable) } #[pyfunction] fn listdrives(vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::ERROR_MORE_DATA; - - let mut buffer = [0u16; 256]; - let len = - unsafe { FileSystem::GetLogicalDriveStringsW(buffer.len() as _, buffer.as_mut_ptr()) }; - if len == 0 { - return Err(vm.new_last_os_error()); - } - if len as usize >= buffer.len() { - return Err(std::io::Error::from_raw_os_error(ERROR_MORE_DATA as _).to_pyexception(vm)); - } - let drives: Vec<_> = buffer[..(len - 1) as usize] - .split(|&c| c == 0) - .map(|drive| vm.new_pyobj(String::from_utf16_lossy(drive))) + let drives: Vec<_> = host_nt::listdrives() + .map_err(|err| err.to_pyexception(vm))? + .into_iter() + .map(|drive| vm.new_pyobj(drive.to_string_lossy().into_owned())) .collect(); Ok(vm.ctx.new_list(drives)) } #[pyfunction] fn listvolumes(vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::ERROR_NO_MORE_FILES; - - let mut result = Vec::new(); - let mut buffer = [0u16; Foundation::MAX_PATH as usize + 1]; - - let find = unsafe { FileSystem::FindFirstVolumeW(buffer.as_mut_ptr(), buffer.len() as _) }; - if find == INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); - } - - loop { - // Find the null terminator - let len = buffer.iter().position(|&c| c == 0).unwrap_or(buffer.len()); - let volume = String::from_utf16_lossy(&buffer[..len]); - result.push(vm.new_pyobj(volume)); - - let ret = unsafe { - FileSystem::FindNextVolumeW(find, buffer.as_mut_ptr(), buffer.len() as _) - }; - if ret == 0 { - let err = io::Error::last_os_error(); - unsafe { FileSystem::FindVolumeClose(find) }; - if err.raw_os_error() == Some(ERROR_NO_MORE_FILES as i32) { - break; - } - return Err(err.to_pyexception(vm)); - } - } - + let result = host_nt::listvolumes() + .map_err(|err| err.to_pyexception(vm))? + .into_iter() + .map(|volume| vm.new_pyobj(volume.to_string_lossy().into_owned())) + .collect(); Ok(vm.ctx.new_list(result)) } #[pyfunction] fn listmounts(volume: OsPath, vm: &VirtualMachine) -> PyResult { - use windows_sys::Win32::Foundation::ERROR_MORE_DATA; - - let wide = volume.to_wide_cstring(vm)?; - let mut buflen: u32 = Foundation::MAX_PATH + 1; - let mut buffer: Vec = vec![0; buflen as usize]; - - loop { - let success = unsafe { - FileSystem::GetVolumePathNamesForVolumeNameW( - wide.as_ptr(), - buffer.as_mut_ptr(), - buflen, - &mut buflen, - ) - }; - if success != 0 { - break; - } - let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(ERROR_MORE_DATA as i32) { - buffer.resize(buflen as usize, 0); - continue; - } - return Err(err.to_pyexception(vm)); - } - - // Parse null-separated strings - let mut result = Vec::new(); - let mut start = 0; - for (i, &c) in buffer.iter().enumerate() { - if c == 0 { - if i > start { - let mount = String::from_utf16_lossy(&buffer[start..i]); - result.push(vm.new_pyobj(mount)); - } - start = i + 1; - if start < buffer.len() && buffer[start] == 0 { - break; // Double null = end - } - } - } - + let _ = volume.to_wide_cstring(vm)?; + let result = host_nt::listmounts(volume.as_ref()) + .map_err(|err| err.to_pyexception(vm))? + .into_iter() + .map(|mount| vm.new_pyobj(mount.to_string_lossy().into_owned())) + .collect(); Ok(vm.ctx.new_list(result)) } @@ -1889,182 +1022,29 @@ pub(crate) mod module { #[pyfunction] fn mkdir(args: MkdirArgs<'_>, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Foundation::LocalFree; - use windows_sys::Win32::Security::Authorization::{ - ConvertStringSecurityDescriptorToSecurityDescriptorW, SDDL_REVISION_1, - }; - use windows_sys::Win32::Security::SECURITY_ATTRIBUTES; - let [] = args.dir_fd.0; let wide = args.path.to_wide_cstring(vm)?; - - // special case: mode 0o700 sets a protected ACL - let res = if args.mode == 0o700 { - let mut sec_attr = SECURITY_ATTRIBUTES { - nLength: core::mem::size_of::() as u32, - lpSecurityDescriptor: core::ptr::null_mut(), - bInheritHandle: 0, - }; - // Set a discretionary ACL (D) that is protected (P) and includes - // inheritable (OICI) entries that allow (A) full control (FA) to - // SYSTEM (SY), Administrators (BA), and the owner (OW). - let sddl: Vec = "D:P(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)(A;OICI;FA;;;OW)\0" - .encode_utf16() - .collect(); - let convert_result = unsafe { - ConvertStringSecurityDescriptorToSecurityDescriptorW( - sddl.as_ptr(), - SDDL_REVISION_1, - &mut sec_attr.lpSecurityDescriptor, - core::ptr::null_mut(), - ) - }; - if convert_result == 0 { - return Err(vm.new_last_os_error()); - } - let res = - unsafe { FileSystem::CreateDirectoryW(wide.as_ptr(), &sec_attr as *const _ as _) }; - unsafe { LocalFree(sec_attr.lpSecurityDescriptor) }; - res - } else { - unsafe { FileSystem::CreateDirectoryW(wide.as_ptr(), core::ptr::null_mut()) } - }; - - if res == 0 { - return Err(vm.new_last_os_error()); - } - Ok(()) - } - - unsafe extern "C" { - fn _umask(mask: i32) -> i32; - } - - /// Close fd and convert error to PyException (PEP 446 cleanup) - #[cold] - fn close_fd_and_raise(fd: i32, err: std::io::Error, vm: &VirtualMachine) -> PyBaseExceptionRef { - let _ = unsafe { crt_fd::Owned::from_raw(fd) }; - err.to_pyexception(vm) + host_nt::mkdir(&wide, args.mode).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn umask(mask: i32, vm: &VirtualMachine) -> PyResult { - let result = unsafe { _umask(mask) }; - if result < 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(result) - } + host_nt::umask(mask).map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn pipe(vm: &VirtualMachine) -> PyResult<(i32, i32)> { - use windows_sys::Win32::Security::SECURITY_ATTRIBUTES; - use windows_sys::Win32::System::Pipes::CreatePipe; - - let mut attr = SECURITY_ATTRIBUTES { - nLength: core::mem::size_of::() as u32, - lpSecurityDescriptor: core::ptr::null_mut(), - bInheritHandle: 0, - }; - - let (read_handle, write_handle) = unsafe { - let mut read = MaybeUninit::::uninit(); - let mut write = MaybeUninit::::uninit(); - let res = CreatePipe( - read.as_mut_ptr() as *mut _, - write.as_mut_ptr() as *mut _, - &mut attr as *mut _, - 0, - ); - if res == 0 { - return Err(vm.new_last_os_error()); - } - (read.assume_init(), write.assume_init()) - }; - - // Convert handles to file descriptors - // O_NOINHERIT = 0x80 (MSVC CRT) - const O_NOINHERIT: i32 = 0x80; - let read_fd = unsafe { libc::open_osfhandle(read_handle, O_NOINHERIT) }; - let write_fd = unsafe { libc::open_osfhandle(write_handle, libc::O_WRONLY | O_NOINHERIT) }; - - if read_fd == -1 || write_fd == -1 { - unsafe { - Foundation::CloseHandle(read_handle as _); - Foundation::CloseHandle(write_handle as _); - } - return Err(vm.new_last_os_error()); - } - - Ok((read_fd, write_fd)) + host_nt::pipe().map_err(|e| e.to_pyexception(vm)) } #[pyfunction] fn getppid() -> u32 { - use windows_sys::Win32::System::Threading::{GetCurrentProcess, PROCESS_BASIC_INFORMATION}; - - type NtQueryInformationProcessFn = unsafe extern "system" fn( - process_handle: isize, - process_information_class: u32, - process_information: *mut core::ffi::c_void, - process_information_length: u32, - return_length: *mut u32, - ) -> i32; - - let ntdll = unsafe { - windows_sys::Win32::System::LibraryLoader::GetModuleHandleW(windows_sys::w!( - "ntdll.dll" - )) - }; - if ntdll.is_null() { - return 0; - } - - let func = unsafe { - windows_sys::Win32::System::LibraryLoader::GetProcAddress( - ntdll, - c"NtQueryInformationProcess".as_ptr() as *const u8, - ) - }; - let Some(func) = func else { - return 0; - }; - let nt_query: NtQueryInformationProcessFn = unsafe { core::mem::transmute(func) }; - - let mut info: PROCESS_BASIC_INFORMATION = unsafe { core::mem::zeroed() }; - - let status = unsafe { - nt_query( - GetCurrentProcess() as isize, - 0, // ProcessBasicInformation - &mut info as *mut _ as *mut core::ffi::c_void, - core::mem::size_of::() as u32, - core::ptr::null_mut(), - ) - }; - - if status >= 0 - && info.InheritedFromUniqueProcessId != 0 - && info.InheritedFromUniqueProcessId < u32::MAX as usize - { - info.InheritedFromUniqueProcessId as u32 - } else { - 0 - } + host_nt::getppid() } #[pyfunction] fn dup(fd: i32, vm: &VirtualMachine) -> PyResult { - let fd2 = unsafe { suppress_iph!(libc::dup(fd)) }; - if fd2 < 0 { - return Err(vm.new_last_errno_error()); - } - let borrowed = unsafe { crt_fd::Borrowed::borrow_raw(fd2) }; - let handle = crt_fd::as_handle(borrowed).map_err(|e| close_fd_and_raise(fd2, e, vm))?; - raw_set_handle_inheritable(handle.as_raw_handle() as _, false) - .map_err(|e| close_fd_and_raise(fd2, e, vm))?; - Ok(fd2) + host_nt::dup(fd).map_err(|e| e.to_pyexception(vm)) } #[derive(FromArgs)] @@ -2079,152 +1059,21 @@ pub(crate) mod module { #[pyfunction] fn dup2(args: Dup2Args, vm: &VirtualMachine) -> PyResult { - let result = unsafe { suppress_iph!(libc::dup2(args.fd, args.fd2)) }; - if result < 0 { - return Err(vm.new_last_errno_error()); - } - if !args.inheritable { - let borrowed = unsafe { crt_fd::Borrowed::borrow_raw(args.fd2) }; - let handle = - crt_fd::as_handle(borrowed).map_err(|e| close_fd_and_raise(args.fd2, e, vm))?; - raw_set_handle_inheritable(handle.as_raw_handle() as _, false) - .map_err(|e| close_fd_and_raise(args.fd2, e, vm))?; - } - Ok(args.fd2) + host_nt::dup2(args.fd, args.fd2, args.inheritable).map_err(|e| e.to_pyexception(vm)) } /// Windows-specific readlink that preserves \\?\ prefix for junctions /// returns the substitute name from reparse data which includes the prefix #[pyfunction] fn readlink(path: OsPath, vm: &VirtualMachine) -> PyResult { - use crate::host_env::windows::ToWideString; - use windows_sys::Win32::Foundation::CloseHandle; - use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_FLAG_BACKUP_SEMANTICS, FILE_FLAG_OPEN_REPARSE_POINT, - FILE_SHARE_DELETE, FILE_SHARE_READ, FILE_SHARE_WRITE, OPEN_EXISTING, - }; - use windows_sys::Win32::System::IO::DeviceIoControl; - use windows_sys::Win32::System::Ioctl::FSCTL_GET_REPARSE_POINT; - let mode = path.mode(); - let wide_path = path.as_ref().to_wide_with_nul(); - - // Open the file/directory with reparse point flag - let handle = unsafe { - CreateFileW( - wide_path.as_ptr(), - 0, // No access needed, just reading reparse data - FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - core::ptr::null(), - OPEN_EXISTING, - FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, - core::ptr::null_mut(), - ) - }; - - if handle == INVALID_HANDLE_VALUE { - return Err(OSErrorBuilder::with_filename( - &io::Error::last_os_error(), - path, - vm, - )); - } - - // Buffer for reparse data - MAXIMUM_REPARSE_DATA_BUFFER_SIZE is 16384 - const BUFFER_SIZE: usize = 16384; - let mut buffer = vec![0u8; BUFFER_SIZE]; - let mut bytes_returned: u32 = 0; - - let result = unsafe { - DeviceIoControl( - handle, - FSCTL_GET_REPARSE_POINT, - core::ptr::null(), - 0, - buffer.as_mut_ptr() as *mut _, - BUFFER_SIZE as u32, - &mut bytes_returned, - core::ptr::null_mut(), - ) - }; - - unsafe { CloseHandle(handle) }; - - if result == 0 { - return Err(OSErrorBuilder::with_filename( - &io::Error::last_os_error(), - path, - vm, - )); - } - - // Parse the reparse data buffer - // REPARSE_DATA_BUFFER structure: - // DWORD ReparseTag - // WORD ReparseDataLength - // WORD Reserved - // For symlinks/junctions (IO_REPARSE_TAG_SYMLINK/MOUNT_POINT): - // WORD SubstituteNameOffset - // WORD SubstituteNameLength - // WORD PrintNameOffset - // WORD PrintNameLength - // (For symlinks only: DWORD Flags) - // PathBuffer... - - let reparse_tag = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); - - // Check if it's a symlink or mount point (junction) - use windows_sys::Win32::System::SystemServices::{ - IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, - }; - - let (substitute_offset, substitute_length, path_buffer_start) = - if reparse_tag == IO_REPARSE_TAG_SYMLINK { - // Symlink has Flags field (4 bytes) before PathBuffer - let sub_offset = u16::from_le_bytes([buffer[8], buffer[9]]) as usize; - let sub_length = u16::from_le_bytes([buffer[10], buffer[11]]) as usize; - // PathBuffer starts at offset 20 (after Flags at offset 16) - (sub_offset, sub_length, 20usize) - } else if reparse_tag == IO_REPARSE_TAG_MOUNT_POINT { - // Mount point (junction) has no Flags field - let sub_offset = u16::from_le_bytes([buffer[8], buffer[9]]) as usize; - let sub_length = u16::from_le_bytes([buffer[10], buffer[11]]) as usize; - // PathBuffer starts at offset 16 - (sub_offset, sub_length, 16usize) - } else { - return Err(vm.new_value_error("not a symbolic link")); - }; - - // Extract the substitute name - let path_start = path_buffer_start + substitute_offset; - let path_end = path_start + substitute_length; - - if path_end > buffer.len() { - return Err(vm.new_os_error("Invalid reparse data".to_owned())); - } - - // Convert from UTF-16LE - let path_slice = &buffer[path_start..path_end]; - let wide_chars: Vec = path_slice - .chunks_exact(2) - .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) - .collect(); - - let mut wide_chars = wide_chars; - // For mount points (junctions), the substitute name typically starts with \??\ - // Convert this to \\?\ - if wide_chars.len() > 4 - && wide_chars[0] == b'\\' as u16 - && wide_chars[1] == b'?' as u16 - && wide_chars[2] == b'?' as u16 - && wide_chars[3] == b'\\' as u16 - { - wide_chars[1] = b'\\' as u16; + match host_nt::readlink(path.as_ref()) { + Ok(result_path) => Ok(mode.process_path(std::path::PathBuf::from(result_path), vm)), + Err(host_nt::ReadlinkError::Io(err)) => { + Err(OSErrorBuilder::with_filename(&err, path.clone(), vm)) + } + Err(err) => Err(err.to_pyexception(vm)), } - - let result_path = std::ffi::OsString::from_wide(&wide_chars); - - Ok(mode.process_path(std::path::PathBuf::from(result_path), vm)) } pub(crate) fn support_funcs() -> Vec { diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index ea9918b91c5..917d9a71978 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -1,4 +1,5 @@ // spell-checker:disable +#![allow(unreachable_pub)] use crate::{ AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, @@ -155,8 +156,6 @@ impl ToPyObject for crt_fd::Borrowed<'_> { pub(super) mod _os { use super::{DirFd, FollowSymlinks, SupportFunc}; use crate::host_env::fileutils::StatStruct; - #[cfg(windows)] - use crate::host_env::windows::ToWideString; #[cfg(any(unix, windows))] use crate::utils::ToCString; use crate::{ @@ -178,7 +177,10 @@ pub(super) mod _os { use core::time::Duration; use crossbeam_utils::atomic::AtomicCell; use rustpython_common::wtf8::Wtf8Buf; - use rustpython_host_env::suppress_iph; + #[cfg(windows)] + use rustpython_host_env::nt as host_nt; + #[cfg(all(any(unix, target_os = "wasi"), not(target_os = "redox")))] + use rustpython_host_env::posix as host_posix; use std::{fs, io, path::PathBuf, time::SystemTime}; const OPEN_DIR_FD: bool = cfg!(not(any(windows, target_os = "redox"))); @@ -347,9 +349,9 @@ pub(super) mod _os { let c_path = path.clone().into_cstring(vm)?; #[cfg(not(target_os = "redox"))] if let Some(fd) = dir_fd.raw_opt() { - let res = unsafe { libc::mkdirat(fd, c_path.as_ptr(), mode as _) }; - return if res < 0 { - let err = crate::host_env::os::errno_io_error(); + return if let Err(err) = + crate::host_env::posix::make_dir_at(fd, c_path.as_c_str(), mode as u32) + { Err(OSErrorBuilder::with_filename(&err, path, vm)) } else { Ok(()) @@ -357,9 +359,7 @@ pub(super) mod _os { } #[cfg(target_os = "redox")] let [] = dir_fd.0; - let res = unsafe { libc::mkdir(c_path.as_ptr(), mode as _) }; - if res < 0 { - let err = crate::host_env::os::errno_io_error(); + if let Err(err) = crate::host_env::posix::make_dir(c_path.as_c_str(), mode as u32) { return Err(OSErrorBuilder::with_filename(&err, path, vm)); } Ok(()) @@ -381,9 +381,7 @@ pub(super) mod _os { #[cfg(not(target_os = "redox"))] if let Some(fd) = dir_fd.raw_opt() { let c_path = path.clone().into_cstring(vm)?; - let res = unsafe { libc::unlinkat(fd, c_path.as_ptr(), libc::AT_REMOVEDIR) }; - return if res < 0 { - let err = crate::host_env::os::errno_io_error(); + return if let Err(err) = crate::host_env::posix::remove_dir_at(fd, c_path.as_c_str()) { Err(OSErrorBuilder::with_filename(&err, path, vm)) } else { Ok(()) @@ -439,35 +437,17 @@ pub(super) mod _os { } #[cfg(all(unix, not(target_os = "redox")))] { - use rustpython_host_env::os::ffi::OsStrExt; - use std::os::unix::io::IntoRawFd; - let new_fd = nix::unistd::dup(fno).map_err(|e| e.into_pyexception(vm))?; - let raw_fd = new_fd.into_raw_fd(); - let dir = OwnedDir::from_fd(raw_fd).map_err(|e| { - unsafe { libc::close(raw_fd) }; - e.into_pyexception(vm) - })?; - // OwnedDir::drop calls rewinddir (reset to start) then closedir. + let mut dir = host_posix::FdDirStream::from_fd(fno.into()) + .map_err(|e| e.into_pyexception(vm))?; let mut list = Vec::new(); - loop { - nix::errno::Errno::clear(); - let entry = unsafe { libc::readdir(dir.as_ptr()) }; - if entry.is_null() { - let err = nix::errno::Errno::last(); - if err != nix::errno::Errno::UnknownErrno { - return Err(io::Error::from(err).into_pyexception(vm)); - } - break; - } - let fname = unsafe { core::ffi::CStr::from_ptr((*entry).d_name.as_ptr()) } - .to_bytes(); - match fname { - b"." | b".." => continue, - _ => list.push( - OutputMode::String - .process_path(std::ffi::OsStr::from_bytes(fname), vm), + while let Some(entry) = dir.next_entry().map_err(|e| e.into_pyexception(vm))? { + list.push( + OutputMode::String.process_path( + rustpython_host_env::os::bytes_as_os_str(&entry.name) + .expect("unix dir entry names are arbitrary bytes"), + vm, ), - } + ); } list } @@ -484,11 +464,6 @@ pub(super) mod _os { } } - #[cfg(windows)] - unsafe extern "C" { - fn _wputenv(envstring: *const u16) -> libc::c_int; - } - /// Check if wide string length exceeds Windows environment variable limit. #[cfg(windows)] fn check_env_var_len(wide_len: usize, vm: &VirtualMachine) -> PyResult<()> { @@ -516,15 +491,13 @@ pub(super) mod _os { return Err(vm.new_value_error("illegal environment variable name")); } let env_str = format!("{key_str}={value_str}"); - let wide = env_str.to_wide_with_nul(); - check_env_var_len(wide.len(), vm)?; + // env_str is guaranteed nul-free by the checks above. + let wide = widestring::WideCString::from_str(&env_str) + .expect("env_str validated to contain no NUL"); + check_env_var_len(wide.len() + 1, vm)?; - // Use _wputenv like CPython (not SetEnvironmentVariableW) to update CRT environ - let result = unsafe { suppress_iph!(_wputenv(wide.as_ptr())) }; - if result != 0 { - return Err(vm.new_last_errno_error()); - } - Ok(()) + // Use _wputenv (not SetEnvironmentVariableW) to update CRT environ. + rustpython_host_env::nt::wputenv(&wide).map_err(|e| e.into_pyexception(vm)) } #[cfg(not(windows))] @@ -563,15 +536,13 @@ pub(super) mod _os { } // "key=" to unset (empty value removes the variable) let env_str = format!("{key_str}="); - let wide = env_str.to_wide_with_nul(); - check_env_var_len(wide.len(), vm)?; + // env_str is guaranteed nul-free by the checks above. + let wide = widestring::WideCString::from_str(&env_str) + .expect("env_str validated to contain no NUL"); + check_env_var_len(wide.len() + 1, vm)?; - // Use _wputenv like CPython (not SetEnvironmentVariableW) to update CRT environ - let result = unsafe { suppress_iph!(_wputenv(wide.as_ptr())) }; - if result != 0 { - return Err(vm.new_last_errno_error()); - } - Ok(()) + // Use _wputenv (not SetEnvironmentVariableW) to update CRT environ. + rustpython_host_env::nt::wputenv(&wide).map_err(|e| e.into_pyexception(vm)) } #[cfg(not(windows))] @@ -671,7 +642,10 @@ pub(super) mod _os { match self.stat(self.stat_dir_fd(), FollowSymlinks(follow_symlinks), vm) { Ok(stat_obj) => { let st_mode: i32 = stat_obj.get_attr("st_mode", vm)?.try_into_value(vm)?; - #[allow(clippy::unnecessary_cast)] + #[allow( + clippy::unnecessary_cast, + reason = "'st_mode' and 'S_IFMT' are not u32 on all platforms" + )] Ok((st_mode as u32 & libc::S_IFMT as u32) == mode_bits) } Err(e) => { @@ -856,7 +830,7 @@ pub(super) mod _os { #[cfg(windows)] #[pymethod] fn is_junction(&self, _vm: &VirtualMachine) -> bool { - junction::exists(self.pathval.clone()).unwrap_or(false) + host_nt::test_file_type_by_name(&self.pathval, host_nt::TestType::Junction) } #[pymethod] @@ -991,7 +965,7 @@ pub(super) mod _os { let lstat = { let cell = OnceCell::new(); if let Ok(stat_struct) = - crate::windows::win32_xstat(pathval.as_os_str(), false) + host_nt::win32_xstat(pathval.as_os_str(), false) { let stat_obj = StatResultData::from_stat(&stat_struct, vm).to_pyobject(vm); @@ -1031,54 +1005,12 @@ pub(super) mod _os { } } - /// Wrapper around a raw `libc::DIR*` for fd-based scandir. - #[cfg(all(unix, not(target_os = "redox")))] - struct OwnedDir(core::ptr::NonNull); - - #[cfg(all(unix, not(target_os = "redox")))] - impl OwnedDir { - fn from_fd(fd: crt_fd::Raw) -> io::Result { - let ptr = unsafe { libc::fdopendir(fd) }; - core::ptr::NonNull::new(ptr) - .map(OwnedDir) - .ok_or_else(io::Error::last_os_error) - } - - fn as_ptr(&self) -> *mut libc::DIR { - self.0.as_ptr() - } - } - - #[cfg(all(unix, not(target_os = "redox")))] - impl Drop for OwnedDir { - fn drop(&mut self) { - unsafe { - libc::rewinddir(self.0.as_ptr()); - libc::closedir(self.0.as_ptr()); - } - } - } - - #[cfg(all(unix, not(target_os = "redox")))] - impl core::fmt::Debug for OwnedDir { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_tuple("OwnedDir").field(&self.0).finish() - } - } - - // Safety: OwnedDir wraps a *mut libc::DIR. All access is synchronized - // through the PyMutex in ScandirIteratorFd. - #[cfg(all(unix, not(target_os = "redox")))] - unsafe impl Send for OwnedDir {} - #[cfg(all(unix, not(target_os = "redox")))] - unsafe impl Sync for OwnedDir {} - #[cfg(all(unix, not(target_os = "redox")))] #[pyattr] #[pyclass(name = "ScandirIter")] #[derive(Debug, PyPayload)] struct ScandirIteratorFd { - dir: crate::common::lock::PyMutex>, + dir: crate::common::lock::PyMutex>, /// The original fd passed to scandir(), stored in DirEntry for fstatat orig_fd: crt_fd::Raw, } @@ -1129,59 +1061,37 @@ pub(super) mod _os { #[cfg(all(unix, not(target_os = "redox")))] impl IterNext for ScandirIteratorFd { fn next(zelf: &crate::Py, vm: &VirtualMachine) -> PyResult { - use rustpython_host_env::os::ffi::OsStrExt; let mut guard = zelf.dir.lock(); let dir = match guard.as_mut() { None => return Ok(PyIterReturn::StopIteration(None)), Some(dir) => dir, }; - loop { - nix::errno::Errno::clear(); - let entry = unsafe { - let ptr = libc::readdir(dir.as_ptr()); - if ptr.is_null() { - let err = nix::errno::Errno::last(); - if err != nix::errno::Errno::UnknownErrno { - return Err(io::Error::from(err).into_pyexception(vm)); - } - drop(guard.take()); - return Ok(PyIterReturn::StopIteration(None)); - } - &*ptr - }; - let fname = unsafe { core::ffi::CStr::from_ptr(entry.d_name.as_ptr()) }.to_bytes(); - if fname == b"." || fname == b".." { - continue; + let Some(entry) = dir.next_entry().map_err(|e| e.into_pyexception(vm))? else { + drop(guard.take()); + return Ok(PyIterReturn::StopIteration(None)); + }; + let file_name = std::ffi::OsString::from( + rustpython_host_env::os::bytes_as_os_str(&entry.name) + .expect("unix dir entry names are arbitrary bytes"), + ); + let pathval = PathBuf::from(&file_name); + Ok(PyIterReturn::Return( + DirEntry { + file_name, + pathval, + file_type: Err(io::Error::other( + "file_type unavailable for fd-based scandir", + )), + d_type: entry.d_type, + dir_fd: Some(zelf.orig_fd), + mode: OutputMode::String, + lstat: OnceCell::new(), + stat: OnceCell::new(), + ino: AtomicCell::new(entry.ino as _), } - let file_name = std::ffi::OsString::from(std::ffi::OsStr::from_bytes(fname)); - let pathval = PathBuf::from(&file_name); - #[cfg(target_os = "freebsd")] - let ino = entry.d_fileno; - #[cfg(not(target_os = "freebsd"))] - let ino = entry.d_ino; - let d_type = entry.d_type; - return Ok(PyIterReturn::Return( - DirEntry { - file_name, - pathval, - file_type: Err(io::Error::other( - "file_type unavailable for fd-based scandir", - )), - d_type: if d_type == libc::DT_UNKNOWN { - None - } else { - Some(d_type) - }, - dir_fd: Some(zelf.orig_fd), - mode: OutputMode::String, - lstat: OnceCell::new(), - stat: OnceCell::new(), - ino: AtomicCell::new(ino as _), - } - .into_ref(&vm.ctx) - .into(), - )); - } + .into_ref(&vm.ctx) + .into(), + )) } } @@ -1209,15 +1119,8 @@ pub(super) mod _os { } #[cfg(all(unix, not(target_os = "redox")))] { - use std::os::unix::io::IntoRawFd; - // closedir() closes the fd, so duplicate it first - let new_fd = nix::unistd::dup(fno).map_err(|e| e.into_pyexception(vm))?; - let raw_fd = new_fd.into_raw_fd(); - let dir = OwnedDir::from_fd(raw_fd).map_err(|e| { - // fdopendir failed, close the dup'd fd - unsafe { libc::close(raw_fd) }; - e.into_pyexception(vm) - })?; + let dir = host_posix::FdDirStream::from_fd(fno.into()) + .map_err(|e| e.into_pyexception(vm))?; Ok(ScandirIteratorFd { dir: crate::common::lock::PyMutex::new(Some(dir)), orig_fd: fno.as_raw(), @@ -1333,6 +1236,7 @@ pub(super) mod _os { #[cfg(not(windows))] #[allow(clippy::useless_conversion, reason = "needed for 32-bit platforms")] let st_blksize = i64::from(stat.st_blksize); + #[cfg(not(windows))] #[allow(clippy::useless_conversion, reason = "needed for 32-bit platforms")] let st_blocks = i64::from(stat.st_blocks); @@ -1401,10 +1305,9 @@ pub(super) mod _os { dir_fd: DirFd<'_, { STAT_DIR_FD as usize }>, follow_symlinks: FollowSymlinks, ) -> io::Result> { - // TODO: replicate CPython's win32_xstat let [] = dir_fd.0; match file { - OsPathOrFd::Path(path) => crate::windows::win32_xstat(&path.path, follow_symlinks.0), + OsPathOrFd::Path(path) => host_nt::win32_xstat(&path.path, follow_symlinks.0), OsPathOrFd::Fd(fd) => crate::host_env::fileutils::fstat(fd), } .map(Some) @@ -1416,42 +1319,14 @@ pub(super) mod _os { dir_fd: DirFd<'_, { STAT_DIR_FD as usize }>, follow_symlinks: FollowSymlinks, ) -> io::Result> { - let mut stat = core::mem::MaybeUninit::uninit(); - let ret = match file { - OsPathOrFd::Path(path) => { - use rustpython_host_env::os::ffi::OsStrExt; - let path = path.as_ref().as_os_str().as_bytes(); - let path = match alloc::ffi::CString::new(path) { - Ok(x) => x, - Err(_) => return Ok(None), - }; - - #[cfg(not(target_os = "redox"))] - let fstatat_ret = dir_fd.raw_opt().map(|dir_fd| { - let flags = if follow_symlinks.0 { - 0 - } else { - libc::AT_SYMLINK_NOFOLLOW - }; - unsafe { libc::fstatat(dir_fd, path.as_ptr(), stat.as_mut_ptr(), flags) } - }); - #[cfg(target_os = "redox")] - let ([], fstatat_ret) = (dir_fd.0, None); - - fstatat_ret.unwrap_or_else(|| { - if follow_symlinks.0 { - unsafe { libc::stat(path.as_ptr(), stat.as_mut_ptr()) } - } else { - unsafe { libc::lstat(path.as_ptr(), stat.as_mut_ptr()) } - } - }) - } - OsPathOrFd::Fd(fd) => unsafe { libc::fstat(fd.as_raw(), stat.as_mut_ptr()) }, - }; - if ret < 0 { - return Err(io::Error::last_os_error()); + match file { + OsPathOrFd::Path(path) => host_posix::stat_path( + path.as_ref().as_os_str(), + dir_fd.raw_opt(), + follow_symlinks.0, + ), + OsPathOrFd::Fd(fd) => host_posix::stat_fd(fd).map(Some), } - Ok(Some(unsafe { stat.assume_init() })) } #[pyfunction] @@ -1494,40 +1369,7 @@ pub(super) mod _os { #[pyfunction] fn chdir(path: OsPath, vm: &VirtualMachine) -> PyResult<()> { crate::host_env::os::set_current_dir(&path.path) - .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm))?; - - #[cfg(windows)] - { - // win32_wchdir() - - // On Windows, set the per-drive CWD environment variable (=X:) - // This is required for GetFullPathNameW to work correctly with drive-relative paths - - use std::os::windows::ffi::OsStrExt; - use windows_sys::Win32::System::Environment::SetEnvironmentVariableW; - - if let Ok(cwd) = crate::host_env::os::current_dir() { - let cwd_str = cwd.as_os_str(); - let mut cwd_wide: Vec = cwd_str.encode_wide().collect(); - - // Check for UNC-like paths (\\server\share or //server/share) - // wcsncmp(new_path, L"\\\\", 2) == 0 || wcsncmp(new_path, L"//", 2) == 0 - let is_unc_like_path = cwd_wide.len() >= 2 - && ((cwd_wide[0] == b'\\' as u16 && cwd_wide[1] == b'\\' as u16) - || (cwd_wide[0] == b'/' as u16 && cwd_wide[1] == b'/' as u16)); - - if !is_unc_like_path { - // Create env var name "=X:" where X is the drive letter - let env_name: [u16; 4] = [b'=' as u16, cwd_wide[0], b':' as u16, 0]; - cwd_wide.push(0); // null-terminate the path - unsafe { - SetEnvironmentVariableW(env_name.as_ptr(), cwd_wide.as_ptr()); - } - } - } - } - - Ok(()) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } #[pyfunction] @@ -1547,7 +1389,7 @@ pub(super) mod _os { .argument("dst") .try_path(dst, vm)?; - fs::rename(&src.path, &dst.path).map_err(|err| { + crate::host_env::os::rename(&src.path, &dst.path).map_err(|err| { let builder = err.to_os_error_builder(vm); let builder = builder.filename(src.filename(vm)); let builder = builder.filename2(dst.filename(vm)); @@ -1570,7 +1412,7 @@ pub(super) mod _os { #[pyfunction] fn cpu_count(vm: &VirtualMachine) -> PyObjectRef { - let cpu_count = num_cpus::get(); + let cpu_count = crate::host_env::os::cpu_count(); vm.ctx.new_int(cpu_count).into() } @@ -1581,10 +1423,7 @@ pub(super) mod _os { #[pyfunction] fn abort() { - unsafe extern "C" { - fn abort(); - } - unsafe { abort() } + crate::host_env::os::abort() } #[pyfunction] @@ -1592,14 +1431,12 @@ pub(super) mod _os { if size < 0 { return Err(vm.new_value_error("negative argument not allowed")); } - let mut buf = vec![0u8; size as usize]; - getrandom::fill(&mut buf).map_err(|e| io::Error::from(e).into_pyexception(vm))?; - Ok(buf) + crate::host_env::os::urandom(size as usize).map_err(|e| e.into_pyexception(vm)) } #[pyfunction] - pub(crate) fn isatty(fd: i32) -> bool { - unsafe { suppress_iph!(libc::isatty(fd)) != 0 } + pub fn isatty(fd: i32) -> bool { + crate::host_env::os::isatty(fd) } #[pyfunction] @@ -1609,32 +1446,7 @@ pub(super) mod _os { how: i32, vm: &VirtualMachine, ) -> PyResult { - #[cfg(not(windows))] - let res = unsafe { suppress_iph!(libc::lseek(fd.as_raw(), position, how)) }; - #[cfg(windows)] - let res = unsafe { - use std::os::windows::io::AsRawHandle; - use windows_sys::Win32::Storage::FileSystem; - let handle = crt_fd::as_handle(fd).map_err(|e| e.into_pyexception(vm))?; - let mut distance_to_move: [i32; 2] = core::mem::transmute(position); - let ret = FileSystem::SetFilePointer( - handle.as_raw_handle(), - distance_to_move[0], - &mut distance_to_move[1], - how as _, - ); - if ret == FileSystem::INVALID_SET_FILE_POINTER { - -1 - } else { - distance_to_move[0] = ret as _; - core::mem::transmute::<[i32; 2], i64>(distance_to_move) - } - }; - if res < 0 { - Err(vm.new_last_os_error()) - } else { - Ok(res) - } + crate::host_env::os::seek_fd(fd, position, how).map_err(|e| e.into_pyexception(vm)) } #[derive(FromArgs)] @@ -1664,20 +1476,9 @@ pub(super) mod _os { .map_err(|_| vm.new_value_error("embedded null byte"))?; let follow = follow_symlinks.into_option().unwrap_or(true); - let flags = if follow { libc::AT_SYMLINK_FOLLOW } else { 0 }; - - let ret = unsafe { - libc::linkat( - libc::AT_FDCWD, - src_cstr.as_ptr(), - libc::AT_FDCWD, - dst_cstr.as_ptr(), - flags, - ) - }; - - if ret != 0 { - let err = std::io::Error::last_os_error(); + if let Err(err) = + crate::host_env::posix::link_paths(src_cstr.as_c_str(), dst_cstr.as_c_str(), follow) + { let builder = err.to_os_error_builder(vm); let builder = builder.filename(src.filename(vm)); let builder = builder.filename2(dst.filename(vm)); @@ -1714,7 +1515,7 @@ pub(super) mod _os { #[pyfunction] fn system(command: PyStrRef, vm: &VirtualMachine) -> PyResult { let cstr = command.to_cstring(vm)?; - let x = unsafe { libc::system(cstr.as_ptr()) }; + let x = crate::host_env::os::system(cstr.as_c_str()); Ok(x) } @@ -1798,31 +1599,14 @@ pub(super) mod _os { { let path_for_err = path.clone(); let path = path.into_cstring(vm)?; - - let ts = |d: Duration| libc::timespec { - tv_sec: d.as_secs() as _, - tv_nsec: d.subsec_nanos() as _, - }; - let times = [ts(acc), ts(modif)]; - - let ret = unsafe { - libc::utimensat( - dir_fd.get().as_raw(), - path.as_ptr(), - times.as_ptr(), - if _follow_symlinks.0 { - 0 - } else { - libc::AT_SYMLINK_NOFOLLOW - }, - ) - }; - if ret < 0 { - Err(OSErrorBuilder::with_filename( - &io::Error::last_os_error(), - path_for_err, - vm, - )) + if let Err(err) = crate::host_env::posix::set_file_times_at( + dir_fd.get().as_raw(), + path.as_c_str(), + acc, + modif, + _follow_symlinks.0, + ) { + Err(OSErrorBuilder::with_filename(&err, path_for_err, vm)) } else { Ok(()) } @@ -1830,21 +1614,12 @@ pub(super) mod _os { #[cfg(target_os = "redox")] { let [] = dir_fd.0; - - let tv = |d: Duration| libc::timeval { - tv_sec: d.as_secs() as _, - tv_usec: d.as_micros() as _, - }; - nix::sys::stat::utimes(path.as_ref(), &tv(acc).into(), &tv(modif).into()) + rustpython_host_env::posix::utimes(path.as_ref(), acc, modif) .map_err(|err| err.into_pyexception(vm)) } } #[cfg(windows)] { - use std::os::windows::prelude::*; - type DWORD = u32; - use windows_sys::Win32::{Foundation::FILETIME, Storage::FileSystem}; - let [] = dir_fd.0; if !_follow_symlinks.0 { @@ -1853,37 +1628,8 @@ pub(super) mod _os { )); } - let ft = |d: Duration| { - let intervals = ((d.as_secs() as i64 + 11644473600) * 10_000_000) - + (d.subsec_nanos() as i64 / 100); - FILETIME { - dwLowDateTime: intervals as DWORD, - dwHighDateTime: (intervals >> 32) as DWORD, - } - }; - - let acc = ft(acc); - let modif = ft(modif); - - let f = crate::host_env::fs::open_write_with_custom_flags( - &path, - windows_sys::Win32::Storage::FileSystem::FILE_FLAG_BACKUP_SEMANTICS, - ) - .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; - - let ret = unsafe { - FileSystem::SetFileTime(f.as_raw_handle() as _, core::ptr::null(), &acc, &modif) - }; - - if ret == 0 { - Err(OSErrorBuilder::with_filename( - &io::Error::last_os_error(), - path, - vm, - )) - } else { - Ok(()) - } + crate::host_env::os::set_file_times(&path, acc, modif) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } } @@ -1916,32 +1662,12 @@ pub(super) mod _os { fn times(vm: &VirtualMachine) -> PyResult { #[cfg(windows)] { - use core::mem::MaybeUninit; - use windows_sys::Win32::{Foundation::FILETIME, System::Threading}; - - let mut _create = MaybeUninit::::uninit(); - let mut _exit = MaybeUninit::::uninit(); - let mut kernel = MaybeUninit::::uninit(); - let mut user = MaybeUninit::::uninit(); - - unsafe { - let h_proc = Threading::GetCurrentProcess(); - Threading::GetProcessTimes( - h_proc, - _create.as_mut_ptr(), - _exit.as_mut_ptr(), - kernel.as_mut_ptr(), - user.as_mut_ptr(), - ); - } - - let kernel = unsafe { kernel.assume_init() }; - let user = unsafe { user.assume_init() }; + let times = crate::host_env::time::get_process_times_100ns() + .ok_or_else(|| vm.new_last_os_error())?; let times_result = TimesResultData { - user: user.dwHighDateTime as f64 * 429.4967296 + user.dwLowDateTime as f64 * 1e-7, - system: kernel.dwHighDateTime as f64 * 429.4967296 - + kernel.dwLowDateTime as f64 * 1e-7, + user: times.user as f64 * 1e-7, + system: times.system as f64 * 1e-7, children_user: 0.0, children_system: 0.0, elapsed: 0.0, @@ -1951,27 +1677,15 @@ pub(super) mod _os { } #[cfg(unix)] { - let mut t = libc::tms { - tms_utime: 0, - tms_stime: 0, - tms_cutime: 0, - tms_cstime: 0, - }; - - let tick_for_second = unsafe { libc::sysconf(libc::_SC_CLK_TCK) } as f64; - let c = unsafe { libc::times(&mut t as *mut _) }; - - // XXX: The signedness of `clock_t` varies from platform to platform. - if c == (-1i8) as libc::clock_t { - return Err(vm.new_os_error("Fail to get times".to_string())); - } + let times = crate::host_env::time::process_times() + .map_err(|_| vm.new_os_error("Fail to get times".to_string()))?; let times_result = TimesResultData { - user: t.tms_utime as f64 / tick_for_second, - system: t.tms_stime as f64 / tick_for_second, - children_user: t.tms_cutime as f64 / tick_for_second, - children_system: t.tms_cstime as f64 / tick_for_second, - elapsed: c as f64 / tick_for_second, + user: times.user, + system: times.system, + children_user: times.children_user, + children_system: times.children_system, + elapsed: times.elapsed, }; Ok(times_result.to_pyobject(vm)) @@ -1995,45 +1709,25 @@ pub(super) mod _os { #[cfg(target_os = "linux")] #[pyfunction] - fn copy_file_range(args: CopyFileRangeArgs<'_>, vm: &VirtualMachine) -> PyResult { - #[allow(clippy::unnecessary_option_map_or_else)] - let p_offset_src = args.offset_src.as_ref().map_or_else(core::ptr::null, |x| x); - #[allow(clippy::unnecessary_option_map_or_else)] - let p_offset_dst = args.offset_dst.as_ref().map_or_else(core::ptr::null, |x| x); + fn copy_file_range(mut args: CopyFileRangeArgs<'_>, vm: &VirtualMachine) -> PyResult { let count: usize = args .count .try_into() .map_err(|_| vm.new_value_error("count should >= 0"))?; - // The flags argument is provided to allow - // for future extensions and currently must be to 0. - let flags = 0u32; - - // Safety: p_offset_src and p_offset_dst is a unique pointer for offset_src and offset_dst respectively, - // and will only be freed after this function ends. - // - // Why not use `libc::copy_file_range`: On `musl-libc`, `libc::copy_file_range` is not provided. Therefore - // we use syscalls directly instead. - let ret = unsafe { - libc::syscall( - libc::SYS_copy_file_range, - args.src, - p_offset_src as *mut i64, - args.dst, - p_offset_dst as *mut i64, - count, - flags, - ) - }; - - usize::try_from(ret).map_err(|_| vm.new_last_errno_error()) + crate::host_env::os::copy_file_range( + args.src, + args.offset_src.as_mut(), + args.dst, + args.offset_dst.as_mut(), + count, + ) + .map_err(|_| vm.new_last_errno_error()) } #[pyfunction] fn strerror(e: i32) -> String { - unsafe { core::ffi::CStr::from_ptr(libc::strerror(e)) } - .to_string_lossy() - .into_owned() + crate::host_env::time::strerror(e) } #[pyfunction] @@ -2072,16 +1766,8 @@ pub(super) mod _os { #[cfg(all(unix, not(any(target_os = "redox", target_os = "android"))))] #[pyfunction] fn getloadavg(vm: &VirtualMachine) -> PyResult<(f64, f64, f64)> { - let mut loadavg = [0f64; 3]; - - // Safety: loadavg is on stack and only write by `getloadavg` and are freed - // after this function ends. - unsafe { - if libc::getloadavg(&mut loadavg[0] as *mut f64, 3) != 3 { - return Err(vm.new_os_error("Load averages are unobtainable".to_string())); - } - } - + let loadavg = crate::host_env::time::getloadavg() + .map_err(|_| vm.new_os_error("Load averages are unobtainable".to_string()))?; Ok((loadavg[0], loadavg[1], loadavg[2])) } @@ -2091,16 +1777,12 @@ pub(super) mod _os { let status = u32::try_from(status) .map_err(|_| vm.new_value_error(format!("invalid WEXITSTATUS: {status}")))?; - let status = status as libc::c_int; - if libc::WIFEXITED(status) { - return Ok(libc::WEXITSTATUS(status)); - } - - if libc::WIFSIGNALED(status) { - return Ok(-libc::WTERMSIG(status)); + if let Some(exitcode) = crate::host_env::time::waitstatus_to_exitcode(status as libc::c_int) + { + return Ok(exitcode); } - Err(vm.new_value_error(format!("Invalid wait status: {status}"))) + Err(vm.new_value_error(format!("Invalid wait status: {}", status as libc::c_int))) } #[cfg(windows)] @@ -2119,31 +1801,7 @@ pub(super) mod _os { return None; } - cfg_select! { - any(target_os = "android", target_os = "redox") => { - Some("UTF-8".to_owned()) - } - windows => { - use windows_sys::Win32::System::Console; - let cp = match fd { - 0 => unsafe { Console::GetConsoleCP() }, - 1 | 2 => unsafe { Console::GetConsoleOutputCP() }, - _ => 0, - }; - - Some(format!("cp{cp}")) - } - _ => { - Some(unsafe { - let encoding = libc::nl_langinfo(libc::CODESET); - if encoding.is_null() || encoding.read() == b'\0' as libc::c_char { - "UTF-8".to_owned() - } else { - core::ffi::CStr::from_ptr(encoding).to_string_lossy().into_owned() - } - }) - } - } + rustpython_host_env::os::device_encoding(fd) } #[pystruct_sequence_data] @@ -2214,27 +1872,7 @@ pub(super) mod _os { #[cfg(all(unix, not(target_os = "redox")))] impl StatvfsResultData { - fn from_statvfs(st: libc::statvfs) -> Self { - // f_fsid is a struct on some platforms (e.g., Linux fsid_t) and a scalar on others. - // We extract raw bytes and interpret as a native-endian integer. - // Note: The value may differ across architectures due to endianness. - let f_fsid = { - let ptr = core::ptr::addr_of!(st.f_fsid) as *const u8; - let size = core::mem::size_of_val(&st.f_fsid); - if size >= 8 { - let bytes = unsafe { core::slice::from_raw_parts(ptr, 8) }; - u64::from_ne_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - bytes[7], - ]) as libc::c_ulong - } else if size >= 4 { - let bytes = unsafe { core::slice::from_raw_parts(ptr, 4) }; - u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as libc::c_ulong - } else { - 0 - } - }; - + fn from_statvfs(st: crate::host_env::posix::StatVfsInfo) -> Self { Self { f_bsize: st.f_bsize, f_frsize: st.f_frsize, @@ -2246,7 +1884,7 @@ pub(super) mod _os { f_favail: st.f_favail, f_flag: st.f_flag, f_namemax: st.f_namemax, - f_fsid, + f_fsid: st.f_fsid, } } } @@ -2256,22 +1894,17 @@ pub(super) mod _os { #[pyfunction] #[pyfunction(name = "fstatvfs")] fn statvfs(path: OsPathOrFd<'_>, vm: &VirtualMachine) -> PyResult { - let mut st: libc::statvfs = unsafe { core::mem::zeroed() }; - let ret = match &path { + let st = match &path { OsPathOrFd::Path(p) => { let cpath = p.clone().into_cstring(vm)?; - unsafe { libc::statvfs(cpath.as_ptr(), &mut st) } + crate::host_env::posix::statvfs_path(cpath.as_c_str()) } - OsPathOrFd::Fd(fd) => unsafe { libc::fstatvfs(fd.as_raw(), &mut st) }, + OsPathOrFd::Fd(fd) => crate::host_env::posix::statvfs_fd(fd.as_raw()), }; - if ret != 0 { - return Err(OSErrorBuilder::with_filename( - &io::Error::last_os_error(), - path, - vm, - )); + if let Err(err) = st { + return Err(OSErrorBuilder::with_filename(&err, path, vm)); } - Ok(StatvfsResultData::from_statvfs(st).to_pyobject(vm)) + Ok(StatvfsResultData::from_statvfs(st.unwrap()).to_pyobject(vm)) } pub(super) fn support_funcs() -> Vec { @@ -2309,7 +1942,7 @@ pub(super) mod _os { supports } } -pub(crate) use _os::{ftruncate, isatty, lseek}; +pub(crate) use _os::ftruncate; pub(crate) struct SupportFunc { name: &'static str, diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index ca3a9a2dc48..b45f31821c7 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -18,7 +18,7 @@ pub mod module { use crate::{ AsObject, Py, PyObjectRef, PyResult, VirtualMachine, builtins::{PyDictRef, PyInt, PyListRef, PyTupleRef, PyUtf8Str}, - convert::{IntoPyException, ToPyObject, TryFromObject}, + convert::{IntoPyException, ToPyException, ToPyObject, TryFromObject}, exceptions::OSErrorBuilder, function::{ArgMapping, Either, KwArgs, OptionalArg}, ospath::{OsPath, OsPathOrFd}, @@ -35,17 +35,11 @@ pub mod module { ))] use crate::{builtins::PyUtf8StrRef, utils::ToCString}; use alloc::ffi::CString; - use bitflags::bitflags; use core::ffi::CStr; - use nix::{ - errno::Errno, - fcntl, - unistd::{self, Gid, Pid, Uid}, - }; use rustpython_host_env::os::ffi::OsStringExt; use std::{ fs, io, - os::fd::{AsFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd}, + os::fd::{BorrowedFd, FromRawFd, IntoRawFd, OwnedFd}, }; use strum::IntoEnumIterator; use strum_macros::{EnumIter, EnumString}; @@ -374,127 +368,20 @@ pub mod module { } } - // Flags for os_access - bitflags! { - #[derive(Copy, Clone, Debug, PartialEq, Eq)] - pub struct AccessFlags: u8 { - const F_OK = _os::F_OK; - const R_OK = _os::R_OK; - const W_OK = _os::W_OK; - const X_OK = _os::X_OK; - } - } - - struct Permissions { - is_readable: bool, - is_writable: bool, - is_executable: bool, - } - - const fn get_permissions(mode: u32) -> Permissions { - Permissions { - is_readable: mode & 4 != 0, - is_writable: mode & 2 != 0, - is_executable: mode & 1 != 0, - } - } - - fn get_right_permission( - mode: u32, - file_owner: Uid, - file_group: Gid, - ) -> nix::Result { - let owner_mode = (mode & 0o700) >> 6; - let owner_permissions = get_permissions(owner_mode); - - let group_mode = (mode & 0o070) >> 3; - let group_permissions = get_permissions(group_mode); - - let others_mode = mode & 0o007; - let others_permissions = get_permissions(others_mode); - - let user_id = nix::unistd::getuid(); - let groups_ids = getgroups_impl()?; - - if file_owner == user_id { - Ok(owner_permissions) - } else if groups_ids.contains(&file_group) { - Ok(group_permissions) - } else { - Ok(others_permissions) - } - } - - #[cfg(any(target_os = "macos", target_os = "ios"))] - fn getgroups_impl() -> nix::Result> { - use core::ptr; - use libc::{c_int, gid_t}; - - let ret = unsafe { libc::getgroups(0, ptr::null_mut()) }; - let mut groups = Vec::::with_capacity(Errno::result(ret)? as usize); - let ret = unsafe { - libc::getgroups( - groups.capacity() as c_int, - groups.as_mut_ptr() as *mut gid_t, - ) - }; - - Errno::result(ret).map(|s| { - unsafe { groups.set_len(s as usize) }; - groups - }) - } - - #[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "redox")))] - use nix::unistd::getgroups as getgroups_impl; - - #[cfg(target_os = "redox")] - fn getgroups_impl() -> nix::Result> { - Err(nix::Error::EOPNOTSUPP) - } - #[pyfunction] fn getgroups(vm: &VirtualMachine) -> PyResult> { - let group_ids = getgroups_impl().map_err(|e| e.into_pyexception(vm))?; + let group_ids = + rustpython_host_env::posix::getgroups().map_err(|e| e.into_pyexception(vm))?; Ok(group_ids .into_iter() - .map(|gid| vm.ctx.new_int(gid.as_raw()).into()) + .map(|gid| vm.ctx.new_int(gid).into()) .collect()) } #[pyfunction] pub(super) fn access(path: OsPath, mode: u8, vm: &VirtualMachine) -> PyResult { - use std::os::unix::fs::MetadataExt; - - let flags = AccessFlags::from_bits(mode).ok_or_else(|| { - vm.new_value_error( - "One of the flags is wrong, there are only 4 possibilities F_OK, R_OK, W_OK and X_OK", - ) - })?; - - let metadata = match crate::host_env::fs::metadata(&path.path) { - Ok(m) => m, - // If the file doesn't exist, return False for any access check - Err(_) => return Ok(false), - }; - - // if it's only checking for F_OK - if flags == AccessFlags::F_OK { - return Ok(true); // File exists - } - - let user_id = metadata.uid(); - let group_id = metadata.gid(); - let mode = metadata.mode(); - - let perm = get_right_permission(mode, Uid::from_raw(user_id), Gid::from_raw(group_id)) - .map_err(|err| err.into_pyexception(vm))?; - - let r_ok = !flags.contains(AccessFlags::R_OK) || perm.is_readable; - let w_ok = !flags.contains(AccessFlags::W_OK) || perm.is_writable; - let x_ok = !flags.contains(AccessFlags::X_OK) || perm.is_executable; - - Ok(r_ok && w_ok && x_ok) + rustpython_host_env::posix::check_access(path.as_ref(), mode) + .map_err(|err| err.to_pyexception(vm)) } #[pyattr] @@ -536,18 +423,13 @@ pub mod module { let dst = args.dst.into_cstring(vm)?; #[cfg(not(target_os = "redox"))] { - nix::unistd::symlinkat(&*src, args.dir_fd.get(), &*dst) + rustpython_host_env::posix::symlinkat(&src, args.dir_fd.get().into(), &dst) .map_err(|err| err.into_pyexception(vm)) } #[cfg(target_os = "redox")] { let [] = args.dir_fd.0; - let res = unsafe { libc::symlink(src.as_ptr(), dst.as_ptr()) }; - if res < 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + rustpython_host_env::posix::symlink(&src, &dst).map_err(|err| err.into_pyexception(vm)) } } @@ -561,13 +443,8 @@ pub mod module { #[cfg(not(target_os = "redox"))] if let Some(fd) = dir_fd.raw_opt() { let c_path = path.clone().into_cstring(vm)?; - let res = unsafe { libc::unlinkat(fd, c_path.as_ptr(), 0) }; - return if res < 0 { - let err = crate::host_env::os::errno_io_error(); - Err(OSErrorBuilder::with_filename(&err, path, vm)) - } else { - Ok(()) - }; + return rustpython_host_env::posix::unlinkat(fd, &c_path) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)); } #[cfg(target_os = "redox")] let [] = dir_fd.0; @@ -580,12 +457,7 @@ pub mod module { fn fchdir(fd: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { warn_if_bool_fd(&fd, vm)?; let fd = i32::try_from_object(vm, fd)?; - let ret = unsafe { libc::fchdir(fd) }; - if ret == 0 { - Ok(()) - } else { - Err(io::Error::last_os_error().into_pyexception(vm)) - } + rustpython_host_env::posix::fchdir(fd).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] @@ -593,11 +465,8 @@ pub mod module { fn chroot(path: OsPath, vm: &VirtualMachine) -> PyResult<()> { use crate::exceptions::OSErrorBuilder; - nix::unistd::chroot(&*path.path).map_err(|err| { - // Use `From for io::Error` when it is available - let io_err: io::Error = err.into(); - OSErrorBuilder::with_filename(&io_err, path, vm) - }) + rustpython_host_env::posix::chroot(std::path::Path::new(&path.path)) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } // As of now, redox does not seems to support chown command (cf. https://gitlab.redox-os.org/redox-os/coreutils , last checked on 05/07/2020) @@ -612,7 +481,7 @@ pub mod module { vm: &VirtualMachine, ) -> PyResult<()> { let uid = if uid >= 0 { - Some(nix::unistd::Uid::from_raw(uid as u32)) + Some(uid as u32) } else if uid == -1 { None } else { @@ -620,30 +489,24 @@ pub mod module { }; let gid = if gid >= 0 { - Some(nix::unistd::Gid::from_raw(gid as u32)) + Some(gid as u32) } else if gid == -1 { None } else { return Err(vm.new_os_error("Specified gid is not valid.")); }; - let flag = if follow_symlinks.0 { - nix::fcntl::AtFlags::empty() - } else { - nix::fcntl::AtFlags::AT_SYMLINK_NOFOLLOW - }; - match path { - OsPathOrFd::Path(ref p) => { - nix::unistd::fchownat(dir_fd.get(), p.path.as_os_str(), uid, gid, flag) - } - OsPathOrFd::Fd(fd) => nix::unistd::fchown(fd, uid, gid), + OsPathOrFd::Path(ref p) => rustpython_host_env::posix::fchownat( + dir_fd.get().into(), + p.path.as_os_str(), + uid, + gid, + follow_symlinks.0, + ), + OsPathOrFd::Fd(fd) => rustpython_host_env::posix::fchown(fd.into(), uid, gid), } - .map_err(|err| { - // Use `From for io::Error` when it is available - let err = io::Error::from_raw_os_error(err as i32); - OSErrorBuilder::with_filename(&err, path, vm) - }) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } #[cfg(not(target_os = "redox"))] @@ -904,7 +767,7 @@ pub mod module { }; if num_threads > 1 { - let pid = unsafe { libc::getpid() }; + let pid = rustpython_host_env::posix::getpid(); let msg = format!( "This process (pid={pid}) is multi-threaded, use of {name}() may lead to deadlocks in the child." ); @@ -936,24 +799,23 @@ pub mod module { .call(("os.fork",), vm)?; py_os_before_fork(vm); + let pid = rustpython_host_env::posix::fork(); - let pid = unsafe { libc::fork() }; - // Save errno immediately — AfterFork callbacks may clobber it. - let saved_errno = nix::Error::last_raw(); - if pid == 0 { - py_os_after_fork_child(vm); - } else { - // Match CPython timing: capture this before parent after-fork hooks - // in case those hooks start threads. - let num_os_threads = get_number_of_os_threads(); - py_os_after_fork_parent(vm); - // Match CPython timing: warn only after parent callback path resumes world. - warn_if_multi_threaded("fork", num_os_threads, vm); - } - if pid == -1 { - Err(nix::Error::from_raw(saved_errno).into_pyexception(vm)) - } else { - Ok(pid) + match pid { + Ok(0) => { + py_os_after_fork_child(vm); + Ok(0) + } + Ok(pid) => { + // Match CPython timing: capture this before parent after-fork hooks + // in case those hooks start threads. + let num_os_threads = get_number_of_os_threads(); + py_os_after_fork_parent(vm); + // Match CPython timing: warn only after parent callback path resumes world. + warn_if_multi_threaded("fork", num_os_threads, vm); + Ok(pid) + } + Err(err) => Err(err.into_pyexception(vm)), } } @@ -975,45 +837,27 @@ pub mod module { #[cfg(not(target_os = "redox"))] impl MknodArgs<'_> { - fn _mknod(self, vm: &VirtualMachine) -> PyResult { - Ok(unsafe { - libc::mknod( - self.path.clone().into_cstring(vm)?.as_ptr(), - self.mode, - self.device, - ) - }) - } - #[cfg(not(target_vendor = "apple"))] fn mknod(self, vm: &VirtualMachine) -> PyResult<()> { - let ret = match self.dir_fd.raw_opt() { - None => self._mknod(vm)?, - Some(non_default_fd) => unsafe { - libc::mknodat( - non_default_fd, - self.path.clone().into_cstring(vm)?.as_ptr(), - self.mode, - self.device, - ) - }, - }; - if ret != 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) + let c_path = self.path.clone().into_cstring(vm)?; + match self.dir_fd.raw_opt() { + None => rustpython_host_env::posix::mknod(&c_path, self.mode, self.device), + Some(non_default_fd) => rustpython_host_env::posix::mknodat( + non_default_fd, + &c_path, + self.mode, + self.device, + ), } + .map_err(|err| err.into_pyexception(vm)) } #[cfg(target_vendor = "apple")] fn mknod(self, vm: &VirtualMachine) -> PyResult<()> { let [] = self.dir_fd.0; - let ret = self._mknod(vm)?; - if ret != 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + let c_path = self.path.clone().into_cstring(vm)?; + rustpython_host_env::posix::mknod(&c_path, self.mode, self.device) + .map_err(|err| err.into_pyexception(vm)) } } @@ -1026,49 +870,31 @@ pub mod module { #[cfg(not(target_os = "redox"))] #[pyfunction] fn nice(increment: i32, vm: &VirtualMachine) -> PyResult { - Errno::clear(); - let res = unsafe { libc::nice(increment) }; - if res == -1 && Errno::last_raw() != 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(res) - } + rustpython_host_env::posix::nice(increment).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] #[pyfunction] fn sched_get_priority_max(policy: i32, vm: &VirtualMachine) -> PyResult { - let max = unsafe { libc::sched_get_priority_max(policy) }; - if max == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(max) - } + rustpython_host_env::posix::sched_get_priority_max(policy) + .map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] #[pyfunction] fn sched_get_priority_min(policy: i32, vm: &VirtualMachine) -> PyResult { - let min = unsafe { libc::sched_get_priority_min(policy) }; - if min == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(min) - } + rustpython_host_env::posix::sched_get_priority_min(policy) + .map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn sched_yield(vm: &VirtualMachine) -> PyResult<()> { - nix::sched::sched_yield().map_err(|e| e.into_pyexception(vm)) + rustpython_host_env::posix::sched_yield().map_err(|e| e.into_pyexception(vm)) } #[pyfunction] fn get_inheritable(fd: BorrowedFd<'_>, vm: &VirtualMachine) -> PyResult { - let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFD); - match flags { - Ok(ret) => Ok((ret & libc::FD_CLOEXEC) == 0), - Err(err) => Err(err.into_pyexception(vm)), - } + rustpython_host_env::fcntl::get_inheritable(fd).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] @@ -1078,36 +904,18 @@ pub mod module { #[pyfunction] fn get_blocking(fd: BorrowedFd<'_>, vm: &VirtualMachine) -> PyResult { - let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL); - match flags { - Ok(ret) => Ok((ret & libc::O_NONBLOCK) == 0), - Err(err) => Err(err.into_pyexception(vm)), - } + rustpython_host_env::fcntl::get_blocking(fd).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn set_blocking(fd: BorrowedFd<'_>, blocking: bool, vm: &VirtualMachine) -> PyResult<()> { - let _set_flag = || { - use nix::fcntl::{FcntlArg, OFlag, fcntl}; - - let flags = OFlag::from_bits_truncate(fcntl(fd, FcntlArg::F_GETFL)?); - let mut new_flags = flags; - new_flags.set(OFlag::from_bits_truncate(libc::O_NONBLOCK), !blocking); - if flags != new_flags { - fcntl(fd, FcntlArg::F_SETFL(new_flags))?; - } - Ok(()) - }; - _set_flag().map_err(|err: nix::Error| err.into_pyexception(vm)) + rustpython_host_env::fcntl::set_blocking(fd, blocking) + .map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn pipe(vm: &VirtualMachine) -> PyResult<(OwnedFd, OwnedFd)> { - use nix::unistd::pipe; - let (rfd, wfd) = pipe().map_err(|err| err.into_pyexception(vm))?; - set_inheritable(rfd.as_fd(), false, vm)?; - set_inheritable(wfd.as_fd(), false, vm)?; - Ok((rfd, wfd)) + rustpython_host_env::posix::pipe().map_err(|err| err.into_pyexception(vm)) } // cfg from nix @@ -1122,8 +930,7 @@ pub mod module { ))] #[pyfunction] fn pipe2(flags: libc::c_int, vm: &VirtualMachine) -> PyResult<(OwnedFd, OwnedFd)> { - let oflags = fcntl::OFlag::from_bits_truncate(flags); - nix::unistd::pipe2(oflags).map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::pipe2(flags).map_err(|err| err.into_pyexception(vm)) } fn _chmod( @@ -1147,11 +954,7 @@ pub mod module { #[cfg(not(target_os = "redox"))] fn _fchmod(fd: BorrowedFd<'_>, mode: u32, vm: &VirtualMachine) -> PyResult<()> { - nix::sys::stat::fchmod( - fd, - nix::sys::stat::Mode::from_bits_truncate(mode as libc::mode_t), - ) - .map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::fchmod(fd, mode).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] @@ -1196,16 +999,9 @@ pub mod module { #[cfg(any(target_os = "macos", target_os = "freebsd", target_os = "netbsd",))] #[pyfunction] fn lchmod(path: OsPath, mode: u32, vm: &VirtualMachine) -> PyResult<()> { - unsafe extern "C" { - fn lchmod(path: *const libc::c_char, mode: libc::mode_t) -> libc::c_int; - } let c_path = path.clone().into_cstring(vm)?; - if unsafe { lchmod(c_path.as_ptr(), mode as libc::mode_t) } == 0 { - Ok(()) - } else { - let err = std::io::Error::last_os_error(); - Err(OSErrorBuilder::with_filename(&err, path, vm)) - } + rustpython_host_env::posix::lchmod(&c_path, mode as libc::mode_t) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } #[pyfunction] @@ -1228,9 +1024,7 @@ pub mod module { return Err(vm.new_value_error("execv() arg 2 first element cannot be empty")); } - unistd::execv(&path, &argv) - .map(|_ok| ()) - .map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::execv(&path, &argv).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] @@ -1278,90 +1072,86 @@ pub mod module { let env: Vec<&CStr> = env.iter().map(|entry| entry.as_c_str()).collect(); - unistd::execve(&path, &argv, &env).map_err(|err| err.into_pyexception(vm))?; + rustpython_host_env::posix::execve(&path, &argv, &env) + .map_err(|err| err.into_pyexception(vm))?; Ok(()) } #[pyfunction] fn getppid(vm: &VirtualMachine) -> PyObjectRef { - let ppid = unistd::getppid().as_raw(); + let ppid = rustpython_host_env::posix::getppid(); vm.ctx.new_int(ppid).into() } #[pyfunction] fn getgid(vm: &VirtualMachine) -> PyObjectRef { - let gid = unistd::getgid().as_raw(); + let gid = rustpython_host_env::posix::getgid(); vm.ctx.new_int(gid).into() } #[pyfunction] fn getegid(vm: &VirtualMachine) -> PyObjectRef { - let egid = unistd::getegid().as_raw(); + let egid = rustpython_host_env::posix::getegid(); vm.ctx.new_int(egid).into() } #[pyfunction] fn getpgid(pid: u32, vm: &VirtualMachine) -> PyResult { - let pgid = - unistd::getpgid(Some(Pid::from_raw(pid as i32))).map_err(|e| e.into_pyexception(vm))?; - Ok(vm.new_pyobj(pgid.as_raw())) + let pgid = rustpython_host_env::posix::getpgid(pid).map_err(|e| e.into_pyexception(vm))?; + Ok(vm.new_pyobj(pgid)) } #[pyfunction] fn getpgrp(vm: &VirtualMachine) -> PyObjectRef { - vm.ctx.new_int(unistd::getpgrp().as_raw()).into() + vm.ctx.new_int(rustpython_host_env::posix::getpgrp()).into() } #[cfg(not(target_os = "redox"))] #[pyfunction] fn getsid(pid: u32, vm: &VirtualMachine) -> PyResult { - let sid = - unistd::getsid(Some(Pid::from_raw(pid as i32))).map_err(|e| e.into_pyexception(vm))?; - Ok(vm.new_pyobj(sid.as_raw())) + let sid = rustpython_host_env::posix::getsid(pid).map_err(|e| e.into_pyexception(vm))?; + Ok(vm.new_pyobj(sid)) } #[pyfunction] fn getuid(vm: &VirtualMachine) -> PyObjectRef { - let uid = unistd::getuid().as_raw(); + let uid = rustpython_host_env::posix::getuid(); vm.ctx.new_int(uid).into() } #[pyfunction] fn geteuid(vm: &VirtualMachine) -> PyObjectRef { - let euid = unistd::geteuid().as_raw(); + let euid = rustpython_host_env::posix::geteuid(); vm.ctx.new_int(euid).into() } #[cfg(not(any(target_os = "wasi", target_os = "android")))] #[pyfunction] - fn setgid(gid: Gid, vm: &VirtualMachine) -> PyResult<()> { - unistd::setgid(gid).map_err(|err| err.into_pyexception(vm)) + fn setgid(gid: RawGid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setgid(gid.0).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] #[pyfunction] - fn setegid(egid: Gid, vm: &VirtualMachine) -> PyResult<()> { - unistd::setegid(egid).map_err(|err| err.into_pyexception(vm)) + fn setegid(egid: RawGid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setegid(egid.0).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn setpgid(pid: u32, pgid: u32, vm: &VirtualMachine) -> PyResult<()> { - unistd::setpgid(Pid::from_raw(pid as i32), Pid::from_raw(pgid as i32)) - .map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::setpgid(pid, pgid).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn setpgrp(vm: &VirtualMachine) -> PyResult<()> { // setpgrp() is equivalent to setpgid(0, 0) - unistd::setpgid(Pid::from_raw(0), Pid::from_raw(0)).map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::setpgrp().map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "redox")))] #[pyfunction] fn setsid(vm: &VirtualMachine) -> PyResult<()> { - unistd::setsid() - .map(|_ok| ()) - .map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::setsid().map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "redox")))] @@ -1369,9 +1159,7 @@ pub mod module { fn tcgetpgrp(fd: i32, vm: &VirtualMachine) -> PyResult { use std::os::fd::BorrowedFd; let fd = unsafe { BorrowedFd::borrow_raw(fd) }; - unistd::tcgetpgrp(fd) - .map(|pid| pid.as_raw()) - .map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::tcgetpgrp(fd).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "redox")))] @@ -1379,7 +1167,7 @@ pub mod module { fn tcsetpgrp(fd: i32, pgid: libc::pid_t, vm: &VirtualMachine) -> PyResult<()> { use std::os::fd::BorrowedFd; let fd = unsafe { BorrowedFd::borrow_raw(fd) }; - unistd::tcsetpgrp(fd, Pid::from_raw(pgid)).map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::tcsetpgrp(fd, pgid).map_err(|err| err.into_pyexception(vm)) } fn try_from_id(vm: &VirtualMachine, obj: PyObjectRef, typ_name: &str) -> PyResult { @@ -1408,35 +1196,40 @@ pub mod module { } } - impl TryFromObject for Uid { + #[derive(Clone, Copy)] + struct RawUid(u32); + + #[derive(Clone, Copy)] + struct RawGid(u32); + + impl TryFromObject for RawUid { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - try_from_id(vm, obj, "uid").map(Self::from_raw) + try_from_id(vm, obj, "uid").map(Self) } } - impl TryFromObject for Gid { + impl TryFromObject for RawGid { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - try_from_id(vm, obj, "gid").map(Self::from_raw) + try_from_id(vm, obj, "gid").map(Self) } } #[cfg(not(any(target_os = "wasi", target_os = "android")))] #[pyfunction] - fn setuid(uid: Uid) -> nix::Result<()> { - unistd::setuid(uid) + fn setuid(uid: RawUid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setuid(uid.0).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] #[pyfunction] - fn seteuid(euid: Uid) -> nix::Result<()> { - unistd::seteuid(euid) + fn seteuid(euid: RawUid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::seteuid(euid.0).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] #[pyfunction] - fn setreuid(ruid: Uid, euid: Uid) -> nix::Result<()> { - let ret = unsafe { libc::setreuid(ruid.as_raw(), euid.as_raw()) }; - nix::Error::result(ret).map(drop) + fn setreuid(ruid: RawUid, euid: RawUid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setreuid(ruid.0, euid.0).map_err(|err| err.into_pyexception(vm)) } // cfg from nix @@ -1447,107 +1240,96 @@ pub mod module { target_os = "openbsd" ))] #[pyfunction] - fn setresuid(ruid: Uid, euid: Uid, suid: Uid) -> nix::Result<()> { - unistd::setresuid(ruid, euid, suid) + fn setresuid(ruid: RawUid, euid: RawUid, suid: RawUid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setresuid(ruid.0, euid.0, suid.0) + .map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] #[pyfunction] fn openpty(vm: &VirtualMachine) -> PyResult<(OwnedFd, OwnedFd)> { - let r = nix::pty::openpty(None, None).map_err(|err| err.into_pyexception(vm))?; - for fd in [&r.master, &r.slave] { - super::set_inheritable(fd.as_fd(), false).map_err(|e| e.into_pyexception(vm))?; - } - Ok((r.master, r.slave)) + rustpython_host_env::posix::openpty().map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn ttyname(fd: BorrowedFd<'_>, vm: &VirtualMachine) -> PyResult { - let name = unistd::ttyname(fd).map_err(|e| e.into_pyexception(vm))?; - let name = name.into_os_string().into_string().unwrap(); + let name = rustpython_host_env::posix::ttyname(fd).map_err(|e| e.into_pyexception(vm))?; + let name = name.into_string().unwrap(); Ok(vm.ctx.new_str(name).into()) } #[pyfunction] fn umask(mask: libc::mode_t) -> libc::mode_t { - unsafe { libc::umask(mask) } + rustpython_host_env::posix::umask(mask) } #[pyfunction] - fn uname() -> _os::UnameResultData { - let info = rustix::system::uname(); - _os::UnameResultData { - sysname: info.sysname().to_string_lossy().into(), - nodename: info.nodename().to_string_lossy().into(), - release: info.release().to_string_lossy().into(), - version: info.version().to_string_lossy().into(), - machine: info.machine().to_string_lossy().into(), - } + fn uname(vm: &VirtualMachine) -> PyResult<_os::UnameResultData> { + let info = rustpython_host_env::posix::uname_info() + .map_err(|err| vm.new_unicode_decode_error(err.to_string()))?; + Ok(_os::UnameResultData { + sysname: info.sysname, + nodename: info.nodename, + release: info.release, + version: info.version, + machine: info.machine, + }) } #[pyfunction] fn sync() { #[cfg(not(any(target_os = "redox", target_os = "android")))] - unsafe { - libc::sync(); - } + rustpython_host_env::posix::sync(); } // cfg from nix #[cfg(any(target_os = "android", target_os = "linux", target_os = "openbsd"))] #[pyfunction] - fn getresuid() -> nix::Result<(u32, u32, u32)> { - let ret = unistd::getresuid()?; - Ok(( - ret.real.as_raw(), - ret.effective.as_raw(), - ret.saved.as_raw(), - )) + fn getresuid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { + rustpython_host_env::posix::getresuid().map_err(|err| err.into_pyexception(vm)) } // cfg from nix #[cfg(any(target_os = "android", target_os = "linux", target_os = "openbsd"))] #[pyfunction] - fn getresgid() -> nix::Result<(u32, u32, u32)> { - let ret = unistd::getresgid()?; - Ok(( - ret.real.as_raw(), - ret.effective.as_raw(), - ret.saved.as_raw(), - )) + fn getresgid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { + rustpython_host_env::posix::getresgid().map_err(|err| err.into_pyexception(vm)) } // cfg from nix #[cfg(any(target_os = "freebsd", target_os = "linux", target_os = "openbsd"))] #[pyfunction] - fn setresgid(rgid: Gid, egid: Gid, sgid: Gid, vm: &VirtualMachine) -> PyResult<()> { - unistd::setresgid(rgid, egid, sgid).map_err(|err| err.into_pyexception(vm)) + fn setresgid(rgid: RawGid, egid: RawGid, sgid: RawGid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setresgid(rgid.0, egid.0, sgid.0) + .map_err(|err| err.into_pyexception(vm)) } #[cfg(not(any(target_os = "wasi", target_os = "android", target_os = "redox")))] #[pyfunction] - fn setregid(rgid: Gid, egid: Gid) -> nix::Result<()> { - let ret = unsafe { libc::setregid(rgid.as_raw(), egid.as_raw()) }; - nix::Error::result(ret).map(drop) + fn setregid(rgid: RawGid, egid: RawGid, vm: &VirtualMachine) -> PyResult<()> { + rustpython_host_env::posix::setregid(rgid.0, egid.0).map_err(|err| err.into_pyexception(vm)) } // cfg from nix #[cfg(any(target_os = "freebsd", target_os = "linux", target_os = "openbsd"))] #[pyfunction] - fn initgroups(user_name: PyUtf8StrRef, gid: Gid, vm: &VirtualMachine) -> PyResult<()> { + fn initgroups(user_name: PyUtf8StrRef, gid: RawGid, vm: &VirtualMachine) -> PyResult<()> { let user = user_name.to_cstring(vm)?; - unistd::initgroups(&user, gid).map_err(|err| err.into_pyexception(vm)) + rustpython_host_env::posix::initgroups(&user, gid.0).map_err(|err| err.into_pyexception(vm)) } // cfg from nix #[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "redox")))] #[pyfunction] fn setgroups( - group_ids: crate::function::ArgIterable, + group_ids: crate::function::ArgIterable, vm: &VirtualMachine, ) -> PyResult<()> { - let gids = group_ids.iter(vm)?.collect::, _>>()?; - unistd::setgroups(&gids).map_err(|err| err.into_pyexception(vm)) + let gids = group_ids + .iter(vm)? + .map(|gid| gid.map(|gid| gid.0)) + .collect::, _>>()?; + rustpython_host_env::posix::setgroups_raw(&gids).map_err(|err| err.into_pyexception(vm)) } #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] @@ -1637,8 +1419,6 @@ pub mod module { #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] impl PosixSpawnArgs { fn spawn(self, spawnp: bool, vm: &VirtualMachine) -> PyResult { - use nix::sys::signal; - use crate::TryFromBorrowedObject; let path = self @@ -1647,8 +1427,7 @@ pub mod module { .into_cstring(vm) .map_err(|_| vm.new_value_error("path should not have nul bytes"))?; - let mut file_actions = - nix::spawn::PosixSpawnFileActions::init().map_err(|e| e.into_pyexception(vm))?; + let mut file_actions = Vec::new(); if let Some(it) = self.file_actions { for action in it.iter(vm)? { let action = action?; @@ -1659,7 +1438,7 @@ pub mod module { let id = PosixSpawnFileActionIdentifier::try_from(id) .map_err(|_| vm.new_type_error("Unknown file_actions identifier"))?; let args: crate::function::FuncArgs = args.to_vec().into(); - let ret = match id { + let parsed = match id { PosixSpawnFileActionIdentifier::Open => { let (fd, path, oflag, mode): (_, OsPath, _, _) = args.bind(vm)?; let path = CString::new(path.into_bytes()).map_err(|_| { @@ -1667,90 +1446,40 @@ pub mod module { "POSIX_SPAWN_OPEN path should not have nul bytes", ) })?; - let oflag = nix::fcntl::OFlag::from_bits_retain(oflag); - let mode = nix::sys::stat::Mode::from_bits_retain(mode); - file_actions.add_open(fd, &*path, oflag, mode) + rustpython_host_env::posix::PosixSpawnFileAction::Open { + fd, + path, + oflag, + mode, + } } PosixSpawnFileActionIdentifier::Close => { let (fd,) = args.bind(vm)?; - file_actions.add_close(fd) + rustpython_host_env::posix::PosixSpawnFileAction::Close { fd } } PosixSpawnFileActionIdentifier::Dup2 => { let (fd, newfd) = args.bind(vm)?; - file_actions.add_dup2(fd, newfd) + rustpython_host_env::posix::PosixSpawnFileAction::Dup2 { fd, newfd } } }; - if let Err(err) = ret { - let err = err.into(); - return Err(OSErrorBuilder::with_filename(&err, self.path, vm)); - } - } - } - - let mut attrp = - nix::spawn::PosixSpawnAttr::init().map_err(|e| e.into_pyexception(vm))?; - let mut flags = nix::spawn::PosixSpawnFlags::empty(); - - if let Some(sigs) = self.setsigdef { - let mut set = signal::SigSet::empty(); - for sig in sigs.iter(vm)? { - let sig = sig?; - let sig = signal::Signal::try_from(sig).map_err(|_| { - vm.new_value_error(format!("signal number {sig} out of range")) - })?; - set.add(sig); + file_actions.push(parsed); } - attrp - .set_sigdefault(&set) - .map_err(|e| e.into_pyexception(vm))?; - flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_SETSIGDEF); - } - - if let Some(pgid) = self.setpgroup { - attrp - .set_pgroup(nix::unistd::Pid::from_raw(pgid)) - .map_err(|e| e.into_pyexception(vm))?; - flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_SETPGROUP); - } - - if self.resetids { - flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_RESETIDS); } - if self.setsid { - // Note: POSIX_SPAWN_SETSID may not be available on all platforms - cfg_select! { - any( - target_os = "linux", - target_os = "haiku", - target_os = "solaris", - target_os = "illumos", - target_os = "hurd", - ) => { - flags.insert(nix::spawn::PosixSpawnFlags::from_bits_retain(libc::POSIX_SPAWN_SETSID)); - } - _ => { - return Err(vm.new_not_implemented_error( - "setsid parameter is not supported on this platform", - )); + let setsigdef = self + .setsigdef + .map(|sigs| { + let sigs = sigs.iter(vm)?.collect::>>()?; + for &sig in &sigs { + if !rustpython_host_env::posix::validate_posix_spawn_signal(sig) { + return Err( + vm.new_value_error(format!("signal number {sig} out of range")) + ); + } } - } - } - - if let Some(sigs) = self.setsigmask { - let mut set = signal::SigSet::empty(); - for sig in sigs.iter(vm)? { - let sig = sig?; - let sig = signal::Signal::try_from(sig).map_err(|_| { - vm.new_value_error(format!("signal number {sig} out of range")) - })?; - set.add(sig); - } - attrp - .set_sigmask(&set) - .map_err(|e| e.into_pyexception(vm))?; - flags.insert(nix::spawn::PosixSpawnFlags::POSIX_SPAWN_SETSIGMASK); - } + Ok(sigs) + }) + .transpose()?; if let Some(_scheduler) = self.scheduler { // TODO: Implement scheduler parameter handling @@ -1760,10 +1489,27 @@ pub mod module { ); } - if !flags.is_empty() { - attrp.set_flags(flags).map_err(|e| e.into_pyexception(vm))?; + if self.setsid && !rustpython_host_env::posix::supports_posix_spawn_setsid() { + return Err(vm.new_not_implemented_error( + "setsid parameter is not supported on this platform", + )); } + let setsigmask = self + .setsigmask + .map(|sigs| { + let sigs = sigs.iter(vm)?.collect::>>()?; + for &sig in &sigs { + if !rustpython_host_env::posix::validate_posix_spawn_signal(sig) { + return Err( + vm.new_value_error(format!("signal number {sig} out of range")) + ); + } + } + Ok(sigs) + }) + .transpose()?; + let args: Vec = self .args .iter(vm)? @@ -1789,13 +1535,19 @@ pub mod module { .collect::>>()? }; - let ret = if spawnp { - nix::spawn::posix_spawnp(&path, &file_actions, &attrp, &args, &env) - } else { - nix::spawn::posix_spawn(&*path, &file_actions, &attrp, &args, &env) - }; - ret.map(Into::into) - .map_err(|err| OSErrorBuilder::with_filename(&err.into(), self.path, vm)) + rustpython_host_env::posix::posix_spawn(rustpython_host_env::posix::PosixSpawnConfig { + path: &path, + args: &args, + env: &env, + file_actions: &file_actions, + setsigdef: setsigdef.as_deref(), + setpgroup: self.setpgroup, + resetids: self.resetids, + setsid: self.setsid, + setsigmask: setsigmask.as_deref(), + spawnp, + }) + .map_err(|err| OSErrorBuilder::with_filename(&err, self.path, vm)) } } @@ -1813,42 +1565,42 @@ pub mod module { #[pyfunction(name = "WCOREDUMP")] fn wcoredump(status: i32) -> bool { - libc::WCOREDUMP(status) + rustpython_host_env::posix::wcoredump(status) } #[pyfunction(name = "WIFCONTINUED")] fn wifcontinued(status: i32) -> bool { - libc::WIFCONTINUED(status) + rustpython_host_env::posix::wifcontinued(status) } #[pyfunction(name = "WIFSTOPPED")] fn wifstopped(status: i32) -> bool { - libc::WIFSTOPPED(status) + rustpython_host_env::posix::wifstopped(status) } #[pyfunction(name = "WIFSIGNALED")] fn wifsignaled(status: i32) -> bool { - libc::WIFSIGNALED(status) + rustpython_host_env::posix::wifsignaled(status) } #[pyfunction(name = "WIFEXITED")] fn wifexited(status: i32) -> bool { - libc::WIFEXITED(status) + rustpython_host_env::posix::wifexited(status) } #[pyfunction(name = "WEXITSTATUS")] fn wexitstatus(status: i32) -> i32 { - libc::WEXITSTATUS(status) + rustpython_host_env::posix::wexitstatus(status) } #[pyfunction(name = "WSTOPSIG")] fn wstopsig(status: i32) -> i32 { - libc::WSTOPSIG(status) + rustpython_host_env::posix::wstopsig(status) } #[pyfunction(name = "WTERMSIG")] fn wtermsig(status: i32) -> i32 { - libc::WTERMSIG(status) + rustpython_host_env::posix::wtermsig(status) } #[cfg(target_os = "linux")] @@ -1859,33 +1611,23 @@ pub mod module { vm: &VirtualMachine, ) -> PyResult { let flags = flags.unwrap_or(0); - let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, flags) as libc::c_long }; - if fd == -1 { - Err(vm.new_last_errno_error()) - } else { - // Safety: syscall returns a new owned file descriptor. - Ok(unsafe { OwnedFd::from_raw_fd(fd as libc::c_int) }) - } + rustpython_host_env::posix::pidfd_open(pid, flags).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn waitpid(pid: libc::pid_t, opt: i32, vm: &VirtualMachine) -> PyResult<(libc::pid_t, i32)> { let mut status = 0; loop { - // Capture errno inside the closure: attach_thread (called by - // allow_threads on return) can clobber errno via syscalls. - let (res, err) = vm.allow_threads(|| { - let r = unsafe { libc::waitpid(pid, &mut status, opt) }; - (r, nix::Error::last_raw()) - }); - if res == -1 { - if err == libc::EINTR { + let res = + vm.allow_threads(|| rustpython_host_env::posix::waitpid(pid, &mut status, opt)); + match res { + Err(err) if err.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; continue; } - return Err(nix::Error::from_raw(err).into_pyexception(vm)); + Err(err) => return Err(err.into_pyexception(vm)), + Ok(res) => return Ok((res, status)), } - return Ok((res, status)); } } @@ -1896,14 +1638,7 @@ pub mod module { #[pyfunction] fn kill(pid: i32, sig: isize, vm: &VirtualMachine) -> PyResult<()> { - { - let ret = unsafe { libc::kill(pid, sig as i32) }; - if ret == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } - } + rustpython_host_env::posix::kill(pid, sig as i32).map_err(|err| err.into_pyexception(vm)) } #[pyfunction] @@ -1911,50 +1646,23 @@ pub mod module { fd: OptionalArg, vm: &VirtualMachine, ) -> PyResult<_os::TerminalSizeData> { - let (columns, lines) = { - nix::ioctl_read_bad!(winsz, libc::TIOCGWINSZ, libc::winsize); - let mut w = libc::winsize { - ws_row: 0, - ws_col: 0, - ws_xpixel: 0, - ws_ypixel: 0, - }; - unsafe { winsz(fd.unwrap_or(libc::STDOUT_FILENO), &mut w) } + let (columns, lines) = + rustpython_host_env::posix::get_terminal_size(fd.unwrap_or(libc::STDOUT_FILENO)) + .map(|(columns, lines)| (columns.into(), lines.into())) .map_err(|err| err.into_pyexception(vm))?; - (w.ws_col.into(), w.ws_row.into()) - }; Ok(_os::TerminalSizeData { columns, lines }) } - // from libstd: - // https://github.com/rust-lang/rust/blob/daecab3a784f28082df90cebb204998051f3557d/src/libstd/sys/unix/fs.rs#L1251 - #[cfg(target_os = "macos")] - unsafe extern "C" { - fn fcopyfile( - in_fd: libc::c_int, - out_fd: libc::c_int, - state: *mut libc::c_void, // copyfile_state_t (unused) - flags: u32, // copyfile_flags_t - ) -> libc::c_int; - } - #[cfg(target_os = "macos")] #[pyfunction] fn _fcopyfile(in_fd: i32, out_fd: i32, flags: i32, vm: &VirtualMachine) -> PyResult<()> { - let ret = unsafe { fcopyfile(in_fd, out_fd, core::ptr::null_mut(), flags as u32) }; - if ret < 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + rustpython_host_env::posix::fcopyfile(in_fd, out_fd, flags as u32) + .map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn dup(fd: BorrowedFd<'_>, vm: &VirtualMachine) -> PyResult { - let fd = nix::unistd::dup(fd).map_err(|e| e.into_pyexception(vm))?; - super::set_inheritable(fd.as_fd(), false) - .map(|()| fd) - .map_err(|e| e.into_pyexception(vm)) + rustpython_host_env::posix::dup_noninheritable(fd).map_err(|e| e.into_pyexception(vm)) } #[derive(FromArgs)] @@ -1969,13 +1677,8 @@ pub mod module { #[pyfunction] fn dup2(args: Dup2Args<'_>, vm: &VirtualMachine) -> PyResult { - let mut fd2 = core::mem::ManuallyDrop::new(args.fd2); - nix::unistd::dup2(args.fd, &mut fd2).map_err(|e| e.into_pyexception(vm))?; - let fd2 = core::mem::ManuallyDrop::into_inner(fd2); - if !args.inheritable { - super::set_inheritable(fd2.as_fd(), false).map_err(|e| e.into_pyexception(vm))? - } - Ok(fd2) + rustpython_host_env::posix::dup2(args.fd, args.fd2, args.inheritable) + .map_err(|e| e.into_pyexception(vm)) } pub(crate) fn support_funcs() -> Vec { @@ -2013,12 +1716,10 @@ pub mod module { // Get a pointer to the login name string. The string is statically // allocated and might be overwritten on subsequent calls to this // function or to `cuserid()`. See man getlogin(3) for more information. - let ptr = unsafe { libc::getlogin() }; - if ptr.is_null() { + let Some(login) = rustpython_host_env::posix::getlogin() else { return Err(vm.new_os_error("unable to determine login name")); - } - let slice = unsafe { CStr::from_ptr(ptr) }; - slice + }; + login .to_str() .map(|s| s.to_owned()) .map_err(|e| vm.new_unicode_decode_error(format!("unable to decode login name: {e}"))) @@ -2038,56 +1739,33 @@ pub mod module { vm: &VirtualMachine, ) -> PyResult> { let user = user.to_cstring(vm)?; - let gid = Gid::from_raw(group); - let group_ids = unistd::getgrouplist(&user, gid).map_err(|err| err.into_pyexception(vm))?; - Ok(group_ids - .into_iter() - .map(|gid| vm.new_pyobj(gid.as_raw())) - .collect()) + let group_ids = rustpython_host_env::posix::getgrouplist(&user, group) + .map_err(|err| err.into_pyexception(vm))?; + Ok(group_ids.into_iter().map(|gid| vm.new_pyobj(gid)).collect()) } - #[cfg(not(target_os = "redox"))] - type PriorityWhichType = cfg_select! { - all(target_os = "linux", target_env = "gnu") => libc::__priority_which_t, - _ => libc::c_int, - }; - - #[cfg(not(target_os = "redox"))] - type PriorityWhoType = cfg_select! { - target_os = "freebsd" => i32, - _ => u32, - }; - #[cfg(not(target_os = "redox"))] #[pyfunction] fn getpriority( - which: PriorityWhichType, - who: PriorityWhoType, + which: rustpython_host_env::posix::PriorityWhichType, + who: rustpython_host_env::posix::PriorityWhoType, vm: &VirtualMachine, ) -> PyResult { - Errno::clear(); - let retval = unsafe { libc::getpriority(which, who) }; - if Errno::last_raw() != 0 { - Err(vm.new_last_errno_error()) - } else { - Ok(vm.ctx.new_int(retval).into()) - } + rustpython_host_env::posix::getpriority(which, who) + .map(|retval| vm.ctx.new_int(retval).into()) + .map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] #[pyfunction] fn setpriority( - which: PriorityWhichType, - who: PriorityWhoType, + which: rustpython_host_env::posix::PriorityWhichType, + who: rustpython_host_env::posix::PriorityWhoType, priority: i32, vm: &VirtualMachine, ) -> PyResult<()> { - let retval = unsafe { libc::setpriority(which, who, priority) }; - if retval == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(()) - } + rustpython_host_env::posix::setpriority(which, who, priority) + .map_err(|err| err.into_pyexception(vm)) } struct PathconfName(i32); @@ -2110,8 +1788,7 @@ pub mod module { } } - // Copy from [nix::unistd::PathconfVar](https://docs.rs/nix/0.21.0/nix/unistd/enum.PathconfVar.html) - // Change enum name to fit python doc + // Mirror the libc pathconf constants as Python-facing names. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, EnumIter, EnumString)] #[repr(i32)] #[allow(non_camel_case_types)] @@ -2284,28 +1961,14 @@ pub mod module { PathconfName(name): PathconfName, vm: &VirtualMachine, ) -> PyResult> { - Errno::clear(); - debug_assert_eq!(Errno::last_raw(), 0); - let raw = match &path { + match &path { OsPathOrFd::Path(path) => { - let path = path.clone().into_cstring(vm)?; - unsafe { libc::pathconf(path.as_ptr(), name) } - } - OsPathOrFd::Fd(fd) => unsafe { libc::fpathconf(fd.as_raw(), name) }, - }; - - if raw == -1 { - if Errno::last_raw() == 0 { - Ok(None) - } else { - Err(OSErrorBuilder::with_filename( - &io::Error::from(Errno::last()), - path, - vm, - )) + let c_path = path.clone().into_cstring(vm)?; + rustpython_host_env::posix::pathconf(&c_path, name) + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm)) } - } else { - Ok(Some(raw)) + OsPathOrFd::Fd(fd) => rustpython_host_env::posix::fpathconf(fd.as_raw(), name) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)), } } @@ -2511,12 +2174,7 @@ pub mod module { #[pyfunction] fn sysconf(name: SysconfName, vm: &VirtualMachine) -> PyResult { - crate::host_env::os::set_errno(0); - let r = unsafe { libc::sysconf(name.0) }; - if r == -1 && crate::host_env::os::get_errno() != 0 { - return Err(vm.new_last_errno_error()); - } - Ok(r) + rustpython_host_env::posix::sysconf(name.0).map_err(|err| err.into_pyexception(vm)) } #[pyattr] @@ -2559,10 +2217,10 @@ pub mod module { fn sendfile(args: SendFileArgs<'_>, vm: &VirtualMachine) -> PyResult { let mut file_offset = args.offset; - let res = nix::sys::sendfile::sendfile( + let res = rustpython_host_env::posix::sendfile( args.out_fd, args.in_fd, - Some(&mut file_offset), + &mut file_offset, args.count as usize, ) .map_err(|err| err.into_pyexception(vm))?; @@ -2609,11 +2267,11 @@ pub mod module { .map(|v| v.iter().map(|borrowed| &**borrowed).collect::>()); let trailers = trailers.as_deref(); - let (res, written) = nix::sys::sendfile::sendfile( + let (res, written) = rustpython_host_env::posix::sendfile( args.in_fd, args.out_fd, args.offset, - Some(count), + count, headers, trailers, ); @@ -2628,11 +2286,6 @@ pub mod module { Ok(vm.ctx.new_int(written as u64).into()) } - #[cfg(target_os = "linux")] - unsafe fn sys_getrandom(buf: *mut libc::c_void, buflen: usize, flags: u32) -> isize { - unsafe { libc::syscall(libc::SYS_getrandom, buf, buflen, flags as usize) as _ } - } - #[cfg(target_os = "linux")] #[pyfunction] fn getrandom(size: isize, flags: OptionalArg, vm: &VirtualMachine) -> PyResult> { @@ -2640,12 +2293,11 @@ pub mod module { .map_err(|_| vm.new_os_error(format!("Invalid argument for size: {size}")))?; let mut buf = Vec::with_capacity(size); unsafe { - let len = sys_getrandom( + let len = rustpython_host_env::posix::getrandom( buf.as_mut_ptr() as *mut libc::c_void, size, flags.unwrap_or(0), ) - .try_into() .map_err(|_| vm.new_last_os_error())?; buf.set_len(len); } @@ -2671,8 +2323,11 @@ pub mod module { #[pymodule(sub)] mod posix_sched { use crate::{ - AsObject, Py, PyObjectRef, PyResult, VirtualMachine, builtins::PyTupleRef, - convert::ToPyObject, function::FuncArgs, types::PyStructSequence, + AsObject, Py, PyObjectRef, PyResult, VirtualMachine, + builtins::PyTupleRef, + convert::{IntoPyException, ToPyObject}, + function::FuncArgs, + types::PyStructSequence, }; #[derive(FromArgs)] @@ -2750,12 +2405,7 @@ mod posix_sched { #[pyfunction] fn sched_getscheduler(pid: libc::pid_t, vm: &VirtualMachine) -> PyResult { - let policy = unsafe { libc::sched_getscheduler(pid) }; - if policy == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(policy) - } + rustpython_host_env::posix::sched_getscheduler(pid).map_err(|err| err.into_pyexception(vm)) } #[cfg(not(target_env = "musl"))] @@ -2773,23 +2423,14 @@ mod posix_sched { #[pyfunction] fn sched_setscheduler(args: SchedSetschedulerArgs, vm: &VirtualMachine) -> PyResult { let libc_sched_param = convert_sched_param(&args.sched_param, vm)?; - let policy = unsafe { libc::sched_setscheduler(args.pid, args.policy, &libc_sched_param) }; - if policy == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(policy) - } + rustpython_host_env::posix::sched_setscheduler(args.pid, args.policy, &libc_sched_param) + .map_err(|err| err.into_pyexception(vm)) } #[pyfunction] fn sched_getparam(pid: libc::pid_t, vm: &VirtualMachine) -> PyResult { - let param = unsafe { - let mut param = core::mem::MaybeUninit::uninit(); - if -1 == libc::sched_getparam(pid, param.as_mut_ptr()) { - return Err(vm.new_last_errno_error()); - } - param.assume_init() - }; + let param = rustpython_host_env::posix::sched_getparam(pid) + .map_err(|err| err.into_pyexception(vm))?; Ok(PySchedParam::from_data( SchedParamData { sched_priority: param.sched_priority.to_pyobject(vm), @@ -2811,11 +2452,7 @@ mod posix_sched { #[pyfunction] fn sched_setparam(args: SchedSetParamArgs, vm: &VirtualMachine) -> PyResult { let libc_sched_param = convert_sched_param(&args.sched_param, vm)?; - let ret = unsafe { libc::sched_setparam(args.pid, &libc_sched_param) }; - if ret == -1 { - Err(vm.new_last_errno_error()) - } else { - Ok(ret) - } + rustpython_host_env::posix::sched_setparam(args.pid, &libc_sched_param) + .map_err(|err| err.into_pyexception(vm)) } } diff --git a/crates/vm/src/stdlib/pwd.rs b/crates/vm/src/stdlib/pwd.rs index b898625906f..e2f987ce019 100644 --- a/crates/vm/src/stdlib/pwd.rs +++ b/crates/vm/src/stdlib/pwd.rs @@ -11,7 +11,7 @@ mod pwd { exceptions, types::PyStructSequence, }; - use nix::unistd::{self, User}; + use rustpython_host_env::pwd as host_pwd; #[cfg(not(target_os = "android"))] use crate::{PyObjectRef, convert::ToPyObject}; @@ -34,26 +34,16 @@ mod pwd { #[pyclass(with(PyStructSequence))] impl PyPasswd {} - impl From for PasswdData { - fn from(user: User) -> Self { - // this is just a pain... - let cstr_lossy = |s: alloc::ffi::CString| { - s.into_string() - .unwrap_or_else(|e| e.into_cstring().to_string_lossy().into_owned()) - }; - let pathbuf_lossy = |p: std::path::PathBuf| { - p.into_os_string() - .into_string() - .unwrap_or_else(|s| s.to_string_lossy().into_owned()) - }; + impl From for PasswdData { + fn from(user: host_pwd::Passwd) -> Self { Self { pw_name: user.name, - pw_passwd: cstr_lossy(user.passwd), - pw_uid: user.uid.as_raw(), - pw_gid: user.gid.as_raw(), - pw_gecos: cstr_lossy(user.gecos), - pw_dir: pathbuf_lossy(user.dir), - pw_shell: pathbuf_lossy(user.shell), + pw_passwd: user.passwd, + pw_uid: user.uid, + pw_gid: user.gid, + pw_gecos: user.gecos, + pw_dir: user.dir, + pw_shell: user.shell, } } } @@ -64,7 +54,7 @@ mod pwd { if pw_name.contains('\0') { return Err(exceptions::cstring_error(vm)); } - let user = User::from_name(name.as_str()).ok().flatten(); + let user = host_pwd::getpwnam(name.as_str()); let user = user.ok_or_else(|| { vm.new_key_error( vm.ctx @@ -77,11 +67,9 @@ mod pwd { #[pyfunction] fn getpwuid(uid: PyIntRef, vm: &VirtualMachine) -> PyResult { - let uid_t = libc::uid_t::try_from(uid.as_bigint()) - .map(unistd::Uid::from_raw) - .ok(); + let uid_t = libc::uid_t::try_from(uid.as_bigint()).ok(); let user = uid_t - .map(User::from_uid) + .map(host_pwd::getpwuid) .transpose() .map_err(|err| err.into_pyexception(vm))? .flatten(); @@ -99,19 +87,10 @@ mod pwd { #[cfg(not(target_os = "android"))] #[pyfunction] fn getpwall(vm: &VirtualMachine) -> Vec { - // setpwent, getpwent, etc are not thread safe. Could use fgetpwent_r, but this is easier - static GETPWALL: parking_lot::Mutex<()> = parking_lot::Mutex::new(()); - let _guard = GETPWALL.lock(); - let mut list = Vec::new(); - - unsafe { libc::setpwent() }; - while let Some(ptr) = core::ptr::NonNull::new(unsafe { libc::getpwent() }) { - let user = User::from(unsafe { ptr.as_ref() }); - let passwd = PasswdData::from(user).to_pyobject(vm); - list.push(passwd); - } - unsafe { libc::endpwent() }; - - list + host_pwd::getpwall() + .into_iter() + .map(PasswdData::from) + .map(|passwd| passwd.to_pyobject(vm)) + .collect() } } diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 02beec87bc5..5e3ba265eaa 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -57,15 +57,6 @@ pub mod sys { io::{IsTerminal, Read, Write}, }; - #[cfg(windows)] - use windows_sys::Win32::{ - Foundation::MAX_PATH, - Storage::FileSystem::{ - GetFileVersionInfoSizeW, GetFileVersionInfoW, VS_FIXEDFILEINFO, VerQueryValueW, - }, - System::LibraryLoader::{GetModuleFileNameW, GetModuleHandleW}, - }; - // Rust target triple (e.g., "x86_64-unknown-linux-gnu") pub(crate) const RUST_MULTIARCH: &str = env!("RUSTPYTHON_TARGET_TRIPLE"); @@ -658,7 +649,7 @@ pub mod sys { fn _git(vm: &VirtualMachine) -> PyTupleRef { vm.new_tuple(( ascii!("RustPython"), - version::get_git_identifier(), + version::GIT_IDENTIFIER, version::GIT_REVISION, )) } @@ -667,7 +658,8 @@ pub mod sys { fn implementation(vm: &VirtualMachine) -> PyRef { const NAME: &str = "rustpython"; - let cache_tag = format!("{NAME}-{}_{}", version::MAJOR_IMPL, version::MINOR_IMPL); + // cache tag uses 'cpython' because our compiler is cpython compatible + let cache_tag = format!("cpython-{}{}", version::MAJOR, version::MINOR); let ctx = &vm.ctx; py_namespace!(vm, { "name" => ctx.new_str(NAME), @@ -715,17 +707,13 @@ pub mod sys { vm.ctx.none() } - #[pyattr] - fn version(_vm: &VirtualMachine) -> String { - version::get_version() - } + #[pyattr(name = "version")] + const VERSION: &str = version::RUSTPYTHON_VERSION; + // Note: This is Python DLL version in CPython, but we arbitrary fill it for compatibility #[cfg(windows)] - #[pyattr] - fn winver(_vm: &VirtualMachine) -> String { - // Note: This is Python DLL version in CPython, but we arbitrary fill it for compatibility - version::get_winver_number() - } + #[pyattr(name = "winver")] + const WINVER: &str = version::WINVER; #[pyattr] fn _xoptions(vm: &VirtualMachine) -> PyDictRef { @@ -767,11 +755,6 @@ pub mod sys { Ok(()) } - #[pyfunction] - fn audit(_args: FuncArgs) { - // TODO: sys.audit implementation - } - #[pyfunction] const fn _is_gil_enabled() -> bool { false // RustPython has no GIL (like free-threaded Python) @@ -985,25 +968,40 @@ pub mod sys { #[pyfunction] fn _getframe(offset: OptionalArg, vm: &VirtualMachine) -> PyResult { let offset = offset.into_option().unwrap_or(0); - let frames = vm.frames.borrow(); - if offset >= frames.len() { - return Err(vm.new_value_error("call stack is not deep enough")); + let frame_ref = { + let frames = vm.frames.borrow(); + if offset >= frames.len() { + return Err(vm.new_value_error("call stack is not deep enough")); + } + + let idx = frames.len() - offset - 1; + // SAFETY: the FrameRef is alive on the call stack while it's in the Vec + let py: &crate::Py = unsafe { frames[idx].as_ref() }; + py.to_owned() + }; + + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("sys._getframe"), frame_ref.to_owned()), vm)?; } - let idx = frames.len() - offset - 1; - // SAFETY: the FrameRef is alive on the call stack while it's in the Vec - let py: &crate::Py = unsafe { frames[idx].as_ref() }; - Ok(py.to_owned()) + + Ok(frame_ref) } #[pyfunction] - fn _getframemodulename(depth: OptionalArg, vm: &VirtualMachine) -> PyObjectRef { + fn _getframemodulename( + depth: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { let depth = depth.into_option().unwrap_or(0); + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("sys._getframemodulename"), depth), vm)?; + } // Get the frame at the specified depth let func_obj = { let frames = vm.frames.borrow(); if depth >= frames.len() { - return vm.ctx.none(); + return Ok(vm.ctx.none()); } let idx = frames.len() - depth - 1; // SAFETY: the FrameRef is alive on the call stack while it's in the Vec @@ -1012,7 +1010,7 @@ pub mod sys { }; // If the frame has a function object, return its __module__ attribute - if let Some(func_obj) = func_obj { + Ok(if let Some(func_obj) = func_obj { func_obj .get_attr(identifier!(vm, __module__), vm) .unwrap_or_else( @@ -1021,7 +1019,7 @@ pub mod sys { ) } else { vm.ctx.none() - } + }) } /// Return a dictionary mapping each thread's identifier to the topmost stack frame @@ -1082,119 +1080,22 @@ pub mod sys { vm.trace_func.borrow().clone() } - #[cfg(windows)] - fn get_kernel32_version() -> std::io::Result<(u32, u32, u32)> { - use crate::host_env::windows::ToWideString; - unsafe { - // Create a wide string for "kernel32.dll" - let module_name: Vec = std::ffi::OsStr::new("kernel32.dll").to_wide_with_nul(); - let h_kernel32 = GetModuleHandleW(module_name.as_ptr()); - if h_kernel32.is_null() { - return Err(std::io::Error::last_os_error()); - } - - // Prepare a buffer for the module file path - let mut kernel32_path = [0u16; MAX_PATH as usize]; - let len = GetModuleFileNameW( - h_kernel32, - kernel32_path.as_mut_ptr(), - kernel32_path.len() as u32, - ); - if len == 0 { - return Err(std::io::Error::last_os_error()); - } - - // Get the size of the version information block - let ver_block_size = - GetFileVersionInfoSizeW(kernel32_path.as_ptr(), core::ptr::null_mut()); - if ver_block_size == 0 { - return Err(std::io::Error::last_os_error()); - } - - // Allocate a buffer to hold the version information - let mut ver_block = vec![0u8; ver_block_size as usize]; - if GetFileVersionInfoW( - kernel32_path.as_ptr(), - 0, - ver_block_size, - ver_block.as_mut_ptr() as *mut _, - ) == 0 - { - return Err(std::io::Error::last_os_error()); - } - - // Prepare an empty sub-block string (L"") as required by VerQueryValueW - let sub_block: Vec = std::ffi::OsStr::new("").to_wide_with_nul(); - - let mut ffi_ptr: *mut VS_FIXEDFILEINFO = core::ptr::null_mut(); - let mut ffi_len: u32 = 0; - if VerQueryValueW( - ver_block.as_ptr() as *const _, - sub_block.as_ptr(), - &mut ffi_ptr as *mut *mut VS_FIXEDFILEINFO as *mut *mut _, - &mut ffi_len as *mut u32, - ) == 0 - || ffi_ptr.is_null() - { - return Err(std::io::Error::last_os_error()); - } - - // Extract the version numbers from the VS_FIXEDFILEINFO structure. - let ffi = *ffi_ptr; - let real_major = (ffi.dwProductVersionMS >> 16) & 0xFFFF; - let real_minor = ffi.dwProductVersionMS & 0xFFFF; - let real_build = (ffi.dwProductVersionLS >> 16) & 0xFFFF; - - Ok((real_major, real_minor, real_build)) - } - } - #[cfg(windows)] #[pyfunction] fn getwindowsversion(vm: &VirtualMachine) -> PyResult { - use std::ffi::OsString; - use std::os::windows::ffi::OsStringExt; - use windows_sys::Win32::System::SystemInformation::{ - GetVersionExW, OSVERSIONINFOEXW, OSVERSIONINFOW, - }; - - let mut version: OSVERSIONINFOEXW = unsafe { core::mem::zeroed() }; - version.dwOSVersionInfoSize = core::mem::size_of::() as u32; - let result = unsafe { - let os_vi = &mut version as *mut OSVERSIONINFOEXW as *mut OSVERSIONINFOW; - // SAFETY: GetVersionExW accepts a pointer of OSVERSIONINFOW, but windows-sys crate's type currently doesn't allow to do so. - // https://docs.microsoft.com/en-us/windows/win32/api/sysinfoapi/nf-sysinfoapi-getversionexw#parameters - GetVersionExW(os_vi) - }; - - if result == 0 { - return Err(vm.new_os_error("failed to get windows version".to_owned())); - } - - let service_pack = { - let (last, _) = version - .szCSDVersion - .iter() - .take_while(|&x| x != &0) - .enumerate() - .last() - .unwrap_or((0, &0)); - let sp = OsString::from_wide(&version.szCSDVersion[..last]); - sp.into_string() - .map_err(|_| vm.new_os_error("service pack is not ASCII".to_owned()))? - }; - let real_version = get_kernel32_version().map_err(|e| vm.new_os_error(e.to_string()))?; + let version = crate::host_env::windows::get_windows_version() + .map_err(|e| vm.new_os_error(e.to_string()))?; let winver = WindowsVersionData { - major: real_version.0, - minor: real_version.1, - build: real_version.2, - platform: version.dwPlatformId, - service_pack, - service_pack_major: version.wServicePackMajor, - service_pack_minor: version.wServicePackMinor, - suite_mask: version.wSuiteMask, - product_type: version.wProductType, - platform_version: (real_version.0, real_version.1, real_version.2), // TODO Provide accurate version, like CPython impl + major: version.major, + minor: version.minor, + build: version.build, + platform: version.platform, + service_pack: version.service_pack, + service_pack_major: version.service_pack_major, + service_pack_minor: version.service_pack_minor, + suite_mask: version.suite_mask, + product_type: version.product_type, + platform_version: (version.major, version.minor, version.build), // TODO Provide accurate version, like CPython impl }; Ok(PyWindowsVersion::from_data(winver, vm)) } @@ -1830,6 +1731,60 @@ pub mod sys { #[pyclass(with(PyStructSequence))] impl PyUnraisableHookArgs {} + + pub(crate) fn run_audit_hooks( + event: PyStrRef, + args: &PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + let hooks = vm.audit_hooks.borrow().clone(); + + if hooks.is_empty() { + return Ok(()); + } + + for hook in hooks { + hook.call((event.clone(), args.clone()), vm)?; + } + + Ok(()) + } + + #[pyfunction] + fn audit(event: PyStrRef, args: PosArgs, vm: &VirtualMachine) -> PyResult<()> { + if vm.audit_hooks.borrow().is_empty() { + return Ok(()); + } + + let args_tup = vm.ctx.new_tuple(args.into_vec()).into(); + run_audit_hooks(event, &args_tup, vm) + } + + #[pyfunction] + fn addaudithook(hook: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let hooks = vm.audit_hooks.borrow().clone(); + + if hooks.is_empty() { + vm.audit_hooks.borrow_mut().push(hook); + return Ok(()); + } + + let args: PyObjectRef = vm.ctx.new_tuple(vec![]).into(); + let event: PyObjectRef = vm.ctx.new_str("sys.addaudithook").into(); + + for existing_hook in hooks { + let Err(exc) = existing_hook.call((event.clone(), args.clone()), vm) else { + continue; + }; + if exc.class().fast_issubclass(vm.ctx.exceptions.runtime_error) { + return Ok(()); + } + return Err(exc); + } + + vm.audit_hooks.borrow_mut().push(hook); + Ok(()) + } } pub(crate) fn init_module(vm: &VirtualMachine, module: &Py, builtins: &Py) { diff --git a/crates/vm/src/stdlib/sys/monitoring.rs b/crates/vm/src/stdlib/sys/monitoring.rs index bf113c7937f..6e61692507d 100644 --- a/crates/vm/src/stdlib/sys/monitoring.rs +++ b/crates/vm/src/stdlib/sys/monitoring.rs @@ -598,6 +598,16 @@ fn register_callback( let tool = check_valid_tool(tool_id, vm)?; let event_id = parse_single_event(event, vm)?; + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call( + ( + vm.ctx.new_str("sys.monitoring.register_callback"), + func.clone(), + ), + vm, + )?; + } + let mut state = vm.state.monitoring.lock(); let prev = state .callbacks diff --git a/crates/vm/src/stdlib/time.rs b/crates/vm/src/stdlib/time.rs index 31e41a89b08..a56dffe08c7 100644 --- a/crates/vm/src/stdlib/time.rs +++ b/crates/vm/src/stdlib/time.rs @@ -8,33 +8,23 @@ pub use decl::time; pub(crate) use decl::module_def; -#[cfg(not(target_env = "msvc"))] -#[cfg(not(target_arch = "wasm32"))] -unsafe extern "C" { - #[cfg(not(target_os = "freebsd"))] - #[link_name = "daylight"] - static c_daylight: core::ffi::c_int; - // pub static dstbias: std::ffi::c_int; - #[link_name = "timezone"] - static c_timezone: core::ffi::c_long; - #[link_name = "tzname"] - static c_tzname: [*const core::ffi::c_char; 2]; - #[link_name = "tzset"] - fn c_tzset(); -} - #[pymodule(name = "time", with(#[cfg(any(unix, windows))] platform))] -pub mod decl { +mod decl { + #![allow(unreachable_pub)] + + #[cfg(any(unix, windows))] + use crate::builtins::PyBaseExceptionRef; use crate::{ AsObject, Py, PyObjectRef, PyResult, VirtualMachine, - builtins::{PyBaseExceptionRef, PyStrRef, PyTypeRef}, + builtins::{PyStrRef, PyTypeRef}, function::{Either, FuncArgs, OptionalArg}, types::{PyStructSequence, struct_sequence_new}, }; #[cfg(any(unix, windows))] - use crate::{common::wtf8::Wtf8Buf, convert::ToPyObject}; - #[cfg(unix)] - use alloc::ffi::CString; + use crate::{ + common::wtf8::Wtf8Buf, + convert::{ToPyException, ToPyObject}, + }; #[cfg(not(any(unix, windows)))] use chrono::{ DateTime, Datelike, TimeZone, Timelike, @@ -44,19 +34,6 @@ pub mod decl { #[cfg(any(unix, windows))] use rustpython_host_env::time::asctime_from_tm; use rustpython_host_env::time::{self as host_time}; - #[cfg(target_env = "msvc")] - #[cfg(not(target_arch = "wasm32"))] - use windows_sys::Win32::System::Time::TIME_ZONE_INFORMATION; - - #[cfg(windows)] - unsafe extern "C" { - fn wcsftime( - s: *mut libc::wchar_t, - max: libc::size_t, - format: *const libc::wchar_t, - tm: *const libc::tm, - ) -> libc::size_t; - } #[allow(dead_code)] pub(super) const SEC_TO_MS: i64 = host_time::SEC_TO_MS; @@ -97,6 +74,10 @@ pub mod decl { #[pyfunction] fn sleep(seconds: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("time.sleep"), seconds.clone()), vm)?; + } + let seconds_type_name = seconds.class().name().to_owned(); let dur = seconds.try_into_value::(vm).map_err(|e| { if e.class().is(vm.ctx.exceptions.value_error) @@ -125,18 +106,17 @@ pub mod decl { if remaining.is_zero() { break; } - let ts = nix::sys::time::TimeSpec::from(remaining); - let (res, err) = vm.allow_threads(|| { - let r = unsafe { libc::nanosleep(ts.as_ref(), core::ptr::null_mut()) }; - (r, nix::Error::last_raw()) - }); - if res == 0 { - break; - } - if err != libc::EINTR { - return Err( - vm.new_os_error(format!("nanosleep: {}", nix::Error::from_raw(err))) - ); + let sleep_result = vm.allow_threads(|| host_time::nanosleep(remaining)); + match sleep_result { + Ok(()) => break, + Err(err) if err.raw_os_error() == Some(libc::EINTR) => {} + Err(err) => { + let errno = err.raw_os_error().unwrap_or(0); + return Err(vm.new_os_error(format!( + "nanosleep: {}", + host_time::nix_errno_display(errno) + ))); + } } // EINTR: run signal handlers, then retry with remaining time vm.check_signals()?; @@ -217,7 +197,7 @@ pub mod decl { #[cfg(target_env = "msvc")] #[cfg(not(target_arch = "wasm32"))] - pub(super) fn get_tz_info() -> TIME_ZONE_INFORMATION { + pub(super) fn get_tz_info() -> host_time::WindowsTimeZoneInfo { host_time::get_tz_info() } @@ -231,7 +211,7 @@ pub mod decl { #[pyattr] fn altzone(_vm: &VirtualMachine) -> core::ffi::c_long { // TODO: RUSTPYTHON; Add support for using the C altzone - unsafe { super::c_timezone - 3600 } + crate::host_env::time::tz::timezone() - 3600 } #[cfg(target_env = "msvc")] @@ -240,14 +220,14 @@ pub mod decl { fn altzone(_vm: &VirtualMachine) -> i32 { let info = get_tz_info(); // https://users.rust-lang.org/t/accessing-tzname-and-similar-constants-in-windows/125771/3 - (info.Bias + info.StandardBias) * 60 - 3600 + (info.bias + info.standard_bias) * 60 - 3600 } #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] #[pyattr] fn timezone(_vm: &VirtualMachine) -> core::ffi::c_long { - unsafe { super::c_timezone } + crate::host_env::time::tz::timezone() } #[cfg(target_env = "msvc")] @@ -256,7 +236,7 @@ pub mod decl { fn timezone(_vm: &VirtualMachine) -> i32 { let info = get_tz_info(); // https://users.rust-lang.org/t/accessing-tzname-and-similar-constants-in-windows/125771/3 - (info.Bias + info.StandardBias) * 60 + (info.bias + info.standard_bias) * 60 } #[cfg(not(target_os = "freebsd"))] @@ -264,7 +244,7 @@ pub mod decl { #[cfg(not(target_arch = "wasm32"))] #[pyattr] fn daylight(_vm: &VirtualMachine) -> core::ffi::c_int { - unsafe { super::c_daylight } + crate::host_env::time::tz::daylight() } #[cfg(target_env = "msvc")] @@ -273,7 +253,7 @@ pub mod decl { fn daylight(_vm: &VirtualMachine) -> i32 { let info = get_tz_info(); // https://users.rust-lang.org/t/accessing-tzname-and-similar-constants-in-windows/125771/3 - (info.StandardBias != info.DaylightBias) as i32 + (info.standard_bias != info.daylight_bias) as i32 } #[cfg(not(target_env = "msvc"))] @@ -281,13 +261,7 @@ pub mod decl { #[pyattr] fn tzname(vm: &VirtualMachine) -> crate::builtins::PyTupleRef { use crate::builtins::tuple::IntoPyTuple; - - unsafe fn to_str(s: *const core::ffi::c_char) -> String { - unsafe { core::ffi::CStr::from_ptr(s) } - .to_string_lossy() - .into_owned() - } - unsafe { (to_str(super::c_tzname[0]), to_str(super::c_tzname[1])) }.into_pytuple(vm) + crate::host_env::time::tz::tzname_strings().into_pytuple(vm) } #[cfg(target_env = "msvc")] @@ -296,13 +270,7 @@ pub mod decl { fn tzname(vm: &VirtualMachine) -> crate::builtins::PyTupleRef { use crate::builtins::tuple::IntoPyTuple; let info = get_tz_info(); - let standard = widestring::decode_utf16_lossy(info.StandardName) - .take_while(|&c| c != '\0') - .collect::(); - let daylight = widestring::decode_utf16_lossy(info.DaylightName) - .take_while(|&c| c != '\0') - .collect::(); - let tz_name = (&*standard, &*daylight); + let tz_name = (&*info.standard_name, &*info.daylight_name); tz_name.into_pytuple(vm) } @@ -337,19 +305,12 @@ pub mod decl { } } - #[cfg(any(unix, windows))] - struct CheckedTm { - tm: libc::tm, - #[cfg(unix)] - zone: Option, - } - #[cfg(any(unix, windows))] fn checked_tm_from_struct_time( t: &StructTimeData, vm: &VirtualMachine, func_name: &'static str, - ) -> PyResult { + ) -> PyResult { let invalid_tuple = || vm.new_type_error(format!("{func_name}(): illegal time tuple argument")); let classify_err = |e: PyBaseExceptionRef| { @@ -400,45 +361,6 @@ pub mod decl { .try_into_value(vm) .map_err(classify_err)?; - let mut tm: libc::tm = unsafe { core::mem::zeroed() }; - tm.tm_year = year - 1900; - tm.tm_mon = tm_mon; - tm.tm_mday = tm_mday; - tm.tm_hour = tm_hour; - tm.tm_min = tm_min; - tm.tm_sec = tm_sec; - tm.tm_wday = tm_wday; - tm.tm_yday = tm_yday; - tm.tm_isdst = tm_isdst; - - if tm.tm_mon == -1 { - tm.tm_mon = 0; - } else if tm.tm_mon < 0 || tm.tm_mon > 11 { - return Err(vm.new_value_error("month out of range")); - } - if tm.tm_mday == 0 { - tm.tm_mday = 1; - } else if tm.tm_mday < 0 || tm.tm_mday > 31 { - return Err(vm.new_value_error("day of month out of range")); - } - if tm.tm_hour < 0 || tm.tm_hour > 23 { - return Err(vm.new_value_error("hour out of range")); - } - if tm.tm_min < 0 || tm.tm_min > 59 { - return Err(vm.new_value_error("minute out of range")); - } - if tm.tm_sec < 0 || tm.tm_sec > 61 { - return Err(vm.new_value_error("seconds out of range")); - } - if tm.tm_wday < 0 { - return Err(vm.new_value_error("day of week out of range")); - } - if tm.tm_yday == -1 { - tm.tm_yday = 0; - } else if tm.tm_yday < 0 || tm.tm_yday > 365 { - return Err(vm.new_value_error("day of year out of range")); - } - #[cfg(unix)] { use crate::builtins::PyUtf8StrRef; @@ -450,28 +372,47 @@ pub mod decl { .clone() .try_into_value(vm) .map_err(|_| invalid_tuple())?; + Some(zone.as_str().to_owned()) + }; + let gmtoff = if t.tm_gmtoff.is(&vm.ctx.none) { + None + } else { Some( - CString::new(zone.as_str()) - .map_err(|_| vm.new_value_error("embedded null character"))?, + t.tm_gmtoff + .clone() + .try_into_value::(vm) + .map_err(classify_err)?, ) }; - if let Some(zone) = &zone { - tm.tm_zone = zone.as_ptr().cast_mut(); - } - if !t.tm_gmtoff.is(&vm.ctx.none) { - let gmtoff: i64 = t - .tm_gmtoff - .clone() - .try_into_value(vm) - .map_err(classify_err)?; - tm.tm_gmtoff = gmtoff as _; - } - - Ok(CheckedTm { tm, zone }) + host_time::checked_tm_from_parts(host_time::CheckedTmParts { + year: year.into(), + tm_mon, + tm_mday, + tm_hour, + tm_min, + tm_sec, + tm_wday, + tm_yday, + tm_isdst, + zone, + gmtoff, + }) + .map_err(|err| err.to_pyexception(vm)) } #[cfg(windows)] { - Ok(CheckedTm { tm }) + host_time::checked_tm_from_parts(host_time::CheckedTmParts { + year: year.into(), + tm_mon, + tm_mday, + tm_hour, + tm_min, + tm_sec, + tm_wday, + tm_yday, + tm_isdst, + }) + .map_err(|err| err.to_pyexception(vm)) } } @@ -605,7 +546,11 @@ pub mod decl { } #[cfg(any(unix, windows))] - fn strftime_crt(format: &PyStrRef, checked_tm: CheckedTm, vm: &VirtualMachine) -> PyResult { + fn strftime_crt( + format: &PyStrRef, + checked_tm: host_time::CheckedTm, + vm: &VirtualMachine, + ) -> PyResult { #[cfg(unix)] let _keep_zone_alive = &checked_tm.zone; let mut tm = checked_tm.tm; @@ -620,62 +565,14 @@ pub mod decl { } } - #[cfg(unix)] - fn strftime_ascii(fmt: &str, tm: &libc::tm, vm: &VirtualMachine) -> PyResult { - let fmt_c = - CString::new(fmt).map_err(|_| vm.new_value_error("embedded null character"))?; - let mut size = 1024usize; - let max_scale = 256usize.saturating_mul(fmt.len().max(1)); - loop { - let mut out = vec![0u8; size]; - let written = unsafe { - libc::strftime( - out.as_mut_ptr().cast(), - out.len(), - fmt_c.as_ptr(), - tm as *const libc::tm, - ) - }; - if written > 0 || size >= max_scale { - return Ok(String::from_utf8_lossy(&out[..written]).into_owned()); - } - size = size.saturating_mul(2); - } - } - - #[cfg(windows)] - fn strftime_ascii(fmt: &str, tm: &libc::tm, vm: &VirtualMachine) -> PyResult { - if fmt.contains('\0') { - return Err(vm.new_value_error("embedded null character")); - } - // Use wcsftime for proper Unicode output (e.g. %Z timezone names) - let fmt_wide: Vec = fmt.encode_utf16().chain(core::iter::once(0)).collect(); - let mut size = 1024usize; - let max_scale = 256usize.saturating_mul(fmt.len().max(1)); - loop { - let mut out = vec![0u16; size]; - let written = unsafe { - rustpython_host_env::suppress_iph!(wcsftime( - out.as_mut_ptr(), - out.len(), - fmt_wide.as_ptr(), - tm as *const libc::tm, - )) - }; - if written > 0 || size >= max_scale { - return Ok(String::from_utf16_lossy(&out[..written])); - } - size = size.saturating_mul(2); - } - } - let mut out = Wtf8Buf::new(); let mut ascii = String::new(); for codepoint in format.as_wtf8().code_points() { if codepoint.to_u32() == 0 { if !ascii.is_empty() { - let part = strftime_ascii(&ascii, &tm, vm)?; + let part = host_time::strftime_ascii(&ascii, &tm) + .map_err(|_| vm.new_value_error("embedded null character"))?; out.extend(part.chars()); ascii.clear(); } @@ -690,14 +587,16 @@ pub mod decl { } if !ascii.is_empty() { - let part = strftime_ascii(&ascii, &tm, vm)?; + let part = host_time::strftime_ascii(&ascii, &tm) + .map_err(|_| vm.new_value_error("embedded null character"))?; out.extend(part.chars()); ascii.clear(); } out.push(codepoint); } if !ascii.is_empty() { - let part = strftime_ascii(&ascii, &tm, vm)?; + let part = host_time::strftime_ascii(&ascii, &tm) + .map_err(|_| vm.new_value_error("embedded null character"))?; out.extend(part.chars()); } Ok(out.to_pyobject(vm)) @@ -786,18 +685,9 @@ pub mod decl { #[cfg(all(target_arch = "wasm32", target_os = "emscripten"))] fn get_process_time(vm: &VirtualMachine) -> PyResult { - let t: libc::tms = unsafe { - let mut t = core::mem::MaybeUninit::uninit(); - if libc::times(t.as_mut_ptr()) == -1 { - return Err(vm.new_os_error("Failed to get clock time".to_owned())); - } - t.assume_init() - }; - let freq = unsafe { libc::sysconf(libc::_SC_CLK_TCK) }; - - Ok(Duration::from_nanos( - time_muldiv(t.tms_utime, SEC_TO_NS, freq) + time_muldiv(t.tms_stime, SEC_TO_NS, freq), - )) + let times = host_time::process_times() + .map_err(|_| vm.new_os_error("Failed to get clock time".to_owned()))?; + Ok(Duration::from_secs_f64(times.user + times.system)) } #[cfg(not(any( @@ -943,31 +833,32 @@ pub mod decl { return Err(vm.new_overflow_error("year out of range")); } - let mut tm: libc::tm = unsafe { core::mem::zeroed() }; - tm.tm_sec = t.tm_sec.clone().try_into_value(vm).map_err(classify_err)?; - tm.tm_min = t.tm_min.clone().try_into_value(vm).map_err(classify_err)?; - tm.tm_hour = t.tm_hour.clone().try_into_value(vm).map_err(classify_err)?; - tm.tm_mday = t.tm_mday.clone().try_into_value(vm).map_err(classify_err)?; - tm.tm_mon = t - .tm_mon - .clone() - .try_into_value::(vm) - .map_err(classify_err)? - - 1; - tm.tm_year = year - 1900; - tm.tm_wday = -1; - tm.tm_yday = t - .tm_yday - .clone() - .try_into_value::(vm) - .map_err(classify_err)? - - 1; - tm.tm_isdst = t - .tm_isdst - .clone() - .try_into_value(vm) - .map_err(classify_err)?; - Ok(tm) + host_time::mktime_tm_from_parts(host_time::MktimeTmParts { + year, + tm_sec: t.tm_sec.clone().try_into_value(vm).map_err(classify_err)?, + tm_min: t.tm_min.clone().try_into_value(vm).map_err(classify_err)?, + tm_hour: t.tm_hour.clone().try_into_value(vm).map_err(classify_err)?, + tm_mday: t.tm_mday.clone().try_into_value(vm).map_err(classify_err)?, + tm_mon: t + .tm_mon + .clone() + .try_into_value::(vm) + .map_err(classify_err)?, + tm_yday: t + .tm_yday + .clone() + .try_into_value::(vm) + .map_err(classify_err)?, + tm_isdst: t + .tm_isdst + .clone() + .try_into_value(vm) + .map_err(classify_err)?, + }) + .map_err(|err| match err { + host_time::CheckedTmError::YearOutOfRange => vm.new_overflow_error("year out of range"), + _ => vm.new_type_error("mktime(): illegal time tuple argument"), + }) } #[cfg(any(unix, windows))] @@ -1009,9 +900,7 @@ pub mod decl { ) -> PyResult<()> { #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] - unsafe { - super::c_tzset() - }; + crate::host_env::time::tz::tzset(); __module_exec(vm, module); Ok(()) @@ -1030,9 +919,14 @@ mod platform { convert::IntoPyException, }; use core::time::Duration; - #[cfg_attr(target_env = "musl", allow(deprecated))] - use libc::time_t; - use nix::{sys::time::TimeSpec, time::ClockId}; + #[cfg(any( + target_os = "illumos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "solaris", + ))] + use rustpython_host_env::resource as host_resource; + use rustpython_host_env::time::{self as host_time, ClockId}; #[cfg(target_os = "solaris")] #[pyattr] @@ -1095,40 +989,33 @@ mod platform { } } - #[cfg_attr(target_env = "musl", allow(deprecated))] - pub(super) fn current_time_t() -> time_t { - unsafe { libc::time(core::ptr::null_mut()) } + pub(super) fn current_time_t() -> host_time::TimeT { + host_time::current_time_t() } - #[cfg_attr(target_env = "musl", allow(deprecated))] pub(super) fn gmtime_from_timestamp( - when: time_t, + when: host_time::TimeT, vm: &VirtualMachine, ) -> PyResult { - let mut out = core::mem::MaybeUninit::::uninit(); - let ret = unsafe { libc::gmtime_r(&when, out.as_mut_ptr()) }; - if ret.is_null() { + let Some(tm) = host_time::gmtime_from_timestamp(when) else { return Err(vm.new_overflow_error("timestamp out of range for platform time_t")); - } - Ok(struct_time_from_tm(vm, unsafe { out.assume_init() })) + }; + Ok(struct_time_from_tm(vm, tm)) } - #[cfg_attr(target_env = "musl", allow(deprecated))] pub(super) fn localtime_from_timestamp( - when: time_t, + when: host_time::TimeT, vm: &VirtualMachine, ) -> PyResult { - let mut out = core::mem::MaybeUninit::::uninit(); - let ret = unsafe { libc::localtime_r(&when, out.as_mut_ptr()) }; - if ret.is_null() { + let Some(tm) = host_time::localtime_from_timestamp(when) else { return Err(vm.new_overflow_error("timestamp out of range for platform time_t")); - } - Ok(struct_time_from_tm(vm, unsafe { out.assume_init() })) + }; + Ok(struct_time_from_tm(vm, tm)) } pub(super) fn unix_mktime(t: &StructTimeData, vm: &VirtualMachine) -> PyResult { let mut tm = super::decl::tm_from_struct_time(t, vm)?; - let timestamp = unsafe { libc::mktime(&mut tm) }; + let timestamp = host_time::mktime(&mut tm); if timestamp == -1 && tm.tm_wday == -1 { return Err(vm.new_overflow_error("mktime argument out of range")); } @@ -1136,8 +1023,7 @@ mod platform { } fn get_clock_time(clk_id: ClockId, vm: &VirtualMachine) -> PyResult { - let ts = nix::time::clock_gettime(clk_id).map_err(|e| e.into_pyexception(vm))?; - Ok(ts.into()) + rustpython_host_env::time::clock_gettime(clk_id).map_err(|e| e.into_pyexception(vm)) } #[pyfunction] @@ -1153,23 +1039,8 @@ mod platform { #[cfg(not(target_os = "redox"))] #[pyfunction] fn clock_getres(clk_id: ClockId, vm: &VirtualMachine) -> PyResult { - let ts = nix::time::clock_getres(clk_id).map_err(|e| e.into_pyexception(vm))?; - Ok(Duration::from(ts).as_secs_f64()) - } - - #[cfg(not(target_os = "redox"))] - #[cfg(not(target_vendor = "apple"))] - fn set_clock_time(clk_id: ClockId, timespec: TimeSpec, vm: &VirtualMachine) -> PyResult<()> { - nix::time::clock_settime(clk_id, timespec).map_err(|e| e.into_pyexception(vm)) - } - - #[cfg(not(target_os = "redox"))] - #[cfg(target_os = "macos")] - fn set_clock_time(clk_id: ClockId, timespec: TimeSpec, vm: &VirtualMachine) -> PyResult<()> { - // idk why nix disables clock_settime on macos - let ret = unsafe { libc::clock_settime(clk_id.as_raw(), timespec.as_ref()) }; - nix::Error::result(ret) - .map(drop) + rustpython_host_env::time::clock_getres(clk_id) + .map(|d| d.as_secs_f64()) .map_err(|e| e.into_pyexception(vm)) } @@ -1177,7 +1048,7 @@ mod platform { #[cfg(any(not(target_vendor = "apple"), target_os = "macos"))] #[pyfunction] fn clock_settime(clk_id: ClockId, time: Duration, vm: &VirtualMachine) -> PyResult<()> { - set_clock_time(clk_id, time.into(), vm) + rustpython_host_env::time::clock_settime(clk_id, time).map_err(|e| e.into_pyexception(vm)) } #[cfg(not(target_os = "redox"))] @@ -1185,8 +1056,8 @@ mod platform { #[cfg_attr(target_env = "musl", allow(deprecated))] #[pyfunction] fn clock_settime_ns(clk_id: ClockId, time: libc::time_t, vm: &VirtualMachine) -> PyResult<()> { - let ts = Duration::from_nanos(time as _).into(); - set_clock_time(clk_id, ts, vm) + rustpython_host_env::time::clock_settime(clk_id, Duration::from_nanos(time as _)) + .map_err(|e| e.into_pyexception(vm)) } // Requires all CLOCK constants available and clock_getres @@ -1271,7 +1142,8 @@ mod platform { #[cfg(target_os = "solaris")] pub(super) fn get_thread_time(vm: &VirtualMachine) -> PyResult { - Ok(Duration::from_nanos(unsafe { libc::gethrvtime() })) + let _ = vm; + Ok(host_time::gethrvtime_duration()) } #[cfg(not(any( @@ -1291,7 +1163,6 @@ mod platform { target_os = "openbsd", ))] pub(super) fn get_process_time(vm: &VirtualMachine) -> PyResult { - use nix::sys::resource::{UsageWho, getrusage}; fn from_timeval(tv: libc::timeval, vm: &VirtualMachine) -> PyResult { (|tv: libc::timeval| { let t = tv.tv_sec.checked_mul(SEC_TO_NS)?; @@ -1300,9 +1171,9 @@ mod platform { })(tv) .ok_or_else(|| vm.new_overflow_error("timestamp too large to convert to i64")) } - let ru = getrusage(UsageWho::RUSAGE_SELF).map_err(|e| e.into_pyexception(vm))?; - let utime = from_timeval(ru.user_time().into(), vm)?; - let stime = from_timeval(ru.system_time().into(), vm)?; + let ru = host_resource::getrusage(libc::RUSAGE_SELF).map_err(|e| e.into_pyexception(vm))?; + let utime = from_timeval(ru.ru_utime, vm)?; + let stime = from_timeval(ru.ru_stime, vm)?; Ok(Duration::from_nanos((utime + stime) as u64)) } @@ -1317,19 +1188,7 @@ mod platform { builtins::{PyNamespace, PyUtf8StrRef}, }; use core::time::Duration; - use windows_sys::Win32::{ - Foundation::FILETIME, - System::Performance::{QueryPerformanceCounter, QueryPerformanceFrequency}, - System::SystemInformation::{GetSystemTimeAdjustment, GetTickCount64}, - System::Threading::{GetCurrentProcess, GetCurrentThread, GetProcessTimes, GetThreadTimes}, - }; - - unsafe extern "C" { - fn _gmtime64_s(tm: *mut libc::tm, time: *const libc::time_t) -> libc::c_int; - fn _localtime64_s(tm: *mut libc::tm, time: *const libc::time_t) -> libc::c_int; - #[link_name = "_mktime64"] - fn c_mktime(tm: *mut libc::tm) -> libc::time_t; - } + use rustpython_host_env::time as host_time; fn struct_time_from_tm( vm: &VirtualMachine, @@ -1352,80 +1211,51 @@ mod platform { } } - #[cfg_attr(target_env = "musl", allow(deprecated))] - pub(super) fn current_time_t() -> libc::time_t { - unsafe { libc::time(core::ptr::null_mut()) } + pub(super) fn current_time_t() -> host_time::TimeT { + host_time::current_time_t() } - #[cfg_attr(target_env = "musl", allow(deprecated))] pub(super) fn gmtime_from_timestamp( - when: libc::time_t, + when: host_time::TimeT, vm: &VirtualMachine, ) -> PyResult { - let mut out = core::mem::MaybeUninit::::uninit(); - let err = unsafe { _gmtime64_s(out.as_mut_ptr(), &when) }; - if err != 0 { - return Err(vm.new_overflow_error("timestamp out of range for platform time_t")); - } - Ok(struct_time_from_tm( - vm, - unsafe { out.assume_init() }, - "UTC", - 0, - )) + let tm = host_time::gmtime_from_timestamp(when) + .ok_or_else(|| vm.new_overflow_error("timestamp out of range for platform time_t"))?; + Ok(struct_time_from_tm(vm, tm, "UTC", 0)) } - #[cfg_attr(target_env = "musl", allow(deprecated))] pub(super) fn localtime_from_timestamp( - when: libc::time_t, + when: host_time::TimeT, vm: &VirtualMachine, ) -> PyResult { - let mut out = core::mem::MaybeUninit::::uninit(); - let err = unsafe { _localtime64_s(out.as_mut_ptr(), &when) }; - if err != 0 { - return Err(vm.new_overflow_error("timestamp out of range for platform time_t")); - } - let tm = unsafe { out.assume_init() }; + let tm = host_time::localtime_from_timestamp(when) + .ok_or_else(|| vm.new_overflow_error("timestamp out of range for platform time_t"))?; // Get timezone info from Windows API let info = get_tz_info(); let (bias, name) = if tm.tm_isdst > 0 { - (info.DaylightBias, &info.DaylightName) + (info.daylight_bias, &info.daylight_name) } else { - (info.StandardBias, &info.StandardName) + (info.standard_bias, &info.standard_name) }; - let zone = widestring::decode_utf16_lossy(name.iter().copied()) - .take_while(|&c| c != '\0') - .collect::(); - #[allow(clippy::unnecessary_cast, reason = "info.Bias is not always i32")] - let gmtoff = -((info.Bias + bias) as i32) * 60; + let gmtoff = -(info.bias + bias) * 60; - Ok(struct_time_from_tm(vm, tm, &zone, gmtoff)) + Ok(struct_time_from_tm(vm, tm, name, gmtoff)) } pub(super) fn win_mktime(t: &StructTimeData, vm: &VirtualMachine) -> PyResult { let mut tm = super::decl::tm_from_struct_time(t, vm)?; - let timestamp = unsafe { rustpython_host_env::suppress_iph!(c_mktime(&mut tm)) }; + let timestamp = host_time::mktime(&mut tm); if timestamp == -1 && tm.tm_wday == -1 { return Err(vm.new_overflow_error("mktime argument out of range")); } Ok(timestamp as f64) } - fn u64_from_filetime(time: FILETIME) -> u64 { - let large: [u32; 2] = [time.dwLowDateTime, time.dwHighDateTime]; - unsafe { core::mem::transmute(large) } - } - fn win_perf_counter_frequency(vm: &VirtualMachine) -> PyResult { - let frequency = unsafe { - let mut freq = core::mem::MaybeUninit::uninit(); - if QueryPerformanceFrequency(freq.as_mut_ptr()) == 0 { - return Err(vm.new_last_os_error()); - } - freq.assume_init() - }; + let frequency = + host_time::query_performance_frequency().ok_or_else(|| vm.new_last_os_error())?; if frequency < 1 { Err(vm.new_runtime_error("invalid QueryPerformanceFrequency")) @@ -1446,11 +1276,7 @@ mod platform { } pub(super) fn get_perf_time(vm: &VirtualMachine) -> PyResult { - let ticks = unsafe { - let mut performance_count = core::mem::MaybeUninit::uninit(); - QueryPerformanceCounter(performance_count.as_mut_ptr()); - performance_count.assume_init() - }; + let ticks = host_time::query_performance_counter(); Ok(Duration::from_nanos(time_muldiv( ticks, @@ -1460,25 +1286,11 @@ mod platform { } fn get_system_time_adjustment(vm: &VirtualMachine) -> PyResult { - let mut _time_adjustment = core::mem::MaybeUninit::uninit(); - let mut time_increment = core::mem::MaybeUninit::uninit(); - let mut _is_time_adjustment_disabled = core::mem::MaybeUninit::uninit(); - let time_increment = unsafe { - if GetSystemTimeAdjustment( - _time_adjustment.as_mut_ptr(), - time_increment.as_mut_ptr(), - _is_time_adjustment_disabled.as_mut_ptr(), - ) == 0 - { - return Err(vm.new_last_os_error()); - } - time_increment.assume_init() - }; - Ok(time_increment) + host_time::get_system_time_adjustment().ok_or_else(|| vm.new_last_os_error()) } pub(super) fn get_monotonic_time(vm: &VirtualMachine) -> PyResult { - let ticks = unsafe { GetTickCount64() }; + let ticks = host_time::tick_count64(); Ok(Duration::from_nanos( (ticks as i64) @@ -1523,52 +1335,14 @@ mod platform { } pub(super) fn get_thread_time(vm: &VirtualMachine) -> PyResult { - let (kernel_time, user_time) = unsafe { - let mut _creation_time = core::mem::MaybeUninit::uninit(); - let mut _exit_time = core::mem::MaybeUninit::uninit(); - let mut kernel_time = core::mem::MaybeUninit::uninit(); - let mut user_time = core::mem::MaybeUninit::uninit(); - - let thread = GetCurrentThread(); - if GetThreadTimes( - thread, - _creation_time.as_mut_ptr(), - _exit_time.as_mut_ptr(), - kernel_time.as_mut_ptr(), - user_time.as_mut_ptr(), - ) == 0 - { - return Err(vm.new_os_error("Failed to get clock time".to_owned())); - } - (kernel_time.assume_init(), user_time.assume_init()) - }; - let k_time = u64_from_filetime(kernel_time); - let u_time = u64_from_filetime(user_time); - Ok(Duration::from_nanos((k_time + u_time) * 100)) + let total = host_time::get_thread_time_100ns() + .ok_or_else(|| vm.new_os_error("Failed to get clock time".to_owned()))?; + Ok(Duration::from_nanos(total * 100)) } pub(super) fn get_process_time(vm: &VirtualMachine) -> PyResult { - let (kernel_time, user_time) = unsafe { - let mut _creation_time = core::mem::MaybeUninit::uninit(); - let mut _exit_time = core::mem::MaybeUninit::uninit(); - let mut kernel_time = core::mem::MaybeUninit::uninit(); - let mut user_time = core::mem::MaybeUninit::uninit(); - - let process = GetCurrentProcess(); - if GetProcessTimes( - process, - _creation_time.as_mut_ptr(), - _exit_time.as_mut_ptr(), - kernel_time.as_mut_ptr(), - user_time.as_mut_ptr(), - ) == 0 - { - return Err(vm.new_os_error("Failed to get clock time".to_owned())); - } - (kernel_time.assume_init(), user_time.assume_init()) - }; - let k_time = u64_from_filetime(kernel_time); - let u_time = u64_from_filetime(user_time); - Ok(Duration::from_nanos((k_time + u_time) * 100)) + let total = host_time::get_process_time_100ns() + .ok_or_else(|| vm.new_os_error("Failed to get clock time".to_owned()))?; + Ok(Duration::from_nanos(total * 100)) } } diff --git a/crates/vm/src/stdlib/winreg.rs b/crates/vm/src/stdlib/winreg.rs index 4e0725141d8..468767e9d38 100644 --- a/crates/vm/src/stdlib/winreg.rs +++ b/crates/vm/src/stdlib/winreg.rs @@ -7,7 +7,7 @@ pub(crate) use winreg::module_def; mod winreg { use crate::builtins::{PyInt, PyStr, PyTuple, PyTypeRef}; use crate::common::hash::PyHash; - use crate::convert::TryFromObject; + use crate::convert::{ToPyException, TryFromObject}; use crate::function::FuncArgs; use crate::host_env::windows::ToWideString; use crate::object::AsObject; @@ -18,35 +18,20 @@ mod winreg { use crossbeam_utils::atomic::AtomicCell; use malachite_bigint::Sign; use num_traits::ToPrimitive; - use windows_sys::Win32::Foundation::{self, ERROR_MORE_DATA}; - use windows_sys::Win32::System::Registry; + use rustpython_host_env::winreg as host_winreg; /// Atomic HKEY handle type for lock-free thread-safe access - type AtomicHKEY = AtomicCell; - - /// Convert byte slice to UTF-16 slice (zero-copy when aligned) - fn bytes_as_wide_slice(bytes: &[u8]) -> &[u16] { - // SAFETY: Windows Registry API returns properly aligned UTF-16 data. - // align_to handles any edge cases safely by returning empty prefix/suffix - // if alignment doesn't match. - let (prefix, u16_slice, suffix) = unsafe { bytes.align_to::() }; - debug_assert!( - prefix.is_empty() && suffix.is_empty(), - "Registry data should be u16-aligned" - ); - u16_slice - } + type AtomicHKEY = AtomicCell; fn os_error_from_windows_code( vm: &VirtualMachine, code: i32, ) -> crate::PyRef { - use crate::convert::ToPyException; std::io::Error::from_raw_os_error(code).to_pyexception(vm) } /// Wrapper type for HKEY that can be created from PyHkey or int - struct HKEYArg(Registry::HKEY); + struct HKEYArg(host_winreg::HKEY); impl TryFromObject for HKEYArg { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { @@ -56,21 +41,20 @@ mod winreg { } // Then try int let handle = usize::try_from_object(vm, obj)?; - Ok(Self(handle as Registry::HKEY)) + Ok(Self(handle as host_winreg::HKEY)) } } // access rights #[pyattr] - pub(super) use windows_sys::Win32::System::Registry::{ + pub(super) use host_winreg::{ KEY_ALL_ACCESS, KEY_CREATE_LINK, KEY_CREATE_SUB_KEY, KEY_ENUMERATE_SUB_KEYS, KEY_EXECUTE, KEY_NOTIFY, KEY_QUERY_VALUE, KEY_READ, KEY_SET_VALUE, KEY_WOW64_32KEY, KEY_WOW64_64KEY, KEY_WRITE, }; - // value types #[pyattr] - pub(super) use windows_sys::Win32::System::Registry::{ + pub(super) use host_winreg::{ REG_BINARY, REG_CREATED_NEW_KEY, REG_DWORD, REG_DWORD_BIG_ENDIAN, REG_DWORD_LITTLE_ENDIAN, REG_EXPAND_SZ, REG_FULL_RESOURCE_DESCRIPTOR, REG_LINK, REG_MULTI_SZ, REG_NONE, REG_NOTIFY_CHANGE_ATTRIBUTES, REG_NOTIFY_CHANGE_LAST_SET, REG_NOTIFY_CHANGE_NAME, @@ -80,25 +64,14 @@ mod winreg { REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, REG_WHOLE_HIVE_VOLATILE, }; - // Additional constants not in windows-sys #[pyattr] - const REG_REFRESH_HIVE: u32 = 0x00000002; + const REG_REFRESH_HIVE: u32 = host_winreg::REG_REFRESH_HIVE; #[pyattr] - const REG_NO_LAZY_FLUSH: u32 = 0x00000004; - // REG_LEGAL_OPTION is a mask of all option flags + const REG_NO_LAZY_FLUSH: u32 = host_winreg::REG_NO_LAZY_FLUSH; #[pyattr] - const REG_LEGAL_OPTION: u32 = Registry::REG_OPTION_RESERVED - | Registry::REG_OPTION_NON_VOLATILE - | Registry::REG_OPTION_VOLATILE - | Registry::REG_OPTION_CREATE_LINK - | Registry::REG_OPTION_BACKUP_RESTORE - | Registry::REG_OPTION_OPEN_LINK; - // REG_LEGAL_CHANGE_FILTER is a mask of all notify flags + const REG_LEGAL_OPTION: u32 = host_winreg::REG_LEGAL_OPTION; #[pyattr] - const REG_LEGAL_CHANGE_FILTER: u32 = Registry::REG_NOTIFY_CHANGE_NAME - | Registry::REG_NOTIFY_CHANGE_ATTRIBUTES - | Registry::REG_NOTIFY_CHANGE_LAST_SET - | Registry::REG_NOTIFY_CHANGE_SECURITY; + const REG_LEGAL_CHANGE_FILTER: u32 = host_winreg::REG_LEGAL_CHANGE_FILTER; // error is an alias for OSError (for backwards compatibility) #[pyattr] @@ -108,37 +81,37 @@ mod winreg { #[pyattr(once)] fn HKEY_CLASSES_ROOT(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_CLASSES_ROOT).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_CLASSES_ROOT).into_ref(&vm.ctx) } #[pyattr(once)] fn HKEY_CURRENT_USER(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_CURRENT_USER).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_CURRENT_USER).into_ref(&vm.ctx) } #[pyattr(once)] fn HKEY_LOCAL_MACHINE(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_LOCAL_MACHINE).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_LOCAL_MACHINE).into_ref(&vm.ctx) } #[pyattr(once)] fn HKEY_USERS(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_USERS).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_USERS).into_ref(&vm.ctx) } #[pyattr(once)] fn HKEY_PERFORMANCE_DATA(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_PERFORMANCE_DATA).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_PERFORMANCE_DATA).into_ref(&vm.ctx) } #[pyattr(once)] fn HKEY_CURRENT_CONFIG(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_CURRENT_CONFIG).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_CURRENT_CONFIG).into_ref(&vm.ctx) } #[pyattr(once)] fn HKEY_DYN_DATA(vm: &VirtualMachine) -> PyRef { - PyHkey::new(Registry::HKEY_DYN_DATA).into_ref(&vm.ctx) + PyHkey::new(host_winreg::HKEY_DYN_DATA).into_ref(&vm.ctx) } #[pyattr] @@ -152,7 +125,7 @@ mod winreg { unsafe impl Sync for PyHkey {} impl PyHkey { - fn new(hkey: Registry::HKEY) -> Self { + fn new(hkey: host_winreg::HKEY) -> Self { Self { hkey: AtomicHKEY::new(hkey), } @@ -186,7 +159,7 @@ mod winreg { if old_hkey.is_null() { return Ok(()); } - let res = unsafe { Registry::RegCloseKey(old_hkey) }; + let res = host_winreg::close_key(old_hkey); if res == 0 { Ok(()) } else { @@ -225,7 +198,7 @@ mod winreg { fn drop(&mut self) { let hkey = self.hkey.swap(core::ptr::null_mut()); if !hkey.is_null() { - unsafe { Registry::RegCloseKey(hkey) }; + host_winreg::close_key(hkey); } } } @@ -237,7 +210,7 @@ mod winreg { } } - pub(super) const HKEY_ERR_MSG: &str = "bad operand type"; + const HKEY_ERR_MSG: &str = "bad operand type"; impl AsNumber for PyHkey { fn as_number() -> &'static PyNumberMethods { @@ -282,45 +255,31 @@ mod winreg { key: PyRef, vm: &VirtualMachine, ) -> PyResult { - if let Some(computer_name) = computer_name { - let mut ret_key = core::ptr::null_mut(); - let wide_computer_name = computer_name.to_wide_with_nul(); - let res = unsafe { - Registry::RegConnectRegistryW( - wide_computer_name.as_ptr(), - key.hkey.load(), - &mut ret_key, - ) - }; - if res == 0 { - Ok(PyHkey::new(ret_key)) - } else { - Err(vm.new_os_error(format!("error code: {res}"))) - } + let wide_computer_name = computer_name.map(|n| n.to_wide_cstring()); + let mut ret_key = core::ptr::null_mut(); + let res = unsafe { + host_winreg::connect_registry( + wide_computer_name.as_deref(), + key.hkey.load(), + &mut ret_key, + ) + }; + if res == 0 { + Ok(PyHkey::new(ret_key)) } else { - let mut ret_key = core::ptr::null_mut(); - let res = unsafe { - Registry::RegConnectRegistryW(core::ptr::null_mut(), key.hkey.load(), &mut ret_key) - }; - if res == 0 { - Ok(PyHkey::new(ret_key)) - } else { - Err(vm.new_os_error(format!("error code: {res}"))) - } + Err(os_error_from_windows_code(vm, res as i32)) } } #[pyfunction] fn CreateKey(key: PyRef, sub_key: String, vm: &VirtualMachine) -> PyResult { - let wide_sub_key = sub_key.to_wide_with_nul(); + let wide_sub_key = sub_key.to_wide_cstring(); let mut out_key = core::ptr::null_mut(); - let res = unsafe { - Registry::RegCreateKeyW(key.hkey.load(), wide_sub_key.as_ptr(), &mut out_key) - }; + let res = unsafe { host_winreg::create_key(key.hkey.load(), &wide_sub_key, &mut out_key) }; if res == 0 { Ok(PyHkey::new(out_key)) } else { - Err(vm.new_os_error(format!("error code: {res}"))) + Err(os_error_from_windows_code(vm, res as i32)) } } @@ -332,22 +291,22 @@ mod winreg { sub_key: String, #[pyarg(any, default = 0)] reserved: u32, - #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_WRITE)] + #[pyarg(any, default = host_winreg::KEY_WRITE)] access: u32, } #[pyfunction] fn CreateKeyEx(args: CreateKeyExArgs, vm: &VirtualMachine) -> PyResult { - let wide_sub_key = args.sub_key.to_wide_with_nul(); - let mut res: Registry::HKEY = core::ptr::null_mut(); + let wide_sub_key = args.sub_key.to_wide_cstring(); + let mut res: host_winreg::HKEY = core::ptr::null_mut(); let err = unsafe { let key = args.key.hkey.load(); - Registry::RegCreateKeyExW( + host_winreg::create_key_ex( key, - wide_sub_key.as_ptr(), + &wide_sub_key, args.reserved, - core::ptr::null(), - Registry::REG_OPTION_NON_VOLATILE, + core::ptr::null_mut(), + host_winreg::REG_OPTION_NON_VOLATILE, args.access, core::ptr::null(), &mut res, @@ -371,8 +330,8 @@ mod winreg { #[pyfunction] fn DeleteKey(key: PyRef, sub_key: String, vm: &VirtualMachine) -> PyResult<()> { - let wide_sub_key = sub_key.to_wide_with_nul(); - let res = unsafe { Registry::RegDeleteKeyW(key.hkey.load(), wide_sub_key.as_ptr()) }; + let wide_sub_key = sub_key.to_wide_cstring(); + let res = unsafe { host_winreg::delete_key(key.hkey.load(), &wide_sub_key) }; if res == 0 { Ok(()) } else { @@ -382,11 +341,8 @@ mod winreg { #[pyfunction] fn DeleteValue(key: PyRef, value: Option, vm: &VirtualMachine) -> PyResult<()> { - let wide_value = value.map(|v| v.to_wide_with_nul()); - let value_ptr = wide_value - .as_ref() - .map_or(core::ptr::null(), |v| v.as_ptr()); - let res = unsafe { Registry::RegDeleteValueW(key.hkey.load(), value_ptr) }; + let wide_value = value.map(|v| v.to_wide_cstring()); + let res = unsafe { host_winreg::delete_value(key.hkey.load(), wide_value.as_deref()) }; if res == 0 { Ok(()) } else { @@ -400,7 +356,7 @@ mod winreg { key: PyRef, #[pyarg(any)] sub_key: String, - #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_WOW64_64KEY)] + #[pyarg(any, default = host_winreg::KEY_WOW64_64KEY)] access: u32, #[pyarg(any, default = 0)] reserved: u32, @@ -408,11 +364,11 @@ mod winreg { #[pyfunction] fn DeleteKeyEx(args: DeleteKeyExArgs, vm: &VirtualMachine) -> PyResult<()> { - let wide_sub_key = args.sub_key.to_wide_with_nul(); + let wide_sub_key = args.sub_key.to_wide_cstring(); let res = unsafe { - Registry::RegDeleteKeyExW( + host_winreg::delete_key_ex( args.key.hkey.load(), - wide_sub_key.as_ptr(), + &wide_sub_key, args.access, args.reserved, ) @@ -435,16 +391,7 @@ mod winreg { let mut tmpbuf = [0u16; 257]; let mut len = tmpbuf.len() as u32; let res = unsafe { - Registry::RegEnumKeyExW( - key.hkey.load(), - index as u32, - tmpbuf.as_mut_ptr(), - &mut len, - core::ptr::null_mut(), - core::ptr::null_mut(), - core::ptr::null_mut(), - core::ptr::null_mut(), - ) + host_winreg::enum_key_ex(key.hkey.load(), index as u32, tmpbuf.as_mut_ptr(), &mut len) }; if res != 0 { return Err(os_error_from_windows_code(vm, res as i32)); @@ -458,21 +405,14 @@ mod winreg { // Query registry for the required buffer sizes. let mut ret_value_size: u32 = 0; let mut ret_data_size: u32 = 0; - let hkey: Registry::HKEY = hkey.hkey.load(); + let hkey: host_winreg::HKEY = hkey.hkey.load(); let rc = unsafe { - Registry::RegQueryInfoKeyW( + host_winreg::query_info_key( hkey, ptr::null_mut(), ptr::null_mut(), - ptr::null_mut(), - ptr::null_mut(), - ptr::null_mut(), - ptr::null_mut(), - ptr::null_mut(), &mut ret_value_size as *mut u32, &mut ret_data_size as *mut u32, - ptr::null_mut(), - ptr::null_mut(), ) }; if rc != 0 { @@ -495,18 +435,17 @@ mod winreg { let mut current_data_size = ret_data_size; let mut reg_type: u32 = 0; let rc = unsafe { - Registry::RegEnumValueW( + host_winreg::enum_value( hkey, index, ret_value_buf.as_mut_ptr(), &mut current_value_size as *mut u32, - ptr::null_mut(), &mut reg_type as *mut u32, ret_data_buf.as_mut_ptr(), &mut current_data_size as *mut u32, ) }; - if rc == ERROR_MORE_DATA { + if rc == host_winreg::ERROR_MORE_DATA { // Double the buffer sizes. buf_data_size *= 2; buf_value_size *= 2; @@ -547,7 +486,7 @@ mod winreg { #[pyfunction] fn FlushKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { - let res = unsafe { Registry::RegFlushKey(key.hkey.load()) }; + let res = host_winreg::flush_key(key.hkey.load()); if res == 0 { Ok(()) } else { @@ -562,10 +501,9 @@ mod winreg { file_name: String, vm: &VirtualMachine, ) -> PyResult<()> { - let sub_key = sub_key.to_wide_with_nul(); - let file_name = file_name.to_wide_with_nul(); - let res = - unsafe { Registry::RegLoadKeyW(key.hkey.load(), sub_key.as_ptr(), file_name.as_ptr()) }; + let sub_key = sub_key.to_wide_cstring(); + let file_name = file_name.to_wide_cstring(); + let res = unsafe { host_winreg::load_key(key.hkey.load(), &sub_key, &file_name) }; if res == 0 { Ok(()) } else { @@ -581,24 +519,18 @@ mod winreg { sub_key: String, #[pyarg(any, default = 0)] reserved: u32, - #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_READ)] + #[pyarg(any, default = host_winreg::KEY_READ)] access: u32, } #[pyfunction] #[pyfunction(name = "OpenKeyEx")] fn OpenKey(args: OpenKeyArgs, vm: &VirtualMachine) -> PyResult { - let wide_sub_key = args.sub_key.to_wide_with_nul(); - let mut res: Registry::HKEY = core::ptr::null_mut(); + let wide_sub_key = args.sub_key.to_wide_cstring(); + let mut res: host_winreg::HKEY = core::ptr::null_mut(); let err = unsafe { let key = args.key.hkey.load(); - Registry::RegOpenKeyExW( - key, - wide_sub_key.as_ptr(), - args.reserved, - args.access, - &mut res, - ) + host_winreg::open_key_ex(key, &wide_sub_key, args.reserved, args.access, &mut res) }; if err == 0 { Ok(PyHkey { @@ -613,35 +545,12 @@ mod winreg { #[pyfunction] fn QueryInfoKey(key: HKEYArg, vm: &VirtualMachine) -> PyResult> { let key = key.0; - let mut lpcsubkeys: u32 = 0; - let mut lpcvalues: u32 = 0; - let mut lpftlastwritetime: Foundation::FILETIME = unsafe { core::mem::zeroed() }; - let err = unsafe { - Registry::RegQueryInfoKeyW( - key, - core::ptr::null_mut(), - core::ptr::null_mut(), - 0 as _, - &mut lpcsubkeys, - core::ptr::null_mut(), - core::ptr::null_mut(), - &mut lpcvalues, - core::ptr::null_mut(), - core::ptr::null_mut(), - core::ptr::null_mut(), - &mut lpftlastwritetime, - ) - }; - - if err != 0 { - return Err(vm.new_os_error(format!("error code: {err}"))); - } - let l: u64 = (lpftlastwritetime.dwHighDateTime as u64) << 32 - | lpftlastwritetime.dwLowDateTime as u64; + let info = host_winreg::query_info_key_full(key) + .map_err(|err| vm.new_os_error(format!("error code: {err}")))?; let tup: Vec = vec![ - vm.ctx.new_int(lpcsubkeys).into(), - vm.ctx.new_int(lpcvalues).into(), - vm.ctx.new_int(l).into(), + vm.ctx.new_int(info.sub_keys).into(), + vm.ctx.new_int(info.values).into(), + vm.ctx.new_int(info.last_write_time).into(), ]; Ok(vm.ctx.new_tuple(tup)) } @@ -650,157 +559,31 @@ mod winreg { fn QueryValue(key: HKEYArg, sub_key: Option, vm: &VirtualMachine) -> PyResult { let hkey = key.0; - if hkey == Registry::HKEY_PERFORMANCE_DATA { + if hkey == host_winreg::HKEY_PERFORMANCE_DATA { return Err(os_error_from_windows_code( vm, - Foundation::ERROR_INVALID_HANDLE as i32, + host_winreg::ERROR_INVALID_HANDLE as i32, )); } - // Open subkey if provided and non-empty - let child_key = if let Some(ref sk) = sub_key { - if !sk.is_empty() { - let wide_sub_key = sk.to_wide_with_nul(); - let mut out_key = core::ptr::null_mut(); - let res = unsafe { - Registry::RegOpenKeyExW( - hkey, - wide_sub_key.as_ptr(), - 0, - Registry::KEY_QUERY_VALUE, - &mut out_key, - ) - }; - if res != 0 { - return Err(os_error_from_windows_code(vm, res as i32)); - } - Some(out_key) - } else { - None - } - } else { - None - }; - - let target_key = child_key.unwrap_or(hkey); - let mut buf_size: u32 = 256; - let mut buffer: Vec = vec![0; buf_size as usize]; - let mut reg_type: u32 = 0; - - // Loop to handle ERROR_MORE_DATA - let result = loop { - let mut size = buf_size; - let res = unsafe { - Registry::RegQueryValueExW( - target_key, - core::ptr::null(), // NULL value name for default value - core::ptr::null_mut(), - &mut reg_type, - buffer.as_mut_ptr(), - &mut size, - ) - }; - if res == ERROR_MORE_DATA { - buf_size *= 2; - buffer.resize(buf_size as usize, 0); - continue; - } - if res == Foundation::ERROR_FILE_NOT_FOUND { - // Return empty string if there's no default value - break Ok(String::new()); - } - if res != 0 { - break Err(os_error_from_windows_code(vm, res as i32)); - } - if reg_type != Registry::REG_SZ { - break Err(os_error_from_windows_code( - vm, - Foundation::ERROR_INVALID_DATA as i32, - )); - } - - // Convert UTF-16 to String - let u16_slice = bytes_as_wide_slice(&buffer[..size as usize]); - let len = u16_slice - .iter() - .position(|&c| c == 0) - .unwrap_or(u16_slice.len()); - break String::from_utf16(&u16_slice[..len]) - .map_err(|e| vm.new_value_error(format!("UTF16 error: {e}"))); - }; - - // Close child key if we opened one - if let Some(ck) = child_key { - unsafe { Registry::RegCloseKey(ck) }; - } - - result + host_winreg::query_default_value(hkey, sub_key.as_deref().map(std::ffi::OsStr::new)) + .map_err(|err| err.to_pyexception(vm)) } #[pyfunction] fn QueryValueEx(key: HKEYArg, name: String, vm: &VirtualMachine) -> PyResult> { let hkey = key.0; - let wide_name = name.to_wide_with_nul(); - let mut buf_size: u32 = 0; - let res = unsafe { - Registry::RegQueryValueExW( - hkey, - wide_name.as_ptr(), - core::ptr::null_mut(), - core::ptr::null_mut(), - core::ptr::null_mut(), - &mut buf_size, - ) - }; - // Handle ERROR_MORE_DATA by using a default buffer size - if res == ERROR_MORE_DATA || buf_size == 0 { - buf_size = 256; - } else if res != 0 { - return Err(os_error_from_windows_code(vm, res as i32)); - } - - let mut ret_buf = vec![0u8; buf_size as usize]; - let mut typ = 0; - let mut ret_size: u32; - - // Loop to handle ERROR_MORE_DATA - loop { - ret_size = buf_size; - let res = unsafe { - Registry::RegQueryValueExW( - hkey, - wide_name.as_ptr(), - core::ptr::null_mut(), - &mut typ, - ret_buf.as_mut_ptr(), - &mut ret_size, - ) - }; - - if res != ERROR_MORE_DATA { - if res != 0 { - return Err(os_error_from_windows_code(vm, res as i32)); - } - break; - } - - // Double buffer size and retry - buf_size *= 2; - ret_buf.resize(buf_size as usize, 0); - } - - // Only pass the bytes actually returned by the API - let obj = reg_to_py(vm, &ret_buf[..ret_size as usize], typ)?; + let (ret_buf, typ) = host_winreg::query_value_bytes(hkey, std::ffi::OsStr::new(&name)) + .map_err(|err| os_error_from_windows_code(vm, err as i32))?; + let obj = reg_to_py(vm, &ret_buf, typ)?; // Return tuple (value, type) Ok(vm.ctx.new_tuple(vec![obj, vm.ctx.new_int(typ).into()])) } #[pyfunction] fn SaveKey(key: PyRef, file_name: String, vm: &VirtualMachine) -> PyResult<()> { - let file_name = file_name.to_wide_with_nul(); - let res = unsafe { - Registry::RegSaveKeyW(key.hkey.load(), file_name.as_ptr(), core::ptr::null_mut()) - }; + let file_name = file_name.to_wide_cstring(); + let res = unsafe { host_winreg::save_key(key.hkey.load(), &file_name) }; if res == 0 { Ok(()) } else { @@ -816,61 +599,24 @@ mod winreg { value: String, vm: &VirtualMachine, ) -> PyResult<()> { - if typ != Registry::REG_SZ { + if typ != host_winreg::REG_SZ { return Err(vm.new_type_error("type must be winreg.REG_SZ")); } let hkey = key.hkey.load(); - if hkey == Registry::HKEY_PERFORMANCE_DATA { + if hkey == host_winreg::HKEY_PERFORMANCE_DATA { return Err(os_error_from_windows_code( vm, - Foundation::ERROR_INVALID_HANDLE as i32, + host_winreg::ERROR_INVALID_HANDLE as i32, )); } - // Create subkey if sub_key is non-empty - let child_key = if !sub_key.is_empty() { - let wide_sub_key = sub_key.to_wide_with_nul(); - let mut out_key = core::ptr::null_mut(); - let res = unsafe { - Registry::RegCreateKeyExW( - hkey, - wide_sub_key.as_ptr(), - 0, - core::ptr::null(), - 0, - Registry::KEY_SET_VALUE, - core::ptr::null(), - &mut out_key, - core::ptr::null_mut(), - ) - }; - if res != 0 { - return Err(os_error_from_windows_code(vm, res as i32)); - } - Some(out_key) - } else { - None - }; - - let target_key = child_key.unwrap_or(hkey); - // Convert value to UTF-16 for Wide API - let wide_value = value.to_wide_with_nul(); - let res = unsafe { - Registry::RegSetValueExW( - target_key, - core::ptr::null(), // value name is NULL - 0, - typ, - wide_value.as_ptr() as *const u8, - (wide_value.len() * 2) as u32, // byte count - ) - }; - - // Close child key if we created one - if let Some(ck) = child_key { - unsafe { Registry::RegCloseKey(ck) }; - } + let res = host_winreg::set_default_value( + hkey, + std::ffi::OsStr::new(&sub_key), + typ, + std::ffi::OsStr::new(&value), + ); if res == 0 { Ok(()) @@ -897,7 +643,7 @@ mod winreg { Ok(vm.ctx.new_int(val).into()) } REG_SZ | REG_EXPAND_SZ => { - let u16_slice = bytes_as_wide_slice(ret_data); + let u16_slice = host_winreg::bytes_as_wide_slice(ret_data); // Only use characters up to the first NUL. let len = u16_slice .iter() @@ -911,7 +657,7 @@ mod winreg { if ret_data.is_empty() { Ok(vm.ctx.new_list(vec![]).into()) } else { - let u16_slice = bytes_as_wide_slice(ret_data); + let u16_slice = host_winreg::bytes_as_wide_slice(ret_data); let u16_count = u16_slice.len(); // Remove trailing null if present (like countStrings) @@ -1047,17 +793,15 @@ mod winreg { value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - let wide_value_name = value_name.as_deref().map(|s| s.to_wide_with_nul()); - let value_name_ptr = wide_value_name - .as_deref() - .map_or(core::ptr::null(), |s| s.as_ptr()); + let wide_value_name = value_name.as_deref().map(|s| s.to_wide_cstring()); let reg_value = py2reg(value, typ, vm)?; let (ptr, len) = match ®_value { Some(v) => (v.as_ptr(), v.len() as u32), None => (core::ptr::null(), 0), }; - let res = - unsafe { Registry::RegSetValueExW(key.hkey.load(), value_name_ptr, 0, typ, ptr, len) }; + let res = unsafe { + host_winreg::set_value_ex(key.hkey.load(), wide_value_name.as_deref(), typ, ptr, len) + }; if res != 0 { return Err(os_error_from_windows_code(vm, res as i32)); } @@ -1066,7 +810,7 @@ mod winreg { #[pyfunction] fn DisableReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { - let res = unsafe { Registry::RegDisableReflectionKey(key.hkey.load()) }; + let res = host_winreg::disable_reflection_key(key.hkey.load()); if res == 0 { Ok(()) } else { @@ -1076,7 +820,7 @@ mod winreg { #[pyfunction] fn EnableReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { - let res = unsafe { Registry::RegEnableReflectionKey(key.hkey.load()) }; + let res = host_winreg::enable_reflection_key(key.hkey.load()); if res == 0 { Ok(()) } else { @@ -1087,7 +831,7 @@ mod winreg { #[pyfunction] fn QueryReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult { let mut result: i32 = 0; - let res = unsafe { Registry::RegQueryReflectionKey(key.hkey.load(), &mut result) }; + let res = unsafe { host_winreg::query_reflection_key(key.hkey.load(), &mut result) }; if res == 0 { Ok(result != 0) } else { @@ -1097,34 +841,7 @@ mod winreg { #[pyfunction] fn ExpandEnvironmentStrings(i: String, vm: &VirtualMachine) -> PyResult { - let wide_input = i.to_wide_with_nul(); - - // First call with size=0 to get required buffer size - let required_size = unsafe { - windows_sys::Win32::System::Environment::ExpandEnvironmentStringsW( - wide_input.as_ptr(), - core::ptr::null_mut(), - 0, - ) - }; - if required_size == 0 { - return Err(vm.new_os_error("ExpandEnvironmentStringsW failed".to_string())); - } - - // Allocate buffer with exact size and expand - let mut out = vec![0u16; required_size as usize]; - let r = unsafe { - windows_sys::Win32::System::Environment::ExpandEnvironmentStringsW( - wide_input.as_ptr(), - out.as_mut_ptr(), - required_size, - ) - }; - if r == 0 { - return Err(vm.new_os_error("ExpandEnvironmentStringsW failed".to_string())); - } - - let len = out.iter().position(|&c| c == 0).unwrap_or(out.len()); - String::from_utf16(&out[..len]).map_err(|e| vm.new_value_error(format!("UTF16 error: {e}"))) + host_winreg::expand_environment_strings(std::ffi::OsStr::new(&i)) + .map_err(|err| err.to_pyexception(vm)) } } diff --git a/crates/vm/src/stdlib/winsound.rs b/crates/vm/src/stdlib/winsound.rs index 359d3967a99..e8214a8fd19 100644 --- a/crates/vm/src/stdlib/winsound.rs +++ b/crates/vm/src/stdlib/winsound.rs @@ -3,18 +3,6 @@ pub(crate) use winsound::module_def; -mod win32 { - #[link(name = "winmm")] - unsafe extern "system" { - pub(super) fn PlaySoundW(pszSound: *const u16, hmod: isize, fdwSound: u32) -> i32; - } - - unsafe extern "system" { - pub(super) fn Beep(dwFreq: u32, dwDuration: u32) -> i32; - pub(super) fn MessageBeep(uType: u32) -> i32; - } -} - #[pymodule] mod winsound { use crate::builtins::{PyBytes, PyStr}; @@ -22,6 +10,7 @@ mod winsound { use crate::host_env::windows::ToWideString; use crate::protocol::PyBuffer; use crate::{AsObject, PyObjectRef, PyResult, VirtualMachine}; + use rustpython_host_env::winsound::{PlaySoundSource, play_sound}; // PlaySound flags #[pyattr] @@ -79,32 +68,34 @@ mod winsound { flags: i32, } + fn map_play_err( + vm: &VirtualMachine, + ) -> impl FnOnce( + rustpython_host_env::winsound::PlaySoundError, + ) -> crate::builtins::PyBaseExceptionRef + + '_ { + use rustpython_host_env::winsound::PlaySoundError::*; + |err| match err { + MemoryAsyncRejected => vm.new_runtime_error("Cannot play asynchronously from memory"), + MemoryFlagWithoutBuffer | CallFailed => vm.new_runtime_error("Failed to play sound"), + } + } + #[pyfunction] fn PlaySound(args: PlaySoundArgs, vm: &VirtualMachine) -> PyResult<()> { let sound = args.sound; let flags = args.flags as u32; if vm.is_none(&sound) { - let ok = unsafe { super::win32::PlaySoundW(core::ptr::null(), 0, flags) }; - if ok == 0 { - return Err(vm.new_runtime_error("Failed to play sound")); - } - return Ok(()); + return play_sound(PlaySoundSource::Stop, flags).map_err(map_play_err(vm)); } if flags & SND_MEMORY != 0 { - if flags & SND_ASYNC != 0 { - return Err(vm.new_runtime_error("Cannot play asynchronously from memory")); - } let buffer = PyBuffer::try_from_borrowed_object(vm, &sound)?; let buf = buffer .as_contiguous() .ok_or_else(|| vm.new_type_error("a bytes-like object is required, not 'str'"))?; - let ok = unsafe { super::win32::PlaySoundW(buf.as_ptr() as *const u16, 0, flags) }; - if ok == 0 { - return Err(vm.new_runtime_error("Failed to play sound")); - } - return Ok(()); + return play_sound(PlaySoundSource::Memory(&buf), flags).map_err(map_play_err(vm)); } if sound.downcastable::() { @@ -157,11 +148,9 @@ mod winsound { } let wide = path.to_wide_with_nul(); - let ok = unsafe { super::win32::PlaySoundW(wide.as_ptr(), 0, flags) }; - if ok == 0 { - return Err(vm.new_runtime_error("Failed to play sound")); - } - Ok(()) + let wide_cstr = widestring::WideCStr::from_slice_truncate(&wide) + .map_err(|_| vm.new_value_error("embedded null character"))?; + play_sound(PlaySoundSource::Name(wide_cstr), flags).map_err(map_play_err(vm)) } #[derive(FromArgs)] @@ -178,11 +167,11 @@ mod winsound { return Err(vm.new_value_error("frequency must be in 37 thru 32767")); } - let ok = unsafe { super::win32::Beep(args.frequency as u32, args.duration as u32) }; - if ok == 0 { - return Err(vm.new_runtime_error("Failed to beep")); + if rustpython_host_env::winsound::beep(args.frequency as u32, args.duration as u32) { + Ok(()) + } else { + Err(vm.new_runtime_error("Failed to beep")) } - Ok(()) } #[derive(FromArgs)] @@ -193,10 +182,6 @@ mod winsound { #[pyfunction] fn MessageBeep(args: MessageBeepArgs, vm: &VirtualMachine) -> PyResult<()> { - let ok = unsafe { super::win32::MessageBeep(args.r#type) }; - if ok == 0 { - return Err(std::io::Error::last_os_error().into_pyexception(vm)); - } - Ok(()) + rustpython_host_env::winsound::message_beep(args.r#type).map_err(|e| e.into_pyexception(vm)) } } diff --git a/crates/vm/src/types/slot_defs.rs b/crates/vm/src/types/slot_defs.rs index 6ba5d1d6453..956bc6fa4d1 100644 --- a/crates/vm/src/types/slot_defs.rs +++ b/crates/vm/src/types/slot_defs.rs @@ -1501,7 +1501,7 @@ mod tests { use super::*; #[test] - fn test_find_by_name() { + fn find_by_name() { // __len__ appears in both sequence and mapping let len_defs: Vec<_> = find_slot_defs_by_name("__len__").collect(); assert_eq!(len_defs.len(), 2); @@ -1516,7 +1516,7 @@ mod tests { } #[test] - fn test_slot_op() { + fn slot_op() { // Test comparison ops assert_eq!(SlotOp::Lt.as_compare_op(), Some(PyComparisonOp::Lt)); assert_eq!(SlotOp::Eq.as_compare_op(), Some(PyComparisonOp::Eq)); diff --git a/crates/vm/src/version.rs b/crates/vm/src/version.rs index 29ed3c2aa4a..05eb12a2942 100644 --- a/crates/vm/src/version.rs +++ b/crates/vm/src/version.rs @@ -1,149 +1,76 @@ -//! Several function to retrieve version information. - -use chrono::{Local, prelude::DateTime}; -use core::time::Duration; -use std::time::UNIX_EPOCH; +//! Version info constants. +//! +//! Most of the constants are auto calculated at compile time. The main exception is the +//! target CPython version. This is defined and updated in `build.rs`. + +macro_rules! parse_consts { + ($name: ident, $var: literal) => { + pub const $name: usize = match usize::from_str_radix(env!($var), 10) { + Ok(v) => v, + Err(_) => panic!(concat!("Compile with Cargo to get '", $var, "'")), + }; + }; +} -// = 3.14.0alpha -pub const MAJOR: usize = 3; -pub const MINOR: usize = 14; -pub const MICRO: usize = 0; -pub const RELEASELEVEL: &str = "alpha"; -pub const RELEASELEVEL_N: usize = 0xA; -pub const SERIAL: usize = 0; +// CPython target version info +parse_consts!(MAJOR, "MAJOR_CPY"); +parse_consts!(MINOR, "MINOR_CPY"); +parse_consts!(MICRO, "MICRO_CPY"); +pub const RELEASELEVEL: &str = env!("RELEASE_LEVEL_CPY"); +parse_consts!(RELEASELEVEL_N, "RELEASE_LEVEL_N_CPY"); +parse_consts!(SERIAL, "SERIAL_CPY"); pub const VERSION_HEX: usize = (MAJOR << 24) | (MINOR << 16) | (MICRO << 8) | (RELEASELEVEL_N << 4) | SERIAL; +#[cfg(windows)] +pub const WINVER: &str = env!("WINVER_CPY"); + pub const GIT_REVISION: &str = env!("RUSTPYTHON_GIT_HASH"); -const GIT_TAG: &str = env!("RUSTPYTHON_GIT_TAG"); -const GIT_BRANCH: &str = env!("RUSTPYTHON_GIT_BRANCH"); +pub const GIT_IDENTIFIER: &str = env!("RUSTPYTHON_GIT_IDENTIFIER"); +// const GIT_TAG: &str = env!("RUSTPYTHON_GIT_TAG"); +// const GIT_BRANCH: &str = env!("RUSTPYTHON_GIT_BRANCH"); // RustPython version -pub const MAJOR_IMPL: usize = match usize::from_str_radix(env!("CARGO_PKG_VERSION_MAJOR"), 10) { - Ok(v) => v, - Err(_) => panic!("Compile with Cargo to get 'CARGO_PKG_VERSION_MAJOR'"), -}; -pub const MINOR_IMPL: usize = match usize::from_str_radix(env!("CARGO_PKG_VERSION_MINOR"), 10) { - Ok(v) => v, - Err(_) => panic!("Compile with Cargo to get 'CARGO_PKG_VERSION_MINOR'"), -}; -pub const MICRO_IMPL: usize = match usize::from_str_radix(env!("CARGO_PKG_VERSION_PATCH"), 10) { - Ok(v) => v, - Err(_) => panic!("Compile with Cargo to get 'CARGO_PKG_VERSION_PATCH'"), -}; +parse_consts!(MAJOR_IMPL, "CARGO_PKG_VERSION_MAJOR"); +parse_consts!(MINOR_IMPL, "CARGO_PKG_VERSION_MINOR"); +parse_consts!(MICRO_IMPL, "CARGO_PKG_VERSION_PATCH"); pub const RELEASELEVEL_IMPL: &str = env!("RUSTPYTHON_RELEASE_LEVEL"); -pub const SERIAL_IMPL: usize = match usize::from_str_radix(env!("RUSTPYTHON_RELEASE_SERIAL"), 10) { - Ok(v) => v, - Err(_) => panic!("Compile with Cargo to get 'RUSTPYTHON_RELEASE_SERIAL'"), -}; +parse_consts!(RELEASELEVEL_N_IMPL, "RUSTPYTHON_RELEASE_LEVEL_N"); +parse_consts!(SERIAL_IMPL, "RUSTPYTHON_RELEASE_SERIAL"); pub const VERSION_HEX_IMPL: usize = (MAJOR_IMPL << 24) | (MINOR_IMPL << 16) | (MICRO_IMPL << 8) - | (RELEASELEVEL_N << 4) + | (RELEASELEVEL_N_IMPL << 4) | SERIAL_IMPL; -#[must_use] -pub fn get_version() -> String { - // Windows: include MSC v. for compatibility with ctypes.util.find_library - // MSC v.1929 = VS 2019, version 14+ makes find_msvcrt() return None - let msc_info = cfg_select! { - windows => {{ - let arch = if cfg!(target_pointer_width = "64") { - "64 bit (AMD64)" - } else { - "32 bit (Intel)" - }; - // Include both RustPython identifier and MSC v. for compatibility - format!(" MSC v.1929 {arch}",) - }}, - _ => String::new(), - }; - - format!( - "{:.80} ({:.80}) \n[RustPython {} with {:.80}{}]", // \n is PyPy convention - get_version_number(), - get_build_info(), - env!("CARGO_PKG_VERSION"), - COMPILER, - msc_info, - ) -} - -#[must_use] -pub fn get_version_number() -> String { - format!("{MAJOR}.{MINOR}.{MICRO}{RELEASELEVEL}") -} - -#[must_use] -pub fn get_winver_number() -> String { - format!("{MAJOR}.{MINOR}") -} - -const COMPILER: &str = env!("RUSTC_VERSION"); +pub const RUSTPYTHON_BUILD_INFO: &str = env!("RUSTPYTHON_BUILD_INFO"); +pub const RUSTPYTHON_VERSION: &str = const { + const LEFT: &str = env!("RUSTPYTHON_VERSION_LEFT"); + const RIGHT: &str = env!("RUSTPYTHON_VERSION_RIGHT"); + const LEN: usize = LEFT.len() + RIGHT.len() + 1; -#[must_use] -pub fn get_build_info() -> String { - // See: https://reproducible-builds.org/docs/timestamps/ - let separator = if GIT_REVISION.is_empty() { "" } else { ":" }; - let git_identifier = get_git_identifier(); + const fn concat() -> [u8; LEN] { + let mut bytes_temp = [0u8; LEN]; - format!( - "{id}{sep}{revision}, {date:.20}, {time:.9}", - id = if git_identifier.is_empty() { - "default" - } else { - git_identifier - }, - sep = separator, - revision = GIT_REVISION, - date = get_git_date(), - time = get_git_time(), - ) -} + let (left, _) = bytes_temp.split_at_mut(LEFT.len()); + left.copy_from_slice(LEFT.as_bytes()); + let (_, right) = bytes_temp.split_at_mut(LEFT.len() + 1); + right.copy_from_slice(RIGHT.as_bytes()); + bytes_temp[LEFT.len()] = b'\n'; -#[must_use] -pub const fn get_git_identifier() -> &'static str { - if GIT_TAG.is_empty() || GIT_TAG.eq_ignore_ascii_case("undefined") { - GIT_BRANCH - } else { - GIT_TAG + bytes_temp } -} - -fn get_git_timestamp_datetime() -> DateTime { - let timestamp = option_env!("RUSTPYTHON_GIT_TIMESTAMP").unwrap_or_default(); - let timestamp = timestamp.parse::().unwrap_or_default(); - - let datetime = UNIX_EPOCH + Duration::from_secs(timestamp); - - datetime.into() -} - -#[must_use] -pub fn get_git_date() -> String { - let datetime = get_git_timestamp_datetime(); - - datetime.format("%b %e %Y").to_string() -} - -#[must_use] -pub fn get_git_time() -> String { - let datetime = get_git_timestamp_datetime(); - - datetime.format("%H:%M:%S").to_string() -} -#[must_use] -pub fn get_git_datetime() -> String { - let date = get_git_date(); - let time = get_git_time(); - - format!("{date} {time}") -} + const BUF: [u8; LEN] = concat(); + match str::from_utf8(&BUF) { + Ok(v) => v, + Err(_) => unreachable!(), + } +}; // Must be aligned to Lib/importlib/_bootstrap_external.py -// Bumped to 2994 for new CommonConstant discriminants (BuiltinList, BuiltinSet) -pub const PYC_MAGIC_NUMBER: u16 = 2994; +// Matches CPython 3.14 (Include/internal/pycore_magic_number.h). +pub const PYC_MAGIC_NUMBER: u16 = 3627; // CPython format: magic_number | ('\r' << 16) | ('\n' << 24) // This protects against text-mode file reads diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs index 8d20e750706..4d16c5d8075 100644 --- a/crates/vm/src/vm/context.rs +++ b/crates/vm/src/vm/context.rs @@ -26,6 +26,7 @@ use crate::{ object::{Py, PyObjectPayload, PyObjectRef, PyPayload, PyRef}, types::{PyTypeFlags, PyTypeSlots, TypeZoo}, }; +use core::ffi::{CStr, c_void}; use malachite_bigint::BigInt; use num_complex::Complex64; use num_traits::ToPrimitive; @@ -754,10 +755,11 @@ impl Context { pub fn new_capsule( &self, - ptr: *mut core::ffi::c_void, + ptr: *mut c_void, + name: Option<&'static CStr>, destructor: Option, ) -> PyRef { - PyCapsule::new(ptr, destructor).into_ref(self) + PyCapsule::new(ptr, name, destructor).into_ref(self) } } diff --git a/crates/vm/src/vm/interpreter.rs b/crates/vm/src/vm/interpreter.rs index 505986acae0..ecdfea01a29 100644 --- a/crates/vm/src/vm/interpreter.rs +++ b/crates/vm/src/vm/interpreter.rs @@ -95,8 +95,11 @@ where } as usize); // Initialize frozen modules (core + user-provided) - let mut frozen: std::collections::HashMap<&'static str, FrozenModule, ahash::RandomState> = - core_frozen_inits().collect(); + let mut frozen: std::collections::HashMap< + &'static str, + FrozenModule, + rapidhash::quality::RandomState, + > = core_frozen_inits().collect(); frozen.extend(frozen_modules); // Create PyGlobalState @@ -570,7 +573,7 @@ mod tests { use malachite_bigint::ToBigInt; #[test] - fn test_add_py_integers() { + fn add_py_integers() { Interpreter::without_stdlib(Default::default()).enter(|vm| { let a: PyObjectRef = vm.ctx.new_int(33_i32).into(); let b: PyObjectRef = vm.ctx.new_int(12_i32).into(); @@ -581,7 +584,7 @@ mod tests { } #[test] - fn test_multiply_str() { + fn multiply_str() { Interpreter::without_stdlib(Default::default()).enter(|vm| { let a = vm.new_pyobj(crate::common::ascii!("Hello ")); let b = vm.new_pyobj(4_i32); diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 44842a866a0..7775c34e053 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -48,11 +48,6 @@ use core::{ sync::atomic::{AtomicBool, AtomicU64, Ordering}, }; use crossbeam_utils::atomic::AtomicCell; -#[cfg(unix)] -use nix::{ - sys::signal::{SaFlags, SigAction, SigSet, Signal::SIGINT, kill, sigaction}, - unistd::getpid, -}; use std::{ collections::{HashMap, HashSet}, ffi::{OsStr, OsString}, @@ -105,6 +100,7 @@ pub struct VirtualMachine { /// Current running asyncio task for this thread pub asyncio_running_task: RefCell>, pub(crate) callable_cache: CallableCache, + pub(crate) audit_hooks: RefCell>, } /// Non-owning frame pointer for the frames stack. @@ -577,9 +573,7 @@ pub(super) fn stw_trace(msg: core::fmt::Arguments<'_>) { crate::stdlib::_thread::get_ident(), msg ); - unsafe { - let _ = libc::write(libc::STDERR_FILENO, out.buf.as_ptr().cast(), out.len); - } + crate::host_env::io::write_stderr_raw(&out.buf[..out.len]); } } @@ -595,7 +589,7 @@ pub(crate) struct CallableCache { pub struct PyGlobalState { pub config: PyConfig, pub module_defs: BTreeMap<&'static str, &'static builtins::PyModuleDef>, - pub frozen: HashMap<&'static str, FrozenModule, ahash::RandomState>, + pub frozen: HashMap<&'static str, FrozenModule, rapidhash::quality::RandomState>, pub stacksize: AtomicCell, pub thread_count: AtomicCell, pub hash_secret: HashSecret, @@ -757,6 +751,7 @@ impl VirtualMachine { asyncio_running_loop: RefCell::new(None), asyncio_running_task: RefCell::new(None), callable_cache: CallableCache::default(), + audit_hooks: RefCell::new(vec![]), }; if vm.state.hash_secret.hash_str("") @@ -1446,19 +1441,7 @@ impl VirtualMachine { /// Returns (base, top) where base is the lowest address and top is the highest. #[cfg(all(not(miri), not(target_env = "musl"), windows))] fn get_stack_bounds() -> (usize, usize) { - use windows_sys::Win32::System::Threading::{ - GetCurrentThreadStackLimits, SetThreadStackGuarantee, - }; - let mut low: usize = 0; - let mut high: usize = 0; - unsafe { - GetCurrentThreadStackLimits(&mut low as *mut usize, &mut high as *mut usize); - // Add the guaranteed stack space (reserved for exception handling) - let mut guarantee: u32 = 0; - SetThreadStackGuarantee(&mut guarantee); - low += guarantee as usize; - } - (low, high) + crate::host_env::windows::current_thread_stack_bounds() } /// Get stack boundaries on non-Windows platforms. @@ -2167,15 +2150,10 @@ impl VirtualMachine { self.print_exception(exc); cfg_select! { unix => { - let action = SigAction::new( - nix::sys::signal::SigHandler::SigDfl, - SaFlags::SA_ONSTACK, - SigSet::empty(), - ); - let result = unsafe { sigaction(SIGINT, &action) }; - if result.is_ok() { + if crate::host_env::signal::set_sigint_default_onstack().is_ok() { self.flush_std(); - kill(getpid(), SIGINT).expect("Expect to be killed."); + crate::host_env::signal::send_sigint_to_self() + .expect("Expect to be killed."); } (libc::SIGINT as u32) + 128 @@ -2288,52 +2266,57 @@ pub fn resolve_frozen_alias(name: &str) -> &str { } } -#[test] -fn test_nested_frozen() { - use rustpython_vm as vm; - - vm::Interpreter::builder(Default::default()) - .add_frozen_modules(rustpython_vm::py_freeze!( - dir = "../../../../extra_tests/snippets" - )) - .build() - .enter(|vm| { - let scope = vm.new_scope_with_builtins(); - - let source = "from dir_module.dir_module_inner import value2"; - let code_obj = vm - .compile(source, vm::compiler::Mode::Exec, "".to_owned()) - .map_err(|err| vm.new_syntax_error(&err, Some(source))) - .unwrap(); - - if let Err(e) = vm.run_code_obj(code_obj, scope) { - vm.print_exception(e); - panic!(); - } - }) -} - -#[test] -fn frozen_origname_matches() { - use rustpython_vm as vm; - - vm::Interpreter::builder(Default::default()) - .build() - .enter(|vm| { - let check = |name, expected| { - let module = import::import_frozen(vm, name).unwrap(); - let origname: PyStrRef = module - .get_attr("__origname__", vm) - .unwrap() - .try_into_value(vm) +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn nested_frozen() { + use rustpython_vm as vm; + + vm::Interpreter::builder(Default::default()) + .add_frozen_modules(rustpython_vm::py_freeze!( + dir = "../../../../extra_tests/snippets" + )) + .build() + .enter(|vm| { + let scope = vm.new_scope_with_builtins(); + + let source = "from dir_module.dir_module_inner import value2"; + let code_obj = vm + .compile(source, vm::compiler::Mode::Exec, "".to_owned()) + .map_err(|err| vm.new_syntax_error(&err, Some(source))) .unwrap(); - assert_eq!(origname.as_wtf8(), expected); - }; - check("_frozen_importlib", "importlib._bootstrap"); - check( - "_frozen_importlib_external", - "importlib._bootstrap_external", - ); - }); + if let Err(e) = vm.run_code_obj(code_obj, scope) { + vm.print_exception(e); + panic!(); + } + }) + } + + #[test] + fn frozen_origname_matches() { + use rustpython_vm as vm; + + vm::Interpreter::builder(Default::default()) + .build() + .enter(|vm| { + let check = |name, expected| { + let module = import::import_frozen(vm, name).unwrap(); + let origname: PyStrRef = module + .get_attr("__origname__", vm) + .unwrap() + .try_into_value(vm) + .unwrap(); + assert_eq!(origname.as_wtf8(), expected); + }; + + check("_frozen_importlib", "importlib._bootstrap"); + check( + "_frozen_importlib_external", + "importlib._bootstrap_external", + ); + }); + } } diff --git a/crates/vm/src/vm/python_run.rs b/crates/vm/src/vm/python_run.rs index 26000b5e11f..ccb295465ec 100644 --- a/crates/vm/src/vm/python_run.rs +++ b/crates/vm/src/vm/python_run.rs @@ -174,3 +174,60 @@ mod file_run { Ok(crate::import::check_pyc_magic_number_bytes(&buf)) } } + +#[cfg(test)] +mod tests { + use crate::object::AsObject; + use rustpython_vm::Interpreter; + + fn interpreter() -> Interpreter { + Interpreter::without_stdlib(Default::default()) + } + + #[test] + fn block_expr_return_const() { + interpreter().enter(|vm| { + let scope = vm.new_scope_with_builtins(); + let value = vm.unwrap_pyresult(vm.run_block_expr(scope, "1")); + let value = vm.unwrap_pyresult(value.try_int(vm)); + let value: u32 = vm.unwrap_pyresult(value.try_to_primitive(vm)); + assert_eq!(value, 1); + }) + } + + #[test] + fn block_expr_return_nonconst() { + interpreter().enter(|vm| { + let scope = vm.new_scope_with_builtins(); + vm.unwrap_pyresult(scope.globals.set_item("x", vm.new_pyobj(3), vm)); + let value = vm.unwrap_pyresult(vm.run_block_expr(scope, "2 + x")); + let value = vm.unwrap_pyresult(value.try_int(vm)); + let value: u32 = vm.unwrap_pyresult(value.try_to_primitive(vm)); + assert_eq!(value, 5); + }) + } + + #[test] + fn block_expr_return_function_def() { + interpreter().enter(|vm| { + let scope = vm.new_scope_with_builtins(); + let value = + vm.unwrap_pyresult(vm.run_block_expr(scope.clone(), "def f():\n return 7")); + vm.unwrap_pyresult(scope.globals.set_item("returned", value, vm)); + let value = vm.unwrap_pyresult(vm.run_block_expr(scope, "returned is f")); + assert!(value.is(&vm.ctx.true_value)); + }) + } + + #[test] + fn block_expr_return_class_def() { + interpreter().enter(|vm| { + let scope = vm.new_scope_with_builtins(); + let value = + vm.unwrap_pyresult(vm.run_block_expr(scope.clone(), "class C:\n value = 11")); + vm.unwrap_pyresult(scope.globals.set_item("returned", value, vm)); + let value = vm.unwrap_pyresult(vm.run_block_expr(scope, "returned is C")); + assert!(value.is(&vm.ctx.true_value)); + }) + } +} diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index a81396bff34..f89cf818d2c 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -730,6 +730,7 @@ impl VirtualMachine { asyncio_running_loop: RefCell::new(None), asyncio_running_task: RefCell::new(None), callable_cache: self.callable_cache.clone(), + audit_hooks: RefCell::new(vec![]), }; ThreadedVirtualMachine { vm } } diff --git a/crates/vm/src/vm/vm_new.rs b/crates/vm/src/vm/vm_new.rs index 63de652ac7a..3b50c25695f 100644 --- a/crates/vm/src/vm/vm_new.rs +++ b/crates/vm/src/vm/vm_new.rs @@ -1,3 +1,14 @@ +#[cfg(feature = "parser")] +use ruff_python_ast::token::TokenKind; + +use ruff_python_parser::{InterpolatedStringErrorType, LexicalErrorType, ParseErrorType}; + +use rustpython_common::wtf8::Wtf8Buf; +use rustpython_compiler_core::SourceLocation; + +#[cfg(feature = "parser")] +use rustpython_compiler::{CompileError, ParseError}; + use crate::{ AsObject, Py, PyObject, PyObjectRef, PyRef, PyResult, builtins::{ @@ -13,8 +24,6 @@ use crate::{ scope::Scope, vm::VirtualMachine, }; -use rustpython_common::wtf8::Wtf8Buf; -use rustpython_compiler_core::SourceLocation; macro_rules! define_exception_fn { ( @@ -25,14 +34,237 @@ macro_rules! define_exception_fn { stringify!($python_repr), " object.\nUseful for raising errors from python functions implemented in rust." )] - pub fn $fn_name(&self, msg: impl Into) -> PyBaseExceptionRef - { + pub fn $fn_name(&self, msg: impl Into) -> $crate::builtins::PyBaseExceptionRef { let err = self.ctx.exceptions.$attr.to_owned(); self.new_exception_msg(err, msg.into()) } }; } +#[derive(Clone, Debug)] +struct SyntaxErrorInfo { + msg: String, + narrow_caret: bool, +} + +impl SyntaxErrorInfo { + #[must_use] + const fn new(msg: String, narrow_caret: bool) -> Self { + Self { msg, narrow_caret } + } + + fn with_msg(&mut self, msg: &str) { + self.msg = msg.into(); + } + + #[cfg(feature = "parser")] + const fn with_narrow_caret(&mut self, narrow_caret: bool) { + self.narrow_caret = narrow_caret; + } + + #[cfg(feature = "parser")] + #[must_use] + const fn handle_expected_token(expected: &TokenKind, found: &TokenKind) -> &'static str { + match (*expected, *found) { + (TokenKind::Colon, TokenKind::Newline) => "expected ':'", + + (TokenKind::Lpar, _) => "expected '('", + + (TokenKind::Else, y) if !matches!(y, TokenKind::Colon) => { + "expected 'else' after 'if' expression" + } + + _ => "invalid syntax", + } + } + + #[cfg(feature = "parser")] + fn analyze_compile_error(&mut self, compile_error: &CompileError) { + let CompileError::Parse(ParseError { + error, location, .. + }) = compile_error + else { + return; + }; + + let msg = match error { + ParseErrorType::FStringError(InterpolatedStringErrorType::UnterminatedString) + | ParseErrorType::Lexical(LexicalErrorType::FStringError( + InterpolatedStringErrorType::UnterminatedString, + )) => "unterminated f-string literal".into(), + + ParseErrorType::FStringError( + InterpolatedStringErrorType::UnterminatedTripleQuotedString, + ) + | ParseErrorType::Lexical(LexicalErrorType::FStringError( + InterpolatedStringErrorType::UnterminatedTripleQuotedString, + )) => "unterminated triple-quoted f-string literal".into(), + + ParseErrorType::FStringError(_) + | ParseErrorType::Lexical(LexicalErrorType::FStringError(_)) => { + // Replace backticks with single quotes to match CPython's error messages + format!("invalid syntax: {}", self.msg.replace('`', "'")) + } + + ParseErrorType::UnexpectedExpressionToken => format!("invalid syntax: {}", self.msg), + + ParseErrorType::ExpectedToken { expected, found } => { + Self::handle_expected_token(expected, found).into() + } + + ParseErrorType::InvalidStarredExpressionUsage => { + self.with_narrow_caret(true); + "invalid syntax".into() + } + + ParseErrorType::InvalidDeleteTarget => "invalid syntax".into(), + + ParseErrorType::Lexical(LexicalErrorType::LineContinuationError) => { + "unexpected character after line continuation character".into() + } + + ParseErrorType::Lexical(LexicalErrorType::UnclosedStringError) => { + format!( + "unterminated string literal (detected at line {})", + location.line + ) + } + + ParseErrorType::EmptyTypeParams => "Type parameter list cannot be empty".into(), + + ParseErrorType::InvalidStarPatternUsage => { + self.with_narrow_caret(true); + "cannot use starred expression here".into() + } + + ParseErrorType::ExpectedKeywordParam => "named arguments must follow bare *".into(), + + ParseErrorType::EmptyImportNames => "Expected one or more names after 'import'".into(), + + ParseErrorType::DuplicateKeywordArgumentError(arg_name) => { + format!("keyword argument repeated: {arg_name}") + } + + ParseErrorType::UnparenthesizedGeneratorExpression => { + "Generator expression must be parenthesized".into() + } + + ParseErrorType::NonDefaultParamAfterDefaultParam => { + "parameter without a default follows parameter with a default".into() + } + + ParseErrorType::VarParameterWithDefault => { + "var-positional argument cannot have default value".into() + } + + ParseErrorType::PositionalAfterKeywordArgument => { + "positional argument follows keyword argument".into() + } + + ParseErrorType::PositionalAfterKeywordUnpacking => { + "positional argument follows keyword argument unpacking".into() + } + + ParseErrorType::InvalidArgumentUnpackingOrder => { + "iterable argument unpacking follows keyword argument unpacking".into() + } + + ParseErrorType::ParamAfterVarKeywordParam => { + "arguments cannot follow var-keyword argument".into() + } + + ParseErrorType::Lexical(LexicalErrorType::UnrecognizedToken { .. }) + | ParseErrorType::SimpleStatementsOnSameLine + | ParseErrorType::SimpleAndCompoundStatementOnSameLine + | ParseErrorType::ExpectedExpression => "invalid syntax".into(), + + ParseErrorType::OtherError(s) + if s.starts_with("Expected an identifier, but found a keyword") => + { + "invalid syntax".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case( + "bytes literal cannot be mixed with non-bytes literals", + ) => + { + "cannot mix bytes and nonbytes literals".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("positional patterns cannot follow keyword patterns") => + { + "positional patterns follow keyword patterns".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("boolean 'not' expression cannot be used here") => + { + "'not' after an operator must be parenthesized".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("trailing comma not allowed") => + { + "trailing comma not allowed without surrounding parentheses".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case( + "multiple exception types must be parenthesized when using `as`", + ) => + { + "multiple exception types must be parenthesized when using 'as'".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case( + "position-only parameter separator not allowed as first parameter", + ) => + { + "at least one argument must precede /".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("only one '/' separator allowed") => + { + "/ may appear only once".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("'/' parameter must appear before '*' parameter") => + { + "/ must be ahead of *".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("expected `except` or `finally` after `try` block") => + { + "expected 'except' or 'finally' block".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case("only one '*' parameter allowed") => + { + "* argument may appear only once".into() + } + + ParseErrorType::OtherError(s) + if s.eq_ignore_ascii_case( + r#"cannot have both 'except' and 'except*' on the same 'try'"#, + ) => + { + r#"cannot have both 'except' and 'except*' on the same 'try'"#.into() + } + + _ => return, + }; + + self.with_msg(&msg); + } +} + /// Collection of object creation helpers impl VirtualMachine { /// Create a new python object @@ -83,7 +315,7 @@ impl VirtualMachine { let def = self .ctx .new_method_def(name, f, PyMethodFlags::empty(), None); - def.build_function(self) + def.build_function(self, None) } pub fn new_method( @@ -492,135 +724,31 @@ impl VirtualMachine { Some(line + "\n") } - let statement = if let Some(source) = source { - get_statement(source, error.location()) - } else { - None - }; + let statement = source.and_then(|src| get_statement(src, error.location())); let mut msg = error.to_string(); if let Some(msg) = msg.get_mut(..1) { msg.make_ascii_lowercase(); } - let mut narrow_caret = false; - match error { - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: - ruff_python_parser::ParseErrorType::FStringError( - ruff_python_parser::InterpolatedStringErrorType::UnterminatedString, - ) - | ruff_python_parser::ParseErrorType::Lexical( - ruff_python_parser::LexicalErrorType::FStringError( - ruff_python_parser::InterpolatedStringErrorType::UnterminatedString, - ), - ), - .. - }) => { - msg = "unterminated f-string literal".to_owned(); - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: - ruff_python_parser::ParseErrorType::FStringError( - ruff_python_parser::InterpolatedStringErrorType::UnterminatedTripleQuotedString, - ) - | ruff_python_parser::ParseErrorType::Lexical( - ruff_python_parser::LexicalErrorType::FStringError( - ruff_python_parser::InterpolatedStringErrorType::UnterminatedTripleQuotedString, - ), - ), - .. - }) => { - msg = "unterminated triple-quoted f-string literal".to_owned(); - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: - ruff_python_parser::ParseErrorType::FStringError(_) - | ruff_python_parser::ParseErrorType::Lexical( - ruff_python_parser::LexicalErrorType::FStringError(_), - ), - .. - }) => { - // Replace backticks with single quotes to match CPython's error messages - msg = msg.replace('`', "'"); - msg.insert_str(0, "invalid syntax: "); - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: ruff_python_parser::ParseErrorType::UnexpectedExpressionToken, - .. - }) => msg.insert_str(0, "invalid syntax: "), - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: - ruff_python_parser::ParseErrorType::Lexical( - ruff_python_parser::LexicalErrorType::UnrecognizedToken { .. }, - ) - | ruff_python_parser::ParseErrorType::SimpleStatementsOnSameLine - | ruff_python_parser::ParseErrorType::SimpleAndCompoundStatementOnSameLine - | ruff_python_parser::ParseErrorType::ExpectedToken { .. } - | ruff_python_parser::ParseErrorType::ExpectedExpression, - .. - }) => { - msg = "invalid syntax".to_owned(); - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: ruff_python_parser::ParseErrorType::InvalidStarredExpressionUsage, - .. - }) => { - msg = "invalid syntax".to_owned(); - narrow_caret = true; - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: ruff_python_parser::ParseErrorType::InvalidDeleteTarget, - .. - }) => { - msg = "invalid syntax".to_owned(); - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: - ruff_python_parser::ParseErrorType::Lexical( - ruff_python_parser::LexicalErrorType::LineContinuationError, - ), - .. - }) => { - msg = "unexpected character after line continuation".to_owned(); - } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: - ruff_python_parser::ParseErrorType::Lexical( - ruff_python_parser::LexicalErrorType::UnclosedStringError, - ), - .. - }) => { - msg = "unterminated string".to_owned(); + + cfg_select! { + feature = "parser" => { + let mut syntax_error_info = SyntaxErrorInfo::new(msg, false); + syntax_error_info.analyze_compile_error(error); } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: ruff_python_parser::ParseErrorType::OtherError(s), - .. - }) if s.eq_ignore_ascii_case("bytes literal cannot be mixed with non-bytes literals") => { - msg = "cannot mix bytes and nonbytes literals".to_owned(); + _ => { + let syntax_error_info = SyntaxErrorInfo::new(msg, false); } - #[cfg(feature = "parser")] - crate::compiler::CompileError::Parse(rustpython_compiler::ParseError { - error: ruff_python_parser::ParseErrorType::OtherError(s), - .. - }) if s.starts_with("Expected an identifier, but found a keyword") => { - msg = "invalid syntax".to_owned(); - } - _ => {} - } + }; + if syntax_error_type.is(self.ctx.exceptions.tab_error) { - msg = "inconsistent use of tabs and spaces in indentation".to_owned(); + syntax_error_info.with_msg("inconsistent use of tabs and spaces in indentation"); } + + let SyntaxErrorInfo { msg, narrow_caret } = syntax_error_info; + let syntax_error = self.new_exception_msg(syntax_error_type, msg.into()); + let (lineno, offset) = error.python_location(); let lineno = self.ctx.new_int(lineno); let offset = self.ctx.new_int(offset); diff --git a/crates/vm/src/windows.rs b/crates/vm/src/windows.rs index 1d858310ff9..aed2b3fd689 100644 --- a/crates/vm/src/windows.rs +++ b/crates/vm/src/windows.rs @@ -1,18 +1,12 @@ -use crate::host_env::fileutils::{ - StatStruct, - windows::{FILE_INFO_BY_NAME_CLASS, get_file_information_by_name}, -}; use crate::{ PyObjectRef, PyResult, TryFromObject, VirtualMachine, convert::{ToPyObject, ToPyResult}, }; -use rustpython_host_env::windows::ToWideString; -use std::ffi::OsStr; -use windows_sys::Win32::Foundation::{HANDLE, INVALID_HANDLE_VALUE}; +use rustpython_host_env::nt as host_nt; /// Windows HANDLE wrapper for Python interop #[derive(Clone, Copy)] -pub struct WinHandle(pub HANDLE); +pub struct WinHandle(pub host_nt::Handle); pub(crate) trait WindowsSysResultValue { type Ok: ToPyObject; @@ -22,11 +16,11 @@ pub(crate) trait WindowsSysResultValue { fn into_ok(self) -> Self::Ok; } -impl WindowsSysResultValue for HANDLE { +impl WindowsSysResultValue for host_nt::Handle { type Ok = WinHandle; fn is_err(&self) -> bool { - *self == INVALID_HANDLE_VALUE + host_nt::is_invalid_handle(*self) } fn into_ok(self) -> Self::Ok { @@ -73,7 +67,7 @@ type HandleInt = isize; impl TryFromObject for WinHandle { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let handle = HandleInt::try_from_object(vm, obj)?; - Ok(Self(handle as HANDLE)) + Ok(Self(handle as host_nt::Handle)) } } @@ -82,465 +76,3 @@ impl ToPyObject for WinHandle { (self.0 as HandleInt).to_pyobject(vm) } } - -pub fn init_winsock() { - static WSA_INIT: parking_lot::Once = parking_lot::Once::new(); - WSA_INIT.call_once(|| unsafe { - let mut wsa_data = core::mem::MaybeUninit::uninit(); - let _ = windows_sys::Win32::Networking::WinSock::WSAStartup(0x0101, wsa_data.as_mut_ptr()); - }) -} - -// win32_xstat in cpython -pub fn win32_xstat(path: &OsStr, traverse: bool) -> std::io::Result { - let mut result = win32_xstat_impl(path, traverse)?; - // ctime is only deprecated from 3.12, so we copy birthtime across - result.st_ctime = result.st_birthtime; - result.st_ctime_nsec = result.st_birthtime_nsec; - Ok(result) -} - -fn is_reparse_tag_name_surrogate(tag: u32) -> bool { - (tag & 0x20000000) > 0 -} - -// Constants -const IO_REPARSE_TAG_SYMLINK: u32 = 0xA000000C; -const S_IFMT: u16 = libc::S_IFMT as u16; -const S_IFDIR: u16 = libc::S_IFDIR as u16; -const S_IFREG: u16 = libc::S_IFREG as u16; -const S_IFCHR: u16 = libc::S_IFCHR as u16; -const S_IFLNK: u16 = crate::host_env::fileutils::windows::S_IFLNK as u16; -const S_IFIFO: u16 = crate::host_env::fileutils::windows::S_IFIFO as u16; - -/// FILE_ATTRIBUTE_TAG_INFO structure for GetFileInformationByHandleEx -#[repr(C)] -#[derive(Default)] -struct FileAttributeTagInfo { - file_attributes: u32, - reparse_tag: u32, -} - -/// Ported from attributes_to_mode (fileutils.c) -fn attributes_to_mode(attr: u32) -> u16 { - use windows_sys::Win32::Storage::FileSystem::{ - FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_READONLY, - }; - let mut m: u16 = 0; - if attr & FILE_ATTRIBUTE_DIRECTORY != 0 { - m |= S_IFDIR | 0o111; // IFEXEC for user,group,other - } else { - m |= S_IFREG; - } - if attr & FILE_ATTRIBUTE_READONLY != 0 { - m |= 0o444; - } else { - m |= 0o666; - } - m -} - -/// Ported from _Py_attribute_data_to_stat (fileutils.c) -/// Converts BY_HANDLE_FILE_INFORMATION to StatStruct -fn attribute_data_to_stat( - info: &windows_sys::Win32::Storage::FileSystem::BY_HANDLE_FILE_INFORMATION, - reparse_tag: u32, - basic_info: Option<&windows_sys::Win32::Storage::FileSystem::FILE_BASIC_INFO>, - id_info: Option<&windows_sys::Win32::Storage::FileSystem::FILE_ID_INFO>, -) -> StatStruct { - use crate::host_env::fileutils::windows::SECS_BETWEEN_EPOCHS; - use windows_sys::Win32::Storage::FileSystem::FILE_ATTRIBUTE_REPARSE_POINT; - - let mut st_mode = attributes_to_mode(info.dwFileAttributes); - let st_size = ((info.nFileSizeHigh as u64) << 32) | (info.nFileSizeLow as u64); - let st_dev = id_info.map_or(info.dwVolumeSerialNumber, |id| id.VolumeSerialNumber as u32); - let st_nlink = info.nNumberOfLinks as i32; - - // Convert FILETIME/LARGE_INTEGER to (time_t, nsec) - let filetime_to_time = |ft_low: u32, ft_high: u32| -> (libc::time_t, i32) { - let ticks = ((ft_high as i64) << 32) | (ft_low as i64); - let nsec = ((ticks % 10_000_000) * 100) as i32; - let sec = (ticks / 10_000_000 - SECS_BETWEEN_EPOCHS) as libc::time_t; - (sec, nsec) - }; - - let large_integer_to_time = |li: i64| -> (libc::time_t, i32) { - let nsec = ((li % 10_000_000) * 100) as i32; - let sec = (li / 10_000_000 - SECS_BETWEEN_EPOCHS) as libc::time_t; - (sec, nsec) - }; - - let (st_birthtime, st_birthtime_nsec); - let (st_mtime, st_mtime_nsec); - let (st_atime, st_atime_nsec); - - if let Some(bi) = basic_info { - (st_birthtime, st_birthtime_nsec) = large_integer_to_time(bi.CreationTime); - (st_mtime, st_mtime_nsec) = large_integer_to_time(bi.LastWriteTime); - (st_atime, st_atime_nsec) = large_integer_to_time(bi.LastAccessTime); - } else { - (st_birthtime, st_birthtime_nsec) = filetime_to_time( - info.ftCreationTime.dwLowDateTime, - info.ftCreationTime.dwHighDateTime, - ); - (st_mtime, st_mtime_nsec) = filetime_to_time( - info.ftLastWriteTime.dwLowDateTime, - info.ftLastWriteTime.dwHighDateTime, - ); - (st_atime, st_atime_nsec) = filetime_to_time( - info.ftLastAccessTime.dwLowDateTime, - info.ftLastAccessTime.dwHighDateTime, - ); - } - - // Get file ID from id_info or fallback to file index - let (st_ino, st_ino_high) = if let Some(id) = id_info { - // FILE_ID_INFO.FileId is FILE_ID_128 which is [u8; 16] - let bytes = id.FileId.Identifier; - let low = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); - let high = u64::from_le_bytes(bytes[8..16].try_into().unwrap()); - (low, high) - } else { - let ino = ((info.nFileIndexHigh as u64) << 32) | (info.nFileIndexLow as u64); - (ino, 0u64) - }; - - // Set symlink mode if applicable - if info.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 - && reparse_tag == IO_REPARSE_TAG_SYMLINK - { - st_mode = (st_mode & !S_IFMT) | S_IFLNK; - } - - StatStruct { - st_dev, - st_ino, - st_ino_high, - st_mode, - st_nlink, - st_uid: 0, - st_gid: 0, - st_rdev: 0, - st_size, - st_atime, - st_atime_nsec, - st_mtime, - st_mtime_nsec, - st_ctime: 0, // Will be set by caller - st_ctime_nsec: 0, - st_birthtime, - st_birthtime_nsec, - st_file_attributes: info.dwFileAttributes, - st_reparse_tag: reparse_tag, - } -} - -/// Get file info using FindFirstFileW (fallback when CreateFileW fails) -/// Ported from attributes_from_dir -fn attributes_from_dir( - path: &OsStr, -) -> std::io::Result<( - windows_sys::Win32::Storage::FileSystem::BY_HANDLE_FILE_INFORMATION, - u32, -)> { - use windows_sys::Win32::Storage::FileSystem::{ - BY_HANDLE_FILE_INFORMATION, FILE_ATTRIBUTE_REPARSE_POINT, FindClose, FindFirstFileW, - WIN32_FIND_DATAW, - }; - - let wide: Vec = path.to_wide_with_nul(); - let mut find_data: WIN32_FIND_DATAW = unsafe { core::mem::zeroed() }; - - let handle = unsafe { FindFirstFileW(wide.as_ptr(), &mut find_data) }; - if handle == INVALID_HANDLE_VALUE { - return Err(std::io::Error::last_os_error()); - } - unsafe { FindClose(handle) }; - - let mut info: BY_HANDLE_FILE_INFORMATION = unsafe { core::mem::zeroed() }; - info.dwFileAttributes = find_data.dwFileAttributes; - info.ftCreationTime = find_data.ftCreationTime; - info.ftLastAccessTime = find_data.ftLastAccessTime; - info.ftLastWriteTime = find_data.ftLastWriteTime; - info.nFileSizeHigh = find_data.nFileSizeHigh; - info.nFileSizeLow = find_data.nFileSizeLow; - - let reparse_tag = if find_data.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 { - find_data.dwReserved0 - } else { - 0 - }; - - Ok((info, reparse_tag)) -} - -/// Ported from win32_xstat_slow_impl -fn win32_xstat_slow_impl(path: &OsStr, traverse: bool) -> std::io::Result { - use windows_sys::Win32::{ - Foundation::{ - CloseHandle, ERROR_ACCESS_DENIED, ERROR_CANT_ACCESS_FILE, ERROR_INVALID_FUNCTION, - ERROR_INVALID_PARAMETER, ERROR_NOT_SUPPORTED, ERROR_SHARING_VIOLATION, GENERIC_READ, - INVALID_HANDLE_VALUE, - }, - Storage::FileSystem::{ - BY_HANDLE_FILE_INFORMATION, CreateFileW, FILE_ATTRIBUTE_DIRECTORY, - FILE_ATTRIBUTE_NORMAL, FILE_ATTRIBUTE_REPARSE_POINT, FILE_BASIC_INFO, - FILE_FLAG_BACKUP_SEMANTICS, FILE_FLAG_OPEN_REPARSE_POINT, FILE_ID_INFO, - FILE_READ_ATTRIBUTES, FILE_SHARE_READ, FILE_SHARE_WRITE, FILE_TYPE_CHAR, - FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_UNKNOWN, FileAttributeTagInfo, FileBasicInfo, - FileIdInfo, GetFileAttributesW, GetFileInformationByHandle, - GetFileInformationByHandleEx, GetFileType, INVALID_FILE_ATTRIBUTES, OPEN_EXISTING, - }, - }; - - let wide: Vec = path.to_wide_with_nul(); - - let access = FILE_READ_ATTRIBUTES; - let mut flags = FILE_FLAG_BACKUP_SEMANTICS; - if !traverse { - flags |= FILE_FLAG_OPEN_REPARSE_POINT; - } - - let mut h_file = unsafe { - CreateFileW( - wide.as_ptr(), - access, - 0, - core::ptr::null(), - OPEN_EXISTING, - flags, - core::ptr::null_mut(), - ) - }; - - let mut file_info: BY_HANDLE_FILE_INFORMATION = unsafe { core::mem::zeroed() }; - let mut tag_info = FileAttributeTagInfo::default(); - let mut is_unhandled_tag = false; - - if h_file == INVALID_HANDLE_VALUE { - let error = std::io::Error::last_os_error(); - let error_code = error.raw_os_error().unwrap_or(0) as u32; - - match error_code { - ERROR_ACCESS_DENIED | ERROR_SHARING_VIOLATION => { - // Try reading the parent directory using FindFirstFileW - let (info, reparse_tag) = attributes_from_dir(path)?; - file_info = info; - tag_info.reparse_tag = reparse_tag; - - if file_info.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 - && (traverse || !is_reparse_tag_name_surrogate(tag_info.reparse_tag)) - { - return Err(error); - } - // h_file remains INVALID_HANDLE_VALUE, we'll use file_info from FindFirstFileW - } - ERROR_INVALID_PARAMETER => { - // Retry with GENERIC_READ (needed for \\.\con) - h_file = unsafe { - CreateFileW( - wide.as_ptr(), - access | GENERIC_READ, - FILE_SHARE_READ | FILE_SHARE_WRITE, - core::ptr::null(), - OPEN_EXISTING, - flags, - core::ptr::null_mut(), - ) - }; - if h_file == INVALID_HANDLE_VALUE { - return Err(error); - } - } - ERROR_CANT_ACCESS_FILE if traverse => { - // bpo37834: open unhandled reparse points if traverse fails - is_unhandled_tag = true; - h_file = unsafe { - CreateFileW( - wide.as_ptr(), - access, - 0, - core::ptr::null(), - OPEN_EXISTING, - flags | FILE_FLAG_OPEN_REPARSE_POINT, - core::ptr::null_mut(), - ) - }; - if h_file == INVALID_HANDLE_VALUE { - return Err(error); - } - } - _ => return Err(error), - } - } - - // Scope for handle cleanup - let result = (|| -> std::io::Result { - if h_file != INVALID_HANDLE_VALUE { - // Handle types other than files on disk - let file_type = unsafe { GetFileType(h_file) }; - if file_type != FILE_TYPE_DISK { - if file_type == FILE_TYPE_UNKNOWN { - let err = std::io::Error::last_os_error(); - if err.raw_os_error().unwrap_or(0) != 0 { - return Err(err); - } - } - let file_attributes = unsafe { GetFileAttributesW(wide.as_ptr()) }; - let mut st_mode: u16 = 0; - if file_attributes != INVALID_FILE_ATTRIBUTES - && file_attributes & FILE_ATTRIBUTE_DIRECTORY != 0 - { - st_mode = S_IFDIR; - } else if file_type == FILE_TYPE_CHAR { - st_mode = S_IFCHR; - } else if file_type == FILE_TYPE_PIPE { - st_mode = S_IFIFO; - } - return Ok(StatStruct { - st_mode, - ..Default::default() - }); - } - - // Query the reparse tag - if !traverse || is_unhandled_tag { - let mut local_tag_info: FileAttributeTagInfo = unsafe { core::mem::zeroed() }; - let ret = unsafe { - GetFileInformationByHandleEx( - h_file, - FileAttributeTagInfo, - &mut local_tag_info as *mut _ as *mut _, - core::mem::size_of::() as u32, - ) - }; - if ret == 0 { - let err_code = - std::io::Error::last_os_error().raw_os_error().unwrap_or(0) as u32; - match err_code { - ERROR_INVALID_PARAMETER | ERROR_INVALID_FUNCTION | ERROR_NOT_SUPPORTED => { - local_tag_info.file_attributes = FILE_ATTRIBUTE_NORMAL; - local_tag_info.reparse_tag = 0; - } - _ => return Err(std::io::Error::last_os_error()), - } - } else if local_tag_info.file_attributes & FILE_ATTRIBUTE_REPARSE_POINT != 0 { - if is_reparse_tag_name_surrogate(local_tag_info.reparse_tag) { - if is_unhandled_tag { - return Err(std::io::Error::from_raw_os_error( - ERROR_CANT_ACCESS_FILE as i32, - )); - } - // This is a symlink, keep the tag info - } else if !is_unhandled_tag { - // Traverse a non-link reparse point - unsafe { CloseHandle(h_file) }; - return win32_xstat_slow_impl(path, true); - } - } - tag_info = local_tag_info; - } - - // Get file information - let ret = unsafe { GetFileInformationByHandle(h_file, &mut file_info) }; - if ret == 0 { - let err_code = std::io::Error::last_os_error().raw_os_error().unwrap_or(0) as u32; - match err_code { - ERROR_INVALID_PARAMETER | ERROR_INVALID_FUNCTION | ERROR_NOT_SUPPORTED => { - // Volumes and physical disks are block devices - return Ok(StatStruct { - st_mode: 0x6000, // S_IFBLK - ..Default::default() - }); - } - _ => return Err(std::io::Error::last_os_error()), - } - } - - // Get FILE_BASIC_INFO - let mut basic_info: FILE_BASIC_INFO = unsafe { core::mem::zeroed() }; - let has_basic_info = unsafe { - GetFileInformationByHandleEx( - h_file, - FileBasicInfo, - &mut basic_info as *mut _ as *mut _, - core::mem::size_of::() as u32, - ) - } != 0; - - // Get FILE_ID_INFO (optional) - let mut id_info: FILE_ID_INFO = unsafe { core::mem::zeroed() }; - let has_id_info = unsafe { - GetFileInformationByHandleEx( - h_file, - FileIdInfo, - &mut id_info as *mut _ as *mut _, - core::mem::size_of::() as u32, - ) - } != 0; - - let mut result = attribute_data_to_stat( - &file_info, - tag_info.reparse_tag, - if has_basic_info { - Some(&basic_info) - } else { - None - }, - if has_id_info { Some(&id_info) } else { None }, - ); - result.update_st_mode_from_path(path, file_info.dwFileAttributes); - Ok(result) - } else { - // We got file_info from attributes_from_dir - let mut result = attribute_data_to_stat(&file_info, tag_info.reparse_tag, None, None); - result.update_st_mode_from_path(path, file_info.dwFileAttributes); - Ok(result) - } - })(); - - // Cleanup - if h_file != INVALID_HANDLE_VALUE { - unsafe { CloseHandle(h_file) }; - } - - result -} - -fn win32_xstat_impl(path: &OsStr, traverse: bool) -> std::io::Result { - use windows_sys::Win32::{Foundation, Storage::FileSystem::FILE_ATTRIBUTE_REPARSE_POINT}; - - let stat_info = - get_file_information_by_name(path, FILE_INFO_BY_NAME_CLASS::FileStatBasicByNameInfo); - match stat_info { - Ok(stat_info) => { - if (stat_info.FileAttributes & FILE_ATTRIBUTE_REPARSE_POINT == 0) - || (!traverse && is_reparse_tag_name_surrogate(stat_info.ReparseTag)) - { - let mut result = - crate::host_env::fileutils::windows::stat_basic_info_to_stat(&stat_info); - // If st_ino is 0, fall through to slow path to get proper file ID - if result.st_ino != 0 || result.st_ino_high != 0 { - result.update_st_mode_from_path(path, stat_info.FileAttributes); - return Ok(result); - } - } - } - Err(e) => { - if let Some(errno) = e.raw_os_error() - && matches!( - errno as u32, - Foundation::ERROR_FILE_NOT_FOUND - | Foundation::ERROR_PATH_NOT_FOUND - | Foundation::ERROR_NOT_READY - | Foundation::ERROR_BAD_NET_NAME - ) - { - return Err(e); - } - } - } - - // Fallback to slow implementation - win32_xstat_slow_impl(path, traverse) -} diff --git a/examples/custom_tls_providers.rs b/examples/custom_tls_providers.rs new file mode 100644 index 00000000000..1f382fc1b84 --- /dev/null +++ b/examples/custom_tls_providers.rs @@ -0,0 +1,65 @@ +//! Example project to demonstrate how to set a custom rustls provider for RustPython. + +// spell-checker: ignore graviola + +use std::env; + +use rustls::crypto::ring; +use rustpython_pylib::FROZEN_STDLIB; +use rustpython_stdlib::{ssl::providers::CryptoExt, stdlib_module_defs}; +use rustpython_vm::Interpreter; + +const SCRIPT: &str = r#" +import urllib.request + +with urllib.request.urlopen("https://python.org") as response: + assert response.status == 200 +"#; + +fn main() { + let provider = env::args() + .skip(1) + .find_map(|arg| match &*arg { + "--ring" => Some("ring"), + "--graviola" => Some("graviola"), + _ => None, + }) + .unwrap_or("ring"); + + match provider { + "ring" => { + let ext = CryptoExt { + all_cipher_suites: Some(ring::ALL_CIPHER_SUITES), + all_kx_groups: Some(ring::ALL_KX_GROUPS), + any_supported_key: Some(ring::sign::any_supported_type), + ticketer: ring::Ticketer::new, + }; + CryptoExt::set_provider(ring::default_provider(), ext).unwrap(); + println!("Using ring for cryptography"); + } + "graviola" => { + let ext = CryptoExt { + all_cipher_suites: Some(rustls_graviola::suites::ALL_CIPHER_SUITES), + all_kx_groups: Some(rustls_graviola::kx::ALL_KX_GROUPS), + any_supported_key: None, + ticketer: rustls_graviola::Ticketer::new, + }; + CryptoExt::set_provider(rustls_graviola::default_provider(), ext).unwrap(); + println!("Using Graviola for cryptography"); + } + unsupported => panic!("Unsupported provider: {unsupported}"), + } + + let builder = Interpreter::builder(Default::default()); + let defs = stdlib_module_defs(&builder.ctx); + let result = builder + .add_native_modules(&defs) + .add_frozen_modules(FROZEN_STDLIB) + .build() + .run(|vm| { + let scope = vm.new_scope_with_builtins(); + vm.run_block_expr(scope, SCRIPT).map(|_| ()) + }); + + assert_eq!(0, result); +} diff --git a/extra_tests/snippets/stdlib_select.py b/extra_tests/snippets/stdlib_select.py index d27bb82b1c3..5263bc344f6 100644 --- a/extra_tests/snippets/stdlib_select.py +++ b/extra_tests/snippets/stdlib_select.py @@ -4,6 +4,8 @@ from testutils import assert_raises +TOO_MANY_SELECT_FDS = 4096 + class Nope: pass @@ -42,3 +44,36 @@ def fileno(self): assert recvr in rres assert sendr in wres + +# Too many descriptors for select.select() +if sys.platform != "win32": + import resource + + soft_max_fds, hard_max_fds = resource.getrlimit(resource.RLIMIT_NOFILE) + if soft_max_fds != resource.RLIM_INFINITY: + # 100 additional fds should be enough for interpreter needs + need_fds = TOO_MANY_SELECT_FDS + 100 + + soft_max_fds = max(soft_max_fds, need_fds) + if hard_max_fds != resource.RLIM_INFINITY: + assert hard_max_fds >= soft_max_fds, ( + "Not enough file descriptors for this test" + ) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_max_fds, hard_max_fds)) +sockets = [s for _ in range(TOO_MANY_SELECT_FDS // 2) for s in socket.socketpair()] +assert_raises(ValueError, select.select, sockets, [], [], 0) +if sys.platform != "win32": + # Try to overflow descriptor bit mask on *nix with a single item + max_fd = -1 + max_fd_sock = None + sockets.reverse() + for sock in sockets: + if sock.fileno() > max_fd: + max_fd = sock.fileno() + max_fd_sock = sock + assert_raises(ValueError, select.select, [max_fd_sock], [], [], 0) +del sockets +a, b = socket.socketpair() +# CPython disallows this on *nix systems too. +assert_raises(ValueError, select.select, [a] * TOO_MANY_SELECT_FDS, [], [], 0) +del a, b diff --git a/extra_tests/snippets/stdlib_sqlite.py b/extra_tests/snippets/stdlib_sqlite.py index f2e02b48cf1..83fc7db94f0 100644 --- a/extra_tests/snippets/stdlib_sqlite.py +++ b/extra_tests/snippets/stdlib_sqlite.py @@ -53,3 +53,20 @@ def finalize(self): cx.create_aggregate("aggtxt", 1, AggrText) cur.execute("select aggtxt(key) from foo") assert cur.fetchone()[0] == "341011" + +# Blob extended-slice assignment with negative step +# Guard: CPython 3.11 has a SystemError bug with negative-step Blob slicing; +# this test only runs on RustPython where the fix is being validated. +# TODO: remove this once https://github.com/python/cpython/pull/150450 is released and RustPython CI uses it. +import sys + +if sys.implementation.name == "rustpython": + cx.execute("CREATE TABLE blobtest(b BLOB)") + data = b"this blob data string is exactly fifty bytes long!" + cx.execute("INSERT INTO blobtest(b) VALUES (?)", (data,)) + blob = cx.blobopen("blobtest", "b", 1) + blob[9:0:-2] = b"12345" # writes to indices 9, 7, 5, 3, 1 + actual = cx.execute("select b from blobtest").fetchone()[0] + expected = b"t5i4 3l2b1" + data[10:] + assert actual == expected, f"got {actual!r}, expected {expected!r}" + blob.close() diff --git a/extra_tests/snippets/stdlib_ssl_short_recv.py b/extra_tests/snippets/stdlib_ssl_short_recv.py new file mode 100644 index 00000000000..4ec36e5b7e0 --- /dev/null +++ b/extra_tests/snippets/stdlib_ssl_short_recv.py @@ -0,0 +1,88 @@ +import os +import socket +import ssl +import sys +import threading + +if sys.implementation.name.lower() != "rustpython": + print("Ignored: stdlib_ssl_short_recv (RustPython only)") + raise SystemExit + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CERTFILE = os.path.join(ROOT_DIR, "Lib/test/certdata/keycert.pem") +DATA = b"x" * 128 + +orig_recv = socket.socket.recv +client_sockname = None +recv_n = {} + + +def new_recv(sock, bufsize, flags=0): + sockname = sock.getsockname() + if sockname not in recv_n: + recv_n[sockname] = 0 + + bufsize = 1 + + if flags & socket.MSG_PEEK == 0: + recv_n[sockname] += 1 + return orig_recv(sock, bufsize, flags) + + +socket.socket.recv = new_recv + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(("127.0.0.1", 0)) +listener.listen(1) +addr, port = listener.getsockname() +server_errors = [] + + +def server(): + try: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(CERTFILE) + + sock, _ = listener.accept() + sock.settimeout(5.0) + + ssock = server_context.wrap_socket(sock, server_side=True) + try: + ssock.sendall(DATA) + finally: + ssock.close() + except BaseException as exc: + server_errors.append(exc) + finally: + listener.close() + + +thread = threading.Thread(target=server) +thread.start() + +raw = socket.create_connection((addr, port), timeout=5.0) +client_sockname = raw.getsockname() +raw.settimeout(5.0) + +client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +client_context.check_hostname = False +client_context.verify_mode = ssl.CERT_NONE + +client = client_context.wrap_socket(raw, server_hostname=None) +try: + chunks = [] + while sum(len(chunk) for chunk in chunks) < len(DATA): + chunk = client.recv(20000) + if not chunk: + break + chunks.append(chunk) +finally: + client.close() + +thread.join(10.0) +assert not thread.is_alive(), "server thread did not stop" +assert not server_errors, server_errors +assert b"".join(chunks) == DATA +assert len(recv_n) == 2 +assert all(n > 100 for n in recv_n.values()) diff --git a/extra_tests/snippets/stdlib_threading.py b/extra_tests/snippets/stdlib_threading.py index f35d7e9d085..cb989e1fd33 100644 --- a/extra_tests/snippets/stdlib_threading.py +++ b/extra_tests/snippets/stdlib_threading.py @@ -1,6 +1,7 @@ import multiprocessing import os import threading +import time def import_in_thread(module_name): @@ -62,6 +63,48 @@ def start_fork_process_after_thread(): assert process.exitcode == 0, process.exitcode +def thread_join_ordering(): + output = [] + + def thread_function(name): + output.append((name, 0)) + time.sleep(2.0) + output.append((name, 1)) + + output.append((0, 0)) + x = threading.Thread(target=thread_function, args=(1,)) + output.append((0, 1)) + x.start() + output.append((0, 2)) + x.join() + output.append((0, 3)) + + assert len(output) == 6, output + # CPython has [(1, 0), (0, 2)] for the middle 2, but we have [(0, 2), (1, 0)] + # TODO: maybe fix this, if it turns out to be a problem? + # assert output == [(0, 0), (0, 1), (1, 0), (0, 2), (1, 1), (0, 3)] + + +def thread_exit_without_join(): + # Regression for https://github.com/RustPython/RustPython/issues/7813: + # a thread started without ``.join()`` must exit cleanly even when the + # captured target callable drops during teardown (which can fire + # weakref callbacks that re-enter the VM). + output = [] + + def runner(): + output.append("runner done") + + threading.Thread(target=runner).start() + time.sleep(1) + output.append("main done") + assert "runner done" in output, output + assert "main done" in output, output + + +thread_join_ordering() +thread_exit_without_join() + import_in_thread("functools") import_in_thread("tempfile") import_in_thread("multiprocessing.connection") diff --git a/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py b/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py new file mode 100644 index 00000000000..18ac9ab010a --- /dev/null +++ b/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py @@ -0,0 +1,246 @@ +import os +import socket +import ssl +import sys +import threading +import time +import urllib.request + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CERTFILE = os.path.join(ROOT_DIR, "Lib/test/certdata/keycert.pem") +BODY = b"x" * 407_676 + +# TLS record body sizes observed from https://crates.io/api/v1/crates/tokio. +TLS_RECORD_BODY_SIZES = [ + 2855, + 281, + 53, + 218, + 1095, + 1395, + 1395, + 483, + 1395, + 1395, + 1395, + 1395, + 48, + 1360, + 1354, + 1395, + 1395, + 1395, + 1367, + 1395, + 1395, + 1395, + 1395, + 1326, + 1395, + 1395, + 1395, + 47, + 1395, + 1395, + 1395, + 1395, + 95, + 1395, + 1332, + 1287, + 1388, + 1395, + 1395, + 1374, + 1395, + 1380, + 794, + 791, + 1395, + 1381, + 1395, + 1395, + 1395, + 1333, + 1395, + 1395, + 1395, + 1395, + 1395, + 1395, + 965, + 16401, + 3914, + 2526, + 1041, + 8209, + 9233, + 16401, + 11650, + 10262, + 7486, + 3468, + 692, + 1041, + 16401, + 12242, + 9466, + 1041, + 8209, + 9233, + 8209, + 9233, + 16401, + 1041, + 8209, + 9233, + 6161, + 2065, + 9233, + 16401, + 16358, + 10806, + 1041, + 8209, + 16401, + 3914, + 16401, + 16401, + 3089, + 9233, + 4642, + 478, + 8209, + 3140, + 1752, + 9233, + 8209, + 8209, + 16401, + 16064, + 14676, + 13288, + 2065, + 16401, + 1041, + 8209, + 16401, + 1041, + 6374, + 1007, +] + +server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +server_context.load_cert_chain(CERTFILE) +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(("127.0.0.1", 0)) +listener.listen(1) +addr, port = listener.getsockname() +server_errors = [] +finished = False + + +def guard_timeout(): + time.sleep(20) + if not finished: + print( + "stdlib_urllib_https_misaligned_recv.py timed out", + file=sys.stderr, + flush=True, + ) + os.abort() + + +threading.Thread(target=guard_timeout, daemon=True).start() + + +def drain_outgoing(outgoing, conn): + while True: + try: + data = outgoing.read() + except ssl.SSLWantReadError: + return + if not data: + return + conn.sendall(data) + + +def run_server(): + try: + conn, _ = listener.accept() + conn.settimeout(5.0) + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + tls = server_context.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + tls.do_handshake() + break + except ssl.SSLWantReadError: + drain_outgoing(outgoing, conn) + incoming.write(conn.recv(65536)) + except ssl.SSLWantWriteError: + pass + drain_outgoing(outgoing, conn) + + request = b"" + while b"\r\n\r\n" not in request: + try: + request += tls.read(65536) + except ssl.SSLWantReadError: + drain_outgoing(outgoing, conn) + incoming.write(conn.recv(65536)) + drain_outgoing(outgoing, conn) + + response = ( + b"HTTP/1.1 200 OK\r\n" + b"Connection: close\r\n" + + b"Content-Length: " + + str(len(BODY)).encode() + + b"\r\n" + + b"Content-Type: application/json\r\n" + + b"\r\n" + + BODY + ) + plaintext_sizes = [max(1, n - 17) for n in TLS_RECORD_BODY_SIZES] + pos = 0 + while pos < len(response): + size = plaintext_sizes.pop(0) if plaintext_sizes else 16384 + end = min(len(response), pos + size) + while pos < end: + try: + pos += tls.write(response[pos:end]) + except ssl.SSLWantWriteError: + pass + drain_outgoing(outgoing, conn) + conn.close() + except BaseException as exc: + server_errors.append(exc) + finally: + listener.close() + + +thread = threading.Thread(target=run_server) +thread.start() + +client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +client_context.check_hostname = False +client_context.verify_mode = ssl.CERT_NONE +opener = urllib.request.build_opener( + urllib.request.ProxyHandler({}), + urllib.request.HTTPSHandler(context=client_context), +) +try: + with opener.open(f"https://{addr}:{port}/", timeout=5.0) as response: + body = response.read() + + thread.join(10.0) + assert not thread.is_alive(), "server thread did not stop" + assert not server_errors, server_errors + assert body == BODY +finally: + finished = True diff --git a/extra_tests/snippets/test_threading.py b/extra_tests/snippets/test_threading.py deleted file mode 100644 index 4d7c29f5093..00000000000 --- a/extra_tests/snippets/test_threading.py +++ /dev/null @@ -1,24 +0,0 @@ -import threading -import time - -output = [] - - -def thread_function(name): - output.append((name, 0)) - time.sleep(2.0) - output.append((name, 1)) - - -output.append((0, 0)) -x = threading.Thread(target=thread_function, args=(1,)) -output.append((0, 1)) -x.start() -output.append((0, 2)) -x.join() -output.append((0, 3)) - -assert len(output) == 6, output -# CPython has [(1, 0), (0, 2)] for the middle 2, but we have [(0, 2), (1, 0)] -# TODO: maybe fix this, if it turns out to be a problem? -# assert output == [(0, 0), (0, 1), (1, 0), (0, 2), (1, 1), (0, 3)] diff --git a/host_env_proposal.md b/host_env_proposal.md new file mode 100644 index 00000000000..fd32df1abc9 --- /dev/null +++ b/host_env_proposal.md @@ -0,0 +1,494 @@ +# Plan: Create `rustpython-host_env` crate + +## Context + +RustPython controls host OS access via the `host_env` feature flag, enforced by `#[cfg(feature = "host_env")]` scattered across hundreds of locations. If a `cfg` is forgotten, host code leaks into sandbox builds silently. + +By isolating host OS API wrappers into a dedicated crate, **the crate boundary itself becomes the sandbox guarantee**. Key constraint: this crate has **zero Python runtime dependency**. All Python-level bindings must be added by the consumer (vm/stdlib). + +## Current State + +### Already Python-free host abstractions in `crates/common/src/`: +- `os.rs` — errno handling, exit_code, winerror_to_errno, OsStr ffi conversions +- `crt_fd.rs` — CRT file descriptor abstraction (Owned/Borrowed types, open/read/write/close) +- `fileutils.rs` — fstat, fopen, Windows StatStruct +- `windows.rs` — ToWideString, FromWideString traits +- `macros.rs` — `suppress_iph!` macro (MSVC invalid parameter handler suppression) + +### Pure host functions embedded in vm/stdlib modules: + +These files mix Python bindings with pure host API calls. The host parts should be extracted: + +**`vm/src/stdlib/posix.rs`** (2908 lines): +- `set_inheritable(fd, inheritable)` — pure nix fcntl wrapper +- `getgroups_impl()` — pure libc/nix wrapper +- `get_right_permission()`, `get_permissions()` — pure permission logic +- 400+ libc constant re-exports (`#[pyattr] use libc::*`) + +**`vm/src/stdlib/nt.rs`** (2301 lines): +- `win32_hchmod()`, `win32_lchmod()`, `fchmod_impl()` — pure Windows API calls (currently return PyResult, should return io::Result) +- Spawn mode constants, `O_*` flags + +**`vm/src/stdlib/_signal.rs`** (729 lines): +- `timeval_to_double()`, `double_to_timeval()`, `itimerval_to_tuple()` — pure math +- 30+ signal/timer constants + +**`vm/src/stdlib/time.rs`** (1616 lines): +- `asctime_from_tm()` — pure string formatting +- `get_tz_info()` — pure Windows API +- Time unit constants (`SEC_TO_MS`, `MS_TO_US`, etc.) +- `duration_since_system_now()` — host clock access (currently takes vm, can return io::Result instead) + +**`vm/src/stdlib/msvcrt.rs`**: +- `getch()`, `getwch()`, `getche()`, `getwche()`, `kbhit()`, `setmode_binary()` — all pure host +- Locking constants (`LK_UNLCK`, `LK_LOCK`, etc.) + +**`vm/src/stdlib/_winapi.rs`** (2180 lines): +- `GetACP()`, `GetCurrentProcess()`, `GetLastError()`, `GetVersion()` — pure host +- 100+ Windows API constants + +**`vm/src/stdlib/os.rs`** (2395 lines): +- `fs_metadata()` — pure `std::fs` wrapper +- libc flag constants (`O_APPEND`, `O_CREAT`, etc.) + +## Dependency Graph (After) + +``` +rustpython-host_env (NEW — zero Python dep, independent of common) +├── Dependencies: libc, nix (unix), windows-sys (win), widestring (win), rustpython-wtf8 +├── From common: os, crt_fd, fileutils, windows, macros +└── Extracted from vm/stdlib: posix, nt, signal, time, msvcrt, winapi, socket, mmap, ... + +rustpython-common (NO host_env dependency — pure algorithmic code only) +└── cformat, float_ops, hash, int, str, encodings, etc. + +rustpython-vm +├── rustpython-common +├── rustpython-host_env (optional, feature = "host_env") +├── libc (retained for type definitions & constants used inline in #[pyattr]) +└── Python bindings call host_env for actual OS operations + +rustpython-stdlib +├── rustpython-vm, rustpython-common +├── rustpython-host_env (optional, feature = "host_env") +└── libc, nix, socket2, memmap2 (retained for now — future migration target) +``` + +`common` and `host_env` are fully independent — no dependency in either direction. + +## Phase 1: Create the crate and move modules from common + +Create `crates/host_env/`, **move** host modules from common, and update common to re-export. + +### New files: + +**`crates/host_env/Cargo.toml`:** +```toml +[package] +name = "rustpython-host_env" +description = "Host OS API abstractions for RustPython (zero Python dependency)" +version.workspace = true +edition.workspace = true + +[dependencies] +rustpython-wtf8 = { workspace = true } +libc = { workspace = true } +num-traits = { workspace = true } +cfg-if = { workspace = true } + +[target.'cfg(unix)'.dependencies] +nix = { workspace = true } + +[target.'cfg(windows)'.dependencies] +widestring = { workspace = true } +windows-sys = { workspace = true, features = [ + "Win32_Foundation", + "Win32_Globalization", + "Win32_Networking_WinSock", + "Win32_Storage_FileSystem", + "Win32_System_Console", + "Win32_System_Ioctl", + "Win32_System_LibraryLoader", + "Win32_System_SystemServices", + "Win32_System_Time", +] } +``` + +**`crates/host_env/src/lib.rs`:** +```rust +#[macro_use] +mod macros; +pub use macros::*; + +pub mod os; + +#[cfg(any(unix, windows, target_os = "wasi"))] +pub mod crt_fd; + +#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] +pub mod fileutils; + +#[cfg(windows)] +pub mod windows; + +// New modules — extracted from vm/stdlib (Phase 2) +#[cfg(unix)] +pub mod posix; +#[cfg(windows)] +pub mod nt; +pub mod signal; +pub mod time; +#[cfg(windows)] +pub mod msvcrt; +#[cfg(windows)] +pub mod winapi; +``` + +**Modules moved from common**: `os.rs`, `crt_fd.rs`, `fileutils.rs`, `windows.rs`, `macros.rs` + +### Modified files: + +**`Cargo.toml` (workspace root):** +- Add `"crates/host_env"` to `[workspace.members]` +- Add `rustpython-host_env = { path = "crates/host_env" }` to `[workspace.dependencies]` + +**`crates/common/Cargo.toml`:** +- Remove `nix`, `windows-sys`, `widestring` from direct dependencies +- Keep `libc` for type definitions (`wchar_t` in `str.rs`) +- No `host_env` feature or dependency — common stays purely algorithmic + +**`crates/common/src/lib.rs`:** +- Remove `pub mod os`, `pub mod crt_fd`, `pub mod fileutils`, `pub mod windows` declarations +- Remove `#[macro_use] mod macros` and `suppress_iph!` macro (moved to host_env) +- Delete the source files: `os.rs`, `crt_fd.rs`, `fileutils.rs`, `windows.rs`, `macros.rs` + +**`crates/vm/Cargo.toml`:** +```toml +[features] +host_env = ["rustpython-host_env"] + +[dependencies] +rustpython-host_env = { workspace = true, optional = true } +``` + +**`crates/stdlib/Cargo.toml`:** +```toml +[features] +host_env = ["rustpython-vm/host_env", "rustpython-host_env"] + +[dependencies] +rustpython-host_env = { workspace = true, optional = true } +``` + +### Verification: +```bash +cargo check -p rustpython-host_env +cargo test +cargo check -p rustpython-vm --no-default-features --features compiler,gc # sandbox build +``` + +## Phase 2: Extract host functions from vm/stdlib modules + +Extract pure host API functions and constants from vm's stdlib modules into new modules within `host_env`. + +### New modules in `crates/host_env/src/`: + +**`posix.rs`** — extracted from `vm/src/stdlib/posix.rs`: +```rust +use std::os::fd::BorrowedFd; + +pub fn set_inheritable(fd: BorrowedFd<'_>, inheritable: bool) -> nix::Result<()> { + use nix::fcntl; + let flags = fcntl::FdFlag::from_bits_truncate(fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFD)?); + let mut new_flags = flags; + new_flags.set(fcntl::FdFlag::FD_CLOEXEC, !inheritable); + if flags != new_flags { + fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFD(new_flags))?; + } + Ok(()) +} + +pub fn getgroups() -> nix::Result> { ... } +pub fn get_right_permission(mode: u32, file_owner: Uid, file_group: Gid) -> nix::Result { ... } +``` + +**`nt.rs`** — extracted from `vm/src/stdlib/nt.rs`: +```rust +pub fn win32_hchmod(handle: HANDLE, mode: u32) -> io::Result<()> { ... } +pub fn win32_lchmod(path: &OsStr, mode: u32) -> io::Result<()> { ... } +``` + +**`signal.rs`** — extracted from `vm/src/stdlib/_signal.rs`: +```rust +pub fn timeval_to_double(tv: &libc::timeval) -> f64 { ... } +pub fn double_to_timeval(val: f64) -> libc::timeval { ... } +pub fn itimerval_to_tuple(it: &libc::itimerval) -> (f64, f64) { ... } +``` + +**`time.rs`** — extracted from `vm/src/stdlib/time.rs`: +```rust +pub const SEC_TO_MS: i64 = 1000; +pub const MS_TO_US: i64 = 1000; +// ... + +pub fn asctime_from_tm(tm: &libc::tm) -> String { ... } +pub fn duration_since_system_now() -> io::Result { ... } + +#[cfg(windows)] +pub fn get_tz_info() -> TIME_ZONE_INFORMATION { ... } +``` + +**`msvcrt.rs`** — extracted from `vm/src/stdlib/msvcrt.rs`: +```rust +pub fn getch() -> Vec { ... } +pub fn getwch() -> String { ... } +pub fn kbhit() -> i32 { ... } +pub fn setmode_binary(fd: crt_fd::Borrowed<'_>) { ... } + +pub const LK_UNLCK: i32 = 0; +pub const LK_LOCK: i32 = 1; +// ... +``` + +**`winapi.rs`** — extracted from `vm/src/stdlib/_winapi.rs`: +```rust +pub fn get_acp() -> u32 { ... } +pub fn get_current_process() -> HANDLE { ... } +pub fn get_last_error() -> u32 { ... } +pub fn get_version() -> u32 { ... } +// + Windows API constants +``` + +### Modified vm/stdlib files: + +Each file is updated to call `rustpython_host_env::` instead of inlining the host calls: + +```rust +// BEFORE (vm/src/stdlib/posix.rs) +pub fn set_inheritable(fd: BorrowedFd<'_>, inheritable: bool) -> nix::Result<()> { + use nix::fcntl; + // ... 10 lines of nix API calls +} + +// AFTER (vm/src/stdlib/posix.rs) +pub use rustpython_host_env::posix::set_inheritable; +``` + +## Phase 3: vm/stdlib import migration + +All `common::os`, `common::crt_fd`, `common::fileutils`, `common::windows` imports must be updated to `rustpython_host_env::`. + +### Import migration targets (vm) — ~20 files: + +| File | Current | New | +|------|---------|-----| +| `ospath.rs` | `rustpython_common::crt_fd` | `rustpython_host_env::crt_fd` | +| `stdlib/os.rs` | `common::crt_fd`, `common::os::*` | `rustpython_host_env::` | +| `stdlib/nt.rs` | `common::windows::*`, `common::crt_fd::*` | `rustpython_host_env::` | +| `stdlib/_io.rs` | `common::crt_fd::Offset`, `common::fileutils::fstat` | `rustpython_host_env::` | +| `stdlib/_signal.rs` | `common::crt_fd::*`, `common::fileutils::fstat` | `rustpython_host_env::` | +| `stdlib/posix.rs` | `common::os::*`, `common::crt_fd::Offset` | `rustpython_host_env::` | +| `stdlib/_ctypes/function.rs` | `rustpython_common::os::get_errno` | `rustpython_host_env::os::` | +| `stdlib/_codecs.rs` | `common::windows::ToWideString` | `rustpython_host_env::windows::` | +| `stdlib/sys.rs`, `winreg.rs`, `winsound.rs` | `common::windows::ToWideString` | `rustpython_host_env::windows::` | +| `windows.rs` | `rustpython_common::windows::ToWideString` | `rustpython_host_env::windows::` | +| `exceptions.rs` | `common::os::ErrorExt`, `common::os::winerror_to_errno` | `rustpython_host_env::os::` | + +### Import migration targets (stdlib) — ~7 files: + +| File | Current | New | +|------|---------|-----| +| `socket.rs` | `common::os::ErrorExt`, `common::os::errno_io_error` | `rustpython_host_env::os::` | +| `mmap.rs` | `rustpython_common::crt_fd` | `rustpython_host_env::crt_fd` | +| `faulthandler.rs` | `rustpython_common::os::{get_errno, set_errno}` | `rustpython_host_env::os::` | +| `posixshmem.rs` | `common::os::errno_io_error` | `rustpython_host_env::os::` | +| `termios.rs` | `common::os::ErrorExt` | `rustpython_host_env::os::` | +| `overlapped.rs` | `crate::vm::common::os::winerror_to_errno` | `rustpython_host_env::os::` | +| `openssl.rs` | `rustpython_common::fileutils::fopen` | `rustpython_host_env::fileutils::` | + +### External consumers: + +| File | Current | New | +|------|---------|-----| +| `src/lib.rs` | `rustpython_vm::common::os::exit_code` | `rustpython_host_env::os::exit_code` | +| `examples/*.rs` | `vm::common::os::exit_code` | Keep via re-export | + +## Phase 4 (Future): Extract host functions from stdlib modules + +Same pattern as Phase 2, but for `crates/stdlib/src/` modules. These modules heavily use `libc`, `nix`, `socket2`, `memmap2` directly. Extract the pure host layer into `host_env`. + +**Target modules and what goes into host_env:** + +| stdlib module | host_env module | What to extract | +|---------------|----------------|-----------------| +| `socket.rs` (3498 lines) | `host_env::socket` | Socket creation, bind, connect, address conversion, cmsg helpers, poll wrappers. Re-export `socket2` types. | +| `mmap.rs` (1625 lines) | `host_env::mmap` | mmap/munmap wrappers, madvise, msync. Re-export `memmap2` types. | +| `select.rs` (745 lines) | `host_env::select` | select/poll/epoll/kqueue wrappers via libc/nix. | +| `posixsubprocess.rs` (537 lines) | `host_env::subprocess` | fork_exec, pipe, dup2, close-on-exec logic. | +| `multiprocessing.rs` (1152 lines) | `host_env::multiprocessing` | Semaphore operations (sem_open/wait/post/unlink via libc). | +| `fcntl.rs` (220 lines) | `host_env::fcntl` | fcntl, ioctl, flock wrappers. | +| `faulthandler.rs` (1333 lines) | `host_env::faulthandler` | Signal handler registration, stack dump via libc write. | +| `locale.rs` (332 lines) | `host_env::locale` | strcoll, strxfrm, setlocale wrappers. | +| `resource.rs` (194 lines) | `host_env::resource` | getrusage, getrlimit, setrlimit wrappers. | +| `grp.rs` (103 lines) | `host_env::grp` | getgrent/setgrent/endgrent, Group lookup via nix. | +| `syslog.rs` (148 lines) | `host_env::syslog` | openlog, syslog, closelog, setlogmask wrappers. | +| `posixshmem.rs` (52 lines) | `host_env::shm` | shm_open, shm_unlink wrappers. | +| `termios.rs` (280 lines) | `host_env::termios` | Terminal attribute get/set via termios crate. | + +After this, `nix`, `socket2`, `memmap2`, `rustix` are removed from stdlib's direct dependencies. Only `host_env` provides them. + +## Phase 5: Lint enforcement + +Three layers of enforcement, from strongest to lightest: + +### Layer 1: Crate boundary (compile-time, absolute) + +The strongest guarantee. If a crate doesn't list `rustpython-host_env` in its `[dependencies]`, it physically cannot call any host_env function. This is already enforced by Rust's module system. + +**Pure crates (no host_env dependency allowed):** +- `rustpython-common` +- `rustpython-compiler`, `rustpython-compiler-core`, `rustpython-compiler-source` +- `rustpython-codegen` +- `rustpython-literal` +- `rustpython-sre_engine` +- `rustpython-wtf8` +- `rustpython-derive`, `rustpython-derive-impl` + +CI check: +```bash +# Verify pure crates don't depend on host_env +for crate in common compiler compiler-core compiler-source codegen literal sre_engine wtf8 derive derive-impl; do + if rg 'rustpython-host_env' "crates/$crate/Cargo.toml"; then + echo "ERROR: $crate should not depend on host_env" + exit 1 + fi +done +``` + +### Layer 2: clippy disallowed_methods (compile-time, configurable) + +Block direct host API usage in vm/stdlib. Force all host access through `host_env`. + +**Workspace-level `clippy.toml`** (project root): +```toml +disallowed-methods = [ + # Filesystem + { path = "std::fs::read", reason = "use rustpython_host_env for host filesystem access" }, + { path = "std::fs::write", reason = "use rustpython_host_env" }, + { path = "std::fs::read_to_string", reason = "use rustpython_host_env" }, + { path = "std::fs::read_dir", reason = "use rustpython_host_env" }, + { path = "std::fs::create_dir", reason = "use rustpython_host_env" }, + { path = "std::fs::create_dir_all", reason = "use rustpython_host_env" }, + { path = "std::fs::remove_file", reason = "use rustpython_host_env" }, + { path = "std::fs::remove_dir", reason = "use rustpython_host_env" }, + { path = "std::fs::metadata", reason = "use rustpython_host_env" }, + { path = "std::fs::symlink_metadata", reason = "use rustpython_host_env" }, + { path = "std::fs::canonicalize", reason = "use rustpython_host_env" }, + { path = "std::fs::File::open", reason = "use rustpython_host_env" }, + { path = "std::fs::File::create", reason = "use rustpython_host_env" }, + { path = "std::fs::OpenOptions::open", reason = "use rustpython_host_env" }, + + # Environment + { path = "std::env::var", reason = "use rustpython_host_env" }, + { path = "std::env::var_os", reason = "use rustpython_host_env" }, + { path = "std::env::set_var", reason = "use rustpython_host_env" }, + { path = "std::env::remove_var", reason = "use rustpython_host_env" }, + { path = "std::env::vars", reason = "use rustpython_host_env" }, + { path = "std::env::vars_os", reason = "use rustpython_host_env" }, + { path = "std::env::current_dir", reason = "use rustpython_host_env" }, + { path = "std::env::set_current_dir", reason = "use rustpython_host_env" }, + { path = "std::env::temp_dir", reason = "use rustpython_host_env" }, + + # Process + { path = "std::process::Command::new", reason = "use rustpython_host_env" }, + { path = "std::process::exit", reason = "use rustpython_host_env" }, + { path = "std::process::abort", reason = "use rustpython_host_env" }, + { path = "std::process::id", reason = "use rustpython_host_env" }, + + # Network + { path = "std::net::TcpStream::connect", reason = "use rustpython_host_env" }, + { path = "std::net::TcpListener::bind", reason = "use rustpython_host_env" }, + { path = "std::net::UdpSocket::bind", reason = "use rustpython_host_env" }, +] +``` + +**`crates/host_env/clippy.toml`** (overrides — host_env is allowed to use everything): +```toml +disallowed-methods = [] +``` + +Clippy resolves `clippy.toml` by walking up from the crate directory, so `host_env`'s local config takes precedence over the workspace root. + +**Workspace `Cargo.toml`:** +```toml +[workspace.lints.clippy] +disallowed_methods = "deny" +``` + +### Layer 3: Sandbox build verification (CI) + +Build without `host_env` feature to catch any code that accidentally compiles without the feature gate: + +```bash +cargo check -p rustpython-vm --no-default-features --features compiler,gc +cargo check -p rustpython-stdlib --no-default-features --features compiler +``` + +### Layer 4: Whitelist-based module audit (CI script) + +Maintain a whitelist of modules in vm/stdlib that are known to NOT use host_env. Any change that adds a `rustpython_host_env` import to a whitelisted module triggers CI failure. + +```bash +# .ci/host_env_whitelist.txt — modules that must stay host-free +# vm modules: +crates/vm/src/stdlib/_abc.rs +crates/vm/src/stdlib/_collections.rs +crates/vm/src/stdlib/_functools.rs +crates/vm/src/stdlib/_operator.rs +crates/vm/src/stdlib/_sre.rs +crates/vm/src/stdlib/_stat.rs +crates/vm/src/stdlib/_string.rs +crates/vm/src/stdlib/errno.rs +crates/vm/src/stdlib/gc.rs +crates/vm/src/stdlib/itertools.rs +crates/vm/src/stdlib/marshal.rs + +# Check: +while IFS= read -r file; do + if rg 'rustpython_host_env' "$file" 2>/dev/null; then + echo "ERROR: $file is whitelisted as host-free but imports host_env" + exit 1 + fi +done < .ci/host_env_whitelist.txt +``` + +The inverse is also useful — list all files that ARE allowed to use host_env, and reject any new file that uses it without being on the list. This catches accidental host API usage in new modules. + +### Layer 5: `#![no_std]` for pure crates + +After removing host modules from `common`, it could potentially become `#![no_std]` unconditionally (it already has `#![cfg_attr(not(feature = "std"), no_std)]`). This is the strongest possible guarantee — no `std::fs`, `std::env`, `std::net`, `std::process` available at all. + +Candidate crates for unconditional `#![no_std]`: +- `rustpython-literal` +- `rustpython-wtf8` +- `rustpython-compiler-source` + +### Summary of enforcement layers + +| Layer | What it catches | Strength | Cost | +|-------|----------------|----------|------| +| Crate boundary | Missing host_env dependency | Absolute — compile error | Zero — automatic | +| clippy disallowed_methods | Direct std::fs/env/net usage | Strong — clippy deny | Low — clippy.toml config | +| Sandbox build | Missing `#[cfg(feature = "host_env")]` | Strong — compile error | Low — CI job | +| Module whitelist | Unintended host_env usage in pure modules | Medium — CI script | Low — maintain whitelist | +| `#![no_std]` | Any std usage in pure crates | Absolute — compile error | Medium — may need refactoring | + +## Risk Assessment + +| Risk | Level | Mitigation | +|------|-------|------------| +| Target modules have Python type dependencies | **Low** | Verified: only `libc`, `nix`, `windows-sys`, `rustpython-wtf8` | +| Internal cross-references break on move | **Low** | `crt_fd`, `os`, `fileutils`, `windows` all move together; `crate::` paths stay valid | +| `suppress_iph!` macro `$crate` resolution | **Medium** | `$crate` automatically resolves to new crate; `__macro_private` moves alongside | +| Breaking external consumers | **Medium** | Clean break — consumers must update `common::os` to `host_env::os`. No re-export shim. | +| Scope of Phase 2 extraction | **Medium** | Start with clearly pure functions; mixed functions can be migrated incrementally | diff --git a/scripts/generate_opcode_metadata.py b/scripts/generate_opcode_metadata.py deleted file mode 100644 index 5b9a0b10be9..00000000000 --- a/scripts/generate_opcode_metadata.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Generate Lib/_opcode_metadata.py for RustPython bytecode. - -This file generates opcode metadata that is compatible with CPython 3.13. -""" - -import itertools -import pathlib -import re -import typing - -ROOT = pathlib.Path(__file__).parents[1] -BYTECODE_FILE = ( - ROOT / "crates" / "compiler-core" / "src" / "bytecode" / "instructions.rs" -) -OPCODE_METADATA_FILE = ROOT / "Lib" / "_opcode_metadata.py" - - -# Opcodes that needs to be first, regardless of their opcode ID. -PRIORITY_OPMAP = { - "CACHE", - "RESERVED", - "RESUME", - "INSTRUMENTED_LINE", - "ENTER_EXECUTOR", -} - - -def to_snake_case(s: str) -> str: - res = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", s) - return re.sub(r"(\D)(\d+)$", r"\1_\2", res).upper() - - -class Opcode(typing.NamedTuple): - rust_name: str - id: int - have_oparg: bool - - @property - def is_instrumented(self): - return self.cpython_name.startswith("INSTRUMENTED_") - - @property - def cpython_name(self): - return to_snake_case(self.rust_name) - - @classmethod - def from_str(cls, text: str): - # Split on commas that are followed by a newline + an uppercase letter (new entry) - entries = re.split(r",\s*\n\s*(?=[A-Z])", text) - for entry in entries: - entry = entry.strip() - if not entry: - continue - have_oparg = "Arg<" in entry # Hacky but works - rust_name = re.match(r"(\w+)", entry).group(1) - id_num = re.findall(r"= (\d+)", entry)[0] - yield cls(rust_name=rust_name, id=int(id_num), have_oparg=have_oparg) - - def __lt__(self, other: typing.Self) -> bool: - sprio, oprio = ( - opcode.cpython_name not in PRIORITY_OPMAP for opcode in (self, other) - ) - return (sprio, self.id) < (oprio, other.id) - - -def extract_enum_body(text: str, name: str) -> str: - # Find the start of the enum block - start_match = re.search(rf"enum\s+{name}\s*\{{", text) - if not start_match: - return None - - # Manually track brace depth from that point - depth = 0 - start = start_match.end() - 1 # position of opening '{' - for i, ch in enumerate(text[start:], start): - if ch == "{": - depth += 1 - elif ch == "}": - depth -= 1 - if depth == 0: - # Return only the inner content (excluding outer braces) - return text[start + 1 : i] - - -def build_deopts(text: str) -> dict[str, list[str]]: - raw_body = re.search(r"fn deopt\(self\)(.*)", text, re.DOTALL).group(1) - match_start = raw_body.find("match self") - if match_start == -1: - raise ValueError("Could not detect a match statement in deopt method") - - brace_depth = 0 - block_start = None - block_end = None - - for i, ch in enumerate(raw_body[match_start:], match_start): - if ch == "{": - brace_depth += 1 - if block_start is None: - block_start = i + 1 - elif ch == "}": - brace_depth -= 1 - if brace_depth == 0: - block_end = i - break - - match_body = raw_body[block_start:block_end] - - arm_pattern = re.compile( - r"((?:Self::\w+\s*\|\s*)*Self::\w+)\s*=>\s*(?:\{\s*)?Self::(\w+)", re.DOTALL - ) - variants_pattern = re.compile(r"Self::(\w+)") - - deopts = {} - for hit in arm_pattern.finditer(match_body): - raw_variants = hit.group(1) - opcode = hit.group(2) - - variants = variants_pattern.findall(raw_variants) - - key = to_snake_case(opcode) - value = [to_snake_case(variant) for variant in variants] - deopts[key] = value - - return deopts - - -contents = BYTECODE_FILE.read_text(encoding="utf-8") - -deopts = build_deopts(contents) - -enum_body = "\n".join( - extract_enum_body(contents, enum_name) - for enum_name in ("Instruction", "PseudoInstruction") -) -opcodes = list(Opcode.from_str(enum_body)) - -have_oparg = min(opcode.id for opcode in opcodes if opcode.have_oparg) - 1 -min_instrumented = min(opcode.id for opcode in opcodes if opcode.is_instrumented) - -# Generate the output file -output = """# This file is generated by scripts/generate_opcode_metadata.py -# for RustPython bytecode format (CPython 3.14 compatible opcode numbers). -# Do not edit! -""" - -output += "\n_specializations = {\n" - -for key, lst in deopts.items(): - output += f' "{key}": [\n' - for item in lst: - output += f' "{item}",\n' - output += " ],\n" - -output += "}\n" - -specialized = set(itertools.chain.from_iterable(deopts.values())) -output += "\n_specialized_opmap = {\n" -for opcode in sorted(opcodes, key=lambda op: op.cpython_name): - cpython_name = opcode.cpython_name - if cpython_name not in specialized: - continue - - output += f" '{cpython_name}': {opcode.id},\n" - -output += "}\n" - -output += "\nopmap = {\n" - -for opcode in sorted(opcodes): - cpython_name = opcode.cpython_name - if cpython_name in specialized: - continue - - output += f" '{cpython_name}': {opcode.id},\n" - -output += "}\n" - -output += f""" -HAVE_ARGUMENT = {have_oparg} -MIN_INSTRUMENTED_OPCODE = {min_instrumented} -""" - -OPCODE_METADATA_FILE.write_text(output, encoding="utf-8") diff --git a/scripts/update_lib/deps.py b/scripts/update_lib/deps.py index 49ddf9a8730..72374aa6be1 100644 --- a/scripts/update_lib/deps.py +++ b/scripts/update_lib/deps.py @@ -320,9 +320,12 @@ def clear_import_graph_caches() -> None: "pickle": { "hard_deps": ["_compat_pickle.py"], "test": [ + "picklecommon.py", "test_pickle.py", "test_picklebuffer.py", "test_pickletools.py", + "test_xpickle.py", + "xpickle_worker.py", ], }, "re": { diff --git a/src/interpreter.rs b/src/interpreter.rs index 9060f3a7bec..230192d1e21 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -14,6 +14,8 @@ impl InterpreterBuilderExt for InterpreterBuilder { fn init_stdlib(self) -> Self { let defs = rustpython_stdlib::stdlib_module_defs(&self.ctx); let builder = self.add_native_modules(&defs); + #[cfg(all(feature = "ssl-rustls-aws-lc", not(target_arch = "wasm32")))] + let builder = builder.init_hook(install_default_tls_provider); cfg_select! { feature = "freeze-stdlib" => { @@ -26,6 +28,20 @@ impl InterpreterBuilderExt for InterpreterBuilder { } } +#[cfg(all(feature = "ssl-rustls-aws-lc", not(target_arch = "wasm32")))] +fn install_default_tls_provider(_vm: &mut crate::VirtualMachine) { + use rustls::crypto::aws_lc_rs; + use rustpython_stdlib::ssl::providers::CryptoExt; + + let ext = CryptoExt { + all_cipher_suites: Some(aws_lc_rs::ALL_CIPHER_SUITES), + all_kx_groups: Some(aws_lc_rs::ALL_KX_GROUPS), + any_supported_key: Some(aws_lc_rs::sign::any_supported_type), + ticketer: aws_lc_rs::Ticketer::new, + }; + let _ = CryptoExt::set_provider(aws_lc_rs::default_provider(), ext); +} + /// Set stdlib_dir for frozen standard library #[cfg(all(feature = "stdlib", feature = "freeze-stdlib"))] fn set_frozen_stdlib_dir(vm: &mut crate::VirtualMachine) { diff --git a/src/lib.rs b/src/lib.rs index a8384244cfa..14deb2972d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ pub use shell::run_shell; not(any(feature = "ssl-rustls", feature = "ssl-openssl")) ))] compile_error!( - "Feature \"ssl\" is now enabled by either \"ssl-rustls\" or \"ssl-openssl\" to be enabled. Do not manually pass \"ssl\" feature. To enable ssl-openssl, use --no-default-features to disable ssl-rustls" + "Feature \"ssl\" is now enabled by either \"ssl-rustls\" or \"ssl-openssl\". Do not manually pass \"ssl\" feature. To enable ssl-openssl, use --no-default-features to disable ssl-rustls*" ); /// The main cli of the `rustpython` interpreter. This function will return `std::process::ExitCode` diff --git a/src/settings.rs b/src/settings.rs index 5233cf98d49..25398cd8fc4 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -199,7 +199,7 @@ fn help(parser: lexopt::Parser) -> ! { } fn version() -> ! { - println!("Python {}", rustpython_vm::version::get_version()); + println!("Python {}", rustpython_vm::version::RUSTPYTHON_VERSION); std::process::exit(0); } diff --git a/tools/opcode_metadata/conf.toml b/tools/opcode_metadata/conf.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tools/opcode_metadata/cpython.py b/tools/opcode_metadata/cpython.py new file mode 100644 index 00000000000..56743a88842 --- /dev/null +++ b/tools/opcode_metadata/cpython.py @@ -0,0 +1,27 @@ +import os +import pathlib +import sys + +try: + CPYTHON_ROOT = pathlib.Path(os.environ["CPYTHON_ROOT"]).expanduser().resolve() +except KeyError: + raise ValueError("Missing environment variable 'CPYTHON_ROOT'") + +CPYTHON_TOOLS_LIB = CPYTHON_ROOT / "Tools" / "cases_generator" + + +if (path := CPYTHON_TOOLS_LIB.as_posix()) not in sys.path: + sys.path.append(path) + + +from analyzer import SKIP_PROPERTIES, Analysis, Family, Properties, analyze_files +from stack import get_stack_effect + + +def get_analysis() -> Analysis: + from generators_common import DEFAULT_INPUT + + analysis = analyze_files([DEFAULT_INPUT]) + # Our speration is done at the enum definition + analysis.instructions |= analysis.pseudos + return analysis diff --git a/tools/opcode_metadata/generate_py_opcode_metadata.py b/tools/opcode_metadata/generate_py_opcode_metadata.py new file mode 100644 index 00000000000..08f350e4b39 --- /dev/null +++ b/tools/opcode_metadata/generate_py_opcode_metadata.py @@ -0,0 +1,107 @@ +""" +Generate Lib/_opcode_metadata.py for RustPython bytecode. + +This file generates opcode metadata that is compatible with CPython 3.14. +""" + +import functools +import io +import itertools +import operator +import pathlib +import typing + +from opcodes import OpcodeInfo +from utils import DEFAULT_INPUT, ROOT, get_conf, to_pascal_case, to_upper_snake_case + +OUT_FILE = ROOT / "Lib/_opcode_metadata.py" + + +# Opcodes that needs to be first, regardless of their opcode ID. +PRIORITY_OPMAP = { + "CACHE", + "RESERVED", + "RESUME", + "INSTRUMENTED_LINE", + "ENTER_EXECUTOR", +} + +INDENT = " " * 4 +INDENT2 = INDENT * 2 + + +def main(): + override_conf = get_conf() + inp = DEFAULT_INPUT.read_text() + + infos = tuple(OpcodeInfo.iter_infos(inp, override_conf)) + opcodes = tuple(itertools.chain.from_iterable(info.opcodes for info in infos)) + + script_path = pathlib.Path(__file__).resolve().relative_to(ROOT).as_posix() + out = io.StringIO() + + out.write( + f""" +# This file is generated by {script_path} +# for RustPython bytecode format (CPython 3.14 compatible opcode numbers). +# Do not edit! +""".lstrip() + ) + + # _specializations + out.write("\n") + out.write("_specializations = {\n") + + deopts = functools.reduce(operator.ior, map(operator.attrgetter("deopts"), infos)) + + for key, lst in deopts.items(): + key = to_upper_snake_case(key) + out.write(f'{INDENT}"{key}": [\n') + + for item in map(to_upper_snake_case, lst): + out.write(f'{INDENT2}"{item}",\n') + + out.write(f"{INDENT}],\n") + + out.write("}\n") + + # _specialized_opmap + out.write("\n") + out.write("_specialized_opmap = {\n") + + specialized = set(itertools.chain.from_iterable(deopts.values())) + for opcode in sorted(opcodes, key=lambda op: op.cpython_name): + if opcode.rust_name not in specialized: + continue + out.write(f"{INDENT}'{opcode.cpython_name}': {opcode.id},\n") + + out.write("}\n") + + # opmap + out.write("\n") + out.write("opmap = {\n") + + key = lambda op: (op.cpython_name not in PRIORITY_OPMAP, op.id) + for opcode in sorted(opcodes, key=key): + if opcode.rust_name in specialized: + continue + + out.write(f"{INDENT}'{opcode.cpython_name}': {opcode.id},\n") + + out.write("}\n") + + # min + out.write("\n") + have_argument = min(opcode.id for opcode in opcodes if opcode.have_argument) - 1 + out.write(f"HAVE_ARGUMENT = {have_argument}\n") + + min_instrumented = min(opcode.id for opcode in opcodes if opcode.is_instrumented) + out.write(f"MIN_INSTRUMENTED_OPCODE = {min_instrumented}\n") + + # write output + generated = out.getvalue() + OUT_FILE.write_text(generated) + + +if __name__ == "__main__": + main() diff --git a/tools/opcode_metadata/generate_rs_opcode_metadata.py b/tools/opcode_metadata/generate_rs_opcode_metadata.py new file mode 100644 index 00000000000..df2476c5e08 --- /dev/null +++ b/tools/opcode_metadata/generate_rs_opcode_metadata.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python +from __future__ import annotations + +import collections +import dataclasses +import io +import os +import pathlib +import subprocess +import sys +import typing + +import tomllib +from cpython import Analysis, get_analysis, get_stack_effect +from opcodes import OpcodeInfo +from utils import DEFAULT_INPUT, ROOT, get_conf, to_pascal_case + +OUT_FILE = ROOT / "crates/compiler-core/src/bytecode/opcode_metadata.rs" + + +@dataclasses.dataclass(frozen=True, slots=True) +class OpcodeGen: + info: OpcodeDef + + @property + def fn_as_info_size(self) -> str: + return f""" + /// Returns [`Self`] as [`{self.size}`]. + #[must_use] + pub const fn as_{self.size}(self) -> {self.size} {{ + self.as_numeric() + }} + """ + + @property + def fn_try_from_numeric(self) -> str: + return f""" + pub const fn try_from_{self.size}( + value: {self.size}, + ) -> Result {{ + Self::try_from_numeric(value) + }} + """ + + @property + def fn_has_arg(self) -> str: + return self.gen_fn_has_attr("has_arg", "oparg", "HAS_ARG_FLAG") + + @property + def fn_has_const(self) -> str: + return self.gen_fn_has_attr("has_const", "uses_co_consts", "HAS_CONST_FLAG") + + @property + def fn_has_name(self) -> str: + return self.gen_fn_has_attr("has_name", "uses_co_names", "HAS_NAME_FLAG") + + @property + def fn_has_jump(self) -> str: + return self.gen_fn_has_attr("has_jump", "jumps", "HAS_JUMP_FLAG") + + @property + def fn_has_free(self) -> str: + return self.gen_fn_has_attr("has_free", "has_free", "HAS_FREE_FLAG") + + @property + def fn_has_local(self) -> str: + return self.gen_fn_has_attr("has_local", "uses_locals", "HAS_LOCAL_FLAG") + + @property + def fn_has_eval_break(self) -> str: + return self.gen_fn_has_attr( + "has_eval_break", "eval_breaker", "HAS_EVAL_BREAK_FLAG" + ) + + @property + def fn_is_instrumented(self) -> str: + arms = "|".join( + f"Self::{opcode.rust_name}" for opcode in self if opcode.is_instrumented + ) + + arms = arms.strip() + if arms: + inner = f"matches!(self, {arms})" + else: + inner = "false" + + return f""" + #[must_use] + pub const fn is_instrumented(self) -> bool {{ + {inner} + }} + """ + + @property + def fn_to_base(self) -> str: + arms = ",\n".join( + f"Self::{iname} => Self::{name}" + for name, iname in self.instrumented_mapping.items() + ) + + arms = arms.strip() + if not arms: + inner = "None" + else: + inner = f""" + Some(match self {{ + {arms}, + _ => return None, + + }}) + """ + + return f""" + #[must_use] + pub const fn to_base(self) -> Option {{ + {inner} + }} + """ + + @property + def fn_to_instrumented(self) -> str: + arms = ",\n".join( + f"Self::{name} => Self::{iname}" + for name, iname in self.instrumented_mapping.items() + ) + + arms = arms.strip() + if not arms: + inner = "None" + else: + inner = f""" + Some(match self {{ + {arms}, + _ => return None, + + }}) + """ + + return f""" + #[must_use] + pub const fn to_instrumented(self) -> Option {{ + {inner} + }} + """ + + @property + def fn_deopt(self) -> str: + arms = "" + for target, specialized in self.info.deopts.items(): + ops = "|".join(f"Self::{op}" for op in specialized) + arms += f"{ops} => Self::{target},\n" + + arms = arms.strip() + + if not arms: + inner = "None" + else: + inner = f""" + Some(match self {{ + {arms} + _ => return None, + }}) + """ + + return f""" + #[must_use] + pub const fn deopt(self) -> Option {{ + {inner} + }} + """ + + @property + def fn_cache_entries(self) -> str: + arms = "" + for opcode in self: + name = opcode.rust_name + if opcode.is_instrumented: + continue + if getattr(opcode, "family", None) and (opcode.family.name != name): + continue + + try: + size = opcode.cache_entry + except AttributeError: + continue + + if size > 1: + arms += f"Self::{name} => {size - 1},\n" + + arms = arms.strip() + if not arms: + inner = "0" + else: + inner = f""" + match self.deoptimize() {{ + {arms} + _ => 0, + }} + """ + + return f""" + #[must_use] + pub const fn cache_entries(self) -> usize {{ + {inner} + }} + """ + + @property + def fn_stack_effect_info(self) -> str: + oparg_used = False + arms = "" + for opcode in self: + name = opcode.rust_name + + popped = opcode.stack_effect_popped + pushed = opcode.stack_effect_pushed + + pushed_comment = "" + popped_comment = "" + + if popped != opcode.cpy_popped: + popped_comment = f"// TODO: Differs from CPython `{opcode.cpy_popped}`" + + if pushed != opcode.cpy_pushed: + pushed_comment = f"// TODO: Differs from CPython `{opcode.cpy_pushed}`" + + oparg_used = oparg_used or any("oparg" in expr for expr in (pushed, popped)) + + arms += f""" + Self::{name} => ( + {pushed}, {pushed_comment} + {popped}, {popped_comment} + ), + """.strip() + + arms = arms.strip() + + oparg_arg = "_oparg" + oparg_cast = "" + if oparg_used: + oparg_arg = "oparg" + oparg_cast = f""" + // Reason for converting {oparg_arg} to i32 is because of expressions like `1 + (oparg -1)` + // that causes underflow errors. + let oparg = i32::try_from({oparg_arg}).expect("{oparg_arg} does not fit in an `i32`"); + """ + + return f""" + #[must_use] + pub fn stack_effect_info(&self, {oparg_arg}: u32) -> StackEffect {{ + {oparg_cast} + + let (pushed, popped) = match self {{ + {arms} + }}; + + debug_assert!(u32::try_from(pushed).is_ok()); + debug_assert!(u32::try_from(popped).is_ok()); + + StackEffect::new(pushed as u32, popped as u32) + }} + """ + + def gen(self) -> str: + methods = "\n\n".join( + getattr(self, attr).strip() + for attr in sorted(dir(self)) + if attr.startswith("fn_") + ) + + impls = "\n\n".join( + getattr(self, attr).strip() + for attr in sorted(dir(self)) + if attr.startswith("impl_") + ) + + return f""" + impl super::{self.info.enum_name} {{ + {methods} + }} + + {impls} + """ + + def gen_fn_has_attr(self, fn_name: str, properties_attr: str, doc_flag: str) -> str: + arms = "|".join( + f"Self::{opcode.rust_name}" + for opcode in self + if getattr(opcode.properties, properties_attr) + ) + + if arms: + inner = f"matches!(self, {arms})" + else: + inner = "false" + + return f""" + /// Does this opcode have '{doc_flag}' set. + #[must_use] + pub const fn {fn_name}(self) -> bool {{ + {inner} + }} + """ + + @property + def instrumented_mapping(self) -> dict[str, str]: + names, inames = set(), set() + for opcode in self: + name = opcode.rust_name + if opcode.is_instrumented: + inames.add(name) + else: + names.add(name) + + res = {} + for iname in sorted(inames): + name = iname.removeprefix("Instrumented") + if name not in names: + continue + + res[name] = iname + + return res + + @property + def size(self) -> str: + return self.info.size + + def __iter__(self): + yield from self.info.opcodes + + +def rustfmt(code: str) -> str: + return subprocess.check_output(["rustfmt", "--emit=stdout"], input=code, text=True) + + +def main(): + override_conf = get_conf() + inp = DEFAULT_INPUT.read_text() + opcode_infos = OpcodeInfo.iter_infos(inp, override_conf) + + outfile = io.StringIO() + + for info in opcode_infos: + gen = OpcodeGen(info).gen() + outfile.write(gen) + + generated = outfile.getvalue() + + script_path = pathlib.Path(__file__).resolve().relative_to(ROOT).as_posix() + + output = rustfmt( + f""" +// This file is generated by {script_path} +// Do not edit! + +use crate::{{ + bytecode::instruction::StackEffect, + marshal::MarshalError, +}}; + +{generated} + """ + ) + + OUT_FILE.write_text(output) + + +if __name__ == "__main__": + main() diff --git a/tools/opcode_metadata/opcodes.py b/tools/opcode_metadata/opcodes.py new file mode 100644 index 00000000000..c1b19801761 --- /dev/null +++ b/tools/opcode_metadata/opcodes.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import collections +import dataclasses +import re +import typing +import warnings + +import utils +from cpython import SKIP_PROPERTIES, Family, Properties, get_analysis, get_stack_effect +from utils import SKIP_OVERRIDE, Override, OverrideConfs, StackEffect, to_pascal_case + +if typing.TYPE_CHECKING: + from collections.abc import Iterable + + +@dataclasses.dataclass(frozen=True, slots=True) +class OpcodeInfo: + enum_name: str + size: str + opcodes: tuple[Opcode, ...] + + @property + def deopts(self) -> dict[str, list[str]]: + analysis = get_analysis() + names = {opcode.rust_name for opcode in self} + + res = collections.defaultdict(list) + for family in analysis.families.values(): + family_name = to_pascal_case(family.name) + if family_name not in names: + continue + + for member in family.members: + member_name = to_pascal_case(member.name) + if member.name == family_name: + continue + + res[family_name].append(member_name) + + return dict(res) + + def __iter__(self): + yield from self.opcodes + + @classmethod + def iter_infos( + cls, text: str, override_confs: OverrideConfs + ) -> Iterable[typing.Self]: + for block_match in re.finditer( + r"define_opcodes!\s*\((.+?)\);", text, re.DOTALL + ): + block = block_match.group(1).strip() + + size = re.search(r"#\[repr\((\w+)\)\]", block).group(1) + enum_name = re.search( + r"#\[repr\(\w+\)\]\s*pub\s+enum\s+(\w+)\s*;", block + ).group(1) + + second_enum_match = re.search(r"pub\s+enum\s+(\w+)\s*\{", block, re.DOTALL) + entries = utils.extract_enum_body(block, second_enum_match.end() - 1) + + opcodes = tuple(sorted(iter_opcodes(entries, override_confs))) + + yield cls(enum_name, size, opcodes) + + +def iter_opcodes(text: str, override_confs: OverrideConfs) -> Iterable[Opcode]: + analysis = get_analysis() + # Split on commas that are followed by a newline + an uppercase letter (new entry) + entries = map(str.strip, re.split(r",\s*\n\s*(?=[A-Z])", text)) + for entry in entries: + if not entry: + continue + + opcode = Opcode.from_str(entry) + + rust_name = opcode.rust_name + override = override_confs.get(rust_name, SKIP_OVERRIDE) + + cpython_name = opcode.cpython_name + + kwargs = {} + if instr := analysis.instructions.get(cpython_name): + kwargs["properties"] = instr.properties + kwargs["family"] = getattr(instr, "family", None) + kwargs["cache_entry"] = getattr(instr, "size", -1) + + stack = get_stack_effect(instr) + + popped = (-stack.base_offset).to_c() + pushed = (stack.logical_sp - stack.base_offset).to_c() + kwargs["stack_effect"] = StackEffect(popped=popped, pushed=pushed) + elif override == SKIP_OVERRIDE: + warnings.warn( + f"Could not get instruction metadata for {rust_name}" + " from CPython or override conf" + ) + + yield dataclasses.replace(opcode, override=override, **kwargs) + + +@dataclasses.dataclass(frozen=True, slots=True) +class Opcode: + rust_name: str + id: int + have_argument: bool = False + cache_entry: int = 0 + stack_effect: StackEffect | None = None + properties: Properties = dataclasses.field(default_factory=lambda: SKIP_PROPERTIES) + family: Family | None = None + override: Override = dataclasses.field(default_factory=Override) + + @property + def is_instrumented(self) -> bool: + if (res := self.override.is_instrumented) is not None: + return res + + return self.cpython_name.startswith("INSTRUMENTED_") + + @property + def cpython_name(self): + return utils.to_upper_snake_case(self.rust_name) + + @property + def cpy_popped(self) -> str | None: + return getattr(self.stack_effect, "popped", None) + + @property + def cpy_pushed(self) -> str | None: + return getattr(self.stack_effect, "pushed", None) + + @property + def stack_effect_popped(self) -> str: + ove_popped = self.override.stack_effect.popped + + if (ove_popped is None) and (self.cpy_popped is None): + raise ValueError(f"{self.rust_name} is missing popped stack_effect") + + return ove_popped or self.cpy_popped + + @property + def stack_effect_pushed(self) -> str: + ove_pushed = self.override.stack_effect.pushed + + if (ove_pushed is None) and (self.cpy_pushed is None): + raise ValueError(f"{self.rust_name} is missing pushed stack_effect") + + return ove_pushed or self.cpy_pushed + + @classmethod + def from_str(cls, entry: str) -> typing.Self: + rust_name = re.match(r"(\w+)", entry).group(1) + id_num = re.findall(r"= (\d+)", entry)[0] + have_argument = "Arg<" in entry + return cls(rust_name, int(id_num), have_argument=have_argument) + + def __lt__(self, other: typing.Self) -> bool: + return self.id < other.id diff --git a/tools/opcode_metadata/utils.py b/tools/opcode_metadata/utils.py new file mode 100644 index 00000000000..bc3a9ace8d5 --- /dev/null +++ b/tools/opcode_metadata/utils.py @@ -0,0 +1,96 @@ +import dataclasses +import pathlib +import re +import sys + +import tomllib + +ROOT = pathlib.Path(__file__).parents[2].resolve() +DEFAULT_INPUT = ROOT / "crates/compiler-core/src/bytecode/instruction.rs" +DEFAULT_CONF = pathlib.Path(__file__).parent / "conf.toml" + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class StackEffect: + pushed: str | None = None + popped: str | None = None + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class Override: + is_instrumented: bool | None = None + stack_effect: StackEffect = dataclasses.field(default_factory=StackEffect) + + +type OverrideConfs = dict[str, Override] + +SKIP_STACK_EFFECT = StackEffect() +SKIP_OVERRIDE = Override() + + +def get_conf(path: pathlib.Path = DEFAULT_CONF) -> OverrideConfs: + data = path.read_text(encoding="utf-8") + conf = tomllib.loads(data) + for k, v in conf.items(): + v["stack_effect"] = StackEffect(**v.get("stack_effect", {})) + conf[k] = Override(**v) + + return conf + + +def to_pascal_case(s: str) -> str: + return s.title().replace("_", "") + + +def to_upper_snake_case(s: str) -> str: + """ + Converts a PascalCaseString to be SNAKE_CASE + + Parameters + ---------- + s : str + Pascal cased string to convert. + + Returns + ------- + str + Uppercased snake case string. + + Examples + -------- + >>> to_upper_snake_case("LoadAttr") + LOAD_ATTR + >>> to_upper_snake_case("CallIntrinsic1") + CALL_INTRINSIC_1 + """ + res = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", s) + return re.sub(r"(\D)(\d+)$", r"\1_\2", res).upper() + + +def extract_enum_body(text: str, start: int) -> str: + """ + Extract the rust enum body from a raw rust source code. + + Parameters + ---------- + text : str + Rust source code containing the enum body. + start : int + Offset to start searching from. + + Returns + ------- + str + Extracted enum body. + """ + assert text[start] == "{" + depth = 0 + for i, ch in enumerate(text[start:], start): + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return text[start + 1 : i].strip() # exclude the outer braces + + raise ValueError("Could not find end to enum body") diff --git a/wasm/demo/package-lock.json b/wasm/demo/package-lock.json index b9d5a9e42e9..14744312ac0 100644 --- a/wasm/demo/package-lock.json +++ b/wasm/demo/package-lock.json @@ -24,7 +24,7 @@ "serve": "^14.2.6", "webpack": "^5.105.0", "webpack-cli": "^6.0.1", - "webpack-dev-server": "^5.2.1" + "webpack-dev-server": "^5.2.4" } }, "node_modules/@codemirror/autocomplete": { @@ -299,6 +299,188 @@ "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==", "license": "MIT" }, + "node_modules/@noble/hashes": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.4.0.tgz", + "integrity": "sha512-V1JJ1WTRUqHHrOSh597hURcMqVKVGL/ea3kv0gSnEdsEZ0/+VyPghM1lMNGc00z7CIQorSvbKpuJkxvuHbvdbg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + } + }, + "node_modules/@peculiar/asn1-cms": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-cms/-/asn1-cms-2.7.0.tgz", + "integrity": "sha512-hew63shtzzvBcSHbhm+cyAmKe6AIfinT9hzEqSPjDC6opTTMKmTkQ0gHuN2KsWlvqiKw1S/fS94fhag/FJkioQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "@peculiar/asn1-x509-attr": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-csr": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-csr/-/asn1-csr-2.7.0.tgz", + "integrity": "sha512-VVsAyGqErT9D1SY4aEqozThXMVI+ssVRiv2DDeYuvpBKLIgZ3hYs3Ay3u/VSoKq6ESFi9cf6rf3IOOzfwh7oMA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-ecc": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-ecc/-/asn1-ecc-2.7.0.tgz", + "integrity": "sha512-n7KEs/Q/wrB415cxy4fHOBhegp4NdJ15fkJPwcB/3/8iNBQC2L/N7SChJPKDJPZGYH0jD4Tg4/0vnHmwghnbKw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-pfx": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-pfx/-/asn1-pfx-2.7.0.tgz", + "integrity": "sha512-V/nrlQVmhg7lYAsM7E13UDL5erAwFv6kCIVFqNaMIHSVi7dngcT839JkRTkQBqznMG98l2XjxYk74ZztAohZzA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-cms": "^2.7.0", + "@peculiar/asn1-pkcs8": "^2.7.0", + "@peculiar/asn1-rsa": "^2.7.0", + "@peculiar/asn1-schema": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-pkcs8": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-pkcs8/-/asn1-pkcs8-2.7.0.tgz", + "integrity": "sha512-9GTl1nE8Mx1kTZ+7QyYatDyKsm34QcWRBFkY1iPvWC3X4Dona5s/tlLiQsx5WzVdZqiMBZNYT0buyw4/vbhnjw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-pkcs9": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-pkcs9/-/asn1-pkcs9-2.7.0.tgz", + "integrity": "sha512-Bh7m+OuIaSEllPQcSd9OSp93F4ROWH7sbITWV8MI+8dwsjE5111/87VxiWVvYFKyww3vp39geLv9ENqhwWHcew==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-cms": "^2.7.0", + "@peculiar/asn1-pfx": "^2.7.0", + "@peculiar/asn1-pkcs8": "^2.7.0", + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "@peculiar/asn1-x509-attr": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-rsa": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-rsa/-/asn1-rsa-2.7.0.tgz", + "integrity": "sha512-/qvENQrXyTZURjMqSeofHul0JJt2sNSzSwk36pl2olkHbaioMQgrASDZAlHXl0xUlnVbHj0uGgOrBMTb5x2aJQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-schema": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-schema/-/asn1-schema-2.7.0.tgz", + "integrity": "sha512-W8ZfWzLmQnrcky+eh3tni4IozMdqBDiHWU0N+vve/UGjMaUs8c0L7A2oEdkBXS8rTpWDpK/aoI3DG/L/hxmxPg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/utils": "^2.0.2", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-x509": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-x509/-/asn1-x509-2.7.0.tgz", + "integrity": "sha512-mUn9RRrkGDnG4ALfunDmzyRW5dg+sWCj/pfnCCqEHYbkGxEpvUt6iVJv8Yw1cyp6SWZ26ZE5oSmI5SqEaen15g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/utils": "^2.0.2", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/asn1-x509-attr": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/@peculiar/asn1-x509-attr/-/asn1-x509-attr-2.7.0.tgz", + "integrity": "sha512-NS8e7SOgXipkzUPLF/sce7ukpMpWjhxYsH0n6Y+bHYo4TTxOb95Zv7hqwSuL212mj5YxovjdOKQOgH1As3E94w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-schema": "^2.7.0", + "@peculiar/asn1-x509": "^2.7.0", + "asn1js": "^3.0.6", + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/utils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/@peculiar/utils/-/utils-2.0.3.tgz", + "integrity": "sha512-+oL3HPFRIZ1St2K50lWCXiioIgSoxzz7R1J3uF6neO2yl1sgmpgY6XXJH4BdpoDkMWznQTeYF6oWNDZLCdQ4eQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "tslib": "^2.8.1" + } + }, + "node_modules/@peculiar/x509": { + "version": "1.14.3", + "resolved": "https://registry.npmjs.org/@peculiar/x509/-/x509-1.14.3.tgz", + "integrity": "sha512-C2Xj8FZ0uHWeCXXqX5B4/gVFQmtSkiuOolzAgutjTfseNOHT3pUjljDZsTSxXFGgio54bCzVFqmEOUrIVk8RDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@peculiar/asn1-cms": "^2.6.0", + "@peculiar/asn1-csr": "^2.6.0", + "@peculiar/asn1-ecc": "^2.6.0", + "@peculiar/asn1-pkcs9": "^2.6.0", + "@peculiar/asn1-rsa": "^2.6.0", + "@peculiar/asn1-schema": "^2.6.0", + "@peculiar/asn1-x509": "^2.6.0", + "pvtsutils": "^1.3.6", + "reflect-metadata": "^0.2.2", + "tslib": "^2.8.1", + "tsyringe": "^4.10.0" + }, + "engines": { + "node": ">=20.0.0" + } + }, "node_modules/@types/body-parser": { "version": "1.19.5", "resolved": "https://registry.npmjs.org/@types/body-parser/-/body-parser-1.19.5.tgz", @@ -371,16 +553,16 @@ "license": "MIT" }, "node_modules/@types/express": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.21.tgz", - "integrity": "sha512-ejlPM315qwLpaQlQDTjPdsUFSc6ZsP4AN6AlWnogPjQ7CVi7PYF3YVz+CY3jE2pwYf7E/7HlDAN0rV2GxTG0HQ==", + "version": "4.17.25", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.25.tgz", + "integrity": "sha512-dVd04UKsfpINUnK0yBoYHDF3xu7xVH4BuDotC/xGuycx4CgbP48X/KF/586bcObxT0HENHXEU8Nqtu6NR+eKhw==", "dev": true, "license": "MIT", "dependencies": { "@types/body-parser": "*", "@types/express-serve-static-core": "^4.17.33", "@types/qs": "*", - "@types/serve-static": "*" + "@types/serve-static": "^1" } }, "node_modules/@types/express-serve-static-core": { @@ -457,16 +639,6 @@ "undici-types": "~6.20.0" } }, - "node_modules/@types/node-forge": { - "version": "1.3.11", - "resolved": "https://registry.npmjs.org/@types/node-forge/-/node-forge-1.3.11.tgz", - "integrity": "sha512-FQx220y22OKNTqaByeBGqHWYz4cl94tpcxeFdvBo3wjG6XPBuZ0BNgNZRV5J5TFmmcsJ4IzsLkmGRiQbnYsBEQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/node": "*" - } - }, "node_modules/@types/qs": { "version": "6.9.18", "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.18.tgz", @@ -1002,6 +1174,21 @@ "dev": true, "license": "MIT" }, + "node_modules/asn1js": { + "version": "3.0.10", + "resolved": "https://registry.npmjs.org/asn1js/-/asn1js-3.0.10.tgz", + "integrity": "sha512-S2s3aOytiKdFRdulw2qPE51MzjzVOisppcVv7jVFR+Kw0kxwvFrDcYA0h7Ndqbmj0HkMIXYWaoj7fli8kgx1eg==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "pvtsutils": "^1.3.6", + "pvutils": "^1.1.5", + "tslib": "^2.8.1" + }, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -1040,24 +1227,24 @@ } }, "node_modules/body-parser": { - "version": "1.20.3", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", - "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", + "version": "1.20.5", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.5.tgz", + "integrity": "sha512-3grm+/2tUOvu2cjJkvsIxrv/wVpfXQW4PsQHYm7yk4vfpu7Ekl6nEsYBoJUL6qDwZUx8wUhQ8tR2qz+ad9c9OA==", "dev": true, "license": "MIT", "dependencies": { - "bytes": "3.1.2", + "bytes": "~3.1.2", "content-type": "~1.0.5", "debug": "2.6.9", "depd": "2.0.0", - "destroy": "1.2.0", - "http-errors": "2.0.0", - "iconv-lite": "0.4.24", - "on-finished": "2.4.1", - "qs": "6.13.0", - "raw-body": "2.5.2", + "destroy": "~1.2.0", + "http-errors": "~2.0.1", + "iconv-lite": "~0.4.24", + "on-finished": "~2.4.1", + "qs": "~6.15.1", + "raw-body": "~2.5.3", "type-is": "~1.6.18", - "unpipe": "1.0.0" + "unpipe": "~1.0.0" }, "engines": { "node": ">= 0.8", @@ -1074,6 +1261,37 @@ "node": ">= 0.8" } }, + "node_modules/body-parser/node_modules/http-errors": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.1.tgz", + "integrity": "sha512-4FbRdAX+bSdmo4AUFuS0WNiPz8NgFt+r8ThgNWmlrjQjt1Q7ZR9+zTlce2859x4KSXrwIsaeTqDoKQmtP8pLmQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "depd": "~2.0.0", + "inherits": "~2.0.4", + "setprototypeof": "~1.2.0", + "statuses": "~2.0.2", + "toidentifier": "~1.0.1" + }, + "engines": { + "node": ">= 0.8" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/body-parser/node_modules/statuses": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", + "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/bonjour-service": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.3.0.tgz", @@ -1219,6 +1437,16 @@ "node": ">= 0.8" } }, + "node_modules/bytestreamjs": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/bytestreamjs/-/bytestreamjs-2.0.1.tgz", + "integrity": "sha512-U1Z/ob71V/bXfVABvNr/Kumf5VyeQRBEm6Txb0PQ6S7V5GpBM3w4Cbqz/xPDicR5tN0uvDifng8C+5qECeGwyQ==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=6.0.0" + } + }, "node_modules/call-bind-apply-helpers": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", @@ -2226,15 +2454,15 @@ } }, "node_modules/express": { - "version": "4.22.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.22.1.tgz", - "integrity": "sha512-F2X8g9P1X7uCPZMA3MVf9wcTqlyNp7IhH5qPCI0izhaOIYXaW9L535tGA3qmjRzpH+bZczqq7hVKxTR4NWnu+g==", + "version": "4.22.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.22.2.tgz", + "integrity": "sha512-IuL+Elrou2ZvCFHs18/CIzy2Nzvo25nZ1/D2eIZlz7c+QUayAcYoiM2BthCjs+EBHVpjYjcuLDAiCWgeIX3X1Q==", "dev": true, "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "~1.20.3", + "body-parser": "~1.20.5", "content-disposition": "~0.5.4", "content-type": "~1.0.4", "cookie": "~0.7.1", @@ -2253,7 +2481,7 @@ "parseurl": "~1.3.3", "path-to-regexp": "~0.1.12", "proxy-addr": "~2.0.7", - "qs": "~6.14.0", + "qs": "~6.15.1", "range-parser": "~1.2.1", "safe-buffer": "5.2.1", "send": "~0.19.0", @@ -2292,22 +2520,6 @@ "dev": true, "license": "MIT" }, - "node_modules/express/node_modules/qs": { - "version": "6.14.1", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", - "integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "side-channel": "^1.1.0" - }, - "engines": { - "node": ">=0.6" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/express/node_modules/range-parser": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", @@ -3538,16 +3750,6 @@ "tslib": "^2.0.3" } }, - "node_modules/node-forge": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.4.0.tgz", - "integrity": "sha512-LarFH0+6VfriEhqMMcLX2F7SwSXeWwnEAJEsYm5QKWchiVYVvJyV9v7UDvUv+w5HO23ZpQTXDv/GxdDdMyOuoQ==", - "dev": true, - "license": "(BSD-3-Clause OR GPL-2.0)", - "engines": { - "node": ">= 6.13.0" - } - }, "node_modules/node-releases": { "version": "2.0.27", "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.27.tgz", @@ -3848,6 +4050,24 @@ "node": ">=8" } }, + "node_modules/pkijs": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/pkijs/-/pkijs-3.4.0.tgz", + "integrity": "sha512-emEcLuomt2j03vxD54giVB4SxTjnsqkU692xZOZXHDVoYyypEm+b3jpiTcc+Cf+myooc+/Ly0z01jqeNHVgJGw==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "@noble/hashes": "1.4.0", + "asn1js": "^3.0.6", + "bytestreamjs": "^2.0.1", + "pvtsutils": "^1.3.6", + "pvutils": "^1.1.3", + "tslib": "^2.8.1" + }, + "engines": { + "node": ">=16.0.0" + } + }, "node_modules/postcss": { "version": "8.5.10", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.10.tgz", @@ -4003,14 +4223,34 @@ "node": ">= 0.10" } }, + "node_modules/pvtsutils": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/pvtsutils/-/pvtsutils-1.3.6.tgz", + "integrity": "sha512-PLgQXQ6H2FWCaeRak8vvk1GW462lMxB5s3Jm673N82zI4vqtVUPuZdffdZbPDFRoU8kAhItWFtPCWiPpp4/EDg==", + "dev": true, + "license": "MIT", + "dependencies": { + "tslib": "^2.8.1" + } + }, + "node_modules/pvutils": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/pvutils/-/pvutils-1.1.5.tgz", + "integrity": "sha512-KTqnxsgGiQ6ZAzZCVlJH5eOjSnvlyEgx1m8bkRJfOhmGRqfo5KLvmAlACQkrjEtOQ4B7wF9TdSLIs9O90MX9xA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=16.0.0" + } + }, "node_modules/qs": { - "version": "6.13.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", - "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", + "version": "6.15.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.15.2.tgz", + "integrity": "sha512-Rzq0KEyX/w/tEybncDgdkZrJgVUsUMk3xjh3t5bv3S1HTAtg+uOYt72+ZfwiQwKdysThkTBdL/rTi6HDmX9Ddw==", "dev": true, "license": "BSD-3-Clause", "dependencies": { - "side-channel": "^1.0.6" + "side-channel": "^1.1.0" }, "engines": { "node": ">=0.6" @@ -4030,16 +4270,16 @@ } }, "node_modules/raw-body": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", - "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", + "version": "2.5.3", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.3.tgz", + "integrity": "sha512-s4VSOf6yN0rvbRZGxs8Om5CWj6seneMwK3oDb4lWDH0UPhWcxwOWw5+qk24bxq87szX1ydrwylIOp2uG1ojUpA==", "dev": true, "license": "MIT", "dependencies": { - "bytes": "3.1.2", - "http-errors": "2.0.0", - "iconv-lite": "0.4.24", - "unpipe": "1.0.0" + "bytes": "~3.1.2", + "http-errors": "~2.0.1", + "iconv-lite": "~0.4.24", + "unpipe": "~1.0.0" }, "engines": { "node": ">= 0.8" @@ -4055,6 +4295,37 @@ "node": ">= 0.8" } }, + "node_modules/raw-body/node_modules/http-errors": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.1.tgz", + "integrity": "sha512-4FbRdAX+bSdmo4AUFuS0WNiPz8NgFt+r8ThgNWmlrjQjt1Q7ZR9+zTlce2859x4KSXrwIsaeTqDoKQmtP8pLmQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "depd": "~2.0.0", + "inherits": "~2.0.4", + "setprototypeof": "~1.2.0", + "statuses": "~2.0.2", + "toidentifier": "~1.0.1" + }, + "engines": { + "node": ">= 0.8" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/raw-body/node_modules/statuses": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", + "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/rc": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", @@ -4112,6 +4383,13 @@ "node": ">= 10.13.0" } }, + "node_modules/reflect-metadata": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/reflect-metadata/-/reflect-metadata-0.2.2.tgz", + "integrity": "sha512-urBwgfrvVP/eAyXx4hluJivBKzuEbSQs9rKWCrCkbSxNv8mxPcUZKeuoF3Uy4mJl3Lwprp6yy5/39VWigZ4K6Q==", + "dev": true, + "license": "Apache-2.0" + }, "node_modules/registry-auth-token": { "version": "3.3.2", "resolved": "https://registry.npmjs.org/registry-auth-token/-/registry-auth-token-3.3.2.tgz", @@ -4286,17 +4564,17 @@ "license": "MIT" }, "node_modules/selfsigned": { - "version": "2.4.1", - "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.4.1.tgz", - "integrity": "sha512-th5B4L2U+eGLq1TVh7zNRGBapioSORUeymIydxgFpwww9d2qyKvtuPU2jJuHvYAwwqi2Y596QBL3eEqcPEYL8Q==", + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-5.5.0.tgz", + "integrity": "sha512-ftnu3TW4+3eBfLRFnDEkzGxSF/10BJBkaLJuBHZX0kiPS7bRdlpZGu6YGt4KngMkdTwJE6MbjavFpqHvqVt+Ew==", "dev": true, "license": "MIT", "dependencies": { - "@types/node-forge": "^1.3.0", - "node-forge": "^1" + "@peculiar/x509": "^1.14.2", + "pkijs": "^3.3.3" }, "engines": { - "node": ">=10" + "node": ">=18" } }, "node_modules/semver": { @@ -5084,6 +5362,26 @@ "dev": true, "license": "0BSD" }, + "node_modules/tsyringe": { + "version": "4.10.0", + "resolved": "https://registry.npmjs.org/tsyringe/-/tsyringe-4.10.0.tgz", + "integrity": "sha512-axr3IdNuVIxnaK5XGEUFTu3YmAQ6lllgrvqfEoR16g/HGnYY/6We4oWENtAnzK6/LpJ2ur9PAb80RBt7/U4ugw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tslib": "^1.9.3" + }, + "engines": { + "node": ">= 6.0.0" + } + }, + "node_modules/tsyringe/node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", + "dev": true, + "license": "0BSD" + }, "node_modules/type-fest": { "version": "2.19.0", "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-2.19.0.tgz", @@ -5393,15 +5691,15 @@ } }, "node_modules/webpack-dev-server": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", - "integrity": "sha512-ml/0HIj9NLpVKOMq+SuBPLHcmbG+TGIjXRHsYfZwocUBIqEvws8NnS/V9AFQ5FKP+tgn5adwVwRrTEpGL33QFQ==", + "version": "5.2.4", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.4.tgz", + "integrity": "sha512-GqDPGZN9bRqKBTkp4aWkobDDHMsrXKoGSdOH56smIri8qR0JG8gfL8/v/f/OZR3/OKXjG8uwJbFVhKm/FNU/UA==", "dev": true, "license": "MIT", "dependencies": { "@types/bonjour": "^3.5.13", "@types/connect-history-api-fallback": "^1.5.4", - "@types/express": "^4.17.21", + "@types/express": "^4.17.25", "@types/express-serve-static-core": "^4.17.21", "@types/serve-index": "^1.9.4", "@types/serve-static": "^1.15.5", @@ -5411,17 +5709,17 @@ "bonjour-service": "^1.2.1", "chokidar": "^3.6.0", "colorette": "^2.0.10", - "compression": "^1.7.4", + "compression": "^1.8.1", "connect-history-api-fallback": "^2.0.0", - "express": "^4.21.2", + "express": "^4.22.1", "graceful-fs": "^4.2.6", - "http-proxy-middleware": "^2.0.7", + "http-proxy-middleware": "^2.0.9", "ipaddr.js": "^2.1.0", "launch-editor": "^2.6.1", "open": "^10.0.3", "p-retry": "^6.2.0", "schema-utils": "^4.2.0", - "selfsigned": "^2.4.1", + "selfsigned": "^5.5.0", "serve-index": "^1.9.1", "sockjs": "^0.3.24", "spdy": "^4.0.2", diff --git a/wasm/demo/package.json b/wasm/demo/package.json index a815b90c70b..0b22c24ea50 100644 --- a/wasm/demo/package.json +++ b/wasm/demo/package.json @@ -19,7 +19,7 @@ "serve": "^14.2.6", "webpack": "^5.105.0", "webpack-cli": "^6.0.1", - "webpack-dev-server": "^5.2.1" + "webpack-dev-server": "^5.2.4" }, "scripts": { "dev": "webpack serve",