diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..651d3ae --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,160 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +permissions: + contents: read + +jobs: + clippy: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install Rust + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 # stable + with: + toolchain: stable + components: clippy + + - name: Cache cargo + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ubuntu-latest-cargo-clippy-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ubuntu-latest-cargo-clippy- + + - name: Run clippy + run: cargo clippy --workspace --all-features -- -D warnings + + test: + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install Rust + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 # stable + with: + toolchain: stable + + - name: Cache cargo + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ matrix.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ matrix.os }}-cargo- + + - name: Run tests + run: cargo test --workspace --all-features + + dispatch-tests: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup Node.js + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 + with: + node-version: '20' + + - name: Run npm dispatch tests + run: node --test npm/socket-patch/bin/socket-patch.test.mjs + + - name: Setup Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + with: + python-version: '3.12' + + - name: Run pypi dispatch tests + run: python pypi/socket-patch/test_dispatch.py + + e2e: + needs: test + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + suite: e2e_npm + - os: ubuntu-latest + suite: e2e_pypi + - os: ubuntu-latest + suite: e2e_cargo + - os: ubuntu-latest + suite: e2e_golang + - os: ubuntu-latest + suite: e2e_maven + - os: ubuntu-latest + suite: e2e_gem + - os: ubuntu-latest + suite: e2e_composer + - os: ubuntu-latest + suite: e2e_nuget + - os: macos-latest + suite: e2e_npm + - os: macos-latest + suite: e2e_pypi + runs-on: ${{ matrix.os }} + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install Rust + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 # stable + with: + toolchain: stable + + - name: Cache cargo + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ matrix.os }}-cargo-e2e-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ matrix.os }}-cargo-e2e- + + - name: Setup Node.js + if: matrix.suite == 'e2e_npm' + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 + with: + node-version: 20 + + - name: Setup Python + if: matrix.suite == 'e2e_pypi' + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + with: + python-version: "3.12" + + - name: Setup Ruby + if: matrix.suite == 'e2e_gem' + uses: ruby/setup-ruby@319994f95fa847cf3fb3cd3dbe89f6dcde9f178f # v1.295.0 + with: + ruby-version: '3.2' + bundler-cache: false + + - name: Run e2e tests + run: cargo test -p socket-patch-cli --all-features --test ${{ matrix.suite }} -- --ignored diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index f84cf50..0000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: 📦 Publish - -on: - workflow_dispatch: - inputs: - dist-tag: - description: 'npm dist-tag (latest, next, beta, canary, backport, etc.)' - required: false - default: 'latest' - type: string - debug: - description: 'Enable debug output' - required: false - default: '0' - type: string - options: - - '0' - - '1' - -permissions: - contents: write - id-token: write - -jobs: - publish: - uses: SocketDev/socket-registry/.github/workflows/provenance.yml@63ad52562c1f2d007a1833b2b22cffc3001e1cc2 # main - with: - debug: ${{ inputs.debug }} - dist-tag: ${{ inputs.dist-tag }} - package-name: '@socketsecurity/socket-patch' - publish-script: 'publish:ci' - setup-script: 'pnpm run build' - use-trusted-publishing: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..8f12316 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,357 @@ +name: Release + +on: + workflow_dispatch: + inputs: + dry-run: + description: 'Dry run (build only, skip publish)' + type: boolean + default: false + +permissions: {} + +jobs: + version: + runs-on: ubuntu-latest + outputs: + version: ${{ steps.read.outputs.VERSION }} + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Read version from Cargo.toml + id: read + run: | + VERSION=$(grep '^version = ' Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') + echo "VERSION=$VERSION" >> "$GITHUB_OUTPUT" + echo "Release version: $VERSION" + + - name: Check tag does not exist + run: | + VERSION="${{ steps.read.outputs.VERSION }}" + if git rev-parse "v${VERSION}" >/dev/null 2>&1; then + echo "::error::Tag v${VERSION} already exists. Bump the version in a PR first." + exit 1 + fi + + tag: + needs: version + if: ${{ !inputs.dry-run }} + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Create and push tag + run: | + TAG="v${{ needs.version.outputs.version }}" + git tag "$TAG" + git push origin "$TAG" + + build: + needs: [version, tag] + if: ${{ always() && needs.version.result == 'success' && (needs.tag.result == 'success' || needs.tag.result == 'skipped') }} + strategy: + matrix: + include: + - target: aarch64-apple-darwin + runner: macos-14 + archive: tar.gz + build-tool: cargo + - target: x86_64-apple-darwin + runner: macos-14 + archive: tar.gz + build-tool: cargo + - target: x86_64-unknown-linux-gnu + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: x86_64-unknown-linux-musl + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: aarch64-unknown-linux-gnu + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: aarch64-unknown-linux-musl + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: x86_64-pc-windows-msvc + runner: windows-latest + archive: zip + build-tool: cargo + - target: i686-pc-windows-msvc + runner: windows-latest + archive: zip + build-tool: cargo + - target: aarch64-pc-windows-msvc + runner: windows-latest + archive: zip + build-tool: cargo + - target: aarch64-linux-android + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: arm-unknown-linux-gnueabihf + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: arm-unknown-linux-musleabihf + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: i686-unknown-linux-gnu + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + - target: i686-unknown-linux-musl + runner: ubuntu-latest + archive: tar.gz + build-tool: cross + runs-on: ${{ matrix.runner }} + permissions: + contents: read + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install Rust + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 # stable + with: + toolchain: stable + targets: ${{ matrix.target }} + + - name: Install cross + if: matrix.build-tool == 'cross' + run: cargo install cross --git https://github.com/cross-rs/cross + + - name: Build (cargo) + if: matrix.build-tool == 'cargo' + run: cargo build --release --target ${{ matrix.target }} + + - name: Build (cross) + if: matrix.build-tool == 'cross' + run: cross build --release --target ${{ matrix.target }} + + - name: Package (unix) + if: matrix.archive == 'tar.gz' + run: | + cd target/${{ matrix.target }}/release + tar czf ../../../socket-patch-${{ matrix.target }}.tar.gz socket-patch + cd ../../.. + + - name: Package (windows) + if: matrix.archive == 'zip' + shell: pwsh + run: | + Compress-Archive -Path "target/${{ matrix.target }}/release/socket-patch.exe" -DestinationPath "socket-patch-${{ matrix.target }}.zip" + + - name: Upload artifact (tar.gz) + if: matrix.archive == 'tar.gz' + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 + with: + name: socket-patch-${{ matrix.target }} + path: socket-patch-${{ matrix.target }}.tar.gz + + - name: Upload artifact (zip) + if: matrix.archive == 'zip' + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 + with: + name: socket-patch-${{ matrix.target }} + path: socket-patch-${{ matrix.target }}.zip + + github-release: + needs: [version, build] + if: ${{ !inputs.dry-run }} + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Download all artifacts + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + with: + path: artifacts + merge-multiple: true + + - name: Create GitHub Release + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + TAG="v${{ needs.version.outputs.version }}" + gh release create "$TAG" \ + --repo "$GITHUB_REPOSITORY" \ + --generate-notes \ + artifacts/* + + cargo-publish: + needs: [version, build] + if: ${{ !inputs.dry-run }} + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install Rust + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 # stable + with: + toolchain: stable + + - name: Authenticate with crates.io + id: crates-io-auth + uses: rust-lang/crates-io-auth-action@b7e9a28eded4986ec6b1fa40eeee8f8f165559ec # v1.0.3 + + - name: Publish socket-patch-core + run: cargo publish -p socket-patch-core + env: + CARGO_REGISTRY_TOKEN: ${{ steps.crates-io-auth.outputs.token }} + + - name: Wait for crates.io index update + run: sleep 30 + + - name: Copy README for CLI crate + run: cp README.md crates/socket-patch-cli/README.md + + - name: Publish socket-patch-cli + run: cargo publish -p socket-patch-cli + env: + CARGO_REGISTRY_TOKEN: ${{ steps.crates-io-auth.outputs.token }} + + npm-publish: + needs: [version, build] + if: ${{ !inputs.dry-run }} + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Configure git for HTTPS + run: git config --global url."https://github.com/".insteadOf "ssh://git@github.com/" + + - name: Download all artifacts + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + with: + path: artifacts + merge-multiple: true + + - name: Setup Node.js + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 + with: + node-version: '22.22.1' + registry-url: 'https://registry.npmjs.org' + + - name: Update npm for trusted publishing + run: npm install -g npm@latest + + - name: Stage binaries into platform packages + run: | + # Unix platforms: extract binary into each platform package directory + stage_unix() { + local artifact="$1" pkg_dir="$2" + tar xzf "artifacts/${artifact}.tar.gz" -C "${pkg_dir}/" + } + + # Windows platforms: extract .exe into each platform package directory + stage_win() { + local artifact="$1" pkg_dir="$2" + unzip -o "artifacts/${artifact}.zip" -d "${pkg_dir}/" + } + + stage_unix socket-patch-aarch64-apple-darwin npm/socket-patch-darwin-arm64 + stage_unix socket-patch-x86_64-apple-darwin npm/socket-patch-darwin-x64 + stage_unix socket-patch-x86_64-unknown-linux-gnu npm/socket-patch-linux-x64-gnu + stage_unix socket-patch-x86_64-unknown-linux-musl npm/socket-patch-linux-x64-musl + stage_unix socket-patch-aarch64-unknown-linux-gnu npm/socket-patch-linux-arm64-gnu + stage_unix socket-patch-aarch64-unknown-linux-musl npm/socket-patch-linux-arm64-musl + stage_unix socket-patch-arm-unknown-linux-gnueabihf npm/socket-patch-linux-arm-gnu + stage_unix socket-patch-arm-unknown-linux-musleabihf npm/socket-patch-linux-arm-musl + stage_unix socket-patch-i686-unknown-linux-gnu npm/socket-patch-linux-ia32-gnu + stage_unix socket-patch-i686-unknown-linux-musl npm/socket-patch-linux-ia32-musl + stage_unix socket-patch-aarch64-linux-android npm/socket-patch-android-arm64 + + stage_win socket-patch-x86_64-pc-windows-msvc npm/socket-patch-win32-x64 + stage_win socket-patch-i686-pc-windows-msvc npm/socket-patch-win32-ia32 + stage_win socket-patch-aarch64-pc-windows-msvc npm/socket-patch-win32-arm64 + + - name: Publish platform packages + run: | + for pkg_dir in npm/socket-patch-*/; do + echo "Publishing ${pkg_dir}..." + npm publish "./${pkg_dir}" --provenance --access public || { + if npm view "@socketsecurity/$(basename "$pkg_dir")@${{ needs.version.outputs.version }}" version >/dev/null 2>&1; then + echo "Already published, skipping." + else + exit 1 + fi + } + done + + - name: Wait for npm registry propagation + run: sleep 30 + + - name: Copy README for npm package + run: cp README.md npm/socket-patch/README.md + + - name: Publish main package + run: | + npm publish ./npm/socket-patch --provenance --access public || { + if npm view "@socketsecurity/socket-patch@${{ needs.version.outputs.version }}" version >/dev/null 2>&1; then + echo "Already published, skipping." + else + exit 1 + fi + } + + pypi-publish: + needs: [version, build] + if: ${{ !inputs.dry-run }} + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Download all artifacts + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + with: + path: artifacts + merge-multiple: true + + - name: Setup Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + with: + python-version: '3.12' + + - name: Copy README for PyPI package + run: cp README.md pypi/socket-patch/README.md + + - name: Build platform wheels + run: | + VERSION="${{ needs.version.outputs.version }}" + python scripts/build-pypi-wheels.py --version "$VERSION" --artifacts artifacts --dist dist + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 + with: + packages-dir: dist/ diff --git a/.gitignore b/.gitignore index 9a5aced..1ab7447 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,14 @@ dist # Vite logs files vite.config.js.timestamp-* vite.config.ts.timestamp-* + +# Rust +target/ + +# npm binaries (populated at publish time) +npm/socket-patch/bin/socket-patch-* + +# READMEs copied at publish time +crates/socket-patch-cli/README.md +npm/socket-patch/README.md +pypi/socket-patch/README.md diff --git a/.oxlintrc.json b/.oxlintrc.json deleted file mode 100644 index 3415f15..0000000 --- a/.oxlintrc.json +++ /dev/null @@ -1,64 +0,0 @@ -{ - "$schema": "./node_modules/oxlint/configuration_schema.json", - "ignorePatterns": ["**/dist/", "**/node_modules/", "**/.git/"], - "plugins": ["typescript", "oxc", "promise", "import"], - "categories": { - "correctness": "warn", - "perf": "warn", - "suspicious": "warn" - }, - "rules": { - "no-unused-vars": "allow", - "no-new-array": "allow", - "no-empty-file": "allow", - "no-await-in-loop": "allow", - "consistent-function-scoping": "allow", - "no-new": "allow", - "no-extraneous-class": "allow", - "no-array-index-key": "allow", - "no-unsafe-optional-chaining": "allow", - "no-promise-in-callback": "allow", - "no-callback-in-promise": "allow", - "consistent-type-imports": "deny", - "no-empty-named-blocks": "allow", - "no-unnecessary-parameter-property-assignment": "allow", - "no-unneeded-ternary": "allow", - "no-eq-null": "allow", - "max-lines-per-function": "allow", - "max-depth": "allow", - "no-magic-numbers": "allow", - "no-unassigned-import": "allow", - "promise/always-return": "allow", - "no-unassigned-vars": "deny", - "typescript/no-floating-promises": "deny", - "typescript/no-misused-promises": "deny", - "typescript/return-await": "allow", - "typescript/await-thenable": "allow", - "typescript/consistent-type-imports": "allow", - "typescript/no-base-to-string": "allow", - "typescript/no-duplicate-type-constituents": "allow", - "typescript/no-for-in-array": "allow", - "typescript/no-meaningless-void-operator": "allow", - "typescript/no-misused-spread": "allow", - "typescript/no-redundant-type-constituents": "allow", - "typescript/no-unnecessary-boolean-literal-compare": "allow", - "typescript/no-unnecessary-template-expression": "allow", - "typescript/no-unnecessary-type-arguments": "allow", - "typescript/no-unnecessary-type-assertion": "allow", - "typescript/no-unsafe-enum-comparison": "allow", - "typescript/no-unsafe-type-assertion": "allow", - "typescript/require-array-sort-compare": "allow", - "typescript/restrict-template-expressions": "allow", - "typescript/triple-slash-reference": "allow", - "typescript/unbound-method": "allow" - }, - "overrides": [ - { - "files": ["**/*.test.ts", "**/*.test.js", "**/*.spec.ts", "**/*.spec.js"], - "rules": { - "typescript/no-floating-promises": "allow", - "typescript/no-misused-promises": "allow" - } - } - ] -} diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..75484b5 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,2117 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror 1.0.69", + "zeroize", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "r-efi 5.3.0", + "wasip2", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket-patch-cli" +version = "2.1.4" +dependencies = [ + "clap", + "dialoguer", + "hex", + "indicatif", + "regex", + "serde", + "serde_json", + "sha2", + "socket-patch-core", + "tempfile", + "tokio", + "uuid", +] + +[[package]] +name = "socket-patch-core" +version = "2.1.4" +dependencies = [ + "hex", + "once_cell", + "regex", + "reqwest", + "serde", + "serde_json", + "sha2", + "tempfile", + "thiserror 2.0.18", + "tokio", + "uuid", + "walkdir", +] + +[[package]] +name = "socket2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tempfile" +version = "3.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "uuid" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" +dependencies = [ + "cfg-if", + "futures-util", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7ad7565 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,32 @@ +[workspace] +members = ["crates/socket-patch-core", "crates/socket-patch-cli"] +resolver = "2" + +[workspace.package] +version = "2.1.4" +edition = "2021" +license = "MIT" +repository = "https://github.com/SocketDev/socket-patch" + +[workspace.dependencies] +socket-patch-core = { path = "crates/socket-patch-core", version = "2.1.4" } +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +sha2 = "0.10" +hex = "0.4" +reqwest = { version = "0.12", features = ["rustls-tls", "json"], default-features = false } +tokio = { version = "1", features = ["full"] } +thiserror = "2" +walkdir = "2" +uuid = { version = "1", features = ["v4"] } +dialoguer = "0.11" +indicatif = "0.17" +tempfile = "3" +regex = "1" +once_cell = "1" + +[profile.release] +strip = true +lto = true +opt-level = "s" diff --git a/EDGE_CASES.md b/EDGE_CASES.md deleted file mode 100644 index 4bda7ea..0000000 --- a/EDGE_CASES.md +++ /dev/null @@ -1,464 +0,0 @@ -# Socket-Patch Setup Command: Edge Case Analysis - -This document provides a comprehensive analysis of all edge cases handled by the `socket-patch setup` command. - -## Detection Logic - -The setup command detects if a postinstall script is already configured by checking if the string contains `'socket-patch apply'`. This substring match is intentionally lenient to recognize various valid formats. - -## Edge Cases - -### 1. No scripts field at all - -**Input:** -```json -{ - "name": "test", - "version": "1.0.0" -} -``` - -**Behavior:** ✅ Creates scripts field and adds postinstall - -**Output:** -```json -{ - "name": "test", - "version": "1.0.0", - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply" - } -} -``` - ---- - -### 2. Scripts field exists but no postinstall - -**Input:** -```json -{ - "scripts": { - "test": "jest", - "build": "tsc" - } -} -``` - -**Behavior:** ✅ Adds postinstall to existing scripts object - -**Output:** -```json -{ - "scripts": { - "test": "jest", - "build": "tsc", - "postinstall": "npx @socketsecurity/socket-patch apply" - } -} -``` - ---- - -### 2a. Postinstall is null - -**Input:** -```json -{ - "scripts": { - "postinstall": null - } -} -``` - -**Behavior:** ✅ Treats as missing, adds socket-patch command - -**Output:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply" - } -} -``` - ---- - -### 2b. Postinstall is empty string - -**Input:** -```json -{ - "scripts": { - "postinstall": "" - } -} -``` - -**Behavior:** ✅ Replaces empty string with socket-patch command - -**Output:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply" - } -} -``` - ---- - -### 2c. Postinstall is whitespace only - -**Input:** -```json -{ - "scripts": { - "postinstall": " \n\t " - } -} -``` - -**Behavior:** ✅ Treats as empty, adds socket-patch command - -**Output:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply" - } -} -``` - ---- - -### 3. Postinstall exists but missing socket-patch setup - -**Input:** -```json -{ - "scripts": { - "postinstall": "echo 'Running postinstall tasks'" - } -} -``` - -**Behavior:** ✅ Prepends socket-patch before existing script - -**Output:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply && echo 'Running postinstall tasks'" - } -} -``` - -**Rationale:** Socket-patch runs first to apply security patches before other setup tasks. Uses `&&` to ensure existing script only runs if patching succeeds. - ---- - -### 4a. socket-patch apply without npx - -**Input:** -```json -{ - "scripts": { - "postinstall": "socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Valid if socket-patch is installed as a dependency. The substring `'socket-patch apply'` is present. - ---- - -### 4b. npx socket-patch apply (without @socketsecurity/) - -**Input:** -```json -{ - "scripts": { - "postinstall": "npx socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Valid format. The substring `'socket-patch apply'` is present. - ---- - -### 4c. Canonical format: npx @socketsecurity/socket-patch apply - -**Input:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** This is the recommended canonical format. - ---- - -### 4d. pnpm socket-patch apply - -**Input:** -```json -{ - "scripts": { - "postinstall": "pnpm socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Valid format for pnpm users. The substring `'socket-patch apply'` is present. - ---- - -### 4e. yarn socket-patch apply - -**Input:** -```json -{ - "scripts": { - "postinstall": "yarn socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Valid format for yarn users. The substring `'socket-patch apply'` is present. - ---- - -### 4f. node_modules/.bin/socket-patch apply (direct path) - -**Input:** -```json -{ - "scripts": { - "postinstall": "node_modules/.bin/socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Valid format using direct path. The substring `'socket-patch apply'` is present. - ---- - -### 4g. socket apply (main Socket CLI - DIFFERENT command) - -**Input:** -```json -{ - "scripts": { - "postinstall": "socket apply" - } -} -``` - -**Behavior:** ⚠️ NOT recognized as configured, adds socket-patch - -**Output:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply && socket apply" - } -} -``` - -**Rationale:** `socket apply` is a DIFFERENT command from the main Socket CLI. The substring `'socket-patch apply'` is NOT present. Socket-patch should be added separately. - ---- - -### 4h. socket-patch list (wrong subcommand) - -**Input:** -```json -{ - "scripts": { - "postinstall": "socket-patch list" - } -} -``` - -**Behavior:** ⚠️ NOT recognized as configured, adds socket-patch apply - -**Output:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply && socket-patch list" - } -} -``` - -**Rationale:** `socket-patch list` is a different subcommand. The substring `'socket-patch apply'` is NOT present (missing "apply"). Socket-patch apply should be added. - ---- - -### 4i. socket-patch apply with flags - -**Input:** -```json -{ - "scripts": { - "postinstall": "npx @socketsecurity/socket-patch apply --silent" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Valid format with flags. The substring `'socket-patch apply'` is present. - ---- - -### 4j. socket-patch apply in middle of script chain - -**Input:** -```json -{ - "scripts": { - "postinstall": "echo start && socket-patch apply && echo done" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Socket-patch is already in the chain. The substring `'socket-patch apply'` is present. - ---- - -### 4k. socket-patch apply at end of chain - -**Input:** -```json -{ - "scripts": { - "postinstall": "npm run prepare && socket-patch apply" - } -} -``` - -**Behavior:** ✅ Recognized as configured, no changes - -**Rationale:** Socket-patch is already present. The substring `'socket-patch apply'` is present. - -**Note:** While this is recognized, it's not ideal since patches won't be applied before the prepare script runs. However, we don't modify it to avoid breaking existing setups. - ---- - -### 5. Postinstall with invalid data types - -#### 5a. Number instead of string - -**Input:** -```json -{ - "scripts": { - "postinstall": 123 - } -} -``` - -**Behavior:** ✅ Treated as not configured, adds socket-patch - -**Rationale:** Invalid type is coerced or ignored. Setup adds proper string command. - -#### 5b. Array instead of string - -**Input:** -```json -{ - "scripts": { - "postinstall": ["echo", "hello"] - } -} -``` - -**Behavior:** ✅ Treated as not configured, adds socket-patch - -**Rationale:** Invalid type. Setup adds proper string command. - -#### 5c. Object instead of string - -**Input:** -```json -{ - "scripts": { - "postinstall": { "command": "echo hello" } - } -} -``` - -**Behavior:** ✅ Treated as not configured, adds socket-patch - -**Rationale:** Invalid type. Setup adds proper string command. - ---- - -### 6. Malformed JSON - -**Input:** -``` -{ name: "test", invalid json } -``` - -**Behavior:** ❌ Throws error: "Invalid package.json: failed to parse JSON" - -**Rationale:** Cannot process malformed JSON. User must fix the JSON first. - ---- - -## Summary Table - -| Scenario | Contains `'socket-patch apply'`? | Behavior | -|----------|----------------------------------|----------| -| No scripts field | ❌ | Add scripts + postinstall | -| Scripts exists, no postinstall | ❌ | Add postinstall | -| Postinstall is null/undefined/empty | ❌ | Add socket-patch command | -| Postinstall has other command | ❌ | Prepend socket-patch | -| `socket-patch apply` | ✅ | Skip (already configured) | -| `npx socket-patch apply` | ✅ | Skip (already configured) | -| `npx @socketsecurity/socket-patch apply` | ✅ | Skip (already configured) | -| `pnpm/yarn socket-patch apply` | ✅ | Skip (already configured) | -| `node_modules/.bin/socket-patch apply` | ✅ | Skip (already configured) | -| `socket-patch apply --flags` | ✅ | Skip (already configured) | -| In script chain with `socket-patch apply` | ✅ | Skip (already configured) | -| `socket apply` (main CLI) | ❌ | Add socket-patch apply | -| `socket-patch list` (wrong subcommand) | ❌ | Add socket-patch apply | -| Invalid data types | ❌ | Add socket-patch command | -| Malformed JSON | N/A | Throw error | - -## Testing - -All edge cases are tested in: -- **Unit tests:** `submodules/socket-patch/src/package-json/detect.test.ts` -- **E2E tests:** `workspaces/api-v0/e2e-tests/tests/59_socket-patch-setup.js` - -Run tests: -```bash -# Unit tests -cd submodules/socket-patch -npm test - -# E2E tests -pnpm --filter @socketsecurity/api-v0 run test e2e-tests/tests/59_socket-patch-setup.js -``` diff --git a/README.md b/README.md index 5f41bdd..2243335 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,46 @@ # Socket Patch CLI -Apply security patches to npm dependencies without waiting for upstream fixes. +Apply security patches to npm and Python dependencies without waiting for upstream fixes. ## Installation +### One-line install (recommended) + +```bash +curl -fsSL https://raw.githubusercontent.com/SocketDev/socket-patch/main/scripts/install.sh | sh +``` + +Detects your platform (macOS/Linux, x64/ARM64), downloads the latest binary, and installs to `/usr/local/bin` or `~/.local/bin`. Use `sudo sh` instead of `sh` if `/usr/local/bin` requires root. + +
+Manual download + +Download a prebuilt binary from the [latest release](https://github.com/SocketDev/socket-patch/releases/latest): + +```bash +# macOS (Apple Silicon) +curl -fsSL https://github.com/SocketDev/socket-patch/releases/latest/download/socket-patch-aarch64-apple-darwin.tar.gz | tar xz + +# macOS (Intel) +curl -fsSL https://github.com/SocketDev/socket-patch/releases/latest/download/socket-patch-x86_64-apple-darwin.tar.gz | tar xz + +# Linux (x86_64) +curl -fsSL https://github.com/SocketDev/socket-patch/releases/latest/download/socket-patch-x86_64-unknown-linux-musl.tar.gz | tar xz + +# Linux (ARM64) +curl -fsSL https://github.com/SocketDev/socket-patch/releases/latest/download/socket-patch-aarch64-unknown-linux-gnu.tar.gz | tar xz +``` + +Then move the binary onto your `PATH`: + +```bash +sudo mv socket-patch /usr/local/bin/ +``` + +
+ +### npm + ```bash npx @socketsecurity/socket-patch ``` @@ -14,83 +51,237 @@ Or install globally: npm install -g @socketsecurity/socket-patch ``` +### pip + +```bash +pip install socket-patch +``` + +### Cargo + +```bash +cargo install socket-patch-cli +``` + +By default this builds with npm and PyPI support. For additional ecosystems: + +```bash +cargo install socket-patch-cli --features cargo,golang,maven,composer,nuget +``` + +## Quick Start + +You can pass a patch UUID directly to `socket-patch` as a shortcut: + +```bash +socket-patch 550e8400-e29b-41d4-a716-446655440000 +# equivalent to: socket-patch get 550e8400-e29b-41d4-a716-446655440000 +``` + ## Commands +All commands support `--json` for structured JSON output and `--cwd ` to set the working directory (default: `.`). Every JSON response includes a `"status"` field (`"success"`, `"error"`, `"no_manifest"`, etc.) for reliable programmatic consumption. + +### `get` + +Get security patches from Socket API and apply them. Accepts a UUID, CVE ID, GHSA ID, PURL, or package name. The identifier type is auto-detected but can be forced with a flag. + +Alias: `download` + +**Usage:** +```bash +socket-patch get [options] +``` + +**Options:** +| Flag | Description | +|------|-------------| +| `--org ` | Organization slug (required when using `SOCKET_API_TOKEN`) | +| `--id` | Force identifier to be treated as a UUID | +| `--cve` | Force identifier to be treated as a CVE ID | +| `--ghsa` | Force identifier to be treated as a GHSA ID | +| `-p, --package` | Force identifier to be treated as a package name | +| `-y, --yes` | Skip confirmation prompt for multiple patches | +| `--save-only` | Download patch without applying it (alias: `--no-apply`) | +| `--one-off` | Apply patch immediately without saving to `.socket` folder | +| `-g, --global` | Apply to globally installed packages | +| `--global-prefix ` | Custom path to global `node_modules` | +| `--json` | Output results as JSON | +| `--api-token ` | Socket API token (overrides `SOCKET_API_TOKEN`) | +| `--api-url ` | Socket API URL (overrides `SOCKET_API_URL`) | +| `--cwd ` | Working directory (default: `.`) | + +**Examples:** +```bash +# Get patch by UUID +socket-patch get 550e8400-e29b-41d4-a716-446655440000 + +# Get patch by CVE +socket-patch get CVE-2024-12345 + +# Get patch by GHSA +socket-patch get GHSA-xxxx-yyyy-zzzz + +# Get patch by package name (fuzzy matches installed packages) +socket-patch get lodash + +# Download only, don't apply +socket-patch get CVE-2024-12345 --save-only + +# Apply to global packages +socket-patch get lodash -g + +# JSON output for scripting +socket-patch get CVE-2024-12345 --json -y +``` + +### `scan` + +Scan installed packages for available security patches. + +**Usage:** +```bash +socket-patch scan [options] +``` + +**Options:** +| Flag | Description | +|------|-------------| +| `--org ` | Organization slug | +| `--json` | Output results as JSON | +| `--ecosystems ` | Restrict to specific ecosystems (comma-separated, e.g. `npm,pypi`) | +| `-g, --global` | Scan globally installed packages | +| `--global-prefix ` | Custom path to global `node_modules` | +| `--batch-size ` | Packages per API request (default: `100`) | +| `--api-token ` | Socket API token (overrides `SOCKET_API_TOKEN`) | +| `--api-url ` | Socket API URL (overrides `SOCKET_API_URL`) | +| `--cwd ` | Working directory (default: `.`) | + +**Examples:** +```bash +# Scan local project +socket-patch scan + +# Scan with JSON output +socket-patch scan --json + +# Scan only npm packages +socket-patch scan --ecosystems npm + +# Scan global packages +socket-patch scan -g +``` + ### `apply` -Apply security patches from manifest. +Apply security patches from the local manifest. **Usage:** ```bash -npx @socketsecurity/socket-patch apply [options] +socket-patch apply [options] ``` **Options:** -- `--cwd` - Working directory (default: current directory) -- `-d, --dry-run` - Verify patches without modifying files -- `-s, --silent` - Only output errors -- `-m, --manifest-path` - Path to manifest (default: `.socket/manifest.json`) +| Flag | Description | +|------|-------------| +| `-d, --dry-run` | Verify patches without modifying files | +| `-s, --silent` | Only output errors | +| `-f, --force` | Skip pre-application hash verification (apply even if package version differs) | +| `-m, --manifest-path ` | Path to manifest (default: `.socket/manifest.json`) | +| `--offline` | Do not download missing blobs; fail if any are missing | +| `-g, --global` | Apply to globally installed packages | +| `--global-prefix ` | Custom path to global `node_modules` | +| `--ecosystems ` | Restrict to specific ecosystems (comma-separated, e.g. `npm,pypi`) | +| `--json` | Output results as JSON | +| `-v, --verbose` | Show detailed per-file verification information | +| `--cwd ` | Working directory (default: `.`) | **Examples:** ```bash # Apply patches -npx @socketsecurity/socket-patch apply +socket-patch apply # Dry run -npx @socketsecurity/socket-patch apply --dry-run +socket-patch apply --dry-run -# Custom manifest -npx @socketsecurity/socket-patch apply -m /path/to/manifest.json +# Apply only npm patches +socket-patch apply --ecosystems npm + +# Apply in offline mode +socket-patch apply --offline + +# JSON output for CI/CD +socket-patch apply --json ``` -### `download` +### `rollback` -Download patch from Socket API. +Rollback patches to restore original files. If no identifier is given, all patches are rolled back. **Usage:** ```bash -npx @socketsecurity/socket-patch download --uuid --org [options] +socket-patch rollback [identifier] [options] ``` **Options:** -- `--uuid` - Patch UUID (required) -- `--org` - Organization slug (required) -- `--api-token` - API token (or use `SOCKET_API_TOKEN` env var) -- `--api-url` - API URL (default: `https://api.socket.dev`) -- `--cwd` - Working directory -- `-m, --manifest-path` - Path to manifest +| Flag | Description | +|------|-------------| +| `-d, --dry-run` | Verify rollback without modifying files | +| `-s, --silent` | Only output errors | +| `-m, --manifest-path ` | Path to manifest (default: `.socket/manifest.json`) | +| `--offline` | Do not download missing blobs; fail if any are missing | +| `-g, --global` | Rollback globally installed packages | +| `--global-prefix ` | Custom path to global `node_modules` | +| `--one-off` | Rollback by fetching original files from API (no manifest required) | +| `--ecosystems ` | Restrict to specific ecosystems (comma-separated) | +| `--json` | Output results as JSON | +| `-v, --verbose` | Show detailed per-file verification information | +| `--org ` | Organization slug | +| `--api-token ` | Socket API token (overrides `SOCKET_API_TOKEN`) | +| `--api-url ` | Socket API URL (overrides `SOCKET_API_URL`) | +| `--cwd ` | Working directory (default: `.`) | **Examples:** ```bash -# Download patch -export SOCKET_API_TOKEN="your-token" -npx @socketsecurity/socket-patch download --uuid "550e8400-e29b-41d4-a716-446655440000" --org "my-org" +# Rollback all patches +socket-patch rollback + +# Rollback a specific package +socket-patch rollback "pkg:npm/lodash@4.17.20" + +# Rollback by UUID +socket-patch rollback 550e8400-e29b-41d4-a716-446655440000 -# With explicit token -npx @socketsecurity/socket-patch download --uuid "..." --org "my-org" --api-token "token" +# Dry run +socket-patch rollback --dry-run + +# JSON output +socket-patch rollback --json ``` ### `list` -List patches in manifest. +List all patches in the local manifest. **Usage:** ```bash -npx @socketsecurity/socket-patch list [options] +socket-patch list [options] ``` **Options:** -- `--cwd` - Working directory -- `-m, --manifest-path` - Path to manifest -- `--json` - Output as JSON +| Flag | Description | +|------|-------------| +| `--json` | Output as JSON | +| `-m, --manifest-path ` | Path to manifest (default: `.socket/manifest.json`) | +| `--cwd ` | Working directory (default: `.`) | **Examples:** ```bash # List patches -npx @socketsecurity/socket-patch list +socket-patch list # JSON output -npx @socketsecurity/socket-patch list --json +socket-patch list --json ``` **Sample Output:** @@ -111,29 +302,133 @@ Package: pkg:npm/lodash@4.17.20 ### `remove` -Remove patch from manifest. +Remove a patch from the manifest (rolls back files first by default). **Usage:** ```bash -npx @socketsecurity/socket-patch remove [options] +socket-patch remove [options] ``` **Arguments:** - `identifier` - Package PURL (e.g., `pkg:npm/package@version`) or patch UUID **Options:** -- `--cwd` - Working directory -- `-m, --manifest-path` - Path to manifest +| Flag | Description | +|------|-------------| +| `--skip-rollback` | Only update manifest, do not restore original files | +| `-g, --global` | Remove from globally installed packages | +| `--global-prefix ` | Custom path to global `node_modules` | +| `--json` | Output results as JSON | +| `-m, --manifest-path ` | Path to manifest (default: `.socket/manifest.json`) | +| `--cwd ` | Working directory (default: `.`) | **Examples:** ```bash # Remove by PURL -npx @socketsecurity/socket-patch remove "pkg:npm/lodash@4.17.20" +socket-patch remove "pkg:npm/lodash@4.17.20" # Remove by UUID -npx @socketsecurity/socket-patch remove "550e8400-e29b-41d4-a716-446655440000" +socket-patch remove 550e8400-e29b-41d4-a716-446655440000 + +# Remove without rolling back files +socket-patch remove "pkg:npm/lodash@4.17.20" --skip-rollback + +# JSON output +socket-patch remove "pkg:npm/lodash@4.17.20" --json +``` + +### `setup` + +Configure `package.json` postinstall scripts to automatically apply patches after `npm install`. + +**Usage:** +```bash +socket-patch setup [options] +``` + +**Options:** +| Flag | Description | +|------|-------------| +| `-d, --dry-run` | Preview changes without modifying files | +| `-y, --yes` | Skip confirmation prompt | +| `--json` | Output results as JSON | +| `--cwd ` | Working directory (default: `.`) | + +**Examples:** +```bash +# Interactive setup +socket-patch setup + +# Non-interactive +socket-patch setup -y + +# Preview changes +socket-patch setup --dry-run + +# JSON output for scripting +socket-patch setup --json -y +``` + +### `repair` + +Download missing blobs and clean up unused blobs. + +Alias: `gc` + +**Usage:** +```bash +socket-patch repair [options] ``` +**Options:** +| Flag | Description | +|------|-------------| +| `-d, --dry-run` | Show what would be done without doing it | +| `--offline` | Skip network operations (cleanup only) | +| `--download-only` | Only download missing blobs, do not clean up | +| `--json` | Output results as JSON | +| `-m, --manifest-path ` | Path to manifest (default: `.socket/manifest.json`) | +| `--cwd ` | Working directory (default: `.`) | + +**Examples:** +```bash +# Repair (download missing + clean up unused) +socket-patch repair + +# Cleanup only, no downloads +socket-patch repair --offline + +# Download missing blobs only +socket-patch repair --download-only + +# JSON output +socket-patch repair --json +``` + +## Scripting & CI/CD + +All commands support `--json` for machine-readable output. JSON responses always include a `"status"` field for easy error detection: + +```bash +# Check for available patches in CI +result=$(socket-patch scan --json --ecosystems npm) +patches=$(echo "$result" | jq '.totalPatches') + +# Apply patches and check result +socket-patch apply --json | jq '.status' +# "success", "partial_failure", "no_manifest", or "error" +``` + +When stdin is not a TTY (e.g., in CI pipelines), interactive prompts auto-proceed instead of blocking. Progress indicators and ANSI colors are automatically suppressed when output is piped. + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `SOCKET_API_TOKEN` | API authentication token | +| `SOCKET_ORG_SLUG` | Default organization slug | +| `SOCKET_API_URL` | API base URL (default: `https://api.socket.dev`) | + ## Manifest Format Downloaded patches are stored in `.socket/manifest.json`: @@ -157,10 +452,22 @@ Downloaded patches are stored in `.socket/manifest.json`: "severity": "high", "description": "Detailed description" } - } + }, + "description": "Patch description", + "license": "MIT", + "tier": "free" } } } ``` -Patched file contents are in `.socket/blobs/` (named by git SHA256 hash). +Patched file contents are in `.socket/blob/` (named by git SHA256 hash). + +## Supported Platforms + +| Platform | Architecture | +|----------|-------------| +| macOS | ARM64 (Apple Silicon), x86_64 (Intel) | +| Linux | x86_64, ARM64, ARMv7, i686 | +| Windows | x86_64, ARM64, i686 | +| Android | ARM64 | diff --git a/biome.json b/biome.json deleted file mode 100644 index f6b3f5a..0000000 --- a/biome.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "$schema": "./node_modules/@biomejs/biome/configuration_schema.json", - "files": { - "includes": ["**", "!.git", "!dist", "!node_modules"] - }, - "formatter": { - "enabled": true, - "formatWithErrors": false, - "indentStyle": "space", - "indentWidth": 2, - "lineEnding": "lf", - "lineWidth": 80 - }, - "linter": { - "enabled": false, - "rules": { - "style": { - "noParameterAssign": "error", - "useAsConstAssertion": "error", - "useDefaultParameterLast": "error", - "useEnumInitializers": "error", - "useSelfClosingElements": "error", - "useSingleVarDeclarator": "error", - "noUnusedTemplateLiteral": "error", - "useNumberNamespace": "error", - "noInferrableTypes": "error", - "noUselessElse": "error" - } - } - }, - "javascript": { - "formatter": { - "arrowParentheses": "asNeeded", - "semicolons": "asNeeded", - "quoteStyle": "single", - "jsxQuoteStyle": "single", - "trailingCommas": "all" - } - }, - "json": { - "formatter": { - "trailingCommas": "none", - "indentStyle": "space", - "indentWidth": 2 - } - } -} diff --git a/crates/socket-patch-cli/Cargo.toml b/crates/socket-patch-cli/Cargo.toml new file mode 100644 index 0000000..8911c79 --- /dev/null +++ b/crates/socket-patch-cli/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "socket-patch-cli" +description = "CLI binary for socket-patch: apply, rollback, get, scan security patches" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +readme = "README.md" + +[[bin]] +name = "socket-patch" +path = "src/main.rs" + +[dependencies] +socket-patch-core = { workspace = true } +clap = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +dialoguer = { workspace = true } +indicatif = { workspace = true } +uuid = { workspace = true } +regex = { workspace = true } +tempfile = { workspace = true } + +[features] +default = [] +cargo = ["socket-patch-core/cargo"] +golang = ["socket-patch-core/golang"] +maven = ["socket-patch-core/maven"] +composer = ["socket-patch-core/composer"] +nuget = ["socket-patch-core/nuget"] + +[dev-dependencies] +sha2 = { workspace = true } +hex = { workspace = true } diff --git a/crates/socket-patch-cli/src/commands/apply.rs b/crates/socket-patch-cli/src/commands/apply.rs new file mode 100644 index 0000000..c065db4 --- /dev/null +++ b/crates/socket-patch-cli/src/commands/apply.rs @@ -0,0 +1,486 @@ +use clap::Args; +use socket_patch_core::api::blob_fetcher::{ + fetch_missing_blobs, format_fetch_result, get_missing_blobs, +}; +use socket_patch_core::api::client::get_api_client_from_env; +use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; +use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem}; +use socket_patch_core::manifest::operations::read_manifest; +use socket_patch_core::patch::apply::{apply_package_patch, verify_file_patch, ApplyResult, VerifyStatus}; +use socket_patch_core::utils::cleanup_blobs::{cleanup_unused_blobs, format_cleanup_result}; +use socket_patch_core::utils::purl::strip_purl_qualifiers; +use socket_patch_core::utils::telemetry::{track_patch_applied, track_patch_apply_failed}; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use crate::ecosystem_dispatch::{find_packages_for_purls, partition_purls}; + +#[derive(Args)] +pub struct ApplyArgs { + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Verify patches can be applied without modifying files + #[arg(short = 'd', long = "dry-run", default_value_t = false)] + pub dry_run: bool, + + /// Only output errors + #[arg(short = 's', long, default_value_t = false)] + pub silent: bool, + + /// Path to patch manifest file + #[arg(short = 'm', long = "manifest-path", default_value = DEFAULT_PATCH_MANIFEST_PATH)] + pub manifest_path: String, + + /// Do not download missing blobs, fail if any are missing + #[arg(long, default_value_t = false)] + pub offline: bool, + + /// Apply patches to globally installed npm packages + #[arg(short = 'g', long, default_value_t = false)] + pub global: bool, + + /// Custom path to global node_modules + #[arg(long = "global-prefix")] + pub global_prefix: Option, + + /// Restrict patching to specific ecosystems + #[arg(long, value_delimiter = ',')] + pub ecosystems: Option>, + + /// Skip pre-application hash verification (apply even if package version differs) + #[arg(short = 'f', long, default_value_t = false)] + pub force: bool, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, + + /// Show detailed per-file verification information + #[arg(short = 'v', long, default_value_t = false)] + pub verbose: bool, +} + +fn verify_status_str(status: &VerifyStatus) -> &'static str { + match status { + VerifyStatus::Ready => "ready", + VerifyStatus::AlreadyPatched => "already_patched", + VerifyStatus::HashMismatch => "hash_mismatch", + VerifyStatus::NotFound => "not_found", + } +} + +fn result_to_json(result: &ApplyResult) -> serde_json::Value { + serde_json::json!({ + "purl": result.package_key, + "path": result.package_path, + "success": result.success, + "error": result.error, + "filesPatched": result.files_patched, + "filesVerified": result.files_verified.iter().map(|f| { + serde_json::json!({ + "file": f.file, + "status": verify_status_str(&f.status), + "message": f.message, + "currentHash": f.current_hash, + "expectedHash": f.expected_hash, + "targetHash": f.target_hash, + }) + }).collect::>(), + }) +} + +pub async fn run(args: ApplyArgs) -> i32 { + let (telemetry_client, _) = get_api_client_from_env(None).await; + let api_token = telemetry_client.api_token().cloned(); + let org_slug = telemetry_client.org_slug().cloned(); + + let manifest_path = if Path::new(&args.manifest_path).is_absolute() { + PathBuf::from(&args.manifest_path) + } else { + args.cwd.join(&args.manifest_path) + }; + + // Check if manifest exists - exit successfully if no .socket folder is set up + if tokio::fs::metadata(&manifest_path).await.is_err() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "no_manifest", + "patchesApplied": 0, + "alreadyPatched": 0, + "failed": 0, + "dryRun": args.dry_run, + "results": [], + })).unwrap()); + } else if !args.silent { + println!("No .socket folder found, skipping patch application."); + } + return 0; + } + + match apply_patches_inner(&args, &manifest_path).await { + Ok((success, results, unmatched)) => { + let patched_count = results + .iter() + .filter(|r| r.success && !r.files_patched.is_empty()) + .count(); + let already_patched_count = results + .iter() + .filter(|r| { + r.files_verified + .iter() + .all(|f| f.status == VerifyStatus::AlreadyPatched) + }) + .count(); + let failed_count = results.iter().filter(|r| !r.success).count(); + + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": if success { "success" } else { "partial_failure" }, + "patchesApplied": patched_count, + "alreadyPatched": already_patched_count, + "failed": failed_count, + "unmatchedPatches": unmatched.len(), + "unmatchedPurls": unmatched, + "dryRun": args.dry_run, + "results": results.iter().map(result_to_json).collect::>(), + })).unwrap()); + } else if !args.silent && !results.is_empty() { + let patched: Vec<_> = results.iter().filter(|r| r.success).collect(); + let already_patched: Vec<_> = results + .iter() + .filter(|r| { + r.files_verified + .iter() + .all(|f| f.status == VerifyStatus::AlreadyPatched) + }) + .collect(); + + if args.dry_run { + println!("\nPatch verification complete:"); + println!(" {} package(s) can be patched", patched.len()); + if !already_patched.is_empty() { + println!(" {} package(s) already patched", already_patched.len()); + } + } else { + println!("\nPatched packages:"); + for result in &patched { + if !result.files_patched.is_empty() { + println!(" {}", result.package_key); + } else if result.files_verified.iter().all(|f| { + f.status == VerifyStatus::AlreadyPatched + }) { + println!(" {} (already patched)", result.package_key); + } + } + } + + if args.verbose { + println!("\nDetailed verification:"); + for result in &results { + println!(" {}:", result.package_key); + for f in &result.files_verified { + let status_str = match f.status { + VerifyStatus::Ready => "ready", + VerifyStatus::AlreadyPatched => "already patched", + VerifyStatus::HashMismatch => "hash mismatch", + VerifyStatus::NotFound => "not found", + }; + println!(" {} [{}]", f.file, status_str); + if let Some(ref msg) = f.message { + println!(" message: {msg}"); + } + if args.verbose { + if let Some(ref h) = f.current_hash { + println!(" current: {h}"); + } + if let Some(ref h) = f.expected_hash { + println!(" expected: {h}"); + } + if let Some(ref h) = f.target_hash { + println!(" target: {h}"); + } + } + } + } + } + } + + // Track telemetry + if success { + track_patch_applied(patched_count, args.dry_run, api_token.as_deref(), org_slug.as_deref()).await; + } else { + track_patch_apply_failed("One or more patches failed to apply", args.dry_run, api_token.as_deref(), org_slug.as_deref()).await; + } + + if success { 0 } else { 1 } + } + Err(e) => { + track_patch_apply_failed(&e, args.dry_run, api_token.as_deref(), org_slug.as_deref()).await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e, + "patchesApplied": 0, + "alreadyPatched": 0, + "failed": 0, + "dryRun": args.dry_run, + "results": [], + })).unwrap()); + } else if !args.silent { + eprintln!("Error: {e}"); + } + 1 + } + } +} + +async fn apply_patches_inner( + args: &ApplyArgs, + manifest_path: &Path, +) -> Result<(bool, Vec, Vec), String> { + let manifest = read_manifest(manifest_path) + .await + .map_err(|e| e.to_string())? + .ok_or_else(|| "Invalid manifest".to_string())?; + + let socket_dir = manifest_path.parent().unwrap(); + let blobs_path = socket_dir.join("blobs"); + tokio::fs::create_dir_all(&blobs_path) + .await + .map_err(|e| e.to_string())?; + + // Check for and download missing blobs + let missing_blobs = get_missing_blobs(&manifest, &blobs_path).await; + if !missing_blobs.is_empty() { + if args.offline { + if !args.silent && !args.json { + eprintln!( + "Error: {} blob(s) are missing and --offline mode is enabled.", + missing_blobs.len() + ); + eprintln!("Run \"socket-patch repair\" to download missing blobs."); + } + return Ok((false, Vec::new(), Vec::new())); + } + + if !args.silent && !args.json { + println!("Downloading {} missing blob(s)...", missing_blobs.len()); + } + + let (client, _) = get_api_client_from_env(None).await; + let fetch_result = fetch_missing_blobs(&manifest, &blobs_path, &client, None).await; + + if !args.silent && !args.json { + println!("{}", format_fetch_result(&fetch_result)); + } + + if fetch_result.failed > 0 { + if !args.silent && !args.json { + eprintln!("Some blobs could not be downloaded. Cannot apply patches."); + } + return Ok((false, Vec::new(), Vec::new())); + } + } + + // Partition manifest PURLs by ecosystem + let manifest_purls: Vec = manifest.patches.keys().cloned().collect(); + let partitioned = + partition_purls(&manifest_purls, args.ecosystems.as_deref()); + + let target_manifest_purls: HashSet = partitioned + .values() + .flat_map(|purls| purls.iter().cloned()) + .collect(); + + let crawler_options = CrawlerOptions { + cwd: args.cwd.clone(), + global: args.global, + global_prefix: args.global_prefix.clone(), + batch_size: 100, + }; + + let all_packages = + find_packages_for_purls(&partitioned, &crawler_options, args.silent || args.json).await; + + let has_any_purls = !partitioned.is_empty(); + + if all_packages.is_empty() && !has_any_purls { + if !args.silent && !args.json { + if args.global || args.global_prefix.is_some() { + eprintln!("No global packages found"); + } else { + eprintln!("No package directories found"); + } + } + return Ok((false, Vec::new(), Vec::new())); + } + + if all_packages.is_empty() { + if !args.silent && !args.json { + eprintln!("Warning: No packages found that match available patches"); + eprintln!( + " {} targeted manifest patch(es) were in scope, but no matching packages were found on disk.", + target_manifest_purls.len() + ); + eprintln!(" Check that packages are installed and --cwd points to the right directory."); + } + let unmatched: Vec = target_manifest_purls.iter().cloned().collect(); + return Ok((false, Vec::new(), unmatched)); + } + + // Apply patches + let mut results: Vec = Vec::new(); + let mut has_errors = false; + + // Group pypi PURLs by base (for variant matching with qualifiers) + let mut pypi_qualified_groups: HashMap> = HashMap::new(); + if let Some(pypi_purls) = partitioned.get(&Ecosystem::Pypi) { + for purl in pypi_purls { + let base = strip_purl_qualifiers(purl).to_string(); + pypi_qualified_groups + .entry(base) + .or_default() + .push(purl.clone()); + } + } + + let mut applied_base_purls: HashSet = HashSet::new(); + let mut matched_manifest_purls: HashSet = HashSet::new(); + + for (purl, pkg_path) in &all_packages { + if Ecosystem::from_purl(purl) == Some(Ecosystem::Pypi) { + let base_purl = strip_purl_qualifiers(purl).to_string(); + if applied_base_purls.contains(&base_purl) { + continue; + } + + let variants = pypi_qualified_groups + .get(&base_purl) + .cloned() + .unwrap_or_else(|| vec![base_purl.clone()]); + let mut applied = false; + + for variant_purl in &variants { + let patch = match manifest.patches.get(variant_purl) { + Some(p) => p, + None => continue, + }; + + // Check first file hash match (skip when --force) + if !args.force { + if let Some((file_name, file_info)) = patch.files.iter().next() { + let verify = verify_file_patch(pkg_path, file_name, file_info).await; + if verify.status == VerifyStatus::HashMismatch { + continue; + } + } + } + + let result = apply_package_patch( + variant_purl, + pkg_path, + &patch.files, + &blobs_path, + args.dry_run, + args.force, + ) + .await; + + if result.success { + applied = true; + applied_base_purls.insert(base_purl.clone()); + results.push(result); + matched_manifest_purls.insert(variant_purl.clone()); + break; + } else { + results.push(result); + } + } + + if !applied { + has_errors = true; + if !args.silent && !args.json { + eprintln!("Failed to patch {base_purl}: no matching variant found"); + } + } + } else { + // npm PURLs: direct lookup + let patch = match manifest.patches.get(purl) { + Some(p) => p, + None => continue, + }; + + let result = apply_package_patch( + purl, + pkg_path, + &patch.files, + &blobs_path, + args.dry_run, + args.force, + ) + .await; + + if !result.success { + has_errors = true; + if !args.silent && !args.json { + eprintln!( + "Failed to patch {}: {}", + purl, + result.error.as_deref().unwrap_or("unknown error") + ); + } + } + results.push(result); + matched_manifest_purls.insert(purl.clone()); + } + } + + // Check if targeted manifest entries had no matches + let unmatched: Vec = target_manifest_purls + .iter() + .filter(|p| !matched_manifest_purls.contains(*p)) + .cloned() + .collect(); + + if !unmatched.is_empty() && !args.silent && !args.json { + eprintln!("\nWarning: {} manifest patch(es) had no matching installed package:", unmatched.len()); + for purl in &unmatched { + eprintln!(" - {}", purl); + } + } + + if !target_manifest_purls.is_empty() && matched_manifest_purls.is_empty() && !all_packages.is_empty() { + if !args.silent && !args.json { + eprintln!("Warning: None of the targeted manifest patches matched installed packages."); + } + has_errors = true; + } + + // Post-apply summary + if !args.silent && !args.json { + let applied_count = results.iter().filter(|r| r.success && !r.files_patched.is_empty()).count(); + let already_count = results.iter().filter(|r| { + r.files_verified.iter().all(|f| f.status == VerifyStatus::AlreadyPatched) + }).count(); + println!( + "\nSummary: {}/{} targeted patches applied, {} already patched, {} not found on disk", + applied_count, + target_manifest_purls.len(), + already_count, + unmatched.len() + ); + } + + // Clean up unused blobs + if !args.silent && !args.json { + if let Ok(cleanup_result) = cleanup_unused_blobs(&manifest, &blobs_path, args.dry_run).await { + if cleanup_result.blobs_removed > 0 { + println!("\n{}", format_cleanup_result(&cleanup_result, args.dry_run)); + } + } + } + + Ok((!has_errors, results, unmatched)) +} diff --git a/crates/socket-patch-cli/src/commands/get.rs b/crates/socket-patch-cli/src/commands/get.rs new file mode 100644 index 0000000..624b454 --- /dev/null +++ b/crates/socket-patch-cli/src/commands/get.rs @@ -0,0 +1,1259 @@ +use clap::Args; +use regex::Regex; +use socket_patch_core::api::client::get_api_client_from_env; +use socket_patch_core::api::types::{PatchSearchResult, SearchResponse}; +use socket_patch_core::crawlers::CrawlerOptions; +use socket_patch_core::manifest::operations::{read_manifest, write_manifest}; +use socket_patch_core::manifest::schema::{ + PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo, +}; +use socket_patch_core::utils::fuzzy_match::fuzzy_match_packages; +use socket_patch_core::utils::purl::is_purl; +use std::collections::HashMap; +use std::fmt; +use std::path::PathBuf; + +use crate::ecosystem_dispatch::crawl_all_ecosystems; +use crate::output::{confirm, select_one, SelectError}; + +#[derive(Args)] +pub struct GetArgs { + /// Patch identifier (UUID, CVE ID, GHSA ID, PURL, or package name) + pub identifier: String, + + /// Organization slug + #[arg(long)] + pub org: Option, + + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Force identifier to be treated as a patch UUID + #[arg(long, default_value_t = false)] + pub id: bool, + + /// Force identifier to be treated as a CVE ID + #[arg(long, default_value_t = false)] + pub cve: bool, + + /// Force identifier to be treated as a GHSA ID + #[arg(long, default_value_t = false)] + pub ghsa: bool, + + /// Force identifier to be treated as a package name + #[arg(short = 'p', long = "package", default_value_t = false)] + pub package: bool, + + /// Skip confirmation prompt for multiple patches + #[arg(short = 'y', long, default_value_t = false)] + pub yes: bool, + + /// Socket API URL (overrides SOCKET_API_URL env var) + #[arg(long = "api-url")] + pub api_url: Option, + + /// Socket API token (overrides SOCKET_API_TOKEN env var) + #[arg(long = "api-token")] + pub api_token: Option, + + /// Download patch without applying it + #[arg(long = "save-only", alias = "no-apply", default_value_t = false)] + pub save_only: bool, + + /// Apply patch to globally installed npm packages + #[arg(short = 'g', long, default_value_t = false)] + pub global: bool, + + /// Custom path to global node_modules + #[arg(long = "global-prefix")] + pub global_prefix: Option, + + /// Apply patch immediately without saving to .socket folder + #[arg(long = "one-off", default_value_t = false)] + pub one_off: bool, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, +} + +#[derive(Debug, PartialEq)] +enum IdentifierType { + Uuid, + Cve, + Ghsa, + Purl, + Package, +} + +impl fmt::Display for IdentifierType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + IdentifierType::Uuid => write!(f, "UUID"), + IdentifierType::Cve => write!(f, "CVE"), + IdentifierType::Ghsa => write!(f, "GHSA"), + IdentifierType::Purl => write!(f, "PURL"), + IdentifierType::Package => write!(f, "package name"), + } + } +} + +fn detect_identifier_type(identifier: &str) -> Option { + let uuid_re = Regex::new(r"(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$").unwrap(); + let cve_re = Regex::new(r"(?i)^CVE-\d{4}-\d+$").unwrap(); + let ghsa_re = Regex::new(r"(?i)^GHSA-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}$").unwrap(); + + if uuid_re.is_match(identifier) { + Some(IdentifierType::Uuid) + } else if cve_re.is_match(identifier) { + Some(IdentifierType::Cve) + } else if ghsa_re.is_match(identifier) { + Some(IdentifierType::Ghsa) + } else if is_purl(identifier) { + Some(IdentifierType::Purl) + } else { + None + } +} + +/// Select one patch per PURL from available patches. +/// +/// - Paid users: auto-select the most recent paid patch per PURL. +/// - Free users with one patch: auto-select it. +/// - Free users with multiple patches: interactive selection via dialoguer. +/// - JSON mode with multiple free patches: returns an error with options list. +/// +/// Returns `Ok(selected_patches)` or `Err(exit_code)` if selection fails. +pub fn select_patches( + patches: &[PatchSearchResult], + can_access_paid: bool, + is_json: bool, +) -> Result, i32> { + // Group accessible patches by PURL + let mut by_purl: HashMap> = HashMap::new(); + for p in patches { + if p.tier == "free" || can_access_paid { + by_purl.entry(p.purl.clone()).or_default().push(p); + } + } + + let mut selected = Vec::new(); + + for (purl, mut group) in by_purl { + // Sort by published_at descending (most recent first) + group.sort_by(|a, b| b.published_at.cmp(&a.published_at)); + + if can_access_paid { + // Paid user: prefer most recent paid patch, fallback to most recent free + let choice = group + .iter() + .find(|p| p.tier == "paid") + .or_else(|| group.first()) + .unwrap(); + selected.push((*choice).clone()); + } else if group.len() == 1 { + selected.push(group[0].clone()); + } else { + // Free user with multiple patches: interactive selection + let options: Vec = group + .iter() + .map(|p| { + let vuln_summary: Vec = p + .vulnerabilities + .iter() + .map(|(id, v)| { + if v.cves.is_empty() { + id.clone() + } else { + v.cves.join(", ") + } + }) + .collect(); + let vulns = if vuln_summary.is_empty() { + String::new() + } else { + format!(" (fixes: {})", vuln_summary.join(", ")) + }; + let desc = if p.description.len() > 60 { + format!("{}...", &p.description[..57]) + } else { + p.description.clone() + }; + format!("{} [{}]{} - {}", p.uuid, p.tier, vulns, desc) + }) + .collect(); + + match select_one( + &format!("Multiple patches available for {purl}. Select one:"), + &options, + is_json, + ) { + Ok(idx) => { + selected.push(group[idx].clone()); + } + Err(SelectError::JsonModeNeedsExplicit) => { + let options_json: Vec = group + .iter() + .map(|p| { + let vulns: Vec = p + .vulnerabilities + .iter() + .map(|(id, v)| { + serde_json::json!({ + "id": id, + "cves": v.cves, + "severity": v.severity, + "summary": v.summary, + }) + }) + .collect(); + serde_json::json!({ + "uuid": p.uuid, + "tier": p.tier, + "published_at": p.published_at, + "description": p.description, + "vulnerabilities": vulns, + }) + }) + .collect(); + println!( + "{}", + serde_json::to_string_pretty(&serde_json::json!({ + "status": "selection_required", + "error": format!("Multiple patches available for {purl}. Specify --id to select one."), + "purl": purl, + "options": options_json, + })) + .unwrap() + ); + return Err(1); + } + Err(SelectError::Cancelled) => { + eprintln!("Selection cancelled."); + return Err(0); + } + } + } + } + + Ok(selected) +} + +/// Download parameters shared between get and scan commands. +#[allow(dead_code)] +pub struct DownloadParams { + pub cwd: PathBuf, + pub org: Option, + pub save_only: bool, + pub one_off: bool, + pub global: bool, + pub global_prefix: Option, + pub json: bool, + pub silent: bool, +} + +/// Download and apply a set of selected patches. +/// +/// Used by both `get` and `scan` commands. Returns (exit_code, json_result). +pub async fn download_and_apply_patches( + selected: &[PatchSearchResult], + params: &DownloadParams, +) -> (i32, serde_json::Value) { + let (api_client, _) = get_api_client_from_env(params.org.as_deref()).await; + let effective_org: Option<&str> = None; + + let socket_dir = params.cwd.join(".socket"); + let blobs_dir = socket_dir.join("blobs"); + let manifest_path = socket_dir.join("manifest.json"); + + if let Err(e) = tokio::fs::create_dir_all(&socket_dir).await { + let err = format!("Failed to create .socket directory: {}", e); + if params.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": &err, + })).unwrap()); + } else { + eprintln!("Error: {}", &err); + } + return (1, serde_json::json!({"status": "error", "error": err})); + } + if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await { + let err = format!("Failed to create blobs directory: {}", e); + if params.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": &err, + })).unwrap()); + } else { + eprintln!("Error: {}", &err); + } + return (1, serde_json::json!({"status": "error", "error": err})); + } + + let mut manifest = match read_manifest(&manifest_path).await { + Ok(Some(m)) => m, + _ => PatchManifest::new(), + }; + + if !params.json && !params.silent { + eprintln!("\nDownloading {} patch(es)...", selected.len()); + } + + let mut patches_added = 0; + let mut patches_skipped = 0; + let mut patches_failed = 0; + let mut downloaded_patches: Vec = Vec::new(); + let mut updates: Vec = Vec::new(); + + for search_result in selected { + // Check for updates: existing patch with different UUID + if let Some(existing) = manifest.patches.get(&search_result.purl) { + if existing.uuid != search_result.uuid { + updates.push(search_result.purl.clone()); + if !params.json && !params.silent { + eprintln!( + " [update] {} (replacing {})", + search_result.purl, + &existing.uuid[..8] + ); + } + } + } + + match api_client + .fetch_patch(effective_org, &search_result.uuid) + .await + { + Ok(Some(patch)) => { + // Check if already in manifest with same UUID + if manifest + .patches + .get(&patch.purl) + .is_some_and(|p| p.uuid == patch.uuid) + { + if !params.json && !params.silent { + eprintln!(" [skip] {} (already in manifest)", patch.purl); + } + downloaded_patches.push(serde_json::json!({ + "purl": patch.purl, + "uuid": patch.uuid, + "action": "skipped", + })); + patches_skipped += 1; + continue; + } + + // Save blob contents + let mut patch_failed = false; + let mut files = HashMap::new(); + for (file_path, file_info) in &patch.files { + if let (Some(ref before), Some(ref after)) = + (&file_info.before_hash, &file_info.after_hash) + { + files.insert( + file_path.clone(), + PatchFileInfo { + before_hash: before.clone(), + after_hash: after.clone(), + }, + ); + } + + if let (Some(ref blob_content), Some(ref after_hash)) = + (&file_info.blob_content, &file_info.after_hash) + { + match base64_decode(blob_content) { + Ok(decoded) => { + let blob_path = blobs_dir.join(after_hash); + if let Err(e) = tokio::fs::write(&blob_path, &decoded).await { + if !params.json && !params.silent { + eprintln!(" [error] Failed to write blob for {}: {}", file_path, e); + } + patch_failed = true; + break; + } + } + Err(e) => { + if !params.json && !params.silent { + eprintln!(" [error] Failed to decode blob for {}: {}", file_path, e); + } + patch_failed = true; + break; + } + } + } + + // Also store beforeHash blob if present (needed for rollback) + if let (Some(ref before_blob), Some(ref before_hash)) = + (&file_info.before_blob_content, &file_info.before_hash) + { + match base64_decode(before_blob) { + Ok(decoded) => { + if let Err(e) = tokio::fs::write(blobs_dir.join(before_hash), &decoded).await { + if !params.json && !params.silent { + eprintln!(" [error] Failed to write before-blob for {}: {}", file_path, e); + } + patch_failed = true; + break; + } + } + Err(e) => { + if !params.json && !params.silent { + eprintln!(" [error] Failed to decode before-blob for {}: {}", file_path, e); + } + patch_failed = true; + break; + } + } + } + } + + if patch_failed { + patches_failed += 1; + downloaded_patches.push(serde_json::json!({ + "purl": patch.purl, + "uuid": patch.uuid, + "action": "failed", + "error": "Blob decode or write failed", + })); + continue; + } + + let vulnerabilities: HashMap = patch + .vulnerabilities + .iter() + .map(|(id, v)| { + ( + id.clone(), + VulnerabilityInfo { + cves: v.cves.clone(), + summary: v.summary.clone(), + severity: v.severity.clone(), + description: v.description.clone(), + }, + ) + }) + .collect(); + + manifest.patches.insert( + patch.purl.clone(), + PatchRecord { + uuid: patch.uuid.clone(), + exported_at: patch.published_at.clone(), + files, + vulnerabilities, + description: patch.description.clone(), + license: patch.license.clone(), + tier: patch.tier.clone(), + }, + ); + + if !params.json && !params.silent { + eprintln!(" [add] {}", patch.purl); + } + downloaded_patches.push(serde_json::json!({ + "purl": patch.purl, + "uuid": patch.uuid, + "action": "added", + })); + patches_added += 1; + } + Ok(None) => { + if !params.json && !params.silent { + eprintln!(" [fail] {} (could not fetch details)", search_result.purl); + } + downloaded_patches.push(serde_json::json!({ + "purl": search_result.purl, + "uuid": search_result.uuid, + "action": "failed", + "error": "could not fetch details", + })); + patches_failed += 1; + } + Err(e) => { + if !params.json && !params.silent { + eprintln!(" [fail] {} ({e})", search_result.purl); + } + downloaded_patches.push(serde_json::json!({ + "purl": search_result.purl, + "uuid": search_result.uuid, + "action": "failed", + "error": e.to_string(), + })); + patches_failed += 1; + } + } + } + + // Write manifest + if let Err(e) = write_manifest(&manifest_path, &manifest).await { + let err_json = serde_json::json!({ + "status": "error", + "error": format!("Error writing manifest: {e}"), + }); + if params.json { + println!("{}", serde_json::to_string_pretty(&err_json).unwrap()); + } else { + eprintln!("Error writing manifest: {e}"); + } + return (1, err_json); + } + + if !params.json && !params.silent { + eprintln!("\nPatches saved to {}", manifest_path.display()); + eprintln!(" Added: {patches_added}"); + if patches_skipped > 0 { + eprintln!(" Skipped: {patches_skipped}"); + } + if patches_failed > 0 { + eprintln!(" Failed: {patches_failed}"); + } + if !updates.is_empty() { + eprintln!(" Updated: {}", updates.len()); + } + } + + // Auto-apply unless --save-only + let mut apply_succeeded = false; + if !params.save_only && patches_added > 0 { + if !params.json && !params.silent { + eprintln!("\nApplying patches..."); + } + let apply_args = super::apply::ApplyArgs { + cwd: params.cwd.clone(), + dry_run: false, + silent: params.json || params.silent, + manifest_path: manifest_path.display().to_string(), + offline: false, + global: params.global, + global_prefix: params.global_prefix.clone(), + ecosystems: None, + force: false, + json: false, + verbose: false, + }; + let code = super::apply::run(apply_args).await; + apply_succeeded = code == 0; + if code != 0 && !params.json && !params.silent { + eprintln!("\nSome patches could not be applied."); + } + } + + let result_json = serde_json::json!({ + "status": if patches_failed > 0 { "partial_failure" } else { "success" }, + "found": selected.len(), + "downloaded": patches_added, + "skipped": patches_skipped, + "failed": patches_failed, + "applied": if apply_succeeded { patches_added } else { 0 }, + "updated": updates.len(), + "patches": downloaded_patches, + }); + + let exit_code = if patches_failed > 0 || (!apply_succeeded && patches_added > 0 && !params.save_only) { 1 } else { 0 }; + (exit_code, result_json) +} + +pub async fn run(args: GetArgs) -> i32 { + // Validate flags + let type_flags = [args.id, args.cve, args.ghsa, args.package] + .iter() + .filter(|&&f| f) + .count(); + if type_flags > 1 { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Only one of --id, --cve, --ghsa, or --package can be specified", + })).unwrap()); + } else { + eprintln!("Error: Only one of --id, --cve, --ghsa, or --package can be specified"); + } + return 1; + } + if args.one_off && args.save_only { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "--one-off and --save-only cannot be used together", + })).unwrap()); + } else { + eprintln!("Error: --one-off and --save-only cannot be used together"); + } + return 1; + } + + // Override env vars + if let Some(ref url) = args.api_url { + std::env::set_var("SOCKET_API_URL", url); + } + if let Some(ref token) = args.api_token { + std::env::set_var("SOCKET_API_TOKEN", token); + } + + let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()).await; + + // org slug is already stored in the client + let effective_org_slug: Option<&str> = None; + + // Determine identifier type + let id_type = if args.id { + IdentifierType::Uuid + } else if args.cve { + IdentifierType::Cve + } else if args.ghsa { + IdentifierType::Ghsa + } else if args.package { + IdentifierType::Package + } else { + match detect_identifier_type(&args.identifier) { + Some(t) => t, + None => { + if !args.json { + println!("Treating \"{}\" as a package name search", args.identifier); + } + IdentifierType::Package + } + } + }; + + // Handle UUID: fetch and download directly + if id_type == IdentifierType::Uuid { + if !args.json { + println!("Fetching patch by UUID: {}", args.identifier); + } + match api_client + .fetch_patch(effective_org_slug, &args.identifier) + .await + { + Ok(Some(patch)) => { + if patch.tier == "paid" && use_public_proxy { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "paid_required", + "found": 1, + "downloaded": 0, + "applied": 0, + "patches": [{ + "purl": patch.purl, + "uuid": patch.uuid, + "tier": "paid", + }], + })).unwrap()); + } else { + println!("\nThis patch requires a paid subscription to download."); + println!("\n Patch: {}", patch.purl); + println!(" Tier: paid"); + println!("\n Upgrade at: https://socket.dev/pricing\n"); + } + return 0; + } + + // Save to manifest + return save_and_apply_patch(&args, &patch.purl, &patch.uuid, effective_org_slug) + .await; + } + Ok(None) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "not_found", + "found": 0, + "downloaded": 0, + "applied": 0, + "patches": [], + })).unwrap()); + } else { + println!("No patch found with UUID: {}", args.identifier); + } + return 0; + } + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + return 1; + } + } + } + + // For CVE/GHSA/PURL/package, search first + let search_response: SearchResponse = match id_type { + IdentifierType::Cve => { + if !args.json { + println!("Searching patches for CVE: {}", args.identifier); + } + match api_client + .search_patches_by_cve(effective_org_slug, &args.identifier) + .await + { + Ok(r) => r, + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + return 1; + } + } + } + IdentifierType::Ghsa => { + if !args.json { + println!("Searching patches for GHSA: {}", args.identifier); + } + match api_client + .search_patches_by_ghsa(effective_org_slug, &args.identifier) + .await + { + Ok(r) => r, + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + return 1; + } + } + } + IdentifierType::Purl => { + if !args.json { + println!("Searching patches for PURL: {}", args.identifier); + } + match api_client + .search_patches_by_package(effective_org_slug, &args.identifier) + .await + { + Ok(r) => r, + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + return 1; + } + } + } + IdentifierType::Package => { + if !args.json { + println!("Enumerating packages..."); + } + let crawler_options = CrawlerOptions { + cwd: args.cwd.clone(), + global: args.global, + global_prefix: args.global_prefix.clone(), + batch_size: 100, + }; + let (all_packages, _) = crawl_all_ecosystems(&crawler_options).await; + + if all_packages.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "no_packages", + "found": 0, + "downloaded": 0, + "applied": 0, + "patches": [], + })).unwrap()); + } else if args.global { + println!("No global packages found."); + } else { + #[allow(unused_mut)] + let mut install_cmds = String::from("npm/yarn/pnpm/pip"); + #[cfg(feature = "cargo")] + install_cmds.push_str("/cargo"); + #[cfg(feature = "golang")] + install_cmds.push_str("/go"); + #[cfg(feature = "maven")] + install_cmds.push_str("/mvn"); + #[cfg(feature = "composer")] + install_cmds.push_str("/composer"); + println!("No packages found. Run {install_cmds} install first."); + } + return 0; + } + + if !args.json { + println!("Found {} packages", all_packages.len()); + } + + let matches = fuzzy_match_packages(&args.identifier, &all_packages, 20); + + if matches.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "no_match", + "found": 0, + "downloaded": 0, + "applied": 0, + "patches": [], + })).unwrap()); + } else { + println!("No packages matching \"{}\" found.", args.identifier); + } + return 0; + } + + if !args.json { + println!( + "Found {} matching package(s), checking for available patches...", + matches.len() + ); + } + + // Search for patches for the best match + let best_match = &matches[0]; + match api_client + .search_patches_by_package(effective_org_slug, &best_match.purl) + .await + { + Ok(r) => r, + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + return 1; + } + } + } + _ => unreachable!(), + }; + + if search_response.patches.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "not_found", + "found": 0, + "downloaded": 0, + "applied": 0, + "patches": [], + })).unwrap()); + } else { + println!( + "No patches found for {}: {}", + id_type, args.identifier + ); + } + return 0; + } + + if !args.json { + display_search_results(&search_response.patches, search_response.can_access_paid_patches); + } + + // Filter accessible patches + let accessible: Vec<_> = search_response + .patches + .iter() + .filter(|p| p.tier == "free" || search_response.can_access_paid_patches) + .cloned() + .collect(); + + if accessible.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "paid_required", + "found": search_response.patches.len(), + "downloaded": 0, + "applied": 0, + "patches": search_response.patches.iter().map(|p| serde_json::json!({ + "purl": p.purl, + "uuid": p.uuid, + "tier": p.tier, + })).collect::>(), + })).unwrap()); + } else { + println!("\nAll available patches require a paid subscription."); + println!("\n Upgrade at: https://socket.dev/pricing\n"); + } + return 0; + } + + // Smart patch selection: pick one patch per PURL + let selected = match select_patches( + &accessible, + search_response.can_access_paid_patches, + args.json, + ) { + Ok(s) => s, + Err(code) => return code, + }; + + if selected.is_empty() { + if !args.json { + println!("No patches selected."); + } + return 0; + } + + // Confirm before downloading (default YES) + let prompt = format!("Download {} patch(es)?", selected.len()); + if !confirm(&prompt, true, args.yes, args.json) { + if !args.json { + println!("Download cancelled."); + } + return 0; + } + + // Download and apply + let params = DownloadParams { + cwd: args.cwd.clone(), + org: args.org.clone(), + save_only: args.save_only, + one_off: args.one_off, + global: args.global, + global_prefix: args.global_prefix.clone(), + json: args.json, + silent: false, + }; + + let (code, result_json) = download_and_apply_patches(&selected, ¶ms).await; + + if args.json { + println!("{}", serde_json::to_string_pretty(&result_json).unwrap()); + } + + code +} + +fn display_search_results(patches: &[PatchSearchResult], can_access_paid: bool) { + println!("\nFound patches:\n"); + + for (i, patch) in patches.iter().enumerate() { + let tier_label = if patch.tier == "paid" { + " [PAID]" + } else { + " [FREE]" + }; + let access_label = if patch.tier == "paid" && !can_access_paid { + " (no access)" + } else { + "" + }; + + println!(" {}. {}{}{}", i + 1, patch.purl, tier_label, access_label); + println!(" UUID: {}", patch.uuid); + if !patch.description.is_empty() { + let desc = if patch.description.len() > 80 { + format!("{}...", &patch.description[..77]) + } else { + patch.description.clone() + }; + println!(" Description: {desc}"); + } + + let vuln_ids: Vec<_> = patch.vulnerabilities.keys().collect(); + if !vuln_ids.is_empty() { + let vuln_summary: Vec = patch + .vulnerabilities + .iter() + .map(|(id, vuln)| { + let cves = if vuln.cves.is_empty() { + id.to_string() + } else { + vuln.cves.join(", ") + }; + format!("{cves} ({})", vuln.severity) + }) + .collect(); + println!(" Fixes: {}", vuln_summary.join(", ")); + } + println!(); + } +} + +async fn save_and_apply_patch( + args: &GetArgs, + _purl: &str, + uuid: &str, + _org_slug: Option<&str>, +) -> i32 { + // For UUID mode, fetch and save + let (api_client, _) = get_api_client_from_env(args.org.as_deref()).await; + let effective_org: Option<&str> = None; // org slug is already stored in the client + + let patch = match api_client.fetch_patch(effective_org, uuid).await { + Ok(Some(p)) => p, + Ok(None) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "not_found", + "found": 0, + "downloaded": 0, + "applied": 0, + "patches": [], + })).unwrap()); + } else { + println!("No patch found with UUID: {uuid}"); + } + return 0; + } + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + return 1; + } + }; + + let socket_dir = args.cwd.join(".socket"); + let blobs_dir = socket_dir.join("blobs"); + let manifest_path = socket_dir.join("manifest.json"); + + if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": format!("Failed to create blobs directory: {}", e), + })).unwrap()); + } else { + eprintln!("Error: Failed to create blobs directory: {}", e); + } + return 1; + } + + let mut manifest = match read_manifest(&manifest_path).await { + Ok(Some(m)) => m, + _ => PatchManifest::new(), + }; + + // Build and save patch record + let mut blob_failed = false; + let mut files = HashMap::new(); + for (file_path, file_info) in &patch.files { + if let Some(ref after) = file_info.after_hash { + files.insert( + file_path.clone(), + PatchFileInfo { + before_hash: file_info + .before_hash + .clone() + .unwrap_or_default(), + after_hash: after.clone(), + }, + ); + } + if let (Some(ref blob_content), Some(ref after_hash)) = + (&file_info.blob_content, &file_info.after_hash) + { + match base64_decode(blob_content) { + Ok(decoded) => { + if let Err(e) = tokio::fs::write(blobs_dir.join(after_hash), &decoded).await { + if !args.json { + eprintln!(" [error] Failed to write blob for {}: {}", file_path, e); + } + blob_failed = true; + break; + } + } + Err(e) => { + if !args.json { + eprintln!(" [error] Failed to decode blob for {}: {}", file_path, e); + } + blob_failed = true; + break; + } + } + } + // Also store beforeHash blob if present (needed for rollback) + if let (Some(ref before_blob), Some(ref before_hash)) = + (&file_info.before_blob_content, &file_info.before_hash) + { + match base64_decode(before_blob) { + Ok(decoded) => { + if let Err(e) = tokio::fs::write(blobs_dir.join(before_hash), &decoded).await { + if !args.json { + eprintln!(" [error] Failed to write before-blob for {}: {}", file_path, e); + } + blob_failed = true; + break; + } + } + Err(e) => { + if !args.json { + eprintln!(" [error] Failed to decode before-blob for {}: {}", file_path, e); + } + blob_failed = true; + break; + } + } + } + } + + if blob_failed { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "found": 1, + "downloaded": 0, + "applied": 0, + "error": "Blob decode or write failed", + "patches": [{ + "purl": patch.purl, + "uuid": patch.uuid, + "action": "failed", + "error": "Blob decode or write failed", + }], + })).unwrap()); + } else { + eprintln!("Error: Blob decode or write failed for patch {}", patch.purl); + } + return 1; + } + + let vulnerabilities: HashMap = patch + .vulnerabilities + .iter() + .map(|(id, v)| { + ( + id.clone(), + VulnerabilityInfo { + cves: v.cves.clone(), + summary: v.summary.clone(), + severity: v.severity.clone(), + description: v.description.clone(), + }, + ) + }) + .collect(); + + let added = manifest + .patches + .get(&patch.purl) + .is_none_or(|p| p.uuid != patch.uuid); + + manifest.patches.insert( + patch.purl.clone(), + PatchRecord { + uuid: patch.uuid.clone(), + exported_at: patch.published_at.clone(), + files, + vulnerabilities, + description: patch.description.clone(), + license: patch.license.clone(), + tier: patch.tier.clone(), + }, + ); + + if let Err(e) = write_manifest(&manifest_path, &manifest).await { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": format!("Error writing manifest: {e}"), + })).unwrap()); + } else { + eprintln!("Error writing manifest: {e}"); + } + return 1; + } + + if !args.json { + println!("\nPatch saved to {}", manifest_path.display()); + if added { + println!(" Added: 1"); + } else { + println!(" Skipped: 1 (already exists)"); + } + } + + let mut apply_succeeded = false; + if !args.save_only && added { + if !args.json { + println!("\nApplying patches..."); + } + let apply_args = super::apply::ApplyArgs { + cwd: args.cwd.clone(), + dry_run: false, + silent: args.json, + manifest_path: manifest_path.display().to_string(), + offline: false, + global: args.global, + global_prefix: args.global_prefix.clone(), + ecosystems: None, + force: false, + json: false, + verbose: false, + }; + let code = super::apply::run(apply_args).await; + apply_succeeded = code == 0; + if code != 0 && !args.json { + eprintln!("\nSome patches could not be applied."); + } + } + + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "success", + "found": 1, + "downloaded": if added { 1 } else { 0 }, + "applied": if apply_succeeded { 1 } else { 0 }, + "patches": [{ + "purl": patch.purl, + "uuid": patch.uuid, + "action": if added { "added" } else { "skipped" }, + }], + })).unwrap()); + } + + if !apply_succeeded && added && !args.save_only { 1 } else { 0 } +} + +fn base64_decode(input: &str) -> Result, String> { + let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut table = [255u8; 256]; + for (i, &c) in chars.iter().enumerate() { + table[c as usize] = i as u8; + } + + let input = input.as_bytes(); + let mut output = Vec::with_capacity(input.len() * 3 / 4); + + let mut buf = 0u32; + let mut bits = 0u32; + + for &b in input { + if b == b'=' || b == b'\n' || b == b'\r' { + continue; + } + let val = table[b as usize]; + if val == 255 { + return Err(format!("Invalid base64 character: {}", b as char)); + } + buf = (buf << 6) | val as u32; + bits += 6; + if bits >= 8 { + bits -= 8; + output.push((buf >> bits) as u8); + buf &= (1 << bits) - 1; + } + } + + Ok(output) +} diff --git a/crates/socket-patch-cli/src/commands/list.rs b/crates/socket-patch-cli/src/commands/list.rs new file mode 100644 index 0000000..8dc00a6 --- /dev/null +++ b/crates/socket-patch-cli/src/commands/list.rs @@ -0,0 +1,143 @@ +use clap::Args; +use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; +use socket_patch_core::manifest::operations::read_manifest; +use std::path::{Path, PathBuf}; + +#[derive(Args)] +pub struct ListArgs { + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Path to patch manifest file + #[arg(short = 'm', long = "manifest-path", default_value = DEFAULT_PATCH_MANIFEST_PATH)] + pub manifest_path: String, + + /// Output as JSON + #[arg(long, default_value_t = false)] + pub json: bool, +} + +pub async fn run(args: ListArgs) -> i32 { + let manifest_path = if Path::new(&args.manifest_path).is_absolute() { + PathBuf::from(&args.manifest_path) + } else { + args.cwd.join(&args.manifest_path) + }; + + // Check if manifest exists + if tokio::fs::metadata(&manifest_path).await.is_err() { + if args.json { + println!( + "{}", + serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Manifest not found", + "path": manifest_path.display().to_string() + })).unwrap() + ); + } else { + eprintln!("Manifest not found at {}", manifest_path.display()); + } + return 1; + } + + match read_manifest(&manifest_path).await { + Ok(Some(manifest)) => { + let patch_entries: Vec<_> = manifest.patches.iter().collect(); + + if patch_entries.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ "status": "success", "patches": [] })).unwrap()); + } else { + println!("No patches found in manifest."); + } + return 0; + } + + if args.json { + let json_output = serde_json::json!({ + "status": "success", + "patches": patch_entries.iter().map(|(purl, patch)| { + serde_json::json!({ + "purl": purl, + "uuid": patch.uuid, + "exportedAt": patch.exported_at, + "tier": patch.tier, + "license": patch.license, + "description": patch.description, + "files": patch.files.keys().collect::>(), + "vulnerabilities": patch.vulnerabilities.iter().map(|(id, vuln)| { + serde_json::json!({ + "id": id, + "cves": vuln.cves, + "summary": vuln.summary, + "severity": vuln.severity, + "description": vuln.description, + }) + }).collect::>(), + }) + }).collect::>() + }); + println!("{}", serde_json::to_string_pretty(&json_output).unwrap()); + } else { + println!("Found {} patch(es):\n", patch_entries.len()); + + for (purl, patch) in &patch_entries { + println!("Package: {purl}"); + println!(" UUID: {}", patch.uuid); + println!(" Tier: {}", patch.tier); + println!(" License: {}", patch.license); + println!(" Exported: {}", patch.exported_at); + + if !patch.description.is_empty() { + println!(" Description: {}", patch.description); + } + + let vuln_entries: Vec<_> = patch.vulnerabilities.iter().collect(); + if !vuln_entries.is_empty() { + println!(" Vulnerabilities ({}):", vuln_entries.len()); + for (id, vuln) in &vuln_entries { + let cve_list = if vuln.cves.is_empty() { + String::new() + } else { + format!(" ({})", vuln.cves.join(", ")) + }; + println!(" - {id}{cve_list}"); + println!(" Severity: {}", vuln.severity); + println!(" Summary: {}", vuln.summary); + } + } + + let file_list: Vec<_> = patch.files.keys().collect(); + if !file_list.is_empty() { + println!(" Files patched ({}):", file_list.len()); + for file_path in &file_list { + println!(" - {file_path}"); + } + } + + println!(); + } + } + + 0 + } + Ok(None) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ "status": "error", "error": "Invalid manifest" })).unwrap()); + } else { + eprintln!("Error: Invalid manifest at {}", manifest_path.display()); + } + 1 + } + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ "status": "error", "error": e.to_string() })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + 1 + } + } +} diff --git a/crates/socket-patch-cli/src/commands/mod.rs b/crates/socket-patch-cli/src/commands/mod.rs new file mode 100644 index 0000000..499366f --- /dev/null +++ b/crates/socket-patch-cli/src/commands/mod.rs @@ -0,0 +1,8 @@ +pub mod apply; +pub mod get; +pub mod list; +pub mod remove; +pub mod repair; +pub mod rollback; +pub mod scan; +pub mod setup; diff --git a/crates/socket-patch-cli/src/commands/remove.rs b/crates/socket-patch-cli/src/commands/remove.rs new file mode 100644 index 0000000..8acff80 --- /dev/null +++ b/crates/socket-patch-cli/src/commands/remove.rs @@ -0,0 +1,349 @@ +use clap::Args; +use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; +use socket_patch_core::manifest::operations::{read_manifest, write_manifest}; +use socket_patch_core::manifest::schema::PatchManifest; +use socket_patch_core::utils::cleanup_blobs::{cleanup_unused_blobs, format_cleanup_result}; +use socket_patch_core::utils::telemetry::{track_patch_removed, track_patch_remove_failed}; +use std::path::{Path, PathBuf}; + +use super::rollback::rollback_patches; +use crate::output::confirm; + +#[derive(Args)] +pub struct RemoveArgs { + /// Package PURL or patch UUID + pub identifier: String, + + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Path to patch manifest file + #[arg(short = 'm', long = "manifest-path", default_value = DEFAULT_PATCH_MANIFEST_PATH)] + pub manifest_path: String, + + /// Skip rolling back files before removing (only update manifest) + #[arg(long = "skip-rollback", default_value_t = false)] + pub skip_rollback: bool, + + /// Skip confirmation prompts + #[arg(short = 'y', long, default_value_t = false)] + pub yes: bool, + + /// Remove patches from globally installed npm packages + #[arg(short = 'g', long, default_value_t = false)] + pub global: bool, + + /// Custom path to global node_modules + #[arg(long = "global-prefix")] + pub global_prefix: Option, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, +} + +pub async fn run(args: RemoveArgs) -> i32 { + let (telemetry_client, _) = + socket_patch_core::api::client::get_api_client_from_env(None).await; + let api_token = telemetry_client.api_token().cloned(); + let org_slug = telemetry_client.org_slug().cloned(); + + let manifest_path = if Path::new(&args.manifest_path).is_absolute() { + PathBuf::from(&args.manifest_path) + } else { + args.cwd.join(&args.manifest_path) + }; + + if tokio::fs::metadata(&manifest_path).await.is_err() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Manifest not found", + "path": manifest_path.display().to_string(), + })).unwrap()); + } else { + eprintln!("Manifest not found at {}", manifest_path.display()); + } + return 1; + } + + // Read manifest to show what will be removed and confirm + let manifest = match read_manifest(&manifest_path).await { + Ok(Some(m)) => m, + Ok(None) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Invalid manifest", + })).unwrap()); + } else { + eprintln!("Invalid manifest at {}", manifest_path.display()); + } + return 1; + } + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e.to_string(), + })).unwrap()); + } else { + eprintln!("Error reading manifest: {e}"); + } + return 1; + } + }; + + // Find matching patches to show what will be removed + let matching: Vec<(&String, &socket_patch_core::manifest::schema::PatchRecord)> = + if args.identifier.starts_with("pkg:") { + manifest + .patches + .iter() + .filter(|(purl, _)| *purl == &args.identifier) + .collect() + } else { + manifest + .patches + .iter() + .filter(|(_, patch)| patch.uuid == args.identifier) + .collect() + }; + + if matching.is_empty() { + track_patch_remove_failed( + &format!("No patch found matching identifier: {}", args.identifier), + api_token.as_deref(), + org_slug.as_deref(), + ) + .await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "not_found", + "error": format!("No patch found matching identifier: {}", args.identifier), + "removed": 0, + "purls": [], + })).unwrap()); + } else { + eprintln!( + "No patch found matching identifier: {}", + args.identifier + ); + } + return 1; + } + + // Show what will be removed and confirm + if !args.json { + eprintln!("The following patch(es) will be removed:"); + for (purl, patch) in &matching { + let file_count = patch.files.len(); + eprintln!(" - {} (UUID: {}, {} file(s))", purl, &patch.uuid[..8], file_count); + } + eprintln!(); + } + + let prompt = format!( + "Remove {} patch(es) and rollback files?", + matching.len() + ); + if !confirm(&prompt, true, args.yes, args.json) { + if !args.json { + println!("Removal cancelled."); + } + return 0; + } + + // First, rollback the patch if not skipped + let mut rollback_count = 0; + if !args.skip_rollback { + if !args.json { + println!("Rolling back patch before removal..."); + } + match rollback_patches( + &args.cwd, + &manifest_path, + Some(&args.identifier), + false, + args.json, // silent when JSON + false, + args.global, + args.global_prefix.clone(), + None, + ) + .await + { + Ok((success, results)) => { + if !success { + track_patch_remove_failed( + "Rollback failed during patch removal", + api_token.as_deref(), + org_slug.as_deref(), + ) + .await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Rollback failed during patch removal. Use --skip-rollback to remove from manifest without restoring files.", + })).unwrap()); + } else { + eprintln!("\nRollback failed. Use --skip-rollback to remove from manifest without restoring files."); + } + return 1; + } + + rollback_count = results + .iter() + .filter(|r| r.success && !r.files_rolled_back.is_empty()) + .count(); + let already_original = results + .iter() + .filter(|r| { + r.success + && r.files_verified.iter().all(|f| { + f.status + == socket_patch_core::patch::rollback::VerifyRollbackStatus::AlreadyOriginal + }) + }) + .count(); + + if !args.json { + if rollback_count > 0 { + println!("Rolled back {rollback_count} package(s)"); + } + if already_original > 0 { + println!("{already_original} package(s) already in original state"); + } + if results.is_empty() { + println!("No packages found to rollback (not installed)"); + } + println!(); + } + } + Err(e) => { + track_patch_remove_failed(&e, api_token.as_deref(), org_slug.as_deref()).await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": format!("Error during rollback: {e}. Use --skip-rollback to remove from manifest without restoring files."), + })).unwrap()); + } else { + eprintln!("Error during rollback: {e}"); + eprintln!("\nRollback failed. Use --skip-rollback to remove from manifest without restoring files."); + } + return 1; + } + } + } + + // Now remove from manifest + match remove_patch_from_manifest(&args.identifier, &manifest_path).await { + Ok((removed, manifest)) => { + if removed.is_empty() { + track_patch_remove_failed( + &format!("No patch found matching identifier: {}", args.identifier), + api_token.as_deref(), + org_slug.as_deref(), + ) + .await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "not_found", + "error": format!("No patch found matching identifier: {}", args.identifier), + "removed": 0, + "purls": [], + })).unwrap()); + } else { + eprintln!( + "No patch found matching identifier: {}", + args.identifier + ); + } + return 1; + } + + if !args.json { + println!("Removed {} patch(es) from manifest:", removed.len()); + for purl in &removed { + println!(" - {purl}"); + } + println!("\nManifest updated at {}", manifest_path.display()); + } + + // Clean up unused blobs + let socket_dir = manifest_path.parent().unwrap(); + let blobs_path = socket_dir.join("blobs"); + let mut blobs_removed = 0; + if let Ok(cleanup_result) = cleanup_unused_blobs(&manifest, &blobs_path, false).await { + blobs_removed = cleanup_result.blobs_removed; + if !args.json && cleanup_result.blobs_removed > 0 { + println!("\n{}", format_cleanup_result(&cleanup_result, false)); + } + } + + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "success", + "removed": removed.len(), + "rolledBack": rollback_count, + "blobsCleaned": blobs_removed, + "purls": removed, + })).unwrap()); + } + + track_patch_removed(removed.len(), api_token.as_deref(), org_slug.as_deref()).await; + 0 + } + Err(e) => { + track_patch_remove_failed(&e, api_token.as_deref(), org_slug.as_deref()).await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e, + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + 1 + } + } +} + +async fn remove_patch_from_manifest( + identifier: &str, + manifest_path: &Path, +) -> Result<(Vec, PatchManifest), String> { + let mut manifest = read_manifest(manifest_path) + .await + .map_err(|e| e.to_string())? + .ok_or_else(|| "Invalid manifest".to_string())?; + + let mut removed = Vec::new(); + + if identifier.starts_with("pkg:") { + if manifest.patches.remove(identifier).is_some() { + removed.push(identifier.to_string()); + } + } else { + let purls_to_remove: Vec = manifest + .patches + .iter() + .filter(|(_, patch)| patch.uuid == identifier) + .map(|(purl, _)| purl.clone()) + .collect(); + + for purl in purls_to_remove { + manifest.patches.remove(&purl); + removed.push(purl); + } + } + + if !removed.is_empty() { + write_manifest(manifest_path, &manifest) + .await + .map_err(|e| e.to_string())?; + } + + Ok((removed, manifest)) +} diff --git a/crates/socket-patch-cli/src/commands/repair.rs b/crates/socket-patch-cli/src/commands/repair.rs new file mode 100644 index 0000000..33d7d04 --- /dev/null +++ b/crates/socket-patch-cli/src/commands/repair.rs @@ -0,0 +1,193 @@ +use clap::Args; +use socket_patch_core::api::blob_fetcher::{ + fetch_missing_blobs, format_fetch_result, get_missing_blobs, +}; +use socket_patch_core::api::client::get_api_client_from_env; +use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; +use socket_patch_core::manifest::operations::read_manifest; +use socket_patch_core::utils::cleanup_blobs::{cleanup_unused_blobs, format_cleanup_result}; +use std::path::{Path, PathBuf}; + +#[derive(Args)] +pub struct RepairArgs { + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Path to patch manifest file + #[arg(short = 'm', long = "manifest-path", default_value = DEFAULT_PATCH_MANIFEST_PATH)] + pub manifest_path: String, + + /// Show what would be done without actually doing it + #[arg(short = 'd', long = "dry-run", default_value_t = false)] + pub dry_run: bool, + + /// Skip network operations (cleanup only) + #[arg(long, default_value_t = false)] + pub offline: bool, + + /// Only download missing blobs, do not clean up + #[arg(long = "download-only", default_value_t = false)] + pub download_only: bool, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, +} + +pub async fn run(args: RepairArgs) -> i32 { + let manifest_path = if Path::new(&args.manifest_path).is_absolute() { + PathBuf::from(&args.manifest_path) + } else { + args.cwd.join(&args.manifest_path) + }; + + if tokio::fs::metadata(&manifest_path).await.is_err() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Manifest not found", + "path": manifest_path.display().to_string(), + })).unwrap()); + } else { + eprintln!("Manifest not found at {}", manifest_path.display()); + } + return 1; + } + + match repair_inner(&args, &manifest_path).await { + Ok(result) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&result).unwrap()); + } + 0 + } + Err(e) => { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e, + })).unwrap()); + } else { + eprintln!("Error: {e}"); + } + 1 + } + } +} + +async fn repair_inner(args: &RepairArgs, manifest_path: &Path) -> Result { + let manifest = read_manifest(manifest_path) + .await + .map_err(|e| e.to_string())? + .ok_or_else(|| "Invalid manifest".to_string())?; + + let socket_dir = manifest_path.parent().unwrap(); + let blobs_path = socket_dir.join("blobs"); + + let missing_count; + let mut downloaded_count = 0usize; + let mut download_failed_count = 0usize; + let mut blobs_cleaned = 0usize; + let mut blobs_checked = 0usize; + + // Step 1: Check for and download missing blobs + if !args.offline { + let missing_blobs = get_missing_blobs(&manifest, &blobs_path).await; + missing_count = missing_blobs.len(); + + if !missing_blobs.is_empty() { + if !args.json { + println!("Found {} missing blob(s)", missing_blobs.len()); + } + + if args.dry_run { + if !args.json { + println!("\nDry run - would download:"); + for hash in missing_blobs.iter().take(10) { + println!(" - {}...", &hash[..12.min(hash.len())]); + } + if missing_blobs.len() > 10 { + println!(" ... and {} more", missing_blobs.len() - 10); + } + } + } else { + if !args.json { + println!("\nDownloading missing blobs..."); + } + let (client, _) = get_api_client_from_env(None).await; + let fetch_result = fetch_missing_blobs(&manifest, &blobs_path, &client, None).await; + downloaded_count = fetch_result.downloaded; + download_failed_count = fetch_result.failed; + if !args.json { + println!("{}", format_fetch_result(&fetch_result)); + } + } + } else if !args.json { + println!("All blobs are present locally."); + } + } else { + let missing_blobs = get_missing_blobs(&manifest, &blobs_path).await; + missing_count = missing_blobs.len(); + if !missing_blobs.is_empty() { + if !args.json { + println!( + "Warning: {} blob(s) are missing (offline mode - not downloading)", + missing_blobs.len() + ); + for hash in missing_blobs.iter().take(5) { + println!(" - {}...", &hash[..12.min(hash.len())]); + } + if missing_blobs.len() > 5 { + println!(" ... and {} more", missing_blobs.len() - 5); + } + } + } else if !args.json { + println!("All blobs are present locally."); + } + } + + // Step 2: Clean up unused blobs + if !args.download_only { + if !args.json { + println!(); + } + match cleanup_unused_blobs(&manifest, &blobs_path, args.dry_run).await { + Ok(cleanup_result) => { + blobs_checked = cleanup_result.blobs_checked; + blobs_cleaned = cleanup_result.blobs_removed; + if !args.json { + if cleanup_result.blobs_checked == 0 { + println!("No blobs directory found, nothing to clean up."); + } else if cleanup_result.blobs_removed == 0 { + println!( + "Checked {} blob(s), all are in use.", + cleanup_result.blobs_checked + ); + } else { + println!("{}", format_cleanup_result(&cleanup_result, args.dry_run)); + } + } + } + Err(e) => { + if !args.json { + eprintln!("Warning: cleanup failed: {e}"); + } + } + } + } + + if !args.dry_run && !args.json { + println!("\nRepair complete."); + } + + Ok(serde_json::json!({ + "status": "success", + "dryRun": args.dry_run, + "missingBlobs": missing_count, + "downloaded": downloaded_count, + "downloadFailed": download_failed_count, + "blobsChecked": blobs_checked, + "blobsCleaned": blobs_cleaned, + })) +} diff --git a/crates/socket-patch-cli/src/commands/rollback.rs b/crates/socket-patch-cli/src/commands/rollback.rs new file mode 100644 index 0000000..93bf96f --- /dev/null +++ b/crates/socket-patch-cli/src/commands/rollback.rs @@ -0,0 +1,533 @@ +use clap::Args; +use socket_patch_core::api::blob_fetcher::{ + fetch_blobs_by_hash, format_fetch_result, +}; +use socket_patch_core::api::client::get_api_client_from_env; +use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; +use socket_patch_core::crawlers::CrawlerOptions; +use socket_patch_core::manifest::operations::read_manifest; +use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord}; +use socket_patch_core::patch::rollback::{rollback_package_patch, RollbackResult, VerifyRollbackStatus}; +use socket_patch_core::utils::telemetry::{track_patch_rolled_back, track_patch_rollback_failed}; +use std::collections::HashSet; +use std::path::{Path, PathBuf}; + +use crate::ecosystem_dispatch::{find_packages_for_rollback, partition_purls}; + +#[derive(Args)] +pub struct RollbackArgs { + /// Package PURL or patch UUID to rollback. Omit to rollback all patches. + pub identifier: Option, + + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Verify rollback can be performed without modifying files + #[arg(short = 'd', long = "dry-run", default_value_t = false)] + pub dry_run: bool, + + /// Only output errors + #[arg(short = 's', long, default_value_t = false)] + pub silent: bool, + + /// Path to patch manifest file + #[arg(short = 'm', long = "manifest-path", default_value = DEFAULT_PATCH_MANIFEST_PATH)] + pub manifest_path: String, + + /// Do not download missing blobs, fail if any are missing + #[arg(long, default_value_t = false)] + pub offline: bool, + + /// Rollback patches from globally installed npm packages + #[arg(short = 'g', long, default_value_t = false)] + pub global: bool, + + /// Custom path to global node_modules + #[arg(long = "global-prefix")] + pub global_prefix: Option, + + /// Rollback a patch by fetching beforeHash blobs from API (no manifest required) + #[arg(long = "one-off", default_value_t = false)] + pub one_off: bool, + + /// Organization slug + #[arg(long)] + pub org: Option, + + /// Socket API URL (overrides SOCKET_API_URL env var) + #[arg(long = "api-url")] + pub api_url: Option, + + /// Socket API token (overrides SOCKET_API_TOKEN env var) + #[arg(long = "api-token")] + pub api_token: Option, + + /// Restrict rollback to specific ecosystems + #[arg(long, value_delimiter = ',')] + pub ecosystems: Option>, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, + + /// Show detailed per-file verification information + #[arg(short = 'v', long, default_value_t = false)] + pub verbose: bool, +} + +struct PatchToRollback { + purl: String, + patch: PatchRecord, +} + +fn find_patches_to_rollback( + manifest: &PatchManifest, + identifier: Option<&str>, +) -> Vec { + match identifier { + None => manifest + .patches + .iter() + .map(|(purl, patch)| PatchToRollback { + purl: purl.clone(), + patch: patch.clone(), + }) + .collect(), + Some(id) => { + let mut patches = Vec::new(); + if id.starts_with("pkg:") { + if let Some(patch) = manifest.patches.get(id) { + patches.push(PatchToRollback { + purl: id.to_string(), + patch: patch.clone(), + }); + } + } else { + for (purl, patch) in &manifest.patches { + if patch.uuid == id { + patches.push(PatchToRollback { + purl: purl.clone(), + patch: patch.clone(), + }); + } + } + } + patches + } + } +} + +fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet { + let mut blobs = HashSet::new(); + for patch in manifest.patches.values() { + for file_info in patch.files.values() { + blobs.insert(file_info.before_hash.clone()); + } + } + blobs +} + +async fn get_missing_before_blobs( + manifest: &PatchManifest, + blobs_path: &Path, +) -> HashSet { + let before_blobs = get_before_hash_blobs(manifest); + let mut missing = HashSet::new(); + for hash in before_blobs { + let blob_path = blobs_path.join(&hash); + if tokio::fs::metadata(&blob_path).await.is_err() { + missing.insert(hash); + } + } + missing +} + +fn verify_rollback_status_str(status: &VerifyRollbackStatus) -> &'static str { + match status { + VerifyRollbackStatus::Ready => "ready", + VerifyRollbackStatus::AlreadyOriginal => "already_original", + VerifyRollbackStatus::HashMismatch => "hash_mismatch", + VerifyRollbackStatus::NotFound => "not_found", + VerifyRollbackStatus::MissingBlob => "missing_blob", + } +} + +fn result_to_json(result: &RollbackResult) -> serde_json::Value { + serde_json::json!({ + "purl": result.package_key, + "path": result.package_path, + "success": result.success, + "error": result.error, + "filesRolledBack": result.files_rolled_back, + "filesVerified": result.files_verified.iter().map(|f| { + serde_json::json!({ + "file": f.file, + "status": verify_rollback_status_str(&f.status), + "message": f.message, + "currentHash": f.current_hash, + "expectedHash": f.expected_hash, + "targetHash": f.target_hash, + }) + }).collect::>(), + }) +} + +pub async fn run(args: RollbackArgs) -> i32 { + // Override env vars if CLI options provided (before building client) + if let Some(ref url) = args.api_url { + std::env::set_var("SOCKET_API_URL", url); + } + if let Some(ref token) = args.api_token { + std::env::set_var("SOCKET_API_TOKEN", token); + } + + let (telemetry_client, _) = get_api_client_from_env(args.org.as_deref()).await; + let api_token = telemetry_client.api_token().cloned(); + let org_slug = telemetry_client.org_slug().cloned(); + + // Validate one-off requires identifier + if args.one_off && args.identifier.is_none() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "--one-off requires an identifier (UUID or PURL)", + })).unwrap()); + } else { + eprintln!("Error: --one-off requires an identifier (UUID or PURL)"); + } + return 1; + } + + // Handle one-off mode + if args.one_off { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "One-off rollback mode is not yet implemented", + })).unwrap()); + } else { + eprintln!("One-off rollback mode: fetching patch data..."); + } + return 1; + } + + let manifest_path = if Path::new(&args.manifest_path).is_absolute() { + PathBuf::from(&args.manifest_path) + } else { + args.cwd.join(&args.manifest_path) + }; + + if tokio::fs::metadata(&manifest_path).await.is_err() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": "Manifest not found", + "path": manifest_path.display().to_string(), + })).unwrap()); + } else if !args.silent { + eprintln!("Manifest not found at {}", manifest_path.display()); + } + return 1; + } + + match rollback_patches_inner(&args, &manifest_path).await { + Ok((success, results)) => { + let rolled_back_count = results + .iter() + .filter(|r| r.success && !r.files_rolled_back.is_empty()) + .count(); + let already_original_count = results + .iter() + .filter(|r| { + r.success + && r.files_verified.iter().all(|f| { + f.status == VerifyRollbackStatus::AlreadyOriginal + }) + }) + .count(); + let failed_count = results.iter().filter(|r| !r.success).count(); + + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": if success { "success" } else { "partial_failure" }, + "rolledBack": rolled_back_count, + "alreadyOriginal": already_original_count, + "failed": failed_count, + "dryRun": args.dry_run, + "results": results.iter().map(result_to_json).collect::>(), + })).unwrap()); + } else if !args.silent && !results.is_empty() { + let rolled_back: Vec<_> = results + .iter() + .filter(|r| r.success && !r.files_rolled_back.is_empty()) + .collect(); + let already_original: Vec<_> = results + .iter() + .filter(|r| { + r.success + && r.files_verified.iter().all(|f| { + f.status == VerifyRollbackStatus::AlreadyOriginal + }) + }) + .collect(); + let failed: Vec<_> = results.iter().filter(|r| !r.success).collect(); + + if args.dry_run { + println!("\nRollback verification complete:"); + let can_rollback = results.iter().filter(|r| r.success).count(); + println!(" {can_rollback} package(s) can be rolled back"); + if !already_original.is_empty() { + println!( + " {} package(s) already in original state", + already_original.len() + ); + } + if !failed.is_empty() { + println!(" {} package(s) cannot be rolled back", failed.len()); + } + } else { + if !rolled_back.is_empty() || !already_original.is_empty() { + println!("\nRolled back packages:"); + for result in &rolled_back { + println!(" {}", result.package_key); + } + for result in &already_original { + println!(" {} (already original)", result.package_key); + } + } + if !failed.is_empty() { + println!("\nFailed to rollback:"); + for result in &failed { + println!( + " {}: {}", + result.package_key, + result.error.as_deref().unwrap_or("unknown error") + ); + } + } + } + + if args.verbose { + println!("\nDetailed verification:"); + for result in &results { + println!(" {}:", result.package_key); + for f in &result.files_verified { + let status_str = match f.status { + VerifyRollbackStatus::Ready => "ready", + VerifyRollbackStatus::AlreadyOriginal => "already original", + VerifyRollbackStatus::HashMismatch => "hash mismatch", + VerifyRollbackStatus::NotFound => "not found", + VerifyRollbackStatus::MissingBlob => "missing blob", + }; + println!(" {} [{}]", f.file, status_str); + if let Some(ref msg) = f.message { + println!(" message: {msg}"); + } + if let Some(ref h) = f.current_hash { + println!(" current: {h}"); + } + if let Some(ref h) = f.expected_hash { + println!(" expected: {h}"); + } + if let Some(ref h) = f.target_hash { + println!(" target: {h}"); + } + } + } + } + } + + if success { + track_patch_rolled_back(rolled_back_count, api_token.as_deref(), org_slug.as_deref()).await; + } else { + track_patch_rollback_failed("One or more rollbacks failed", api_token.as_deref(), org_slug.as_deref()).await; + } + + if success { 0 } else { 1 } + } + Err(e) => { + track_patch_rollback_failed(&e, api_token.as_deref(), org_slug.as_deref()).await; + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "error", + "error": e, + "rolledBack": 0, + "alreadyOriginal": 0, + "failed": 0, + "dryRun": args.dry_run, + "results": [], + })).unwrap()); + } else if !args.silent { + eprintln!("Error: {e}"); + } + 1 + } + } +} + +async fn rollback_patches_inner( + args: &RollbackArgs, + manifest_path: &Path, +) -> Result<(bool, Vec), String> { + let manifest = read_manifest(manifest_path) + .await + .map_err(|e| e.to_string())? + .ok_or_else(|| "Invalid manifest".to_string())?; + + let socket_dir = manifest_path.parent().unwrap(); + let blobs_path = socket_dir.join("blobs"); + tokio::fs::create_dir_all(&blobs_path) + .await + .map_err(|e| e.to_string())?; + + let patches_to_rollback = + find_patches_to_rollback(&manifest, args.identifier.as_deref()); + + if patches_to_rollback.is_empty() { + if args.identifier.is_some() { + return Err(format!( + "No patch found matching identifier: {}", + args.identifier.as_deref().unwrap() + )); + } + if !args.silent && !args.json { + println!("No patches found in manifest"); + } + return Ok((true, Vec::new())); + } + + // Create filtered manifest + let filtered_manifest = PatchManifest { + patches: patches_to_rollback + .iter() + .map(|p| (p.purl.clone(), p.patch.clone())) + .collect(), + }; + + // Check for missing beforeHash blobs + let missing_blobs = get_missing_before_blobs(&filtered_manifest, &blobs_path).await; + if !missing_blobs.is_empty() { + if args.offline { + if !args.silent && !args.json { + eprintln!( + "Error: {} blob(s) are missing and --offline mode is enabled.", + missing_blobs.len() + ); + eprintln!("Run \"socket-patch repair\" to download missing blobs."); + } + return Ok((false, Vec::new())); + } + + if !args.silent && !args.json { + println!("Downloading {} missing blob(s)...", missing_blobs.len()); + } + + let (client, _) = get_api_client_from_env(None).await; + let fetch_result = fetch_blobs_by_hash(&missing_blobs, &blobs_path, &client, None).await; + + if !args.silent && !args.json { + println!("{}", format_fetch_result(&fetch_result)); + } + + let still_missing = get_missing_before_blobs(&filtered_manifest, &blobs_path).await; + if !still_missing.is_empty() { + if !args.silent && !args.json { + eprintln!( + "{} blob(s) could not be downloaded. Cannot rollback.", + still_missing.len() + ); + } + return Ok((false, Vec::new())); + } + } + + // Partition PURLs by ecosystem + let rollback_purls: Vec = patches_to_rollback.iter().map(|p| p.purl.clone()).collect(); + let partitioned = + partition_purls(&rollback_purls, args.ecosystems.as_deref()); + + let crawler_options = CrawlerOptions { + cwd: args.cwd.clone(), + global: args.global, + global_prefix: args.global_prefix.clone(), + batch_size: 100, + }; + + let all_packages = + find_packages_for_rollback(&partitioned, &crawler_options, args.silent || args.json).await; + + if all_packages.is_empty() { + if !args.silent && !args.json { + println!("No packages found that match patches to rollback"); + } + return Ok((true, Vec::new())); + } + + // Rollback patches + let mut results: Vec = Vec::new(); + let mut has_errors = false; + + for (purl, pkg_path) in &all_packages { + let patch = match filtered_manifest.patches.get(purl) { + Some(p) => p, + None => continue, + }; + + let result = rollback_package_patch( + purl, + pkg_path, + &patch.files, + &blobs_path, + args.dry_run, + ) + .await; + + if !result.success { + has_errors = true; + if !args.silent && !args.json { + eprintln!( + "Failed to rollback {}: {}", + purl, + result.error.as_deref().unwrap_or("unknown error") + ); + } + } + results.push(result); + } + + Ok((!has_errors, results)) +} + +// Export for use by remove command +#[allow(clippy::too_many_arguments)] +pub async fn rollback_patches( + cwd: &Path, + manifest_path: &Path, + identifier: Option<&str>, + dry_run: bool, + silent: bool, + offline: bool, + global: bool, + global_prefix: Option, + ecosystems: Option>, +) -> Result<(bool, Vec), String> { + let args = RollbackArgs { + identifier: identifier.map(String::from), + cwd: cwd.to_path_buf(), + dry_run, + silent, + manifest_path: manifest_path.display().to_string(), + offline, + global, + global_prefix, + one_off: false, + org: None, + api_url: None, + api_token: None, + ecosystems, + json: false, + verbose: false, + }; + rollback_patches_inner(&args, manifest_path).await +} diff --git a/crates/socket-patch-cli/src/commands/scan.rs b/crates/socket-patch-cli/src/commands/scan.rs new file mode 100644 index 0000000..bb1079a --- /dev/null +++ b/crates/socket-patch-cli/src/commands/scan.rs @@ -0,0 +1,578 @@ +use clap::Args; +use socket_patch_core::api::client::get_api_client_from_env; +use socket_patch_core::api::types::{BatchPackagePatches, PatchSearchResult}; +use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem}; +use socket_patch_core::manifest::operations::read_manifest; +use std::collections::HashSet; +use std::path::PathBuf; + +use crate::ecosystem_dispatch::crawl_all_ecosystems; +use crate::output::{color, confirm, format_severity, stderr_is_tty, stdout_is_tty}; + +use super::get::{download_and_apply_patches, select_patches, DownloadParams}; + +const DEFAULT_BATCH_SIZE: usize = 100; + +#[derive(Args)] +pub struct ScanArgs { + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Organization slug + #[arg(long)] + pub org: Option, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, + + /// Skip confirmation prompts + #[arg(short = 'y', long, default_value_t = false)] + pub yes: bool, + + /// Scan globally installed npm packages + #[arg(short = 'g', long, default_value_t = false)] + pub global: bool, + + /// Custom path to global node_modules + #[arg(long = "global-prefix")] + pub global_prefix: Option, + + /// Number of packages to query per API request + #[arg(long = "batch-size", default_value_t = DEFAULT_BATCH_SIZE)] + pub batch_size: usize, + + /// Socket API URL (overrides SOCKET_API_URL env var) + #[arg(long = "api-url")] + pub api_url: Option, + + /// Socket API token (overrides SOCKET_API_TOKEN env var) + #[arg(long = "api-token")] + pub api_token: Option, + + /// Restrict scanning to specific ecosystems (comma-separated: npm,pypi,cargo,maven) + #[arg(long, value_delimiter = ',')] + pub ecosystems: Option>, +} + +pub async fn run(args: ScanArgs) -> i32 { + // Override env vars if CLI options provided + if let Some(ref url) = args.api_url { + std::env::set_var("SOCKET_API_URL", url); + } + if let Some(ref token) = args.api_token { + std::env::set_var("SOCKET_API_TOKEN", token); + } + + let (api_client, _use_public_proxy) = get_api_client_from_env(args.org.as_deref()).await; + + // org slug is already stored in the client + let effective_org_slug: Option<&str> = None; + + let crawler_options = CrawlerOptions { + cwd: args.cwd.clone(), + global: args.global, + global_prefix: args.global_prefix.clone(), + batch_size: args.batch_size, + }; + + let scan_target = if args.global || args.global_prefix.is_some() { + "global packages" + } else { + "packages" + }; + + let show_progress = !args.json && stderr_is_tty(); + + if show_progress { + eprint!("Scanning {scan_target}..."); + } + + // Crawl packages + let (all_crawled, eco_counts) = crawl_all_ecosystems(&crawler_options).await; + + // Filter by --ecosystems if provided + let filtered_crawled: Vec<_> = if let Some(ref allowed) = args.ecosystems { + all_crawled + .into_iter() + .filter(|pkg| { + if let Some(eco) = Ecosystem::from_purl(&pkg.purl) { + allowed.iter().any(|a| a == eco.cli_name()) + } else { + false + } + }) + .collect() + } else { + all_crawled + }; + + let all_purls: Vec = filtered_crawled.iter().map(|p| p.purl.clone()).collect(); + let package_count = all_purls.len(); + + if package_count == 0 { + if show_progress { + eprintln!(); + } + if args.json { + println!( + "{}", + serde_json::to_string_pretty(&serde_json::json!({ + "status": "success", + "scannedPackages": 0, + "packagesWithPatches": 0, + "totalPatches": 0, + "freePatches": 0, + "paidPatches": 0, + "canAccessPaidPatches": false, + "packages": [], + })) + .unwrap() + ); + } else if args.global || args.global_prefix.is_some() { + println!("No global packages found."); + } else { + #[allow(unused_mut)] + let mut install_cmds = String::from("npm/yarn/pnpm/pip"); + #[cfg(feature = "cargo")] + install_cmds.push_str("/cargo"); + #[cfg(feature = "golang")] + install_cmds.push_str("/go"); + #[cfg(feature = "maven")] + install_cmds.push_str("/mvn"); + #[cfg(feature = "composer")] + install_cmds.push_str("/composer"); + println!("No packages found. Run {install_cmds} install first."); + } + return 0; + } + + // Build ecosystem summary + let mut eco_parts = Vec::new(); + for eco in Ecosystem::all() { + let count = if args.ecosystems.is_some() { + // When filtering, count the filtered packages + filtered_crawled.iter().filter(|p| Ecosystem::from_purl(&p.purl) == Some(*eco)).count() + } else { + eco_counts.get(eco).copied().unwrap_or(0) + }; + if count > 0 { + eco_parts.push(format!("{count} {}", eco.display_name())); + } + } + let eco_summary = if eco_parts.is_empty() { + String::new() + } else { + format!(" ({})", eco_parts.join(", ")) + }; + + if !args.json { + if show_progress { + eprintln!("\rFound {package_count} packages{eco_summary}"); + } else { + eprintln!("Found {package_count} packages{eco_summary}"); + } + } + + // Query API in batches + let mut all_packages_with_patches: Vec = Vec::new(); + let mut can_access_paid_patches = false; + let total_batches = all_purls.len().div_ceil(args.batch_size); + + if show_progress { + eprint!("Querying API for patches... (batch 1/{total_batches})"); + } + + for (batch_idx, chunk) in all_purls.chunks(args.batch_size).enumerate() { + if show_progress { + eprint!( + "\rQuerying API for patches... (batch {}/{})", + batch_idx + 1, + total_batches + ); + } + + let purls: Vec = chunk.to_vec(); + match api_client + .search_patches_batch(effective_org_slug, &purls) + .await + { + Ok(response) => { + if response.can_access_paid_patches { + can_access_paid_patches = true; + } + for pkg in response.packages { + if !pkg.patches.is_empty() { + all_packages_with_patches.push(pkg); + } + } + } + Err(e) => { + if !args.json { + eprintln!("\nError querying batch {}: {e}", batch_idx + 1); + } + } + } + } + + let total_patches_found: usize = all_packages_with_patches + .iter() + .map(|p| p.patches.len()) + .sum(); + + if !args.json { + if total_patches_found > 0 { + if show_progress { + eprintln!( + "\rFound {total_patches_found} patches for {} packages", + all_packages_with_patches.len() + ); + } else { + eprintln!( + "Found {total_patches_found} patches for {} packages", + all_packages_with_patches.len() + ); + } + } else if show_progress { + eprintln!("\rAPI query complete"); + } else { + eprintln!("API query complete"); + } + } + + // Calculate patch counts + let mut free_patches = 0usize; + let mut paid_patches = 0usize; + for pkg in &all_packages_with_patches { + for patch in &pkg.patches { + if patch.tier == "free" { + free_patches += 1; + } else { + paid_patches += 1; + } + } + } + let total_patches = free_patches + paid_patches; + + if args.json { + let result = serde_json::json!({ + "status": "success", + "scannedPackages": package_count, + "packagesWithPatches": all_packages_with_patches.len(), + "totalPatches": total_patches, + "freePatches": free_patches, + "paidPatches": paid_patches, + "canAccessPaidPatches": can_access_paid_patches, + "packages": all_packages_with_patches, + }); + println!("{}", serde_json::to_string_pretty(&result).unwrap()); + return 0; + } + + let use_color = stdout_is_tty(); + + if all_packages_with_patches.is_empty() { + println!("\nNo patches available for installed packages."); + return 0; + } + + // Check manifest for existing patches (update detection) + let manifest_path = args.cwd.join(".socket").join("manifest.json"); + let existing_manifest = read_manifest(&manifest_path).await.ok().flatten(); + let mut updates_available = 0usize; + + // Print table + println!("\n{}", "=".repeat(100)); + println!( + "{} {} {} VULNERABILITIES", + "PACKAGE".to_string() + &" ".repeat(33), + "PATCHES".to_string() + " ", + "SEVERITY".to_string() + &" ".repeat(8), + ); + println!("{}", "=".repeat(100)); + + for pkg in &all_packages_with_patches { + let max_purl_len = 40; + let display_purl = if pkg.purl.len() > max_purl_len { + format!("{}...", &pkg.purl[..max_purl_len - 3]) + } else { + pkg.purl.clone() + }; + + let pkg_free = pkg.patches.iter().filter(|p| p.tier == "free").count(); + let pkg_paid = pkg.patches.iter().filter(|p| p.tier == "paid").count(); + + let count_str = if pkg_paid > 0 { + if can_access_paid_patches { + format!("{}+{}", pkg_free, pkg_paid) + } else { + format!("{}+{}", pkg_free, color(&pkg_paid.to_string(), "33", use_color)) + } + } else { + format!("{}", pkg_free) + }; + + // Get highest severity + let severity = pkg + .patches + .iter() + .filter_map(|p| p.severity.as_deref()) + .min_by_key(|s| severity_order(s)) + .unwrap_or("unknown"); + + // Collect vuln IDs + let mut all_cves = HashSet::new(); + let mut all_ghsas = HashSet::new(); + for patch in &pkg.patches { + for cve in &patch.cve_ids { + all_cves.insert(cve.clone()); + } + for ghsa in &patch.ghsa_ids { + all_ghsas.insert(ghsa.clone()); + } + } + let vuln_ids: Vec<_> = all_cves.into_iter().chain(all_ghsas).collect(); + let vuln_str = if vuln_ids.len() > 2 { + format!( + "{} (+{})", + vuln_ids[..2].join(", "), + vuln_ids.len() - 2 + ) + } else if vuln_ids.is_empty() { + "-".to_string() + } else { + vuln_ids.join(", ") + }; + + // Check for updates + let has_update = if let Some(ref manifest) = existing_manifest { + if let Some(existing) = manifest.patches.get(&pkg.purl) { + // If any patch in the batch has a different UUID than what's in manifest, update available + pkg.patches.iter().any(|p| p.uuid != existing.uuid) + } else { + false + } + } else { + false + }; + if has_update { + updates_available += 1; + } + + let update_marker = if has_update { + color(" [UPDATE]", "33", use_color) + } else { + String::new() + }; + + println!( + "{:<40} {:>8} {:<16} {}{}", + display_purl, + count_str, + format_severity(severity, use_color), + vuln_str, + update_marker, + ); + } + + println!("{}", "=".repeat(100)); + + // Summary + if can_access_paid_patches { + println!( + "\nSummary: {} package(s) with {} available patch(es)", + all_packages_with_patches.len(), + total_patches, + ); + } else { + println!( + "\nSummary: {} package(s) with {} free patch(es)", + all_packages_with_patches.len(), + free_patches, + ); + if paid_patches > 0 { + println!( + "{}", + color( + &format!(" + {} additional patch(es) available with paid subscription", paid_patches), + "33", + use_color, + ), + ); + println!( + "\nUpgrade to Socket's paid plan to access all patches: https://socket.dev/pricing" + ); + } + } + + if updates_available > 0 { + println!( + "\n{}", + color( + &format!("{updates_available} package(s) have newer patches available."), + "33", + use_color, + ), + ); + } + + // Count downloadable patches + let downloadable_count = if can_access_paid_patches { + all_packages_with_patches.len() + } else { + all_packages_with_patches + .iter() + .filter(|pkg| pkg.patches.iter().any(|p| p.tier == "free")) + .count() + }; + + if downloadable_count == 0 { + println!("\nNo downloadable patches (paid subscription required)."); + return 0; + } + + // Fetch full PatchSearchResult for each package that has patches + if show_progress { + eprint!("\nFetching patch details..."); + } + + let mut all_search_results: Vec = Vec::new(); + for (i, pkg) in all_packages_with_patches.iter().enumerate() { + if show_progress { + eprint!( + "\rFetching patch details... ({}/{})", + i + 1, + all_packages_with_patches.len() + ); + } + match api_client + .search_patches_by_package(effective_org_slug, &pkg.purl) + .await + { + Ok(response) => { + all_search_results.extend(response.patches); + } + Err(e) => { + eprintln!("\n Warning: could not fetch details for {}: {e}", pkg.purl); + } + } + } + + if show_progress { + eprintln!(); + } + + if all_search_results.is_empty() { + eprintln!("Could not fetch patch details."); + return 1; + } + + // Smart selection + let selected: Vec = + match select_patches(&all_search_results, can_access_paid_patches, false) { + Ok(s) => s, + Err(code) => return code, + }; + + if selected.is_empty() { + println!("No patches selected."); + return 0; + } + + // Display detailed summary of selected patches before confirming + println!("\nPatches to apply:\n"); + for patch in &selected { + // Collect CVE/GHSA IDs and highest severity from vulnerabilities + let mut vuln_ids: Vec = Vec::new(); + let mut highest_severity: Option<&str> = None; + for (id, vuln) in &patch.vulnerabilities { + if vuln.cves.is_empty() { + vuln_ids.push(id.clone()); + } else { + for cve in &vuln.cves { + vuln_ids.push(cve.clone()); + } + } + let sev = vuln.severity.as_str(); + if highest_severity + .is_none_or(|cur| severity_order(sev) < severity_order(cur)) + { + highest_severity = Some(sev); + } + } + + let sev_display = highest_severity.unwrap_or("unknown"); + let sev_colored = format_severity(sev_display, use_color); + + let desc = if patch.description.len() > 72 { + format!("{}...", &patch.description[..69]) + } else { + patch.description.clone() + }; + + println!( + " {} [{}] {}", + patch.purl, + patch.tier.to_uppercase(), + sev_colored, + ); + if !vuln_ids.is_empty() { + println!(" Fixes: {}", vuln_ids.join(", ")); + } + // Show per-vulnerability summaries + for vuln in patch.vulnerabilities.values() { + if !vuln.summary.is_empty() { + let summary = if vuln.summary.len() > 76 { + format!("{}...", &vuln.summary[..73]) + } else { + vuln.summary.clone() + }; + let cve_label = if vuln.cves.is_empty() { + String::new() + } else { + format!("{}: ", vuln.cves.join(", ")) + }; + println!(" - {cve_label}{summary}"); + } + } + if !desc.is_empty() { + println!(" {desc}"); + } + println!(); + } + + // Prompt to download + let prompt = format!("Download and apply {} patch(es)?", selected.len()); + if !confirm(&prompt, true, args.yes, args.json) { + println!("\nTo apply a patch, run:"); + println!(" socket-patch get "); + println!(" socket-patch get "); + return 0; + } + + // Download and apply + let params = DownloadParams { + cwd: args.cwd.clone(), + org: args.org.clone(), + save_only: false, + one_off: false, + global: args.global, + global_prefix: args.global_prefix.clone(), + json: false, + silent: false, + }; + + let (code, _) = download_and_apply_patches(&selected, ¶ms).await; + code +} + +fn severity_order(s: &str) -> u8 { + match s.to_lowercase().as_str() { + "critical" => 0, + "high" => 1, + "medium" => 2, + "low" => 3, + _ => 4, + } +} diff --git a/crates/socket-patch-cli/src/commands/setup.rs b/crates/socket-patch-cli/src/commands/setup.rs new file mode 100644 index 0000000..cb7b9f5 --- /dev/null +++ b/crates/socket-patch-cli/src/commands/setup.rs @@ -0,0 +1,287 @@ +use clap::Args; +use socket_patch_core::package_json::detect::PackageManager; +use socket_patch_core::package_json::find::{ + detect_package_manager, find_package_json_files, WorkspaceType, +}; +use socket_patch_core::package_json::update::{update_package_json, UpdateStatus}; +use std::io::{self, Write}; +use std::path::{Path, PathBuf}; + +use crate::output::stdin_is_tty; + +#[derive(Args)] +pub struct SetupArgs { + /// Working directory + #[arg(long, default_value = ".")] + pub cwd: PathBuf, + + /// Preview changes without modifying files + #[arg(short = 'd', long = "dry-run", default_value_t = false)] + pub dry_run: bool, + + /// Skip confirmation prompt + #[arg(short = 'y', long, default_value_t = false)] + pub yes: bool, + + /// Output results as JSON + #[arg(long, default_value_t = false)] + pub json: bool, +} + +pub async fn run(args: SetupArgs) -> i32 { + if !args.json { + println!("Searching for package.json files..."); + } + + let find_result = find_package_json_files(&args.cwd).await; + + // For pnpm monorepos, only update root package.json. + // pnpm runs root postinstall on `pnpm install`, so workspace-level + // postinstall scripts are unnecessary. Individual workspaces may not + // have `@socketsecurity/socket-patch` as a dependency, causing + // `npx @socketsecurity/socket-patch apply` to fail due to pnpm's + // strict module isolation. + let package_json_files = match find_result.workspace_type { + WorkspaceType::Pnpm => find_result + .files + .into_iter() + .filter(|loc| loc.is_root) + .collect(), + _ => find_result.files, + }; + + if package_json_files.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "no_files", + "updated": 0, + "alreadyConfigured": 0, + "errors": 0, + "files": [], + })).unwrap()); + } else { + println!("No package.json files found"); + } + return 0; + } + + // Detect package manager from lockfiles in the project root. + let pm = detect_package_manager(&args.cwd).await; + + if !args.json { + println!("Found {} package.json file(s)", package_json_files.len()); + if pm == PackageManager::Pnpm { + println!("Detected pnpm project (using pnpm dlx)"); + } + } + + // Preview changes (always preview first) + let mut preview_results = Vec::new(); + for loc in &package_json_files { + let result = update_package_json(&loc.path, true, pm).await; + preview_results.push(result); + } + + // Display preview + let to_update: Vec<_> = preview_results + .iter() + .filter(|r| r.status == UpdateStatus::Updated) + .collect(); + let already_configured: Vec<_> = preview_results + .iter() + .filter(|r| r.status == UpdateStatus::AlreadyConfigured) + .collect(); + let errors: Vec<_> = preview_results + .iter() + .filter(|r| r.status == UpdateStatus::Error) + .collect(); + + if !args.json { + println!("\nPackage.json files to be updated:\n"); + + if !to_update.is_empty() { + println!("Will update:"); + for result in &to_update { + let rel_path = pathdiff(&result.path, &args.cwd); + println!(" + {rel_path}"); + if result.old_script.is_empty() { + println!(" postinstall: (no script)"); + } else { + println!(" postinstall: \"{}\"", result.old_script); + } + println!(" -> postinstall: \"{}\"", result.new_script); + if result.old_dependencies_script.is_empty() { + println!(" dependencies: (no script)"); + } else { + println!(" dependencies: \"{}\"", result.old_dependencies_script); + } + println!( + " -> dependencies: \"{}\"", + result.new_dependencies_script + ); + } + println!(); + } + + if !already_configured.is_empty() { + println!("Already configured (will skip):"); + for result in &already_configured { + let rel_path = pathdiff(&result.path, &args.cwd); + println!(" = {rel_path}"); + } + println!(); + } + + if !errors.is_empty() { + println!("Errors:"); + for result in &errors { + let rel_path = pathdiff(&result.path, &args.cwd); + println!( + " ! {}: {}", + rel_path, + result.error.as_deref().unwrap_or("unknown error") + ); + } + println!(); + } + } + + if to_update.is_empty() { + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "already_configured", + "updated": 0, + "alreadyConfigured": already_configured.len(), + "errors": errors.len(), + "files": preview_results.iter().map(|r| { + serde_json::json!({ + "path": r.path, + "status": match r.status { + UpdateStatus::Updated => "updated", + UpdateStatus::AlreadyConfigured => "already_configured", + UpdateStatus::Error => "error", + }, + "error": r.error, + }) + }).collect::>(), + })).unwrap()); + } else { + println!("All package.json files are already configured with socket-patch!"); + } + return 0; + } + + // If not dry-run, ask for confirmation + if !args.dry_run { + if !args.yes && !args.json { + if !stdin_is_tty() { + // Non-interactive: default to yes with warning + eprintln!("Non-interactive mode detected, proceeding automatically."); + } else { + print!("Proceed with these changes? (y/N): "); + io::stdout().flush().unwrap(); + let mut answer = String::new(); + io::stdin().read_line(&mut answer).unwrap(); + let answer = answer.trim().to_lowercase(); + if answer != "y" && answer != "yes" { + println!("Aborted"); + return 0; + } + } + } + + if !args.json { + println!("\nApplying changes..."); + } + let mut results = Vec::new(); + for loc in &package_json_files { + let result = update_package_json(&loc.path, false, pm).await; + results.push(result); + } + + let updated = results.iter().filter(|r| r.status == UpdateStatus::Updated).count(); + let already = results.iter().filter(|r| r.status == UpdateStatus::AlreadyConfigured).count(); + let errs = results.iter().filter(|r| r.status == UpdateStatus::Error).count(); + + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": if errs > 0 { "partial_failure" } else { "success" }, + "updated": updated, + "alreadyConfigured": already, + "errors": errs, + "packageManager": match pm { + PackageManager::Npm => "npm", + PackageManager::Pnpm => "pnpm", + }, + "files": results.iter().map(|r| { + serde_json::json!({ + "path": r.path, + "status": match r.status { + UpdateStatus::Updated => "updated", + UpdateStatus::AlreadyConfigured => "already_configured", + UpdateStatus::Error => "error", + }, + "error": r.error, + }) + }).collect::>(), + })).unwrap()); + } else { + println!("\nSummary:"); + println!(" {updated} file(s) updated"); + println!(" {already} file(s) already configured"); + if errs > 0 { + println!(" {errs} error(s)"); + } + } + + if errs > 0 { 1 } else { 0 } + } else { + let updated = preview_results.iter().filter(|r| r.status == UpdateStatus::Updated).count(); + let already = preview_results.iter().filter(|r| r.status == UpdateStatus::AlreadyConfigured).count(); + let errs = preview_results.iter().filter(|r| r.status == UpdateStatus::Error).count(); + + if args.json { + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "status": "dry_run", + "wouldUpdate": updated, + "alreadyConfigured": already, + "errors": errs, + "dryRun": true, + "packageManager": match pm { + PackageManager::Npm => "npm", + PackageManager::Pnpm => "pnpm", + }, + "files": preview_results.iter().map(|r| { + serde_json::json!({ + "path": r.path, + "status": match r.status { + UpdateStatus::Updated => "updated", + UpdateStatus::AlreadyConfigured => "already_configured", + UpdateStatus::Error => "error", + }, + "oldScript": r.old_script, + "newScript": r.new_script, + "oldDependenciesScript": r.old_dependencies_script, + "newDependenciesScript": r.new_dependencies_script, + "error": r.error, + }) + }).collect::>(), + })).unwrap()); + } else { + println!("\nSummary:"); + println!(" {updated} file(s) would be updated"); + println!(" {already} file(s) already configured"); + if errs > 0 { + println!(" {errs} error(s)"); + } + } + 0 + } +} + +fn pathdiff(path: &str, base: &Path) -> String { + let p = Path::new(path); + p.strip_prefix(base) + .map(|r| r.display().to_string()) + .unwrap_or_else(|_| path.to_string()) +} diff --git a/crates/socket-patch-cli/src/ecosystem_dispatch.rs b/crates/socket-patch-cli/src/ecosystem_dispatch.rs new file mode 100644 index 0000000..2c499d9 --- /dev/null +++ b/crates/socket-patch-cli/src/ecosystem_dispatch.rs @@ -0,0 +1,706 @@ +use socket_patch_core::crawlers::{ + CrawledPackage, CrawlerOptions, Ecosystem, NpmCrawler, PythonCrawler, +}; +use socket_patch_core::utils::purl::strip_purl_qualifiers; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; + +#[cfg(feature = "cargo")] +use socket_patch_core::crawlers::CargoCrawler; +use socket_patch_core::crawlers::RubyCrawler; +#[cfg(feature = "golang")] +use socket_patch_core::crawlers::GoCrawler; +#[cfg(feature = "maven")] +use socket_patch_core::crawlers::MavenCrawler; +#[cfg(feature = "composer")] +use socket_patch_core::crawlers::ComposerCrawler; +#[cfg(feature = "nuget")] +use socket_patch_core::crawlers::NuGetCrawler; + +/// Partition PURLs by ecosystem, filtering by the `--ecosystems` flag if set. +pub fn partition_purls( + purls: &[String], + allowed_ecosystems: Option<&[String]>, +) -> HashMap> { + let mut map: HashMap> = HashMap::new(); + + for purl in purls { + if let Some(eco) = Ecosystem::from_purl(purl) { + if let Some(allowed) = allowed_ecosystems { + if !allowed.iter().any(|a| a == eco.cli_name()) { + continue; + } + } + map.entry(eco).or_default().push(purl.clone()); + } + } + + map +} + +/// For each ecosystem in the partitioned map, create the crawler, discover +/// source paths, and look up the given PURLs. Returns a unified +/// `purl -> path` map. +pub async fn find_packages_for_purls( + partitioned: &HashMap>, + options: &CrawlerOptions, + silent: bool, +) -> HashMap { + let mut all_packages: HashMap = HashMap::new(); + + // npm + if let Some(npm_purls) = partitioned.get(&Ecosystem::Npm) { + if !npm_purls.is_empty() { + let npm_crawler = NpmCrawler; + match npm_crawler.get_node_modules_paths(options).await { + Ok(nm_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = nm_paths.first() { + println!("Using global npm packages at: {}", first.display()); + } + } + for nm_path in &nm_paths { + match npm_crawler.find_by_purls(nm_path, npm_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", nm_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find npm packages: {e}"); + } + } + } + } + } + + // pypi — deduplicate by base PURL (stripping qualifiers) + if let Some(pypi_purls) = partitioned.get(&Ecosystem::Pypi) { + if !pypi_purls.is_empty() { + let python_crawler = PythonCrawler; + let base_pypi_purls: Vec = pypi_purls + .iter() + .map(|p| strip_purl_qualifiers(p).to_string()) + .collect::>() + .into_iter() + .collect(); + + match python_crawler.get_site_packages_paths(options).await { + Ok(sp_paths) => { + for sp_path in &sp_paths { + match python_crawler.find_by_purls(sp_path, &base_pypi_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", sp_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Python packages: {e}"); + } + } + } + } + } + + // cargo + #[cfg(feature = "cargo")] + if let Some(cargo_purls) = partitioned.get(&Ecosystem::Cargo) { + if !cargo_purls.is_empty() { + let cargo_crawler = CargoCrawler; + match cargo_crawler.get_crate_source_paths(options).await { + Ok(src_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = src_paths.first() { + println!("Using cargo crate sources at: {}", first.display()); + } + } + for src_path in &src_paths { + match cargo_crawler.find_by_purls(src_path, cargo_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", src_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Cargo crates: {e}"); + } + } + } + } + } + + // gem + if let Some(gem_purls) = partitioned.get(&Ecosystem::Gem) { + if !gem_purls.is_empty() { + let ruby_crawler = RubyCrawler; + match ruby_crawler.get_gem_paths(options).await { + Ok(gem_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = gem_paths.first() { + println!("Using ruby gem paths at: {}", first.display()); + } + } + for gem_path in &gem_paths { + match ruby_crawler.find_by_purls(gem_path, gem_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", gem_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Ruby gems: {e}"); + } + } + } + } + } + + // golang + #[cfg(feature = "golang")] + if let Some(golang_purls) = partitioned.get(&Ecosystem::Golang) { + if !golang_purls.is_empty() { + let go_crawler = GoCrawler; + match go_crawler.get_module_cache_paths(options).await { + Ok(cache_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = cache_paths.first() { + println!("Using Go module cache at: {}", first.display()); + } + } + for cache_path in &cache_paths { + match go_crawler.find_by_purls(cache_path, golang_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", cache_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Go modules: {e}"); + } + } + } + } + } + + // maven + #[cfg(feature = "maven")] + if let Some(maven_purls) = partitioned.get(&Ecosystem::Maven) { + if !maven_purls.is_empty() { + let maven_crawler = MavenCrawler; + match maven_crawler.get_maven_repo_paths(options).await { + Ok(repo_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = repo_paths.first() { + println!("Using Maven repository at: {}", first.display()); + } + } + for repo_path in &repo_paths { + match maven_crawler.find_by_purls(repo_path, maven_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", repo_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Maven packages: {e}"); + } + } + } + } + } + + // composer + #[cfg(feature = "composer")] + if let Some(composer_purls) = partitioned.get(&Ecosystem::Composer) { + if !composer_purls.is_empty() { + let composer_crawler = ComposerCrawler; + match composer_crawler.get_vendor_paths(options).await { + Ok(vendor_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = vendor_paths.first() { + println!("Using PHP vendor packages at: {}", first.display()); + } + } + for vendor_path in &vendor_paths { + match composer_crawler.find_by_purls(vendor_path, composer_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", vendor_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find PHP packages: {e}"); + } + } + } + } + } + + // nuget + #[cfg(feature = "nuget")] + if let Some(nuget_purls) = partitioned.get(&Ecosystem::Nuget) { + if !nuget_purls.is_empty() { + let nuget_crawler = NuGetCrawler; + match nuget_crawler.get_nuget_package_paths(options).await { + Ok(pkg_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = pkg_paths.first() { + println!("Using NuGet packages at: {}", first.display()); + } + } + for pkg_path in &pkg_paths { + match nuget_crawler.find_by_purls(pkg_path, nuget_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", pkg_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find NuGet packages: {e}"); + } + } + } + } + } + + all_packages +} + +/// Crawl all enabled ecosystems and return all packages plus per-ecosystem counts. +pub async fn crawl_all_ecosystems( + options: &CrawlerOptions, +) -> (Vec, HashMap) { + let mut all_packages = Vec::new(); + let mut counts: HashMap = HashMap::new(); + + let npm_crawler = NpmCrawler; + let npm_packages = npm_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Npm, npm_packages.len()); + all_packages.extend(npm_packages); + + let python_crawler = PythonCrawler; + let python_packages = python_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Pypi, python_packages.len()); + all_packages.extend(python_packages); + + #[cfg(feature = "cargo")] + { + let cargo_crawler = CargoCrawler; + let cargo_packages = cargo_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Cargo, cargo_packages.len()); + all_packages.extend(cargo_packages); + } + + { + let ruby_crawler = RubyCrawler; + let gem_packages = ruby_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Gem, gem_packages.len()); + all_packages.extend(gem_packages); + } + + #[cfg(feature = "golang")] + { + let go_crawler = GoCrawler; + let go_packages = go_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Golang, go_packages.len()); + all_packages.extend(go_packages); + } + + #[cfg(feature = "maven")] + { + let maven_crawler = MavenCrawler; + let maven_packages = maven_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Maven, maven_packages.len()); + all_packages.extend(maven_packages); + } + + #[cfg(feature = "composer")] + { + let composer_crawler = ComposerCrawler; + let composer_packages = composer_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Composer, composer_packages.len()); + all_packages.extend(composer_packages); + } + + #[cfg(feature = "nuget")] + { + let nuget_crawler = NuGetCrawler; + let nuget_packages = nuget_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Nuget, nuget_packages.len()); + all_packages.extend(nuget_packages); + } + + (all_packages, counts) +} + +/// Variant of `find_packages_for_purls` for rollback, which needs to remap +/// pypi qualified PURLs (with `?artifact_id=...`) to the base PURL found +/// by the crawler. +pub async fn find_packages_for_rollback( + partitioned: &HashMap>, + options: &CrawlerOptions, + silent: bool, +) -> HashMap { + let mut all_packages: HashMap = HashMap::new(); + + // npm + if let Some(npm_purls) = partitioned.get(&Ecosystem::Npm) { + if !npm_purls.is_empty() { + let npm_crawler = NpmCrawler; + match npm_crawler.get_node_modules_paths(options).await { + Ok(nm_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = nm_paths.first() { + println!("Using global npm packages at: {}", first.display()); + } + } + for nm_path in &nm_paths { + match npm_crawler.find_by_purls(nm_path, npm_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", nm_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find npm packages: {e}"); + } + } + } + } + } + + // pypi — remap qualified PURLs to found base PURLs + if let Some(pypi_purls) = partitioned.get(&Ecosystem::Pypi) { + if !pypi_purls.is_empty() { + let python_crawler = PythonCrawler; + let base_pypi_purls: Vec = pypi_purls + .iter() + .map(|p| strip_purl_qualifiers(p).to_string()) + .collect::>() + .into_iter() + .collect(); + + if let Ok(sp_paths) = python_crawler.get_site_packages_paths(options).await { + for sp_path in &sp_paths { + match python_crawler.find_by_purls(sp_path, &base_pypi_purls).await { + Ok(packages) => { + for (base_purl, pkg) in packages { + for qualified_purl in pypi_purls { + if strip_purl_qualifiers(qualified_purl) == base_purl + && !all_packages.contains_key(qualified_purl) + { + all_packages + .insert(qualified_purl.clone(), pkg.path.clone()); + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", sp_path.display(), e); + } + } + } + } + } + } + } + + // cargo + #[cfg(feature = "cargo")] + if let Some(cargo_purls) = partitioned.get(&Ecosystem::Cargo) { + if !cargo_purls.is_empty() { + let cargo_crawler = CargoCrawler; + match cargo_crawler.get_crate_source_paths(options).await { + Ok(src_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = src_paths.first() { + println!("Using cargo crate sources at: {}", first.display()); + } + } + for src_path in &src_paths { + match cargo_crawler.find_by_purls(src_path, cargo_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", src_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Cargo crates: {e}"); + } + } + } + } + } + + // gem + if let Some(gem_purls) = partitioned.get(&Ecosystem::Gem) { + if !gem_purls.is_empty() { + let ruby_crawler = RubyCrawler; + match ruby_crawler.get_gem_paths(options).await { + Ok(gem_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = gem_paths.first() { + println!("Using ruby gem paths at: {}", first.display()); + } + } + for gem_path in &gem_paths { + match ruby_crawler.find_by_purls(gem_path, gem_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", gem_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Ruby gems: {e}"); + } + } + } + } + } + + // golang + #[cfg(feature = "golang")] + if let Some(golang_purls) = partitioned.get(&Ecosystem::Golang) { + if !golang_purls.is_empty() { + let go_crawler = GoCrawler; + match go_crawler.get_module_cache_paths(options).await { + Ok(cache_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = cache_paths.first() { + println!("Using Go module cache at: {}", first.display()); + } + } + for cache_path in &cache_paths { + match go_crawler.find_by_purls(cache_path, golang_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", cache_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Go modules: {e}"); + } + } + } + } + } + + // maven + #[cfg(feature = "maven")] + if let Some(maven_purls) = partitioned.get(&Ecosystem::Maven) { + if !maven_purls.is_empty() { + let maven_crawler = MavenCrawler; + match maven_crawler.get_maven_repo_paths(options).await { + Ok(repo_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = repo_paths.first() { + println!("Using Maven repository at: {}", first.display()); + } + } + for repo_path in &repo_paths { + match maven_crawler.find_by_purls(repo_path, maven_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", repo_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Maven packages: {e}"); + } + } + } + } + } + + // composer + #[cfg(feature = "composer")] + if let Some(composer_purls) = partitioned.get(&Ecosystem::Composer) { + if !composer_purls.is_empty() { + let composer_crawler = ComposerCrawler; + match composer_crawler.get_vendor_paths(options).await { + Ok(vendor_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = vendor_paths.first() { + println!("Using PHP vendor packages at: {}", first.display()); + } + } + for vendor_path in &vendor_paths { + match composer_crawler.find_by_purls(vendor_path, composer_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", vendor_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find PHP packages: {e}"); + } + } + } + } + } + + // nuget + #[cfg(feature = "nuget")] + if let Some(nuget_purls) = partitioned.get(&Ecosystem::Nuget) { + if !nuget_purls.is_empty() { + let nuget_crawler = NuGetCrawler; + match nuget_crawler.get_nuget_package_paths(options).await { + Ok(pkg_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = pkg_paths.first() { + println!("Using NuGet packages at: {}", first.display()); + } + } + for pkg_path in &pkg_paths { + match nuget_crawler.find_by_purls(pkg_path, nuget_purls).await { + Ok(packages) => { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + Err(e) => { + if !silent { + eprintln!("Warning: Failed to scan {}: {}", pkg_path.display(), e); + } + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find NuGet packages: {e}"); + } + } + } + } + } + + all_packages +} diff --git a/crates/socket-patch-cli/src/main.rs b/crates/socket-patch-cli/src/main.rs new file mode 100644 index 0000000..e278400 --- /dev/null +++ b/crates/socket-patch-cli/src/main.rs @@ -0,0 +1,94 @@ +mod commands; +mod ecosystem_dispatch; +mod output; + +use clap::{Parser, Subcommand}; + +#[derive(Parser)] +#[command( + name = "socket-patch", + about = "CLI tool for applying security patches to dependencies", + version, + propagate_version = true +)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Apply security patches to dependencies + Apply(commands::apply::ApplyArgs), + + /// Rollback patches to restore original files + Rollback(commands::rollback::RollbackArgs), + + /// Get security patches from Socket API and apply them + #[command(visible_alias = "download")] + Get(commands::get::GetArgs), + + /// Scan installed packages for available security patches + Scan(commands::scan::ScanArgs), + + /// List all patches in the local manifest + List(commands::list::ListArgs), + + /// Remove a patch from the manifest by PURL or UUID (rolls back files first) + Remove(commands::remove::RemoveArgs), + + /// Configure package.json postinstall scripts to apply patches + Setup(commands::setup::SetupArgs), + + /// Download missing blobs and clean up unused blobs + #[command(visible_alias = "gc")] + Repair(commands::repair::RepairArgs), +} + +/// Check whether `s` looks like a UUID (8-4-4-4-12 hex pattern). +fn looks_like_uuid(s: &str) -> bool { + let parts: Vec<&str> = s.split('-').collect(); + if parts.len() != 5 { + return false; + } + let expected = [8, 4, 4, 4, 12]; + parts + .iter() + .zip(expected.iter()) + .all(|(p, &len)| p.len() == len && p.chars().all(|c| c.is_ascii_hexdigit())) +} + +#[tokio::main] +async fn main() { + let cli = match Cli::try_parse() { + Ok(cli) => cli, + Err(err) => { + // If parsing failed, check whether the user passed a bare UUID + // (e.g. `socket-patch 80630680-...`) and retry as `get ...`. + let args: Vec = std::env::args().collect(); + if args.len() >= 2 && looks_like_uuid(&args[1]) { + let mut new_args = vec![args[0].clone(), "get".into()]; + new_args.extend_from_slice(&args[1..]); + match Cli::try_parse_from(&new_args) { + Ok(cli) => cli, + Err(_) => err.exit(), + } + } else { + err.exit() + } + } + }; + + let exit_code = match cli.command { + Commands::Apply(args) => commands::apply::run(args).await, + Commands::Rollback(args) => commands::rollback::run(args).await, + Commands::Get(args) => commands::get::run(args).await, + Commands::Scan(args) => commands::scan::run(args).await, + Commands::List(args) => commands::list::run(args).await, + Commands::Remove(args) => commands::remove::run(args).await, + Commands::Setup(args) => commands::setup::run(args).await, + Commands::Repair(args) => commands::repair::run(args).await, + }; + + std::process::exit(exit_code); +} diff --git a/crates/socket-patch-cli/src/output.rs b/crates/socket-patch-cli/src/output.rs new file mode 100644 index 0000000..c92f1ae --- /dev/null +++ b/crates/socket-patch-cli/src/output.rs @@ -0,0 +1,97 @@ +use std::io::{self, IsTerminal, Write}; + +/// Check if stdout is a terminal (for ANSI color output). +pub fn stdout_is_tty() -> bool { + std::io::stdout().is_terminal() +} + +/// Check if stderr is a terminal (for progress output). +pub fn stderr_is_tty() -> bool { + std::io::stderr().is_terminal() +} + +/// Check if stdin is a terminal (for interactive prompts). +pub fn stdin_is_tty() -> bool { + std::io::stdin().is_terminal() +} + +/// Format a severity string with optional ANSI colors. +pub fn format_severity(s: &str, use_color: bool) -> String { + if !use_color { + return s.to_string(); + } + match s.to_lowercase().as_str() { + "critical" => format!("\x1b[31m{s}\x1b[0m"), + "high" => format!("\x1b[91m{s}\x1b[0m"), + "medium" => format!("\x1b[33m{s}\x1b[0m"), + "low" => format!("\x1b[36m{s}\x1b[0m"), + _ => s.to_string(), + } +} + +/// Wrap text in ANSI color codes if use_color is true. +pub fn color(text: &str, code: &str, use_color: bool) -> String { + if use_color { + format!("\x1b[{code}m{text}\x1b[0m") + } else { + text.to_string() + } +} + +/// Error type for interactive selection. +pub enum SelectError { + /// User cancelled the selection. + Cancelled, + /// JSON mode requires explicit selection (e.g. via --id). + JsonModeNeedsExplicit, +} + +/// Prompt the user for a yes/no confirmation. +/// +/// - `skip_prompt` (from `-y` flag) or `is_json`: return `default_yes` immediately. +/// - Non-TTY stdin: return `default_yes` with a stderr warning. +/// - Interactive: print prompt to stderr, read line; empty = `default_yes`. +pub fn confirm(prompt: &str, default_yes: bool, skip_prompt: bool, is_json: bool) -> bool { + if skip_prompt || is_json { + return default_yes; + } + if !stdin_is_tty() { + eprintln!("Non-interactive mode detected, proceeding with default."); + return default_yes; + } + let hint = if default_yes { "[Y/n]" } else { "[y/N]" }; + eprint!("{prompt} {hint} "); + io::stderr().flush().unwrap(); + let mut answer = String::new(); + io::stdin().read_line(&mut answer).unwrap(); + let answer = answer.trim().to_lowercase(); + if answer.is_empty() { + return default_yes; + } + answer == "y" || answer == "yes" +} + +/// Prompt the user to select one option from a list using dialoguer. +/// +/// - `is_json`: return `Err(SelectError::JsonModeNeedsExplicit)`. +/// - Non-TTY: auto-select first option with stderr warning. +/// - Interactive: use `dialoguer::Select` on stderr. +pub fn select_one(prompt: &str, options: &[String], is_json: bool) -> Result { + if is_json { + return Err(SelectError::JsonModeNeedsExplicit); + } + if !stdin_is_tty() { + eprintln!("Non-interactive mode: auto-selecting first option."); + return Ok(0); + } + let selection = dialoguer::Select::with_theme(&dialoguer::theme::ColorfulTheme::default()) + .with_prompt(prompt) + .items(options) + .default(0) + .interact_opt() + .map_err(|_| SelectError::Cancelled)?; + match selection { + Some(idx) => Ok(idx), + None => Err(SelectError::Cancelled), + } +} diff --git a/crates/socket-patch-cli/tests/e2e_cargo.rs b/crates/socket-patch-cli/tests/e2e_cargo.rs new file mode 100644 index 0000000..c4be5bb --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_cargo.rs @@ -0,0 +1,112 @@ +#![cfg(feature = "cargo")] +//! End-to-end tests for the Cargo/Rust crate patching lifecycle. +//! +//! These tests exercise crawling against a temporary directory with a fake +//! Cargo registry layout. They do **not** require network access or a real +//! Cargo installation. +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --features cargo --test e2e_cargo +//! ``` + +use std::path::PathBuf; +use std::process::{Command, Output}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn run(args: &[&str], cwd: &std::path::Path) -> Output { + Command::new(binary()) + .args(args) + .current_dir(cwd) + .env("CARGO_HOME", cwd.join(".cargo")) + .output() + .expect("Failed to run socket-patch binary") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers crates in a fake registry layout. +#[test] +fn scan_discovers_fake_registry_crates() { + let dir = tempfile::tempdir().unwrap(); + + // Set up a fake CARGO_HOME/registry/src/index.crates.io-xxx/ structure + let index_dir = dir + .path() + .join(".cargo") + .join("registry") + .join("src") + .join("index.crates.io-test"); + + // Create serde-1.0.200 + let serde_dir = index_dir.join("serde-1.0.200"); + std::fs::create_dir_all(&serde_dir).unwrap(); + std::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .unwrap(); + + // Create tokio-1.38.0 + let tokio_dir = index_dir.join("tokio-1.38.0"); + std::fs::create_dir_all(&tokio_dir).unwrap(); + std::fs::write( + tokio_dir.join("Cargo.toml"), + "[package]\nname = \"tokio\"\nversion = \"1.38.0\"\n", + ) + .unwrap(); + + // Run scan (will fail to connect to API, but we just check discovery) + let output = run(&["scan", "--cwd", dir.path().to_str().unwrap()], dir.path()); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + // Should discover the crates (output mentions "Found X packages") + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover crate packages, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers crates in a vendor layout. +#[test] +fn scan_discovers_vendor_crates() { + let dir = tempfile::tempdir().unwrap(); + + // Set up vendor directory + let vendor_dir = dir.path().join("vendor"); + + let serde_dir = vendor_dir.join("serde"); + std::fs::create_dir_all(&serde_dir).unwrap(); + std::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .unwrap(); + + // Run scan with JSON output to avoid API calls + let output = run( + &["scan", "--json", "--cwd", dir.path().to_str().unwrap()], + dir.path(), + ); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + // JSON output should show scannedPackages >= 1 (the vendor crate) + // or at minimum the scan should report finding packages + let combined = format!("{stdout}{stderr}"); + assert!( + combined.contains("scannedPackages") || combined.contains("Found"), + "Expected scan output, got:\n{combined}" + ); +} diff --git a/crates/socket-patch-cli/tests/e2e_composer.rs b/crates/socket-patch-cli/tests/e2e_composer.rs new file mode 100644 index 0000000..6ceb8ab --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_composer.rs @@ -0,0 +1,125 @@ +#![cfg(feature = "composer")] +//! End-to-end tests for the Composer/PHP package patching lifecycle. +//! +//! These tests exercise crawling against a temporary directory with a fake +//! Composer vendor layout. They do **not** require network access or a real +//! PHP/Composer installation. +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --features composer --test e2e_composer +//! ``` + +use std::path::PathBuf; +use std::process::{Command, Output}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn run(args: &[&str], cwd: &std::path::Path) -> Output { + Command::new(binary()) + .args(args) + .current_dir(cwd) + .output() + .expect("Failed to run socket-patch binary") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers packages via Composer 2 installed.json. +#[test] +fn scan_discovers_composer2_packages() { + let dir = tempfile::tempdir().unwrap(); + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + + // Create composer.json so local mode activates + std::fs::write( + project_dir.join("composer.json"), + r#"{"require": {"monolog/monolog": "^3.0"}}"#, + ) + .unwrap(); + + // Set up vendor directory with installed.json (Composer 2 format) + let vendor_dir = project_dir.join("vendor"); + let composer_dir = vendor_dir.join("composer"); + std::fs::create_dir_all(&composer_dir).unwrap(); + + // Create Composer 2 installed.json with packages array + std::fs::write( + composer_dir.join("installed.json"), + r#"{"packages": [ + {"name": "monolog/monolog", "version": "3.5.0"}, + {"name": "symfony/console", "version": "6.4.1"} + ]}"#, + ) + .unwrap(); + + // Create the actual vendor directories for the packages + std::fs::create_dir_all(vendor_dir.join("monolog").join("monolog")).unwrap(); + std::fs::create_dir_all(vendor_dir.join("symfony").join("console")).unwrap(); + + let output = run( + &["scan", "--cwd", project_dir.to_str().unwrap()], + &project_dir, + ); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover Composer packages, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers packages via Composer 1 installed.json (flat array). +#[test] +fn scan_discovers_composer1_packages() { + let dir = tempfile::tempdir().unwrap(); + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + + // Create composer.lock so local mode activates + std::fs::write( + project_dir.join("composer.lock"), + r#"{"packages": []}"#, + ) + .unwrap(); + + // Set up vendor directory with Composer 1 installed.json (flat array) + let vendor_dir = project_dir.join("vendor"); + let composer_dir = vendor_dir.join("composer"); + std::fs::create_dir_all(&composer_dir).unwrap(); + + std::fs::write( + composer_dir.join("installed.json"), + r#"[ + {"name": "guzzlehttp/guzzle", "version": "7.8.1"} + ]"#, + ) + .unwrap(); + + // Create the actual vendor directory for the package + std::fs::create_dir_all(vendor_dir.join("guzzlehttp").join("guzzle")).unwrap(); + + let output = run( + &["scan", "--json", "--cwd", project_dir.to_str().unwrap()], + &project_dir, + ); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("scannedPackages") || combined.contains("Found"), + "Expected scan output, got:\n{combined}" + ); +} diff --git a/crates/socket-patch-cli/tests/e2e_gem.rs b/crates/socket-patch-cli/tests/e2e_gem.rs new file mode 100644 index 0000000..e46fb9d --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_gem.rs @@ -0,0 +1,448 @@ +//! End-to-end tests for the RubyGems patch lifecycle. +//! +//! Non-ignored tests exercise crawling against a temporary directory with fake +//! gem layouts. They do **not** require network access or a real Ruby +//! installation. +//! +//! Ignored tests exercise the full CLI against the real Socket API, using the +//! **activestorage@5.2.0** patch (UUID `4bf7fe0b-dc57-4ea8-945f-bc4a04c47a15`), +//! which fixes CVE-2022-21831 (code injection). +//! +//! # Running +//! ```sh +//! # Scan tests (no network needed) +//! cargo test -p socket-patch-cli --test e2e_gem +//! +//! # Full lifecycle (needs bundler + network) +//! cargo test -p socket-patch-cli --test e2e_gem -- --ignored +//! ``` + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; + +use sha2::{Digest, Sha256}; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +const GEM_UUID: &str = "4bf7fe0b-dc57-4ea8-945f-bc4a04c47a15"; +const GEM_PURL: &str = "pkg:gem/activestorage@5.2.0"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn has_command(cmd: &str) -> bool { + Command::new(cmd) + .arg("--version") + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .is_ok() +} + +/// Compute Git SHA-256: `SHA256("blob \0" ++ content)`. +fn git_sha256(content: &[u8]) -> String { + let header = format!("blob {}\0", content.len()); + let mut hasher = Sha256::new(); + hasher.update(header.as_bytes()); + hasher.update(content); + hex::encode(hasher.finalize()) +} + +fn git_sha256_file(path: &Path) -> String { + let content = std::fs::read(path).unwrap_or_else(|e| panic!("read {}: {e}", path.display())); + git_sha256(&content) +} + +fn run(cwd: &Path, args: &[&str]) -> (i32, String, String) { + let out: Output = Command::new(binary()) + .args(args) + .current_dir(cwd) + .env_remove("SOCKET_API_TOKEN") + .output() + .expect("failed to execute socket-patch binary"); + + let code = out.status.code().unwrap_or(-1); + let stdout = String::from_utf8_lossy(&out.stdout).to_string(); + let stderr = String::from_utf8_lossy(&out.stderr).to_string(); + (code, stdout, stderr) +} + +fn assert_run_ok(cwd: &Path, args: &[&str], context: &str) -> (String, String) { + let (code, stdout, stderr) = run(cwd, args); + assert_eq!( + code, 0, + "{context} failed (exit {code}).\nstdout:\n{stdout}\nstderr:\n{stderr}" + ); + (stdout, stderr) +} + +fn bundle_run(cwd: &Path, args: &[&str]) { + let out = Command::new("bundle") + .args(args) + .current_dir(cwd) + .output() + .expect("failed to run bundle"); + assert!( + out.status.success(), + "bundle {args:?} failed (exit {:?}).\nstdout:\n{}\nstderr:\n{}", + out.status.code(), + String::from_utf8_lossy(&out.stdout), + String::from_utf8_lossy(&out.stderr), + ); +} + +/// Write a minimal Gemfile that installs activestorage 5.2.0. +fn write_gemfile(cwd: &Path) { + std::fs::write( + cwd.join("Gemfile"), + "source 'https://rubygems.org'\ngem 'activestorage', '5.2.0'\n", + ) + .expect("write Gemfile"); +} + +/// Locate the gem install directory under vendor/bundle/ruby/*/gems/activestorage-5.2.0. +fn find_gem_dir(cwd: &Path) -> PathBuf { + let ruby_dir = cwd.join("vendor/bundle/ruby"); + for entry in std::fs::read_dir(&ruby_dir).expect("read vendor/bundle/ruby") { + let entry = entry.unwrap(); + let gem_dir = entry.path().join("gems").join("activestorage-5.2.0"); + if gem_dir.exists() { + return gem_dir; + } + } + panic!( + "could not find activestorage-5.2.0 gem dir under {}", + ruby_dir.display() + ); +} + +/// Read the manifest and return the files map for the gem patch. +fn read_patch_files(manifest_path: &Path) -> serde_json::Value { + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(manifest_path).unwrap()).unwrap(); + let patch = &manifest["patches"][GEM_PURL]; + assert!(patch.is_object(), "manifest should contain {GEM_PURL}"); + patch["files"].clone() +} + +/// Record hashes of all files in the gem dir that will be patched. +fn record_original_hashes(gem_dir: &Path, files: &serde_json::Value) -> HashMap { + let mut hashes = HashMap::new(); + for (rel_path, _) in files.as_object().expect("files object") { + let full_path = gem_dir.join(rel_path); + let hash = if full_path.exists() { + git_sha256_file(&full_path) + } else { + String::new() + }; + hashes.insert(rel_path.clone(), hash); + } + hashes +} + +/// Verify all patched files match their afterHash from the manifest. +fn assert_after_hashes(gem_dir: &Path, files: &serde_json::Value) { + for (rel_path, info) in files.as_object().expect("files object") { + let after_hash = info["afterHash"] + .as_str() + .expect("afterHash should be a string"); + let full_path = gem_dir.join(rel_path); + assert!( + full_path.exists(), + "patched file should exist: {}", + full_path.display() + ); + assert_eq!( + git_sha256_file(&full_path), + after_hash, + "hash mismatch for {rel_path} after patching" + ); + } +} + +/// Verify all patched files match their beforeHash (or are removed if new). +fn assert_before_hashes(gem_dir: &Path, files: &serde_json::Value) { + for (rel_path, info) in files.as_object().expect("files object") { + let before_hash = info["beforeHash"].as_str().unwrap_or(""); + let full_path = gem_dir.join(rel_path); + if before_hash.is_empty() { + assert!( + !full_path.exists(), + "new file {rel_path} should be removed after rollback" + ); + } else { + assert_eq!( + git_sha256_file(&full_path), + before_hash, + "{rel_path} should match beforeHash" + ); + } + } +} + +/// Verify files match the originally recorded hashes. +fn assert_original_hashes(gem_dir: &Path, original_hashes: &HashMap) { + for (rel_path, orig_hash) in original_hashes { + if orig_hash.is_empty() { + continue; + } + let full_path = gem_dir.join(rel_path); + if full_path.exists() { + assert_eq!( + git_sha256_file(&full_path), + *orig_hash, + "{rel_path} should match original hash" + ); + } + } +} + +// --------------------------------------------------------------------------- +// Scan tests (no network needed) +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers gems in a vendor/bundle layout. +#[test] +fn scan_discovers_vendored_gems() { + let dir = tempfile::tempdir().unwrap(); + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + + // Create Gemfile so local mode activates + std::fs::write(project_dir.join("Gemfile"), "source 'https://rubygems.org'\n").unwrap(); + + // Set up vendor/bundle/ruby//gems/ layout + let gems_dir = project_dir + .join("vendor") + .join("bundle") + .join("ruby") + .join("3.2.0") + .join("gems"); + + // Create rails-7.1.0 with lib/ marker + let rails_dir = gems_dir.join("rails-7.1.0"); + std::fs::create_dir_all(rails_dir.join("lib")).unwrap(); + + // Create nokogiri-1.15.4 with lib/ marker + let nokogiri_dir = gems_dir.join("nokogiri-1.15.4"); + std::fs::create_dir_all(nokogiri_dir.join("lib")).unwrap(); + + let output = Command::new(binary()) + .args(["scan", "--cwd", project_dir.to_str().unwrap()]) + .current_dir(&project_dir) + .output() + .expect("Failed to run socket-patch binary"); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover vendored gems, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers gems with gemspec markers. +#[test] +fn scan_discovers_gems_with_gemspec() { + let dir = tempfile::tempdir().unwrap(); + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + + // Create Gemfile.lock so local mode activates + std::fs::write(project_dir.join("Gemfile.lock"), "GEM\n specs:\n").unwrap(); + + // Set up vendor/bundle/ruby//gems/ layout + let gems_dir = project_dir + .join("vendor") + .join("bundle") + .join("ruby") + .join("3.1.0") + .join("gems"); + + // Create net-http-0.4.1 with .gemspec marker (no lib/) + let net_http_dir = gems_dir.join("net-http-0.4.1"); + std::fs::create_dir_all(&net_http_dir).unwrap(); + std::fs::write(net_http_dir.join("net-http.gemspec"), "# gemspec\n").unwrap(); + + let output = Command::new(binary()) + .args(["scan", "--json", "--cwd", project_dir.to_str().unwrap()]) + .current_dir(&project_dir) + .output() + .expect("Failed to run socket-patch binary"); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("scannedPackages") || combined.contains("Found"), + "Expected scan output, got:\n{combined}" + ); +} + +// --------------------------------------------------------------------------- +// Lifecycle tests (need bundler + network) +// --------------------------------------------------------------------------- + +/// Full lifecycle: get -> list (verify CVE-2022-21831) -> rollback -> apply -> remove. +#[test] +#[ignore] +fn test_gem_full_lifecycle() { + if !has_command("bundle") { + eprintln!("SKIP: bundle not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + // -- Setup: create project and install activestorage@5.2.0 ---------------- + write_gemfile(cwd); + bundle_run(cwd, &["install", "--path", "vendor/bundle"]); + + let gem_dir = find_gem_dir(cwd); + + // -- GET: download + apply patch ------------------------------------------ + assert_run_ok(cwd, &["get", GEM_UUID], "get"); + + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), ".socket/manifest.json should exist after get"); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + let patch = &manifest["patches"][GEM_PURL]; + assert!(patch.is_object(), "manifest should contain {GEM_PURL}"); + assert_eq!(patch["uuid"].as_str().unwrap(), GEM_UUID); + + let files = &patch["files"]; + assert!( + files.as_object().map_or(false, |f| !f.is_empty()), + "patch should modify at least one file" + ); + + // Files should now be patched — verify against afterHash from manifest. + assert_after_hashes(&gem_dir, files); + + // -- LIST: verify JSON output --------------------------------------------- + let (stdout, _) = assert_run_ok(cwd, &["list", "--json"], "list --json"); + let list: serde_json::Value = serde_json::from_str(&stdout).unwrap(); + let patches = list["patches"].as_array().expect("patches should be an array"); + assert_eq!(patches.len(), 1); + assert_eq!(patches[0]["uuid"].as_str().unwrap(), GEM_UUID); + assert_eq!(patches[0]["purl"].as_str().unwrap(), GEM_PURL); + + let vulns = patches[0]["vulnerabilities"] + .as_array() + .expect("vulnerabilities array"); + assert!(!vulns.is_empty(), "patch should report at least one vulnerability"); + + let has_cve = vulns.iter().any(|v| { + v["cves"] + .as_array() + .map_or(false, |cves| cves.iter().any(|c| c == "CVE-2022-21831")) + }); + assert!(has_cve, "vulnerability list should include CVE-2022-21831"); + + // -- ROLLBACK: restore original files ------------------------------------- + assert_run_ok(cwd, &["rollback"], "rollback"); + assert_before_hashes(&gem_dir, files); + + // -- APPLY: re-apply from manifest ---------------------------------------- + assert_run_ok(cwd, &["apply"], "apply"); + assert_after_hashes(&gem_dir, files); + + // -- REMOVE: rollback + remove from manifest ------------------------------ + assert_run_ok(cwd, &["remove", GEM_UUID], "remove"); + assert_before_hashes(&gem_dir, files); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + assert!( + manifest["patches"].as_object().unwrap().is_empty(), + "manifest should be empty after remove" + ); +} + +/// `get --no-apply` + `apply --dry-run` should not modify files. +#[test] +#[ignore] +fn test_gem_dry_run() { + if !has_command("bundle") { + eprintln!("SKIP: bundle not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + write_gemfile(cwd); + bundle_run(cwd, &["install", "--path", "vendor/bundle"]); + + let gem_dir = find_gem_dir(cwd); + + // Download without applying. + assert_run_ok(cwd, &["get", GEM_UUID, "--no-apply"], "get --no-apply"); + + // Read manifest to get file list and expected hashes. + let manifest_path = cwd.join(".socket/manifest.json"); + let files = read_patch_files(&manifest_path); + let original_hashes = record_original_hashes(&gem_dir, &files); + + // Files should still be original (not patched). + assert_original_hashes(&gem_dir, &original_hashes); + + // Dry-run should succeed but leave files untouched. + assert_run_ok(cwd, &["apply", "--dry-run"], "apply --dry-run"); + assert_original_hashes(&gem_dir, &original_hashes); + + // Real apply should work. + assert_run_ok(cwd, &["apply"], "apply"); + assert_after_hashes(&gem_dir, &files); +} + +/// `get --save-only` should save the patch to the manifest without applying. +#[test] +#[ignore] +fn test_gem_save_only() { + if !has_command("bundle") { + eprintln!("SKIP: bundle not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + write_gemfile(cwd); + bundle_run(cwd, &["install", "--path", "vendor/bundle"]); + + let gem_dir = find_gem_dir(cwd); + + // Download with --save-only. + assert_run_ok(cwd, &["get", GEM_UUID, "--save-only"], "get --save-only"); + + // Read manifest to get file list and expected hashes. + let manifest_path = cwd.join(".socket/manifest.json"); + let files = read_patch_files(&manifest_path); + let original_hashes = record_original_hashes(&gem_dir, &files); + + // Files should still be original (not patched). + assert_original_hashes(&gem_dir, &original_hashes); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + let patch = &manifest["patches"][GEM_PURL]; + assert!(patch.is_object(), "manifest should contain {GEM_PURL}"); + assert_eq!(patch["uuid"].as_str().unwrap(), GEM_UUID); + + // Real apply should work. + assert_run_ok(cwd, &["apply"], "apply"); + assert_after_hashes(&gem_dir, &files); +} diff --git a/crates/socket-patch-cli/tests/e2e_golang.rs b/crates/socket-patch-cli/tests/e2e_golang.rs new file mode 100644 index 0000000..a0a76af --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_golang.rs @@ -0,0 +1,120 @@ +#![cfg(feature = "golang")] +//! End-to-end tests for the Go module patching lifecycle. +//! +//! These tests exercise crawling against a temporary directory with a fake +//! Go module cache layout. They do **not** require network access or a real +//! Go installation. +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --features golang --test e2e_golang +//! ``` + +use std::path::PathBuf; +use std::process::{Command, Output}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn run(args: &[&str], cwd: &std::path::Path, gomodcache: &std::path::Path) -> Output { + Command::new(binary()) + .args(args) + .current_dir(cwd) + .env("GOMODCACHE", gomodcache) + .output() + .expect("Failed to run socket-patch binary") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers Go modules in a fake module cache. +#[test] +fn scan_discovers_go_modules() { + let dir = tempfile::tempdir().unwrap(); + let cache_dir = dir.path().join("gomodcache"); + + // Create fake module: github.com/gin-gonic/gin@v1.9.1 + let gin_dir = cache_dir + .join("github.com") + .join("gin-gonic") + .join("gin@v1.9.1"); + std::fs::create_dir_all(&gin_dir).unwrap(); + std::fs::write( + gin_dir.join("go.mod"), + "module github.com/gin-gonic/gin\n\ngo 1.21\n", + ) + .unwrap(); + + // Create fake module: golang.org/x/text@v0.14.0 + let text_dir = cache_dir.join("golang.org").join("x").join("text@v0.14.0"); + std::fs::create_dir_all(&text_dir).unwrap(); + std::fs::write( + text_dir.join("go.mod"), + "module golang.org/x/text\n\ngo 1.21\n", + ) + .unwrap(); + + // Create a go.mod in the project directory so local mode activates + std::fs::write( + dir.path().join("go.mod"), + "module example.com/myproject\n\ngo 1.21\n", + ) + .unwrap(); + + let output = run( + &["scan", "--cwd", dir.path().to_str().unwrap()], + dir.path(), + &cache_dir, + ); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover Go module packages, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers case-encoded Go modules. +#[test] +fn scan_discovers_case_encoded_modules() { + let dir = tempfile::tempdir().unwrap(); + let cache_dir = dir.path().join("gomodcache"); + + // Create case-encoded module: github.com/!azure/azure-sdk-for-go@v1.0.0 + // (represents github.com/Azure/azure-sdk-for-go) + let azure_dir = cache_dir + .join("github.com") + .join("!azure") + .join("azure-sdk-for-go@v1.0.0"); + std::fs::create_dir_all(&azure_dir).unwrap(); + + // Create a go.mod in the project directory + std::fs::write( + dir.path().join("go.mod"), + "module example.com/myproject\n\ngo 1.21\n", + ) + .unwrap(); + + let output = run( + &["scan", "--json", "--cwd", dir.path().to_str().unwrap()], + dir.path(), + &cache_dir, + ); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("scannedPackages") || combined.contains("Found"), + "Expected scan output, got:\n{combined}" + ); +} diff --git a/crates/socket-patch-cli/tests/e2e_maven.rs b/crates/socket-patch-cli/tests/e2e_maven.rs new file mode 100644 index 0000000..0dd5699 --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_maven.rs @@ -0,0 +1,154 @@ +#![cfg(feature = "maven")] +//! End-to-end tests for the Maven/Java package patching lifecycle. +//! +//! These tests exercise crawling against a temporary directory with a fake +//! Maven local repository layout. They do **not** require network access or a +//! real Maven/Java installation. +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --features maven --test e2e_maven +//! ``` + +use std::path::PathBuf; +use std::process::{Command, Output}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn run(args: &[&str], cwd: &std::path::Path, m2_repo: &std::path::Path) -> Output { + Command::new(binary()) + .args(args) + .current_dir(cwd) + .env("MAVEN_REPO_LOCAL", m2_repo) + .output() + .expect("Failed to run socket-patch binary") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers artifacts in a fake Maven local repo. +#[test] +fn scan_discovers_maven_artifacts() { + let dir = tempfile::tempdir().unwrap(); + + // Set up a fake Maven local repository + let m2_repo = dir.path().join("m2-repo"); + + // Create commons-lang3 3.12.0 + let lang_dir = m2_repo + .join("org") + .join("apache") + .join("commons") + .join("commons-lang3") + .join("3.12.0"); + std::fs::create_dir_all(&lang_dir).unwrap(); + std::fs::write( + lang_dir.join("commons-lang3-3.12.0.pom"), + r#" + org.apache.commons + commons-lang3 + 3.12.0 +"#, + ) + .unwrap(); + + // Create guava 32.1.2-jre + let guava_dir = m2_repo + .join("com") + .join("google") + .join("guava") + .join("guava") + .join("32.1.2-jre"); + std::fs::create_dir_all(&guava_dir).unwrap(); + std::fs::write( + guava_dir.join("guava-32.1.2-jre.pom"), + r#" + com.google.guava + guava + 32.1.2-jre +"#, + ) + .unwrap(); + + // Create a pom.xml in the project directory so local mode activates + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + std::fs::write( + project_dir.join("pom.xml"), + r#"4.0.0"#, + ) + .unwrap(); + + let output = run( + &["scan", "--cwd", project_dir.to_str().unwrap()], + &project_dir, + &m2_repo, + ); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover Maven artifacts, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers Gradle project artifacts. +#[test] +fn scan_discovers_gradle_project_artifacts() { + let dir = tempfile::tempdir().unwrap(); + + // Set up a fake Maven local repository + let m2_repo = dir.path().join("m2-repo"); + + // Create a single artifact + let jackson_dir = m2_repo + .join("com") + .join("fasterxml") + .join("jackson") + .join("core") + .join("jackson-core") + .join("2.15.0"); + std::fs::create_dir_all(&jackson_dir).unwrap(); + std::fs::write( + jackson_dir.join("jackson-core-2.15.0.pom"), + r#" + com.fasterxml.jackson.core + jackson-core + 2.15.0 +"#, + ) + .unwrap(); + + // Create a build.gradle in the project directory (Gradle project) + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + std::fs::write( + project_dir.join("build.gradle"), + "plugins { id 'java' }\n", + ) + .unwrap(); + + let output = run( + &["scan", "--json", "--cwd", project_dir.to_str().unwrap()], + &project_dir, + &m2_repo, + ); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("scannedPackages") || combined.contains("Found"), + "Expected scan output, got:\n{combined}" + ); +} diff --git a/crates/socket-patch-cli/tests/e2e_npm.rs b/crates/socket-patch-cli/tests/e2e_npm.rs new file mode 100644 index 0000000..376c495 --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_npm.rs @@ -0,0 +1,547 @@ +//! End-to-end tests for the npm patch lifecycle. +//! +//! These tests exercise the full CLI against the real Socket API, using the +//! **minimist@1.2.2** patch (UUID `80630680-4da6-45f9-bba8-b888e0ffd58c`), +//! which fixes CVE-2021-44906 (Prototype Pollution). +//! +//! # Prerequisites +//! - `npm` on PATH +//! - Network access to `patches-api.socket.dev` and `registry.npmjs.org` +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --test e2e_npm -- --ignored +//! ``` + +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; + +use sha2::{Digest, Sha256}; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +const NPM_UUID: &str = "80630680-4da6-45f9-bba8-b888e0ffd58c"; +#[allow(dead_code)] +const NPM_PURL: &str = "pkg:npm/minimist@1.2.2"; + +/// Git SHA-256 of the *unpatched* `index.js` shipped with minimist 1.2.2. +const BEFORE_HASH: &str = "311f1e893e6eac502693fad8617dcf5353a043ccc0f7b4ba9fe385e838b67a10"; + +/// Git SHA-256 of the *patched* `index.js` after the security fix. +const AFTER_HASH: &str = "043f04d19e884aa5f8371428718d2a3f27a0d231afe77a2620ac6312f80aaa28"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn has_command(cmd: &str) -> bool { + Command::new(cmd) + .arg("--version") + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .is_ok() +} + +/// Compute Git SHA-256: `SHA256("blob \0" ++ content)`. +fn git_sha256(content: &[u8]) -> String { + let header = format!("blob {}\0", content.len()); + let mut hasher = Sha256::new(); + hasher.update(header.as_bytes()); + hasher.update(content); + hex::encode(hasher.finalize()) +} + +fn git_sha256_file(path: &Path) -> String { + let content = std::fs::read(path).unwrap_or_else(|e| panic!("read {}: {e}", path.display())); + git_sha256(&content) +} + +/// Run the CLI binary with the given args, setting `cwd` as the working dir. +/// Returns `(exit_code, stdout, stderr)`. +fn run(cwd: &Path, args: &[&str]) -> (i32, String, String) { + let out: Output = Command::new(binary()) + .args(args) + .current_dir(cwd) + .env_remove("SOCKET_API_TOKEN") // force public proxy (free-tier) + .output() + .expect("failed to execute socket-patch binary"); + + let code = out.status.code().unwrap_or(-1); + let stdout = String::from_utf8_lossy(&out.stdout).to_string(); + let stderr = String::from_utf8_lossy(&out.stderr).to_string(); + (code, stdout, stderr) +} + +fn assert_run_ok(cwd: &Path, args: &[&str], context: &str) -> (String, String) { + let (code, stdout, stderr) = run(cwd, args); + assert_eq!( + code, 0, + "{context} failed (exit {code}).\nstdout:\n{stdout}\nstderr:\n{stderr}" + ); + (stdout, stderr) +} + +fn npm_run(cwd: &Path, args: &[&str]) { + let out = Command::new("npm") + .args(args) + .current_dir(cwd) + .output() + .expect("failed to run npm"); + assert!( + out.status.success(), + "npm {args:?} failed (exit {:?}).\nstdout:\n{}\nstderr:\n{}", + out.status.code(), + String::from_utf8_lossy(&out.stdout), + String::from_utf8_lossy(&out.stderr), + ); +} + +/// Write a minimal package.json (avoids `npm init -y` which rejects temp dir +/// names that start with `.` or contain invalid characters). +fn write_package_json(cwd: &Path) { + std::fs::write( + cwd.join("package.json"), + r#"{"name":"e2e-test","version":"0.0.0","private":true}"#, + ) + .expect("write package.json"); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Full lifecycle: get → verify → list → rollback → apply → remove. +#[test] +#[ignore] +fn test_npm_full_lifecycle() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + // -- Setup: create a project and install minimist@1.2.2 ---------------- + write_package_json(cwd); + npm_run(cwd, &["install", "minimist@1.2.2"]); + + let index_js = cwd.join("node_modules/minimist/index.js"); + assert!(index_js.exists(), "minimist/index.js must exist after npm install"); + + // Confirm the original file matches the expected before-hash. + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "freshly installed index.js should have the expected beforeHash" + ); + + // -- GET: download + apply patch --------------------------------------- + assert_run_ok(cwd, &["get", NPM_UUID], "get"); + + // Manifest should exist and contain the patch. + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), ".socket/manifest.json should exist after get"); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + let patch = &manifest["patches"][NPM_PURL]; + assert!(patch.is_object(), "manifest should contain {NPM_PURL}"); + assert_eq!(patch["uuid"].as_str().unwrap(), NPM_UUID); + + // The file should now be patched. + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "index.js should match afterHash after get" + ); + + // -- LIST: verify JSON output ------------------------------------------ + let (stdout, _) = assert_run_ok(cwd, &["list", "--json"], "list --json"); + let list: serde_json::Value = serde_json::from_str(&stdout).unwrap(); + let patches = list["patches"].as_array().expect("patches should be an array"); + assert_eq!(patches.len(), 1); + assert_eq!(patches[0]["uuid"].as_str().unwrap(), NPM_UUID); + assert_eq!(patches[0]["purl"].as_str().unwrap(), NPM_PURL); + + let vulns = patches[0]["vulnerabilities"] + .as_array() + .expect("vulnerabilities array"); + assert!(!vulns.is_empty(), "patch should report at least one vulnerability"); + + // Verify the vulnerability details match CVE-2021-44906 + let has_cve = vulns.iter().any(|v| { + v["cves"] + .as_array() + .map_or(false, |cves| cves.iter().any(|c| c == "CVE-2021-44906")) + }); + assert!(has_cve, "vulnerability list should include CVE-2021-44906"); + + // -- ROLLBACK: restore original file ----------------------------------- + assert_run_ok(cwd, &["rollback"], "rollback"); + + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "index.js should match beforeHash after rollback" + ); + + // -- APPLY: re-apply from manifest ------------------------------------ + assert_run_ok(cwd, &["apply"], "apply"); + + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "index.js should match afterHash after re-apply" + ); + + // -- REMOVE: rollback + remove from manifest --------------------------- + assert_run_ok(cwd, &["remove", NPM_UUID], "remove"); + + // File should be back to original. + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "index.js should match beforeHash after remove" + ); + + // Manifest should have no patches left. + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + assert!( + manifest["patches"].as_object().unwrap().is_empty(), + "manifest should be empty after remove" + ); +} + +/// `apply --dry-run` should not modify files on disk. +#[test] +#[ignore] +fn test_npm_dry_run() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + write_package_json(cwd); + npm_run(cwd, &["install", "minimist@1.2.2"]); + + let index_js = cwd.join("node_modules/minimist/index.js"); + assert_eq!(git_sha256_file(&index_js), BEFORE_HASH); + + // Download the patch *without* applying. + assert_run_ok(cwd, &["get", NPM_UUID, "--no-apply"], "get --no-apply"); + + // File should still be original. + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "file should not change after get --no-apply" + ); + + // Dry-run should succeed but leave file untouched. + assert_run_ok(cwd, &["apply", "--dry-run"], "apply --dry-run"); + + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "file should not change after apply --dry-run" + ); + + // Real apply should work. + assert_run_ok(cwd, &["apply"], "apply"); + + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "file should match afterHash after real apply" + ); +} + +/// Global lifecycle: scan → get → list → rollback → apply → remove using `-g --global-prefix`. +#[test] +#[ignore] +fn test_npm_global_lifecycle() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let global_dir = tempfile::tempdir().unwrap(); + let cwd_dir = tempfile::tempdir().unwrap(); + let cwd = cwd_dir.path(); + + // -- Setup: install minimist@1.2.2 globally into a temp prefix ---------- + let out = Command::new("npm") + .args(["install", "-g", "--prefix", global_dir.path().to_str().unwrap(), "minimist@1.2.2"]) + .output() + .expect("failed to run npm install -g"); + assert!( + out.status.success(), + "npm install -g failed.\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&out.stdout), + String::from_utf8_lossy(&out.stderr), + ); + + // On Unix, npm -g --prefix puts packages under /lib/node_modules/ + // On Windows, it's /node_modules/ + let nm_path = if cfg!(windows) { + global_dir.path().join("node_modules") + } else { + global_dir.path().join("lib/node_modules") + }; + + let index_js = nm_path.join("minimist/index.js"); + assert!( + index_js.exists(), + "minimist/index.js must exist after global install at {}", + index_js.display() + ); + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "globally installed index.js should have the expected beforeHash" + ); + + let nm_str = nm_path.to_str().unwrap(); + + // -- SCAN: verify scan -g finds the package ------------------------------ + let (stdout, _) = assert_run_ok( + cwd, + &["scan", "-g", "--global-prefix", nm_str, "--json"], + "scan -g --json", + ); + let scan: serde_json::Value = serde_json::from_str(&stdout).unwrap(); + let scanned = scan["scannedPackages"] + .as_u64() + .expect("scannedPackages should be a number"); + assert!(scanned >= 1, "scan should find at least 1 package, got {scanned}"); + + // -- GET: download + apply patch globally -------------------------------- + assert_run_ok( + cwd, + &["get", NPM_UUID, "-g", "--global-prefix", nm_str], + "get -g", + ); + + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), "manifest should exist after get"); + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "index.js should match afterHash after global get" + ); + + // -- LIST: verify patch in output ---------------------------------------- + let (stdout, _) = assert_run_ok(cwd, &["list", "--json"], "list --json"); + let list: serde_json::Value = serde_json::from_str(&stdout).unwrap(); + let patches = list["patches"].as_array().expect("patches array"); + assert_eq!(patches.len(), 1); + assert_eq!(patches[0]["uuid"].as_str().unwrap(), NPM_UUID); + + // -- ROLLBACK: restore original file globally ---------------------------- + assert_run_ok( + cwd, + &["rollback", "-g", "--global-prefix", nm_str], + "rollback -g", + ); + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "index.js should match beforeHash after global rollback" + ); + + // -- APPLY: re-apply from manifest globally ------------------------------ + assert_run_ok( + cwd, + &["apply", "-g", "--global-prefix", nm_str], + "apply -g", + ); + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "index.js should match afterHash after global apply" + ); + + // -- REMOVE: rollback + remove from manifest globally -------------------- + assert_run_ok( + cwd, + &["remove", NPM_UUID, "-g", "--global-prefix", nm_str], + "remove -g", + ); + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "index.js should match beforeHash after global remove" + ); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + assert!( + manifest["patches"].as_object().unwrap().is_empty(), + "manifest should be empty after global remove" + ); +} + +/// `get --save-only` should save the patch to the manifest without applying. +#[test] +#[ignore] +fn test_npm_save_only() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + write_package_json(cwd); + npm_run(cwd, &["install", "minimist@1.2.2"]); + + let index_js = cwd.join("node_modules/minimist/index.js"); + assert_eq!(git_sha256_file(&index_js), BEFORE_HASH); + + // Download with --save-only (new name for --no-apply). + assert_run_ok(cwd, &["get", NPM_UUID, "--save-only"], "get --save-only"); + + // File should still be original. + assert_eq!( + git_sha256_file(&index_js), + BEFORE_HASH, + "file should not change after get --save-only" + ); + + // Manifest should exist with the patch. + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), "manifest should exist after get --save-only"); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + let patch = &manifest["patches"][NPM_PURL]; + assert!(patch.is_object(), "manifest should contain {NPM_PURL}"); + assert_eq!(patch["uuid"].as_str().unwrap(), NPM_UUID); + + // Real apply should work. + assert_run_ok(cwd, &["apply"], "apply"); + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "file should match afterHash after apply" + ); +} + +/// `apply --force` should apply patches even when file hashes don't match. +#[test] +#[ignore] +fn test_npm_apply_force() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + write_package_json(cwd); + npm_run(cwd, &["install", "minimist@1.2.2"]); + + let index_js = cwd.join("node_modules/minimist/index.js"); + assert_eq!(git_sha256_file(&index_js), BEFORE_HASH); + + // Save the patch without applying. + assert_run_ok(cwd, &["get", NPM_UUID, "--save-only"], "get --save-only"); + + // Corrupt the file to create a hash mismatch (keep same version so PURL matches). + std::fs::write(&index_js, b"// corrupted content\n").unwrap(); + assert_ne!( + git_sha256_file(&index_js), + BEFORE_HASH, + "corrupted file should have a different hash" + ); + + // Normal apply should fail due to hash mismatch. + let (code, _stdout, _stderr) = run(cwd, &["apply"]); + assert_ne!(code, 0, "apply without --force should fail on hash mismatch"); + + // Apply with --force should succeed. + assert_run_ok(cwd, &["apply", "--force"], "apply --force"); + + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "index.js should match afterHash after apply --force" + ); +} + +/// macOS auto-discovery: `scan -g --json` without `--global-prefix` uses real path probing. +#[cfg(target_os = "macos")] +#[test] +#[ignore] +fn test_npm_macos_global_auto_discovery() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let cwd_dir = tempfile::tempdir().unwrap(); + let cwd = cwd_dir.path(); + + // Run scan -g without --global-prefix to exercise macOS auto-discovery + let (code, stdout, stderr) = run(cwd, &["scan", "-g", "--json"]); + + // Should complete without error (exit 0) + assert_eq!( + code, 0, + "scan -g --json failed (exit {code}).\nstdout:\n{stdout}\nstderr:\n{stderr}" + ); + + // Output should be valid JSON with scannedPackages field + let scan: serde_json::Value = serde_json::from_str(&stdout) + .unwrap_or_else(|e| panic!("invalid JSON from scan -g: {e}\nstdout:\n{stdout}")); + assert!( + scan["scannedPackages"].is_u64(), + "scannedPackages should be a number, got: {}", + scan["scannedPackages"] + ); +} + +/// UUID shortcut: `socket-patch ` should behave like `socket-patch get `. +#[test] +#[ignore] +fn test_npm_uuid_shortcut() { + if !has_command("npm") { + eprintln!("SKIP: npm not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + write_package_json(cwd); + npm_run(cwd, &["install", "minimist@1.2.2"]); + + let index_js = cwd.join("node_modules/minimist/index.js"); + assert_eq!(git_sha256_file(&index_js), BEFORE_HASH); + + // Run with bare UUID (no "get" subcommand). + assert_run_ok(cwd, &[NPM_UUID], "uuid shortcut"); + + assert_eq!( + git_sha256_file(&index_js), + AFTER_HASH, + "index.js should match afterHash after UUID shortcut" + ); + + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), "manifest should exist after UUID shortcut"); +} diff --git a/crates/socket-patch-cli/tests/e2e_nuget.rs b/crates/socket-patch-cli/tests/e2e_nuget.rs new file mode 100644 index 0000000..fd98550 --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_nuget.rs @@ -0,0 +1,116 @@ +#![cfg(feature = "nuget")] +//! End-to-end tests for the NuGet/.NET package patching lifecycle. +//! +//! These tests exercise crawling against a temporary directory with fake +//! NuGet package layouts. They do **not** require network access or a real +//! .NET installation. +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --features nuget --test e2e_nuget +//! ``` + +use std::path::PathBuf; +use std::process::{Command, Output}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn run(args: &[&str], cwd: &std::path::Path, nuget_packages: &std::path::Path) -> Output { + Command::new(binary()) + .args(args) + .current_dir(cwd) + .env("NUGET_PACKAGES", nuget_packages) + .output() + .expect("Failed to run socket-patch binary") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers packages in a fake global cache layout. +#[test] +fn scan_discovers_global_cache_packages() { + let dir = tempfile::tempdir().unwrap(); + + // Set up a fake global NuGet cache: // with .nuspec + let nuget_cache = dir.path().join("nuget-cache"); + + let nj_dir = nuget_cache.join("newtonsoft.json").join("13.0.3"); + std::fs::create_dir_all(nj_dir.join("lib")).unwrap(); + std::fs::write( + nj_dir.join("newtonsoft.json.nuspec"), + r#"Newtonsoft.Json13.0.3"#, + ) + .unwrap(); + + let stj_dir = nuget_cache.join("system.text.json").join("8.0.0"); + std::fs::create_dir_all(&stj_dir).unwrap(); + std::fs::write( + stj_dir.join("system.text.json.nuspec"), + r#"System.Text.Json8.0.0"#, + ) + .unwrap(); + + // Create a .csproj so it's recognized as a .NET project + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + std::fs::write(project_dir.join("MyApp.csproj"), "").unwrap(); + + let output = run( + &["scan", "--cwd", project_dir.to_str().unwrap()], + &project_dir, + &nuget_cache, + ); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover NuGet packages, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers packages in a fake legacy packages/ layout. +#[test] +fn scan_discovers_legacy_packages() { + let dir = tempfile::tempdir().unwrap(); + let project_dir = dir.path().join("project"); + std::fs::create_dir_all(&project_dir).unwrap(); + + // Create a .csproj + std::fs::write(project_dir.join("MyApp.csproj"), "").unwrap(); + + // Set up legacy packages/ directory + let packages_dir = project_dir.join("packages"); + + let nj_dir = packages_dir.join("Newtonsoft.Json.13.0.3"); + std::fs::create_dir_all(nj_dir.join("lib")).unwrap(); + std::fs::write( + nj_dir.join("Newtonsoft.Json.nuspec"), + r#"Newtonsoft.Json13.0.3"#, + ) + .unwrap(); + + // Use the packages dir itself as NUGET_PACKAGES (though legacy is found via cwd) + let output = run( + &["scan", "--cwd", project_dir.to_str().unwrap()], + &project_dir, + &packages_dir, + ); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover legacy NuGet packages, got:\n{combined}" + ); +} diff --git a/crates/socket-patch-cli/tests/e2e_pypi.rs b/crates/socket-patch-cli/tests/e2e_pypi.rs new file mode 100644 index 0000000..ac27baa --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_pypi.rs @@ -0,0 +1,641 @@ +//! End-to-end tests for the PyPI patch lifecycle. +//! +//! These tests exercise the full CLI against the real Socket API, using the +//! **pydantic-ai@0.0.36** patch (UUID `725a5343-52ec-4290-b7ce-e1cec55878e1`), +//! which fixes CVE-2026-25580 (SSRF in URL Download Handling). +//! +//! # Prerequisites +//! - `python3` on PATH (with `venv` and `pip` modules) +//! - Network access to `patches-api.socket.dev` and `pypi.org` +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --test e2e_pypi -- --ignored +//! ``` + +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; + +use sha2::{Digest, Sha256}; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +const PYPI_UUID: &str = "725a5343-52ec-4290-b7ce-e1cec55878e1"; +const PYPI_PURL_PREFIX: &str = "pkg:pypi/pydantic-ai@0.0.36"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn has_python3() -> bool { + Command::new("python3") + .arg("--version") + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +/// Compute Git SHA-256: `SHA256("blob \0" ++ content)`. +fn git_sha256(content: &[u8]) -> String { + let header = format!("blob {}\0", content.len()); + let mut hasher = Sha256::new(); + hasher.update(header.as_bytes()); + hasher.update(content); + hex::encode(hasher.finalize()) +} + +fn git_sha256_file(path: &Path) -> String { + let content = std::fs::read(path).unwrap_or_else(|e| panic!("read {}: {e}", path.display())); + git_sha256(&content) +} + +/// Run the CLI binary with the given args, setting `cwd` as the working dir. +fn run(cwd: &Path, args: &[&str]) -> (i32, String, String) { + let out: Output = Command::new(binary()) + .args(args) + .current_dir(cwd) + .env_remove("SOCKET_API_TOKEN") // force public proxy (free-tier) + .output() + .expect("failed to execute socket-patch binary"); + + let code = out.status.code().unwrap_or(-1); + let stdout = String::from_utf8_lossy(&out.stdout).to_string(); + let stderr = String::from_utf8_lossy(&out.stderr).to_string(); + (code, stdout, stderr) +} + +fn assert_run_ok(cwd: &Path, args: &[&str], context: &str) -> (String, String) { + let (code, stdout, stderr) = run(cwd, args); + assert_eq!( + code, 0, + "{context} failed (exit {code}).\nstdout:\n{stdout}\nstderr:\n{stderr}" + ); + (stdout, stderr) +} + +/// Find the `site-packages` directory inside a venv. +/// +/// On Unix: `.venv/lib/python3.X/site-packages` +/// On Windows: `.venv/Lib/site-packages` +fn find_site_packages(cwd: &Path) -> PathBuf { + let venv = cwd.join(".venv"); + if cfg!(windows) { + let sp = venv.join("Lib").join("site-packages"); + assert!(sp.exists(), "site-packages not found at {}", sp.display()); + return sp; + } + // Unix: glob for python3.* directory + let lib = venv.join("lib"); + for entry in std::fs::read_dir(&lib).expect("read .venv/lib") { + let entry = entry.unwrap(); + let name = entry.file_name(); + let name = name.to_string_lossy(); + if name.starts_with("python3.") { + let sp = entry.path().join("site-packages"); + if sp.exists() { + return sp; + } + } + } + panic!("site-packages not found under {}", lib.display()); +} + +/// Create a venv and install pydantic-ai (without transitive deps for speed). +fn setup_venv(cwd: &Path) { + let status = Command::new("python3") + .args(["-m", "venv", ".venv"]) + .current_dir(cwd) + .status() + .expect("failed to create venv"); + assert!(status.success(), "python3 -m venv failed"); + + let pip = if cfg!(windows) { + cwd.join(".venv/Scripts/pip") + } else { + cwd.join(".venv/bin/pip") + }; + + // Install both the meta-package (for dist-info that matches the PURL) + // and the slim package (for the actual Python source files). + // --no-deps keeps the install fast by skipping transitive dependencies. + let out = Command::new(&pip) + .args([ + "install", + "--no-deps", + "--disable-pip-version-check", + "pydantic-ai==0.0.36", + "pydantic-ai-slim==0.0.36", + ]) + .current_dir(cwd) + .output() + .expect("failed to run pip install"); + assert!( + out.status.success(), + "pip install failed.\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&out.stdout), + String::from_utf8_lossy(&out.stderr), + ); +} + +/// Read the manifest and return the files map for the pydantic-ai patch. +/// Returns `(purl, files)` where files is `{ relative_path: { beforeHash, afterHash } }`. +fn read_patch_files(manifest_path: &Path) -> (String, serde_json::Value) { + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(manifest_path).unwrap()).unwrap(); + + let patches = manifest["patches"].as_object().expect("patches object"); + let (purl, patch) = patches + .iter() + .find(|(k, _)| k.starts_with(PYPI_PURL_PREFIX)) + .unwrap_or_else(|| panic!("no patch matching {PYPI_PURL_PREFIX} in manifest")); + + (purl.clone(), patch["files"].clone()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Full lifecycle: get → verify hashes → list → rollback → apply → remove. +#[test] +#[ignore] +fn test_pypi_full_lifecycle() { + if !has_python3() { + eprintln!("SKIP: python3 not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + // -- Setup: create venv and install pydantic-ai@0.0.36 ---------------- + setup_venv(cwd); + + let site_packages = find_site_packages(cwd); + assert!( + site_packages.join("pydantic_ai").exists(), + "pydantic_ai package should be installed in site-packages" + ); + + // Record original hashes of all files that will be patched. + // We'll compare against these after rollback. + let files_to_check = [ + "pydantic_ai/messages.py", + "pydantic_ai/models/__init__.py", + "pydantic_ai/models/anthropic.py", + "pydantic_ai/models/gemini.py", + "pydantic_ai/models/openai.py", + ]; + let original_hashes: Vec<(String, String)> = files_to_check + .iter() + .map(|f| { + let path = site_packages.join(f); + let hash = if path.exists() { + git_sha256_file(&path) + } else { + String::new() // file doesn't exist yet (e.g., _ssrf.py) + }; + (f.to_string(), hash) + }) + .collect(); + + // -- GET: download + apply patch --------------------------------------- + assert_run_ok(cwd, &["get", PYPI_UUID], "get"); + + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), ".socket/manifest.json should exist after get"); + + // Parse the manifest to get file hashes from the API. + let (purl, files_value) = read_patch_files(&manifest_path); + assert!( + purl.starts_with(PYPI_PURL_PREFIX), + "purl should start with {PYPI_PURL_PREFIX}, got {purl}" + ); + + let files = files_value.as_object().expect("files should be an object"); + assert!(!files.is_empty(), "patch should modify at least one file"); + + // Verify every file's hash matches the afterHash from the manifest. + for (rel_path, info) in files { + let after_hash = info["afterHash"] + .as_str() + .expect("afterHash should be a string"); + let full_path = site_packages.join(rel_path); + assert!( + full_path.exists(), + "patched file should exist: {}", + full_path.display() + ); + assert_eq!( + git_sha256_file(&full_path), + after_hash, + "hash mismatch for {rel_path} after get" + ); + } + + // -- LIST: verify JSON output ------------------------------------------ + let (stdout, _) = assert_run_ok(cwd, &["list", "--json"], "list --json"); + let list: serde_json::Value = serde_json::from_str(&stdout).unwrap(); + let patches = list["patches"].as_array().expect("patches array"); + assert_eq!(patches.len(), 1, "should have exactly one patch"); + assert_eq!(patches[0]["uuid"].as_str().unwrap(), PYPI_UUID); + + // Verify vulnerability + let vulns = patches[0]["vulnerabilities"] + .as_array() + .expect("vulnerabilities array"); + assert!(!vulns.is_empty(), "should have vulnerability info"); + let has_cve = vulns.iter().any(|v| { + v["cves"] + .as_array() + .map_or(false, |cves| cves.iter().any(|c| c == "CVE-2026-25580")) + }); + assert!(has_cve, "vulnerability list should include CVE-2026-25580"); + + // -- ROLLBACK: restore original files ---------------------------------- + assert_run_ok(cwd, &["rollback"], "rollback"); + + // Verify files are restored to their original state. + for (rel_path, info) in files { + let before_hash = info["beforeHash"].as_str().unwrap_or(""); + let full_path = site_packages.join(rel_path); + + if before_hash.is_empty() { + // New file — should be deleted after rollback. + assert!( + !full_path.exists(), + "new file {rel_path} should be removed after rollback" + ); + } else { + // Existing file — hash should match beforeHash. + assert_eq!( + git_sha256_file(&full_path), + before_hash, + "{rel_path} should match beforeHash after rollback" + ); + } + } + + // Also verify against our originally recorded hashes. + for (rel_path, orig_hash) in &original_hashes { + if orig_hash.is_empty() { + continue; // file didn't exist before + } + let full_path = site_packages.join(rel_path); + if full_path.exists() { + assert_eq!( + git_sha256_file(&full_path), + *orig_hash, + "{rel_path} should match original hash after rollback" + ); + } + } + + // -- APPLY: re-apply from manifest ------------------------------------ + assert_run_ok(cwd, &["apply"], "apply"); + + for (rel_path, info) in files { + let after_hash = info["afterHash"] + .as_str() + .expect("afterHash should be a string"); + let full_path = site_packages.join(rel_path); + assert_eq!( + git_sha256_file(&full_path), + after_hash, + "{rel_path} should match afterHash after re-apply" + ); + } + + // -- REMOVE: rollback + remove from manifest --------------------------- + assert_run_ok(cwd, &["remove", PYPI_UUID], "remove"); + + // Manifest should be empty. + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + assert!( + manifest["patches"].as_object().unwrap().is_empty(), + "manifest should be empty after remove" + ); +} + +/// `apply --dry-run` should not modify files on disk. +#[test] +#[ignore] +fn test_pypi_dry_run() { + if !has_python3() { + eprintln!("SKIP: python3 not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + setup_venv(cwd); + + let site_packages = find_site_packages(cwd); + + // Record original hashes. + let messages_py = site_packages.join("pydantic_ai/messages.py"); + assert!(messages_py.exists()); + let original_hash = git_sha256_file(&messages_py); + + // Download without applying. + assert_run_ok(cwd, &["get", PYPI_UUID, "--no-apply"], "get --no-apply"); + + // File should be unchanged. + assert_eq!( + git_sha256_file(&messages_py), + original_hash, + "file should not change after get --no-apply" + ); + + // Dry-run should leave file untouched. + assert_run_ok(cwd, &["apply", "--dry-run"], "apply --dry-run"); + assert_eq!( + git_sha256_file(&messages_py), + original_hash, + "file should not change after apply --dry-run" + ); + + // Real apply should work. + assert_run_ok(cwd, &["apply"], "apply"); + + // Read afterHash from manifest to verify. + let manifest_path = cwd.join(".socket/manifest.json"); + let (_, files_value) = read_patch_files(&manifest_path); + let files = files_value.as_object().unwrap(); + let after_hash = files["pydantic_ai/messages.py"]["afterHash"] + .as_str() + .unwrap(); + assert_eq!( + git_sha256_file(&messages_py), + after_hash, + "file should match afterHash after real apply" + ); +} + +/// Global lifecycle: scan → get → rollback → apply → remove using `-g --global-prefix`. +#[test] +#[ignore] +fn test_pypi_global_lifecycle() { + if !has_python3() { + eprintln!("SKIP: python3 not found on PATH"); + return; + } + + let global_dir = tempfile::tempdir().unwrap(); + let cwd_dir = tempfile::tempdir().unwrap(); + let cwd = cwd_dir.path(); + + // -- Setup: pip install --target into global_dir ------------------------- + let out = Command::new("python3") + .args([ + "-m", + "pip", + "install", + "--target", + global_dir.path().to_str().unwrap(), + "--no-deps", + "--disable-pip-version-check", + "pydantic-ai==0.0.36", + "pydantic-ai-slim==0.0.36", + ]) + .output() + .expect("failed to run pip install --target"); + assert!( + out.status.success(), + "pip install --target failed.\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&out.stdout), + String::from_utf8_lossy(&out.stderr), + ); + + assert!( + global_dir.path().join("pydantic_ai").exists(), + "pydantic_ai package should be installed in global_dir" + ); + + let gp_str = global_dir.path().to_str().unwrap(); + + // -- SCAN: verify scan -g finds the package ------------------------------ + let (stdout, _) = assert_run_ok( + cwd, + &["scan", "-g", "--global-prefix", gp_str, "--json"], + "scan -g --json", + ); + let scan: serde_json::Value = serde_json::from_str(&stdout).unwrap(); + let scanned = scan["scannedPackages"] + .as_u64() + .expect("scannedPackages should be a number"); + assert!(scanned >= 1, "scan should find at least 1 package, got {scanned}"); + + // -- GET: download + apply patch globally -------------------------------- + assert_run_ok( + cwd, + &["get", PYPI_UUID, "-g", "--global-prefix", gp_str], + "get -g", + ); + + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), "manifest should exist after get"); + + let (_, files_value) = read_patch_files(&manifest_path); + let files = files_value.as_object().expect("files object"); + + // Verify every patched file matches afterHash. + for (rel_path, info) in files { + let after_hash = info["afterHash"].as_str().expect("afterHash"); + let full_path = global_dir.path().join(rel_path); + assert!(full_path.exists(), "patched file should exist: {}", full_path.display()); + assert_eq!( + git_sha256_file(&full_path), + after_hash, + "{rel_path} should match afterHash after global get" + ); + } + + // -- ROLLBACK: restore original files globally --------------------------- + assert_run_ok( + cwd, + &["rollback", "-g", "--global-prefix", gp_str], + "rollback -g", + ); + + for (rel_path, info) in files { + let before_hash = info["beforeHash"].as_str().unwrap_or(""); + let full_path = global_dir.path().join(rel_path); + if before_hash.is_empty() { + assert!( + !full_path.exists(), + "new file {rel_path} should be removed after global rollback" + ); + } else { + assert_eq!( + git_sha256_file(&full_path), + before_hash, + "{rel_path} should match beforeHash after global rollback" + ); + } + } + + // -- APPLY: re-apply from manifest globally ------------------------------ + assert_run_ok( + cwd, + &["apply", "-g", "--global-prefix", gp_str], + "apply -g", + ); + + for (rel_path, info) in files { + let after_hash = info["afterHash"].as_str().expect("afterHash"); + let full_path = global_dir.path().join(rel_path); + assert_eq!( + git_sha256_file(&full_path), + after_hash, + "{rel_path} should match afterHash after global apply" + ); + } + + // -- REMOVE: rollback + remove from manifest globally -------------------- + assert_run_ok( + cwd, + &["remove", PYPI_UUID, "-g", "--global-prefix", gp_str], + "remove -g", + ); + + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + assert!( + manifest["patches"].as_object().unwrap().is_empty(), + "manifest should be empty after global remove" + ); +} + +/// `get --save-only` should save the patch to the manifest without applying. +#[test] +#[ignore] +fn test_pypi_save_only() { + if !has_python3() { + eprintln!("SKIP: python3 not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + setup_venv(cwd); + + let site_packages = find_site_packages(cwd); + let messages_py = site_packages.join("pydantic_ai/messages.py"); + assert!(messages_py.exists()); + let original_hash = git_sha256_file(&messages_py); + + // Download with --save-only. + assert_run_ok(cwd, &["get", PYPI_UUID, "--save-only"], "get --save-only"); + + // File should be unchanged. + assert_eq!( + git_sha256_file(&messages_py), + original_hash, + "file should not change after get --save-only" + ); + + // Manifest should exist with the patch. + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), "manifest should exist after get --save-only"); + + let (purl, _) = read_patch_files(&manifest_path); + assert!( + purl.starts_with(PYPI_PURL_PREFIX), + "manifest should contain a pydantic-ai patch" + ); + + // Real apply should work. + assert_run_ok(cwd, &["apply"], "apply"); + + let (_, files_value) = read_patch_files(&manifest_path); + let files = files_value.as_object().unwrap(); + let after_hash = files["pydantic_ai/messages.py"]["afterHash"] + .as_str() + .unwrap(); + assert_eq!( + git_sha256_file(&messages_py), + after_hash, + "file should match afterHash after apply" + ); +} + +/// macOS auto-discovery: `scan -g --json` without `--global-prefix` uses real path probing. +#[cfg(target_os = "macos")] +#[test] +#[ignore] +fn test_pypi_macos_global_auto_discovery() { + if !has_python3() { + eprintln!("SKIP: python3 not found on PATH"); + return; + } + + let cwd_dir = tempfile::tempdir().unwrap(); + let cwd = cwd_dir.path(); + + // Run scan -g without --global-prefix to exercise macOS auto-discovery + let (code, stdout, stderr) = run(cwd, &["scan", "-g", "--json"]); + + // Should complete without error (exit 0) + assert_eq!( + code, 0, + "scan -g --json failed (exit {code}).\nstdout:\n{stdout}\nstderr:\n{stderr}" + ); + + // Output should be valid JSON with scannedPackages field + let scan: serde_json::Value = serde_json::from_str(&stdout) + .unwrap_or_else(|e| panic!("invalid JSON from scan -g: {e}\nstdout:\n{stdout}")); + assert!( + scan["scannedPackages"].is_u64(), + "scannedPackages should be a number, got: {}", + scan["scannedPackages"] + ); +} + +/// UUID shortcut: `socket-patch ` should behave like `socket-patch get `. +#[test] +#[ignore] +fn test_pypi_uuid_shortcut() { + if !has_python3() { + eprintln!("SKIP: python3 not found on PATH"); + return; + } + + let dir = tempfile::tempdir().unwrap(); + let cwd = dir.path(); + + setup_venv(cwd); + + let site_packages = find_site_packages(cwd); + assert!(site_packages.join("pydantic_ai").exists()); + + // Run with bare UUID (no "get" subcommand). + assert_run_ok(cwd, &[PYPI_UUID], "uuid shortcut"); + + let manifest_path = cwd.join(".socket/manifest.json"); + assert!(manifest_path.exists(), "manifest should exist after UUID shortcut"); + + let (_, files_value) = read_patch_files(&manifest_path); + let files = files_value.as_object().expect("files object"); + + for (rel_path, info) in files { + let after_hash = info["afterHash"].as_str().expect("afterHash"); + let full_path = site_packages.join(rel_path); + assert_eq!( + git_sha256_file(&full_path), + after_hash, + "{rel_path} should match afterHash after UUID shortcut" + ); + } +} diff --git a/crates/socket-patch-core/Cargo.toml b/crates/socket-patch-core/Cargo.toml new file mode 100644 index 0000000..c081348 --- /dev/null +++ b/crates/socket-patch-core/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "socket-patch-core" +description = "Core library for socket-patch: manifest, hash, crawlers, patch engine, API client" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +readme = "README.md" + +[dependencies] +serde = { workspace = true } +serde_json = { workspace = true } +sha2 = { workspace = true } +hex = { workspace = true } +reqwest = { workspace = true } +tokio = { workspace = true } +thiserror = { workspace = true } +walkdir = { workspace = true } +uuid = { workspace = true } +regex = { workspace = true } +once_cell = { workspace = true } + +[features] +default = [] +cargo = [] +golang = [] +maven = [] +composer = [] +nuget = [] + +[dev-dependencies] +tempfile = { workspace = true } +tokio = { version = "1", features = ["full", "test-util"] } diff --git a/crates/socket-patch-core/README.md b/crates/socket-patch-core/README.md new file mode 100644 index 0000000..a365fb0 --- /dev/null +++ b/crates/socket-patch-core/README.md @@ -0,0 +1,23 @@ +# socket-patch-core + +Core library for [socket-patch](https://github.com/SocketDev/socket-patch) — a CLI tool that applies security patches to npm and Python dependencies (plus Cargo, Go, Maven, Ruby, Composer, and NuGet via feature flags) without waiting for upstream fixes. + +## What this crate provides + +- **Manifest management** — read, write, and validate `.socket/manifest.json` patch manifests +- **Patch engine** — apply and rollback file-level patches using git SHA-256 content hashes +- **Crawlers** — discover installed packages across npm, PyPI, and Ruby gems (default), plus Cargo, Go, Maven, Composer, and NuGet (via feature flags) +- **API client** — fetch patches from the Socket API +- **Utilities** — PURL parsing, blob storage, hash verification, fuzzy matching + +## Usage + +This crate is used internally by the [`socket-patch-cli`](https://crates.io/crates/socket-patch-cli) binary. If you need the CLI, install that instead: + +```bash +cargo install socket-patch-cli +``` + +## License + +MIT diff --git a/crates/socket-patch-core/src/api/blob_fetcher.rs b/crates/socket-patch-core/src/api/blob_fetcher.rs new file mode 100644 index 0000000..7309070 --- /dev/null +++ b/crates/socket-patch-core/src/api/blob_fetcher.rs @@ -0,0 +1,548 @@ +use std::collections::HashSet; +use std::path::{Path, PathBuf}; + +use crate::api::client::ApiClient; +use crate::manifest::operations::get_after_hash_blobs; +use crate::manifest::schema::PatchManifest; + +/// Result of fetching a single blob. +#[derive(Debug, Clone)] +pub struct BlobFetchResult { + pub hash: String, + pub success: bool, + pub error: Option, +} + +/// Aggregate result of a blob-fetch operation. +#[derive(Debug, Clone)] +pub struct FetchMissingBlobsResult { + pub total: usize, + pub downloaded: usize, + pub failed: usize, + pub skipped: usize, + pub results: Vec, +} + +/// Progress callback signature. +/// +/// Called with `(hash, one_based_index, total)` for each blob. +pub type OnProgress = Box; + +// ── Public API ──────────────────────────────────────────────────────── + +/// Determine which `afterHash` blobs referenced in the manifest are +/// missing from disk. +/// +/// Only checks `afterHash` blobs because those are the patched file +/// contents needed for applying patches. `beforeHash` blobs are +/// downloaded on-demand during rollback. +pub async fn get_missing_blobs( + manifest: &PatchManifest, + blobs_path: &Path, +) -> HashSet { + let after_hash_blobs = get_after_hash_blobs(manifest); + let mut missing = HashSet::new(); + + for hash in after_hash_blobs { + let blob_path = blobs_path.join(&hash); + if tokio::fs::metadata(&blob_path).await.is_err() { + missing.insert(hash); + } + } + + missing +} + +/// Download all missing `afterHash` blobs referenced in the manifest. +/// +/// Creates the `blobs_path` directory if it does not exist. +/// +/// # Arguments +/// +/// * `manifest` – Patch manifest whose `afterHash` blobs to check. +/// * `blobs_path` – Directory where blob files are stored (one file per +/// hash). +/// * `client` – [`ApiClient`] used to fetch blobs from the server. +/// * `on_progress` – Optional callback invoked before each download with +/// `(hash, 1-based index, total)`. +pub async fn fetch_missing_blobs( + manifest: &PatchManifest, + blobs_path: &Path, + client: &ApiClient, + on_progress: Option<&OnProgress>, +) -> FetchMissingBlobsResult { + let missing = get_missing_blobs(manifest, blobs_path).await; + + if missing.is_empty() { + return FetchMissingBlobsResult { + total: 0, + downloaded: 0, + failed: 0, + skipped: 0, + results: Vec::new(), + }; + } + + // Ensure blobs directory exists + if let Err(e) = tokio::fs::create_dir_all(blobs_path).await { + // If we cannot create the directory, every blob will fail. + let results: Vec = missing + .iter() + .map(|h| BlobFetchResult { + hash: h.clone(), + success: false, + error: Some(format!("Cannot create blobs directory: {}", e)), + }) + .collect(); + let failed = results.len(); + return FetchMissingBlobsResult { + total: failed, + downloaded: 0, + failed, + skipped: 0, + results, + }; + } + + let hashes: Vec = missing.into_iter().collect(); + download_hashes(&hashes, blobs_path, client, on_progress).await +} + +/// Download specific blobs identified by their hashes. +/// +/// Useful for fetching `beforeHash` blobs during rollback, where only a +/// subset of hashes is required. +/// +/// Blobs that already exist on disk are skipped (counted in `skipped`). +pub async fn fetch_blobs_by_hash( + hashes: &HashSet, + blobs_path: &Path, + client: &ApiClient, + on_progress: Option<&OnProgress>, +) -> FetchMissingBlobsResult { + if hashes.is_empty() { + return FetchMissingBlobsResult { + total: 0, + downloaded: 0, + failed: 0, + skipped: 0, + results: Vec::new(), + }; + } + + // Ensure blobs directory exists + if let Err(e) = tokio::fs::create_dir_all(blobs_path).await { + let results: Vec = hashes + .iter() + .map(|h| BlobFetchResult { + hash: h.clone(), + success: false, + error: Some(format!("Cannot create blobs directory: {}", e)), + }) + .collect(); + let failed = results.len(); + return FetchMissingBlobsResult { + total: failed, + downloaded: 0, + failed, + skipped: 0, + results, + }; + } + + // Filter out hashes that already exist on disk + let mut to_download: Vec = Vec::new(); + let mut skipped: usize = 0; + let mut results: Vec = Vec::new(); + + for hash in hashes { + let blob_path = blobs_path.join(hash); + if tokio::fs::metadata(&blob_path).await.is_ok() { + skipped += 1; + results.push(BlobFetchResult { + hash: hash.clone(), + success: true, + error: None, + }); + } else { + to_download.push(hash.clone()); + } + } + + if to_download.is_empty() { + return FetchMissingBlobsResult { + total: hashes.len(), + downloaded: 0, + failed: 0, + skipped, + results, + }; + } + + let download_result = + download_hashes(&to_download, blobs_path, client, on_progress).await; + + FetchMissingBlobsResult { + total: hashes.len(), + downloaded: download_result.downloaded, + failed: download_result.failed, + skipped, + results: { + let mut combined = results; + combined.extend(download_result.results); + combined + }, + } +} + +/// Format a [`FetchMissingBlobsResult`] as a human-readable string. +pub fn format_fetch_result(result: &FetchMissingBlobsResult) -> String { + if result.total == 0 { + return "All blobs are present locally.".to_string(); + } + + let mut lines: Vec = Vec::new(); + + if result.downloaded > 0 { + lines.push(format!("Downloaded {} blob(s)", result.downloaded)); + } + + if result.failed > 0 { + lines.push(format!("Failed to download {} blob(s)", result.failed)); + + let failed_results: Vec<&BlobFetchResult> = + result.results.iter().filter(|r| !r.success).collect(); + + for r in failed_results.iter().take(5) { + let short_hash = if r.hash.len() >= 12 { + &r.hash[..12] + } else { + &r.hash + }; + let err = r.error.as_deref().unwrap_or("unknown error"); + lines.push(format!(" - {}...: {}", short_hash, err)); + } + + if failed_results.len() > 5 { + lines.push(format!(" ... and {} more", failed_results.len() - 5)); + } + } + + lines.join("\n") +} + +// ── Internal helpers ────────────────────────────────────────────────── + +/// Download a list of blob hashes sequentially, writing each to +/// `blobs_path/`. +async fn download_hashes( + hashes: &[String], + blobs_path: &Path, + client: &ApiClient, + on_progress: Option<&OnProgress>, +) -> FetchMissingBlobsResult { + let total = hashes.len(); + let mut downloaded: usize = 0; + let mut failed: usize = 0; + let mut results: Vec = Vec::with_capacity(total); + + for (i, hash) in hashes.iter().enumerate() { + if let Some(ref cb) = on_progress { + cb(hash, i + 1, total); + } + + match client.fetch_blob(hash).await { + Ok(Some(data)) => { + // Verify content hash matches expected hash before writing + let actual_hash = crate::hash::git_sha256::compute_git_sha256_from_bytes(&data); + if actual_hash != *hash { + results.push(BlobFetchResult { + hash: hash.clone(), + success: false, + error: Some(format!( + "Content hash mismatch: expected {}, got {}", + hash, actual_hash + )), + }); + failed += 1; + continue; + } + + let blob_path: PathBuf = blobs_path.join(hash); + match tokio::fs::write(&blob_path, &data).await { + Ok(()) => { + results.push(BlobFetchResult { + hash: hash.clone(), + success: true, + error: None, + }); + downloaded += 1; + } + Err(e) => { + results.push(BlobFetchResult { + hash: hash.clone(), + success: false, + error: Some(format!("Failed to write blob to disk: {}", e)), + }); + failed += 1; + } + } + } + Ok(None) => { + results.push(BlobFetchResult { + hash: hash.clone(), + success: false, + error: Some("Blob not found on server".to_string()), + }); + failed += 1; + } + Err(e) => { + results.push(BlobFetchResult { + hash: hash.clone(), + success: false, + error: Some(e.to_string()), + }); + failed += 1; + } + } + } + + FetchMissingBlobsResult { + total, + downloaded, + failed, + skipped: 0, + results, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord}; + use std::collections::HashMap; + + fn make_manifest_with_hashes(after_hashes: &[&str]) -> PatchManifest { + let mut files = HashMap::new(); + for (i, ah) in after_hashes.iter().enumerate() { + files.insert( + format!("package/file{}.js", i), + PatchFileInfo { + before_hash: format!( + "before{}{}", + "0".repeat(58), + format!("{:06}", i) + ), + after_hash: ah.to_string(), + }, + ); + } + + let mut patches = HashMap::new(); + patches.insert( + "pkg:npm/test@1.0.0".to_string(), + PatchRecord { + uuid: "test-uuid".to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files, + vulnerabilities: HashMap::new(), + description: "test".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + }, + ); + + PatchManifest { patches } + } + + #[tokio::test] + async fn test_get_missing_blobs_all_missing() { + let dir = tempfile::tempdir().unwrap(); + let blobs_path = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_path).await.unwrap(); + + let h1 = "a".repeat(64); + let h2 = "b".repeat(64); + let manifest = make_manifest_with_hashes(&[&h1, &h2]); + + let missing = get_missing_blobs(&manifest, &blobs_path).await; + assert_eq!(missing.len(), 2); + assert!(missing.contains(&h1)); + assert!(missing.contains(&h2)); + } + + #[tokio::test] + async fn test_get_missing_blobs_some_present() { + let dir = tempfile::tempdir().unwrap(); + let blobs_path = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_path).await.unwrap(); + + let h1 = "a".repeat(64); + let h2 = "b".repeat(64); + + // Write h1 to disk so it is NOT missing + tokio::fs::write(blobs_path.join(&h1), b"data").await.unwrap(); + + let manifest = make_manifest_with_hashes(&[&h1, &h2]); + let missing = get_missing_blobs(&manifest, &blobs_path).await; + assert_eq!(missing.len(), 1); + assert!(missing.contains(&h2)); + assert!(!missing.contains(&h1)); + } + + #[tokio::test] + async fn test_get_missing_blobs_empty_manifest() { + let dir = tempfile::tempdir().unwrap(); + let blobs_path = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_path).await.unwrap(); + + let manifest = PatchManifest::new(); + let missing = get_missing_blobs(&manifest, &blobs_path).await; + assert!(missing.is_empty()); + } + + #[test] + fn test_format_fetch_result_all_present() { + let result = FetchMissingBlobsResult { + total: 0, + downloaded: 0, + failed: 0, + skipped: 0, + results: Vec::new(), + }; + assert_eq!(format_fetch_result(&result), "All blobs are present locally."); + } + + #[test] + fn test_format_fetch_result_some_downloaded() { + let result = FetchMissingBlobsResult { + total: 3, + downloaded: 2, + failed: 1, + skipped: 0, + results: vec![ + BlobFetchResult { + hash: "a".repeat(64), + success: true, + error: None, + }, + BlobFetchResult { + hash: "b".repeat(64), + success: true, + error: None, + }, + BlobFetchResult { + hash: "c".repeat(64), + success: false, + error: Some("Blob not found on server".to_string()), + }, + ], + }; + let output = format_fetch_result(&result); + assert!(output.contains("Downloaded 2 blob(s)")); + assert!(output.contains("Failed to download 1 blob(s)")); + assert!(output.contains("cccccccccccc...")); + assert!(output.contains("Blob not found on server")); + } + + #[test] + fn test_format_fetch_result_truncates_at_5() { + let results: Vec = (0..8) + .map(|i| BlobFetchResult { + hash: format!("{:0>64}", i), + success: false, + error: Some(format!("error {}", i)), + }) + .collect(); + + let result = FetchMissingBlobsResult { + total: 8, + downloaded: 0, + failed: 8, + skipped: 0, + results, + }; + let output = format_fetch_result(&result); + assert!(output.contains("... and 3 more")); + } + + // ── Group 8: format edge cases ─────────────────────────────────── + + #[test] + fn test_format_only_downloaded() { + let result = FetchMissingBlobsResult { + total: 3, + downloaded: 3, + failed: 0, + skipped: 0, + results: vec![ + BlobFetchResult { hash: "a".repeat(64), success: true, error: None }, + BlobFetchResult { hash: "b".repeat(64), success: true, error: None }, + BlobFetchResult { hash: "c".repeat(64), success: true, error: None }, + ], + }; + let output = format_fetch_result(&result); + assert!(output.contains("Downloaded 3 blob(s)")); + assert!(!output.contains("Failed")); + } + + #[test] + fn test_format_short_hash() { + let result = FetchMissingBlobsResult { + total: 1, + downloaded: 0, + failed: 1, + skipped: 0, + results: vec![BlobFetchResult { + hash: "abc".into(), + success: false, + error: Some("not found".into()), + }], + }; + let output = format_fetch_result(&result); + // Hash is < 12 chars, should show full hash + assert!(output.contains("abc...")); + } + + #[test] + fn test_format_error_none() { + let result = FetchMissingBlobsResult { + total: 1, + downloaded: 0, + failed: 1, + skipped: 0, + results: vec![BlobFetchResult { + hash: "d".repeat(64), + success: false, + error: None, + }], + }; + let output = format_fetch_result(&result); + assert!(output.contains("unknown error")); + } + + #[test] + fn test_format_only_failed() { + let result = FetchMissingBlobsResult { + total: 2, + downloaded: 0, + failed: 2, + skipped: 0, + results: vec![ + BlobFetchResult { + hash: "a".repeat(64), + success: false, + error: Some("timeout".into()), + }, + BlobFetchResult { + hash: "b".repeat(64), + success: false, + error: Some("timeout".into()), + }, + ], + }; + let output = format_fetch_result(&result); + assert!(!output.contains("Downloaded")); + assert!(output.contains("Failed to download 2 blob(s)")); + } +} diff --git a/crates/socket-patch-core/src/api/client.rs b/crates/socket-patch-core/src/api/client.rs new file mode 100644 index 0000000..b2dc8f2 --- /dev/null +++ b/crates/socket-patch-core/src/api/client.rs @@ -0,0 +1,1013 @@ +use std::collections::HashSet; + +use reqwest::header::{self, HeaderMap, HeaderValue}; +use reqwest::StatusCode; +use serde::Serialize; + +use crate::api::types::*; +use crate::constants::{ + DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT as USER_AGENT_VALUE, +}; + +/// Check if debug mode is enabled via SOCKET_PATCH_DEBUG env. +fn is_debug_enabled() -> bool { + match std::env::var("SOCKET_PATCH_DEBUG") { + Ok(val) => val == "1" || val == "true", + Err(_) => false, + } +} + +/// Log debug messages when debug mode is enabled. +fn debug_log(message: &str) { + if is_debug_enabled() { + eprintln!("[socket-patch debug] {}", message); + } +} + +/// Severity order for sorting (most severe = lowest number). +fn get_severity_order(severity: Option<&str>) -> u8 { + match severity.map(|s| s.to_lowercase()).as_deref() { + Some("critical") => 0, + Some("high") => 1, + Some("medium") => 2, + Some("low") => 3, + _ => 4, + } +} + +/// Options for constructing an [`ApiClient`]. +#[derive(Debug, Clone)] +pub struct ApiClientOptions { + pub api_url: String, + pub api_token: Option, + /// When true, the client will use the public patch API proxy + /// which only provides access to free patches without authentication. + pub use_public_proxy: bool, + /// Organization slug for authenticated API access. + /// Required when using authenticated API (not public proxy). + pub org_slug: Option, +} + +/// HTTP client for the Socket Patch API. +/// +/// Supports both the authenticated Socket API (`api.socket.dev`) and the +/// public proxy (`patches-api.socket.dev`) which serves free patches +/// without authentication. +#[derive(Debug, Clone)] +pub struct ApiClient { + client: reqwest::Client, + api_url: String, + api_token: Option, + use_public_proxy: bool, + org_slug: Option, +} + +/// Body payload for the batch search POST endpoint. +#[derive(Serialize)] +struct BatchSearchBody { + components: Vec, +} + +#[derive(Serialize)] +struct BatchComponent { + purl: String, +} + +impl ApiClient { + /// Create a new API client from the given options. + /// + /// Constructs a `reqwest::Client` with proper default headers + /// (User-Agent, Accept, and optionally Authorization). + pub fn new(options: ApiClientOptions) -> Self { + let api_url = options.api_url.trim_end_matches('/').to_string(); + + let mut default_headers = HeaderMap::new(); + default_headers.insert( + header::USER_AGENT, + HeaderValue::from_static(USER_AGENT_VALUE), + ); + default_headers.insert( + header::ACCEPT, + HeaderValue::from_static("application/json"), + ); + + if let Some(ref token) = options.api_token { + if let Ok(hv) = HeaderValue::from_str(&format!("Bearer {}", token)) { + default_headers.insert(header::AUTHORIZATION, hv); + } + } + + let client = reqwest::Client::builder() + .default_headers(default_headers) + .build() + .expect("failed to build reqwest client"); + + Self { + client, + api_url, + api_token: options.api_token, + use_public_proxy: options.use_public_proxy, + org_slug: options.org_slug, + } + } + + /// Returns the API token, if set. + pub fn api_token(&self) -> Option<&String> { + self.api_token.as_ref() + } + + /// Returns the org slug, if set. + pub fn org_slug(&self) -> Option<&String> { + self.org_slug.as_ref() + } + + // ── Internal helpers ────────────────────────────────────────────── + + /// Internal GET that deserialises JSON. Returns `Ok(None)` on 404. + async fn get_json( + &self, + path: &str, + ) -> Result, ApiError> { + let url = format!("{}{}", self.api_url, path); + debug_log(&format!("GET {}", url)); + + let resp = self + .client + .get(&url) + .send() + .await + .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?; + + Self::handle_json_response(resp, self.use_public_proxy).await + } + + /// Internal POST that deserialises JSON. Returns `Ok(None)` on 404. + async fn post_json( + &self, + path: &str, + body: &B, + ) -> Result, ApiError> { + let url = format!("{}{}", self.api_url, path); + debug_log(&format!("POST {}", url)); + + let resp = self + .client + .post(&url) + .header(header::CONTENT_TYPE, "application/json") + .json(body) + .send() + .await + .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?; + + Self::handle_json_response(resp, self.use_public_proxy).await + } + + /// Map an HTTP response to `Ok(Some(T))`, `Ok(None)` (404), or `Err`. + async fn handle_json_response( + resp: reqwest::Response, + use_public_proxy: bool, + ) -> Result, ApiError> { + let status = resp.status(); + + match status { + StatusCode::OK => { + let body = resp + .json::() + .await + .map_err(|e| ApiError::Parse(format!("Failed to parse response: {}", e)))?; + Ok(Some(body)) + } + StatusCode::NOT_FOUND => Ok(None), + StatusCode::UNAUTHORIZED => { + Err(ApiError::Unauthorized("Unauthorized: Invalid API token".into())) + } + StatusCode::FORBIDDEN => { + let msg = if use_public_proxy { + "Forbidden: This patch is only available to paid subscribers. \ + Sign up at https://socket.dev to access paid patches." + } else { + "Forbidden: Access denied. This may be a paid patch or \ + you may not have access to this organization." + }; + Err(ApiError::Forbidden(msg.into())) + } + StatusCode::TOO_MANY_REQUESTS => { + Err(ApiError::RateLimited( + "Rate limit exceeded. Please try again later.".into(), + )) + } + _ => { + let text = resp.text().await.unwrap_or_default(); + Err(ApiError::Other(format!( + "API request failed with status {}: {}", + status.as_u16(), + text + ))) + } + } + } + + // ── Public API methods ──────────────────────────────────────────── + + /// Fetch a patch by UUID (full details with blob content). + /// + /// Returns `Ok(None)` when the patch is not found (404). + pub async fn fetch_patch( + &self, + org_slug: Option<&str>, + uuid: &str, + ) -> Result, ApiError> { + let path = if self.use_public_proxy { + format!("/patch/view/{}", uuid) + } else { + let slug = org_slug + .or(self.org_slug.as_deref()) + .unwrap_or("default"); + format!("/v0/orgs/{}/patches/view/{}", slug, uuid) + }; + self.get_json(&path).await + } + + /// Search patches by CVE ID. + pub async fn search_patches_by_cve( + &self, + org_slug: Option<&str>, + cve_id: &str, + ) -> Result { + let encoded = urlencoding_encode(cve_id); + let path = if self.use_public_proxy { + format!("/patch/by-cve/{}", encoded) + } else { + let slug = org_slug + .or(self.org_slug.as_deref()) + .unwrap_or("default"); + format!("/v0/orgs/{}/patches/by-cve/{}", slug, encoded) + }; + let result = self.get_json::(&path).await?; + Ok(result.unwrap_or_else(|| SearchResponse { + patches: Vec::new(), + can_access_paid_patches: false, + })) + } + + /// Search patches by GHSA ID. + pub async fn search_patches_by_ghsa( + &self, + org_slug: Option<&str>, + ghsa_id: &str, + ) -> Result { + let encoded = urlencoding_encode(ghsa_id); + let path = if self.use_public_proxy { + format!("/patch/by-ghsa/{}", encoded) + } else { + let slug = org_slug + .or(self.org_slug.as_deref()) + .unwrap_or("default"); + format!("/v0/orgs/{}/patches/by-ghsa/{}", slug, encoded) + }; + let result = self.get_json::(&path).await?; + Ok(result.unwrap_or_else(|| SearchResponse { + patches: Vec::new(), + can_access_paid_patches: false, + })) + } + + /// Search patches by package PURL. + /// + /// The PURL must be a valid Package URL starting with `pkg:`. + /// Examples: `pkg:npm/lodash@4.17.21`, `pkg:pypi/django@3.2.0` + pub async fn search_patches_by_package( + &self, + org_slug: Option<&str>, + purl: &str, + ) -> Result { + let encoded = urlencoding_encode(purl); + let path = if self.use_public_proxy { + format!("/patch/by-package/{}", encoded) + } else { + let slug = org_slug + .or(self.org_slug.as_deref()) + .unwrap_or("default"); + format!("/v0/orgs/{}/patches/by-package/{}", slug, encoded) + }; + let result = self.get_json::(&path).await?; + Ok(result.unwrap_or_else(|| SearchResponse { + patches: Vec::new(), + can_access_paid_patches: false, + })) + } + + /// Search patches for multiple packages (batch). + /// + /// For authenticated API, uses the POST `/patches/batch` endpoint. + /// For the public proxy (which cannot cache POST bodies on CDN), falls + /// back to individual GET requests per PURL with a concurrency limit of + /// 10. + /// + /// Maximum 500 PURLs per request. + pub async fn search_patches_batch( + &self, + org_slug: Option<&str>, + purls: &[String], + ) -> Result { + if !self.use_public_proxy { + let slug = org_slug + .or(self.org_slug.as_deref()) + .unwrap_or("default"); + let path = format!("/v0/orgs/{}/patches/batch", slug); + let body = BatchSearchBody { + components: purls + .iter() + .map(|p| BatchComponent { purl: p.clone() }) + .collect(), + }; + let result = self.post_json::(&path, &body).await?; + return Ok(result.unwrap_or_else(|| BatchSearchResponse { + packages: Vec::new(), + can_access_paid_patches: false, + })); + } + + // Public proxy: fall back to individual per-package GET requests + self.search_patches_batch_via_individual_queries(purls).await + } + + /// Internal: fall back to individual GET requests per PURL when the + /// batch endpoint is not available (public proxy mode). + /// + /// Processes PURLs in batches of `CONCURRENCY_LIMIT` to avoid + /// overwhelming the server while remaining efficient. + async fn search_patches_batch_via_individual_queries( + &self, + purls: &[String], + ) -> Result { + const CONCURRENCY_LIMIT: usize = 10; + + let mut packages: Vec = Vec::new(); + let mut can_access_paid_patches = false; + + // Collect all (purl, response) pairs + let mut all_results: Vec<(String, Option)> = Vec::new(); + + for chunk in purls.chunks(CONCURRENCY_LIMIT) { + // Use tokio::JoinSet for concurrent execution within each chunk + let mut join_set = tokio::task::JoinSet::new(); + + for purl in chunk { + let purl = purl.clone(); + let client = self.clone(); + join_set.spawn(async move { + let resp = client.search_patches_by_package(None, &purl).await; + match resp { + Ok(r) => (purl, Some(r)), + Err(e) => { + debug_log(&format!("Error fetching patches for {}: {}", purl, e)); + (purl, None) + } + } + }); + } + + while let Some(result) = join_set.join_next().await { + match result { + Ok(pair) => all_results.push(pair), + Err(e) => { + debug_log(&format!("Task join error: {}", e)); + } + } + } + } + + // Convert individual SearchResponse results to BatchSearchResponse format + for (purl, response) in all_results { + let response = match response { + Some(r) if !r.patches.is_empty() => r, + _ => continue, + }; + + if response.can_access_paid_patches { + can_access_paid_patches = true; + } + + let batch_patches: Vec = response + .patches + .into_iter() + .map(convert_search_result_to_batch_info) + .collect(); + + packages.push(BatchPackagePatches { + purl, + patches: batch_patches, + }); + } + + Ok(BatchSearchResponse { + packages, + can_access_paid_patches, + }) + } + + /// Fetch organizations accessible to the current API token. + pub async fn fetch_organizations( + &self, + ) -> Result, ApiError> { + let path = "/v0/organizations"; + match self + .get_json::(path) + .await? + { + Some(resp) => Ok(resp.organizations.into_values().collect()), + None => Ok(Vec::new()), + } + } + + /// Resolve the org slug from the API token by querying `/v0/organizations`. + /// + /// If there is exactly one org, returns its slug. + /// If there are multiple, picks the first and prints a warning. + /// If there are none, returns an error. + pub async fn resolve_org_slug(&self) -> Result { + let orgs = self.fetch_organizations().await?; + match orgs.len() { + 0 => Err(ApiError::Other( + "No organizations found for this API token.".into(), + )), + 1 => Ok(orgs.into_iter().next().unwrap().slug), + _ => { + let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect(); + let first = orgs[0].slug.clone(); + eprintln!( + "Multiple organizations found: {}. Using \"{}\". \ + Pass --org to select a different one.", + slugs.join(", "), + first + ); + Ok(first) + } + } + } + + /// Fetch a blob by its SHA-256 hash. + /// + /// Returns the raw binary content, or `Ok(None)` if not found. + /// Uses the authenticated endpoint when token and org slug are + /// available, otherwise falls back to the public proxy. + pub async fn fetch_blob(&self, hash: &str) -> Result>, ApiError> { + // Validate hash format: SHA-256 = 64 hex characters + if !is_valid_sha256_hex(hash) { + return Err(ApiError::InvalidHash(format!( + "Invalid hash format: {}. Expected SHA256 hash (64 hex characters).", + hash + ))); + } + + let (url, use_auth) = + if self.api_token.is_some() && self.org_slug.is_some() && !self.use_public_proxy { + // Authenticated endpoint + let slug = self.org_slug.as_deref().unwrap(); + let u = format!("{}/v0/orgs/{}/patches/blob/{}", self.api_url, slug, hash); + (u, true) + } else { + // Public proxy + let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL") + .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string()); + let u = format!("{}/patch/blob/{}", proxy_url.trim_end_matches('/'), hash); + (u, false) + }; + + debug_log(&format!("GET blob {}", url)); + + // Build the request. When fetching from the public proxy (different + // base URL than self.api_url), we use a plain client without auth + // headers to avoid leaking credentials to the proxy. + let resp = if use_auth { + self.client + .get(&url) + .header(header::ACCEPT, "application/octet-stream") + .send() + .await + } else { + let mut headers = HeaderMap::new(); + headers.insert( + header::USER_AGENT, + HeaderValue::from_static(USER_AGENT_VALUE), + ); + headers.insert( + header::ACCEPT, + HeaderValue::from_static("application/octet-stream"), + ); + + let plain_client = reqwest::Client::builder() + .default_headers(headers) + .build() + .expect("failed to build plain reqwest client"); + + plain_client.get(&url).send().await + }; + + let resp = resp.map_err(|e| { + ApiError::Network(format!("Network error fetching blob {}: {}", hash, e)) + })?; + + let status = resp.status(); + + match status { + StatusCode::OK => { + let bytes = resp.bytes().await.map_err(|e| { + ApiError::Network(format!("Error reading blob body for {}: {}", hash, e)) + })?; + Ok(Some(bytes.to_vec())) + } + StatusCode::NOT_FOUND => Ok(None), + _ => { + let text = resp.text().await.unwrap_or_default(); + Err(ApiError::Other(format!( + "Failed to fetch blob {}: status {} - {}", + hash, + status.as_u16(), + text, + ))) + } + } + } +} + +// ── Free functions ──────────────────────────────────────────────────── + +/// Get an API client configured from environment variables. +/// +/// If `SOCKET_API_TOKEN` is not set, the client will use the public patch +/// API proxy which provides free access to free-tier patches without +/// authentication. +/// +/// When `SOCKET_API_TOKEN` is set but no org slug is provided (neither via +/// argument nor `SOCKET_ORG_SLUG` env var), the function will attempt to +/// auto-resolve the org slug by querying `GET /v0/organizations`. +/// +/// # Environment variables +/// +/// | Variable | Purpose | +/// |---|---| +/// | `SOCKET_API_URL` | Override the API URL (default `https://api.socket.dev`) | +/// | `SOCKET_API_TOKEN` | API token for authenticated access | +/// | `SOCKET_PATCH_PROXY_URL` | Override the public proxy URL (default `https://patches-api.socket.dev`) | +/// | `SOCKET_ORG_SLUG` | Organization slug | +/// +/// Returns `(client, use_public_proxy)`. +pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) { + let api_token = std::env::var("SOCKET_API_TOKEN") + .ok() + .filter(|t| !t.is_empty()); + let resolved_org_slug = org_slug + .map(String::from) + .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok()); + + if api_token.is_none() { + let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL") + .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string()); + eprintln!( + "No SOCKET_API_TOKEN set. Using public patch API proxy (free patches only)." + ); + let client = ApiClient::new(ApiClientOptions { + api_url: proxy_url, + api_token: None, + use_public_proxy: true, + org_slug: None, + }); + return (client, true); + } + + let api_url = + std::env::var("SOCKET_API_URL").unwrap_or_else(|_| DEFAULT_SOCKET_API_URL.to_string()); + + // Auto-resolve org slug if not provided + let final_org_slug = if resolved_org_slug.is_some() { + resolved_org_slug + } else { + let temp_client = ApiClient::new(ApiClientOptions { + api_url: api_url.clone(), + api_token: api_token.clone(), + use_public_proxy: false, + org_slug: None, + }); + match temp_client.resolve_org_slug().await { + Ok(slug) => Some(slug), + Err(e) => { + eprintln!("Warning: Could not auto-detect organization: {e}"); + None + } + } + }; + + let client = ApiClient::new(ApiClientOptions { + api_url, + api_token, + use_public_proxy: false, + org_slug: final_org_slug, + }); + (client, false) +} + +// ── Helpers ─────────────────────────────────────────────────────────── + +/// Percent-encode a string for use in URL path segments. +fn urlencoding_encode(input: &str) -> String { + // Encode everything that is not unreserved per RFC 3986. + let mut out = String::with_capacity(input.len()); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + out.push(byte as char) + } + _ => { + out.push('%'); + out.push_str(&format!("{:02X}", byte)); + } + } + } + out +} + +/// Truncate a string to at most `max_chars` characters, appending "..." if truncated. +/// Unlike byte slicing (`&s[..n]`), this is safe for multi-byte UTF-8 characters. +fn truncate_to_chars(s: &str, max_chars: usize) -> String { + if s.chars().count() <= max_chars { + return s.to_string(); + } + let truncated: String = s.chars().take(max_chars).collect(); + format!("{}...", truncated) +} + +/// Validate that a string is a 64-character hex string (SHA-256). +fn is_valid_sha256_hex(s: &str) -> bool { + s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit()) +} + +/// Convert a `PatchSearchResult` into a `BatchPatchInfo`, extracting +/// CVE/GHSA IDs and computing the highest severity. +fn convert_search_result_to_batch_info(patch: PatchSearchResult) -> BatchPatchInfo { + let mut cve_ids: Vec = Vec::new(); + let mut ghsa_ids: Vec = Vec::new(); + let mut highest_severity: Option = None; + let mut title = String::new(); + + let mut seen_cves: HashSet = HashSet::new(); + + for (ghsa_id, vuln) in &patch.vulnerabilities { + ghsa_ids.push(ghsa_id.clone()); + + for cve in &vuln.cves { + if seen_cves.insert(cve.clone()) { + cve_ids.push(cve.clone()); + } + } + + // Track highest severity (lower order number = higher severity) + let current_order = get_severity_order(highest_severity.as_deref()); + let vuln_order = get_severity_order(Some(&vuln.severity)); + if vuln_order < current_order { + highest_severity = Some(vuln.severity.clone()); + } + + // Use first non-empty summary as title + if title.is_empty() && !vuln.summary.is_empty() { + title = truncate_to_chars(&vuln.summary, 97); + } + } + + // Use description as fallback title + if title.is_empty() && !patch.description.is_empty() { + title = truncate_to_chars(&patch.description, 97); + } + + cve_ids.sort(); + ghsa_ids.sort(); + + BatchPatchInfo { + uuid: patch.uuid, + purl: patch.purl, + tier: patch.tier, + cve_ids, + ghsa_ids, + severity: highest_severity, + title, + } +} + +// ── Error type ──────────────────────────────────────────────────────── + +/// Errors returned by [`ApiClient`] methods. +#[derive(Debug, thiserror::Error)] +pub enum ApiError { + #[error("{0}")] + Network(String), + + #[error("{0}")] + Parse(String), + + #[error("{0}")] + Unauthorized(String), + + #[error("{0}")] + Forbidden(String), + + #[error("{0}")] + RateLimited(String), + + #[error("{0}")] + InvalidHash(String), + + #[error("{0}")] + Other(String), +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn test_urlencoding_basic() { + assert_eq!(urlencoding_encode("hello"), "hello"); + assert_eq!(urlencoding_encode("a b"), "a%20b"); + assert_eq!( + urlencoding_encode("pkg:npm/lodash@4.17.21"), + "pkg%3Anpm%2Flodash%404.17.21" + ); + } + + #[test] + fn test_is_valid_sha256_hex() { + let valid = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"; + assert!(is_valid_sha256_hex(valid)); + + // Too short + assert!(!is_valid_sha256_hex("abcdef")); + // Non-hex + assert!(!is_valid_sha256_hex( + "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz" + )); + } + + #[test] + fn test_severity_order() { + assert!(get_severity_order(Some("critical")) < get_severity_order(Some("high"))); + assert!(get_severity_order(Some("high")) < get_severity_order(Some("medium"))); + assert!(get_severity_order(Some("medium")) < get_severity_order(Some("low"))); + assert!(get_severity_order(Some("low")) < get_severity_order(None)); + assert_eq!(get_severity_order(Some("unknown")), get_severity_order(None)); + } + + #[test] + fn test_convert_search_result_to_batch_info() { + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1234-5678-9abc".to_string(), + VulnerabilityResponse { + cves: vec!["CVE-2024-0001".into()], + summary: "Test vulnerability".into(), + severity: "high".into(), + description: "A test vuln".into(), + }, + ); + + let patch = PatchSearchResult { + uuid: "uuid-1".into(), + purl: "pkg:npm/test@1.0.0".into(), + published_at: "2024-01-01".into(), + description: "A patch".into(), + license: "MIT".into(), + tier: "free".into(), + vulnerabilities: vulns, + }; + + let info = convert_search_result_to_batch_info(patch); + assert_eq!(info.uuid, "uuid-1"); + assert_eq!(info.cve_ids, vec!["CVE-2024-0001"]); + assert_eq!(info.ghsa_ids, vec!["GHSA-1234-5678-9abc"]); + assert_eq!(info.severity, Some("high".into())); + assert_eq!(info.title, "Test vulnerability"); + } + + #[tokio::test] + async fn test_get_api_client_from_env_no_token() { + // Clear token to ensure public proxy mode + std::env::remove_var("SOCKET_API_TOKEN"); + let (client, is_public) = get_api_client_from_env(None).await; + assert!(is_public); + assert!(client.use_public_proxy); + } + + // ── Group 6: convert_search_result_to_batch_info edge cases ────── + + fn make_vuln(summary: &str, severity: &str, cves: Vec<&str>) -> VulnerabilityResponse { + VulnerabilityResponse { + cves: cves.into_iter().map(String::from).collect(), + summary: summary.into(), + severity: severity.into(), + description: "desc".into(), + } + } + + fn make_patch( + vulns: HashMap, + description: &str, + ) -> PatchSearchResult { + PatchSearchResult { + uuid: "uuid-1".into(), + purl: "pkg:npm/test@1.0.0".into(), + published_at: "2024-01-01".into(), + description: description.into(), + license: "MIT".into(), + tier: "free".into(), + vulnerabilities: vulns, + } + } + + #[test] + fn test_convert_no_vulnerabilities() { + let patch = make_patch(HashMap::new(), "A patch description"); + let info = convert_search_result_to_batch_info(patch); + assert!(info.cve_ids.is_empty()); + assert!(info.ghsa_ids.is_empty()); + assert_eq!(info.title, "A patch description"); + assert!(info.severity.is_none()); + } + + #[test] + fn test_convert_multiple_vulns_picks_highest_severity() { + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1111".into(), + make_vuln("Medium vuln", "medium", vec!["CVE-2024-0001"]), + ); + vulns.insert( + "GHSA-2222".into(), + make_vuln("Critical vuln", "critical", vec!["CVE-2024-0002"]), + ); + let patch = make_patch(vulns, "desc"); + let info = convert_search_result_to_batch_info(patch); + assert_eq!(info.severity, Some("critical".into())); + } + + #[test] + fn test_convert_duplicate_cves_deduplicated() { + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1111".into(), + make_vuln("Vuln A", "high", vec!["CVE-2024-0001"]), + ); + vulns.insert( + "GHSA-2222".into(), + make_vuln("Vuln B", "high", vec!["CVE-2024-0001"]), + ); + let patch = make_patch(vulns, "desc"); + let info = convert_search_result_to_batch_info(patch); + // Same CVE in both vulns should only appear once + let cve_count = info.cve_ids.iter().filter(|c| *c == "CVE-2024-0001").count(); + assert_eq!(cve_count, 1); + } + + #[test] + fn test_convert_title_truncated_at_100() { + let long_summary = "x".repeat(150); + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1111".into(), + make_vuln(&long_summary, "high", vec![]), + ); + let patch = make_patch(vulns, "desc"); + let info = convert_search_result_to_batch_info(patch); + // Should be 97 chars + "..." = 100 chars + assert_eq!(info.title.len(), 100); + assert!(info.title.ends_with("...")); + } + + #[test] + fn test_convert_title_unicode_truncation() { + // Create a summary with multi-byte chars that would panic with byte slicing + // Each emoji is 4 bytes, so 30 emojis = 120 bytes but only 30 chars + let emoji_summary = "\u{1F600}".repeat(30); + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1111".into(), + make_vuln(&emoji_summary, "high", vec![]), + ); + let patch = make_patch(vulns, "desc"); + // This should NOT panic (validates the UTF-8 truncation fix) + let info = convert_search_result_to_batch_info(patch); + assert!(!info.title.is_empty()); + + // Also test with description fallback + let patch2 = make_patch(HashMap::new(), &"\u{1F600}".repeat(120)); + let info2 = convert_search_result_to_batch_info(patch2); + assert!(info2.title.ends_with("...")); + } + + #[test] + fn test_convert_title_falls_back_to_description() { + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1111".into(), + make_vuln("", "high", vec![]), + ); + let patch = make_patch(vulns, "Fallback desc"); + let info = convert_search_result_to_batch_info(patch); + assert_eq!(info.title, "Fallback desc"); + } + + #[test] + fn test_convert_empty_summary_and_description() { + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-1111".into(), + make_vuln("", "high", vec![]), + ); + let patch = make_patch(vulns, ""); + let info = convert_search_result_to_batch_info(patch); + assert!(info.title.is_empty()); + } + + #[test] + fn test_convert_cves_and_ghsas_sorted() { + let mut vulns = HashMap::new(); + vulns.insert( + "GHSA-cccc".into(), + make_vuln("V1", "high", vec!["CVE-2024-0003"]), + ); + vulns.insert( + "GHSA-aaaa".into(), + make_vuln("V2", "high", vec!["CVE-2024-0001"]), + ); + vulns.insert( + "GHSA-bbbb".into(), + make_vuln("V3", "high", vec!["CVE-2024-0002"]), + ); + let patch = make_patch(vulns, "desc"); + let info = convert_search_result_to_batch_info(patch); + // Both should be sorted alphabetically + let mut sorted_cves = info.cve_ids.clone(); + sorted_cves.sort(); + assert_eq!(info.cve_ids, sorted_cves); + let mut sorted_ghsas = info.ghsa_ids.clone(); + sorted_ghsas.sort(); + assert_eq!(info.ghsa_ids, sorted_ghsas); + } + + // ── Group 7: urlencoding + SHA256 edge cases ───────────────────── + + #[test] + fn test_urlencoding_unicode() { + // Multi-byte UTF-8: 'é' = 0xC3 0xA9 + let encoded = urlencoding_encode("café"); + assert_eq!(encoded, "caf%C3%A9"); + } + + #[test] + fn test_urlencoding_empty() { + assert_eq!(urlencoding_encode(""), ""); + } + + #[test] + fn test_urlencoding_all_safe_chars() { + // Unreserved chars should pass through + let safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"; + assert_eq!(urlencoding_encode(safe), safe); + } + + #[test] + fn test_urlencoding_slash_and_at() { + assert_eq!(urlencoding_encode("/"), "%2F"); + assert_eq!(urlencoding_encode("@"), "%40"); + } + + #[test] + fn test_sha256_uppercase_valid() { + let upper = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789"; + assert!(is_valid_sha256_hex(upper)); + } + + #[test] + fn test_sha256_65_chars_invalid() { + let too_long = "a".repeat(65); + assert!(!is_valid_sha256_hex(&too_long)); + } + + #[test] + fn test_sha256_63_chars_invalid() { + let too_short = "a".repeat(63); + assert!(!is_valid_sha256_hex(&too_short)); + } + + #[test] + fn test_sha256_empty_invalid() { + assert!(!is_valid_sha256_hex("")); + } + + #[test] + fn test_sha256_mixed_case_valid() { + let mixed = "aAbBcCdDeEfF0123456789aAbBcCdDeEfF0123456789aAbBcCdDeEfF01234567"; + assert_eq!(mixed.len(), 64); + assert!(is_valid_sha256_hex(mixed)); + } +} diff --git a/crates/socket-patch-core/src/api/mod.rs b/crates/socket-patch-core/src/api/mod.rs new file mode 100644 index 0000000..a0a9feb --- /dev/null +++ b/crates/socket-patch-core/src/api/mod.rs @@ -0,0 +1,6 @@ +pub mod blob_fetcher; +pub mod client; +pub mod types; + +pub use client::ApiClient; +pub use types::*; diff --git a/crates/socket-patch-core/src/api/types.rs b/crates/socket-patch-core/src/api/types.rs new file mode 100644 index 0000000..f09c31d --- /dev/null +++ b/crates/socket-patch-core/src/api/types.rs @@ -0,0 +1,247 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Organization info returned by the `/v0/organizations` endpoint. +#[derive(Debug, Clone, Deserialize)] +pub struct OrganizationInfo { + pub id: String, + pub name: Option, + pub image: Option, + pub plan: String, + pub slug: String, +} + +/// Response from `GET /v0/organizations`. +#[derive(Debug, Clone, Deserialize)] +pub struct OrganizationsResponse { + pub organizations: HashMap, +} + +/// Full patch response with blob content (from view endpoint). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PatchResponse { + pub uuid: String, + pub purl: String, + pub published_at: String, + pub files: HashMap, + pub vulnerabilities: HashMap, + pub description: String, + pub license: String, + pub tier: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PatchFileResponse { + pub before_hash: Option, + pub after_hash: Option, + pub socket_blob: Option, + pub blob_content: Option, + pub before_blob_content: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VulnerabilityResponse { + pub cves: Vec, + pub summary: String, + pub severity: String, + pub description: String, +} + +/// Lightweight search result. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PatchSearchResult { + pub uuid: String, + pub purl: String, + pub published_at: String, + pub description: String, + pub license: String, + pub tier: String, + pub vulnerabilities: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SearchResponse { + pub patches: Vec, + pub can_access_paid_patches: bool, +} + +/// Minimal patch info from batch search. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BatchPatchInfo { + pub uuid: String, + pub purl: String, + pub tier: String, + pub cve_ids: Vec, + pub ghsa_ids: Vec, + pub severity: Option, + pub title: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchPackagePatches { + pub purl: String, + pub patches: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BatchSearchResponse { + pub packages: Vec, + pub can_access_paid_patches: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_patch_response_camel_case() { + let pr = PatchResponse { + uuid: "u1".into(), + purl: "pkg:npm/x@1".into(), + published_at: "2024-01-01".into(), + files: HashMap::new(), + vulnerabilities: HashMap::new(), + description: "desc".into(), + license: "MIT".into(), + tier: "free".into(), + }; + let json = serde_json::to_string(&pr).unwrap(); + assert!(json.contains("publishedAt")); + assert!(!json.contains("published_at")); + } + + #[test] + fn test_patch_response_deserialize() { + let json = r#"{ + "uuid": "u1", + "purl": "pkg:npm/x@1", + "publishedAt": "2024-01-01", + "files": {}, + "vulnerabilities": {}, + "description": "A patch", + "license": "MIT", + "tier": "free" + }"#; + let pr: PatchResponse = serde_json::from_str(json).unwrap(); + assert_eq!(pr.uuid, "u1"); + assert_eq!(pr.published_at, "2024-01-01"); + } + + #[test] + fn test_patch_file_response_optional_fields() { + let pfr = PatchFileResponse { + before_hash: None, + after_hash: None, + socket_blob: None, + blob_content: None, + before_blob_content: None, + }; + let json = serde_json::to_string(&pfr).unwrap(); + let back: PatchFileResponse = serde_json::from_str(&json).unwrap(); + assert!(back.before_hash.is_none()); + assert!(back.after_hash.is_none()); + assert!(back.socket_blob.is_none()); + assert!(back.blob_content.is_none()); + assert!(back.before_blob_content.is_none()); + // Verify camelCase field names + assert!(json.contains("beforeHash")); + assert!(json.contains("afterHash")); + assert!(json.contains("socketBlob")); + assert!(json.contains("blobContent")); + assert!(json.contains("beforeBlobContent")); + } + + #[test] + fn test_search_response_camel_case() { + let sr = SearchResponse { + patches: Vec::new(), + can_access_paid_patches: true, + }; + let json = serde_json::to_string(&sr).unwrap(); + assert!(json.contains("canAccessPaidPatches")); + assert!(!json.contains("can_access_paid_patches")); + } + + #[test] + fn test_batch_search_response_roundtrip() { + let bsr = BatchSearchResponse { + packages: vec![BatchPackagePatches { + purl: "pkg:npm/x@1".into(), + patches: vec![BatchPatchInfo { + uuid: "u1".into(), + purl: "pkg:npm/x@1".into(), + tier: "free".into(), + cve_ids: vec!["CVE-2024-0001".into()], + ghsa_ids: vec!["GHSA-1111-2222-3333".into()], + severity: Some("high".into()), + title: "Test".into(), + }], + }], + can_access_paid_patches: false, + }; + let json = serde_json::to_string(&bsr).unwrap(); + let back: BatchSearchResponse = serde_json::from_str(&json).unwrap(); + assert_eq!(back.packages.len(), 1); + assert_eq!(back.packages[0].patches.len(), 1); + assert!(!back.can_access_paid_patches); + } + + #[test] + fn test_batch_patch_info_camel_case() { + let bpi = BatchPatchInfo { + uuid: "u1".into(), + purl: "pkg:npm/x@1".into(), + tier: "free".into(), + cve_ids: vec!["CVE-2024-0001".into()], + ghsa_ids: vec!["GHSA-1111-2222-3333".into()], + severity: Some("high".into()), + title: "Test".into(), + }; + let json = serde_json::to_string(&bpi).unwrap(); + assert!(json.contains("cveIds")); + assert!(json.contains("ghsaIds")); + assert!(!json.contains("cve_ids")); + assert!(!json.contains("ghsa_ids")); + } + + #[test] + fn test_vulnerability_response_no_rename() { + // VulnerabilityResponse does NOT have rename_all, so fields are snake_case + let vr = VulnerabilityResponse { + cves: vec!["CVE-2024-0001".into()], + summary: "Test".into(), + severity: "high".into(), + description: "A vulnerability".into(), + }; + let json = serde_json::to_string(&vr).unwrap(); + // Without rename_all, field names stay as-is (already lowercase single-word) + assert!(json.contains("\"cves\"")); + assert!(json.contains("\"summary\"")); + assert!(json.contains("\"severity\"")); + assert!(json.contains("\"description\"")); + } + + #[test] + fn test_patch_search_result_roundtrip() { + let psr = PatchSearchResult { + uuid: "u1".into(), + purl: "pkg:npm/test@1.0.0".into(), + published_at: "2024-06-15".into(), + description: "A test patch".into(), + license: "MIT".into(), + tier: "free".into(), + vulnerabilities: HashMap::new(), + }; + let json = serde_json::to_string(&psr).unwrap(); + let back: PatchSearchResult = serde_json::from_str(&json).unwrap(); + assert_eq!(back.uuid, "u1"); + assert_eq!(back.published_at, "2024-06-15"); + assert!(json.contains("publishedAt")); + } +} diff --git a/crates/socket-patch-core/src/constants.rs b/crates/socket-patch-core/src/constants.rs new file mode 100644 index 0000000..1418427 --- /dev/null +++ b/crates/socket-patch-core/src/constants.rs @@ -0,0 +1,17 @@ +/// Default path for the patch manifest file relative to the project root. +pub const DEFAULT_PATCH_MANIFEST_PATH: &str = ".socket/manifest.json"; + +/// Default folder for storing patched file blobs. +pub const DEFAULT_BLOB_FOLDER: &str = ".socket/blob"; + +/// Default Socket directory. +pub const DEFAULT_SOCKET_DIR: &str = ".socket"; + +/// Default public patch API URL for free patches (no auth required). +pub const DEFAULT_PATCH_API_PROXY_URL: &str = "https://patches-api.socket.dev"; + +/// Default Socket API URL for authenticated access. +pub const DEFAULT_SOCKET_API_URL: &str = "https://api.socket.dev"; + +/// User-Agent header value for API requests. +pub const USER_AGENT: &str = "SocketPatchCLI/1.0"; diff --git a/crates/socket-patch-core/src/crawlers/cargo_crawler.rs b/crates/socket-patch-core/src/crawlers/cargo_crawler.rs new file mode 100644 index 0000000..05bdfa1 --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/cargo_crawler.rs @@ -0,0 +1,654 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +// --------------------------------------------------------------------------- +// Cargo.toml minimal parser +// --------------------------------------------------------------------------- + +/// Parse `name` and `version` from a `Cargo.toml` `[package]` section. +/// +/// Uses a simple line-based parser — no TOML crate dependency. +/// Handles `name = "..."` and `version = "..."` within the `[package]` table. +/// Returns `None` if `version.workspace = true` or fields are missing. +pub fn parse_cargo_toml_name_version(content: &str) -> Option<(String, String)> { + let mut in_package = false; + let mut name: Option = None; + let mut version: Option = None; + + for line in content.lines() { + let trimmed = line.trim(); + + // Skip comments and empty lines + if trimmed.starts_with('#') || trimmed.is_empty() { + continue; + } + + // Track table headers + if trimmed.starts_with('[') { + if trimmed == "[package]" { + in_package = true; + } else { + // We left the [package] section + if in_package { + break; + } + } + continue; + } + + if !in_package { + continue; + } + + if let Some(val) = extract_string_value(trimmed, "name") { + name = Some(val); + } else if let Some(val) = extract_string_value(trimmed, "version") { + version = Some(val); + } else if trimmed.starts_with("version") && trimmed.contains("workspace") { + // version.workspace = true — cannot determine version from this file + return None; + } + + if name.is_some() && version.is_some() { + break; + } + } + + match (name, version) { + (Some(n), Some(v)) if !n.is_empty() && !v.is_empty() => Some((n, v)), + _ => None, + } +} + +/// Extract a quoted string value from a `key = "value"` line. +fn extract_string_value(line: &str, key: &str) -> Option { + let rest = line.strip_prefix(key)?; + let rest = rest.trim_start(); + let rest = rest.strip_prefix('=')?; + let rest = rest.trim_start(); + let rest = rest.strip_prefix('"')?; + let end = rest.find('"')?; + Some(rest[..end].to_string()) +} + +// --------------------------------------------------------------------------- +// CargoCrawler +// --------------------------------------------------------------------------- + +/// Cargo/Rust ecosystem crawler for discovering crates in the local +/// vendor directory or the Cargo registry cache (`$CARGO_HOME/registry/src/`). +pub struct CargoCrawler; + +impl CargoCrawler { + /// Create a new `CargoCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get crate source paths based on options. + /// + /// In local mode, checks `/vendor/` first, then falls back to + /// `$CARGO_HOME/registry/src/` index directories — but only if the + /// `cwd` actually contains a `Cargo.toml` or `Cargo.lock` (i.e. is a + /// Rust project). This prevents scanning the global cargo registry + /// when patching a non-Rust project. + /// + /// In global mode, returns `$CARGO_HOME/registry/src/` index directories + /// (or the `--global-prefix` override). + pub async fn get_crate_source_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(Self::get_registry_src_paths().await); + } + + // Local mode: check vendor first + let vendor_dir = options.cwd.join("vendor"); + if is_dir(&vendor_dir).await { + return Ok(vec![vendor_dir]); + } + + // Only fall back to global registry if this looks like a Cargo project + let has_cargo_toml = tokio::fs::metadata(options.cwd.join("Cargo.toml")) + .await + .is_ok(); + let has_cargo_lock = tokio::fs::metadata(options.cwd.join("Cargo.lock")) + .await + .is_ok(); + + if has_cargo_toml || has_cargo_lock { + return Ok(Self::get_registry_src_paths().await); + } + + // Not a Cargo project — return empty + Ok(Vec::new()) + } + + /// Crawl all discovered crate source directories and return every + /// package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let src_paths = self.get_crate_source_paths(options).await.unwrap_or_default(); + + for src_path in &src_paths { + let found = self.scan_crate_source(src_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL inside a single crate source directory. + /// + /// Supports two layouts: + /// - **Registry**: `-/Cargo.toml` + /// - **Vendor**: `/Cargo.toml` (version verified from file contents) + pub async fn find_by_purls( + &self, + src_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + for purl in purls { + if let Some((name, version)) = crate::utils::purl::parse_cargo_purl(purl) { + // Try registry layout: -/ + let registry_dir = src_path.join(format!("{name}-{version}")); + if self + .verify_crate_at_path(®istry_dir, name, version) + .await + { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: registry_dir, + }, + ); + continue; + } + + // Try vendor layout: / + let vendor_dir = src_path.join(name); + if self + .verify_crate_at_path(&vendor_dir, name, version) + .await + { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: vendor_dir, + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// List subdirectories of `$CARGO_HOME/registry/src/`. + /// + /// Each subdirectory corresponds to a registry index + /// (e.g. `index.crates.io-6f17d22bba15001f/`). + async fn get_registry_src_paths() -> Vec { + let cargo_home = Self::cargo_home(); + let registry_src = cargo_home.join("registry").join("src"); + + let mut paths = Vec::new(); + + let mut entries = match tokio::fs::read_dir(®istry_src).await { + Ok(rd) => rd, + Err(_) => return paths, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if ft.is_dir() { + paths.push(registry_src.join(entry.file_name())); + } + } + + paths + } + + /// Scan a crate source directory (either a registry index directory or + /// a vendor directory) and return all valid crate packages found. + async fn scan_crate_source( + &self, + src_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + let mut entries = match tokio::fs::read_dir(src_path).await { + Ok(rd) => rd, + Err(_) => return results, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let dir_name = entry.file_name(); + let dir_name_str = dir_name.to_string_lossy(); + + // Skip hidden directories + if dir_name_str.starts_with('.') { + continue; + } + + let crate_path = src_path.join(&*dir_name_str); + if let Some(pkg) = + self.read_crate_cargo_toml(&crate_path, &dir_name_str, seen).await + { + results.push(pkg); + } + } + + results + } + + /// Read `Cargo.toml` from a crate directory, returning a `CrawledPackage` + /// if valid. Falls back to parsing name+version from the directory name + /// when the Cargo.toml has `version.workspace = true`. + async fn read_crate_cargo_toml( + &self, + crate_path: &Path, + dir_name: &str, + seen: &mut HashSet, + ) -> Option { + let cargo_toml_path = crate_path.join("Cargo.toml"); + let content = tokio::fs::read_to_string(&cargo_toml_path).await.ok()?; + + let (name, version) = match parse_cargo_toml_name_version(&content) { + Some(nv) => nv, + None => { + // Fallback: parse directory name as - + Self::parse_dir_name_version(dir_name)? + } + }; + + let purl = crate::utils::purl::build_cargo_purl(&name, &version); + + if seen.contains(&purl) { + return None; + } + seen.insert(purl.clone()); + + Some(CrawledPackage { + name, + version, + namespace: None, + purl, + path: crate_path.to_path_buf(), + }) + } + + /// Verify that a crate directory contains a Cargo.toml with the expected + /// name and version. + async fn verify_crate_at_path(&self, path: &Path, name: &str, version: &str) -> bool { + let cargo_toml_path = path.join("Cargo.toml"); + let content = match tokio::fs::read_to_string(&cargo_toml_path).await { + Ok(c) => c, + Err(_) => return false, + }; + + match parse_cargo_toml_name_version(&content) { + Some((n, v)) => n == name && v == version, + None => { + // Fallback: check directory name + let dir_name = path + .file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_default(); + if let Some((parsed_name, parsed_version)) = + Self::parse_dir_name_version(&dir_name) + { + parsed_name == name && parsed_version == version + } else { + false + } + } + } + } + + /// Parse a registry directory name into (name, version). + /// + /// Registry directories follow the pattern `-`, + /// where the version is the last `-`-separated component that starts with + /// a digit (handles crate names with hyphens like `serde-json`). + fn parse_dir_name_version(dir_name: &str) -> Option<(String, String)> { + // Find the last '-' followed by a digit + let mut split_idx = None; + for (i, _) in dir_name.match_indices('-') { + if dir_name[i + 1..].starts_with(|c: char| c.is_ascii_digit()) { + split_idx = Some(i); + } + } + let idx = split_idx?; + let name = &dir_name[..idx]; + let version = &dir_name[idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name.to_string(), version.to_string())) + } + + /// Get `CARGO_HOME`, defaulting to `$HOME/.cargo`. + fn cargo_home() -> PathBuf { + if let Ok(cargo_home) = std::env::var("CARGO_HOME") { + return PathBuf::from(cargo_home); + } + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + PathBuf::from(home).join(".cargo") + } +} + +impl Default for CargoCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_cargo_toml_basic() { + let content = r#" +[package] +name = "serde" +version = "1.0.200" +edition = "2021" +"#; + let (name, version) = parse_cargo_toml_name_version(content).unwrap(); + assert_eq!(name, "serde"); + assert_eq!(version, "1.0.200"); + } + + #[test] + fn test_parse_cargo_toml_with_comments() { + let content = r#" +# This is a comment +[package] +name = "tokio" # inline comment ignored since we stop at first " +version = "1.38.0" +"#; + let (name, version) = parse_cargo_toml_name_version(content).unwrap(); + assert_eq!(name, "tokio"); + assert_eq!(version, "1.38.0"); + } + + #[test] + fn test_parse_cargo_toml_workspace_version() { + let content = r#" +[package] +name = "my-crate" +version.workspace = true +"#; + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_cargo_toml_missing_fields() { + let content = r#" +[package] +name = "incomplete" +"#; + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_cargo_toml_no_package_section() { + let content = r#" +[dependencies] +serde = "1.0" +"#; + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_cargo_toml_stops_at_next_section() { + let content = r#" +[package] +name = "foo" + +[dependencies] +version = "fake" +"#; + // Should not find version since it's under [dependencies] + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_dir_name_version() { + assert_eq!( + CargoCrawler::parse_dir_name_version("serde-1.0.200"), + Some(("serde".to_string(), "1.0.200".to_string())) + ); + assert_eq!( + CargoCrawler::parse_dir_name_version("serde-json-1.0.120"), + Some(("serde-json".to_string(), "1.0.120".to_string())) + ); + assert_eq!( + CargoCrawler::parse_dir_name_version("tokio-1.38.0"), + Some(("tokio".to_string(), "1.38.0".to_string())) + ); + assert!(CargoCrawler::parse_dir_name_version("no-version-here").is_none()); + assert!(CargoCrawler::parse_dir_name_version("noversion").is_none()); + } + + #[tokio::test] + async fn test_find_by_purls_registry_layout() { + let dir = tempfile::tempdir().unwrap(); + let serde_dir = dir.path().join("serde-1.0.200"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let purls = vec![ + "pkg:cargo/serde@1.0.200".to_string(), + "pkg:cargo/tokio@1.38.0".to_string(), + ]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:cargo/serde@1.0.200")); + assert!(!result.contains_key("pkg:cargo/tokio@1.38.0")); + } + + #[tokio::test] + async fn test_find_by_purls_vendor_layout() { + let dir = tempfile::tempdir().unwrap(); + let serde_dir = dir.path().join("serde"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let purls = vec!["pkg:cargo/serde@1.0.200".to_string()]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:cargo/serde@1.0.200")); + } + + #[tokio::test] + async fn test_crawl_all_tempdir() { + let dir = tempfile::tempdir().unwrap(); + + // Create fake crate directories + let serde_dir = dir.path().join("serde-1.0.200"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let tokio_dir = dir.path().join("tokio-1.38.0"); + tokio::fs::create_dir_all(&tokio_dir).await.unwrap(); + tokio::fs::write( + tokio_dir.join("Cargo.toml"), + "[package]\nname = \"tokio\"\nversion = \"1.38.0\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:cargo/serde@1.0.200")); + assert!(purls.contains("pkg:cargo/tokio@1.38.0")); + } + + #[tokio::test] + async fn test_crawl_all_deduplication() { + let dir = tempfile::tempdir().unwrap(); + + // Create two directories that would resolve to the same PURL + let dir1 = dir.path().join("serde-1.0.200"); + tokio::fs::create_dir_all(&dir1).await.unwrap(); + tokio::fs::write( + dir1.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + // This would be found if we scan the parent twice + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:cargo/serde@1.0.200"); + } + + #[tokio::test] + async fn test_crawl_workspace_version_fallback() { + let dir = tempfile::tempdir().unwrap(); + + // Create a crate with workspace version — should fall back to dir name parsing + let crate_dir = dir.path().join("my-crate-0.5.0"); + tokio::fs::create_dir_all(&crate_dir).await.unwrap(); + tokio::fs::write( + crate_dir.join("Cargo.toml"), + "[package]\nname = \"my-crate\"\nversion.workspace = true\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:cargo/my-crate@0.5.0"); + } + + #[tokio::test] + async fn test_vendor_layout_via_get_crate_source_paths() { + let dir = tempfile::tempdir().unwrap(); + let vendor = dir.path().join("vendor"); + tokio::fs::create_dir_all(&vendor).await.unwrap(); + + let serde_dir = vendor.join("serde"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let paths = crawler.get_crate_source_paths(&options).await.unwrap(); + assert_eq!(paths.len(), 1); + assert_eq!(paths[0], vendor); + } +} diff --git a/crates/socket-patch-core/src/crawlers/composer_crawler.rs b/crates/socket-patch-core/src/crawlers/composer_crawler.rs new file mode 100644 index 0000000..a9b504e --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/composer_crawler.rs @@ -0,0 +1,466 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use serde::Deserialize; + +use super::types::{CrawledPackage, CrawlerOptions}; + +/// PHP/Composer ecosystem crawler for discovering packages in Composer +/// vendor directories. +pub struct ComposerCrawler; + +/// Composer 2 installed.json format: `{"packages": [...]}` +#[derive(Deserialize)] +struct InstalledJsonV2 { + packages: Vec, +} + +/// A single package entry from installed.json. +#[derive(Deserialize)] +struct ComposerPackageEntry { + name: String, + version: String, +} + +impl ComposerCrawler { + /// Create a new `ComposerCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get vendor paths based on options. + /// + /// In global mode, checks `$COMPOSER_HOME/vendor/` (env var, command + /// fallback, or platform defaults). + /// + /// In local mode, checks `/vendor/` but only if the directory + /// contains `composer/installed.json` and the cwd looks like a PHP + /// project (`composer.json` or `composer.lock` present). + pub async fn get_vendor_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(Self::get_global_vendor_paths().await); + } + + // Local mode + let vendor_dir = options.cwd.join("vendor"); + let installed_json = vendor_dir.join("composer").join("installed.json"); + + if !is_dir(&vendor_dir).await || !is_file(&installed_json).await { + return Ok(Vec::new()); + } + + // Only return if this looks like a PHP project + let has_composer_json = is_file(&options.cwd.join("composer.json")).await; + let has_composer_lock = is_file(&options.cwd.join("composer.lock")).await; + + if has_composer_json || has_composer_lock { + Ok(vec![vendor_dir]) + } else { + Ok(Vec::new()) + } + } + + /// Crawl all discovered vendor paths and return every package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let vendor_paths = self.get_vendor_paths(options).await.unwrap_or_default(); + + for vendor_path in &vendor_paths { + let entries = read_installed_json(vendor_path).await; + for entry in entries { + if let Some((namespace, name)) = entry.name.split_once('/') { + let purl = + crate::utils::purl::build_composer_purl(namespace, name, &entry.version); + + if seen.contains(&purl) { + continue; + } + seen.insert(purl.clone()); + + let pkg_path = vendor_path.join(namespace).join(name); + + packages.push(CrawledPackage { + name: name.to_string(), + version: entry.version, + namespace: Some(namespace.to_string()), + purl, + path: pkg_path, + }); + } + } + } + + packages + } + + /// Find specific packages by PURL inside a single vendor directory. + pub async fn find_by_purls( + &self, + vendor_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + // Build a name -> version lookup from installed.json + let entries = read_installed_json(vendor_path).await; + let installed: HashMap = entries + .into_iter() + .map(|e| (e.name, e.version)) + .collect(); + + for purl in purls { + if let Some(((namespace, name), version)) = + crate::utils::purl::parse_composer_purl(purl) + { + let full_name = format!("{namespace}/{name}"); + let pkg_dir = vendor_path.join(namespace).join(name); + + if !is_dir(&pkg_dir).await { + continue; + } + + // Verify version matches installed.json + if let Some(installed_version) = installed.get(&full_name) { + if installed_version == version { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: Some(namespace.to_string()), + purl: purl.clone(), + path: pkg_dir, + }, + ); + } + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// Get global Composer vendor paths. + async fn get_global_vendor_paths() -> Vec { + let mut paths = Vec::new(); + + if let Some(composer_home) = get_composer_home().await { + let vendor_dir = composer_home.join("vendor"); + if is_dir(&vendor_dir).await { + paths.push(vendor_dir); + } + } + + paths + } +} + +impl Default for ComposerCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Get the Composer home directory. +/// +/// Checks `$COMPOSER_HOME`, then runs `composer global config home`, +/// then falls back to platform defaults. +async fn get_composer_home() -> Option { + // Check env var first + if let Ok(home) = std::env::var("COMPOSER_HOME") { + let path = PathBuf::from(home); + if is_dir(&path).await { + return Some(path); + } + } + + // Try `composer global config home` + if let Ok(output) = std::process::Command::new("composer") + .args(["global", "config", "home"]) + .output() + { + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !stdout.is_empty() { + let path = PathBuf::from(&stdout); + if is_dir(&path).await { + return Some(path); + } + } + } + } + + // Platform defaults + let home_dir = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .ok()?; + let home = PathBuf::from(home_dir); + + let candidates = [ + home.join(".composer"), + home.join(".config").join("composer"), + ]; + + for candidate in &candidates { + if is_dir(candidate).await { + return Some(candidate.clone()); + } + } + + None +} + +/// Read and parse `vendor/composer/installed.json`. +/// +/// Supports both Composer 1 (flat JSON array) and Composer 2 (`{"packages": [...]}`) formats. +async fn read_installed_json(vendor_path: &Path) -> Vec { + let installed_path = vendor_path.join("composer").join("installed.json"); + + let content = match tokio::fs::read_to_string(&installed_path).await { + Ok(c) => c, + Err(_) => return Vec::new(), + }; + + // Try Composer 2 format first (object with packages key) + if let Ok(v2) = serde_json::from_str::(&content) { + return v2.packages; + } + + // Fall back to Composer 1 format (flat array) + if let Ok(v1) = serde_json::from_str::>(&content) { + return v1; + } + + Vec::new() +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +/// Check whether a path is a file. +async fn is_file(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_file()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_crawl_all_composer() { + let dir = tempfile::tempdir().unwrap(); + let vendor_dir = dir.path().join("vendor"); + + // Create installed.json (v2 format) + let composer_dir = vendor_dir.join("composer"); + tokio::fs::create_dir_all(&composer_dir).await.unwrap(); + tokio::fs::write( + composer_dir.join("installed.json"), + r#"{"packages": [ + {"name": "monolog/monolog", "version": "3.5.0"}, + {"name": "symfony/console", "version": "6.4.1"} + ]}"#, + ) + .await + .unwrap(); + + // Create package directories + tokio::fs::create_dir_all(vendor_dir.join("monolog").join("monolog")) + .await + .unwrap(); + tokio::fs::create_dir_all(vendor_dir.join("symfony").join("console")) + .await + .unwrap(); + + // Create composer.json so it's recognized as a PHP project + tokio::fs::write(dir.path().join("composer.json"), "{}") + .await + .unwrap(); + + let crawler = ComposerCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:composer/monolog/monolog@3.5.0")); + assert!(purls.contains("pkg:composer/symfony/console@6.4.1")); + + // Verify namespace is set + let monolog = packages.iter().find(|p| p.name == "monolog").unwrap(); + assert_eq!(monolog.namespace, Some("monolog".to_string())); + } + + #[tokio::test] + async fn test_find_by_purls_composer() { + let dir = tempfile::tempdir().unwrap(); + let vendor_dir = dir.path().join("vendor"); + + // Create installed.json + let composer_dir = vendor_dir.join("composer"); + tokio::fs::create_dir_all(&composer_dir).await.unwrap(); + tokio::fs::write( + composer_dir.join("installed.json"), + r#"{"packages": [ + {"name": "monolog/monolog", "version": "3.5.0"}, + {"name": "symfony/console", "version": "6.4.1"} + ]}"#, + ) + .await + .unwrap(); + + // Create package directories + tokio::fs::create_dir_all(vendor_dir.join("monolog").join("monolog")) + .await + .unwrap(); + tokio::fs::create_dir_all(vendor_dir.join("symfony").join("console")) + .await + .unwrap(); + + let crawler = ComposerCrawler::new(); + let purls = vec![ + "pkg:composer/monolog/monolog@3.5.0".to_string(), + "pkg:composer/symfony/console@6.4.1".to_string(), + "pkg:composer/guzzle/guzzle@7.0.0".to_string(), // not installed + ]; + let result = crawler.find_by_purls(&vendor_dir, &purls).await.unwrap(); + + assert_eq!(result.len(), 2); + assert!(result.contains_key("pkg:composer/monolog/monolog@3.5.0")); + assert!(result.contains_key("pkg:composer/symfony/console@6.4.1")); + assert!(!result.contains_key("pkg:composer/guzzle/guzzle@7.0.0")); + } + + #[tokio::test] + async fn test_installed_json_v1_format() { + let dir = tempfile::tempdir().unwrap(); + let vendor_dir = dir.path(); + + // Create installed.json in Composer 1 format (flat array) + let composer_dir = vendor_dir.join("composer"); + tokio::fs::create_dir_all(&composer_dir).await.unwrap(); + tokio::fs::write( + composer_dir.join("installed.json"), + r#"[ + {"name": "monolog/monolog", "version": "2.9.1"}, + {"name": "psr/log", "version": "3.0.0"} + ]"#, + ) + .await + .unwrap(); + + let entries = read_installed_json(vendor_dir).await; + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].name, "monolog/monolog"); + assert_eq!(entries[0].version, "2.9.1"); + assert_eq!(entries[1].name, "psr/log"); + assert_eq!(entries[1].version, "3.0.0"); + } + + #[tokio::test] + async fn test_installed_json_v2_format() { + let dir = tempfile::tempdir().unwrap(); + let vendor_dir = dir.path(); + + // Create installed.json in Composer 2 format + let composer_dir = vendor_dir.join("composer"); + tokio::fs::create_dir_all(&composer_dir).await.unwrap(); + tokio::fs::write( + composer_dir.join("installed.json"), + r#"{"packages": [ + {"name": "symfony/console", "version": "v6.4.1"}, + {"name": "symfony/string", "version": "v6.4.0"} + ]}"#, + ) + .await + .unwrap(); + + let entries = read_installed_json(vendor_dir).await; + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].name, "symfony/console"); + assert_eq!(entries[0].version, "v6.4.1"); + } + + #[tokio::test] + async fn test_non_php_project_returns_empty() { + let dir = tempfile::tempdir().unwrap(); + + // Create vendor dir with installed.json but no composer.json/lock + let vendor_dir = dir.path().join("vendor"); + let composer_dir = vendor_dir.join("composer"); + tokio::fs::create_dir_all(&composer_dir).await.unwrap(); + tokio::fs::write( + composer_dir.join("installed.json"), + r#"{"packages": [{"name": "foo/bar", "version": "1.0.0"}]}"#, + ) + .await + .unwrap(); + + let crawler = ComposerCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert!(packages.is_empty()); + } + + #[tokio::test] + async fn test_find_by_purls_version_mismatch() { + let dir = tempfile::tempdir().unwrap(); + let vendor_dir = dir.path().join("vendor"); + + let composer_dir = vendor_dir.join("composer"); + tokio::fs::create_dir_all(&composer_dir).await.unwrap(); + tokio::fs::write( + composer_dir.join("installed.json"), + r#"{"packages": [{"name": "monolog/monolog", "version": "3.5.0"}]}"#, + ) + .await + .unwrap(); + + tokio::fs::create_dir_all(vendor_dir.join("monolog").join("monolog")) + .await + .unwrap(); + + let crawler = ComposerCrawler::new(); + // Request a different version than installed + let purls = vec!["pkg:composer/monolog/monolog@2.0.0".to_string()]; + let result = crawler.find_by_purls(&vendor_dir, &purls).await.unwrap(); + + assert!(result.is_empty()); + } +} diff --git a/crates/socket-patch-core/src/crawlers/go_crawler.rs b/crates/socket-patch-core/src/crawlers/go_crawler.rs new file mode 100644 index 0000000..c4f8682 --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/go_crawler.rs @@ -0,0 +1,628 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +// --------------------------------------------------------------------------- +// Case-encoding helpers +// --------------------------------------------------------------------------- + +/// Encode a Go module path for the filesystem. +/// +/// Go's module cache uses case-encoding: uppercase letters are replaced +/// with `!` followed by the lowercase letter. +/// e.g., `"github.com/Azure/azure-sdk"` -> `"github.com/!azure/azure-sdk"` +pub fn encode_module_path(path: &str) -> String { + let mut encoded = String::with_capacity(path.len()); + for ch in path.chars() { + if ch.is_ascii_uppercase() { + encoded.push('!'); + encoded.push(ch.to_ascii_lowercase()); + } else { + encoded.push(ch); + } + } + encoded +} + +/// Decode a case-encoded Go module path. +/// +/// Reverses the encoding: `!` followed by a lowercase letter becomes the +/// uppercase letter. +/// e.g., `"github.com/!azure/azure-sdk"` -> `"github.com/Azure/azure-sdk"` +pub fn decode_module_path(encoded: &str) -> String { + let mut decoded = String::with_capacity(encoded.len()); + let mut chars = encoded.chars(); + while let Some(ch) = chars.next() { + if ch == '!' { + if let Some(next) = chars.next() { + decoded.push(next.to_ascii_uppercase()); + } + } else { + decoded.push(ch); + } + } + decoded +} + +/// Parse the `module` directive from a go.mod file. +/// +/// Returns the module path, e.g., `"github.com/gin-gonic/gin"`. +pub fn parse_go_mod_module(content: &str) -> Option { + for line in content.lines() { + let trimmed = line.trim(); + if let Some(rest) = trimmed.strip_prefix("module") { + let rest = rest.trim(); + // Handle quoted module paths + if rest.starts_with('"') && rest.ends_with('"') && rest.len() >= 2 { + return Some(rest[1..rest.len() - 1].to_string()); + } + // Unquoted module path + if !rest.is_empty() { + return Some(rest.to_string()); + } + } + } + None +} + +// --------------------------------------------------------------------------- +// GoCrawler +// --------------------------------------------------------------------------- + +/// Go module ecosystem crawler for discovering modules in the Go module cache +/// (`$GOMODCACHE` or `$GOPATH/pkg/mod/`). +pub struct GoCrawler; + +impl GoCrawler { + /// Create a new `GoCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get the Go module cache paths. + /// + /// In global mode (or with `--global-prefix`), returns the module cache + /// directory directly. + /// + /// In local mode, only returns the cache path if the cwd contains a + /// `go.mod` or `go.sum` file (i.e., is a Go project). + pub async fn get_module_cache_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(Self::get_gomodcache().map_or_else(Vec::new, |p| vec![p])); + } + + // Local mode: only scan if this looks like a Go project + let has_go_mod = tokio::fs::metadata(options.cwd.join("go.mod")) + .await + .is_ok(); + let has_go_sum = tokio::fs::metadata(options.cwd.join("go.sum")) + .await + .is_ok(); + + if has_go_mod || has_go_sum { + return Ok(Self::get_gomodcache().map_or_else(Vec::new, |p| vec![p])); + } + + // Not a Go project — return empty + Ok(Vec::new()) + } + + /// Crawl the Go module cache and return all discovered packages. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let cache_paths = self + .get_module_cache_paths(options) + .await + .unwrap_or_default(); + + for cache_path in &cache_paths { + let found = self.scan_module_cache(cache_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL in the module cache. + pub async fn find_by_purls( + &self, + cache_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + for purl in purls { + if let Some((module_path, version)) = crate::utils::purl::parse_golang_purl(purl) { + // Encode the module path for the filesystem + let encoded = encode_module_path(module_path); + + // Go module cache layout: @/ + let module_dir = cache_path.join(format!("{encoded}@{version}")); + + if is_dir(&module_dir).await { + // Split module_path into namespace and name + let (namespace, name) = split_module_path(module_path); + + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: Some(namespace.to_string()), + purl: purl.clone(), + path: module_dir, + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// Get `GOMODCACHE`, falling back to `$GOPATH/pkg/mod/` or `$HOME/go/pkg/mod/`. + fn get_gomodcache() -> Option { + if let Ok(cache) = std::env::var("GOMODCACHE") { + let p = PathBuf::from(cache); + if !p.as_os_str().is_empty() { + return Some(p); + } + } + if let Ok(gopath) = std::env::var("GOPATH") { + let p = PathBuf::from(gopath); + if !p.as_os_str().is_empty() { + return Some(p.join("pkg").join("mod")); + } + } + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .ok()?; + Some(PathBuf::from(home).join("go").join("pkg").join("mod")) + } + + /// Recursively scan the module cache directory tree. + /// + /// Go module cache has a hierarchical structure: + /// `/github.com/user/project@v1.0.0/` + /// + /// We walk the tree looking for directories whose name contains `@` + /// (the version separator), which marks a versioned module. + async fn scan_module_cache( + &self, + cache_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + self.scan_dir_recursive(cache_path, cache_path, seen, &mut results) + .await; + results + } + + fn scan_dir_recursive<'a>( + &'a self, + base_path: &'a Path, + current_path: &'a Path, + seen: &'a mut HashSet, + results: &'a mut Vec, + ) -> std::pin::Pin + 'a>> { + Box::pin(async move { + let mut entries = match tokio::fs::read_dir(current_path).await { + Ok(rd) => rd, + Err(_) => return, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let dir_name = entry.file_name(); + let dir_name_str = dir_name.to_string_lossy(); + + // Skip hidden directories and the cache metadata directory + if dir_name_str.starts_with('.') || dir_name_str == "cache" { + continue; + } + + let full_path = current_path.join(&*dir_name_str); + + // Check if this directory has `@` in its name (versioned module) + if dir_name_str.contains('@') { + if let Some(pkg) = + self.parse_versioned_dir(base_path, &full_path, &dir_name_str, seen) + { + results.push(pkg); + } + } else { + // Recurse into subdirectories + self.scan_dir_recursive(base_path, &full_path, seen, results) + .await; + } + } + }) + } + + /// Parse a versioned directory (containing `@`) into a `CrawledPackage`. + fn parse_versioned_dir( + &self, + base_path: &Path, + dir_path: &Path, + _dir_name: &str, + seen: &mut HashSet, + ) -> Option { + // Get the relative path from the cache root. + // Normalize to forward slashes so PURLs are correct on Windows. + let rel_path = dir_path.strip_prefix(base_path).ok()?; + let rel_str = rel_path.to_string_lossy().replace('\\', "/"); + + // Find the last `@` to split module path and version + let at_idx = rel_str.rfind('@')?; + let encoded_module_path = &rel_str[..at_idx]; + let version = &rel_str[at_idx + 1..]; + + if encoded_module_path.is_empty() || version.is_empty() { + return None; + } + + // Decode case-encoded path + let module_path = decode_module_path(encoded_module_path); + + let purl = crate::utils::purl::build_golang_purl(&module_path, version); + + if seen.contains(&purl) { + return None; + } + seen.insert(purl.clone()); + + let (namespace, name) = split_module_path(&module_path); + + Some(CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: Some(namespace.to_string()), + purl, + path: dir_path.to_path_buf(), + }) + } +} + +impl Default for GoCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Split a module path into (namespace, name). +/// +/// e.g., `"github.com/gin-gonic/gin"` -> `("github.com/gin-gonic", "gin")` +/// e.g., `"golang.org/x/text"` -> `("golang.org/x", "text")` +fn split_module_path(module_path: &str) -> (&str, &str) { + match module_path.rfind('/') { + Some(idx) => (&module_path[..idx], &module_path[idx + 1..]), + None => ("", module_path), + } +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_module_path_no_uppercase() { + assert_eq!( + encode_module_path("github.com/gin-gonic/gin"), + "github.com/gin-gonic/gin" + ); + } + + #[test] + fn test_encode_module_path_with_uppercase() { + assert_eq!( + encode_module_path("github.com/Azure/azure-sdk-for-go"), + "github.com/!azure/azure-sdk-for-go" + ); + } + + #[test] + fn test_encode_module_path_multiple_uppercase() { + assert_eq!( + encode_module_path("github.com/BurntSushi/toml"), + "github.com/!burnt!sushi/toml" + ); + } + + #[test] + fn test_decode_module_path_no_encoding() { + assert_eq!( + decode_module_path("github.com/gin-gonic/gin"), + "github.com/gin-gonic/gin" + ); + } + + #[test] + fn test_decode_module_path_with_encoding() { + assert_eq!( + decode_module_path("github.com/!azure/azure-sdk-for-go"), + "github.com/Azure/azure-sdk-for-go" + ); + } + + #[test] + fn test_encode_decode_roundtrip() { + let original = "github.com/Azure/azure-sdk-for-go"; + assert_eq!(decode_module_path(&encode_module_path(original)), original); + + let original2 = "github.com/BurntSushi/toml"; + assert_eq!( + decode_module_path(&encode_module_path(original2)), + original2 + ); + + let original3 = "github.com/gin-gonic/gin"; + assert_eq!( + decode_module_path(&encode_module_path(original3)), + original3 + ); + } + + #[test] + fn test_parse_go_mod_module_basic() { + let content = "module github.com/gin-gonic/gin\n\ngo 1.21\n"; + assert_eq!( + parse_go_mod_module(content), + Some("github.com/gin-gonic/gin".to_string()) + ); + } + + #[test] + fn test_parse_go_mod_module_quoted() { + let content = "module \"github.com/gin-gonic/gin\"\n\ngo 1.21\n"; + assert_eq!( + parse_go_mod_module(content), + Some("github.com/gin-gonic/gin".to_string()) + ); + } + + #[test] + fn test_parse_go_mod_module_missing() { + let content = "go 1.21\n\nrequire (\n\tgithub.com/gin-gonic/gin v1.9.1\n)\n"; + assert_eq!(parse_go_mod_module(content), None); + } + + #[test] + fn test_split_module_path() { + let (ns, name) = split_module_path("github.com/gin-gonic/gin"); + assert_eq!(ns, "github.com/gin-gonic"); + assert_eq!(name, "gin"); + + let (ns, name) = split_module_path("golang.org/x/text"); + assert_eq!(ns, "golang.org/x"); + assert_eq!(name, "text"); + + let (ns, name) = split_module_path("gopkg.in/yaml.v3"); + assert_eq!(ns, "gopkg.in"); + assert_eq!(name, "yaml.v3"); + } + + #[tokio::test] + async fn test_find_by_purls_basic() { + let dir = tempfile::tempdir().unwrap(); + + // Create a fake module directory: github.com/gin-gonic/gin@v1.9.1 + let module_dir = dir + .path() + .join("github.com") + .join("gin-gonic") + .join("gin@v1.9.1"); + tokio::fs::create_dir_all(&module_dir).await.unwrap(); + + let crawler = GoCrawler::new(); + let purls = vec![ + "pkg:golang/github.com/gin-gonic/gin@v1.9.1".to_string(), + "pkg:golang/github.com/missing/pkg@v0.1.0".to_string(), + ]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:golang/github.com/gin-gonic/gin@v1.9.1")); + assert!(!result.contains_key("pkg:golang/github.com/missing/pkg@v0.1.0")); + + let pkg = &result["pkg:golang/github.com/gin-gonic/gin@v1.9.1"]; + assert_eq!(pkg.name, "gin"); + assert_eq!(pkg.version, "v1.9.1"); + assert_eq!(pkg.namespace, Some("github.com/gin-gonic".to_string())); + } + + #[tokio::test] + async fn test_find_by_purls_case_encoded() { + let dir = tempfile::tempdir().unwrap(); + + // Create a case-encoded module directory + let module_dir = dir + .path() + .join("github.com") + .join("!azure") + .join("azure-sdk-for-go@v1.0.0"); + tokio::fs::create_dir_all(&module_dir).await.unwrap(); + + let crawler = GoCrawler::new(); + let purls = vec!["pkg:golang/github.com/Azure/azure-sdk-for-go@v1.0.0".to_string()]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + let pkg = &result["pkg:golang/github.com/Azure/azure-sdk-for-go@v1.0.0"]; + assert_eq!(pkg.name, "azure-sdk-for-go"); + assert_eq!(pkg.namespace, Some("github.com/Azure".to_string())); + } + + #[tokio::test] + async fn test_crawl_all_tempdir() { + let dir = tempfile::tempdir().unwrap(); + + // Create fake module directories + let gin_dir = dir + .path() + .join("github.com") + .join("gin-gonic") + .join("gin@v1.9.1"); + tokio::fs::create_dir_all(&gin_dir).await.unwrap(); + + let text_dir = dir.path().join("golang.org").join("x").join("text@v0.14.0"); + tokio::fs::create_dir_all(&text_dir).await.unwrap(); + + let crawler = GoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:golang/github.com/gin-gonic/gin@v1.9.1")); + assert!(purls.contains("pkg:golang/golang.org/x/text@v0.14.0")); + } + + #[tokio::test] + async fn test_crawl_all_deduplication() { + let dir = tempfile::tempdir().unwrap(); + + // Create a single module + let gin_dir = dir + .path() + .join("github.com") + .join("gin-gonic") + .join("gin@v1.9.1"); + tokio::fs::create_dir_all(&gin_dir).await.unwrap(); + + let crawler = GoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!( + packages[0].purl, + "pkg:golang/github.com/gin-gonic/gin@v1.9.1" + ); + } + + #[tokio::test] + async fn test_crawl_all_skips_cache_dir() { + let dir = tempfile::tempdir().unwrap(); + + // Create a real module + let gin_dir = dir + .path() + .join("github.com") + .join("gin-gonic") + .join("gin@v1.9.1"); + tokio::fs::create_dir_all(&gin_dir).await.unwrap(); + + // Create a "cache" dir (should be skipped) + let cache_dir = dir.path().join("cache").join("download").join("sumdb"); + tokio::fs::create_dir_all(&cache_dir).await.unwrap(); + + let crawler = GoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + } + + #[tokio::test] + async fn test_local_mode_no_go_mod_returns_empty() { + let dir = tempfile::tempdir().unwrap(); + + // No go.mod or go.sum in cwd + let crawler = GoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let paths = crawler.get_module_cache_paths(&options).await.unwrap(); + assert!(paths.is_empty()); + } + + #[tokio::test] + async fn test_crawl_case_encoded_modules() { + let dir = tempfile::tempdir().unwrap(); + + // Create case-encoded module + let azure_dir = dir + .path() + .join("github.com") + .join("!azure") + .join("azure-sdk-for-go@v1.0.0"); + tokio::fs::create_dir_all(&azure_dir).await.unwrap(); + + let crawler = GoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!( + packages[0].purl, + "pkg:golang/github.com/Azure/azure-sdk-for-go@v1.0.0" + ); + assert_eq!(packages[0].name, "azure-sdk-for-go"); + assert_eq!( + packages[0].namespace, + Some("github.com/Azure".to_string()) + ); + } +} diff --git a/crates/socket-patch-core/src/crawlers/maven_crawler.rs b/crates/socket-patch-core/src/crawlers/maven_crawler.rs new file mode 100644 index 0000000..c78c875 --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/maven_crawler.rs @@ -0,0 +1,821 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +// --------------------------------------------------------------------------- +// POM XML minimal parser +// --------------------------------------------------------------------------- + +/// Extract the text value between `` and `` on a single line. +fn extract_xml_value(line: &str, element: &str) -> Option { + let open = format!("<{element}>"); + let close = format!(""); + let start = line.find(&open)?; + let value_start = start + open.len(); + let end = line[value_start..].find(&close)?; + let value = line[value_start..value_start + end].trim().to_string(); + if value.is_empty() { + None + } else { + Some(value) + } +} + +/// Parse `groupId`, `artifactId`, and `version` from a POM XML file. +/// +/// Uses a simple line-based parser — no XML crate dependency. +/// Tracks nesting depth to skip ``, ``, ``, etc. +/// Extracts top-level ``, ``, `` from ``. +/// Falls back to `` block for groupId if missing at top level. +/// Returns `None` for property references (`${...}`). +pub fn parse_pom_group_artifact_version(content: &str) -> Option<(String, String, String)> { + let mut group_id: Option = None; + let mut artifact_id: Option = None; + let mut version: Option = None; + let mut parent_group_id: Option = None; + + let mut in_parent = false; + let mut skip_depth: u32 = 0; + + let skip_sections = [ + "dependencies", + "build", + "profiles", + "reporting", + "dependencyManagement", + "pluginManagement", + "modules", + "distributionManagement", + "repositories", + "pluginRepositories", + ]; + + for line in content.lines() { + let trimmed = line.trim(); + + // Check for skip section open/close + for section in &skip_sections { + let open_tag = format!("<{section}"); + let close_tag = format!(""); + if trimmed.contains(&open_tag) && !trimmed.contains(&close_tag) { + skip_depth += 1; + } + if trimmed.contains(&close_tag) { + skip_depth = skip_depth.saturating_sub(1); + } + } + + if skip_depth > 0 { + continue; + } + + // Track parent section + if trimmed.contains("") { + in_parent = false; + continue; + } + + if in_parent { + if parent_group_id.is_none() { + if let Some(val) = extract_xml_value(trimmed, "groupId") { + if val.contains("${") { + // Property reference in parent — skip + } else { + parent_group_id = Some(val); + } + } + } + continue; + } + + // Extract top-level coordinates + if group_id.is_none() { + if let Some(val) = extract_xml_value(trimmed, "groupId") { + if val.contains("${") { + return None; + } + group_id = Some(val); + } + } + if artifact_id.is_none() { + if let Some(val) = extract_xml_value(trimmed, "artifactId") { + if val.contains("${") { + return None; + } + artifact_id = Some(val); + } + } + if version.is_none() { + if let Some(val) = extract_xml_value(trimmed, "version") { + if val.contains("${") { + return None; + } + version = Some(val); + } + } + } + + // Fall back to parent groupId + let final_group_id = group_id.or(parent_group_id)?; + let final_artifact_id = artifact_id?; + let final_version = version?; + + if final_group_id.is_empty() || final_artifact_id.is_empty() || final_version.is_empty() { + return None; + } + + Some((final_group_id, final_artifact_id, final_version)) +} + +// --------------------------------------------------------------------------- +// Path coordinate helpers +// --------------------------------------------------------------------------- + +/// Convert a Maven groupId to a path segment (e.g. `org.apache.commons` -> `org/apache/commons`). +fn group_id_to_path(group_id: &str) -> String { + group_id.replace('.', "/") +} + +/// Convert a path segment back to a Maven groupId (e.g. `org/apache/commons` -> `org.apache.commons`). +#[allow(dead_code)] +fn path_to_group_id(path: &str) -> String { + path.replace('/', ".") +} + +/// Extract Maven coordinates from a directory path relative to the repository root. +/// +/// The Maven repository layout is: `///` +/// e.g. `org/apache/commons/commons-lang3/3.12.0/` +fn parse_path_coordinates(version_dir: &Path, repo_root: &Path) -> Option<(String, String, String)> { + let rel = version_dir.strip_prefix(repo_root).ok()?; + let components: Vec<&str> = rel + .components() + .filter_map(|c| c.as_os_str().to_str()) + .collect(); + + if components.len() < 3 { + return None; + } + + let version = components[components.len() - 1].to_string(); + let artifact_id = components[components.len() - 2].to_string(); + let group_parts = &components[..components.len() - 2]; + let group_id = group_parts.join("."); + + if group_id.is_empty() || artifact_id.is_empty() || version.is_empty() { + return None; + } + + Some((group_id, artifact_id, version)) +} + +// --------------------------------------------------------------------------- +// MavenCrawler +// --------------------------------------------------------------------------- + +/// Maven/Java ecosystem crawler for discovering packages in the local +/// Maven repository (`~/.m2/repository/`). +pub struct MavenCrawler; + +impl MavenCrawler { + /// Create a new `MavenCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get Maven repository paths based on options. + /// + /// In global mode, returns `~/.m2/repository/` (respects `$M2_HOME`, + /// `$MAVEN_REPO_LOCAL`, `--global-prefix`). + /// + /// In local mode, only returns the Maven repo if the cwd contains + /// `pom.xml`, `build.gradle`, `build.gradle.kts`, `settings.gradle`, + /// or `settings.gradle.kts` (prevents scanning for non-Java projects). + pub async fn get_maven_repo_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + let repo = Self::m2_repo_path(); + if is_dir(&repo).await { + return Ok(vec![repo]); + } + return Ok(Vec::new()); + } + + // Local mode: only return Maven repo if this looks like a Java/Maven/Gradle project + let java_markers = [ + "pom.xml", + "build.gradle", + "build.gradle.kts", + "settings.gradle", + "settings.gradle.kts", + ]; + + let mut is_java_project = false; + for marker in &java_markers { + if tokio::fs::metadata(options.cwd.join(marker)).await.is_ok() { + is_java_project = true; + break; + } + } + + if !is_java_project { + return Ok(Vec::new()); + } + + let repo = Self::m2_repo_path(); + if is_dir(&repo).await { + Ok(vec![repo]) + } else { + Ok(Vec::new()) + } + } + + /// Crawl all discovered Maven repository paths and return every + /// package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let repo_paths = self.get_maven_repo_paths(options).await.unwrap_or_default(); + + for repo_path in &repo_paths { + let found = self.scan_maven_repo(repo_path, &mut seen); + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL inside a single Maven repository path. + /// + /// For each PURL, constructs the expected path: + /// `src_path / groupId.replace('.', '/') / artifactId / version /` + /// and verifies by checking for a `.pom` file. + pub async fn find_by_purls( + &self, + src_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + for purl in purls { + if let Some((group_id, artifact_id, version)) = + crate::utils::purl::parse_maven_purl(purl) + { + let expected_path = src_path + .join(group_id_to_path(group_id)) + .join(artifact_id) + .join(version); + + if self + .verify_maven_at_path(&expected_path, group_id, artifact_id, version) + .await + { + result.insert( + purl.clone(), + CrawledPackage { + name: artifact_id.to_string(), + version: version.to_string(), + namespace: Some(group_id.to_string()), + purl: purl.clone(), + path: expected_path, + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// Get the Maven local repository path. + /// + /// Checks `$MAVEN_REPO_LOCAL`, `$M2_HOME/repository`, `$HOME/.m2/repository`. + fn m2_repo_path() -> PathBuf { + if let Ok(repo_local) = std::env::var("MAVEN_REPO_LOCAL") { + return PathBuf::from(repo_local); + } + if let Ok(m2_home) = std::env::var("M2_HOME") { + return PathBuf::from(m2_home).join("repository"); + } + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + PathBuf::from(home).join(".m2").join("repository") + } + + /// Scan a Maven repository directory and return all valid packages found. + /// + /// Uses `walkdir` to recursively find `.pom` files, then extracts + /// coordinates from the POM content or falls back to directory path parsing. + fn scan_maven_repo( + &self, + repo_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + for entry in walkdir::WalkDir::new(repo_path) + .follow_links(false) + .into_iter() + .filter_map(|e| e.ok()) + { + if !entry.file_type().is_file() { + continue; + } + let path = entry.path(); + if path.extension().is_none_or(|ext| ext != "pom") { + continue; + } + + let version_dir = match path.parent() { + Some(p) => p, + None => continue, + }; + + // Try POM parsing first, fall back to directory path parsing + let coords = std::fs::read_to_string(path) + .ok() + .and_then(|content| parse_pom_group_artifact_version(&content)) + .or_else(|| parse_path_coordinates(version_dir, repo_path)); + + if let Some((group_id, artifact_id, version)) = coords { + let purl = + crate::utils::purl::build_maven_purl(&group_id, &artifact_id, &version); + if seen.insert(purl.clone()) { + results.push(CrawledPackage { + name: artifact_id, + version, + namespace: Some(group_id), + purl, + path: version_dir.to_path_buf(), + }); + } + } + } + + results + } + + /// Verify that a Maven package directory contains a `.pom` file + /// with the expected coordinates. + async fn verify_maven_at_path( + &self, + path: &Path, + _group_id: &str, + _artifact_id: &str, + _version: &str, + ) -> bool { + // The path already encodes the coordinates (groupId/artifactId/version), + // so we just need to verify a .pom file exists here. + self.has_pom_file(path).await + } + + /// Check if a directory contains at least one `.pom` file. + async fn has_pom_file(&self, path: &Path) -> bool { + if !is_dir(path).await { + return false; + } + + let mut entries = match tokio::fs::read_dir(path).await { + Ok(rd) => rd, + Err(_) => return false, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(".pom") { + return true; + } + } + } + + false + } + + /// Find and parse the first `.pom` file in a directory. + #[allow(dead_code)] + async fn read_pom_in_dir(dir: &Path) -> Option<(String, String, String)> { + let mut entries = tokio::fs::read_dir(dir).await.ok()?; + while let Ok(Some(entry)) = entries.next_entry().await { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(".pom") { + let content = tokio::fs::read_to_string(entry.path()).await.ok()?; + return parse_pom_group_artifact_version(&content); + } + } + } + None + } +} + +impl Default for MavenCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---- POM parsing tests ---- + + #[test] + fn test_parse_pom_basic() { + let content = r#" + + 4.0.0 + org.apache.commons + commons-lang3 + 3.12.0 +"#; + let (g, a, v) = parse_pom_group_artifact_version(content).unwrap(); + assert_eq!(g, "org.apache.commons"); + assert_eq!(a, "commons-lang3"); + assert_eq!(v, "3.12.0"); + } + + #[test] + fn test_parse_pom_with_parent_group() { + let content = r#" + + + org.apache + apache + 30 + + commons-lang3 + 3.12.0 +"#; + let (g, a, v) = parse_pom_group_artifact_version(content).unwrap(); + assert_eq!(g, "org.apache"); + assert_eq!(a, "commons-lang3"); + assert_eq!(v, "3.12.0"); + } + + #[test] + fn test_parse_pom_skips_dependencies() { + let content = r#" + + com.example + my-app + 1.0.0 + + + org.other + other-lib + 2.0.0 + + +"#; + let (g, a, v) = parse_pom_group_artifact_version(content).unwrap(); + assert_eq!(g, "com.example"); + assert_eq!(a, "my-app"); + assert_eq!(v, "1.0.0"); + } + + #[test] + fn test_parse_pom_property_reference_returns_none() { + let content = r#" + + com.example + my-app + ${project.version} +"#; + assert!(parse_pom_group_artifact_version(content).is_none()); + } + + #[test] + fn test_parse_pom_missing_version_returns_none() { + let content = r#" + + com.example + my-app +"#; + assert!(parse_pom_group_artifact_version(content).is_none()); + } + + #[test] + fn test_parse_pom_group_id_from_parent_and_top_level() { + // When both project and parent have groupId, use project-level + let content = r#" + + + org.parent + + org.child + my-lib + 2.0.0 +"#; + let (g, a, v) = parse_pom_group_artifact_version(content).unwrap(); + assert_eq!(g, "org.child"); + assert_eq!(a, "my-lib"); + assert_eq!(v, "2.0.0"); + } + + #[test] + fn test_parse_pom_skips_build_section() { + let content = r#" + + com.example + my-app + 1.0.0 + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.11.0 + + + +"#; + let (g, a, v) = parse_pom_group_artifact_version(content).unwrap(); + assert_eq!(g, "com.example"); + assert_eq!(a, "my-app"); + assert_eq!(v, "1.0.0"); + } + + // ---- extract_xml_value tests ---- + + #[test] + fn test_extract_xml_value() { + assert_eq!( + extract_xml_value(" org.apache", "groupId"), + Some("org.apache".to_string()) + ); + assert_eq!( + extract_xml_value(" 1.0.0", "version"), + Some("1.0.0".to_string()) + ); + assert_eq!(extract_xml_value(" foo", "groupId"), None); + assert_eq!(extract_xml_value(" ", "groupId"), None); + } + + // ---- group_id_to_path / path_to_group_id tests ---- + + #[test] + fn test_group_id_to_path() { + assert_eq!(group_id_to_path("org.apache.commons"), "org/apache/commons"); + assert_eq!(group_id_to_path("com.google.guava"), "com/google/guava"); + assert_eq!(group_id_to_path("single"), "single"); + } + + #[test] + fn test_path_to_group_id() { + assert_eq!(path_to_group_id("org/apache/commons"), "org.apache.commons"); + assert_eq!(path_to_group_id("com/google/guava"), "com.google.guava"); + } + + // ---- parse_path_coordinates tests ---- + + #[test] + fn test_parse_path_coordinates() { + let repo = Path::new("/home/user/.m2/repository"); + let version_dir = Path::new("/home/user/.m2/repository/org/apache/commons/commons-lang3/3.12.0"); + let (g, a, v) = parse_path_coordinates(version_dir, repo).unwrap(); + assert_eq!(g, "org.apache.commons"); + assert_eq!(a, "commons-lang3"); + assert_eq!(v, "3.12.0"); + } + + #[test] + fn test_parse_path_coordinates_short_path() { + let repo = Path::new("/repo"); + let version_dir = Path::new("/repo/foo/bar"); + // Only 2 components — not enough (need at least groupId/artifactId/version) + assert!(parse_path_coordinates(version_dir, repo).is_none()); + } + + // ---- find_by_purls tests ---- + + #[tokio::test] + async fn test_find_by_purls_maven() { + let dir = tempfile::tempdir().unwrap(); + + // Create Maven repo layout: org/apache/commons/commons-lang3/3.12.0/ + let pkg_dir = dir.path() + .join("org") + .join("apache") + .join("commons") + .join("commons-lang3") + .join("3.12.0"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write( + pkg_dir.join("commons-lang3-3.12.0.pom"), + r#" + org.apache.commons + commons-lang3 + 3.12.0 +"#, + ) + .await + .unwrap(); + + let crawler = MavenCrawler::new(); + let purls = vec![ + "pkg:maven/org.apache.commons/commons-lang3@3.12.0".to_string(), + "pkg:maven/com.google.guava/guava@32.1.3-jre".to_string(), + ]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:maven/org.apache.commons/commons-lang3@3.12.0")); + assert!(!result.contains_key("pkg:maven/com.google.guava/guava@32.1.3-jre")); + + let pkg = &result["pkg:maven/org.apache.commons/commons-lang3@3.12.0"]; + assert_eq!(pkg.name, "commons-lang3"); + assert_eq!(pkg.version, "3.12.0"); + assert_eq!(pkg.namespace, Some("org.apache.commons".to_string())); + } + + // ---- crawl_all tests ---- + + #[tokio::test] + async fn test_crawl_all_maven() { + let dir = tempfile::tempdir().unwrap(); + + // Create two Maven packages + let pkg1_dir = dir.path() + .join("org") + .join("apache") + .join("commons") + .join("commons-lang3") + .join("3.12.0"); + tokio::fs::create_dir_all(&pkg1_dir).await.unwrap(); + tokio::fs::write( + pkg1_dir.join("commons-lang3-3.12.0.pom"), + r#" + org.apache.commons + commons-lang3 + 3.12.0 +"#, + ) + .await + .unwrap(); + + let pkg2_dir = dir.path() + .join("com") + .join("google") + .join("guava") + .join("guava") + .join("32.1.3-jre"); + tokio::fs::create_dir_all(&pkg2_dir).await.unwrap(); + tokio::fs::write( + pkg2_dir.join("guava-32.1.3-jre.pom"), + r#" + com.google.guava + guava + 32.1.3-jre +"#, + ) + .await + .unwrap(); + + let crawler = MavenCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:maven/org.apache.commons/commons-lang3@3.12.0")); + assert!(purls.contains("pkg:maven/com.google.guava/guava@32.1.3-jre")); + } + + #[tokio::test] + async fn test_crawl_all_deduplication() { + let dir = tempfile::tempdir().unwrap(); + + // Create one package + let pkg_dir = dir.path() + .join("com") + .join("example") + .join("my-lib") + .join("1.0.0"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write( + pkg_dir.join("my-lib-1.0.0.pom"), + r#" + com.example + my-lib + 1.0.0 +"#, + ) + .await + .unwrap(); + + let crawler = MavenCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:maven/com.example/my-lib@1.0.0"); + } + + #[tokio::test] + async fn test_crawl_fallback_to_path_parsing() { + let dir = tempfile::tempdir().unwrap(); + + // Create package with POM that has property references (can't parse) + let pkg_dir = dir.path() + .join("com") + .join("example") + .join("my-lib") + .join("2.0.0"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write( + pkg_dir.join("my-lib-2.0.0.pom"), + r#" + com.example + my-lib + ${project.version} +"#, + ) + .await + .unwrap(); + + let crawler = MavenCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:maven/com.example/my-lib@2.0.0"); + assert_eq!(packages[0].name, "my-lib"); + assert_eq!(packages[0].namespace, Some("com.example".to_string())); + } + + #[tokio::test] + async fn test_get_maven_repo_paths_not_java_project() { + let dir = tempfile::tempdir().unwrap(); + // No pom.xml or build.gradle — should return empty + let crawler = MavenCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let paths = crawler.get_maven_repo_paths(&options).await.unwrap(); + assert!(paths.is_empty()); + } + + #[tokio::test] + async fn test_get_maven_repo_paths_with_global_prefix() { + let dir = tempfile::tempdir().unwrap(); + let crawler = MavenCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let paths = crawler.get_maven_repo_paths(&options).await.unwrap(); + assert_eq!(paths.len(), 1); + assert_eq!(paths[0], dir.path().to_path_buf()); + } +} diff --git a/crates/socket-patch-core/src/crawlers/mod.rs b/crates/socket-patch-core/src/crawlers/mod.rs new file mode 100644 index 0000000..5ec0788 --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/mod.rs @@ -0,0 +1,29 @@ +pub mod npm_crawler; +pub mod python_crawler; +pub mod types; +#[cfg(feature = "cargo")] +pub mod cargo_crawler; +pub mod ruby_crawler; +#[cfg(feature = "golang")] +pub mod go_crawler; +#[cfg(feature = "maven")] +pub mod maven_crawler; +#[cfg(feature = "composer")] +pub mod composer_crawler; +#[cfg(feature = "nuget")] +pub mod nuget_crawler; + +pub use npm_crawler::NpmCrawler; +pub use python_crawler::PythonCrawler; +pub use types::*; +#[cfg(feature = "cargo")] +pub use cargo_crawler::CargoCrawler; +pub use ruby_crawler::RubyCrawler; +#[cfg(feature = "golang")] +pub use go_crawler::GoCrawler; +#[cfg(feature = "maven")] +pub use maven_crawler::MavenCrawler; +#[cfg(feature = "composer")] +pub use composer_crawler::ComposerCrawler; +#[cfg(feature = "nuget")] +pub use nuget_crawler::NuGetCrawler; diff --git a/crates/socket-patch-core/src/crawlers/npm_crawler.rs b/crates/socket-patch-core/src/crawlers/npm_crawler.rs new file mode 100644 index 0000000..e081acd --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/npm_crawler.rs @@ -0,0 +1,954 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::process::Command; + +use serde::Deserialize; + +use super::types::{CrawledPackage, CrawlerOptions}; + +/// Default batch size for crawling. +#[cfg(test)] +const DEFAULT_BATCH_SIZE: usize = 100; + +/// Directories to skip when searching for workspace node_modules. +const SKIP_DIRS: &[&str] = &[ + "dist", + "build", + "coverage", + "tmp", + "temp", + "__pycache__", + "vendor", +]; + +// --------------------------------------------------------------------------- +// Helper: read and parse package.json +// --------------------------------------------------------------------------- + +/// Minimal fields we need from package.json. +#[derive(Deserialize)] +struct PackageJsonPartial { + name: Option, + version: Option, +} + +/// Read and parse a `package.json` file, returning `(name, version)` if valid. +pub async fn read_package_json(pkg_json_path: &Path) -> Option<(String, String)> { + let content = tokio::fs::read_to_string(pkg_json_path).await.ok()?; + let pkg: PackageJsonPartial = serde_json::from_str(&content).ok()?; + let name = pkg.name?; + let version = pkg.version?; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name, version)) +} + +// --------------------------------------------------------------------------- +// Helper: parse package name into (namespace, name) +// --------------------------------------------------------------------------- + +/// Parse a full npm package name into optional namespace and bare name. +/// +/// Examples: +/// - `"@types/node"` -> `(Some("@types"), "node")` +/// - `"lodash"` -> `(None, "lodash")` +pub fn parse_package_name(full_name: &str) -> (Option, String) { + if full_name.starts_with('@') { + if let Some(slash_idx) = full_name.find('/') { + let namespace = full_name[..slash_idx].to_string(); + let name = full_name[slash_idx + 1..].to_string(); + return (Some(namespace), name); + } + } + (None, full_name.to_string()) +} + +// --------------------------------------------------------------------------- +// Helper: build PURL +// --------------------------------------------------------------------------- + +/// Build a PURL string for an npm package. +pub fn build_npm_purl(namespace: Option<&str>, name: &str, version: &str) -> String { + match namespace { + Some(ns) => format!("pkg:npm/{ns}/{name}@{version}"), + None => format!("pkg:npm/{name}@{version}"), + } +} + +// --------------------------------------------------------------------------- +// Global prefix detection helpers +// --------------------------------------------------------------------------- + +/// Get the npm global `node_modules` path via `npm root -g`. +pub fn get_npm_global_prefix() -> Result { + let output = Command::new("npm") + .args(["root", "-g"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .map_err(|e| format!("Failed to run `npm root -g`: {e}"))?; + + if !output.status.success() { + return Err( + "Failed to determine npm global prefix. Ensure npm is installed and in PATH." + .to_string(), + ); + } + + Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) +} + +/// Get the yarn global `node_modules` path via `yarn global dir`. +pub fn get_yarn_global_prefix() -> Option { + let output = Command::new("yarn") + .args(["global", "dir"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let dir = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if dir.is_empty() { + return None; + } + Some(PathBuf::from(dir).join("node_modules").to_string_lossy().to_string()) +} + +/// Get the pnpm global `node_modules` path via `pnpm root -g`. +pub fn get_pnpm_global_prefix() -> Option { + let output = Command::new("pnpm") + .args(["root", "-g"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if path.is_empty() { + return None; + } + Some(path) +} + +/// Get the bun global `node_modules` path via `bun pm bin -g`. +pub fn get_bun_global_prefix() -> Option { + let output = Command::new("bun") + .args(["pm", "bin", "-g"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let bin_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if bin_path.is_empty() { + return None; + } + + let bun_root = PathBuf::from(&bin_path); + let bun_root = bun_root.parent()?; + Some( + bun_root + .join("install") + .join("global") + .join("node_modules") + .to_string_lossy() + .to_string(), + ) +} + +// --------------------------------------------------------------------------- +// Helpers: synchronous wildcard directory resolver +// --------------------------------------------------------------------------- + +/// Resolve a path with `"*"` wildcard segments synchronously. +/// +/// Each segment is either a literal directory name or `"*"` which matches any +/// directory entry. Symlinks are followed via `std::fs::metadata`. +fn find_node_dirs_sync(base: &Path, segments: &[&str]) -> Vec { + if !base.is_dir() { + return Vec::new(); + } + if segments.is_empty() { + return vec![base.to_path_buf()]; + } + + let first = segments[0]; + let rest = &segments[1..]; + + if first == "*" { + let mut results = Vec::new(); + if let Ok(entries) = std::fs::read_dir(base) { + for entry in entries.flatten() { + // Follow symlinks: use metadata() not symlink_metadata() + let is_dir = entry + .metadata() + .map(|m| m.is_dir()) + .unwrap_or(false); + if is_dir { + results.extend(find_node_dirs_sync(&base.join(entry.file_name()), rest)); + } + } + } + results + } else { + find_node_dirs_sync(&base.join(first), rest) + } +} + +// --------------------------------------------------------------------------- +// NpmCrawler +// --------------------------------------------------------------------------- + +/// NPM ecosystem crawler for discovering packages in `node_modules`. +pub struct NpmCrawler; + +impl NpmCrawler { + /// Create a new `NpmCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get `node_modules` paths based on options. + /// + /// In global mode returns well-known global paths; in local mode walks + /// the project tree looking for `node_modules` directories (including + /// workspace packages). + pub async fn get_node_modules_paths(&self, options: &CrawlerOptions) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(self.get_global_node_modules_paths()); + } + + Ok(self.find_local_node_modules_dirs(&options.cwd).await) + } + + /// Crawl all discovered `node_modules` and return every package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let nm_paths = self.get_node_modules_paths(options).await.unwrap_or_default(); + + for nm_path in &nm_paths { + let found = self.scan_node_modules(nm_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL inside a single `node_modules` tree. + /// + /// This is an efficient O(n) lookup where n = number of PURLs: we parse + /// each PURL to derive the expected directory path, then do a direct stat + /// + `package.json` read. + pub async fn find_by_purls( + &self, + node_modules_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + // Parse each PURL to extract the directory key and expected version. + struct Target { + namespace: Option, + name: String, + version: String, + #[allow(dead_code)] purl: String, + dir_key: String, + } + + let purl_set: HashSet<&str> = purls.iter().map(|s| s.as_str()).collect(); + let mut targets: Vec = Vec::new(); + + for purl in purls { + if let Some((ns, name, version)) = Self::parse_purl_components(purl) { + let dir_key = match &ns { + Some(ns_str) => format!("{ns_str}/{name}"), + None => name.clone(), + }; + targets.push(Target { + namespace: ns, + name, + version, + purl: purl.clone(), + dir_key, + }); + } + } + + for target in &targets { + let pkg_path = node_modules_path.join(&target.dir_key); + let pkg_json_path = pkg_path.join("package.json"); + + if let Some((_, version)) = read_package_json(&pkg_json_path).await { + if version == target.version { + let purl = build_npm_purl( + target.namespace.as_deref(), + &target.name, + &version, + ); + if purl_set.contains(purl.as_str()) { + result.insert( + purl.clone(), + CrawledPackage { + name: target.name.clone(), + version, + namespace: target.namespace.clone(), + purl, + path: pkg_path.clone(), + }, + ); + } + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers – global paths + // ------------------------------------------------------------------ + + /// Collect global `node_modules` paths from all known package managers. + fn get_global_node_modules_paths(&self) -> Vec { + let mut seen = HashSet::new(); + let mut paths = Vec::new(); + + let mut add = |p: PathBuf| { + if p.is_dir() && seen.insert(p.clone()) { + paths.push(p); + } + }; + + if let Ok(npm_path) = get_npm_global_prefix() { + add(PathBuf::from(npm_path)); + } + if let Some(pnpm_path) = get_pnpm_global_prefix() { + add(PathBuf::from(pnpm_path)); + } + if let Some(yarn_path) = get_yarn_global_prefix() { + add(PathBuf::from(yarn_path)); + } + if let Some(bun_path) = get_bun_global_prefix() { + add(PathBuf::from(bun_path)); + } + + // macOS-specific fallback paths + if cfg!(target_os = "macos") { + let home = std::env::var("HOME").unwrap_or_default(); + + // Homebrew Apple Silicon + add(PathBuf::from("/opt/homebrew/lib/node_modules")); + // Homebrew Intel / default npm + add(PathBuf::from("/usr/local/lib/node_modules")); + + if !home.is_empty() { + // nvm + for p in find_node_dirs_sync( + &PathBuf::from(&home).join(".nvm/versions/node"), + &["*", "lib", "node_modules"], + ) { + add(p); + } + // volta + for p in find_node_dirs_sync( + &PathBuf::from(&home).join(".volta/tools/image/node"), + &["*", "lib", "node_modules"], + ) { + add(p); + } + // fnm + for p in find_node_dirs_sync( + &PathBuf::from(&home).join(".fnm/node-versions"), + &["*", "installation", "lib", "node_modules"], + ) { + add(p); + } + } + } + + paths + } + + // ------------------------------------------------------------------ + // Private helpers – local node_modules discovery + // ------------------------------------------------------------------ + + /// Find `node_modules` directories within the project root. + /// Recursively searches for workspace `node_modules` but stays within the + /// project. + async fn find_local_node_modules_dirs(&self, start_path: &Path) -> Vec { + let mut results = Vec::new(); + + // Direct node_modules in start_path + let direct = start_path.join("node_modules"); + if is_dir(&direct).await { + results.push(direct); + } + + // Recursively search for workspace node_modules + Self::find_workspace_node_modules(start_path, &mut results).await; + + results + } + + /// Recursively find `node_modules` in subdirectories (for monorepos / workspaces). + /// Skips symlinks, hidden dirs, and well-known non-workspace dirs. + fn find_workspace_node_modules<'a>( + dir: &'a Path, + results: &'a mut Vec, + ) -> std::pin::Pin + 'a>> { + Box::pin(async move { + let mut entries = match tokio::fs::read_dir(dir).await { + Ok(rd) => rd, + Err(_) => return, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let file_type = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + + if !file_type.is_dir() { + continue; + } + + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + // Skip node_modules, hidden dirs, and well-known build dirs + if name_str == "node_modules" + || name_str.starts_with('.') + || SKIP_DIRS.contains(&name_str.as_ref()) + { + continue; + } + + let full_path = dir.join(&name); + + // Check if this subdirectory has its own node_modules + let sub_nm = full_path.join("node_modules"); + if is_dir(&sub_nm).await { + results.push(sub_nm); + } + + // Recurse + Self::find_workspace_node_modules(&full_path, results).await; + } + }) + } + + // ------------------------------------------------------------------ + // Private helpers – scanning + // ------------------------------------------------------------------ + + /// Scan a `node_modules` directory, returning all valid packages found. + async fn scan_node_modules( + &self, + node_modules_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + let mut entries = match tokio::fs::read_dir(node_modules_path).await { + Ok(rd) => rd, + Err(_) => return results, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let name = entry.file_name(); + let name_str = name.to_string_lossy().to_string(); + + // Skip hidden files and node_modules + if name_str.starts_with('.') || name_str == "node_modules" { + continue; + } + + let file_type = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + + // Allow both directories and symlinks (pnpm uses symlinks) + if !file_type.is_dir() && !file_type.is_symlink() { + continue; + } + + let entry_path = node_modules_path.join(&name_str); + + if name_str.starts_with('@') { + // Scoped packages + let scoped = + Self::scan_scoped_packages(&entry_path, seen).await; + results.extend(scoped); + } else { + // Regular package + if let Some(pkg) = Self::check_package(&entry_path, seen).await { + results.push(pkg); + } + // Nested node_modules only for real directories (not symlinks) + if file_type.is_dir() { + let nested = + Self::scan_nested_node_modules(&entry_path, seen).await; + results.extend(nested); + } + } + } + + results + } + + /// Scan a scoped packages directory (`@scope/`). + fn scan_scoped_packages<'a>( + scope_path: &'a Path, + seen: &'a mut HashSet, + ) -> std::pin::Pin> + 'a>> { + Box::pin(async move { + let mut results = Vec::new(); + + let mut entries = match tokio::fs::read_dir(scope_path).await { + Ok(rd) => rd, + Err(_) => return results, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let name = entry.file_name(); + let name_str = name.to_string_lossy().to_string(); + + if name_str.starts_with('.') { + continue; + } + + let file_type = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + + if !file_type.is_dir() && !file_type.is_symlink() { + continue; + } + + let pkg_path = scope_path.join(&name_str); + if let Some(pkg) = Self::check_package(&pkg_path, seen).await { + results.push(pkg); + } + + // Nested node_modules only for real directories + if file_type.is_dir() { + let nested = + Self::scan_nested_node_modules(&pkg_path, seen).await; + results.extend(nested); + } + } + + results + }) + } + + /// Scan nested `node_modules` inside a package (if it exists). + fn scan_nested_node_modules<'a>( + pkg_path: &'a Path, + seen: &'a mut HashSet, + ) -> std::pin::Pin> + 'a>> { + Box::pin(async move { + let nested_nm = pkg_path.join("node_modules"); + + let mut entries = match tokio::fs::read_dir(&nested_nm).await { + Ok(rd) => rd, + Err(_) => return Vec::new(), + }; + + let mut results = Vec::new(); + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let name = entry.file_name(); + let name_str = name.to_string_lossy().to_string(); + + if name_str.starts_with('.') || name_str == "node_modules" { + continue; + } + + let file_type = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + + if !file_type.is_dir() && !file_type.is_symlink() { + continue; + } + + let entry_path = nested_nm.join(&name_str); + + if name_str.starts_with('@') { + let scoped = + Self::scan_scoped_packages(&entry_path, seen).await; + results.extend(scoped); + } else { + if let Some(pkg) = Self::check_package(&entry_path, seen).await { + results.push(pkg); + } + // Recursively check deeper nested node_modules + let deeper = + Self::scan_nested_node_modules(&entry_path, seen).await; + results.extend(deeper); + } + } + + results + }) + } + + /// Check a package directory and return `CrawledPackage` if valid. + /// Deduplicates by PURL via the `seen` set. + async fn check_package( + pkg_path: &Path, + seen: &mut HashSet, + ) -> Option { + let pkg_json_path = pkg_path.join("package.json"); + let (full_name, version) = read_package_json(&pkg_json_path).await?; + let (namespace, name) = parse_package_name(&full_name); + let purl = build_npm_purl(namespace.as_deref(), &name, &version); + + if seen.contains(&purl) { + return None; + } + seen.insert(purl.clone()); + + Some(CrawledPackage { + name, + version, + namespace, + purl, + path: pkg_path.to_path_buf(), + }) + } + + // ------------------------------------------------------------------ + // Private helpers – PURL parsing + // ------------------------------------------------------------------ + + /// Parse a PURL string to extract namespace, name, and version. + fn parse_purl_components(purl: &str) -> Option<(Option, String, String)> { + // Strip qualifiers + let base = match purl.find('?') { + Some(idx) => &purl[..idx], + None => purl, + }; + + let rest = base.strip_prefix("pkg:npm/")?; + let at_idx = rest.rfind('@')?; + let name_part = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + + if name_part.is_empty() || version.is_empty() { + return None; + } + + if name_part.starts_with('@') { + let slash_idx = name_part.find('/')?; + let namespace = name_part[..slash_idx].to_string(); + let name = name_part[slash_idx + 1..].to_string(); + if name.is_empty() { + return None; + } + Some((Some(namespace), name, version.to_string())) + } else { + Some((None, name_part.to_string(), version.to_string())) + } + } +} + +impl Default for NpmCrawler { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Utility +// --------------------------------------------------------------------------- + +/// Check whether a path is a directory (follows symlinks). +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_package_name_scoped() { + let (ns, name) = parse_package_name("@types/node"); + assert_eq!(ns.as_deref(), Some("@types")); + assert_eq!(name, "node"); + } + + #[test] + fn test_parse_package_name_unscoped() { + let (ns, name) = parse_package_name("lodash"); + assert!(ns.is_none()); + assert_eq!(name, "lodash"); + } + + #[test] + fn test_build_npm_purl_scoped() { + assert_eq!( + build_npm_purl(Some("@types"), "node", "20.0.0"), + "pkg:npm/@types/node@20.0.0" + ); + } + + #[test] + fn test_build_npm_purl_unscoped() { + assert_eq!( + build_npm_purl(None, "lodash", "4.17.21"), + "pkg:npm/lodash@4.17.21" + ); + } + + #[test] + fn test_parse_purl_components_scoped() { + let (ns, name, ver) = + NpmCrawler::parse_purl_components("pkg:npm/@types/node@20.0.0").unwrap(); + assert_eq!(ns.as_deref(), Some("@types")); + assert_eq!(name, "node"); + assert_eq!(ver, "20.0.0"); + } + + #[test] + fn test_parse_purl_components_unscoped() { + let (ns, name, ver) = + NpmCrawler::parse_purl_components("pkg:npm/lodash@4.17.21").unwrap(); + assert!(ns.is_none()); + assert_eq!(name, "lodash"); + assert_eq!(ver, "4.17.21"); + } + + #[test] + fn test_parse_purl_components_invalid() { + assert!(NpmCrawler::parse_purl_components("pkg:pypi/requests@2.0").is_none()); + assert!(NpmCrawler::parse_purl_components("not-a-purl").is_none()); + } + + #[tokio::test] + async fn test_read_package_json_valid() { + let dir = tempfile::tempdir().unwrap(); + let pkg_json = dir.path().join("package.json"); + tokio::fs::write( + &pkg_json, + r#"{"name": "test-pkg", "version": "1.0.0"}"#, + ) + .await + .unwrap(); + + let result = read_package_json(&pkg_json).await; + assert!(result.is_some()); + let (name, version) = result.unwrap(); + assert_eq!(name, "test-pkg"); + assert_eq!(version, "1.0.0"); + } + + #[tokio::test] + async fn test_read_package_json_missing() { + let dir = tempfile::tempdir().unwrap(); + let pkg_json = dir.path().join("package.json"); + assert!(read_package_json(&pkg_json).await.is_none()); + } + + #[tokio::test] + async fn test_read_package_json_invalid() { + let dir = tempfile::tempdir().unwrap(); + let pkg_json = dir.path().join("package.json"); + tokio::fs::write(&pkg_json, "not json").await.unwrap(); + assert!(read_package_json(&pkg_json).await.is_none()); + } + + #[tokio::test] + async fn test_crawl_all_basic() { + let dir = tempfile::tempdir().unwrap(); + let nm = dir.path().join("node_modules"); + let pkg_dir = nm.join("foo"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write( + pkg_dir.join("package.json"), + r#"{"name": "foo", "version": "1.2.3"}"#, + ) + .await + .unwrap(); + + let crawler = NpmCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: DEFAULT_BATCH_SIZE, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].name, "foo"); + assert_eq!(packages[0].version, "1.2.3"); + assert_eq!(packages[0].purl, "pkg:npm/foo@1.2.3"); + assert!(packages[0].namespace.is_none()); + } + + #[tokio::test] + async fn test_crawl_all_scoped() { + let dir = tempfile::tempdir().unwrap(); + let nm = dir.path().join("node_modules"); + let scope_dir = nm.join("@types").join("node"); + tokio::fs::create_dir_all(&scope_dir).await.unwrap(); + tokio::fs::write( + scope_dir.join("package.json"), + r#"{"name": "@types/node", "version": "20.0.0"}"#, + ) + .await + .unwrap(); + + let crawler = NpmCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: DEFAULT_BATCH_SIZE, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].name, "node"); + assert_eq!(packages[0].namespace.as_deref(), Some("@types")); + assert_eq!(packages[0].purl, "pkg:npm/@types/node@20.0.0"); + } + + #[test] + fn test_find_node_dirs_sync_wildcard() { + // Create an nvm-like layout: base/v18.0.0/lib/node_modules + let dir = tempfile::tempdir().unwrap(); + let nm1 = dir.path().join("v18.0.0/lib/node_modules"); + let nm2 = dir.path().join("v20.1.0/lib/node_modules"); + std::fs::create_dir_all(&nm1).unwrap(); + std::fs::create_dir_all(&nm2).unwrap(); + + let results = find_node_dirs_sync(dir.path(), &["*", "lib", "node_modules"]); + assert_eq!(results.len(), 2); + assert!(results.contains(&nm1)); + assert!(results.contains(&nm2)); + } + + #[test] + fn test_find_node_dirs_sync_empty() { + // Non-existent base path should return empty + let results = find_node_dirs_sync(Path::new("/nonexistent/path/xyz"), &["*", "lib"]); + assert!(results.is_empty()); + } + + #[test] + fn test_find_node_dirs_sync_literal() { + // All literal segments (no wildcard) + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("lib/node_modules"); + std::fs::create_dir_all(&target).unwrap(); + + let results = find_node_dirs_sync(dir.path(), &["lib", "node_modules"]); + assert_eq!(results.len(), 1); + assert_eq!(results[0], target); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_macos_get_global_node_modules_paths_no_panic() { + let crawler = NpmCrawler::new(); + // Should not panic, even if no package managers are installed + let _paths = crawler.get_global_node_modules_paths(); + } + + #[tokio::test] + async fn test_find_by_purls() { + let dir = tempfile::tempdir().unwrap(); + let nm = dir.path().join("node_modules"); + + // Create foo@1.0.0 + let foo_dir = nm.join("foo"); + tokio::fs::create_dir_all(&foo_dir).await.unwrap(); + tokio::fs::write( + foo_dir.join("package.json"), + r#"{"name": "foo", "version": "1.0.0"}"#, + ) + .await + .unwrap(); + + // Create @types/node@20.0.0 + let types_dir = nm.join("@types").join("node"); + tokio::fs::create_dir_all(&types_dir).await.unwrap(); + tokio::fs::write( + types_dir.join("package.json"), + r#"{"name": "@types/node", "version": "20.0.0"}"#, + ) + .await + .unwrap(); + + let crawler = NpmCrawler::new(); + let purls = vec![ + "pkg:npm/foo@1.0.0".to_string(), + "pkg:npm/@types/node@20.0.0".to_string(), + "pkg:npm/not-installed@0.0.1".to_string(), + ]; + + let result = crawler.find_by_purls(&nm, &purls).await.unwrap(); + + assert_eq!(result.len(), 2); + assert!(result.contains_key("pkg:npm/foo@1.0.0")); + assert!(result.contains_key("pkg:npm/@types/node@20.0.0")); + assert!(!result.contains_key("pkg:npm/not-installed@0.0.1")); + } +} diff --git a/crates/socket-patch-core/src/crawlers/nuget_crawler.rs b/crates/socket-patch-core/src/crawlers/nuget_crawler.rs new file mode 100644 index 0000000..4932243 --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/nuget_crawler.rs @@ -0,0 +1,802 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +/// NuGet/.NET ecosystem crawler for discovering packages in global cache, +/// legacy `packages/` folders, and `obj/` restore layouts. +pub struct NuGetCrawler; + +impl NuGetCrawler { + /// Create a new `NuGetCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get NuGet package paths based on options. + /// + /// In global mode, returns the global NuGet packages folder + /// (`NUGET_PACKAGES` env var or `~/.nuget/packages/`). + /// + /// In local mode (in priority order): + /// 1. `/packages/` folder (legacy packages.config layout) + /// 2. Global cache — but only if cwd contains a .NET project file + /// 3. Paths discovered from `obj/project.assets.json` + pub async fn get_nuget_package_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + let home = nuget_home(); + if is_dir(&home).await { + return Ok(vec![home]); + } + return Ok(Vec::new()); + } + + let mut paths = Vec::new(); + let mut seen = HashSet::new(); + + // 1. Check /packages/ (legacy packages.config layout) + let packages_dir = options.cwd.join("packages"); + if is_dir(&packages_dir).await && seen.insert(packages_dir.clone()) { + paths.push(packages_dir); + } + + // 2. Fall back to global cache if this looks like a .NET project + if is_dotnet_project(&options.cwd).await { + let home = nuget_home(); + if is_dir(&home).await && seen.insert(home.clone()) { + paths.push(home); + } + } + + // 3. Check obj/ dirs for project.assets.json + let obj_paths = discover_paths_from_assets(&options.cwd).await; + for p in obj_paths { + if is_dir(&p).await && seen.insert(p.clone()) { + paths.push(p); + } + } + + Ok(paths) + } + + /// Crawl all discovered package paths and return every package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let pkg_paths = self.get_nuget_package_paths(options).await.unwrap_or_default(); + + for pkg_path in &pkg_paths { + let found = self.scan_package_dir(pkg_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL inside a single package directory. + pub async fn find_by_purls( + &self, + pkg_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + for purl in purls { + if let Some((name, version)) = crate::utils::purl::parse_nuget_purl(purl) { + // Try global cache layout: // + let global_dir = pkg_path.join(name.to_lowercase()).join(version); + if self.verify_nuget_package(&global_dir).await { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: global_dir, + }, + ); + continue; + } + + // Try legacy layout: ./ + let legacy_dir = pkg_path.join(format!("{name}.{version}")); + if self.verify_nuget_package(&legacy_dir).await { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: legacy_dir, + }, + ); + continue; + } + + // Try case-insensitive legacy scan (NuGet names are case-insensitive) + if let Some(found_dir) = self + .find_legacy_dir_case_insensitive(pkg_path, name, version) + .await + { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: found_dir, + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// Scan a package directory and return all valid NuGet packages found. + /// + /// Handles both layouts: + /// - Global cache: `//` with `.nuspec` inside + /// - Legacy packages/: `./` with `.nuspec` inside + async fn scan_package_dir( + &self, + pkg_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + let mut entries = match tokio::fs::read_dir(pkg_path).await { + Ok(rd) => rd, + Err(_) => return results, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let dir_name = entry.file_name(); + let dir_name_str = dir_name.to_string_lossy(); + + // Skip hidden directories + if dir_name_str.starts_with('.') { + continue; + } + + let entry_path = pkg_path.join(&*dir_name_str); + + // Try global cache layout: this directory is a package name, + // containing version subdirectories + if let Some(pkgs) = self + .scan_global_cache_package(&entry_path, &dir_name_str, seen) + .await + { + results.extend(pkgs); + continue; + } + + // Try legacy layout: ./ directory + if let Some((name, version)) = parse_legacy_dir_name(&dir_name_str) { + if self.verify_nuget_package(&entry_path).await { + let purl = crate::utils::purl::build_nuget_purl(&name, &version); + if !seen.contains(&purl) { + seen.insert(purl.clone()); + results.push(CrawledPackage { + name, + version, + namespace: None, + purl, + path: entry_path, + }); + } + } + } + } + + results + } + + /// Scan a global cache package directory (`/`) for version subdirectories. + async fn scan_global_cache_package( + &self, + name_dir: &Path, + name: &str, + seen: &mut HashSet, + ) -> Option> { + let mut version_entries = match tokio::fs::read_dir(name_dir).await { + Ok(rd) => rd, + Err(_) => return None, + }; + + let mut found_any = false; + let mut results = Vec::new(); + + while let Ok(Some(ver_entry)) = version_entries.next_entry().await { + let ft = match ver_entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let ver_name = ver_entry.file_name(); + let ver_str = ver_name.to_string_lossy(); + let ver_path = name_dir.join(&*ver_str); + + if self.verify_nuget_package(&ver_path).await { + found_any = true; + let purl = crate::utils::purl::build_nuget_purl(name, &ver_str); + if !seen.contains(&purl) { + seen.insert(purl.clone()); + results.push(CrawledPackage { + name: name.to_string(), + version: ver_str.to_string(), + namespace: None, + purl, + path: ver_path, + }); + } + } + } + + if found_any { + Some(results) + } else { + None + } + } + + /// Verify that a directory looks like an installed NuGet package. + /// Checks for a `.nuspec` file or a `lib/` directory. + async fn verify_nuget_package(&self, path: &Path) -> bool { + if !is_dir(path).await { + return false; + } + + // Check for lib/ directory + if is_dir(&path.join("lib")).await { + return true; + } + + // Check for any .nuspec file + find_nuspec_in_dir(path).await.is_some() + } + + /// Find a legacy package directory with case-insensitive matching. + async fn find_legacy_dir_case_insensitive( + &self, + pkg_path: &Path, + name: &str, + version: &str, + ) -> Option { + let target = format!("{}.{}", name.to_lowercase(), version.to_lowercase()); + + let mut entries = tokio::fs::read_dir(pkg_path).await.ok()?; + while let Ok(Some(entry)) = entries.next_entry().await { + let dir_name = entry.file_name(); + let dir_name_str = dir_name.to_string_lossy(); + if dir_name_str.to_lowercase() == target { + let path = pkg_path.join(&*dir_name_str); + if self.verify_nuget_package(&path).await { + return Some(path); + } + } + } + + None + } +} + +impl Default for NuGetCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Get the NuGet global packages folder. +/// +/// Checks `NUGET_PACKAGES` env var, falls back to `~/.nuget/packages/`. +fn nuget_home() -> PathBuf { + if let Ok(custom) = std::env::var("NUGET_PACKAGES") { + return PathBuf::from(custom); + } + + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + PathBuf::from(home).join(".nuget").join("packages") +} + +/// Check if the cwd contains any .NET project indicators. +async fn is_dotnet_project(cwd: &Path) -> bool { + let extensions = [".csproj", ".fsproj", ".vbproj", ".sln"]; + + let mut entries = match tokio::fs::read_dir(cwd).await { + Ok(rd) => rd, + Err(_) => return false, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + if let Some(name) = entry.file_name().to_str() { + for ext in &extensions { + if name.ends_with(ext) { + return true; + } + } + if name == "NuGet.Config" || name == "nuget.config" { + return true; + } + } + } + + false +} + +/// Parse a legacy packages directory name into (name, version). +/// +/// Legacy NuGet directories follow the pattern `.`, where +/// the version starts at the last `.` followed by a digit-starting segment. +fn parse_legacy_dir_name(dir_name: &str) -> Option<(String, String)> { + // Find the first '.' followed by a digit + let mut split_idx = None; + for (i, _) in dir_name.match_indices('.') { + if i + 1 < dir_name.len() && dir_name[i + 1..].starts_with(|c: char| c.is_ascii_digit()) { + split_idx = Some(i); + break; + } + } + let idx = split_idx?; + let name = &dir_name[..idx]; + let version = &dir_name[idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name.to_string(), version.to_string())) +} + +/// Find a `.nuspec` file in a directory. +async fn find_nuspec_in_dir(dir: &Path) -> Option { + let mut entries = tokio::fs::read_dir(dir).await.ok()?; + while let Ok(Some(entry)) = entries.next_entry().await { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(".nuspec") { + return Some(dir.join(name)); + } + } + } + None +} + +/// Parse `` and `` from `.nuspec` XML content. +/// +/// Uses simple string matching — the nuspec format always has these +/// elements on separate lines. +pub fn parse_nuspec_id_version(content: &str) -> Option<(String, String)> { + let mut id = None; + let mut version = None; + + for line in content.lines() { + let trimmed = line.trim(); + + if id.is_none() { + if let Some(value) = extract_xml_element(trimmed, "id") { + id = Some(value); + } + } + + if version.is_none() { + if let Some(value) = extract_xml_element(trimmed, "version") { + version = Some(value); + } + } + + if id.is_some() && version.is_some() { + break; + } + } + + match (id, version) { + (Some(id), Some(version)) if !id.is_empty() && !version.is_empty() => { + Some((id, version)) + } + _ => None, + } +} + +/// Extract the text content of a simple XML element like `value`. +fn extract_xml_element(line: &str, tag: &str) -> Option { + let open = format!("<{tag}>"); + let close = format!(""); + + let start = line.find(&open)?; + let after_open = start + open.len(); + let end = line[after_open..].find(&close)?; + let value = &line[after_open..after_open + end]; + let value = value.trim(); + if value.is_empty() { + None + } else { + Some(value.to_string()) + } +} + +/// Discover additional package paths from `obj/project.assets.json` files. +async fn discover_paths_from_assets(cwd: &Path) -> Vec { + let mut paths = Vec::new(); + + // Look for obj/project.assets.json in cwd + let assets_path = cwd.join("obj").join("project.assets.json"); + if let Some(pkg_folder) = parse_project_assets_package_folders(&assets_path).await { + for folder in pkg_folder { + paths.push(folder); + } + } + + // Also check subdirectories one level deep for multi-project solutions + let mut entries = match tokio::fs::read_dir(cwd).await { + Ok(rd) => rd, + Err(_) => return paths, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + let sub_assets = cwd.join(entry.file_name()).join("obj").join("project.assets.json"); + if let Some(pkg_folders) = parse_project_assets_package_folders(&sub_assets).await { + for folder in pkg_folders { + paths.push(folder); + } + } + } + + paths +} + +/// Parse `project.assets.json` to extract the `packageFolders` keys. +/// +/// The file is a JSON object with a `packageFolders` key containing +/// folder paths as keys, e.g.: `{"packageFolders": {"/home/user/.nuget/packages/": {}}}`. +async fn parse_project_assets_package_folders(path: &Path) -> Option> { + let content = tokio::fs::read_to_string(path).await.ok()?; + let json: serde_json::Value = serde_json::from_str(&content).ok()?; + let folders = json.get("packageFolders")?.as_object()?; + + let result: Vec = folders.keys().map(PathBuf::from).collect(); + + if result.is_empty() { + None + } else { + Some(result) + } +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_legacy_dir_name() { + assert_eq!( + parse_legacy_dir_name("Newtonsoft.Json.13.0.3"), + Some(("Newtonsoft.Json".to_string(), "13.0.3".to_string())) + ); + assert_eq!( + parse_legacy_dir_name("System.Text.Json.8.0.0"), + Some(("System.Text.Json".to_string(), "8.0.0".to_string())) + ); + assert_eq!( + parse_legacy_dir_name("Microsoft.Extensions.Logging.8.0.0"), + Some(( + "Microsoft.Extensions.Logging".to_string(), + "8.0.0".to_string() + )) + ); + assert_eq!( + parse_legacy_dir_name("xunit.2.6.2"), + Some(("xunit".to_string(), "2.6.2".to_string())) + ); + assert!(parse_legacy_dir_name("no-version-here").is_none()); + assert!(parse_legacy_dir_name("justtext").is_none()); + } + + #[test] + fn test_parse_nuspec_id_version() { + let nuspec = r#" + + + Newtonsoft.Json + 13.0.3 + James Newton-King + +"#; + assert_eq!( + parse_nuspec_id_version(nuspec), + Some(("Newtonsoft.Json".to_string(), "13.0.3".to_string())) + ); + } + + #[test] + fn test_parse_nuspec_empty() { + assert!(parse_nuspec_id_version("").is_none()); + assert!(parse_nuspec_id_version("").is_none()); + } + + #[test] + fn test_extract_xml_element() { + assert_eq!( + extract_xml_element(" Newtonsoft.Json", "id"), + Some("Newtonsoft.Json".to_string()) + ); + assert_eq!( + extract_xml_element(" 13.0.3", "version"), + Some("13.0.3".to_string()) + ); + assert_eq!(extract_xml_element("", "id"), None); + assert_eq!(extract_xml_element("no tags here", "id"), None); + } + + #[tokio::test] + async fn test_find_by_purls_global_cache_layout() { + let dir = tempfile::tempdir().unwrap(); + + // Create global cache layout: // + let pkg_dir = dir.path().join("newtonsoft.json").join("13.0.3"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write( + pkg_dir.join("newtonsoft.json.nuspec"), + r#"Newtonsoft.Json13.0.3"#, + ) + .await + .unwrap(); + + let crawler = NuGetCrawler::new(); + let purls = vec![ + "pkg:nuget/Newtonsoft.Json@13.0.3".to_string(), + "pkg:nuget/System.Text.Json@8.0.0".to_string(), + ]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:nuget/Newtonsoft.Json@13.0.3")); + assert!(!result.contains_key("pkg:nuget/System.Text.Json@8.0.0")); + } + + #[tokio::test] + async fn test_find_by_purls_legacy_layout() { + let dir = tempfile::tempdir().unwrap(); + + // Create legacy layout: ./ + let pkg_dir = dir.path().join("Newtonsoft.Json.13.0.3"); + tokio::fs::create_dir_all(pkg_dir.join("lib")).await.unwrap(); + + let crawler = NuGetCrawler::new(); + let purls = vec!["pkg:nuget/Newtonsoft.Json@13.0.3".to_string()]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:nuget/Newtonsoft.Json@13.0.3")); + } + + #[tokio::test] + async fn test_crawl_all_global_cache() { + let dir = tempfile::tempdir().unwrap(); + + // Create global cache layout + let nj_dir = dir.path().join("newtonsoft.json").join("13.0.3"); + tokio::fs::create_dir_all(nj_dir.join("lib")).await.unwrap(); + + let stj_dir = dir.path().join("system.text.json").join("8.0.0"); + tokio::fs::create_dir_all(&stj_dir).await.unwrap(); + tokio::fs::write( + stj_dir.join("system.text.json.nuspec"), + "System.Text.Json8.0.0", + ) + .await + .unwrap(); + + let crawler = NuGetCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:nuget/newtonsoft.json@13.0.3")); + assert!(purls.contains("pkg:nuget/system.text.json@8.0.0")); + } + + #[tokio::test] + async fn test_crawl_all_legacy_packages() { + let dir = tempfile::tempdir().unwrap(); + + // Create legacy layout + let nj_dir = dir.path().join("Newtonsoft.Json.13.0.3"); + tokio::fs::create_dir_all(nj_dir.join("lib")).await.unwrap(); + + let xunit_dir = dir.path().join("xunit.2.6.2"); + tokio::fs::create_dir_all(&xunit_dir).await.unwrap(); + tokio::fs::write( + xunit_dir.join("xunit.nuspec"), + "xunit2.6.2", + ) + .await + .unwrap(); + + let crawler = NuGetCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:nuget/Newtonsoft.Json@13.0.3")); + assert!(purls.contains("pkg:nuget/xunit@2.6.2")); + } + + #[tokio::test] + async fn test_is_dotnet_project() { + let dir = tempfile::tempdir().unwrap(); + + // No .NET files — should return false + assert!(!super::is_dotnet_project(dir.path()).await); + + // Add a .csproj file + tokio::fs::write(dir.path().join("MyApp.csproj"), "") + .await + .unwrap(); + assert!(super::is_dotnet_project(dir.path()).await); + } + + #[tokio::test] + async fn test_is_dotnet_project_sln() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("MySolution.sln"), "") + .await + .unwrap(); + assert!(super::is_dotnet_project(dir.path()).await); + } + + #[tokio::test] + async fn test_verify_nuget_package_with_nuspec() { + let dir = tempfile::tempdir().unwrap(); + let pkg_dir = dir.path().join("testpkg"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write(pkg_dir.join("test.nuspec"), "") + .await + .unwrap(); + + let crawler = NuGetCrawler::new(); + assert!(crawler.verify_nuget_package(&pkg_dir).await); + } + + #[tokio::test] + async fn test_verify_nuget_package_with_lib() { + let dir = tempfile::tempdir().unwrap(); + let pkg_dir = dir.path().join("testpkg"); + tokio::fs::create_dir_all(pkg_dir.join("lib")).await.unwrap(); + + let crawler = NuGetCrawler::new(); + assert!(crawler.verify_nuget_package(&pkg_dir).await); + } + + #[tokio::test] + async fn test_verify_nuget_package_empty_dir() { + let dir = tempfile::tempdir().unwrap(); + let pkg_dir = dir.path().join("testpkg"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + + let crawler = NuGetCrawler::new(); + assert!(!crawler.verify_nuget_package(&pkg_dir).await); + } + + #[tokio::test] + async fn test_deduplication() { + let dir = tempfile::tempdir().unwrap(); + + // Create a single package + let pkg_dir = dir.path().join("newtonsoft.json").join("13.0.3"); + tokio::fs::create_dir_all(pkg_dir.join("lib")).await.unwrap(); + + let crawler = NuGetCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:nuget/newtonsoft.json@13.0.3"); + } + + #[tokio::test] + async fn test_project_assets_discovery() { + let dir = tempfile::tempdir().unwrap(); + + // Create obj/project.assets.json + let obj_dir = dir.path().join("obj"); + tokio::fs::create_dir_all(&obj_dir).await.unwrap(); + + let pkg_folder = dir.path().join("custom-packages"); + tokio::fs::create_dir_all(&pkg_folder).await.unwrap(); + + let assets_content = serde_json::json!({ + "packageFolders": { + pkg_folder.to_string_lossy().to_string(): {} + } + }); + tokio::fs::write( + obj_dir.join("project.assets.json"), + serde_json::to_string(&assets_content).unwrap(), + ) + .await + .unwrap(); + + let paths = discover_paths_from_assets(dir.path()).await; + assert_eq!(paths.len(), 1); + assert_eq!(paths[0], pkg_folder); + } + + #[tokio::test] + async fn test_nuget_home_env_var() { + // Test that NUGET_PACKAGES env var is respected + let custom = "/tmp/test-nuget-packages"; + std::env::set_var("NUGET_PACKAGES", custom); + let home = nuget_home(); + assert_eq!(home, PathBuf::from(custom)); + std::env::remove_var("NUGET_PACKAGES"); + } +} diff --git a/crates/socket-patch-core/src/crawlers/python_crawler.rs b/crates/socket-patch-core/src/crawlers/python_crawler.rs new file mode 100644 index 0000000..55fcfdd --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/python_crawler.rs @@ -0,0 +1,875 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +// --------------------------------------------------------------------------- +// Python command discovery +// --------------------------------------------------------------------------- + +/// Find a working Python command on the system. +/// +/// Tries `python3`, `python`, and `py` (Windows launcher) in order, +/// returning the first one that responds to `--version`. +pub fn find_python_command() -> Option<&'static str> { + ["python3", "python", "py"].into_iter().find(|cmd| { + Command::new(cmd) + .args(["--version"]) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .is_ok() + }) +} + +/// Default batch size for crawling. +const _DEFAULT_BATCH_SIZE: usize = 100; + +// --------------------------------------------------------------------------- +// PEP 503 name canonicalization +// --------------------------------------------------------------------------- + +/// Canonicalize a Python package name per PEP 503. +/// +/// Lowercases, trims, and replaces runs of `[-_.]` with a single `-`. +pub fn canonicalize_pypi_name(name: &str) -> String { + let trimmed = name.trim().to_lowercase(); + let mut result = String::with_capacity(trimmed.len()); + let mut in_separator_run = false; + + for ch in trimmed.chars() { + if ch == '-' || ch == '_' || ch == '.' { + if !in_separator_run { + result.push('-'); + in_separator_run = true; + } + // else: skip consecutive separators + } else { + in_separator_run = false; + result.push(ch); + } + } + + result +} + +// --------------------------------------------------------------------------- +// Helpers: read Python metadata from dist-info +// --------------------------------------------------------------------------- + +/// Read `Name` and `Version` from a `.dist-info/METADATA` file. +pub async fn read_python_metadata(dist_info_path: &Path) -> Option<(String, String)> { + let metadata_path = dist_info_path.join("METADATA"); + let content = tokio::fs::read_to_string(&metadata_path).await.ok()?; + + let mut name: Option = None; + let mut version: Option = None; + + for line in content.lines() { + if name.is_some() && version.is_some() { + break; + } + if let Some(rest) = line.strip_prefix("Name:") { + name = Some(rest.trim().to_string()); + } else if let Some(rest) = line.strip_prefix("Version:") { + version = Some(rest.trim().to_string()); + } + // Stop at first empty line (end of headers) + if line.trim().is_empty() && (name.is_some() || version.is_some()) { + break; + } + } + + match (name, version) { + (Some(n), Some(v)) if !n.is_empty() && !v.is_empty() => Some((n, v)), + _ => None, + } +} + +// --------------------------------------------------------------------------- +// Helpers: find Python directories with wildcard matching +// --------------------------------------------------------------------------- + +/// Find directories matching a path pattern with wildcard segments. +/// +/// Supported wildcards: +/// - `"python3.*"` — matches directory entries starting with `python3.` +/// - `"*"` — matches any directory entry +/// +/// All other segments are treated as literal path components. +pub async fn find_python_dirs(base_path: &Path, segments: &[&str]) -> Vec { + let mut results = Vec::new(); + + // Check that base_path is a directory + match tokio::fs::metadata(base_path).await { + Ok(m) if m.is_dir() => {} + _ => return results, + } + + if segments.is_empty() { + results.push(base_path.to_path_buf()); + return results; + } + + let first = segments[0]; + let rest = &segments[1..]; + + if first == "python3.*" { + // Wildcard: list directory and match python3.X entries + if let Ok(mut entries) = tokio::fs::read_dir(base_path).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if name_str.starts_with("python3.") { + let sub = Box::pin(find_python_dirs( + &base_path.join(entry.file_name()), + rest, + )) + .await; + results.extend(sub); + } + } + } + } else if first == "*" { + // Generic wildcard: match any directory entry + if let Ok(mut entries) = tokio::fs::read_dir(base_path).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + let sub = Box::pin(find_python_dirs( + &base_path.join(entry.file_name()), + rest, + )) + .await; + results.extend(sub); + } + } + } else { + // Literal segment: just check if it exists + let sub = + Box::pin(find_python_dirs(&base_path.join(first), rest)).await; + results.extend(sub); + } + + results +} + +// --------------------------------------------------------------------------- +// Helpers: site-packages discovery +// --------------------------------------------------------------------------- + +/// Find `site-packages` (or `dist-packages`) directories under a base dir. +/// +/// Handles both Unix (`lib/python3.X/site-packages`) and macOS/Linux layouts. +pub async fn find_site_packages_under( + base_dir: &Path, + sub_dir_type: &str, // "site-packages" or "dist-packages" +) -> Vec { + if cfg!(windows) { + find_python_dirs(base_dir, &["Lib", sub_dir_type]).await + } else { + find_python_dirs(base_dir, &["lib", "python3.*", sub_dir_type]).await + } +} + +/// Find local virtual environment `site-packages` directories. +/// +/// Checks (in order): +/// 1. `VIRTUAL_ENV` environment variable +/// 2. `.venv` directory in `cwd` +/// 3. `venv` directory in `cwd` +pub async fn find_local_venv_site_packages(cwd: &Path) -> Vec { + let mut results = Vec::new(); + + // 1. Check VIRTUAL_ENV env var + if let Ok(virtual_env) = std::env::var("VIRTUAL_ENV") { + let venv_path = PathBuf::from(&virtual_env); + let matches = find_site_packages_under(&venv_path, "site-packages").await; + results.extend(matches); + if !results.is_empty() { + return results; + } + } + + // 2. Check .venv and venv in cwd + for venv_dir in &[".venv", "venv"] { + let venv_path = cwd.join(venv_dir); + let matches = find_site_packages_under(&venv_path, "site-packages").await; + results.extend(matches); + } + + results +} + +/// Get global/system Python `site-packages` directories. +/// +/// Queries `python3` for site-packages paths, then checks well-known system +/// locations including Homebrew, conda, uv tools, pip --user, etc. +pub async fn get_global_python_site_packages() -> Vec { + let mut results = Vec::new(); + let mut seen = HashSet::new(); + + let add_path = |p: PathBuf, seen: &mut HashSet, results: &mut Vec| { + let resolved = if p.is_absolute() { + p + } else { + std::path::absolute(&p).unwrap_or(p) + }; + if seen.insert(resolved.clone()) { + results.push(resolved); + } + }; + + // 1. Ask Python for site-packages + if let Some(python_cmd) = find_python_command() { + if let Ok(output) = Command::new(python_cmd) + .args([ + "-c", + "import site; print('\\n'.join(site.getsitepackages())); print(site.getusersitepackages())", + ]) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + { + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + for line in stdout.lines() { + let p = line.trim(); + if !p.is_empty() { + add_path(PathBuf::from(p), &mut seen, &mut results); + } + } + } + } + } + + // 2. Well-known system paths + let home_dir = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + + // Helper closure to scan base/lib/python3.*/[dist|site]-packages + async fn scan_well_known( + base: &Path, + pkg_type: &str, + seen: &mut HashSet, + results: &mut Vec, + ) { + let matches = find_python_dirs(base, &["lib", "python3.*", pkg_type]).await; + for m in matches { + let resolved = if m.is_absolute() { + m + } else { + std::path::absolute(&m).unwrap_or(m) + }; + if seen.insert(resolved.clone()) { + results.push(resolved); + } + } + } + + if !cfg!(windows) { + // Debian/Ubuntu + scan_well_known(Path::new("/usr"), "dist-packages", &mut seen, &mut results).await; + scan_well_known(Path::new("/usr"), "site-packages", &mut seen, &mut results).await; + // Debian pip / most distros / macOS + scan_well_known( + Path::new("/usr/local"), + "dist-packages", + &mut seen, + &mut results, + ) + .await; + scan_well_known( + Path::new("/usr/local"), + "site-packages", + &mut seen, + &mut results, + ) + .await; + // pip --user on Unix + let user_local = PathBuf::from(&home_dir).join(".local"); + scan_well_known(&user_local, "site-packages", &mut seen, &mut results).await; + } + + // macOS-specific + if cfg!(target_os = "macos") { + scan_well_known( + Path::new("/opt/homebrew"), + "site-packages", + &mut seen, + &mut results, + ) + .await; + + // Python.org framework + let fw_matches = find_python_dirs( + Path::new("/Library/Frameworks/Python.framework/Versions"), + &["python3.*", "lib", "python3.*", "site-packages"], + ) + .await; + for m in fw_matches { + add_path(m, &mut seen, &mut results); + } + + let fw_matches2 = find_python_dirs( + Path::new("/Library/Frameworks/Python.framework"), + &["Versions", "*", "lib", "python3.*", "site-packages"], + ) + .await; + for m in fw_matches2 { + add_path(m, &mut seen, &mut results); + } + } + + // Windows-specific + if cfg!(windows) { + // pip --user on Windows: %APPDATA%\Python\PythonXY\site-packages + if let Ok(appdata) = std::env::var("APPDATA") { + let appdata_python = PathBuf::from(&appdata).join("Python"); + if let Ok(mut entries) = tokio::fs::read_dir(&appdata_python).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let p = appdata_python.join(entry.file_name()).join("site-packages"); + if tokio::fs::metadata(&p).await.is_ok() { + add_path(p, &mut seen, &mut results); + } + } + } + } + // Common Windows Python install locations + for base in &["C:\\Python", "C:\\Program Files\\Python"] { + if let Ok(mut entries) = tokio::fs::read_dir(base).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let sp = PathBuf::from(base) + .join(entry.file_name()) + .join("Lib") + .join("site-packages"); + if tokio::fs::metadata(&sp).await.is_ok() { + add_path(sp, &mut seen, &mut results); + } + } + } + } + // Microsoft Store / python.org via LocalAppData + if let Ok(local) = std::env::var("LOCALAPPDATA") { + let programs_python = PathBuf::from(&local).join("Programs").join("Python"); + if let Ok(mut entries) = tokio::fs::read_dir(&programs_python).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let sp = programs_python + .join(entry.file_name()) + .join("Lib") + .join("site-packages"); + if tokio::fs::metadata(&sp).await.is_ok() { + add_path(sp, &mut seen, &mut results); + } + } + } + } + } + + // pyenv (works on macOS and Linux) + if !cfg!(windows) { + let pyenv_root = std::env::var("PYENV_ROOT") + .map(PathBuf::from) + .unwrap_or_else(|_| PathBuf::from(&home_dir).join(".pyenv")); + let pyenv_versions = pyenv_root.join("versions"); + let pyenv_matches = find_python_dirs( + &pyenv_versions, + &["*", "lib", "python3.*", "site-packages"], + ) + .await; + for m in pyenv_matches { + add_path(m, &mut seen, &mut results); + } + } + + // Conda + let anaconda = PathBuf::from(&home_dir).join("anaconda3"); + scan_well_known(&anaconda, "site-packages", &mut seen, &mut results).await; + let miniconda = PathBuf::from(&home_dir).join("miniconda3"); + scan_well_known(&miniconda, "site-packages", &mut seen, &mut results).await; + + // uv tools + if cfg!(target_os = "macos") { + let uv_base = PathBuf::from(&home_dir) + .join("Library") + .join("Application Support") + .join("uv") + .join("tools"); + let uv_matches = + find_python_dirs(&uv_base, &["*", "lib", "python3.*", "site-packages"]).await; + for m in uv_matches { + add_path(m, &mut seen, &mut results); + } + } else if cfg!(windows) { + // %LOCALAPPDATA%\uv\tools + if let Ok(local) = std::env::var("LOCALAPPDATA") { + let uv_base = PathBuf::from(local).join("uv").join("tools"); + let uv_matches = + find_python_dirs(&uv_base, &["*", "Lib", "site-packages"]).await; + for m in uv_matches { + add_path(m, &mut seen, &mut results); + } + } + } else { + let uv_base = PathBuf::from(&home_dir) + .join(".local") + .join("share") + .join("uv") + .join("tools"); + let uv_matches = + find_python_dirs(&uv_base, &["*", "lib", "python3.*", "site-packages"]).await; + for m in uv_matches { + add_path(m, &mut seen, &mut results); + } + } + + results +} + +// --------------------------------------------------------------------------- +// PythonCrawler +// --------------------------------------------------------------------------- + +/// Python ecosystem crawler for discovering packages in `site-packages`. +pub struct PythonCrawler; + +impl PythonCrawler { + /// Create a new `PythonCrawler`. + pub fn new() -> Self { + Self + } + + /// Get `site-packages` paths based on options. + pub async fn get_site_packages_paths(&self, options: &CrawlerOptions) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(get_global_python_site_packages().await); + } + Ok(find_local_venv_site_packages(&options.cwd).await) + } + + /// Crawl all discovered `site-packages` and return every package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let sp_paths = self.get_site_packages_paths(options).await.unwrap_or_default(); + + for sp_path in &sp_paths { + let found = self.scan_site_packages(sp_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL. + /// + /// Accepts base PURLs (no qualifiers) — the caller should strip qualifiers + /// before calling. + pub async fn find_by_purls( + &self, + site_packages_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result = HashMap::new(); + + // Build lookup: canonicalized-name@version -> purl + let mut purl_lookup: HashMap = HashMap::new(); + for purl in purls { + if let Some((name, version)) = Self::parse_pypi_purl(purl) { + let key = format!("{}@{}", canonicalize_pypi_name(&name), version); + purl_lookup.insert(key, purl.as_str()); + } + } + + if purl_lookup.is_empty() { + return Ok(result); + } + + // Scan all .dist-info dirs + let entries = match tokio::fs::read_dir(site_packages_path).await { + Ok(rd) => { + let mut entries = rd; + let mut v = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + v.push(entry); + } + v + } + Err(_) => return Ok(result), + }; + + for entry in entries { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if !name_str.ends_with(".dist-info") { + continue; + } + + let dist_info_path = site_packages_path.join(&*name_str); + if let Some((raw_name, version)) = read_python_metadata(&dist_info_path).await { + let canon_name = canonicalize_pypi_name(&raw_name); + let key = format!("{canon_name}@{version}"); + + if let Some(&matched_purl) = purl_lookup.get(&key) { + result.insert( + matched_purl.to_string(), + CrawledPackage { + name: canon_name, + version, + namespace: None, + purl: matched_purl.to_string(), + path: site_packages_path.to_path_buf(), + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// Scan a `site-packages` directory for `.dist-info` directories. + async fn scan_site_packages( + &self, + site_packages_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + let entries = match tokio::fs::read_dir(site_packages_path).await { + Ok(rd) => { + let mut entries = rd; + let mut v = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + v.push(entry); + } + v + } + Err(_) => return results, + }; + + for entry in entries { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if !name_str.ends_with(".dist-info") { + continue; + } + + let dist_info_path = site_packages_path.join(&*name_str); + if let Some((raw_name, version)) = read_python_metadata(&dist_info_path).await { + let canon_name = canonicalize_pypi_name(&raw_name); + let purl = format!("pkg:pypi/{canon_name}@{version}"); + + if seen.contains(&purl) { + continue; + } + seen.insert(purl.clone()); + + results.push(CrawledPackage { + name: canon_name, + version, + namespace: None, + purl, + path: site_packages_path.to_path_buf(), + }); + } + } + + results + } + + /// Parse a PyPI PURL string to extract name and version. + /// Strips qualifiers before parsing. + fn parse_pypi_purl(purl: &str) -> Option<(String, String)> { + // Strip qualifiers + let base = match purl.find('?') { + Some(idx) => &purl[..idx], + None => purl, + }; + + let rest = base.strip_prefix("pkg:pypi/")?; + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + + if name.is_empty() || version.is_empty() { + return None; + } + + Some((name.to_string(), version.to_string())) + } +} + +impl Default for PythonCrawler { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_canonicalize_pypi_name_basic() { + assert_eq!(canonicalize_pypi_name("Requests"), "requests"); + assert_eq!(canonicalize_pypi_name("my_package"), "my-package"); + assert_eq!(canonicalize_pypi_name("My.Package"), "my-package"); + assert_eq!(canonicalize_pypi_name("My-._Package"), "my-package"); + } + + #[test] + fn test_canonicalize_pypi_name_runs() { + // Runs of separators collapse to single - + assert_eq!(canonicalize_pypi_name("a__b"), "a-b"); + assert_eq!(canonicalize_pypi_name("a-.-b"), "a-b"); + assert_eq!(canonicalize_pypi_name("a_._-b"), "a-b"); + } + + #[test] + fn test_canonicalize_pypi_name_trim() { + assert_eq!(canonicalize_pypi_name(" requests "), "requests"); + } + + #[test] + fn test_parse_pypi_purl() { + let (name, ver) = PythonCrawler::parse_pypi_purl("pkg:pypi/requests@2.28.0").unwrap(); + assert_eq!(name, "requests"); + assert_eq!(ver, "2.28.0"); + } + + #[test] + fn test_parse_pypi_purl_with_qualifiers() { + let (name, ver) = + PythonCrawler::parse_pypi_purl("pkg:pypi/requests@2.28.0?artifact_id=abc").unwrap(); + assert_eq!(name, "requests"); + assert_eq!(ver, "2.28.0"); + } + + #[test] + fn test_parse_pypi_purl_invalid() { + assert!(PythonCrawler::parse_pypi_purl("pkg:npm/lodash@4.17.21").is_none()); + assert!(PythonCrawler::parse_pypi_purl("not-a-purl").is_none()); + } + + #[tokio::test] + async fn test_read_python_metadata_valid() { + let dir = tempfile::tempdir().unwrap(); + let dist_info = dir.path().join("requests-2.28.0.dist-info"); + tokio::fs::create_dir_all(&dist_info).await.unwrap(); + tokio::fs::write( + dist_info.join("METADATA"), + "Metadata-Version: 2.1\nName: Requests\nVersion: 2.28.0\n\nSome description", + ) + .await + .unwrap(); + + let result = read_python_metadata(&dist_info).await; + assert!(result.is_some()); + let (name, version) = result.unwrap(); + assert_eq!(name, "Requests"); + assert_eq!(version, "2.28.0"); + } + + #[tokio::test] + async fn test_read_python_metadata_missing() { + let dir = tempfile::tempdir().unwrap(); + let dist_info = dir.path().join("nonexistent.dist-info"); + assert!(read_python_metadata(&dist_info).await.is_none()); + } + + #[tokio::test] + async fn test_find_python_dirs_literal() { + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("lib").join("python3.11").join("site-packages"); + tokio::fs::create_dir_all(&target).await.unwrap(); + + let results = + find_python_dirs(dir.path(), &["lib", "python3.*", "site-packages"]).await; + assert_eq!(results.len(), 1); + assert_eq!(results[0], target); + } + + #[tokio::test] + async fn test_find_python_dirs_wildcard() { + let dir = tempfile::tempdir().unwrap(); + let sp1 = dir.path().join("lib").join("python3.10").join("site-packages"); + let sp2 = dir.path().join("lib").join("python3.11").join("site-packages"); + tokio::fs::create_dir_all(&sp1).await.unwrap(); + tokio::fs::create_dir_all(&sp2).await.unwrap(); + + // Also create a non-matching dir + let non_match = dir.path().join("lib").join("ruby3.0").join("site-packages"); + tokio::fs::create_dir_all(&non_match).await.unwrap(); + + let results = + find_python_dirs(dir.path(), &["lib", "python3.*", "site-packages"]).await; + assert_eq!(results.len(), 2); + } + + #[tokio::test] + async fn test_find_python_dirs_star_wildcard() { + let dir = tempfile::tempdir().unwrap(); + let sp1 = dir + .path() + .join("tools") + .join("mytool") + .join("lib") + .join("python3.11") + .join("site-packages"); + tokio::fs::create_dir_all(&sp1).await.unwrap(); + + let results = find_python_dirs( + dir.path(), + &["tools", "*", "lib", "python3.*", "site-packages"], + ) + .await; + assert_eq!(results.len(), 1); + assert_eq!(results[0], sp1); + } + + #[tokio::test] + async fn test_find_python_dirs_pyenv_layout() { + // Create a pyenv-like layout: versions/3.11.5/lib/python3.11/site-packages + let dir = tempfile::tempdir().unwrap(); + let sp1 = dir + .path() + .join("versions") + .join("3.11.5") + .join("lib") + .join("python3.11") + .join("site-packages"); + let sp2 = dir + .path() + .join("versions") + .join("3.12.0") + .join("lib") + .join("python3.12") + .join("site-packages"); + tokio::fs::create_dir_all(&sp1).await.unwrap(); + tokio::fs::create_dir_all(&sp2).await.unwrap(); + + let results = find_python_dirs( + &dir.path().join("versions"), + &["*", "lib", "python3.*", "site-packages"], + ) + .await; + assert_eq!(results.len(), 2); + assert!(results.contains(&sp1)); + assert!(results.contains(&sp2)); + } + + #[tokio::test] + async fn test_crawl_all_python() { + let dir = tempfile::tempdir().unwrap(); + let venv = dir.path().join(".venv"); + let sp = if cfg!(windows) { + venv.join("Lib").join("site-packages") + } else { + venv.join("lib").join("python3.11").join("site-packages") + }; + tokio::fs::create_dir_all(&sp).await.unwrap(); + + // Create a dist-info dir with METADATA + let dist_info = sp.join("requests-2.28.0.dist-info"); + tokio::fs::create_dir_all(&dist_info).await.unwrap(); + tokio::fs::write( + dist_info.join("METADATA"), + "Metadata-Version: 2.1\nName: Requests\nVersion: 2.28.0\n", + ) + .await + .unwrap(); + + let crawler = PythonCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].name, "requests"); + assert_eq!(packages[0].version, "2.28.0"); + assert_eq!(packages[0].purl, "pkg:pypi/requests@2.28.0"); + assert!(packages[0].namespace.is_none()); + } + + #[test] + fn test_find_python_command() { + // On any platform with Python installed, this should return Some + // In CI environments, Python is typically available + let cmd = find_python_command(); + // We don't assert Some because Python may not be installed, + // but if it is, the command should be valid + if let Some(c) = cmd { + assert!( + ["python3", "python", "py"].contains(&c), + "unexpected command: {c}" + ); + } + } + + #[test] + fn test_home_dir_detection() { + // Verify the fallback chain works: HOME -> USERPROFILE -> "~" + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + // On any CI or dev machine, we should get a real path, not "~" + assert_ne!(home, "~", "expected a real home directory"); + assert!(!home.is_empty()); + } + + #[tokio::test] + async fn test_find_by_purls_python() { + let dir = tempfile::tempdir().unwrap(); + let sp = dir.path().to_path_buf(); + + // Create dist-info + let dist_info = sp.join("requests-2.28.0.dist-info"); + tokio::fs::create_dir_all(&dist_info).await.unwrap(); + tokio::fs::write( + dist_info.join("METADATA"), + "Metadata-Version: 2.1\nName: Requests\nVersion: 2.28.0\n", + ) + .await + .unwrap(); + + let crawler = PythonCrawler::new(); + let purls = vec![ + "pkg:pypi/requests@2.28.0".to_string(), + "pkg:pypi/flask@3.0.0".to_string(), + ]; + + let result = crawler.find_by_purls(&sp, &purls).await.unwrap(); + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:pypi/requests@2.28.0")); + assert!(!result.contains_key("pkg:pypi/flask@3.0.0")); + } +} diff --git a/crates/socket-patch-core/src/crawlers/ruby_crawler.rs b/crates/socket-patch-core/src/crawlers/ruby_crawler.rs new file mode 100644 index 0000000..893fde9 --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/ruby_crawler.rs @@ -0,0 +1,517 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +/// Ruby/RubyGems ecosystem crawler for discovering gems in Bundler vendor +/// directories or global gem installation paths. +pub struct RubyCrawler; + +impl RubyCrawler { + /// Create a new `RubyCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get gem installation paths based on options. + /// + /// In local mode, checks `vendor/bundle/ruby/*/gems/` first (Bundler + /// deployment layout), but only if `Gemfile` or `Gemfile.lock` exists + /// in the cwd. Falls back to querying `gem env gemdir`. + /// + /// In global mode, queries `gem env gemdir` and `gem env gempath`, plus + /// well-known fallback paths for rbenv, rvm, Homebrew, and system Ruby. + pub async fn get_gem_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(Self::get_global_gem_paths().await); + } + + // Local mode: check vendor/bundle first + let vendor_gems = Self::get_vendor_bundle_paths(&options.cwd).await; + if !vendor_gems.is_empty() { + return Ok(vendor_gems); + } + + // Only fall back to global gem paths if this looks like a Ruby project + let has_gemfile = tokio::fs::metadata(options.cwd.join("Gemfile")) + .await + .is_ok(); + let has_gemfile_lock = tokio::fs::metadata(options.cwd.join("Gemfile.lock")) + .await + .is_ok(); + + if has_gemfile || has_gemfile_lock { + // Try gem env gemdir + let mut paths = Vec::new(); + if let Some(gemdir) = Self::run_gem_env("gemdir").await { + let gems_path = PathBuf::from(gemdir).join("gems"); + if is_dir(&gems_path).await { + paths.push(gems_path); + } + } + if !paths.is_empty() { + return Ok(paths); + } + } + + // Not a Ruby project — return empty + Ok(Vec::new()) + } + + /// Crawl all discovered gem paths and return every package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let gem_paths = self.get_gem_paths(options).await.unwrap_or_default(); + + for gem_path in &gem_paths { + let found = self.scan_gem_dir(gem_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL inside a single gem directory. + /// + /// Gem directories follow the `-` pattern. + pub async fn find_by_purls( + &self, + gem_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + for purl in purls { + if let Some((name, version)) = crate::utils::purl::parse_gem_purl(purl) { + let gem_dir = gem_path.join(format!("{name}-{version}")); + if self.verify_gem_at_path(&gem_dir).await { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: gem_dir, + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// Find `vendor/bundle/ruby/*/gems/` directories. + async fn get_vendor_bundle_paths(cwd: &Path) -> Vec { + let vendor_ruby = cwd.join("vendor").join("bundle").join("ruby"); + let mut paths = Vec::new(); + + let mut entries = match tokio::fs::read_dir(&vendor_ruby).await { + Ok(rd) => rd, + Err(_) => return paths, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if ft.is_dir() { + let gems_dir = vendor_ruby.join(entry.file_name()).join("gems"); + if is_dir(&gems_dir).await { + paths.push(gems_dir); + } + } + } + + paths + } + + /// Get global gem paths by querying `gem env` and checking well-known locations. + async fn get_global_gem_paths() -> Vec { + let mut paths = Vec::new(); + let mut seen = HashSet::new(); + + // gem env gemdir + if let Some(gemdir) = Self::run_gem_env("gemdir").await { + let gems_path = PathBuf::from(gemdir).join("gems"); + if is_dir(&gems_path).await && seen.insert(gems_path.clone()) { + paths.push(gems_path); + } + } + + // gem env gempath (colon-separated) + if let Some(gempath) = Self::run_gem_env("gempath").await { + for segment in gempath.split(':') { + let segment = segment.trim(); + if segment.is_empty() { + continue; + } + let gems_path = PathBuf::from(segment).join("gems"); + if is_dir(&gems_path).await && seen.insert(gems_path.clone()) { + paths.push(gems_path); + } + } + } + + // Fallback well-known paths + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + let home = PathBuf::from(home); + + let fallback_globs = [ + home.join(".gem").join("ruby"), + home.join(".rbenv").join("versions"), + home.join(".rvm").join("gems"), + ]; + + for base in &fallback_globs { + if let Ok(mut entries) = tokio::fs::read_dir(base).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let entry_path = base.join(entry.file_name()); + + // ~/.gem/ruby/*/gems/ + let gems_dir = entry_path.join("gems"); + if is_dir(&gems_dir).await && seen.insert(gems_dir.clone()) { + paths.push(gems_dir); + continue; + } + + // ~/.rbenv/versions/*/lib/ruby/gems/*/gems/ + let lib_ruby_gems = entry_path.join("lib").join("ruby").join("gems"); + if let Ok(mut sub_entries) = tokio::fs::read_dir(&lib_ruby_gems).await { + while let Ok(Some(sub_entry)) = sub_entries.next_entry().await { + let gems_dir = lib_ruby_gems.join(sub_entry.file_name()).join("gems"); + if is_dir(&gems_dir).await && seen.insert(gems_dir.clone()) { + paths.push(gems_dir); + } + } + } + } + } + } + + // System paths + let system_bases = [ + PathBuf::from("/usr/lib/ruby/gems"), + PathBuf::from("/usr/local/lib/ruby/gems"), + PathBuf::from("/opt/homebrew/lib/ruby/gems"), + ]; + + for base in &system_bases { + if let Ok(mut entries) = tokio::fs::read_dir(base).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let gems_dir = base.join(entry.file_name()).join("gems"); + if is_dir(&gems_dir).await && seen.insert(gems_dir.clone()) { + paths.push(gems_dir); + } + } + } + } + + paths + } + + /// Run `gem env ` and return the trimmed stdout. + async fn run_gem_env(key: &str) -> Option { + let output = std::process::Command::new("gem") + .args(["env", key]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if stdout.is_empty() { + None + } else { + Some(stdout) + } + } + + /// Scan a gem directory and return all valid gem packages found. + async fn scan_gem_dir( + &self, + gem_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + let mut entries = match tokio::fs::read_dir(gem_path).await { + Ok(rd) => rd, + Err(_) => return results, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let dir_name = entry.file_name(); + let dir_name_str = dir_name.to_string_lossy(); + + // Skip hidden directories + if dir_name_str.starts_with('.') { + continue; + } + + let gem_dir = gem_path.join(&*dir_name_str); + + // Parse name-version from directory name + if let Some((name, version)) = Self::parse_dir_name_version(&dir_name_str) { + // Verify it looks like a gem (has .gemspec or lib/) + if !self.verify_gem_at_path(&gem_dir).await { + continue; + } + + let purl = crate::utils::purl::build_gem_purl(&name, &version); + + if seen.contains(&purl) { + continue; + } + seen.insert(purl.clone()); + + results.push(CrawledPackage { + name, + version, + namespace: None, + purl, + path: gem_dir, + }); + } + } + + results + } + + /// Verify that a directory looks like an installed gem. + /// Checks for a `.gemspec` file or a `lib/` directory. + async fn verify_gem_at_path(&self, path: &Path) -> bool { + if !is_dir(path).await { + return false; + } + + // Check for lib/ directory + if is_dir(&path.join("lib")).await { + return true; + } + + // Check for any .gemspec file + if let Ok(mut entries) = tokio::fs::read_dir(path).await { + while let Ok(Some(entry)) = entries.next_entry().await { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(".gemspec") { + return true; + } + } + } + } + + false + } + + /// Parse a gem directory name into (name, version). + /// + /// Gem directories follow the pattern `-`, where the + /// version is the last `-`-separated component that starts with a digit. + fn parse_dir_name_version(dir_name: &str) -> Option<(String, String)> { + // Find the last '-' followed by a digit + let mut split_idx = None; + for (i, _) in dir_name.match_indices('-') { + if dir_name[i + 1..].starts_with(|c: char| c.is_ascii_digit()) { + split_idx = Some(i); + } + } + let idx = split_idx?; + let name = &dir_name[..idx]; + let version = &dir_name[idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name.to_string(), version.to_string())) + } +} + +impl Default for RubyCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_gem_dir_name() { + assert_eq!( + RubyCrawler::parse_dir_name_version("rails-7.1.0"), + Some(("rails".to_string(), "7.1.0".to_string())) + ); + assert_eq!( + RubyCrawler::parse_dir_name_version("nokogiri-1.16.5"), + Some(("nokogiri".to_string(), "1.16.5".to_string())) + ); + assert_eq!( + RubyCrawler::parse_dir_name_version("activerecord-7.1.3.2"), + Some(("activerecord".to_string(), "7.1.3.2".to_string())) + ); + assert_eq!( + RubyCrawler::parse_dir_name_version("net-http-0.4.1"), + Some(("net-http".to_string(), "0.4.1".to_string())) + ); + assert!(RubyCrawler::parse_dir_name_version("no-version-here").is_none()); + assert!(RubyCrawler::parse_dir_name_version("noversion").is_none()); + } + + #[tokio::test] + async fn test_find_by_purls_gem() { + let dir = tempfile::tempdir().unwrap(); + let rails_dir = dir.path().join("rails-7.1.0"); + tokio::fs::create_dir_all(rails_dir.join("lib")).await.unwrap(); + + let crawler = RubyCrawler::new(); + let purls = vec![ + "pkg:gem/rails@7.1.0".to_string(), + "pkg:gem/nokogiri@1.16.5".to_string(), + ]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:gem/rails@7.1.0")); + assert!(!result.contains_key("pkg:gem/nokogiri@1.16.5")); + } + + #[tokio::test] + async fn test_crawl_all_gems() { + let dir = tempfile::tempdir().unwrap(); + + // Create fake gem directories with lib/ + let rails_dir = dir.path().join("rails-7.1.0"); + tokio::fs::create_dir_all(rails_dir.join("lib")).await.unwrap(); + + let nokogiri_dir = dir.path().join("nokogiri-1.16.5"); + tokio::fs::create_dir_all(nokogiri_dir.join("lib")).await.unwrap(); + + let crawler = RubyCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:gem/rails@7.1.0")); + assert!(purls.contains("pkg:gem/nokogiri@1.16.5")); + } + + #[tokio::test] + async fn test_get_gem_paths_with_vendor_bundle() { + let dir = tempfile::tempdir().unwrap(); + let vendor_gems = dir + .path() + .join("vendor") + .join("bundle") + .join("ruby") + .join("3.2.0") + .join("gems"); + tokio::fs::create_dir_all(&vendor_gems).await.unwrap(); + + let paths = RubyCrawler::get_vendor_bundle_paths(dir.path()).await; + assert_eq!(paths.len(), 1); + assert_eq!(paths[0], vendor_gems); + } + + #[tokio::test] + async fn test_deduplication() { + let dir = tempfile::tempdir().unwrap(); + + // Create a single gem directory + let rails_dir = dir.path().join("rails-7.1.0"); + tokio::fs::create_dir_all(rails_dir.join("lib")).await.unwrap(); + + let crawler = RubyCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:gem/rails@7.1.0"); + } + + #[tokio::test] + async fn test_verify_gem_with_gemspec() { + let dir = tempfile::tempdir().unwrap(); + let gem_dir = dir.path().join("rails-7.1.0"); + tokio::fs::create_dir_all(&gem_dir).await.unwrap(); + tokio::fs::write(gem_dir.join("rails.gemspec"), "# gemspec") + .await + .unwrap(); + + let crawler = RubyCrawler::new(); + assert!(crawler.verify_gem_at_path(&gem_dir).await); + } + + #[tokio::test] + async fn test_verify_gem_empty_dir_fails() { + let dir = tempfile::tempdir().unwrap(); + let gem_dir = dir.path().join("rails-7.1.0"); + tokio::fs::create_dir_all(&gem_dir).await.unwrap(); + + let crawler = RubyCrawler::new(); + assert!(!crawler.verify_gem_at_path(&gem_dir).await); + } +} diff --git a/crates/socket-patch-core/src/crawlers/types.rs b/crates/socket-patch-core/src/crawlers/types.rs new file mode 100644 index 0000000..9bcdbdd --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/types.rs @@ -0,0 +1,347 @@ +use std::path::PathBuf; + +/// Identifies a supported package ecosystem. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Ecosystem { + Npm, + Pypi, + #[cfg(feature = "cargo")] + Cargo, + Gem, + #[cfg(feature = "golang")] + Golang, + #[cfg(feature = "maven")] + Maven, + #[cfg(feature = "composer")] + Composer, + #[cfg(feature = "nuget")] + Nuget, +} + +impl Ecosystem { + /// All enabled ecosystems. + pub fn all() -> &'static [Ecosystem] { + &[ + Ecosystem::Npm, + Ecosystem::Pypi, + #[cfg(feature = "cargo")] + Ecosystem::Cargo, + Ecosystem::Gem, + #[cfg(feature = "golang")] + Ecosystem::Golang, + #[cfg(feature = "maven")] + Ecosystem::Maven, + #[cfg(feature = "composer")] + Ecosystem::Composer, + #[cfg(feature = "nuget")] + Ecosystem::Nuget, + ] + } + + /// Match a PURL string to its ecosystem. + pub fn from_purl(purl: &str) -> Option { + #[cfg(feature = "cargo")] + if purl.starts_with("pkg:cargo/") { + return Some(Ecosystem::Cargo); + } + if purl.starts_with("pkg:gem/") { + return Some(Ecosystem::Gem); + } + #[cfg(feature = "golang")] + if purl.starts_with("pkg:golang/") { + return Some(Ecosystem::Golang); + } + #[cfg(feature = "maven")] + if purl.starts_with("pkg:maven/") { + return Some(Ecosystem::Maven); + } + #[cfg(feature = "composer")] + if purl.starts_with("pkg:composer/") { + return Some(Ecosystem::Composer); + } + #[cfg(feature = "nuget")] + if purl.starts_with("pkg:nuget/") { + return Some(Ecosystem::Nuget); + } + if purl.starts_with("pkg:npm/") { + Some(Ecosystem::Npm) + } else if purl.starts_with("pkg:pypi/") { + Some(Ecosystem::Pypi) + } else { + None + } + } + + /// The PURL prefix for this ecosystem (e.g. `"pkg:npm/"`). + pub fn purl_prefix(&self) -> &'static str { + match self { + Ecosystem::Npm => "pkg:npm/", + Ecosystem::Pypi => "pkg:pypi/", + #[cfg(feature = "cargo")] + Ecosystem::Cargo => "pkg:cargo/", + Ecosystem::Gem => "pkg:gem/", + #[cfg(feature = "golang")] + Ecosystem::Golang => "pkg:golang/", + #[cfg(feature = "maven")] + Ecosystem::Maven => "pkg:maven/", + #[cfg(feature = "composer")] + Ecosystem::Composer => "pkg:composer/", + #[cfg(feature = "nuget")] + Ecosystem::Nuget => "pkg:nuget/", + } + } + + /// Name used in the `--ecosystems` CLI flag (e.g. `"npm"`, `"pypi"`, `"cargo"`). + pub fn cli_name(&self) -> &'static str { + match self { + Ecosystem::Npm => "npm", + Ecosystem::Pypi => "pypi", + #[cfg(feature = "cargo")] + Ecosystem::Cargo => "cargo", + Ecosystem::Gem => "gem", + #[cfg(feature = "golang")] + Ecosystem::Golang => "golang", + #[cfg(feature = "maven")] + Ecosystem::Maven => "maven", + #[cfg(feature = "composer")] + Ecosystem::Composer => "composer", + #[cfg(feature = "nuget")] + Ecosystem::Nuget => "nuget", + } + } + + /// Human-readable name for user-facing messages. + pub fn display_name(&self) -> &'static str { + match self { + Ecosystem::Npm => "npm", + Ecosystem::Pypi => "python", + #[cfg(feature = "cargo")] + Ecosystem::Cargo => "cargo", + Ecosystem::Gem => "ruby", + #[cfg(feature = "golang")] + Ecosystem::Golang => "go", + #[cfg(feature = "maven")] + Ecosystem::Maven => "maven", + #[cfg(feature = "composer")] + Ecosystem::Composer => "php", + #[cfg(feature = "nuget")] + Ecosystem::Nuget => "nuget", + } + } +} + +/// Represents a package discovered during crawling. +#[derive(Debug, Clone)] +pub struct CrawledPackage { + /// Package name (without scope). + pub name: String, + /// Package version. + pub version: String, + /// Package scope/namespace (e.g., "@types") - None for unscoped packages. + pub namespace: Option, + /// Full PURL string (e.g., "pkg:npm/@types/node@20.0.0"). + pub purl: String, + /// Absolute path to the package directory. + pub path: PathBuf, +} + +/// Options for package crawling. +#[derive(Debug, Clone)] +pub struct CrawlerOptions { + /// Working directory to start from. + pub cwd: PathBuf, + /// Use global packages instead of local packages. + pub global: bool, + /// Custom path to global package directory (overrides auto-detection). + pub global_prefix: Option, + /// Batch size for yielding packages (default: 100). + pub batch_size: usize, +} + +impl Default for CrawlerOptions { + fn default() -> Self { + Self { + cwd: std::env::current_dir().unwrap_or_default(), + global: false, + global_prefix: None, + batch_size: 100, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_purl_npm() { + assert_eq!( + Ecosystem::from_purl("pkg:npm/lodash@4.17.21"), + Some(Ecosystem::Npm) + ); + assert_eq!( + Ecosystem::from_purl("pkg:npm/@types/node@20.0.0"), + Some(Ecosystem::Npm) + ); + } + + #[test] + fn test_from_purl_pypi() { + assert_eq!( + Ecosystem::from_purl("pkg:pypi/requests@2.28.0"), + Some(Ecosystem::Pypi) + ); + } + + #[test] + fn test_from_purl_unknown() { + assert_eq!(Ecosystem::from_purl("pkg:unknown/foo@1.0"), None); + assert_eq!(Ecosystem::from_purl("not-a-purl"), None); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_from_purl_cargo() { + assert_eq!( + Ecosystem::from_purl("pkg:cargo/serde@1.0.200"), + Some(Ecosystem::Cargo) + ); + } + + #[test] + fn test_all_count() { + let all = Ecosystem::all(); + #[allow(unused_mut)] + let mut expected = 3; + #[cfg(feature = "cargo")] + { + expected += 1; + } + #[cfg(feature = "golang")] + { + expected += 1; + } + #[cfg(feature = "maven")] + { + expected += 1; + } + #[cfg(feature = "composer")] + { + expected += 1; + } + #[cfg(feature = "nuget")] + { + expected += 1; + } + assert_eq!(all.len(), expected); + } + + #[test] + fn test_cli_name() { + assert_eq!(Ecosystem::Npm.cli_name(), "npm"); + assert_eq!(Ecosystem::Pypi.cli_name(), "pypi"); + } + + #[test] + fn test_display_name() { + assert_eq!(Ecosystem::Npm.display_name(), "npm"); + assert_eq!(Ecosystem::Pypi.display_name(), "python"); + } + + #[test] + fn test_purl_prefix() { + assert_eq!(Ecosystem::Npm.purl_prefix(), "pkg:npm/"); + assert_eq!(Ecosystem::Pypi.purl_prefix(), "pkg:pypi/"); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_cargo_properties() { + assert_eq!(Ecosystem::Cargo.cli_name(), "cargo"); + assert_eq!(Ecosystem::Cargo.display_name(), "cargo"); + assert_eq!(Ecosystem::Cargo.purl_prefix(), "pkg:cargo/"); + } + + #[test] + fn test_from_purl_gem() { + assert_eq!( + Ecosystem::from_purl("pkg:gem/rails@7.1.0"), + Some(Ecosystem::Gem) + ); + } + + #[test] + fn test_gem_properties() { + assert_eq!(Ecosystem::Gem.cli_name(), "gem"); + assert_eq!(Ecosystem::Gem.display_name(), "ruby"); + assert_eq!(Ecosystem::Gem.purl_prefix(), "pkg:gem/"); + } + + #[cfg(feature = "maven")] + #[test] + fn test_from_purl_maven() { + assert_eq!( + Ecosystem::from_purl("pkg:maven/org.apache.commons/commons-lang3@3.12.0"), + Some(Ecosystem::Maven) + ); + } + + #[cfg(feature = "maven")] + #[test] + fn test_maven_properties() { + assert_eq!(Ecosystem::Maven.cli_name(), "maven"); + assert_eq!(Ecosystem::Maven.display_name(), "maven"); + assert_eq!(Ecosystem::Maven.purl_prefix(), "pkg:maven/"); + } + + #[cfg(feature = "golang")] + #[test] + fn test_from_purl_golang() { + assert_eq!( + Ecosystem::from_purl("pkg:golang/github.com/gin-gonic/gin@v1.9.1"), + Some(Ecosystem::Golang) + ); + } + + #[cfg(feature = "golang")] + #[test] + fn test_golang_properties() { + assert_eq!(Ecosystem::Golang.cli_name(), "golang"); + assert_eq!(Ecosystem::Golang.display_name(), "go"); + assert_eq!(Ecosystem::Golang.purl_prefix(), "pkg:golang/"); + } + + #[cfg(feature = "composer")] + #[test] + fn test_from_purl_composer() { + assert_eq!( + Ecosystem::from_purl("pkg:composer/monolog/monolog@3.5.0"), + Some(Ecosystem::Composer) + ); + } + + #[cfg(feature = "composer")] + #[test] + fn test_composer_properties() { + assert_eq!(Ecosystem::Composer.cli_name(), "composer"); + assert_eq!(Ecosystem::Composer.display_name(), "php"); + assert_eq!(Ecosystem::Composer.purl_prefix(), "pkg:composer/"); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_from_purl_nuget() { + assert_eq!( + Ecosystem::from_purl("pkg:nuget/Newtonsoft.Json@13.0.3"), + Some(Ecosystem::Nuget) + ); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_nuget_properties() { + assert_eq!(Ecosystem::Nuget.cli_name(), "nuget"); + assert_eq!(Ecosystem::Nuget.display_name(), "nuget"); + assert_eq!(Ecosystem::Nuget.purl_prefix(), "pkg:nuget/"); + } +} diff --git a/crates/socket-patch-core/src/hash/git_sha256.rs b/crates/socket-patch-core/src/hash/git_sha256.rs new file mode 100644 index 0000000..b4ccd42 --- /dev/null +++ b/crates/socket-patch-core/src/hash/git_sha256.rs @@ -0,0 +1,89 @@ +use sha2::{Digest, Sha256}; +use std::io; +use tokio::io::AsyncReadExt; + +/// Compute Git-compatible SHA256 hash for a byte slice. +/// +/// Git hashes objects as: SHA256("blob \0" + content) +pub fn compute_git_sha256_from_bytes(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + let header = format!("blob {}\0", data.len()); + hasher.update(header.as_bytes()); + hasher.update(data); + hex::encode(hasher.finalize()) +} + +/// Compute Git-compatible SHA256 hash from an async reader with known size. +/// +/// This streams the content through the hasher without loading it all into memory. +pub async fn compute_git_sha256_from_reader( + size: u64, + mut reader: R, +) -> io::Result { + let mut hasher = Sha256::new(); + let header = format!("blob {}\0", size); + hasher.update(header.as_bytes()); + + let mut buf = [0u8; 8192]; + loop { + let n = reader.read(&mut buf).await?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + } + + Ok(hex::encode(hasher.finalize())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_content() { + let hash = compute_git_sha256_from_bytes(b""); + // SHA256("blob 0\0") - Git-compatible hash of empty content + assert_eq!(hash.len(), 64); + // Verify it's consistent + assert_eq!(hash, compute_git_sha256_from_bytes(b"")); + } + + #[test] + fn test_hello_world() { + let content = b"Hello, World!"; + let hash = compute_git_sha256_from_bytes(content); + assert_eq!(hash.len(), 64); + + // Manually compute expected: SHA256("blob 13\0Hello, World!") + use sha2::{Digest, Sha256}; + let mut expected_hasher = Sha256::new(); + expected_hasher.update(b"blob 13\0Hello, World!"); + let expected = hex::encode(expected_hasher.finalize()); + assert_eq!(hash, expected); + } + + #[test] + fn test_known_vector() { + // Known test vector: SHA256("blob 0\0") + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(b"blob 0\0"); + let expected = hex::encode(hasher.finalize()); + assert_eq!(compute_git_sha256_from_bytes(b""), expected); + } + + #[tokio::test] + async fn test_async_reader_matches_sync() { + let content = b"test content for async hashing"; + let sync_hash = compute_git_sha256_from_bytes(content); + + let cursor = tokio::io::BufReader::new(&content[..]); + let async_hash = + compute_git_sha256_from_reader(content.len() as u64, cursor) + .await + .unwrap(); + + assert_eq!(sync_hash, async_hash); + } +} diff --git a/crates/socket-patch-core/src/hash/mod.rs b/crates/socket-patch-core/src/hash/mod.rs new file mode 100644 index 0000000..45732e4 --- /dev/null +++ b/crates/socket-patch-core/src/hash/mod.rs @@ -0,0 +1,3 @@ +pub mod git_sha256; + +pub use git_sha256::*; diff --git a/crates/socket-patch-core/src/lib.rs b/crates/socket-patch-core/src/lib.rs new file mode 100644 index 0000000..3683364 --- /dev/null +++ b/crates/socket-patch-core/src/lib.rs @@ -0,0 +1,8 @@ +pub mod api; +pub mod constants; +pub mod crawlers; +pub mod hash; +pub mod manifest; +pub mod package_json; +pub mod patch; +pub mod utils; diff --git a/crates/socket-patch-core/src/manifest/mod.rs b/crates/socket-patch-core/src/manifest/mod.rs new file mode 100644 index 0000000..39bd775 --- /dev/null +++ b/crates/socket-patch-core/src/manifest/mod.rs @@ -0,0 +1,5 @@ +pub mod operations; +pub mod recovery; +pub mod schema; + +pub use schema::*; diff --git a/crates/socket-patch-core/src/manifest/operations.rs b/crates/socket-patch-core/src/manifest/operations.rs new file mode 100644 index 0000000..30fae4a --- /dev/null +++ b/crates/socket-patch-core/src/manifest/operations.rs @@ -0,0 +1,460 @@ +use std::collections::HashSet; +use std::path::Path; + +use crate::manifest::schema::PatchManifest; + +/// Get all blob hashes referenced by a manifest (both beforeHash and afterHash). +/// Used for garbage collection and validation. +pub fn get_referenced_blobs(manifest: &PatchManifest) -> HashSet { + let mut blobs = HashSet::new(); + + for record in manifest.patches.values() { + for file_info in record.files.values() { + blobs.insert(file_info.before_hash.clone()); + blobs.insert(file_info.after_hash.clone()); + } + } + + blobs +} + +/// Get only afterHash blobs referenced by a manifest. +/// Used for apply operations -- we only need the patched file content, not the original. +/// This saves disk space since beforeHash blobs are not needed for applying patches. +pub fn get_after_hash_blobs(manifest: &PatchManifest) -> HashSet { + let mut blobs = HashSet::new(); + + for record in manifest.patches.values() { + for file_info in record.files.values() { + blobs.insert(file_info.after_hash.clone()); + } + } + + blobs +} + +/// Get only beforeHash blobs referenced by a manifest. +/// Used for rollback operations -- we need the original file content to restore. +pub fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet { + let mut blobs = HashSet::new(); + + for record in manifest.patches.values() { + for file_info in record.files.values() { + blobs.insert(file_info.before_hash.clone()); + } + } + + blobs +} + +/// Differences between two manifests. +#[derive(Debug, Clone)] +pub struct ManifestDiff { + /// PURLs present in new but not old. + pub added: HashSet, + /// PURLs present in old but not new. + pub removed: HashSet, + /// PURLs present in both but with different UUIDs. + pub modified: HashSet, +} + +/// Calculate differences between two manifests. +/// Patches are compared by UUID: if the PURL exists in both manifests but the +/// UUID changed, the patch is considered modified. +pub fn diff_manifests(old_manifest: &PatchManifest, new_manifest: &PatchManifest) -> ManifestDiff { + let old_purls: HashSet<&String> = old_manifest.patches.keys().collect(); + let new_purls: HashSet<&String> = new_manifest.patches.keys().collect(); + + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + let mut modified = HashSet::new(); + + // Find added and modified + for purl in &new_purls { + if !old_purls.contains(purl) { + added.insert((*purl).clone()); + } else { + let old_patch = &old_manifest.patches[*purl]; + let new_patch = &new_manifest.patches[*purl]; + if old_patch.uuid != new_patch.uuid { + modified.insert((*purl).clone()); + } + } + } + + // Find removed + for purl in &old_purls { + if !new_purls.contains(purl) { + removed.insert((*purl).clone()); + } + } + + ManifestDiff { + added, + removed, + modified, + } +} + +/// Validate a parsed JSON value as a PatchManifest. +/// Returns Ok(manifest) if valid, or Err(message) if invalid. +pub fn validate_manifest(value: &serde_json::Value) -> Result { + serde_json::from_value::(value.clone()) + .map_err(|e| format!("Invalid manifest: {}", e)) +} + +/// Read and parse a manifest from the filesystem. +/// Returns Ok(None) if the file does not exist. +/// Returns Err for I/O errors, JSON parse errors, or validation errors. +pub async fn read_manifest(path: impl AsRef) -> Result, std::io::Error> { + let path = path.as_ref(); + + let content = match tokio::fs::read_to_string(path).await { + Ok(c) => c, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None), + Err(e) => return Err(e), // FIX: propagate actual I/O error + }; + + let parsed: serde_json::Value = match serde_json::from_str(&content) { + Ok(v) => v, + Err(e) => return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse manifest JSON: {}", e), + )), + }; + + match validate_manifest(&parsed) { + Ok(manifest) => Ok(Some(manifest)), + Err(e) => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + e, + )), + } +} + +/// Write a manifest to the filesystem with pretty-printed JSON. +pub async fn write_manifest( + path: impl AsRef, + manifest: &PatchManifest, +) -> Result<(), std::io::Error> { + let content = serde_json::to_string_pretty(manifest) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + tokio::fs::write(path, content).await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::manifest::schema::{PatchFileInfo, PatchRecord}; + use std::collections::HashMap; + + const TEST_UUID_1: &str = "11111111-1111-4111-8111-111111111111"; + const TEST_UUID_2: &str = "22222222-2222-4222-8222-222222222222"; + + const BEFORE_HASH_1: &str = + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1111"; + const AFTER_HASH_1: &str = + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb1111"; + const BEFORE_HASH_2: &str = + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc2222"; + const AFTER_HASH_2: &str = + "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd2222"; + const BEFORE_HASH_3: &str = + "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee3333"; + const AFTER_HASH_3: &str = + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3333"; + + fn create_test_manifest() -> PatchManifest { + let mut patches = HashMap::new(); + + let mut files_a = HashMap::new(); + files_a.insert( + "package/index.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_1.to_string(), + after_hash: AFTER_HASH_1.to_string(), + }, + ); + files_a.insert( + "package/lib/utils.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_2.to_string(), + after_hash: AFTER_HASH_2.to_string(), + }, + ); + + patches.insert( + "pkg:npm/pkg-a@1.0.0".to_string(), + PatchRecord { + uuid: TEST_UUID_1.to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files: files_a, + vulnerabilities: HashMap::new(), + description: "Test patch 1".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + }, + ); + + let mut files_b = HashMap::new(); + files_b.insert( + "package/main.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_3.to_string(), + after_hash: AFTER_HASH_3.to_string(), + }, + ); + + patches.insert( + "pkg:npm/pkg-b@2.0.0".to_string(), + PatchRecord { + uuid: TEST_UUID_2.to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files: files_b, + vulnerabilities: HashMap::new(), + description: "Test patch 2".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + }, + ); + + PatchManifest { patches } + } + + #[test] + fn test_get_referenced_blobs_returns_all() { + let manifest = create_test_manifest(); + let blobs = get_referenced_blobs(&manifest); + + assert_eq!(blobs.len(), 6); + assert!(blobs.contains(BEFORE_HASH_1)); + assert!(blobs.contains(AFTER_HASH_1)); + assert!(blobs.contains(BEFORE_HASH_2)); + assert!(blobs.contains(AFTER_HASH_2)); + assert!(blobs.contains(BEFORE_HASH_3)); + assert!(blobs.contains(AFTER_HASH_3)); + } + + #[test] + fn test_get_referenced_blobs_empty_manifest() { + let manifest = PatchManifest::new(); + let blobs = get_referenced_blobs(&manifest); + assert_eq!(blobs.len(), 0); + } + + #[test] + fn test_get_referenced_blobs_deduplicates() { + let mut files = HashMap::new(); + files.insert( + "package/file1.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_1.to_string(), + after_hash: AFTER_HASH_1.to_string(), + }, + ); + files.insert( + "package/file2.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_1.to_string(), // same as file1 + after_hash: AFTER_HASH_2.to_string(), + }, + ); + + let mut patches = HashMap::new(); + patches.insert( + "pkg:npm/pkg-a@1.0.0".to_string(), + PatchRecord { + uuid: TEST_UUID_1.to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files, + vulnerabilities: HashMap::new(), + description: "Test".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + }, + ); + + let manifest = PatchManifest { patches }; + let blobs = get_referenced_blobs(&manifest); + // 3 unique hashes, not 4 + assert_eq!(blobs.len(), 3); + } + + #[test] + fn test_get_after_hash_blobs() { + let manifest = create_test_manifest(); + let blobs = get_after_hash_blobs(&manifest); + + assert_eq!(blobs.len(), 3); + assert!(blobs.contains(AFTER_HASH_1)); + assert!(blobs.contains(AFTER_HASH_2)); + assert!(blobs.contains(AFTER_HASH_3)); + assert!(!blobs.contains(BEFORE_HASH_1)); + assert!(!blobs.contains(BEFORE_HASH_2)); + assert!(!blobs.contains(BEFORE_HASH_3)); + } + + #[test] + fn test_get_after_hash_blobs_empty() { + let manifest = PatchManifest::new(); + let blobs = get_after_hash_blobs(&manifest); + assert_eq!(blobs.len(), 0); + } + + #[test] + fn test_get_before_hash_blobs() { + let manifest = create_test_manifest(); + let blobs = get_before_hash_blobs(&manifest); + + assert_eq!(blobs.len(), 3); + assert!(blobs.contains(BEFORE_HASH_1)); + assert!(blobs.contains(BEFORE_HASH_2)); + assert!(blobs.contains(BEFORE_HASH_3)); + assert!(!blobs.contains(AFTER_HASH_1)); + assert!(!blobs.contains(AFTER_HASH_2)); + assert!(!blobs.contains(AFTER_HASH_3)); + } + + #[test] + fn test_get_before_hash_blobs_empty() { + let manifest = PatchManifest::new(); + let blobs = get_before_hash_blobs(&manifest); + assert_eq!(blobs.len(), 0); + } + + #[test] + fn test_after_plus_before_equals_all() { + let manifest = create_test_manifest(); + let all_blobs = get_referenced_blobs(&manifest); + let after_blobs = get_after_hash_blobs(&manifest); + let before_blobs = get_before_hash_blobs(&manifest); + + let union: HashSet = after_blobs.union(&before_blobs).cloned().collect(); + assert_eq!(union.len(), all_blobs.len()); + for blob in &all_blobs { + assert!(union.contains(blob)); + } + } + + #[test] + fn test_diff_manifests_added() { + let old = PatchManifest::new(); + let new_manifest = create_test_manifest(); + + let diff = diff_manifests(&old, &new_manifest); + assert_eq!(diff.added.len(), 2); + assert!(diff.added.contains("pkg:npm/pkg-a@1.0.0")); + assert!(diff.added.contains("pkg:npm/pkg-b@2.0.0")); + assert_eq!(diff.removed.len(), 0); + assert_eq!(diff.modified.len(), 0); + } + + #[test] + fn test_diff_manifests_removed() { + let old = create_test_manifest(); + let new_manifest = PatchManifest::new(); + + let diff = diff_manifests(&old, &new_manifest); + assert_eq!(diff.added.len(), 0); + assert_eq!(diff.removed.len(), 2); + assert!(diff.removed.contains("pkg:npm/pkg-a@1.0.0")); + assert!(diff.removed.contains("pkg:npm/pkg-b@2.0.0")); + assert_eq!(diff.modified.len(), 0); + } + + #[test] + fn test_diff_manifests_modified() { + let old = create_test_manifest(); + let mut new_manifest = create_test_manifest(); + // Change UUID of pkg-a + new_manifest + .patches + .get_mut("pkg:npm/pkg-a@1.0.0") + .unwrap() + .uuid = "33333333-3333-4333-8333-333333333333".to_string(); + + let diff = diff_manifests(&old, &new_manifest); + assert_eq!(diff.added.len(), 0); + assert_eq!(diff.removed.len(), 0); + assert_eq!(diff.modified.len(), 1); + assert!(diff.modified.contains("pkg:npm/pkg-a@1.0.0")); + } + + #[test] + fn test_diff_manifests_same() { + let old = create_test_manifest(); + let new_manifest = create_test_manifest(); + + let diff = diff_manifests(&old, &new_manifest); + assert_eq!(diff.added.len(), 0); + assert_eq!(diff.removed.len(), 0); + assert_eq!(diff.modified.len(), 0); + } + + #[test] + fn test_validate_manifest_valid() { + let json = serde_json::json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111", + "exportedAt": "2024-01-01T00:00:00Z", + "files": {}, + "vulnerabilities": {}, + "description": "test", + "license": "MIT", + "tier": "free" + } + } + }); + + let result = validate_manifest(&json); + assert!(result.is_ok()); + let manifest = result.unwrap(); + assert_eq!(manifest.patches.len(), 1); + } + + #[test] + fn test_validate_manifest_invalid() { + let json = serde_json::json!({ + "patches": "not-an-object" + }); + + let result = validate_manifest(&json); + assert!(result.is_err()); + } + + #[test] + fn test_validate_manifest_missing_fields() { + let json = serde_json::json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "test" + } + } + }); + + let result = validate_manifest(&json); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_read_manifest_not_found() { + let result = read_manifest("/nonexistent/path/manifest.json").await; + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn test_write_and_read_manifest() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("manifest.json"); + + let manifest = create_test_manifest(); + write_manifest(&path, &manifest).await.unwrap(); + + let read_back = read_manifest(&path).await.unwrap(); + assert!(read_back.is_some()); + let read_back = read_back.unwrap(); + assert_eq!(read_back.patches.len(), 2); + } +} diff --git a/crates/socket-patch-core/src/manifest/recovery.rs b/crates/socket-patch-core/src/manifest/recovery.rs new file mode 100644 index 0000000..e0fb498 --- /dev/null +++ b/crates/socket-patch-core/src/manifest/recovery.rs @@ -0,0 +1,543 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; + +use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo}; + +/// Result of manifest recovery operation. +#[derive(Debug, Clone)] +pub struct RecoveryResult { + pub manifest: PatchManifest, + pub repair_needed: bool, + pub invalid_patches: Vec, + pub recovered_patches: Vec, + pub discarded_patches: Vec, +} + +/// Patch data returned from an external source (e.g., database). +#[derive(Debug, Clone)] +pub struct PatchData { + pub uuid: String, + pub purl: String, + pub published_at: String, + pub files: HashMap, + pub vulnerabilities: HashMap, + pub description: String, + pub license: String, + pub tier: String, +} + +/// File info from external patch data (hashes are optional). +#[derive(Debug, Clone)] +pub struct PatchDataFileInfo { + pub before_hash: Option, + pub after_hash: Option, +} + +/// Vulnerability info from external patch data. +#[derive(Debug, Clone)] +pub struct PatchDataVulnerability { + pub cves: Vec, + pub summary: String, + pub severity: String, + pub description: String, +} + +/// Events emitted during recovery. +#[derive(Debug, Clone)] +pub enum RecoveryEvent { + CorruptedManifest, + InvalidPatch { + purl: String, + uuid: Option, + }, + RecoveredPatch { + purl: String, + uuid: String, + }, + DiscardedPatchNotFound { + purl: String, + uuid: String, + }, + DiscardedPatchPurlMismatch { + purl: String, + uuid: String, + db_purl: String, + }, + DiscardedPatchNoUuid { + purl: String, + }, + RecoveryError { + purl: String, + uuid: String, + error: String, + }, +} + +/// Type alias for the refetch callback. +/// Takes (uuid, optional purl) and returns a future resolving to Option. +pub type RefetchPatchFn = Box< + dyn Fn(String, Option) -> Pin, String>> + Send>> + + Send + + Sync, +>; + +/// Type alias for the recovery event callback. +pub type OnRecoveryEventFn = Box; + +/// Options for manifest recovery. +#[derive(Default)] +pub struct RecoveryOptions { + /// Optional function to refetch patch data from external source (e.g., database). + /// Should return patch data or None if not found. + pub refetch_patch: Option, + + /// Optional callback for logging recovery events. + pub on_recovery_event: Option, +} + + +/// Recover and validate manifest with automatic repair of invalid patches. +/// +/// This function attempts to parse and validate a manifest. If the manifest +/// contains invalid patches, it will attempt to recover them using the provided +/// refetch function. Patches that cannot be recovered are discarded. +pub async fn recover_manifest( + parsed: &serde_json::Value, + options: RecoveryOptions, +) -> RecoveryResult { + let RecoveryOptions { + refetch_patch, + on_recovery_event, + } = options; + + let emit = |event: RecoveryEvent| { + if let Some(ref cb) = on_recovery_event { + cb(event); + } + }; + + // Try strict parse first (fast path for valid manifests) + if let Ok(manifest) = serde_json::from_value::(parsed.clone()) { + return RecoveryResult { + manifest, + repair_needed: false, + invalid_patches: vec![], + recovered_patches: vec![], + discarded_patches: vec![], + }; + } + + // Extract patches object with safety checks + let patches_obj = parsed + .as_object() + .and_then(|obj| obj.get("patches")) + .and_then(|p| p.as_object()); + + let patches_obj = match patches_obj { + Some(obj) => obj, + None => { + // Completely corrupted manifest + emit(RecoveryEvent::CorruptedManifest); + return RecoveryResult { + manifest: PatchManifest::new(), + repair_needed: true, + invalid_patches: vec![], + recovered_patches: vec![], + discarded_patches: vec![], + }; + } + }; + + // Try to recover individual patches + let mut recovered_patches_map: HashMap = HashMap::new(); + let mut invalid_patches: Vec = Vec::new(); + let mut recovered_patches: Vec = Vec::new(); + let mut discarded_patches: Vec = Vec::new(); + + for (purl, patch_data) in patches_obj { + // Try to parse this individual patch + if let Ok(record) = serde_json::from_value::(patch_data.clone()) { + // Valid patch, keep it as-is + recovered_patches_map.insert(purl.clone(), record); + } else { + // Invalid patch, try to recover from external source + let uuid = patch_data + .as_object() + .and_then(|obj| obj.get("uuid")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + invalid_patches.push(purl.clone()); + emit(RecoveryEvent::InvalidPatch { + purl: purl.clone(), + uuid: uuid.clone(), + }); + + if let (Some(ref uuid_str), Some(ref refetch)) = (&uuid, &refetch_patch) { + // Try to refetch from external source + match refetch(uuid_str.clone(), Some(purl.clone())).await { + Ok(Some(patch_from_source)) => { + if patch_from_source.purl == *purl { + // Successfully recovered, reconstruct patch record + let mut manifest_files: HashMap = + HashMap::new(); + for (file_path, file_info) in &patch_from_source.files { + if let (Some(before), Some(after)) = + (&file_info.before_hash, &file_info.after_hash) + { + manifest_files.insert( + file_path.clone(), + PatchFileInfo { + before_hash: before.clone(), + after_hash: after.clone(), + }, + ); + } + } + + let mut vulns: HashMap = HashMap::new(); + for (vuln_id, vuln_data) in &patch_from_source.vulnerabilities { + vulns.insert( + vuln_id.clone(), + VulnerabilityInfo { + cves: vuln_data.cves.clone(), + summary: vuln_data.summary.clone(), + severity: vuln_data.severity.clone(), + description: vuln_data.description.clone(), + }, + ); + } + + recovered_patches_map.insert( + purl.clone(), + PatchRecord { + uuid: patch_from_source.uuid.clone(), + exported_at: patch_from_source.published_at.clone(), + files: manifest_files, + vulnerabilities: vulns, + description: patch_from_source.description.clone(), + license: patch_from_source.license.clone(), + tier: patch_from_source.tier.clone(), + }, + ); + + recovered_patches.push(purl.clone()); + emit(RecoveryEvent::RecoveredPatch { + purl: purl.clone(), + uuid: uuid_str.clone(), + }); + } else { + // PURL mismatch - wrong package! + discarded_patches.push(purl.clone()); + emit(RecoveryEvent::DiscardedPatchPurlMismatch { + purl: purl.clone(), + uuid: uuid_str.clone(), + db_purl: patch_from_source.purl.clone(), + }); + } + } + Ok(None) => { + // Not found in external source (might be unpublished) + discarded_patches.push(purl.clone()); + emit(RecoveryEvent::DiscardedPatchNotFound { + purl: purl.clone(), + uuid: uuid_str.clone(), + }); + } + Err(error_msg) => { + // Error during recovery + discarded_patches.push(purl.clone()); + emit(RecoveryEvent::RecoveryError { + purl: purl.clone(), + uuid: uuid_str.clone(), + error: error_msg, + }); + } + } + } else { + // No UUID or no refetch function, can't recover + discarded_patches.push(purl.clone()); + if let Some(uuid) = uuid { + emit(RecoveryEvent::DiscardedPatchNotFound { + purl: purl.clone(), + uuid, + }); + } else { + emit(RecoveryEvent::DiscardedPatchNoUuid { + purl: purl.clone(), + }); + } + } + } + } + + let repair_needed = !invalid_patches.is_empty(); + + RecoveryResult { + manifest: PatchManifest { + patches: recovered_patches_map, + }, + repair_needed, + invalid_patches, + recovered_patches, + discarded_patches, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn test_valid_manifest_no_repair() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111", + "exportedAt": "2024-01-01T00:00:00Z", + "files": {}, + "vulnerabilities": {}, + "description": "test", + "license": "MIT", + "tier": "free" + } + } + }); + + let result = recover_manifest(&parsed, RecoveryOptions::default()).await; + assert!(!result.repair_needed); + assert_eq!(result.manifest.patches.len(), 1); + assert!(result.invalid_patches.is_empty()); + assert!(result.recovered_patches.is_empty()); + assert!(result.discarded_patches.is_empty()); + } + + #[tokio::test] + async fn test_corrupted_manifest_no_patches_key() { + let parsed = json!({ + "something": "else" + }); + + let result = recover_manifest(&parsed, RecoveryOptions::default()).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 0); + } + + #[tokio::test] + async fn test_corrupted_manifest_patches_not_object() { + let parsed = json!({ + "patches": "not-an-object" + }); + + let result = recover_manifest(&parsed, RecoveryOptions::default()).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 0); + } + + #[tokio::test] + async fn test_invalid_patch_discarded_no_refetch() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111" + // missing required fields + } + } + }); + + let result = recover_manifest(&parsed, RecoveryOptions::default()).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 0); + assert_eq!(result.invalid_patches.len(), 1); + assert_eq!(result.discarded_patches.len(), 1); + } + + #[tokio::test] + async fn test_invalid_patch_no_uuid_discarded() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "garbage": true + } + } + }); + + + let events_clone = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); + let events_ref = events_clone.clone(); + + let options = RecoveryOptions { + refetch_patch: None, + on_recovery_event: Some(Box::new(move |event| { + events_ref.lock().unwrap().push(format!("{:?}", event)); + })), + }; + + let result = recover_manifest(&parsed, options).await; + assert!(result.repair_needed); + assert_eq!(result.discarded_patches.len(), 1); + + let logged = events_clone.lock().unwrap(); + assert!(logged.iter().any(|e| e.contains("DiscardedPatchNoUuid"))); + } + + #[tokio::test] + async fn test_mix_valid_and_invalid_patches() { + let parsed = json!({ + "patches": { + "pkg:npm/good@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111", + "exportedAt": "2024-01-01T00:00:00Z", + "files": {}, + "vulnerabilities": {}, + "description": "good patch", + "license": "MIT", + "tier": "free" + }, + "pkg:npm/bad@1.0.0": { + "uuid": "22222222-2222-4222-8222-222222222222" + // missing required fields + } + } + }); + + let result = recover_manifest(&parsed, RecoveryOptions::default()).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 1); + assert!(result.manifest.patches.contains_key("pkg:npm/good@1.0.0")); + assert_eq!(result.invalid_patches.len(), 1); + assert_eq!(result.discarded_patches.len(), 1); + } + + #[tokio::test] + async fn test_recovery_with_refetch_success() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111" + // missing required fields + } + } + }); + + let options = RecoveryOptions { + refetch_patch: Some(Box::new(|_uuid, _purl| { + Box::pin(async { + Ok(Some(PatchData { + uuid: "11111111-1111-4111-8111-111111111111".to_string(), + purl: "pkg:npm/test@1.0.0".to_string(), + published_at: "2024-01-01T00:00:00Z".to_string(), + files: { + let mut m = HashMap::new(); + m.insert( + "package/index.js".to_string(), + PatchDataFileInfo { + before_hash: Some("aaa".to_string()), + after_hash: Some("bbb".to_string()), + }, + ); + m + }, + vulnerabilities: HashMap::new(), + description: "recovered".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + })) + }) + })), + on_recovery_event: None, + }; + + let result = recover_manifest(&parsed, options).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 1); + assert_eq!(result.recovered_patches.len(), 1); + assert_eq!(result.discarded_patches.len(), 0); + + let record = result.manifest.patches.get("pkg:npm/test@1.0.0").unwrap(); + assert_eq!(record.description, "recovered"); + assert_eq!(record.files.len(), 1); + } + + #[tokio::test] + async fn test_recovery_with_purl_mismatch() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111" + } + } + }); + + let options = RecoveryOptions { + refetch_patch: Some(Box::new(|_uuid, _purl| { + Box::pin(async { + Ok(Some(PatchData { + uuid: "11111111-1111-4111-8111-111111111111".to_string(), + purl: "pkg:npm/other@2.0.0".to_string(), // wrong purl + published_at: "2024-01-01T00:00:00Z".to_string(), + files: HashMap::new(), + vulnerabilities: HashMap::new(), + description: "wrong".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + })) + }) + })), + on_recovery_event: None, + }; + + let result = recover_manifest(&parsed, options).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 0); + assert_eq!(result.discarded_patches.len(), 1); + } + + #[tokio::test] + async fn test_recovery_with_refetch_not_found() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111" + } + } + }); + + let options = RecoveryOptions { + refetch_patch: Some(Box::new(|_uuid, _purl| { + Box::pin(async { Ok(None) }) + })), + on_recovery_event: None, + }; + + let result = recover_manifest(&parsed, options).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 0); + assert_eq!(result.discarded_patches.len(), 1); + } + + #[tokio::test] + async fn test_recovery_with_refetch_error() { + let parsed = json!({ + "patches": { + "pkg:npm/test@1.0.0": { + "uuid": "11111111-1111-4111-8111-111111111111" + } + } + }); + + let options = RecoveryOptions { + refetch_patch: Some(Box::new(|_uuid, _purl| { + Box::pin(async { Err("network error".to_string()) }) + })), + on_recovery_event: None, + }; + + let result = recover_manifest(&parsed, options).await; + assert!(result.repair_needed); + assert_eq!(result.manifest.patches.len(), 0); + assert_eq!(result.discarded_patches.len(), 1); + } +} diff --git a/crates/socket-patch-core/src/manifest/schema.rs b/crates/socket-patch-core/src/manifest/schema.rs new file mode 100644 index 0000000..bfd7fe3 --- /dev/null +++ b/crates/socket-patch-core/src/manifest/schema.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Information about a vulnerability fixed by a patch. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct VulnerabilityInfo { + pub cves: Vec, + pub summary: String, + pub severity: String, + pub description: String, +} + +/// Hash information for a single patched file. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct PatchFileInfo { + pub before_hash: String, + pub after_hash: String, +} + +/// A single patch record in the manifest. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct PatchRecord { + pub uuid: String, + pub exported_at: String, + /// Maps relative file path -> hash info. + pub files: HashMap, + /// Maps vulnerability ID (e.g., "GHSA-...") -> vulnerability info. + pub vulnerabilities: HashMap, + pub description: String, + pub license: String, + pub tier: String, +} + +/// The top-level patch manifest structure. +/// Stored as `.socket/manifest.json`. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PatchManifest { + /// Maps package PURL (e.g., "pkg:npm/lodash@4.17.21") -> patch record. + pub patches: HashMap, +} + +impl PatchManifest { + /// Create an empty manifest. + pub fn new() -> Self { + Self { + patches: HashMap::new(), + } + } +} + +impl Default for PatchManifest { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_manifest_roundtrip() { + let manifest = PatchManifest::new(); + let json = serde_json::to_string_pretty(&manifest).unwrap(); + let parsed: PatchManifest = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.patches.len(), 0); + } + + #[test] + fn test_manifest_with_patch_roundtrip() { + let json = r#"{ + "patches": { + "pkg:npm/simplehttpserver@0.0.6": { + "uuid": "12345678-1234-1234-1234-123456789abc", + "exportedAt": "2024-01-15T10:00:00Z", + "files": { + "package/lib/server.js": { + "beforeHash": "aaaa000000000000000000000000000000000000000000000000000000000000", + "afterHash": "bbbb000000000000000000000000000000000000000000000000000000000000" + } + }, + "vulnerabilities": { + "GHSA-jrhj-2j3q-xf3v": { + "cves": ["CVE-2024-1234"], + "summary": "Path traversal vulnerability", + "severity": "high", + "description": "A path traversal vulnerability exists in simplehttpserver" + } + }, + "description": "Fix path traversal vulnerability", + "license": "MIT", + "tier": "free" + } + } +}"#; + + let manifest: PatchManifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.patches.len(), 1); + + let patch = manifest.patches.get("pkg:npm/simplehttpserver@0.0.6").unwrap(); + assert_eq!(patch.uuid, "12345678-1234-1234-1234-123456789abc"); + assert_eq!(patch.files.len(), 1); + assert_eq!(patch.vulnerabilities.len(), 1); + assert_eq!(patch.tier, "free"); + + let file_info = patch.files.get("package/lib/server.js").unwrap(); + assert_eq!( + file_info.before_hash, + "aaaa000000000000000000000000000000000000000000000000000000000000" + ); + + let vuln = patch.vulnerabilities.get("GHSA-jrhj-2j3q-xf3v").unwrap(); + assert_eq!(vuln.cves, vec!["CVE-2024-1234"]); + assert_eq!(vuln.severity, "high"); + + // Verify round-trip + let serialized = serde_json::to_string_pretty(&manifest).unwrap(); + let reparsed: PatchManifest = serde_json::from_str(&serialized).unwrap(); + assert_eq!(manifest, reparsed); + } + + #[test] + fn test_camel_case_serialization() { + let file_info = PatchFileInfo { + before_hash: "aaa".to_string(), + after_hash: "bbb".to_string(), + }; + let json = serde_json::to_string(&file_info).unwrap(); + assert!(json.contains("beforeHash")); + assert!(json.contains("afterHash")); + assert!(!json.contains("before_hash")); + assert!(!json.contains("after_hash")); + } + + #[test] + fn test_patch_record_camel_case() { + let record = PatchRecord { + uuid: "test-uuid".to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files: HashMap::new(), + vulnerabilities: HashMap::new(), + description: "test".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + }; + let json = serde_json::to_string(&record).unwrap(); + assert!(json.contains("exportedAt")); + assert!(!json.contains("exported_at")); + } +} diff --git a/crates/socket-patch-core/src/package_json/detect.rs b/crates/socket-patch-core/src/package_json/detect.rs new file mode 100644 index 0000000..e90f742 --- /dev/null +++ b/crates/socket-patch-core/src/package_json/detect.rs @@ -0,0 +1,461 @@ +/// Package manager type for selecting the correct command prefix. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum PackageManager { + Npm, + Pnpm, +} + +/// Get the socket-patch apply command for the given package manager. +fn socket_patch_command(pm: PackageManager) -> &'static str { + match pm { + PackageManager::Npm => "npx @socketsecurity/socket-patch apply --silent --ecosystems npm", + PackageManager::Pnpm => { + "pnpm dlx @socketsecurity/socket-patch apply --silent --ecosystems npm" + } + } +} + +/// Legacy command patterns to detect existing configurations. +const LEGACY_PATCH_PATTERNS: &[&str] = &[ + "socket-patch apply", + "npx @socketsecurity/socket-patch apply", + "socket patch apply", +]; + +/// Check if a script string contains any known socket-patch apply pattern. +fn script_is_configured(script: &str) -> bool { + LEGACY_PATCH_PATTERNS + .iter() + .any(|pattern| script.contains(pattern)) +} + +/// Status of setup script configuration (both postinstall and dependencies). +#[derive(Debug, Clone)] +pub struct ScriptSetupStatus { + pub postinstall_configured: bool, + pub postinstall_script: String, + pub dependencies_configured: bool, + pub dependencies_script: String, + pub needs_update: bool, +} + +/// Check if package.json scripts are properly configured for socket-patch. +/// Checks both the postinstall and dependencies lifecycle scripts. +pub fn is_setup_configured(package_json: &serde_json::Value) -> ScriptSetupStatus { + let scripts = package_json.get("scripts"); + + let postinstall_script = scripts + .and_then(|s| s.get("postinstall")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let postinstall_configured = script_is_configured(&postinstall_script); + + let dependencies_script = scripts + .and_then(|s| s.get("dependencies")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let dependencies_configured = script_is_configured(&dependencies_script); + + ScriptSetupStatus { + postinstall_configured, + postinstall_script, + dependencies_configured, + dependencies_script, + needs_update: !postinstall_configured || !dependencies_configured, + } +} + +/// Check if a package.json content string is properly configured. +pub fn is_setup_configured_str(content: &str) -> ScriptSetupStatus { + match serde_json::from_str::(content) { + Ok(val) => is_setup_configured(&val), + Err(_) => ScriptSetupStatus { + postinstall_configured: false, + postinstall_script: String::new(), + dependencies_configured: false, + dependencies_script: String::new(), + needs_update: true, + }, + } +} + +/// Generate an updated script that includes the socket-patch apply command. +/// If already configured, returns unchanged. Otherwise prepends the command. +pub fn generate_updated_script(current_script: &str, pm: PackageManager) -> String { + let command = socket_patch_command(pm); + let trimmed = current_script.trim(); + + // If empty, just add the socket-patch command. + if trimmed.is_empty() { + return command.to_string(); + } + + // If any socket-patch variant is already present, return unchanged. + if script_is_configured(trimmed) { + return trimmed.to_string(); + } + + // Prepend socket-patch command so it runs first. + format!("{command} && {trimmed}") +} + +/// Update a package.json Value with socket-patch in both postinstall and +/// dependencies scripts. +/// Returns (modified, new_postinstall, new_dependencies). +pub fn update_package_json_object( + package_json: &mut serde_json::Value, + pm: PackageManager, +) -> (bool, String, String) { + let status = is_setup_configured(package_json); + + if !status.needs_update { + return ( + false, + status.postinstall_script, + status.dependencies_script, + ); + } + + // Ensure scripts object exists + if package_json.get("scripts").is_none() { + package_json["scripts"] = serde_json::json!({}); + } + + let mut modified = false; + + let new_postinstall = if !status.postinstall_configured { + modified = true; + let s = generate_updated_script(&status.postinstall_script, pm); + package_json["scripts"]["postinstall"] = serde_json::Value::String(s.clone()); + s + } else { + status.postinstall_script + }; + + let new_dependencies = if !status.dependencies_configured { + modified = true; + let s = generate_updated_script(&status.dependencies_script, pm); + package_json["scripts"]["dependencies"] = serde_json::Value::String(s.clone()); + s + } else { + status.dependencies_script + }; + + (modified, new_postinstall, new_dependencies) +} + +/// Parse package.json content and update it with socket-patch scripts. +/// Returns (modified, new_content, old_postinstall, new_postinstall, +/// old_dependencies, new_dependencies). +pub fn update_package_json_content( + content: &str, + pm: PackageManager, +) -> Result<(bool, String, String, String, String, String), String> { + let mut package_json: serde_json::Value = + serde_json::from_str(content).map_err(|e| format!("Invalid package.json: {e}"))?; + + let status = is_setup_configured(&package_json); + + if !status.needs_update { + return Ok(( + false, + content.to_string(), + status.postinstall_script.clone(), + status.postinstall_script, + status.dependencies_script.clone(), + status.dependencies_script, + )); + } + + let old_postinstall = status.postinstall_script.clone(); + let old_dependencies = status.dependencies_script.clone(); + + let (_, new_postinstall, new_dependencies) = + update_package_json_object(&mut package_json, pm); + let new_content = serde_json::to_string_pretty(&package_json).unwrap() + "\n"; + + Ok(( + true, + new_content, + old_postinstall, + new_postinstall, + old_dependencies, + new_dependencies, + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── is_setup_configured ───────────────────────────────────────── + + #[test] + fn test_not_configured() { + let pkg: serde_json::Value = serde_json::json!({ + "name": "test", + "scripts": { + "build": "tsc" + } + }); + let status = is_setup_configured(&pkg); + assert!(!status.postinstall_configured); + assert!(!status.dependencies_configured); + assert!(status.needs_update); + } + + #[test] + fn test_postinstall_configured_dependencies_not() { + let pkg: serde_json::Value = serde_json::json!({ + "name": "test", + "scripts": { + "postinstall": "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + } + }); + let status = is_setup_configured(&pkg); + assert!(status.postinstall_configured); + assert!(!status.dependencies_configured); + assert!(status.needs_update); + } + + #[test] + fn test_both_configured() { + let pkg: serde_json::Value = serde_json::json!({ + "name": "test", + "scripts": { + "postinstall": "npx @socketsecurity/socket-patch apply --silent --ecosystems npm", + "dependencies": "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + } + }); + let status = is_setup_configured(&pkg); + assert!(status.postinstall_configured); + assert!(status.dependencies_configured); + assert!(!status.needs_update); + } + + #[test] + fn test_legacy_socket_patch_apply_recognized() { + let pkg: serde_json::Value = serde_json::json!({ + "scripts": { + "postinstall": "socket patch apply --silent --ecosystems npm", + "dependencies": "socket-patch apply" + } + }); + let status = is_setup_configured(&pkg); + assert!(status.postinstall_configured); + assert!(status.dependencies_configured); + assert!(!status.needs_update); + } + + #[test] + fn test_no_scripts() { + let pkg: serde_json::Value = serde_json::json!({"name": "test"}); + let status = is_setup_configured(&pkg); + assert!(!status.postinstall_configured); + assert!(status.postinstall_script.is_empty()); + assert!(!status.dependencies_configured); + assert!(status.dependencies_script.is_empty()); + } + + #[test] + fn test_no_postinstall() { + let pkg: serde_json::Value = serde_json::json!({ + "scripts": {"build": "tsc"} + }); + let status = is_setup_configured(&pkg); + assert!(!status.postinstall_configured); + assert!(status.postinstall_script.is_empty()); + } + + // ── is_setup_configured_str ───────────────────────────────────── + + #[test] + fn test_configured_str_invalid_json() { + let status = is_setup_configured_str("not json"); + assert!(!status.postinstall_configured); + assert!(status.needs_update); + } + + #[test] + fn test_configured_str_legacy_npx_pattern() { + let content = r#"{"scripts":{"postinstall":"npx @socketsecurity/socket-patch apply --silent"}}"#; + let status = is_setup_configured_str(content); + assert!(status.postinstall_configured); + } + + #[test] + fn test_configured_str_socket_dash_patch() { + let content = + r#"{"scripts":{"postinstall":"socket-patch apply --silent --ecosystems npm"}}"#; + let status = is_setup_configured_str(content); + assert!(status.postinstall_configured); + } + + #[test] + fn test_configured_str_pnpm_dlx_pattern() { + let content = r#"{"scripts":{"postinstall":"pnpm dlx @socketsecurity/socket-patch apply --silent --ecosystems npm"}}"#; + let status = is_setup_configured_str(content); + // "pnpm dlx @socketsecurity/socket-patch apply" contains "socket-patch apply" + assert!(status.postinstall_configured); + } + + // ── generate_updated_script ───────────────────────────────────── + + #[test] + fn test_generate_empty_npm() { + assert_eq!( + generate_updated_script("", PackageManager::Npm), + "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + ); + } + + #[test] + fn test_generate_empty_pnpm() { + assert_eq!( + generate_updated_script("", PackageManager::Pnpm), + "pnpm dlx @socketsecurity/socket-patch apply --silent --ecosystems npm" + ); + } + + #[test] + fn test_generate_prepend_npm() { + assert_eq!( + generate_updated_script("echo done", PackageManager::Npm), + "npx @socketsecurity/socket-patch apply --silent --ecosystems npm && echo done" + ); + } + + #[test] + fn test_generate_prepend_pnpm() { + assert_eq!( + generate_updated_script("echo done", PackageManager::Pnpm), + "pnpm dlx @socketsecurity/socket-patch apply --silent --ecosystems npm && echo done" + ); + } + + #[test] + fn test_generate_already_configured() { + let current = "socket-patch apply && echo done"; + assert_eq!( + generate_updated_script(current, PackageManager::Npm), + current + ); + } + + #[test] + fn test_generate_whitespace_only() { + let result = generate_updated_script(" \t ", PackageManager::Npm); + assert_eq!( + result, + "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + ); + } + + // ── update_package_json_object ────────────────────────────────── + + #[test] + fn test_update_object_creates_scripts() { + let mut pkg: serde_json::Value = serde_json::json!({"name": "test"}); + let (modified, new_postinstall, new_dependencies) = + update_package_json_object(&mut pkg, PackageManager::Npm); + assert!(modified); + assert!(new_postinstall.contains("npx @socketsecurity/socket-patch apply")); + assert!(new_dependencies.contains("npx @socketsecurity/socket-patch apply")); + assert!(pkg.get("scripts").is_some()); + assert!(pkg["scripts"]["postinstall"].is_string()); + assert!(pkg["scripts"]["dependencies"].is_string()); + } + + #[test] + fn test_update_object_creates_scripts_pnpm() { + let mut pkg: serde_json::Value = serde_json::json!({"name": "test"}); + let (modified, new_postinstall, new_dependencies) = + update_package_json_object(&mut pkg, PackageManager::Pnpm); + assert!(modified); + assert!(new_postinstall.contains("pnpm dlx @socketsecurity/socket-patch apply")); + assert!(new_dependencies.contains("pnpm dlx @socketsecurity/socket-patch apply")); + } + + #[test] + fn test_update_object_noop_when_both_configured() { + let mut pkg: serde_json::Value = serde_json::json!({ + "scripts": { + "postinstall": "npx @socketsecurity/socket-patch apply --silent --ecosystems npm", + "dependencies": "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + } + }); + let (modified, _, _) = update_package_json_object(&mut pkg, PackageManager::Npm); + assert!(!modified); + } + + #[test] + fn test_update_object_adds_dependencies_when_postinstall_exists() { + let mut pkg: serde_json::Value = serde_json::json!({ + "scripts": { + "postinstall": "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + } + }); + let (modified, _, new_dependencies) = + update_package_json_object(&mut pkg, PackageManager::Npm); + assert!(modified); + assert!(new_dependencies.contains("npx @socketsecurity/socket-patch apply")); + // postinstall should remain unchanged + assert_eq!( + pkg["scripts"]["postinstall"].as_str().unwrap(), + "npx @socketsecurity/socket-patch apply --silent --ecosystems npm" + ); + } + + // ── update_package_json_content ───────────────────────────────── + + #[test] + fn test_update_content_roundtrip_no_scripts() { + let content = r#"{"name": "test"}"#; + let (modified, new_content, old_pi, new_pi, old_dep, new_dep) = + update_package_json_content(content, PackageManager::Npm).unwrap(); + assert!(modified); + assert!(old_pi.is_empty()); + assert!(new_pi.contains("npx @socketsecurity/socket-patch apply")); + assert!(old_dep.is_empty()); + assert!(new_dep.contains("npx @socketsecurity/socket-patch apply")); + let parsed: serde_json::Value = serde_json::from_str(&new_content).unwrap(); + assert!(parsed["scripts"]["postinstall"].is_string()); + assert!(parsed["scripts"]["dependencies"].is_string()); + } + + #[test] + fn test_update_content_already_configured() { + let content = r#"{"scripts":{"postinstall":"socket patch apply --silent --ecosystems npm","dependencies":"socket patch apply --silent --ecosystems npm"}}"#; + let (modified, _, _, _, _, _) = + update_package_json_content(content, PackageManager::Npm).unwrap(); + assert!(!modified); + } + + #[test] + fn test_update_content_invalid_json() { + let result = update_package_json_content("not json", PackageManager::Npm); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Invalid package.json")); + } + + #[test] + fn test_update_content_pnpm() { + let content = r#"{"name": "test"}"#; + let (modified, new_content, _, new_pi, _, new_dep) = + update_package_json_content(content, PackageManager::Pnpm).unwrap(); + assert!(modified); + assert!(new_pi.contains("pnpm dlx @socketsecurity/socket-patch apply")); + assert!(new_dep.contains("pnpm dlx @socketsecurity/socket-patch apply")); + let parsed: serde_json::Value = serde_json::from_str(&new_content).unwrap(); + assert!(parsed["scripts"]["postinstall"] + .as_str() + .unwrap() + .contains("pnpm dlx")); + assert!(parsed["scripts"]["dependencies"] + .as_str() + .unwrap() + .contains("pnpm dlx")); + } +} diff --git a/crates/socket-patch-core/src/package_json/find.rs b/crates/socket-patch-core/src/package_json/find.rs new file mode 100644 index 0000000..b72c5c7 --- /dev/null +++ b/crates/socket-patch-core/src/package_json/find.rs @@ -0,0 +1,688 @@ +use std::path::{Path, PathBuf}; +use tokio::fs; + +use super::detect::PackageManager; + +/// Detect the package manager based on lockfiles in the project root. +/// Checks for pnpm-lock.yaml, pnpm-lock.yml, and pnpm-workspace.yaml. +pub async fn detect_package_manager(start_path: &Path) -> PackageManager { + for name in &["pnpm-lock.yaml", "pnpm-lock.yml", "pnpm-workspace.yaml"] { + if fs::metadata(start_path.join(name)).await.is_ok() { + return PackageManager::Pnpm; + } + } + PackageManager::Npm +} + +/// Workspace configuration type. +#[derive(Debug, Clone)] +pub enum WorkspaceType { + Npm, + Pnpm, + None, +} + +/// Workspace configuration. +#[derive(Debug, Clone)] +pub struct WorkspaceConfig { + pub ws_type: WorkspaceType, + pub patterns: Vec, +} + +/// Location of a discovered package.json file. +#[derive(Debug, Clone)] +pub struct PackageJsonLocation { + pub path: PathBuf, + pub is_root: bool, + pub is_workspace: bool, + pub workspace_pattern: Option, +} + +/// Result of finding package.json files. +#[derive(Debug)] +pub struct PackageJsonFindResult { + pub files: Vec, + pub workspace_type: WorkspaceType, +} + +/// Find all package.json files, respecting workspace configurations. +pub async fn find_package_json_files( + start_path: &Path, +) -> PackageJsonFindResult { + let mut results = Vec::new(); + let root_package_json = start_path.join("package.json"); + + let mut root_exists = false; + let mut workspace_config = WorkspaceConfig { + ws_type: WorkspaceType::None, + patterns: Vec::new(), + }; + + if fs::metadata(&root_package_json).await.is_ok() { + root_exists = true; + workspace_config = detect_workspaces(&root_package_json).await; + results.push(PackageJsonLocation { + path: root_package_json, + is_root: true, + is_workspace: false, + workspace_pattern: None, + }); + } + + match workspace_config.ws_type { + WorkspaceType::None => { + if root_exists { + let nested = find_nested_package_json_files(start_path).await; + results.extend(nested); + } + } + _ => { + let ws_packages = + find_workspace_packages(start_path, &workspace_config).await; + results.extend(ws_packages); + } + } + + PackageJsonFindResult { + files: results, + workspace_type: workspace_config.ws_type, + } +} + +/// Detect workspace configuration from package.json. +pub async fn detect_workspaces(package_json_path: &Path) -> WorkspaceConfig { + let default = WorkspaceConfig { + ws_type: WorkspaceType::None, + patterns: Vec::new(), + }; + + let content = match fs::read_to_string(package_json_path).await { + Ok(c) => c, + Err(_) => return default, + }; + + let pkg: serde_json::Value = match serde_json::from_str(&content) { + Ok(v) => v, + Err(_) => return default, + }; + + // Check for pnpm workspaces first — pnpm projects may also have + // "workspaces" in package.json for compatibility, but + // pnpm-workspace.yaml is the definitive signal. + let dir = package_json_path.parent().unwrap_or(Path::new(".")); + let pnpm_workspace = dir.join("pnpm-workspace.yaml"); + if let Ok(yaml_content) = fs::read_to_string(&pnpm_workspace).await { + let patterns = parse_pnpm_workspace_patterns(&yaml_content); + return WorkspaceConfig { + ws_type: WorkspaceType::Pnpm, + patterns, + }; + } + + // Check for npm/yarn workspaces + if let Some(workspaces) = pkg.get("workspaces") { + let patterns = if let Some(arr) = workspaces.as_array() { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + } else if let Some(obj) = workspaces.as_object() { + obj.get("packages") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default() + } else { + Vec::new() + }; + + return WorkspaceConfig { + ws_type: WorkspaceType::Npm, + patterns, + }; + } + + default +} + +/// Simple parser for pnpm-workspace.yaml packages field. +fn parse_pnpm_workspace_patterns(yaml_content: &str) -> Vec { + let mut patterns = Vec::new(); + let mut in_packages = false; + + for line in yaml_content.lines() { + let trimmed = line.trim(); + + if trimmed == "packages:" { + in_packages = true; + continue; + } + + if in_packages { + if !trimmed.is_empty() + && !trimmed.starts_with('-') + && !trimmed.starts_with('#') + { + break; + } + + if let Some(rest) = trimmed.strip_prefix('-') { + let item = rest.trim().trim_matches('\'').trim_matches('"'); + if !item.is_empty() { + patterns.push(item.to_string()); + } + } + } + } + + patterns +} + +/// Find workspace packages based on workspace patterns. +async fn find_workspace_packages( + root_path: &Path, + config: &WorkspaceConfig, +) -> Vec { + let mut results = Vec::new(); + + for pattern in &config.patterns { + let packages = find_packages_matching_pattern(root_path, pattern).await; + for p in packages { + results.push(PackageJsonLocation { + path: p, + is_root: false, + is_workspace: true, + workspace_pattern: Some(pattern.clone()), + }); + } + } + + results +} + +/// Find packages matching a workspace pattern. +async fn find_packages_matching_pattern( + root_path: &Path, + pattern: &str, +) -> Vec { + let mut results = Vec::new(); + let parts: Vec<&str> = pattern.split('/').collect(); + + if parts.len() == 2 && parts[1] == "*" { + let search_path = root_path.join(parts[0]); + search_one_level(&search_path, &mut results).await; + } else if parts.len() == 2 && parts[1] == "**" { + let search_path = root_path.join(parts[0]); + search_recursive(&search_path, &mut results).await; + } else { + let pkg_json = root_path.join(pattern).join("package.json"); + if fs::metadata(&pkg_json).await.is_ok() { + results.push(pkg_json); + } + } + + results +} + +/// Search one level deep for package.json files. +async fn search_one_level(dir: &Path, results: &mut Vec) { + let mut entries = match fs::read_dir(dir).await { + Ok(e) => e, + Err(_) => return, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + let pkg_json = entry.path().join("package.json"); + if fs::metadata(&pkg_json).await.is_ok() { + results.push(pkg_json); + } + } +} + +/// Search recursively for package.json files. +async fn search_recursive(dir: &Path, results: &mut Vec) { + let mut entries = match fs::read_dir(dir).await { + Ok(e) => e, + Err(_) => return, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + // Skip hidden directories, node_modules, dist, build + if name_str.starts_with('.') + || name_str == "node_modules" + || name_str == "dist" + || name_str == "build" + { + continue; + } + + let full_path = entry.path(); + let pkg_json = full_path.join("package.json"); + if fs::metadata(&pkg_json).await.is_ok() { + results.push(pkg_json); + } + + Box::pin(search_recursive(&full_path, results)).await; + } +} + +/// Find nested package.json files without workspace configuration. +async fn find_nested_package_json_files( + start_path: &Path, +) -> Vec { + let mut results = Vec::new(); + let root_pkg = start_path.join("package.json"); + search_nested(start_path, &root_pkg, 0, &mut results).await; + results +} + +async fn search_nested( + dir: &Path, + root_pkg: &Path, + depth: usize, + results: &mut Vec, +) { + if depth > 5 { + return; + } + + let mut entries = match fs::read_dir(dir).await { + Ok(e) => e, + Err(_) => return, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + if name_str.starts_with('.') + || name_str == "node_modules" + || name_str == "dist" + || name_str == "build" + { + continue; + } + + let full_path = entry.path(); + let pkg_json = full_path.join("package.json"); + if fs::metadata(&pkg_json).await.is_ok() && pkg_json != root_pkg { + results.push(PackageJsonLocation { + path: pkg_json, + is_root: false, + is_workspace: false, + workspace_pattern: None, + }); + } + + Box::pin(search_nested(&full_path, root_pkg, depth + 1, results)).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Group 1: parse_pnpm_workspace_patterns ─────────────────────── + + #[test] + fn test_parse_pnpm_basic() { + let yaml = "packages:\n - packages/*"; + assert_eq!(parse_pnpm_workspace_patterns(yaml), vec!["packages/*"]); + } + + #[test] + fn test_parse_pnpm_multiple_patterns() { + let yaml = "packages:\n - packages/*\n - apps/*\n - tools/*"; + assert_eq!( + parse_pnpm_workspace_patterns(yaml), + vec!["packages/*", "apps/*", "tools/*"] + ); + } + + #[test] + fn test_parse_pnpm_quoted_patterns() { + let yaml = "packages:\n - 'packages/*'\n - \"apps/*\""; + assert_eq!( + parse_pnpm_workspace_patterns(yaml), + vec!["packages/*", "apps/*"] + ); + } + + #[test] + fn test_parse_pnpm_comments_interspersed() { + let yaml = "packages:\n # workspace packages\n - packages/*\n # apps\n - apps/*"; + assert_eq!( + parse_pnpm_workspace_patterns(yaml), + vec!["packages/*", "apps/*"] + ); + } + + #[test] + fn test_parse_pnpm_empty_content() { + assert!(parse_pnpm_workspace_patterns("").is_empty()); + } + + #[test] + fn test_parse_pnpm_no_packages_key() { + let yaml = "name: my-project\nversion: 1.0.0"; + assert!(parse_pnpm_workspace_patterns(yaml).is_empty()); + } + + #[test] + fn test_parse_pnpm_stops_at_next_section() { + let yaml = "packages:\n - packages/*\ncatalog:\n lodash: 4.17.21"; + assert_eq!(parse_pnpm_workspace_patterns(yaml), vec!["packages/*"]); + } + + #[test] + fn test_parse_pnpm_indented_key() { + // The parser uses `trimmed == "packages:"` so leading spaces should match + let yaml = " packages:\n - packages/*"; + assert_eq!(parse_pnpm_workspace_patterns(yaml), vec!["packages/*"]); + } + + #[test] + fn test_parse_pnpm_dash_only_line() { + let yaml = "packages:\n -\n - packages/*"; + // A bare "-" with no value should be skipped (empty after trim) + assert_eq!(parse_pnpm_workspace_patterns(yaml), vec!["packages/*"]); + } + + #[test] + fn test_parse_pnpm_glob_star_star() { + let yaml = "packages:\n - packages/**"; + assert_eq!(parse_pnpm_workspace_patterns(yaml), vec!["packages/**"]); + } + + // ── Group 2: workspace detection + file discovery ──────────────── + + #[tokio::test] + async fn test_detect_workspaces_npm_array() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, r#"{"workspaces": ["packages/*"]}"#) + .await + .unwrap(); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::Npm)); + assert_eq!(config.patterns, vec!["packages/*"]); + } + + #[tokio::test] + async fn test_detect_workspaces_npm_object() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write( + &pkg, + r#"{"workspaces": {"packages": ["packages/*", "apps/*"]}}"#, + ) + .await + .unwrap(); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::Npm)); + assert_eq!(config.patterns, vec!["packages/*", "apps/*"]); + } + + #[tokio::test] + async fn test_detect_workspaces_pnpm() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, r#"{"name": "root"}"#).await.unwrap(); + let pnpm = dir.path().join("pnpm-workspace.yaml"); + fs::write(&pnpm, "packages:\n - packages/*") + .await + .unwrap(); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::Pnpm)); + assert_eq!(config.patterns, vec!["packages/*"]); + } + + #[tokio::test] + async fn test_detect_workspaces_pnpm_with_workspaces_field() { + // When both pnpm-workspace.yaml AND "workspaces" in package.json + // exist, pnpm should take priority + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write( + &pkg, + r#"{"name": "root", "workspaces": ["packages/*"]}"#, + ) + .await + .unwrap(); + let pnpm = dir.path().join("pnpm-workspace.yaml"); + fs::write(&pnpm, "packages:\n - workspaces/*") + .await + .unwrap(); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::Pnpm)); + // Should use pnpm-workspace.yaml patterns, not package.json workspaces + assert_eq!(config.patterns, vec!["workspaces/*"]); + } + + #[tokio::test] + async fn test_detect_workspaces_none() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, r#"{"name": "root"}"#).await.unwrap(); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::None)); + assert!(config.patterns.is_empty()); + } + + #[tokio::test] + async fn test_detect_workspaces_invalid_json() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, "not valid json!!!").await.unwrap(); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::None)); + } + + #[tokio::test] + async fn test_detect_workspaces_file_not_found() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("nonexistent.json"); + let config = detect_workspaces(&pkg).await; + assert!(matches!(config.ws_type, WorkspaceType::None)); + } + + #[tokio::test] + async fn test_find_no_root_package_json() { + let dir = tempfile::tempdir().unwrap(); + let result = find_package_json_files(dir.path()).await; + assert!(result.files.is_empty()); + } + + #[tokio::test] + async fn test_find_root_only() { + let dir = tempfile::tempdir().unwrap(); + fs::write(dir.path().join("package.json"), r#"{"name":"root"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + assert_eq!(result.files.len(), 1); + assert!(result.files[0].is_root); + } + + #[tokio::test] + async fn test_find_npm_workspaces() { + let dir = tempfile::tempdir().unwrap(); + fs::write( + dir.path().join("package.json"), + r#"{"workspaces": ["packages/*"]}"#, + ) + .await + .unwrap(); + let pkg_a = dir.path().join("packages").join("a"); + fs::create_dir_all(&pkg_a).await.unwrap(); + fs::write(pkg_a.join("package.json"), r#"{"name":"a"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + assert!(matches!(result.workspace_type, WorkspaceType::Npm)); + // root + workspace member + assert_eq!(result.files.len(), 2); + assert!(result.files[0].is_root); + assert!(result.files[1].is_workspace); + } + + #[tokio::test] + async fn test_find_pnpm_workspaces() { + let dir = tempfile::tempdir().unwrap(); + fs::write(dir.path().join("package.json"), r#"{"name":"root"}"#) + .await + .unwrap(); + fs::write( + dir.path().join("pnpm-workspace.yaml"), + "packages:\n - packages/*", + ) + .await + .unwrap(); + let pkg_a = dir.path().join("packages").join("a"); + fs::create_dir_all(&pkg_a).await.unwrap(); + fs::write(pkg_a.join("package.json"), r#"{"name":"a"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + assert!(matches!(result.workspace_type, WorkspaceType::Pnpm)); + // find_package_json_files still returns all files; + // filtering for pnpm is done by the caller (setup command) + assert_eq!(result.files.len(), 2); + assert!(result.files[0].is_root); + assert!(result.files[1].is_workspace); + } + + #[tokio::test] + async fn test_find_nested_skips_node_modules() { + let dir = tempfile::tempdir().unwrap(); + fs::write(dir.path().join("package.json"), r#"{"name":"root"}"#) + .await + .unwrap(); + let nm = dir.path().join("node_modules").join("lodash"); + fs::create_dir_all(&nm).await.unwrap(); + fs::write(nm.join("package.json"), r#"{"name":"lodash"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + // Only root, node_modules should be skipped + assert_eq!(result.files.len(), 1); + assert!(result.files[0].is_root); + } + + #[tokio::test] + async fn test_find_nested_depth_limit() { + let dir = tempfile::tempdir().unwrap(); + fs::write(dir.path().join("package.json"), r#"{"name":"root"}"#) + .await + .unwrap(); + // Create deeply nested package.json at depth 7 (> limit of 5) + let mut deep = dir.path().to_path_buf(); + for i in 0..7 { + deep = deep.join(format!("level{}", i)); + } + fs::create_dir_all(&deep).await.unwrap(); + fs::write(deep.join("package.json"), r#"{"name":"deep"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + // Only root (the deep one exceeds depth limit) + assert_eq!(result.files.len(), 1); + } + + #[tokio::test] + async fn test_find_workspace_double_glob() { + let dir = tempfile::tempdir().unwrap(); + fs::write( + dir.path().join("package.json"), + r#"{"workspaces": ["apps/**"]}"#, + ) + .await + .unwrap(); + let nested = dir.path().join("apps").join("web").join("client"); + fs::create_dir_all(&nested).await.unwrap(); + fs::write(nested.join("package.json"), r#"{"name":"client"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + // root + recursively found workspace member + assert!(result.files.len() >= 2); + } + + #[tokio::test] + async fn test_find_workspace_exact_path() { + let dir = tempfile::tempdir().unwrap(); + fs::write( + dir.path().join("package.json"), + r#"{"workspaces": ["packages/core"]}"#, + ) + .await + .unwrap(); + let core = dir.path().join("packages").join("core"); + fs::create_dir_all(&core).await.unwrap(); + fs::write(core.join("package.json"), r#"{"name":"core"}"#) + .await + .unwrap(); + let result = find_package_json_files(dir.path()).await; + assert_eq!(result.files.len(), 2); + } + + // ── detect_package_manager ────────────────────────────────────── + + #[tokio::test] + async fn test_detect_npm_by_default() { + let dir = tempfile::tempdir().unwrap(); + let pm = detect_package_manager(dir.path()).await; + assert_eq!(pm, PackageManager::Npm); + } + + #[tokio::test] + async fn test_detect_pnpm_lock_yaml() { + let dir = tempfile::tempdir().unwrap(); + fs::write(dir.path().join("pnpm-lock.yaml"), "lockfileVersion: 9.0\n") + .await + .unwrap(); + let pm = detect_package_manager(dir.path()).await; + assert_eq!(pm, PackageManager::Pnpm); + } + + #[tokio::test] + async fn test_detect_pnpm_workspace_yaml() { + let dir = tempfile::tempdir().unwrap(); + fs::write( + dir.path().join("pnpm-workspace.yaml"), + "packages:\n - packages/*", + ) + .await + .unwrap(); + let pm = detect_package_manager(dir.path()).await; + assert_eq!(pm, PackageManager::Pnpm); + } +} diff --git a/crates/socket-patch-core/src/package_json/mod.rs b/crates/socket-patch-core/src/package_json/mod.rs new file mode 100644 index 0000000..7e1546a --- /dev/null +++ b/crates/socket-patch-core/src/package_json/mod.rs @@ -0,0 +1,3 @@ +pub mod detect; +pub mod find; +pub mod update; diff --git a/crates/socket-patch-core/src/package_json/update.rs b/crates/socket-patch-core/src/package_json/update.rs new file mode 100644 index 0000000..f8b859a --- /dev/null +++ b/crates/socket-patch-core/src/package_json/update.rs @@ -0,0 +1,255 @@ +use std::path::Path; +use tokio::fs; + +use super::detect::{is_setup_configured_str, update_package_json_content, PackageManager}; + +/// Result of updating a single package.json. +#[derive(Debug, Clone)] +pub struct UpdateResult { + pub path: String, + pub status: UpdateStatus, + pub old_script: String, + pub new_script: String, + pub old_dependencies_script: String, + pub new_dependencies_script: String, + pub error: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum UpdateStatus { + Updated, + AlreadyConfigured, + Error, +} + +/// Update a single package.json file with socket-patch lifecycle scripts. +pub async fn update_package_json( + package_json_path: &Path, + dry_run: bool, + pm: PackageManager, +) -> UpdateResult { + let path_str = package_json_path.display().to_string(); + + let content = match fs::read_to_string(package_json_path).await { + Ok(c) => c, + Err(e) => { + return UpdateResult { + path: path_str, + status: UpdateStatus::Error, + old_script: String::new(), + new_script: String::new(), + old_dependencies_script: String::new(), + new_dependencies_script: String::new(), + error: Some(e.to_string()), + }; + } + }; + + let status = is_setup_configured_str(&content); + if !status.needs_update { + return UpdateResult { + path: path_str, + status: UpdateStatus::AlreadyConfigured, + old_script: status.postinstall_script.clone(), + new_script: status.postinstall_script, + old_dependencies_script: status.dependencies_script.clone(), + new_dependencies_script: status.dependencies_script, + error: None, + }; + } + + match update_package_json_content(&content, pm) { + Ok((modified, new_content, old_pi, new_pi, old_dep, new_dep)) => { + if !modified { + return UpdateResult { + path: path_str, + status: UpdateStatus::AlreadyConfigured, + old_script: old_pi, + new_script: new_pi, + old_dependencies_script: old_dep, + new_dependencies_script: new_dep, + error: None, + }; + } + + if !dry_run { + if let Err(e) = fs::write(package_json_path, &new_content).await { + return UpdateResult { + path: path_str, + status: UpdateStatus::Error, + old_script: old_pi, + new_script: new_pi, + old_dependencies_script: old_dep, + new_dependencies_script: new_dep, + error: Some(e.to_string()), + }; + } + } + + UpdateResult { + path: path_str, + status: UpdateStatus::Updated, + old_script: old_pi, + new_script: new_pi, + old_dependencies_script: old_dep, + new_dependencies_script: new_dep, + error: None, + } + } + Err(e) => UpdateResult { + path: path_str, + status: UpdateStatus::Error, + old_script: String::new(), + new_script: String::new(), + old_dependencies_script: String::new(), + new_dependencies_script: String::new(), + error: Some(e), + }, + } +} + +/// Update multiple package.json files. +pub async fn update_multiple_package_jsons( + paths: &[&Path], + dry_run: bool, + pm: PackageManager, +) -> Vec { + let mut results = Vec::new(); + for path in paths { + let result = update_package_json(path, dry_run, pm).await; + results.push(result); + } + results +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_update_file_not_found() { + let dir = tempfile::tempdir().unwrap(); + let missing = dir.path().join("nonexistent.json"); + let result = update_package_json(&missing, false, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::Error); + assert!(result.error.is_some()); + } + + #[tokio::test] + async fn test_update_already_configured() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write( + &pkg, + r#"{"name":"test","scripts":{"postinstall":"npx @socketsecurity/socket-patch apply --silent --ecosystems npm","dependencies":"npx @socketsecurity/socket-patch apply --silent --ecosystems npm"}}"#, + ) + .await + .unwrap(); + let result = update_package_json(&pkg, false, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::AlreadyConfigured); + } + + #[tokio::test] + async fn test_update_dry_run_does_not_write() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + let original = r#"{"name":"test","scripts":{"build":"tsc"}}"#; + fs::write(&pkg, original).await.unwrap(); + let result = update_package_json(&pkg, true, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::Updated); + // File should remain unchanged + let content = fs::read_to_string(&pkg).await.unwrap(); + assert_eq!(content, original); + } + + #[tokio::test] + async fn test_update_writes_file() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, r#"{"name":"test","scripts":{"build":"tsc"}}"#) + .await + .unwrap(); + let result = update_package_json(&pkg, false, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::Updated); + let content = fs::read_to_string(&pkg).await.unwrap(); + assert!(content.contains("npx @socketsecurity/socket-patch apply")); + assert!(content.contains("postinstall")); + assert!(content.contains("dependencies")); + } + + #[tokio::test] + async fn test_update_invalid_json() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, "not json!!!").await.unwrap(); + let result = update_package_json(&pkg, false, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::Error); + assert!(result.error.is_some()); + } + + #[tokio::test] + async fn test_update_no_scripts_key() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, r#"{"name":"x"}"#).await.unwrap(); + let result = update_package_json(&pkg, false, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::Updated); + let content = fs::read_to_string(&pkg).await.unwrap(); + assert!(content.contains("postinstall")); + assert!(content.contains("dependencies")); + assert!(content.contains("npx @socketsecurity/socket-patch apply")); + } + + #[tokio::test] + async fn test_update_pnpm() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write(&pkg, r#"{"name":"x"}"#).await.unwrap(); + let result = update_package_json(&pkg, false, PackageManager::Pnpm).await; + assert_eq!(result.status, UpdateStatus::Updated); + let content = fs::read_to_string(&pkg).await.unwrap(); + assert!(content.contains("pnpm dlx @socketsecurity/socket-patch apply")); + } + + #[tokio::test] + async fn test_update_adds_dependencies_when_postinstall_exists() { + let dir = tempfile::tempdir().unwrap(); + let pkg = dir.path().join("package.json"); + fs::write( + &pkg, + r#"{"name":"test","scripts":{"postinstall":"npx @socketsecurity/socket-patch apply --silent --ecosystems npm"}}"#, + ) + .await + .unwrap(); + let result = update_package_json(&pkg, false, PackageManager::Npm).await; + assert_eq!(result.status, UpdateStatus::Updated); + let content = fs::read_to_string(&pkg).await.unwrap(); + assert!(content.contains("dependencies")); + } + + #[tokio::test] + async fn test_update_multiple_mixed() { + let dir = tempfile::tempdir().unwrap(); + + let p1 = dir.path().join("a.json"); + fs::write(&p1, r#"{"name":"a"}"#).await.unwrap(); + + let p2 = dir.path().join("b.json"); + fs::write( + &p2, + r#"{"name":"b","scripts":{"postinstall":"npx @socketsecurity/socket-patch apply --silent --ecosystems npm","dependencies":"npx @socketsecurity/socket-patch apply --silent --ecosystems npm"}}"#, + ) + .await + .unwrap(); + + let p3 = dir.path().join("c.json"); + // Don't create p3 — file not found + + let paths: Vec<&Path> = vec![p1.as_path(), p2.as_path(), p3.as_path()]; + let results = update_multiple_package_jsons(&paths, false, PackageManager::Npm).await; + assert_eq!(results.len(), 3); + assert_eq!(results[0].status, UpdateStatus::Updated); + assert_eq!(results[1].status, UpdateStatus::AlreadyConfigured); + assert_eq!(results[2].status, UpdateStatus::Error); + } +} diff --git a/crates/socket-patch-core/src/patch/apply.rs b/crates/socket-patch-core/src/patch/apply.rs new file mode 100644 index 0000000..2fcb28d --- /dev/null +++ b/crates/socket-patch-core/src/patch/apply.rs @@ -0,0 +1,680 @@ +use std::collections::HashMap; +use std::path::Path; + +use crate::manifest::schema::PatchFileInfo; +use crate::patch::file_hash::compute_file_git_sha256; + +/// Status of a file patch verification. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum VerifyStatus { + /// File is ready to be patched (current hash matches beforeHash). + Ready, + /// File is already in the patched state (current hash matches afterHash). + AlreadyPatched, + /// File hash does not match either beforeHash or afterHash. + HashMismatch, + /// File was not found on disk. + NotFound, +} + +/// Result of verifying whether a single file can be patched. +#[derive(Debug, Clone)] +pub struct VerifyResult { + pub file: String, + pub status: VerifyStatus, + pub message: Option, + pub current_hash: Option, + pub expected_hash: Option, + pub target_hash: Option, +} + +/// Result of applying patches to a single package. +#[derive(Debug, Clone)] +pub struct ApplyResult { + pub package_key: String, + pub package_path: String, + pub success: bool, + pub files_verified: Vec, + pub files_patched: Vec, + pub error: Option, +} + +/// Normalize file path by removing the "package/" prefix if present. +/// Patch files come from the API with paths like "package/lib/file.js" +/// but we need relative paths like "lib/file.js" for the actual package directory. +pub fn normalize_file_path(file_name: &str) -> &str { + const PACKAGE_PREFIX: &str = "package/"; + if let Some(stripped) = file_name.strip_prefix(PACKAGE_PREFIX) { + stripped + } else { + file_name + } +} + +/// Verify a single file can be patched. +pub async fn verify_file_patch( + pkg_path: &Path, + file_name: &str, + file_info: &PatchFileInfo, +) -> VerifyResult { + let normalized = normalize_file_path(file_name); + let filepath = pkg_path.join(normalized); + + let is_new_file = file_info.before_hash.is_empty(); + + // Check if file exists + if tokio::fs::metadata(&filepath).await.is_err() { + // New files (empty beforeHash) are expected to not exist yet. + if is_new_file { + return VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::Ready, + message: None, + current_hash: None, + expected_hash: None, + target_hash: Some(file_info.after_hash.clone()), + }; + } + return VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::NotFound, + message: Some("File not found".to_string()), + current_hash: None, + expected_hash: None, + target_hash: None, + }; + } + + // Compute current hash + let current_hash = match compute_file_git_sha256(&filepath).await { + Ok(h) => h, + Err(e) => { + return VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::NotFound, + message: Some(format!("Failed to hash file: {}", e)), + current_hash: None, + expected_hash: None, + target_hash: None, + }; + } + }; + + // Check if already patched + if current_hash == file_info.after_hash { + return VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::AlreadyPatched, + message: None, + current_hash: Some(current_hash), + expected_hash: None, + target_hash: None, + }; + } + + // New files (empty beforeHash) with existing content that doesn't match + // afterHash: treat as Ready (force overwrite). + if is_new_file { + return VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::Ready, + message: None, + current_hash: Some(current_hash), + expected_hash: None, + target_hash: Some(file_info.after_hash.clone()), + }; + } + + // Check if matches expected before hash + if current_hash != file_info.before_hash { + return VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::HashMismatch, + message: Some("File hash does not match expected value".to_string()), + current_hash: Some(current_hash), + expected_hash: Some(file_info.before_hash.clone()), + target_hash: Some(file_info.after_hash.clone()), + }; + } + + VerifyResult { + file: file_name.to_string(), + status: VerifyStatus::Ready, + message: None, + current_hash: Some(current_hash), + expected_hash: None, + target_hash: Some(file_info.after_hash.clone()), + } +} + +/// Apply a patch to a single file. +/// Writes the patched content and verifies the resulting hash. +pub async fn apply_file_patch( + pkg_path: &Path, + file_name: &str, + patched_content: &[u8], + expected_hash: &str, +) -> Result<(), std::io::Error> { + let normalized = normalize_file_path(file_name); + let filepath = pkg_path.join(normalized); + + // Create parent directories if needed (e.g., new files added by a patch) + if let Some(parent) = filepath.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + // Make file writable if it exists and is read-only (e.g. Go module cache) + #[cfg(unix)] + if let Ok(meta) = tokio::fs::metadata(&filepath).await { + use std::os::unix::fs::PermissionsExt; + let perms = meta.permissions(); + if perms.readonly() { + let mode = perms.mode(); + let mut new_perms = perms; + new_perms.set_mode(mode | 0o200); + tokio::fs::set_permissions(&filepath, new_perms).await?; + } + } + + // Write the patched content + tokio::fs::write(&filepath, patched_content).await?; + + // Verify the hash after writing + let verify_hash = compute_file_git_sha256(&filepath).await?; + if verify_hash != expected_hash { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Hash verification failed after patch. Expected: {}, Got: {}", + expected_hash, verify_hash + ), + )); + } + + Ok(()) +} + +/// Verify and apply patches for a single package. +/// +/// For each file in `files`, this function: +/// 1. Verifies the file is ready to be patched (or already patched). +/// 2. If not dry_run, reads the blob from `blobs_path` and writes it. +/// 3. Returns a summary of what happened. +pub async fn apply_package_patch( + package_key: &str, + pkg_path: &Path, + files: &HashMap, + blobs_path: &Path, + dry_run: bool, + force: bool, +) -> ApplyResult { + let mut result = ApplyResult { + package_key: package_key.to_string(), + package_path: pkg_path.display().to_string(), + success: false, + files_verified: Vec::new(), + files_patched: Vec::new(), + error: None, + }; + + // First, verify all files + for (file_name, file_info) in files { + let mut verify_result = verify_file_patch(pkg_path, file_name, file_info).await; + + if verify_result.status != VerifyStatus::Ready + && verify_result.status != VerifyStatus::AlreadyPatched + { + if force { + match verify_result.status { + VerifyStatus::HashMismatch => { + // Force: treat hash mismatch as ready + verify_result.status = VerifyStatus::Ready; + } + VerifyStatus::NotFound => { + // Force: skip files that don't exist (non-new files) + result.files_verified.push(verify_result); + continue; + } + _ => {} + } + } else { + let msg = verify_result + .message + .clone() + .unwrap_or_else(|| format!("{:?}", verify_result.status)); + result.error = Some(format!( + "Cannot apply patch: {} - {}", + verify_result.file, msg + )); + result.files_verified.push(verify_result); + return result; + } + } + + result.files_verified.push(verify_result); + } + + // Check if all files are already patched + let all_already_patched = result + .files_verified + .iter() + .all(|v| v.status == VerifyStatus::AlreadyPatched); + + if all_already_patched { + result.success = true; + return result; + } + + // Check if all files are either already patched or not found (force mode skip) + let all_done_or_skipped = result + .files_verified + .iter() + .all(|v| v.status == VerifyStatus::AlreadyPatched || v.status == VerifyStatus::NotFound); + + if all_done_or_skipped { + // Some or all files were not found but skipped via --force + let not_found_count = result.files_verified.iter() + .filter(|v| v.status == VerifyStatus::NotFound) + .count(); + result.success = true; + result.error = Some(format!( + "All patch files were skipped: {} not found on disk (--force)", + not_found_count + )); + return result; + } + + // If dry run, stop here + if dry_run { + result.success = true; + return result; + } + + // Apply patches to files that need it + for (file_name, file_info) in files { + let verify_result = result.files_verified.iter().find(|v| v.file == *file_name); + if let Some(vr) = verify_result { + if vr.status == VerifyStatus::AlreadyPatched + || vr.status == VerifyStatus::NotFound + { + continue; + } + } + + // Read patched content from blobs + let blob_path = blobs_path.join(&file_info.after_hash); + let patched_content = match tokio::fs::read(&blob_path).await { + Ok(content) => content, + Err(e) => { + result.error = Some(format!( + "Failed to read blob {}: {}", + file_info.after_hash, e + )); + return result; + } + }; + + // Apply the patch + if let Err(e) = apply_file_patch(pkg_path, file_name, &patched_content, &file_info.after_hash).await { + result.error = Some(e.to_string()); + return result; + } + + result.files_patched.push(file_name.clone()); + } + + result.success = true; + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hash::git_sha256::compute_git_sha256_from_bytes; + + #[test] + fn test_normalize_file_path_with_prefix() { + assert_eq!(normalize_file_path("package/lib/server.js"), "lib/server.js"); + } + + #[test] + fn test_normalize_file_path_without_prefix() { + assert_eq!(normalize_file_path("lib/server.js"), "lib/server.js"); + } + + #[test] + fn test_normalize_file_path_just_prefix() { + assert_eq!(normalize_file_path("package/"), ""); + } + + #[test] + fn test_normalize_file_path_package_not_prefix() { + // "package" without trailing "/" should NOT be stripped + assert_eq!(normalize_file_path("packagefoo/bar.js"), "packagefoo/bar.js"); + } + + #[tokio::test] + async fn test_verify_file_patch_not_found() { + let dir = tempfile::tempdir().unwrap(); + let file_info = PatchFileInfo { + before_hash: "aaa".to_string(), + after_hash: "bbb".to_string(), + }; + + let result = verify_file_patch(dir.path(), "nonexistent.js", &file_info).await; + assert_eq!(result.status, VerifyStatus::NotFound); + } + + #[tokio::test] + async fn test_verify_file_patch_ready() { + let dir = tempfile::tempdir().unwrap(); + let content = b"original content"; + let before_hash = compute_git_sha256_from_bytes(content); + let after_hash = "bbbbbbbb".to_string(); + + tokio::fs::write(dir.path().join("index.js"), content) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: before_hash.clone(), + after_hash, + }; + + let result = verify_file_patch(dir.path(), "index.js", &file_info).await; + assert_eq!(result.status, VerifyStatus::Ready); + assert_eq!(result.current_hash.unwrap(), before_hash); + } + + #[tokio::test] + async fn test_verify_file_patch_already_patched() { + let dir = tempfile::tempdir().unwrap(); + let content = b"patched content"; + let after_hash = compute_git_sha256_from_bytes(content); + + tokio::fs::write(dir.path().join("index.js"), content) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: "aaaa".to_string(), + after_hash: after_hash.clone(), + }; + + let result = verify_file_patch(dir.path(), "index.js", &file_info).await; + assert_eq!(result.status, VerifyStatus::AlreadyPatched); + } + + #[tokio::test] + async fn test_verify_file_patch_hash_mismatch() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("index.js"), b"something else") + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: "aaaa".to_string(), + after_hash: "bbbb".to_string(), + }; + + let result = verify_file_patch(dir.path(), "index.js", &file_info).await; + assert_eq!(result.status, VerifyStatus::HashMismatch); + } + + #[tokio::test] + async fn test_verify_with_package_prefix() { + let dir = tempfile::tempdir().unwrap(); + let content = b"original content"; + let before_hash = compute_git_sha256_from_bytes(content); + + // File is at lib/server.js but patch refers to package/lib/server.js + tokio::fs::create_dir_all(dir.path().join("lib")).await.unwrap(); + tokio::fs::write(dir.path().join("lib/server.js"), content) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: before_hash.clone(), + after_hash: "bbbb".to_string(), + }; + + let result = verify_file_patch(dir.path(), "package/lib/server.js", &file_info).await; + assert_eq!(result.status, VerifyStatus::Ready); + } + + #[tokio::test] + async fn test_apply_file_patch_success() { + let dir = tempfile::tempdir().unwrap(); + let original = b"original"; + let patched = b"patched content"; + let patched_hash = compute_git_sha256_from_bytes(patched); + + tokio::fs::write(dir.path().join("index.js"), original) + .await + .unwrap(); + + apply_file_patch(dir.path(), "index.js", patched, &patched_hash) + .await + .unwrap(); + + let written = tokio::fs::read(dir.path().join("index.js")).await.unwrap(); + assert_eq!(written, patched); + } + + #[tokio::test] + async fn test_apply_file_patch_hash_mismatch() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("index.js"), b"original") + .await + .unwrap(); + + let result = + apply_file_patch(dir.path(), "index.js", b"patched content", "wrong_hash").await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("Hash verification failed")); + } + + #[tokio::test] + async fn test_apply_package_patch_success() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let patched = b"patched content"; + let before_hash = compute_git_sha256_from_bytes(original); + let after_hash = compute_git_sha256_from_bytes(patched); + + // Write original file + tokio::fs::write(pkg_dir.path().join("index.js"), original) + .await + .unwrap(); + + // Write blob + tokio::fs::write(blobs_dir.path().join(&after_hash), patched) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash, + after_hash: after_hash.clone(), + }, + ); + + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, false) + .await; + + assert!(result.success); + assert_eq!(result.files_patched.len(), 1); + assert!(result.error.is_none()); + } + + #[tokio::test] + async fn test_apply_package_patch_dry_run() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let before_hash = compute_git_sha256_from_bytes(original); + + tokio::fs::write(pkg_dir.path().join("index.js"), original) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash, + after_hash: "bbbb".to_string(), + }, + ); + + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), true, false) + .await; + + assert!(result.success); + assert_eq!(result.files_patched.len(), 0); // dry run: nothing actually patched + + // File should still have original content + let content = tokio::fs::read(pkg_dir.path().join("index.js")).await.unwrap(); + assert_eq!(content, original); + } + + #[tokio::test] + async fn test_apply_package_patch_all_already_patched() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let patched = b"patched content"; + let after_hash = compute_git_sha256_from_bytes(patched); + + tokio::fs::write(pkg_dir.path().join("index.js"), patched) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash: "aaaa".to_string(), + after_hash, + }, + ); + + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, false) + .await; + + assert!(result.success); + assert_eq!(result.files_patched.len(), 0); + } + + #[tokio::test] + async fn test_apply_package_patch_hash_mismatch_blocks() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + tokio::fs::write(pkg_dir.path().join("index.js"), b"something unexpected") + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash: "aaaa".to_string(), + after_hash: "bbbb".to_string(), + }, + ); + + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, false) + .await; + + assert!(!result.success); + assert!(result.error.is_some()); + } + + #[tokio::test] + async fn test_apply_package_patch_force_hash_mismatch() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let patched = b"patched content"; + let after_hash = compute_git_sha256_from_bytes(patched); + + // Write a file whose hash does NOT match before_hash + tokio::fs::write(pkg_dir.path().join("index.js"), b"something unexpected") + .await + .unwrap(); + + // Write blob + tokio::fs::write(blobs_dir.path().join(&after_hash), patched) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash: "aaaa".to_string(), + after_hash: after_hash.clone(), + }, + ); + + // Without force: should fail + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, false) + .await; + assert!(!result.success); + + // Reset the file + tokio::fs::write(pkg_dir.path().join("index.js"), b"something unexpected") + .await + .unwrap(); + + // With force: should succeed + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, true) + .await; + assert!(result.success); + assert_eq!(result.files_patched.len(), 1); + + let written = tokio::fs::read(pkg_dir.path().join("index.js")).await.unwrap(); + assert_eq!(written, patched); + } + + #[tokio::test] + async fn test_apply_package_patch_force_not_found_skips() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let mut files = HashMap::new(); + files.insert( + "missing.js".to_string(), + PatchFileInfo { + before_hash: "aaaa".to_string(), + after_hash: "bbbb".to_string(), + }, + ); + + // Without force: should fail (NotFound for non-new file) + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, false) + .await; + assert!(!result.success); + + // With force: should succeed by skipping the missing file + let result = + apply_package_patch("pkg:npm/test@1.0.0", pkg_dir.path(), &files, blobs_dir.path(), false, true) + .await; + assert!(result.success); + assert_eq!(result.files_patched.len(), 0); + } +} diff --git a/crates/socket-patch-core/src/patch/file_hash.rs b/crates/socket-patch-core/src/patch/file_hash.rs new file mode 100644 index 0000000..a9dc362 --- /dev/null +++ b/crates/socket-patch-core/src/patch/file_hash.rs @@ -0,0 +1,75 @@ +use std::path::Path; + +use crate::hash::git_sha256::compute_git_sha256_from_reader; + +/// Compute Git-compatible SHA256 hash of file contents using streaming. +/// +/// Gets the file size first, then streams the file through the hasher +/// without loading the entire file into memory. +pub async fn compute_file_git_sha256(filepath: impl AsRef) -> Result { + let filepath = filepath.as_ref(); + + // Get file size first + let metadata = tokio::fs::metadata(filepath).await?; + let file_size = metadata.len(); + + // Open file for streaming read + let file = tokio::fs::File::open(filepath).await?; + let reader = tokio::io::BufReader::new(file); + + compute_git_sha256_from_reader(file_size, reader).await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hash::git_sha256::compute_git_sha256_from_bytes; + + #[tokio::test] + async fn test_compute_file_git_sha256_matches_bytes() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("test.txt"); + + let content = b"Hello, World!"; + tokio::fs::write(&file_path, content).await.unwrap(); + + let file_hash = compute_file_git_sha256(&file_path).await.unwrap(); + let bytes_hash = compute_git_sha256_from_bytes(content); + + assert_eq!(file_hash, bytes_hash); + } + + #[tokio::test] + async fn test_compute_file_git_sha256_empty_file() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("empty.txt"); + + tokio::fs::write(&file_path, b"").await.unwrap(); + + let file_hash = compute_file_git_sha256(&file_path).await.unwrap(); + let bytes_hash = compute_git_sha256_from_bytes(b""); + + assert_eq!(file_hash, bytes_hash); + } + + #[tokio::test] + async fn test_compute_file_git_sha256_not_found() { + let result = compute_file_git_sha256("/nonexistent/file.txt").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_compute_file_git_sha256_large_content() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("large.bin"); + + // Create a file larger than the 8192 byte buffer + let content: Vec = (0..20000).map(|i| (i % 256) as u8).collect(); + tokio::fs::write(&file_path, &content).await.unwrap(); + + let file_hash = compute_file_git_sha256(&file_path).await.unwrap(); + let bytes_hash = compute_git_sha256_from_bytes(&content); + + assert_eq!(file_hash, bytes_hash); + } +} diff --git a/crates/socket-patch-core/src/patch/mod.rs b/crates/socket-patch-core/src/patch/mod.rs new file mode 100644 index 0000000..e17bd8d --- /dev/null +++ b/crates/socket-patch-core/src/patch/mod.rs @@ -0,0 +1,3 @@ +pub mod apply; +pub mod file_hash; +pub mod rollback; diff --git a/crates/socket-patch-core/src/patch/rollback.rs b/crates/socket-patch-core/src/patch/rollback.rs new file mode 100644 index 0000000..959a551 --- /dev/null +++ b/crates/socket-patch-core/src/patch/rollback.rs @@ -0,0 +1,669 @@ +use std::collections::HashMap; +use std::path::Path; + +use crate::manifest::schema::PatchFileInfo; +use crate::patch::file_hash::compute_file_git_sha256; + +/// Status of a file rollback verification. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum VerifyRollbackStatus { + /// File is ready to be rolled back (current hash matches afterHash). + Ready, + /// File is already in the original state (current hash matches beforeHash). + AlreadyOriginal, + /// File hash does not match the expected afterHash. + HashMismatch, + /// File was not found on disk. + NotFound, + /// The before-hash blob needed for rollback is missing from the blobs directory. + MissingBlob, +} + +/// Result of verifying whether a single file can be rolled back. +#[derive(Debug, Clone)] +pub struct VerifyRollbackResult { + pub file: String, + pub status: VerifyRollbackStatus, + pub message: Option, + pub current_hash: Option, + pub expected_hash: Option, + pub target_hash: Option, +} + +/// Result of rolling back patches for a single package. +#[derive(Debug, Clone)] +pub struct RollbackResult { + pub package_key: String, + pub package_path: String, + pub success: bool, + pub files_verified: Vec, + pub files_rolled_back: Vec, + pub error: Option, +} + +/// Normalize file path by removing the "package/" prefix if present. +fn normalize_file_path(file_name: &str) -> &str { + const PACKAGE_PREFIX: &str = "package/"; + if let Some(stripped) = file_name.strip_prefix(PACKAGE_PREFIX) { + stripped + } else { + file_name + } +} + +/// Verify a single file can be rolled back. +/// +/// A file is ready for rollback if: +/// 1. The file exists on disk. +/// 2. The before-hash blob exists in the blobs directory. +/// 3. Its current hash matches the afterHash (patched state). +pub async fn verify_file_rollback( + pkg_path: &Path, + file_name: &str, + file_info: &PatchFileInfo, + blobs_path: &Path, +) -> VerifyRollbackResult { + let normalized = normalize_file_path(file_name); + let filepath = pkg_path.join(normalized); + + let is_new_file = file_info.before_hash.is_empty(); + + // For new files (empty beforeHash), rollback means deleting the file. + if is_new_file { + if tokio::fs::metadata(&filepath).await.is_err() { + // File already doesn't exist — already rolled back. + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::AlreadyOriginal, + message: None, + current_hash: None, + expected_hash: None, + target_hash: None, + }; + } + let current_hash = compute_file_git_sha256(&filepath).await.unwrap_or_default(); + if current_hash == file_info.after_hash { + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::Ready, + message: None, + current_hash: Some(current_hash), + expected_hash: None, + target_hash: None, + }; + } + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::HashMismatch, + message: Some( + "File has been modified after patching. Cannot safely rollback.".to_string(), + ), + current_hash: Some(current_hash), + expected_hash: Some(file_info.after_hash.clone()), + target_hash: None, + }; + } + + // Check if file exists + if tokio::fs::metadata(&filepath).await.is_err() { + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::NotFound, + message: Some("File not found".to_string()), + current_hash: None, + expected_hash: None, + target_hash: None, + }; + } + + // Check if before blob exists (required for rollback) + let before_blob_path = blobs_path.join(&file_info.before_hash); + if tokio::fs::metadata(&before_blob_path).await.is_err() { + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::MissingBlob, + message: Some(format!( + "Before blob not found: {}. Re-download the patch to enable rollback.", + file_info.before_hash + )), + current_hash: None, + expected_hash: None, + target_hash: Some(file_info.before_hash.clone()), + }; + } + + // Compute current hash + let current_hash = match compute_file_git_sha256(&filepath).await { + Ok(h) => h, + Err(e) => { + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::NotFound, + message: Some(format!("Failed to hash file: {}", e)), + current_hash: None, + expected_hash: None, + target_hash: None, + }; + } + }; + + // Check if already in original state + if current_hash == file_info.before_hash { + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::AlreadyOriginal, + message: None, + current_hash: Some(current_hash), + expected_hash: None, + target_hash: None, + }; + } + + // Check if matches expected patched hash (afterHash) + if current_hash != file_info.after_hash { + return VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::HashMismatch, + message: Some( + "File has been modified after patching. Cannot safely rollback.".to_string(), + ), + current_hash: Some(current_hash), + expected_hash: Some(file_info.after_hash.clone()), + target_hash: Some(file_info.before_hash.clone()), + }; + } + + VerifyRollbackResult { + file: file_name.to_string(), + status: VerifyRollbackStatus::Ready, + message: None, + current_hash: Some(current_hash), + expected_hash: None, + target_hash: Some(file_info.before_hash.clone()), + } +} + +/// Rollback a single file to its original state. +/// Writes the original content and verifies the resulting hash. +pub async fn rollback_file_patch( + pkg_path: &Path, + file_name: &str, + original_content: &[u8], + expected_hash: &str, +) -> Result<(), std::io::Error> { + let normalized = normalize_file_path(file_name); + let filepath = pkg_path.join(normalized); + + // Make file writable if it is read-only (e.g. Go module cache) + #[cfg(unix)] + if let Ok(meta) = tokio::fs::metadata(&filepath).await { + use std::os::unix::fs::PermissionsExt; + let perms = meta.permissions(); + if perms.readonly() { + let mode = perms.mode(); + let mut new_perms = perms; + new_perms.set_mode(mode | 0o200); + tokio::fs::set_permissions(&filepath, new_perms).await?; + } + } + + // Write the original content + tokio::fs::write(&filepath, original_content).await?; + + // Verify the hash after writing + let verify_hash = compute_file_git_sha256(&filepath).await?; + if verify_hash != expected_hash { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Hash verification failed after rollback. Expected: {}, Got: {}", + expected_hash, verify_hash + ), + )); + } + + Ok(()) +} + +/// Verify and rollback patches for a single package. +/// +/// For each file in `files`, this function: +/// 1. Verifies the file is ready to be rolled back (or already original). +/// 2. If not dry_run, reads the before-hash blob and writes it back. +/// 3. Returns a summary of what happened. +pub async fn rollback_package_patch( + package_key: &str, + pkg_path: &Path, + files: &HashMap, + blobs_path: &Path, + dry_run: bool, +) -> RollbackResult { + let mut result = RollbackResult { + package_key: package_key.to_string(), + package_path: pkg_path.display().to_string(), + success: false, + files_verified: Vec::new(), + files_rolled_back: Vec::new(), + error: None, + }; + + // First, verify all files + for (file_name, file_info) in files { + let verify_result = + verify_file_rollback(pkg_path, file_name, file_info, blobs_path).await; + + // If any file has issues (not ready and not already original), we can't proceed + if verify_result.status != VerifyRollbackStatus::Ready + && verify_result.status != VerifyRollbackStatus::AlreadyOriginal + { + let msg = verify_result + .message + .clone() + .unwrap_or_else(|| format!("{:?}", verify_result.status)); + result.error = Some(format!( + "Cannot rollback: {} - {}", + verify_result.file, msg + )); + result.files_verified.push(verify_result); + return result; + } + + result.files_verified.push(verify_result); + } + + // Check if all files are already in original state + let all_original = result + .files_verified + .iter() + .all(|v| v.status == VerifyRollbackStatus::AlreadyOriginal); + if all_original { + result.success = true; + return result; + } + + // If dry run, stop here + if dry_run { + result.success = true; + return result; + } + + // Rollback files that need it + for (file_name, file_info) in files { + let verify_result = result + .files_verified + .iter() + .find(|v| v.file == *file_name); + if let Some(vr) = verify_result { + if vr.status == VerifyRollbackStatus::AlreadyOriginal { + continue; + } + } + + // New files (empty beforeHash): delete instead of restoring. + if file_info.before_hash.is_empty() { + let normalized = normalize_file_path(file_name); + let filepath = pkg_path.join(normalized); + if let Err(e) = tokio::fs::remove_file(&filepath).await { + result.error = Some(format!("Failed to delete {}: {}", file_name, e)); + return result; + } + result.files_rolled_back.push(file_name.clone()); + continue; + } + + // Read original content from blobs + let blob_path = blobs_path.join(&file_info.before_hash); + let original_content = match tokio::fs::read(&blob_path).await { + Ok(content) => content, + Err(e) => { + result.error = Some(format!( + "Failed to read blob {}: {}", + file_info.before_hash, e + )); + return result; + } + }; + + // Rollback the file + if let Err(e) = + rollback_file_patch(pkg_path, file_name, &original_content, &file_info.before_hash) + .await + { + result.error = Some(e.to_string()); + return result; + } + + result.files_rolled_back.push(file_name.clone()); + } + + result.success = true; + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hash::git_sha256::compute_git_sha256_from_bytes; + + #[tokio::test] + async fn test_verify_file_rollback_not_found() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let file_info = PatchFileInfo { + before_hash: "aaa".to_string(), + after_hash: "bbb".to_string(), + }; + + let result = + verify_file_rollback(pkg_dir.path(), "nonexistent.js", &file_info, blobs_dir.path()) + .await; + assert_eq!(result.status, VerifyRollbackStatus::NotFound); + } + + #[tokio::test] + async fn test_verify_file_rollback_missing_blob() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let content = b"patched content"; + tokio::fs::write(pkg_dir.path().join("index.js"), content) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: "missing_blob_hash".to_string(), + after_hash: compute_git_sha256_from_bytes(content), + }; + + let result = + verify_file_rollback(pkg_dir.path(), "index.js", &file_info, blobs_dir.path()).await; + assert_eq!(result.status, VerifyRollbackStatus::MissingBlob); + assert!(result.message.unwrap().contains("Before blob not found")); + } + + #[tokio::test] + async fn test_verify_file_rollback_ready() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let patched = b"patched content"; + let before_hash = compute_git_sha256_from_bytes(original); + let after_hash = compute_git_sha256_from_bytes(patched); + + // File is in patched state + tokio::fs::write(pkg_dir.path().join("index.js"), patched) + .await + .unwrap(); + + // Before blob exists + tokio::fs::write(blobs_dir.path().join(&before_hash), original) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: before_hash.clone(), + after_hash: after_hash.clone(), + }; + + let result = + verify_file_rollback(pkg_dir.path(), "index.js", &file_info, blobs_dir.path()).await; + assert_eq!(result.status, VerifyRollbackStatus::Ready); + assert_eq!(result.current_hash.unwrap(), after_hash); + } + + #[tokio::test] + async fn test_verify_file_rollback_already_original() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let before_hash = compute_git_sha256_from_bytes(original); + + // File is already in original state + tokio::fs::write(pkg_dir.path().join("index.js"), original) + .await + .unwrap(); + + // Before blob exists + tokio::fs::write(blobs_dir.path().join(&before_hash), original) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash: before_hash.clone(), + after_hash: "bbbb".to_string(), + }; + + let result = + verify_file_rollback(pkg_dir.path(), "index.js", &file_info, blobs_dir.path()).await; + assert_eq!(result.status, VerifyRollbackStatus::AlreadyOriginal); + } + + #[tokio::test] + async fn test_verify_file_rollback_hash_mismatch() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let before_hash = compute_git_sha256_from_bytes(original); + + // File has been modified to something unexpected + tokio::fs::write(pkg_dir.path().join("index.js"), b"something unexpected") + .await + .unwrap(); + + // Before blob exists + tokio::fs::write(blobs_dir.path().join(&before_hash), original) + .await + .unwrap(); + + let file_info = PatchFileInfo { + before_hash, + after_hash: "expected_after_hash".to_string(), + }; + + let result = + verify_file_rollback(pkg_dir.path(), "index.js", &file_info, blobs_dir.path()).await; + assert_eq!(result.status, VerifyRollbackStatus::HashMismatch); + assert!(result + .message + .unwrap() + .contains("modified after patching")); + } + + #[tokio::test] + async fn test_rollback_file_patch_success() { + let dir = tempfile::tempdir().unwrap(); + let original = b"original content"; + let original_hash = compute_git_sha256_from_bytes(original); + + // File currently has patched content + tokio::fs::write(dir.path().join("index.js"), b"patched") + .await + .unwrap(); + + rollback_file_patch(dir.path(), "index.js", original, &original_hash) + .await + .unwrap(); + + let written = tokio::fs::read(dir.path().join("index.js")).await.unwrap(); + assert_eq!(written, original); + } + + #[tokio::test] + async fn test_rollback_file_patch_hash_mismatch() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("index.js"), b"patched") + .await + .unwrap(); + + let result = + rollback_file_patch(dir.path(), "index.js", b"original content", "wrong_hash").await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Hash verification failed")); + } + + #[tokio::test] + async fn test_rollback_package_patch_success() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let patched = b"patched content"; + let before_hash = compute_git_sha256_from_bytes(original); + let after_hash = compute_git_sha256_from_bytes(patched); + + // File is in patched state + tokio::fs::write(pkg_dir.path().join("index.js"), patched) + .await + .unwrap(); + + // Before blob exists + tokio::fs::write(blobs_dir.path().join(&before_hash), original) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash: before_hash.clone(), + after_hash, + }, + ); + + let result = rollback_package_patch( + "pkg:npm/test@1.0.0", + pkg_dir.path(), + &files, + blobs_dir.path(), + false, + ) + .await; + + assert!(result.success); + assert_eq!(result.files_rolled_back.len(), 1); + assert!(result.error.is_none()); + + // Verify file was restored + let content = tokio::fs::read(pkg_dir.path().join("index.js")).await.unwrap(); + assert_eq!(content, original); + } + + #[tokio::test] + async fn test_rollback_package_patch_dry_run() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let patched = b"patched content"; + let before_hash = compute_git_sha256_from_bytes(original); + let after_hash = compute_git_sha256_from_bytes(patched); + + tokio::fs::write(pkg_dir.path().join("index.js"), patched) + .await + .unwrap(); + tokio::fs::write(blobs_dir.path().join(&before_hash), original) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash, + after_hash, + }, + ); + + let result = rollback_package_patch( + "pkg:npm/test@1.0.0", + pkg_dir.path(), + &files, + blobs_dir.path(), + true, // dry run + ) + .await; + + assert!(result.success); + assert_eq!(result.files_rolled_back.len(), 0); // dry run + + // File should still be patched + let content = tokio::fs::read(pkg_dir.path().join("index.js")).await.unwrap(); + assert_eq!(content, patched); + } + + #[tokio::test] + async fn test_rollback_package_patch_all_original() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + let original = b"original content"; + let before_hash = compute_git_sha256_from_bytes(original); + + // File is already original + tokio::fs::write(pkg_dir.path().join("index.js"), original) + .await + .unwrap(); + tokio::fs::write(blobs_dir.path().join(&before_hash), original) + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash, + after_hash: "bbbb".to_string(), + }, + ); + + let result = rollback_package_patch( + "pkg:npm/test@1.0.0", + pkg_dir.path(), + &files, + blobs_dir.path(), + false, + ) + .await; + + assert!(result.success); + assert_eq!(result.files_rolled_back.len(), 0); + } + + #[tokio::test] + async fn test_rollback_package_patch_missing_blob_blocks() { + let pkg_dir = tempfile::tempdir().unwrap(); + let blobs_dir = tempfile::tempdir().unwrap(); + + tokio::fs::write(pkg_dir.path().join("index.js"), b"patched content") + .await + .unwrap(); + + let mut files = HashMap::new(); + files.insert( + "index.js".to_string(), + PatchFileInfo { + before_hash: "missing_hash".to_string(), + after_hash: "bbbb".to_string(), + }, + ); + + let result = rollback_package_patch( + "pkg:npm/test@1.0.0", + pkg_dir.path(), + &files, + blobs_dir.path(), + false, + ) + .await; + + assert!(!result.success); + assert!(result.error.is_some()); + } +} diff --git a/crates/socket-patch-core/src/utils/cleanup_blobs.rs b/crates/socket-patch-core/src/utils/cleanup_blobs.rs new file mode 100644 index 0000000..0121cb8 --- /dev/null +++ b/crates/socket-patch-core/src/utils/cleanup_blobs.rs @@ -0,0 +1,419 @@ +use std::path::Path; + +use crate::manifest::operations::get_after_hash_blobs; +use crate::manifest::schema::PatchManifest; + +/// Result of a blob cleanup operation. +#[derive(Debug, Clone)] +pub struct CleanupResult { + pub blobs_checked: usize, + pub blobs_removed: usize, + pub bytes_freed: u64, + pub removed_blobs: Vec, +} + +/// Cleans up unused blob files from the blobs directory. +/// +/// Analyzes the manifest to determine which afterHash blobs are needed for applying patches, +/// then removes any blob files that are not needed. +/// +/// Note: beforeHash blobs are considered "unused" because they are downloaded on-demand +/// during rollback operations. This saves disk space since beforeHash blobs are only +/// needed for rollback, not for applying patches. +pub async fn cleanup_unused_blobs( + manifest: &PatchManifest, + blobs_dir: &Path, + dry_run: bool, +) -> Result { + // Only keep afterHash blobs - beforeHash blobs are downloaded on-demand during rollback + let used_blobs = get_after_hash_blobs(manifest); + + // Check if blobs directory exists + if tokio::fs::metadata(blobs_dir).await.is_err() { + // Blobs directory doesn't exist, nothing to clean up + return Ok(CleanupResult { + blobs_checked: 0, + blobs_removed: 0, + bytes_freed: 0, + removed_blobs: vec![], + }); + } + + // Read all files in the blobs directory + let mut read_dir = tokio::fs::read_dir(blobs_dir).await?; + let mut blob_entries = Vec::new(); + + while let Some(entry) = read_dir.next_entry().await? { + blob_entries.push(entry); + } + + let mut result = CleanupResult { + blobs_checked: blob_entries.len(), + blobs_removed: 0, + bytes_freed: 0, + removed_blobs: vec![], + }; + + // Check each blob file + for entry in &blob_entries { + let file_name = entry.file_name(); + let file_name_str = file_name.to_string_lossy().to_string(); + + // Skip hidden files and directories + if file_name_str.starts_with('.') { + continue; + } + + let blob_path = blobs_dir.join(&file_name_str); + + // Check if it's a file (not a directory) + let metadata = tokio::fs::metadata(&blob_path).await?; + if !metadata.is_file() { + continue; + } + + // If this blob is not in use, remove it + if !used_blobs.contains(&file_name_str) { + result.blobs_removed += 1; + result.bytes_freed += metadata.len(); + result.removed_blobs.push(file_name_str); + + if !dry_run { + tokio::fs::remove_file(&blob_path).await?; + } + } + } + + Ok(result) +} + +/// Formats the cleanup result for human-readable output. +pub fn format_cleanup_result(result: &CleanupResult, dry_run: bool) -> String { + if result.blobs_checked == 0 { + return "No blobs directory found, nothing to clean up.".to_string(); + } + + if result.blobs_removed == 0 { + return format!( + "Checked {} blob(s), all are in use.", + result.blobs_checked + ); + } + + let action = if dry_run { "Would remove" } else { "Removed" }; + let bytes_formatted = format_bytes(result.bytes_freed); + + let mut output = format!( + "{} {} unused blob(s) ({} freed)", + action, result.blobs_removed, bytes_formatted + ); + + if dry_run && !result.removed_blobs.is_empty() { + output.push_str("\nUnused blobs:"); + for blob in &result.removed_blobs { + output.push_str(&format!("\n - {}", blob)); + } + } + + output +} + +/// Formats bytes into a human-readable string. +pub fn format_bytes(bytes: u64) -> String { + if bytes == 0 { + return "0 B".to_string(); + } + + const KB: u64 = 1024; + const MB: u64 = 1024 * 1024; + const GB: u64 = 1024 * 1024 * 1024; + + if bytes < KB { + format!("{} B", bytes) + } else if bytes < MB { + format!("{:.2} KB", bytes as f64 / KB as f64) + } else if bytes < GB { + format!("{:.2} MB", bytes as f64 / MB as f64) + } else { + format!("{:.2} GB", bytes as f64 / GB as f64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord}; + use std::collections::HashMap; + + const TEST_UUID: &str = "11111111-1111-4111-8111-111111111111"; + const BEFORE_HASH_1: &str = + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1111"; + const AFTER_HASH_1: &str = + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb1111"; + const BEFORE_HASH_2: &str = + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc2222"; + const AFTER_HASH_2: &str = + "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd2222"; + const ORPHAN_HASH: &str = + "oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo"; + + fn create_test_manifest() -> PatchManifest { + let mut files = HashMap::new(); + files.insert( + "package/index.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_1.to_string(), + after_hash: AFTER_HASH_1.to_string(), + }, + ); + files.insert( + "package/lib/utils.js".to_string(), + PatchFileInfo { + before_hash: BEFORE_HASH_2.to_string(), + after_hash: AFTER_HASH_2.to_string(), + }, + ); + + let mut patches = HashMap::new(); + patches.insert( + "pkg:npm/pkg-a@1.0.0".to_string(), + PatchRecord { + uuid: TEST_UUID.to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files, + vulnerabilities: HashMap::new(), + description: "Test patch".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + }, + ); + + PatchManifest { patches } + } + + #[tokio::test] + async fn test_cleanup_keeps_after_hash_removes_orphan() { + let dir = tempfile::tempdir().unwrap(); + let blobs_dir = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_dir).await.unwrap(); + + let manifest = create_test_manifest(); + + // Create blobs on disk + tokio::fs::write(blobs_dir.join(AFTER_HASH_1), "after content 1") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(AFTER_HASH_2), "after content 2") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(ORPHAN_HASH), "orphan content") + .await + .unwrap(); + + let result = cleanup_unused_blobs(&manifest, &blobs_dir, false) + .await + .unwrap(); + + // Should remove only the orphan blob + assert_eq!(result.blobs_removed, 1); + assert!(result.removed_blobs.contains(&ORPHAN_HASH.to_string())); + + // afterHash blobs should still exist + assert!(tokio::fs::metadata(blobs_dir.join(AFTER_HASH_1)) + .await + .is_ok()); + assert!(tokio::fs::metadata(blobs_dir.join(AFTER_HASH_2)) + .await + .is_ok()); + + // Orphan blob should be removed + assert!(tokio::fs::metadata(blobs_dir.join(ORPHAN_HASH)) + .await + .is_err()); + } + + #[tokio::test] + async fn test_cleanup_removes_before_hash_blobs() { + let dir = tempfile::tempdir().unwrap(); + let blobs_dir = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_dir).await.unwrap(); + + let manifest = create_test_manifest(); + + // Create both beforeHash and afterHash blobs + tokio::fs::write(blobs_dir.join(BEFORE_HASH_1), "before content 1") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(BEFORE_HASH_2), "before content 2") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(AFTER_HASH_1), "after content 1") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(AFTER_HASH_2), "after content 2") + .await + .unwrap(); + + let result = cleanup_unused_blobs(&manifest, &blobs_dir, false) + .await + .unwrap(); + + // Should remove the beforeHash blobs + assert_eq!(result.blobs_removed, 2); + assert!(result.removed_blobs.contains(&BEFORE_HASH_1.to_string())); + assert!(result.removed_blobs.contains(&BEFORE_HASH_2.to_string())); + + // afterHash blobs should still exist + assert!(tokio::fs::metadata(blobs_dir.join(AFTER_HASH_1)) + .await + .is_ok()); + assert!(tokio::fs::metadata(blobs_dir.join(AFTER_HASH_2)) + .await + .is_ok()); + + // beforeHash blobs should be removed + assert!(tokio::fs::metadata(blobs_dir.join(BEFORE_HASH_1)) + .await + .is_err()); + assert!(tokio::fs::metadata(blobs_dir.join(BEFORE_HASH_2)) + .await + .is_err()); + } + + #[tokio::test] + async fn test_cleanup_dry_run_does_not_delete() { + let dir = tempfile::tempdir().unwrap(); + let blobs_dir = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_dir).await.unwrap(); + + let manifest = create_test_manifest(); + + tokio::fs::write(blobs_dir.join(BEFORE_HASH_1), "before content 1") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(AFTER_HASH_1), "after content 1") + .await + .unwrap(); + + let result = cleanup_unused_blobs(&manifest, &blobs_dir, true) + .await + .unwrap(); + + // Should report beforeHash as would-be-removed + assert_eq!(result.blobs_removed, 1); + assert!(result.removed_blobs.contains(&BEFORE_HASH_1.to_string())); + + // But both blobs should still exist + assert!(tokio::fs::metadata(blobs_dir.join(BEFORE_HASH_1)) + .await + .is_ok()); + assert!(tokio::fs::metadata(blobs_dir.join(AFTER_HASH_1)) + .await + .is_ok()); + } + + #[tokio::test] + async fn test_cleanup_empty_manifest_removes_all() { + let dir = tempfile::tempdir().unwrap(); + let blobs_dir = dir.path().join("blobs"); + tokio::fs::create_dir_all(&blobs_dir).await.unwrap(); + + let manifest = PatchManifest::new(); + + tokio::fs::write(blobs_dir.join(AFTER_HASH_1), "content 1") + .await + .unwrap(); + tokio::fs::write(blobs_dir.join(BEFORE_HASH_1), "content 2") + .await + .unwrap(); + + let result = cleanup_unused_blobs(&manifest, &blobs_dir, false) + .await + .unwrap(); + + assert_eq!(result.blobs_removed, 2); + } + + #[tokio::test] + async fn test_cleanup_nonexistent_blobs_dir() { + let dir = tempfile::tempdir().unwrap(); + let non_existent = dir.path().join("non-existent"); + + let manifest = create_test_manifest(); + + let result = cleanup_unused_blobs(&manifest, &non_existent, false) + .await + .unwrap(); + + assert_eq!(result.blobs_checked, 0); + assert_eq!(result.blobs_removed, 0); + } + + #[test] + fn test_format_bytes() { + assert_eq!(format_bytes(0), "0 B"); + assert_eq!(format_bytes(500), "500 B"); + assert_eq!(format_bytes(1023), "1023 B"); + assert_eq!(format_bytes(1024), "1.00 KB"); + assert_eq!(format_bytes(1536), "1.50 KB"); + assert_eq!(format_bytes(1048576), "1.00 MB"); + assert_eq!(format_bytes(1073741824), "1.00 GB"); + } + + #[test] + fn test_format_cleanup_result_no_blobs_dir() { + let result = CleanupResult { + blobs_checked: 0, + blobs_removed: 0, + bytes_freed: 0, + removed_blobs: vec![], + }; + assert_eq!( + format_cleanup_result(&result, false), + "No blobs directory found, nothing to clean up." + ); + } + + #[test] + fn test_format_cleanup_result_all_in_use() { + let result = CleanupResult { + blobs_checked: 5, + blobs_removed: 0, + bytes_freed: 0, + removed_blobs: vec![], + }; + assert_eq!( + format_cleanup_result(&result, false), + "Checked 5 blob(s), all are in use." + ); + } + + #[test] + fn test_format_cleanup_result_removed() { + let result = CleanupResult { + blobs_checked: 5, + blobs_removed: 2, + bytes_freed: 2048, + removed_blobs: vec!["aaa".to_string(), "bbb".to_string()], + }; + assert_eq!( + format_cleanup_result(&result, false), + "Removed 2 unused blob(s) (2.00 KB freed)" + ); + } + + #[test] + fn test_format_cleanup_result_dry_run_lists_blobs() { + let result = CleanupResult { + blobs_checked: 5, + blobs_removed: 2, + bytes_freed: 2048, + removed_blobs: vec!["aaa".to_string(), "bbb".to_string()], + }; + let formatted = format_cleanup_result(&result, true); + assert!(formatted.starts_with("Would remove 2 unused blob(s)")); + assert!(formatted.contains("Unused blobs:")); + assert!(formatted.contains(" - aaa")); + assert!(formatted.contains(" - bbb")); + } +} diff --git a/crates/socket-patch-core/src/utils/enumerate.rs b/crates/socket-patch-core/src/utils/enumerate.rs new file mode 100644 index 0000000..6535766 --- /dev/null +++ b/crates/socket-patch-core/src/utils/enumerate.rs @@ -0,0 +1,109 @@ +use std::path::Path; + +use crate::crawlers::types::{CrawledPackage, CrawlerOptions}; +use crate::crawlers::NpmCrawler; + +/// Type alias for backward compatibility with the TypeScript codebase. +pub type EnumeratedPackage = CrawledPackage; + +/// Enumerate all packages in a `node_modules` directory. +/// +/// This is a convenience wrapper around `NpmCrawler::crawl_all` that creates +/// a crawler with default options rooted at the given `cwd`. +pub async fn enumerate_node_modules(cwd: &Path) -> Vec { + let crawler = NpmCrawler::new(); + let options = CrawlerOptions { + cwd: cwd.to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + crawler.crawl_all(&options).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_enumerate_empty_dir() { + let dir = tempfile::tempdir().unwrap(); + let packages = enumerate_node_modules(dir.path()).await; + assert!(packages.is_empty()); + } + + #[tokio::test] + async fn test_enumerate_with_packages() { + let dir = tempfile::tempdir().unwrap(); + let nm = dir.path().join("node_modules"); + + // Create a simple package + let pkg_dir = nm.join("test-pkg"); + tokio::fs::create_dir_all(&pkg_dir).await.unwrap(); + tokio::fs::write( + pkg_dir.join("package.json"), + r#"{"name": "test-pkg", "version": "1.0.0"}"#, + ) + .await + .unwrap(); + + // Create a scoped package + let scoped_dir = nm.join("@scope").join("my-lib"); + tokio::fs::create_dir_all(&scoped_dir).await.unwrap(); + tokio::fs::write( + scoped_dir.join("package.json"), + r#"{"name": "@scope/my-lib", "version": "2.0.0"}"#, + ) + .await + .unwrap(); + + let packages = enumerate_node_modules(dir.path()).await; + assert_eq!(packages.len(), 2); + + let purls: Vec<&str> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains(&"pkg:npm/test-pkg@1.0.0")); + assert!(purls.contains(&"pkg:npm/@scope/my-lib@2.0.0")); + } + + #[tokio::test] + async fn test_enumerate_deduplicates() { + let dir = tempfile::tempdir().unwrap(); + let nm = dir.path().join("node_modules"); + + // Create package at top level + let pkg1 = nm.join("foo"); + tokio::fs::create_dir_all(&pkg1).await.unwrap(); + tokio::fs::write( + pkg1.join("package.json"), + r#"{"name": "foo", "version": "1.0.0"}"#, + ) + .await + .unwrap(); + + // Create same package nested inside another + let pkg2 = nm.join("bar"); + tokio::fs::create_dir_all(&pkg2).await.unwrap(); + tokio::fs::write( + pkg2.join("package.json"), + r#"{"name": "bar", "version": "2.0.0"}"#, + ) + .await + .unwrap(); + let nested_foo = pkg2.join("node_modules").join("foo"); + tokio::fs::create_dir_all(&nested_foo).await.unwrap(); + tokio::fs::write( + nested_foo.join("package.json"), + r#"{"name": "foo", "version": "1.0.0"}"#, + ) + .await + .unwrap(); + + let packages = enumerate_node_modules(dir.path()).await; + // foo@1.0.0 should be deduplicated + let foo_count = packages + .iter() + .filter(|p| p.purl == "pkg:npm/foo@1.0.0") + .count(); + assert_eq!(foo_count, 1); + } +} diff --git a/crates/socket-patch-core/src/utils/fuzzy_match.rs b/crates/socket-patch-core/src/utils/fuzzy_match.rs new file mode 100644 index 0000000..e508fa4 --- /dev/null +++ b/crates/socket-patch-core/src/utils/fuzzy_match.rs @@ -0,0 +1,266 @@ +use crate::crawlers::types::CrawledPackage; + +// --------------------------------------------------------------------------- +// MatchType enum +// --------------------------------------------------------------------------- + +/// Match type for sorting results by relevance. +/// +/// Lower numeric value = better match. The ordering is: +/// 1. Exact match on full name (including namespace) +/// 2. Exact match on package name only +/// 3. Prefix match on full name +/// 4. Prefix match on package name +/// 5. Contains match on full name +/// 6. Contains match on package name +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum MatchType { + /// Exact match on full name (including namespace). + ExactFull = 0, + /// Exact match on package name only. + ExactName = 1, + /// Query is a prefix of the full name. + PrefixFull = 2, + /// Query is a prefix of the package name. + PrefixName = 3, + /// Query is contained in the full name. + ContainsFull = 4, + /// Query is contained in the package name. + ContainsName = 5, +} + +// --------------------------------------------------------------------------- +// Internal match result +// --------------------------------------------------------------------------- + +struct MatchResult { + package: CrawledPackage, + match_type: MatchType, +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Get the full display name for a package (including namespace if present). +fn get_full_name(pkg: &CrawledPackage) -> String { + match &pkg.namespace { + Some(ns) => format!("{ns}/{}", pkg.name), + None => pkg.name.clone(), + } +} + +/// Determine the match type for a package against a query. +/// Returns `None` if there is no match. +fn get_match_type(pkg: &CrawledPackage, query: &str) -> Option { + let lower_query = query.to_lowercase(); + let full_name = get_full_name(pkg).to_lowercase(); + let name = pkg.name.to_lowercase(); + + // Check exact matches + if full_name == lower_query { + return Some(MatchType::ExactFull); + } + if name == lower_query { + return Some(MatchType::ExactName); + } + + // Check prefix matches + if full_name.starts_with(&lower_query) { + return Some(MatchType::PrefixFull); + } + if name.starts_with(&lower_query) { + return Some(MatchType::PrefixName); + } + + // Check contains matches + if full_name.contains(&lower_query) { + return Some(MatchType::ContainsFull); + } + if name.contains(&lower_query) { + return Some(MatchType::ContainsName); + } + + None +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Fuzzy match packages against a query string. +/// +/// Matches are sorted by relevance: +/// 1. Exact match on full name (e.g., `"@types/node"` matches `"@types/node"`) +/// 2. Exact match on package name (e.g., `"node"` matches `"@types/node"`) +/// 3. Prefix match on full name +/// 4. Prefix match on package name +/// 5. Contains match on full name +/// 6. Contains match on package name +/// +/// Within the same match type, results are sorted alphabetically by full name. +pub fn fuzzy_match_packages( + query: &str, + packages: &[CrawledPackage], + limit: usize, +) -> Vec { + let trimmed = query.trim(); + if trimmed.is_empty() { + return Vec::new(); + } + + let mut matches: Vec = Vec::new(); + + for pkg in packages { + if let Some(match_type) = get_match_type(pkg, trimmed) { + matches.push(MatchResult { + package: pkg.clone(), + match_type, + }); + } + } + + // Sort by match type (lower is better), then alphabetically by full name + matches.sort_by(|a, b| { + let type_cmp = a.match_type.cmp(&b.match_type); + if type_cmp != std::cmp::Ordering::Equal { + return type_cmp; + } + get_full_name(&a.package).cmp(&get_full_name(&b.package)) + }); + + matches + .into_iter() + .take(limit) + .map(|m| m.package) + .collect() +} + +/// Check if a string looks like a PURL. +pub fn is_purl(s: &str) -> bool { + s.starts_with("pkg:") +} + +/// Check if a string looks like a scoped npm package name. +pub fn is_scoped_package(s: &str) -> bool { + s.starts_with('@') && s.contains('/') +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + fn make_pkg( + name: &str, + version: &str, + namespace: Option<&str>, + ) -> CrawledPackage { + let ns = namespace.map(|s| s.to_string()); + let purl = match &ns { + Some(n) => format!("pkg:npm/{n}/{name}@{version}"), + None => format!("pkg:npm/{name}@{version}"), + }; + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: ns, + purl, + path: PathBuf::from("/fake"), + } + } + + #[test] + fn test_exact_full_name() { + let packages = vec![ + make_pkg("node", "20.0.0", Some("@types")), + make_pkg("node-fetch", "3.0.0", None), + ]; + + let results = fuzzy_match_packages("@types/node", &packages, 20); + // "node-fetch" does NOT contain "@types/node", so only 1 result + assert_eq!(results.len(), 1); + assert_eq!(results[0].name, "node"); // ExactFull + assert_eq!(results[0].namespace.as_deref(), Some("@types")); + } + + #[test] + fn test_exact_name_only() { + let packages = vec![ + make_pkg("node", "20.0.0", Some("@types")), + make_pkg("lodash", "4.17.21", None), + ]; + + let results = fuzzy_match_packages("node", &packages, 20); + assert_eq!(results[0].name, "node"); // ExactName + } + + #[test] + fn test_prefix_match() { + let packages = vec![ + make_pkg("lodash", "4.17.21", None), + make_pkg("lodash-es", "4.17.21", None), + ]; + + let results = fuzzy_match_packages("lodash", &packages, 20); + assert_eq!(results.len(), 2); + assert_eq!(results[0].name, "lodash"); // ExactName is better than PrefixName + } + + #[test] + fn test_contains_match() { + let packages = vec![make_pkg("string-width", "5.0.0", None)]; + + let results = fuzzy_match_packages("width", &packages, 20); + assert_eq!(results.len(), 1); + assert_eq!(results[0].name, "string-width"); + } + + #[test] + fn test_no_match() { + let packages = vec![make_pkg("lodash", "4.17.21", None)]; + + let results = fuzzy_match_packages("zzzzz", &packages, 20); + assert!(results.is_empty()); + } + + #[test] + fn test_empty_query() { + let packages = vec![make_pkg("lodash", "4.17.21", None)]; + assert!(fuzzy_match_packages("", &packages, 20).is_empty()); + assert!(fuzzy_match_packages(" ", &packages, 20).is_empty()); + } + + #[test] + fn test_case_insensitive() { + let packages = vec![make_pkg("React", "18.0.0", None)]; + let results = fuzzy_match_packages("react", &packages, 20); + assert_eq!(results.len(), 1); + } + + #[test] + fn test_limit() { + let packages: Vec = (0..50) + .map(|i| make_pkg(&format!("pkg-{i}"), "1.0.0", None)) + .collect(); + + let results = fuzzy_match_packages("pkg", &packages, 10); + assert_eq!(results.len(), 10); + } + + #[test] + fn test_is_purl() { + assert!(is_purl("pkg:npm/lodash@4.17.21")); + assert!(is_purl("pkg:pypi/requests@2.28.0")); + assert!(!is_purl("lodash")); + assert!(!is_purl("@types/node")); + } + + #[test] + fn test_is_scoped_package() { + assert!(is_scoped_package("@types/node")); + assert!(is_scoped_package("@scope/pkg")); + assert!(!is_scoped_package("lodash")); + assert!(!is_scoped_package("@scope")); + } +} diff --git a/crates/socket-patch-core/src/utils/global_packages.rs b/crates/socket-patch-core/src/utils/global_packages.rs new file mode 100644 index 0000000..77653c3 --- /dev/null +++ b/crates/socket-patch-core/src/utils/global_packages.rs @@ -0,0 +1,186 @@ +use std::path::PathBuf; +use std::process::Command; + +// --------------------------------------------------------------------------- +// Individual package manager global prefix helpers +// --------------------------------------------------------------------------- + +/// Get the npm global `node_modules` path using `npm root -g`. +pub fn get_npm_global_prefix() -> Result { + let output = Command::new("npm") + .args(["root", "-g"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .map_err(|e| format!("Failed to run `npm root -g`: {e}"))?; + + if !output.status.success() { + return Err( + "Failed to determine npm global prefix. Ensure npm is installed and in PATH." + .to_string(), + ); + } + + let path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if path.is_empty() { + return Err("npm root -g returned empty output".to_string()); + } + + Ok(path) +} + +/// Get the yarn global `node_modules` path via `yarn global dir`. +pub fn get_yarn_global_prefix() -> Option { + let output = Command::new("yarn") + .args(["global", "dir"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let dir = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if dir.is_empty() { + return None; + } + + Some( + PathBuf::from(dir) + .join("node_modules") + .to_string_lossy() + .to_string(), + ) +} + +/// Get the pnpm global `node_modules` path via `pnpm root -g`. +pub fn get_pnpm_global_prefix() -> Option { + let output = Command::new("pnpm") + .args(["root", "-g"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if path.is_empty() { + return None; + } + + Some(path) +} + +/// Get the bun global `node_modules` path via `bun pm bin -g`. +pub fn get_bun_global_prefix() -> Option { + let output = Command::new("bun") + .args(["pm", "bin", "-g"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let bin_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if bin_path.is_empty() { + return None; + } + + let bun_root = PathBuf::from(&bin_path); + let parent = bun_root.parent()?; + + Some( + parent + .join("install") + .join("global") + .join("node_modules") + .to_string_lossy() + .to_string(), + ) +} + +// --------------------------------------------------------------------------- +// Aggregation helpers +// --------------------------------------------------------------------------- + +/// Get the global `node_modules` path, with support for a custom override. +/// +/// If `custom` is `Some`, that value is returned directly. Otherwise, falls +/// back to `get_npm_global_prefix()`. +pub fn get_global_prefix(custom: Option<&str>) -> Result { + if let Some(custom_path) = custom { + return Ok(custom_path.to_string()); + } + get_npm_global_prefix() +} + +/// Get all global `node_modules` paths for package lookup. +/// +/// Returns paths from all detected package managers (npm, pnpm, yarn, bun). +/// If `custom` is provided, only that path is returned. +pub fn get_global_node_modules_paths(custom: Option<&str>) -> Vec { + if let Some(custom_path) = custom { + return vec![custom_path.to_string()]; + } + + let mut paths = Vec::new(); + + if let Ok(npm_path) = get_npm_global_prefix() { + paths.push(npm_path); + } + + if let Some(pnpm_path) = get_pnpm_global_prefix() { + paths.push(pnpm_path); + } + + if let Some(yarn_path) = get_yarn_global_prefix() { + paths.push(yarn_path); + } + + if let Some(bun_path) = get_bun_global_prefix() { + paths.push(bun_path); + } + + paths +} + +/// Check if a path is within a global `node_modules` directory. +pub fn is_global_path(pkg_path: &str) -> bool { + let paths = get_global_node_modules_paths(None); + let normalized = PathBuf::from(pkg_path); + let normalized_str = normalized.to_string_lossy(); + + paths.iter().any(|global_path| { + let gp = PathBuf::from(global_path); + normalized_str.starts_with(&*gp.to_string_lossy()) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_global_prefix_custom() { + let result = get_global_prefix(Some("/custom/node_modules")); + assert_eq!(result.unwrap(), "/custom/node_modules"); + } + + #[test] + fn test_get_global_node_modules_paths_custom() { + let paths = get_global_node_modules_paths(Some("/my/custom/path")); + assert_eq!(paths, vec!["/my/custom/path".to_string()]); + } +} diff --git a/crates/socket-patch-core/src/utils/mod.rs b/crates/socket-patch-core/src/utils/mod.rs new file mode 100644 index 0000000..482e134 --- /dev/null +++ b/crates/socket-patch-core/src/utils/mod.rs @@ -0,0 +1,6 @@ +pub mod cleanup_blobs; +pub mod enumerate; +pub mod fuzzy_match; +pub mod global_packages; +pub mod purl; +pub mod telemetry; diff --git a/crates/socket-patch-core/src/utils/purl.rs b/crates/socket-patch-core/src/utils/purl.rs new file mode 100644 index 0000000..0699eb6 --- /dev/null +++ b/crates/socket-patch-core/src/utils/purl.rs @@ -0,0 +1,763 @@ +/// Strip query string qualifiers from a PURL. +/// +/// e.g., `"pkg:pypi/requests@2.28.0?artifact_id=abc"` -> `"pkg:pypi/requests@2.28.0"` +pub fn strip_purl_qualifiers(purl: &str) -> &str { + match purl.find('?') { + Some(idx) => &purl[..idx], + None => purl, + } +} + +/// Check if a PURL is a PyPI package. +pub fn is_pypi_purl(purl: &str) -> bool { + purl.starts_with("pkg:pypi/") +} + +/// Check if a PURL is an npm package. +pub fn is_npm_purl(purl: &str) -> bool { + purl.starts_with("pkg:npm/") +} + +/// Parse a PyPI PURL to extract name and version. +/// +/// e.g., `"pkg:pypi/requests@2.28.0?artifact_id=abc"` -> `Some(("requests", "2.28.0"))` +pub fn parse_pypi_purl(purl: &str) -> Option<(&str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:pypi/")?; + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name, version)) +} + +/// Parse an npm PURL to extract namespace, name, and version. +/// +/// e.g., `"pkg:npm/@types/node@20.0.0"` -> `Some((Some("@types"), "node", "20.0.0"))` +/// e.g., `"pkg:npm/lodash@4.17.21"` -> `Some((None, "lodash", "4.17.21"))` +pub fn parse_npm_purl(purl: &str) -> Option<(Option<&str>, &str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:npm/")?; + + // Find the last @ that separates name from version + let at_idx = rest.rfind('@')?; + let name_part = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + + if name_part.is_empty() || version.is_empty() { + return None; + } + + // Check for scoped package (@scope/name) + if name_part.starts_with('@') { + let slash_idx = name_part.find('/')?; + let namespace = &name_part[..slash_idx]; + let name = &name_part[slash_idx + 1..]; + if name.is_empty() { + return None; + } + Some((Some(namespace), name, version)) + } else { + Some((None, name_part, version)) + } +} + +/// Check if a PURL is a Ruby gem. +pub fn is_gem_purl(purl: &str) -> bool { + purl.starts_with("pkg:gem/") +} + +/// Parse a gem PURL to extract name and version. +/// +/// e.g., `"pkg:gem/rails@7.1.0"` -> `Some(("rails", "7.1.0"))` +pub fn parse_gem_purl(purl: &str) -> Option<(&str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:gem/")?; + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name, version)) +} + +/// Build a gem PURL from components. +pub fn build_gem_purl(name: &str, version: &str) -> String { + format!("pkg:gem/{name}@{version}") +} + +/// Check if a PURL is a Maven package. +#[cfg(feature = "maven")] +pub fn is_maven_purl(purl: &str) -> bool { + purl.starts_with("pkg:maven/") +} + +/// Parse a Maven PURL to extract groupId, artifactId, and version. +/// +/// e.g., `"pkg:maven/org.apache.commons/commons-lang3@3.12.0"` -> `Some(("org.apache.commons", "commons-lang3", "3.12.0"))` +#[cfg(feature = "maven")] +pub fn parse_maven_purl(purl: &str) -> Option<(&str, &str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:maven/")?; + let at_idx = rest.rfind('@')?; + let name_part = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + + if name_part.is_empty() || version.is_empty() { + return None; + } + + // Split groupId/artifactId + let slash_idx = name_part.find('/')?; + let group_id = &name_part[..slash_idx]; + let artifact_id = &name_part[slash_idx + 1..]; + + if group_id.is_empty() || artifact_id.is_empty() { + return None; + } + + Some((group_id, artifact_id, version)) +} + +/// Build a Maven PURL from components. +#[cfg(feature = "maven")] +pub fn build_maven_purl(group_id: &str, artifact_id: &str, version: &str) -> String { + format!("pkg:maven/{group_id}/{artifact_id}@{version}") +} + +/// Check if a PURL is a Go module. +#[cfg(feature = "golang")] +pub fn is_golang_purl(purl: &str) -> bool { + purl.starts_with("pkg:golang/") +} + +/// Parse a Go module PURL to extract module path and version. +/// +/// e.g., `"pkg:golang/github.com/gin-gonic/gin@v1.9.1"` -> `Some(("github.com/gin-gonic/gin", "v1.9.1"))` +#[cfg(feature = "golang")] +pub fn parse_golang_purl(purl: &str) -> Option<(&str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:golang/")?; + let at_idx = rest.rfind('@')?; + let module_path = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if module_path.is_empty() || version.is_empty() { + return None; + } + Some((module_path, version)) +} + +/// Build a Go module PURL from components. +#[cfg(feature = "golang")] +pub fn build_golang_purl(module_path: &str, version: &str) -> String { + format!("pkg:golang/{module_path}@{version}") +} + +/// Check if a PURL is a Composer/PHP package. +#[cfg(feature = "composer")] +pub fn is_composer_purl(purl: &str) -> bool { + purl.starts_with("pkg:composer/") +} + +/// Parse a Composer PURL to extract namespace, name, and version. +/// +/// Composer packages always have a namespace (vendor). +/// e.g., `"pkg:composer/monolog/monolog@3.5.0"` -> `Some((("monolog", "monolog"), "3.5.0"))` +#[cfg(feature = "composer")] +pub fn parse_composer_purl(purl: &str) -> Option<((&str, &str), &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:composer/")?; + let at_idx = rest.rfind('@')?; + let name_part = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + + if name_part.is_empty() || version.is_empty() { + return None; + } + + // Split namespace/name + let slash_idx = name_part.find('/')?; + let namespace = &name_part[..slash_idx]; + let name = &name_part[slash_idx + 1..]; + + if namespace.is_empty() || name.is_empty() { + return None; + } + + Some(((namespace, name), version)) +} + +/// Build a Composer PURL from components. +#[cfg(feature = "composer")] +pub fn build_composer_purl(namespace: &str, name: &str, version: &str) -> String { + format!("pkg:composer/{namespace}/{name}@{version}") +} + +/// Check if a PURL is a NuGet/.NET package. +#[cfg(feature = "nuget")] +pub fn is_nuget_purl(purl: &str) -> bool { + purl.starts_with("pkg:nuget/") +} + +/// Parse a NuGet PURL to extract name and version. +/// +/// e.g., `"pkg:nuget/Newtonsoft.Json@13.0.3"` -> `Some(("Newtonsoft.Json", "13.0.3"))` +#[cfg(feature = "nuget")] +pub fn parse_nuget_purl(purl: &str) -> Option<(&str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:nuget/")?; + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name, version)) +} + +/// Build a NuGet PURL from components. +#[cfg(feature = "nuget")] +pub fn build_nuget_purl(name: &str, version: &str) -> String { + format!("pkg:nuget/{name}@{version}") +} + +/// Check if a PURL is a Cargo/Rust crate. +#[cfg(feature = "cargo")] +pub fn is_cargo_purl(purl: &str) -> bool { + purl.starts_with("pkg:cargo/") +} + +/// Parse a Cargo PURL to extract name and version. +/// +/// e.g., `"pkg:cargo/serde@1.0.200"` -> `Some(("serde", "1.0.200"))` +#[cfg(feature = "cargo")] +pub fn parse_cargo_purl(purl: &str) -> Option<(&str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:cargo/")?; + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name, version)) +} + +/// Build a Cargo PURL from components. +#[cfg(feature = "cargo")] +pub fn build_cargo_purl(name: &str, version: &str) -> String { + format!("pkg:cargo/{name}@{version}") +} + +/// Parse a PURL into ecosystem, package directory path, and version. +/// Supports npm, pypi, and (with `cargo` feature) cargo PURLs. +pub fn parse_purl(purl: &str) -> Option<(&str, String, &str)> { + let base = strip_purl_qualifiers(purl); + if let Some(rest) = base.strip_prefix("pkg:npm/") { + let at_idx = rest.rfind('@')?; + let pkg_dir = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if pkg_dir.is_empty() || version.is_empty() { + return None; + } + Some(("npm", pkg_dir.to_string(), version)) + } else if let Some(rest) = base.strip_prefix("pkg:pypi/") { + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some(("pypi", name.to_string(), version)) + } else { + #[cfg(feature = "cargo")] + if let Some(rest) = base.strip_prefix("pkg:cargo/") { + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + return Some(("cargo", name.to_string(), version)); + } + #[cfg(feature = "golang")] + if let Some(rest) = base.strip_prefix("pkg:golang/") { + let at_idx = rest.rfind('@')?; + let module_path = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if module_path.is_empty() || version.is_empty() { + return None; + } + return Some(("golang", module_path.to_string(), version)); + } + if let Some(rest) = base.strip_prefix("pkg:gem/") { + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + return Some(("gem", name.to_string(), version)); + } + #[cfg(feature = "maven")] + if let Some(rest) = base.strip_prefix("pkg:maven/") { + let at_idx = rest.rfind('@')?; + let name_part = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name_part.is_empty() || version.is_empty() { + return None; + } + return Some(("maven", name_part.to_string(), version)); + } + #[cfg(feature = "composer")] + if let Some(rest) = base.strip_prefix("pkg:composer/") { + let at_idx = rest.rfind('@')?; + let name_part = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name_part.is_empty() || version.is_empty() { + return None; + } + return Some(("composer", name_part.to_string(), version)); + } + #[cfg(feature = "nuget")] + if let Some(rest) = base.strip_prefix("pkg:nuget/") { + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + return Some(("nuget", name.to_string(), version)); + } + None + } +} + +/// Check if a string looks like a PURL. +pub fn is_purl(s: &str) -> bool { + s.starts_with("pkg:") +} + +/// Build an npm PURL from components. +pub fn build_npm_purl(namespace: Option<&str>, name: &str, version: &str) -> String { + match namespace { + Some(ns) => format!("pkg:npm/{}/{name}@{version}", ns), + None => format!("pkg:npm/{name}@{version}"), + } +} + +/// Build a PyPI PURL from components. +pub fn build_pypi_purl(name: &str, version: &str) -> String { + format!("pkg:pypi/{name}@{version}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_strip_qualifiers() { + assert_eq!( + strip_purl_qualifiers("pkg:pypi/requests@2.28.0?artifact_id=abc"), + "pkg:pypi/requests@2.28.0" + ); + assert_eq!( + strip_purl_qualifiers("pkg:npm/lodash@4.17.21"), + "pkg:npm/lodash@4.17.21" + ); + } + + #[test] + fn test_is_pypi_purl() { + assert!(is_pypi_purl("pkg:pypi/requests@2.28.0")); + assert!(!is_pypi_purl("pkg:npm/lodash@4.17.21")); + } + + #[test] + fn test_is_npm_purl() { + assert!(is_npm_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_npm_purl("pkg:pypi/requests@2.28.0")); + } + + #[test] + fn test_parse_pypi_purl() { + assert_eq!( + parse_pypi_purl("pkg:pypi/requests@2.28.0"), + Some(("requests", "2.28.0")) + ); + assert_eq!( + parse_pypi_purl("pkg:pypi/requests@2.28.0?artifact_id=abc"), + Some(("requests", "2.28.0")) + ); + assert_eq!(parse_pypi_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_pypi_purl("pkg:pypi/@2.28.0"), None); + assert_eq!(parse_pypi_purl("pkg:pypi/requests@"), None); + } + + #[test] + fn test_parse_npm_purl() { + assert_eq!( + parse_npm_purl("pkg:npm/lodash@4.17.21"), + Some((None, "lodash", "4.17.21")) + ); + assert_eq!( + parse_npm_purl("pkg:npm/@types/node@20.0.0"), + Some((Some("@types"), "node", "20.0.0")) + ); + assert_eq!(parse_npm_purl("pkg:pypi/requests@2.28.0"), None); + } + + #[test] + fn test_parse_purl() { + let (eco, dir, ver) = parse_purl("pkg:npm/lodash@4.17.21").unwrap(); + assert_eq!(eco, "npm"); + assert_eq!(dir, "lodash"); + assert_eq!(ver, "4.17.21"); + + let (eco, dir, ver) = parse_purl("pkg:npm/@types/node@20.0.0").unwrap(); + assert_eq!(eco, "npm"); + assert_eq!(dir, "@types/node"); + assert_eq!(ver, "20.0.0"); + + let (eco, dir, ver) = parse_purl("pkg:pypi/requests@2.28.0").unwrap(); + assert_eq!(eco, "pypi"); + assert_eq!(dir, "requests"); + assert_eq!(ver, "2.28.0"); + } + + #[test] + fn test_is_purl() { + assert!(is_purl("pkg:npm/lodash@4.17.21")); + assert!(is_purl("pkg:pypi/requests@2.28.0")); + assert!(!is_purl("lodash")); + assert!(!is_purl("CVE-2024-1234")); + } + + #[test] + fn test_build_npm_purl() { + assert_eq!( + build_npm_purl(None, "lodash", "4.17.21"), + "pkg:npm/lodash@4.17.21" + ); + assert_eq!( + build_npm_purl(Some("@types"), "node", "20.0.0"), + "pkg:npm/@types/node@20.0.0" + ); + } + + #[test] + fn test_build_pypi_purl() { + assert_eq!( + build_pypi_purl("requests", "2.28.0"), + "pkg:pypi/requests@2.28.0" + ); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_is_cargo_purl() { + assert!(is_cargo_purl("pkg:cargo/serde@1.0.200")); + assert!(!is_cargo_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_cargo_purl("pkg:pypi/requests@2.28.0")); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_parse_cargo_purl() { + assert_eq!( + parse_cargo_purl("pkg:cargo/serde@1.0.200"), + Some(("serde", "1.0.200")) + ); + assert_eq!( + parse_cargo_purl("pkg:cargo/serde_json@1.0.120"), + Some(("serde_json", "1.0.120")) + ); + assert_eq!(parse_cargo_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_cargo_purl("pkg:cargo/@1.0.0"), None); + assert_eq!(parse_cargo_purl("pkg:cargo/serde@"), None); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_build_cargo_purl() { + assert_eq!( + build_cargo_purl("serde", "1.0.200"), + "pkg:cargo/serde@1.0.200" + ); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_cargo_purl_round_trip() { + let purl = build_cargo_purl("tokio", "1.38.0"); + let (name, version) = parse_cargo_purl(&purl).unwrap(); + assert_eq!(name, "tokio"); + assert_eq!(version, "1.38.0"); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_parse_purl_cargo() { + let (eco, dir, ver) = parse_purl("pkg:cargo/serde@1.0.200").unwrap(); + assert_eq!(eco, "cargo"); + assert_eq!(dir, "serde"); + assert_eq!(ver, "1.0.200"); + } + + #[test] + fn test_is_gem_purl() { + assert!(is_gem_purl("pkg:gem/rails@7.1.0")); + assert!(!is_gem_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_gem_purl("pkg:pypi/requests@2.28.0")); + } + + #[test] + fn test_parse_gem_purl() { + assert_eq!( + parse_gem_purl("pkg:gem/rails@7.1.0"), + Some(("rails", "7.1.0")) + ); + assert_eq!( + parse_gem_purl("pkg:gem/nokogiri@1.16.5"), + Some(("nokogiri", "1.16.5")) + ); + assert_eq!(parse_gem_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_gem_purl("pkg:gem/@1.0.0"), None); + assert_eq!(parse_gem_purl("pkg:gem/rails@"), None); + } + + #[test] + fn test_build_gem_purl() { + assert_eq!( + build_gem_purl("rails", "7.1.0"), + "pkg:gem/rails@7.1.0" + ); + } + + #[test] + fn test_gem_purl_round_trip() { + let purl = build_gem_purl("nokogiri", "1.16.5"); + let (name, version) = parse_gem_purl(&purl).unwrap(); + assert_eq!(name, "nokogiri"); + assert_eq!(version, "1.16.5"); + } + + #[test] + fn test_parse_purl_gem() { + let (eco, dir, ver) = parse_purl("pkg:gem/rails@7.1.0").unwrap(); + assert_eq!(eco, "gem"); + assert_eq!(dir, "rails"); + assert_eq!(ver, "7.1.0"); + } + + #[cfg(feature = "maven")] + #[test] + fn test_is_maven_purl() { + assert!(is_maven_purl("pkg:maven/org.apache.commons/commons-lang3@3.12.0")); + assert!(!is_maven_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_maven_purl("pkg:pypi/requests@2.28.0")); + } + + #[cfg(feature = "maven")] + #[test] + fn test_parse_maven_purl() { + assert_eq!( + parse_maven_purl("pkg:maven/org.apache.commons/commons-lang3@3.12.0"), + Some(("org.apache.commons", "commons-lang3", "3.12.0")) + ); + assert_eq!( + parse_maven_purl("pkg:maven/com.google.guava/guava@32.1.3-jre"), + Some(("com.google.guava", "guava", "32.1.3-jre")) + ); + assert_eq!(parse_maven_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_maven_purl("pkg:maven/@3.12.0"), None); + assert_eq!(parse_maven_purl("pkg:maven/org.apache.commons/@3.12.0"), None); + assert_eq!(parse_maven_purl("pkg:maven/org.apache.commons/commons-lang3@"), None); + } + + #[cfg(feature = "maven")] + #[test] + fn test_build_maven_purl() { + assert_eq!( + build_maven_purl("org.apache.commons", "commons-lang3", "3.12.0"), + "pkg:maven/org.apache.commons/commons-lang3@3.12.0" + ); + } + + #[cfg(feature = "maven")] + #[test] + fn test_maven_purl_round_trip() { + let purl = build_maven_purl("com.google.guava", "guava", "32.1.3-jre"); + let (group_id, artifact_id, version) = parse_maven_purl(&purl).unwrap(); + assert_eq!(group_id, "com.google.guava"); + assert_eq!(artifact_id, "guava"); + assert_eq!(version, "32.1.3-jre"); + } + + #[cfg(feature = "maven")] + #[test] + fn test_parse_purl_maven() { + let (eco, dir, ver) = parse_purl("pkg:maven/org.apache.commons/commons-lang3@3.12.0").unwrap(); + assert_eq!(eco, "maven"); + assert_eq!(dir, "org.apache.commons/commons-lang3"); + assert_eq!(ver, "3.12.0"); + } + + #[cfg(feature = "golang")] + #[test] + fn test_is_golang_purl() { + assert!(is_golang_purl("pkg:golang/github.com/gin-gonic/gin@v1.9.1")); + assert!(!is_golang_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_golang_purl("pkg:pypi/requests@2.28.0")); + } + + #[cfg(feature = "golang")] + #[test] + fn test_parse_golang_purl() { + assert_eq!( + parse_golang_purl("pkg:golang/github.com/gin-gonic/gin@v1.9.1"), + Some(("github.com/gin-gonic/gin", "v1.9.1")) + ); + assert_eq!( + parse_golang_purl("pkg:golang/golang.org/x/text@v0.14.0"), + Some(("golang.org/x/text", "v0.14.0")) + ); + assert_eq!(parse_golang_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_golang_purl("pkg:golang/@v1.0.0"), None); + assert_eq!(parse_golang_purl("pkg:golang/github.com/foo/bar@"), None); + } + + #[cfg(feature = "golang")] + #[test] + fn test_build_golang_purl() { + assert_eq!( + build_golang_purl("github.com/gin-gonic/gin", "v1.9.1"), + "pkg:golang/github.com/gin-gonic/gin@v1.9.1" + ); + } + + #[cfg(feature = "golang")] + #[test] + fn test_golang_purl_round_trip() { + let purl = build_golang_purl("golang.org/x/text", "v0.14.0"); + let (module_path, version) = parse_golang_purl(&purl).unwrap(); + assert_eq!(module_path, "golang.org/x/text"); + assert_eq!(version, "v0.14.0"); + } + + #[cfg(feature = "golang")] + #[test] + fn test_parse_purl_golang() { + let (eco, dir, ver) = parse_purl("pkg:golang/github.com/gin-gonic/gin@v1.9.1").unwrap(); + assert_eq!(eco, "golang"); + assert_eq!(dir, "github.com/gin-gonic/gin"); + assert_eq!(ver, "v1.9.1"); + } + + #[cfg(feature = "composer")] + #[test] + fn test_is_composer_purl() { + assert!(is_composer_purl("pkg:composer/monolog/monolog@3.5.0")); + assert!(!is_composer_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_composer_purl("pkg:pypi/requests@2.28.0")); + } + + #[cfg(feature = "composer")] + #[test] + fn test_parse_composer_purl() { + assert_eq!( + parse_composer_purl("pkg:composer/monolog/monolog@3.5.0"), + Some((("monolog", "monolog"), "3.5.0")) + ); + assert_eq!( + parse_composer_purl("pkg:composer/symfony/console@6.4.1"), + Some((("symfony", "console"), "6.4.1")) + ); + assert_eq!(parse_composer_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_composer_purl("pkg:composer/@3.5.0"), None); + assert_eq!(parse_composer_purl("pkg:composer/monolog/@3.5.0"), None); + assert_eq!(parse_composer_purl("pkg:composer/monolog/monolog@"), None); + } + + #[cfg(feature = "composer")] + #[test] + fn test_build_composer_purl() { + assert_eq!( + build_composer_purl("monolog", "monolog", "3.5.0"), + "pkg:composer/monolog/monolog@3.5.0" + ); + } + + #[cfg(feature = "composer")] + #[test] + fn test_composer_purl_round_trip() { + let purl = build_composer_purl("symfony", "console", "6.4.1"); + let ((namespace, name), version) = parse_composer_purl(&purl).unwrap(); + assert_eq!(namespace, "symfony"); + assert_eq!(name, "console"); + assert_eq!(version, "6.4.1"); + } + + #[cfg(feature = "composer")] + #[test] + fn test_parse_purl_composer() { + let (eco, dir, ver) = parse_purl("pkg:composer/monolog/monolog@3.5.0").unwrap(); + assert_eq!(eco, "composer"); + assert_eq!(dir, "monolog/monolog"); + assert_eq!(ver, "3.5.0"); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_is_nuget_purl() { + assert!(is_nuget_purl("pkg:nuget/Newtonsoft.Json@13.0.3")); + assert!(!is_nuget_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_nuget_purl("pkg:pypi/requests@2.28.0")); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_parse_nuget_purl() { + assert_eq!( + parse_nuget_purl("pkg:nuget/Newtonsoft.Json@13.0.3"), + Some(("Newtonsoft.Json", "13.0.3")) + ); + assert_eq!( + parse_nuget_purl("pkg:nuget/System.Text.Json@8.0.0"), + Some(("System.Text.Json", "8.0.0")) + ); + assert_eq!(parse_nuget_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_nuget_purl("pkg:nuget/@1.0.0"), None); + assert_eq!(parse_nuget_purl("pkg:nuget/Newtonsoft.Json@"), None); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_build_nuget_purl() { + assert_eq!( + build_nuget_purl("Newtonsoft.Json", "13.0.3"), + "pkg:nuget/Newtonsoft.Json@13.0.3" + ); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_nuget_purl_round_trip() { + let purl = build_nuget_purl("System.Text.Json", "8.0.0"); + let (name, version) = parse_nuget_purl(&purl).unwrap(); + assert_eq!(name, "System.Text.Json"); + assert_eq!(version, "8.0.0"); + } + + #[cfg(feature = "nuget")] + #[test] + fn test_parse_purl_nuget() { + let (eco, dir, ver) = parse_purl("pkg:nuget/Newtonsoft.Json@13.0.3").unwrap(); + assert_eq!(eco, "nuget"); + assert_eq!(dir, "Newtonsoft.Json"); + assert_eq!(ver, "13.0.3"); + } +} diff --git a/crates/socket-patch-core/src/utils/telemetry.rs b/crates/socket-patch-core/src/utils/telemetry.rs new file mode 100644 index 0000000..d0892a9 --- /dev/null +++ b/crates/socket-patch-core/src/utils/telemetry.rs @@ -0,0 +1,647 @@ +use std::collections::HashMap; + +use once_cell::sync::Lazy; +use uuid::Uuid; + +use crate::constants::{DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT}; + +// --------------------------------------------------------------------------- +// Session ID — generated once per process invocation +// --------------------------------------------------------------------------- + +/// Unique session ID for the current CLI invocation. +/// Shared across all telemetry events in a single run. +static SESSION_ID: Lazy = Lazy::new(|| Uuid::new_v4().to_string()); + +/// Package version — updated during build. +const PACKAGE_VERSION: &str = "1.0.0"; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +/// Telemetry event types for the patch lifecycle. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PatchTelemetryEventType { + PatchApplied, + PatchApplyFailed, + PatchRemoved, + PatchRemoveFailed, + PatchRolledBack, + PatchRollbackFailed, +} + +impl PatchTelemetryEventType { + /// Return the wire-format string for this event type. + pub fn as_str(&self) -> &'static str { + match self { + Self::PatchApplied => "patch_applied", + Self::PatchApplyFailed => "patch_apply_failed", + Self::PatchRemoved => "patch_removed", + Self::PatchRemoveFailed => "patch_remove_failed", + Self::PatchRolledBack => "patch_rolled_back", + Self::PatchRollbackFailed => "patch_rollback_failed", + } + } +} + +/// Telemetry context describing the execution environment. +#[derive(Debug, Clone, serde::Serialize)] +pub struct PatchTelemetryContext { + pub version: String, + pub platform: String, + pub arch: String, + pub command: String, +} + +/// Error details for telemetry events. +#[derive(Debug, Clone, serde::Serialize)] +pub struct PatchTelemetryError { + #[serde(rename = "type")] + pub error_type: String, + pub message: Option, +} + +/// Telemetry event structure for patch operations. +#[derive(Debug, Clone, serde::Serialize)] +pub struct PatchTelemetryEvent { + pub event_sender_created_at: String, + pub event_type: String, + pub context: PatchTelemetryContext, + pub session_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// Options for tracking a patch event. +pub struct TrackPatchEventOptions { + /// The type of event being tracked. + pub event_type: PatchTelemetryEventType, + /// The CLI command being executed (e.g., "apply", "remove", "rollback"). + pub command: String, + /// Optional metadata to include with the event. + pub metadata: Option>, + /// Optional error information if the operation failed. + /// Tuple of (error_type, message). + pub error: Option<(String, String)>, + /// Optional API token for authenticated telemetry endpoint. + pub api_token: Option, + /// Optional organization slug for authenticated telemetry endpoint. + pub org_slug: Option, +} + +// --------------------------------------------------------------------------- +// Environment checks +// --------------------------------------------------------------------------- + +/// Check if telemetry is disabled via environment variables. +/// +/// Telemetry is disabled when: +/// - `SOCKET_PATCH_TELEMETRY_DISABLED` is `"1"` or `"true"` +/// - `VITEST` is `"true"` (test environment) +pub fn is_telemetry_disabled() -> bool { + matches!( + std::env::var("SOCKET_PATCH_TELEMETRY_DISABLED") + .unwrap_or_default() + .as_str(), + "1" | "true" + ) || std::env::var("VITEST").unwrap_or_default() == "true" +} + +/// Check if debug mode is enabled. +fn is_debug_enabled() -> bool { + matches!( + std::env::var("SOCKET_PATCH_DEBUG") + .unwrap_or_default() + .as_str(), + "1" | "true" + ) +} + +/// Log debug messages when debug mode is enabled. +fn debug_log(message: &str) { + if is_debug_enabled() { + eprintln!("[socket-patch telemetry] {message}"); + } +} + +// --------------------------------------------------------------------------- +// Build event +// --------------------------------------------------------------------------- + +/// Build the telemetry context for the current environment. +fn build_telemetry_context(command: &str) -> PatchTelemetryContext { + PatchTelemetryContext { + version: PACKAGE_VERSION.to_string(), + platform: std::env::consts::OS.to_string(), + arch: std::env::consts::ARCH.to_string(), + command: command.to_string(), + } +} + +/// Sanitize an error message for telemetry. +/// +/// Replaces the user's home directory path with `~` to avoid leaking +/// sensitive file system information. +pub fn sanitize_error_message(message: &str) -> String { + if let Some(home) = home_dir_string() { + if !home.is_empty() { + return message.replace(&home, "~"); + } + } + message.to_string() +} + +/// Get the home directory as a string. +fn home_dir_string() -> Option { + std::env::var("HOME") + .ok() + .or_else(|| std::env::var("USERPROFILE").ok()) +} + +/// Build a telemetry event from the given options. +fn build_telemetry_event(options: &TrackPatchEventOptions) -> PatchTelemetryEvent { + let error = options.error.as_ref().map(|(error_type, message)| { + PatchTelemetryError { + error_type: error_type.clone(), + message: Some(sanitize_error_message(message)), + } + }); + + PatchTelemetryEvent { + event_sender_created_at: chrono_now_iso(), + event_type: options.event_type.as_str().to_string(), + context: build_telemetry_context(&options.command), + session_id: SESSION_ID.clone(), + metadata: options.metadata.clone(), + error, + } +} + +/// Get the current time as an ISO 8601 string. +fn chrono_now_iso() -> String { + let now = std::time::SystemTime::now(); + let duration = now + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default(); + let secs = duration.as_secs(); + + let days = secs / 86400; + let remaining = secs % 86400; + let hours = remaining / 3600; + let minutes = (remaining % 3600) / 60; + let seconds = remaining % 60; + let millis = duration.subsec_millis(); + + let (year, month, day) = days_to_ymd(days); + + format!( + "{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}.{millis:03}Z" + ) +} + +/// Convert days since Unix epoch to (year, month, day). +fn days_to_ymd(days: u64) -> (u64, u64, u64) { + // Adapted from Howard Hinnant's civil_from_days algorithm + let z = days as i64 + 719468; + let era = if z >= 0 { z } else { z - 146096 } / 146097; + let doe = (z - era * 146097) as u64; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let y = if m <= 2 { y + 1 } else { y }; + (y as u64, m, d) +} + +// --------------------------------------------------------------------------- +// Send event +// --------------------------------------------------------------------------- + +/// Send a telemetry event to the API. +/// +/// This is fire-and-forget: errors are logged in debug mode but never +/// propagated. Uses `reqwest` with a 5-second timeout. +async fn send_telemetry_event( + event: &PatchTelemetryEvent, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + let (url, use_auth) = match (api_token, org_slug) { + (Some(_token), Some(slug)) => { + let api_url = std::env::var("SOCKET_API_URL") + .unwrap_or_else(|_| DEFAULT_SOCKET_API_URL.to_string()); + (format!("{api_url}/v0/orgs/{slug}/telemetry"), true) + } + _ => { + let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL") + .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string()); + (format!("{proxy_url}/patch/telemetry"), false) + } + }; + + debug_log(&format!("Sending telemetry to {url}")); + + let client = match reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(5)) + .build() + { + Ok(c) => c, + Err(e) => { + debug_log(&format!("Failed to build HTTP client: {e}")); + return; + } + }; + + let mut request = client + .post(&url) + .header("Content-Type", "application/json") + .header("User-Agent", USER_AGENT); + + if use_auth { + if let Some(token) = api_token { + request = request.header("Authorization", format!("Bearer {token}")); + } + } + + match request.json(event).send().await { + Ok(response) => { + let status = response.status(); + if status.is_success() { + debug_log("Telemetry sent successfully"); + } else { + debug_log(&format!("Telemetry request returned status {status}")); + } + } + Err(e) => { + debug_log(&format!("Telemetry request failed: {e}")); + } + } +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Track a patch lifecycle event. +/// +/// This function is non-blocking and will never return errors. Telemetry +/// failures are logged in debug mode but do not affect CLI operation. +/// +/// If telemetry is disabled (via environment variables), the function returns +/// immediately. +pub async fn track_patch_event(options: TrackPatchEventOptions) { + if is_telemetry_disabled() { + debug_log("Telemetry is disabled, skipping event"); + return; + } + + let event = build_telemetry_event(&options); + send_telemetry_event( + &event, + options.api_token.as_deref(), + options.org_slug.as_deref(), + ) + .await; +} + +/// Fire-and-forget version of `track_patch_event` that spawns the request +/// on a background task so it never blocks the caller. +pub fn track_patch_event_fire_and_forget(options: TrackPatchEventOptions) { + if is_telemetry_disabled() { + debug_log("Telemetry is disabled, skipping event"); + return; + } + + let event = build_telemetry_event(&options); + let api_token = options.api_token.clone(); + let org_slug = options.org_slug.clone(); + + tokio::spawn(async move { + send_telemetry_event(&event, api_token.as_deref(), org_slug.as_deref()).await; + }); +} + +// --------------------------------------------------------------------------- +// Convenience functions +// +// These accept `Option<&str>` for api_token/org_slug to make call sites +// convenient (callers typically have `Option` and call `.as_deref()`). +// --------------------------------------------------------------------------- + +/// Track a successful patch application. +pub async fn track_patch_applied( + patches_count: usize, + dry_run: bool, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + let mut metadata = HashMap::new(); + metadata.insert( + "patches_count".to_string(), + serde_json::Value::Number(serde_json::Number::from(patches_count)), + ); + metadata.insert("dry_run".to_string(), serde_json::Value::Bool(dry_run)); + + track_patch_event(TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchApplied, + command: "apply".to_string(), + metadata: Some(metadata), + error: None, + api_token: api_token.map(|s| s.to_string()), + org_slug: org_slug.map(|s| s.to_string()), + }) + .await; +} + +/// Track a failed patch application. +/// +/// Accepts any `Display` type for the error (works with `&str`, `String`, +/// `anyhow::Error`, `std::io::Error`, etc.). +pub async fn track_patch_apply_failed( + error: impl std::fmt::Display, + dry_run: bool, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + let mut metadata = HashMap::new(); + metadata.insert("dry_run".to_string(), serde_json::Value::Bool(dry_run)); + + track_patch_event(TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchApplyFailed, + command: "apply".to_string(), + metadata: Some(metadata), + error: Some(("Error".to_string(), error.to_string())), + api_token: api_token.map(|s| s.to_string()), + org_slug: org_slug.map(|s| s.to_string()), + }) + .await; +} + +/// Track a successful patch removal. +pub async fn track_patch_removed( + removed_count: usize, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + let mut metadata = HashMap::new(); + metadata.insert( + "removed_count".to_string(), + serde_json::Value::Number(serde_json::Number::from(removed_count)), + ); + + track_patch_event(TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchRemoved, + command: "remove".to_string(), + metadata: Some(metadata), + error: None, + api_token: api_token.map(|s| s.to_string()), + org_slug: org_slug.map(|s| s.to_string()), + }) + .await; +} + +/// Track a failed patch removal. +/// +/// Accepts any `Display` type for the error. +pub async fn track_patch_remove_failed( + error: impl std::fmt::Display, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + track_patch_event(TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchRemoveFailed, + command: "remove".to_string(), + metadata: None, + error: Some(("Error".to_string(), error.to_string())), + api_token: api_token.map(|s| s.to_string()), + org_slug: org_slug.map(|s| s.to_string()), + }) + .await; +} + +/// Track a successful patch rollback. +pub async fn track_patch_rolled_back( + rolled_back_count: usize, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + let mut metadata = HashMap::new(); + metadata.insert( + "rolled_back_count".to_string(), + serde_json::Value::Number(serde_json::Number::from(rolled_back_count)), + ); + + track_patch_event(TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchRolledBack, + command: "rollback".to_string(), + metadata: Some(metadata), + error: None, + api_token: api_token.map(|s| s.to_string()), + org_slug: org_slug.map(|s| s.to_string()), + }) + .await; +} + +/// Track a failed patch rollback. +/// +/// Accepts any `Display` type for the error. +pub async fn track_patch_rollback_failed( + error: impl std::fmt::Display, + api_token: Option<&str>, + org_slug: Option<&str>, +) { + track_patch_event(TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchRollbackFailed, + command: "rollback".to_string(), + metadata: None, + error: Some(("Error".to_string(), error.to_string())), + api_token: api_token.map(|s| s.to_string()), + org_slug: org_slug.map(|s| s.to_string()), + }) + .await; +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Combined into a single test to avoid env-var races across parallel tests. + #[test] + fn test_is_telemetry_disabled() { + // Save originals + let orig_disabled = std::env::var("SOCKET_PATCH_TELEMETRY_DISABLED").ok(); + let orig_vitest = std::env::var("VITEST").ok(); + + // Default: not disabled + std::env::remove_var("SOCKET_PATCH_TELEMETRY_DISABLED"); + std::env::remove_var("VITEST"); + assert!(!is_telemetry_disabled()); + + // Disabled via "1" + std::env::set_var("SOCKET_PATCH_TELEMETRY_DISABLED", "1"); + assert!(is_telemetry_disabled()); + + // Disabled via "true" + std::env::set_var("SOCKET_PATCH_TELEMETRY_DISABLED", "true"); + assert!(is_telemetry_disabled()); + + // Restore originals + match orig_disabled { + Some(v) => std::env::set_var("SOCKET_PATCH_TELEMETRY_DISABLED", v), + None => std::env::remove_var("SOCKET_PATCH_TELEMETRY_DISABLED"), + } + match orig_vitest { + Some(v) => std::env::set_var("VITEST", v), + None => std::env::remove_var("VITEST"), + } + } + + #[test] + fn test_sanitize_error_message() { + let home = home_dir_string().unwrap_or_else(|| "/home/testuser".to_string()); + let msg = format!("Failed to read {home}/projects/secret/file.txt"); + let sanitized = sanitize_error_message(&msg); + assert!(sanitized.contains("~/projects/secret/file.txt")); + assert!(!sanitized.contains(&home)); + } + + #[test] + fn test_sanitize_error_message_no_home() { + let msg = "Some error without paths"; + assert_eq!(sanitize_error_message(msg), msg); + } + + #[test] + fn test_event_type_as_str() { + assert_eq!(PatchTelemetryEventType::PatchApplied.as_str(), "patch_applied"); + assert_eq!( + PatchTelemetryEventType::PatchApplyFailed.as_str(), + "patch_apply_failed" + ); + assert_eq!(PatchTelemetryEventType::PatchRemoved.as_str(), "patch_removed"); + assert_eq!( + PatchTelemetryEventType::PatchRemoveFailed.as_str(), + "patch_remove_failed" + ); + assert_eq!( + PatchTelemetryEventType::PatchRolledBack.as_str(), + "patch_rolled_back" + ); + assert_eq!( + PatchTelemetryEventType::PatchRollbackFailed.as_str(), + "patch_rollback_failed" + ); + } + + #[test] + fn test_build_telemetry_context() { + let ctx = build_telemetry_context("apply"); + assert_eq!(ctx.command, "apply"); + assert_eq!(ctx.version, PACKAGE_VERSION); + assert!(!ctx.platform.is_empty()); + assert!(!ctx.arch.is_empty()); + } + + #[test] + fn test_build_telemetry_event_basic() { + let options = TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchApplied, + command: "apply".to_string(), + metadata: None, + error: None, + api_token: None, + org_slug: None, + }; + + let event = build_telemetry_event(&options); + assert_eq!(event.event_type, "patch_applied"); + assert_eq!(event.context.command, "apply"); + assert!(!event.session_id.is_empty()); + assert!(!event.event_sender_created_at.is_empty()); + assert!(event.metadata.is_none()); + assert!(event.error.is_none()); + } + + #[test] + fn test_build_telemetry_event_with_metadata() { + let mut metadata = HashMap::new(); + metadata.insert( + "patches_count".to_string(), + serde_json::Value::Number(5.into()), + ); + + let options = TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchApplied, + command: "apply".to_string(), + metadata: Some(metadata), + error: None, + api_token: None, + org_slug: None, + }; + + let event = build_telemetry_event(&options); + assert!(event.metadata.is_some()); + let meta = event.metadata.unwrap(); + assert_eq!( + meta.get("patches_count").unwrap(), + &serde_json::Value::Number(5.into()) + ); + } + + #[test] + fn test_build_telemetry_event_with_error() { + let options = TrackPatchEventOptions { + event_type: PatchTelemetryEventType::PatchApplyFailed, + command: "apply".to_string(), + metadata: None, + error: Some(("IoError".to_string(), "file not found".to_string())), + api_token: None, + org_slug: None, + }; + + let event = build_telemetry_event(&options); + assert!(event.error.is_some()); + let err = event.error.unwrap(); + assert_eq!(err.error_type, "IoError"); + assert_eq!(err.message.unwrap(), "file not found"); + } + + #[test] + fn test_session_id_is_consistent() { + let id1 = SESSION_ID.clone(); + let id2 = SESSION_ID.clone(); + assert_eq!(id1, id2); + // Should be a valid UUID v4 format + assert_eq!(id1.len(), 36); + assert!(id1.contains('-')); + } + + #[test] + fn test_chrono_now_iso_format() { + let ts = chrono_now_iso(); + // Should look like "2024-01-15T10:30:45.123Z" + assert!(ts.ends_with('Z')); + assert!(ts.contains('T')); + assert!(ts.contains('-')); + assert!(ts.contains(':')); + assert_eq!(ts.len(), 24); // YYYY-MM-DDTHH:MM:SS.mmmZ + } + + #[test] + fn test_days_to_ymd_epoch() { + let (y, m, d) = days_to_ymd(0); + assert_eq!((y, m, d), (1970, 1, 1)); + } + + #[test] + fn test_days_to_ymd_known_date() { + // 2024-01-01 is day 19723 + let (y, m, d) = days_to_ymd(19723); + assert_eq!((y, m, d), (2024, 1, 1)); + } +} diff --git a/npm/socket-patch-android-arm64/package.json b/npm/socket-patch-android-arm64/package.json new file mode 100644 index 0000000..dd9ece7 --- /dev/null +++ b/npm/socket-patch-android-arm64/package.json @@ -0,0 +1,19 @@ +{ + "name": "@socketsecurity/socket-patch-android-arm64", + "version": "2.1.4", + "description": "socket-patch binary for Android ARM64", + "os": [ + "android" + ], + "cpu": [ + "arm64" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-darwin-arm64/package.json b/npm/socket-patch-darwin-arm64/package.json new file mode 100644 index 0000000..81e0715 --- /dev/null +++ b/npm/socket-patch-darwin-arm64/package.json @@ -0,0 +1,19 @@ +{ + "name": "@socketsecurity/socket-patch-darwin-arm64", + "version": "2.1.4", + "description": "socket-patch binary for macOS ARM64", + "os": [ + "darwin" + ], + "cpu": [ + "arm64" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-darwin-x64/package.json b/npm/socket-patch-darwin-x64/package.json new file mode 100644 index 0000000..9975af8 --- /dev/null +++ b/npm/socket-patch-darwin-x64/package.json @@ -0,0 +1,19 @@ +{ + "name": "@socketsecurity/socket-patch-darwin-x64", + "version": "2.1.4", + "description": "socket-patch binary for macOS x64", + "os": [ + "darwin" + ], + "cpu": [ + "x64" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-arm-gnu/package.json b/npm/socket-patch-linux-arm-gnu/package.json new file mode 100644 index 0000000..1b04eab --- /dev/null +++ b/npm/socket-patch-linux-arm-gnu/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-arm-gnu", + "version": "2.1.4", + "description": "socket-patch binary for Linux ARM (glibc)", + "os": [ + "linux" + ], + "cpu": [ + "arm" + ], + "libc": [ + "glibc" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-arm-musl/package.json b/npm/socket-patch-linux-arm-musl/package.json new file mode 100644 index 0000000..dfd42a0 --- /dev/null +++ b/npm/socket-patch-linux-arm-musl/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-arm-musl", + "version": "2.1.4", + "description": "socket-patch binary for Linux ARM (musl)", + "os": [ + "linux" + ], + "cpu": [ + "arm" + ], + "libc": [ + "musl" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-arm64-gnu/package.json b/npm/socket-patch-linux-arm64-gnu/package.json new file mode 100644 index 0000000..412eee6 --- /dev/null +++ b/npm/socket-patch-linux-arm64-gnu/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-arm64-gnu", + "version": "2.1.4", + "description": "socket-patch binary for Linux ARM64 (glibc)", + "os": [ + "linux" + ], + "cpu": [ + "arm64" + ], + "libc": [ + "glibc" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-arm64-musl/package.json b/npm/socket-patch-linux-arm64-musl/package.json new file mode 100644 index 0000000..9c95bad --- /dev/null +++ b/npm/socket-patch-linux-arm64-musl/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-arm64-musl", + "version": "2.1.4", + "description": "socket-patch binary for Linux ARM64 (musl)", + "os": [ + "linux" + ], + "cpu": [ + "arm64" + ], + "libc": [ + "musl" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-ia32-gnu/package.json b/npm/socket-patch-linux-ia32-gnu/package.json new file mode 100644 index 0000000..450a198 --- /dev/null +++ b/npm/socket-patch-linux-ia32-gnu/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-ia32-gnu", + "version": "2.1.4", + "description": "socket-patch binary for Linux ia32 (glibc)", + "os": [ + "linux" + ], + "cpu": [ + "ia32" + ], + "libc": [ + "glibc" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-ia32-musl/package.json b/npm/socket-patch-linux-ia32-musl/package.json new file mode 100644 index 0000000..cd21732 --- /dev/null +++ b/npm/socket-patch-linux-ia32-musl/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-ia32-musl", + "version": "2.1.4", + "description": "socket-patch binary for Linux ia32 (musl)", + "os": [ + "linux" + ], + "cpu": [ + "ia32" + ], + "libc": [ + "musl" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-x64-gnu/package.json b/npm/socket-patch-linux-x64-gnu/package.json new file mode 100644 index 0000000..5cfc8c5 --- /dev/null +++ b/npm/socket-patch-linux-x64-gnu/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-x64-gnu", + "version": "2.1.4", + "description": "socket-patch binary for Linux x64 (glibc)", + "os": [ + "linux" + ], + "cpu": [ + "x64" + ], + "libc": [ + "glibc" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-linux-x64-musl/package.json b/npm/socket-patch-linux-x64-musl/package.json new file mode 100644 index 0000000..478885b --- /dev/null +++ b/npm/socket-patch-linux-x64-musl/package.json @@ -0,0 +1,22 @@ +{ + "name": "@socketsecurity/socket-patch-linux-x64-musl", + "version": "2.1.4", + "description": "socket-patch binary for Linux x64 (musl)", + "os": [ + "linux" + ], + "cpu": [ + "x64" + ], + "libc": [ + "musl" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-win32-arm64/package.json b/npm/socket-patch-win32-arm64/package.json new file mode 100644 index 0000000..a0a2b32 --- /dev/null +++ b/npm/socket-patch-win32-arm64/package.json @@ -0,0 +1,19 @@ +{ + "name": "@socketsecurity/socket-patch-win32-arm64", + "version": "2.1.4", + "description": "socket-patch binary for Windows ARM64", + "os": [ + "win32" + ], + "cpu": [ + "arm64" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-win32-ia32/package.json b/npm/socket-patch-win32-ia32/package.json new file mode 100644 index 0000000..5c2aade --- /dev/null +++ b/npm/socket-patch-win32-ia32/package.json @@ -0,0 +1,19 @@ +{ + "name": "@socketsecurity/socket-patch-win32-ia32", + "version": "2.1.4", + "description": "socket-patch binary for Windows ia32", + "os": [ + "win32" + ], + "cpu": [ + "ia32" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch-win32-x64/package.json b/npm/socket-patch-win32-x64/package.json new file mode 100644 index 0000000..054eff5 --- /dev/null +++ b/npm/socket-patch-win32-x64/package.json @@ -0,0 +1,19 @@ +{ + "name": "@socketsecurity/socket-patch-win32-x64", + "version": "2.1.4", + "description": "socket-patch binary for Windows x64", + "os": [ + "win32" + ], + "cpu": [ + "x64" + ], + "publishConfig": { + "access": "public" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + } +} diff --git a/npm/socket-patch/bin/socket-patch b/npm/socket-patch/bin/socket-patch new file mode 100755 index 0000000..76361f7 --- /dev/null +++ b/npm/socket-patch/bin/socket-patch @@ -0,0 +1,46 @@ +#!/usr/bin/env node +const { spawnSync } = require("child_process"); +const path = require("path"); + +const PLATFORMS = { + "darwin arm64": ["@socketsecurity/socket-patch-darwin-arm64"], + "darwin x64": ["@socketsecurity/socket-patch-darwin-x64"], + "linux x64": ["@socketsecurity/socket-patch-linux-x64-gnu", "@socketsecurity/socket-patch-linux-x64-musl"], + "linux arm64": ["@socketsecurity/socket-patch-linux-arm64-gnu", "@socketsecurity/socket-patch-linux-arm64-musl"], + "linux arm": ["@socketsecurity/socket-patch-linux-arm-gnu", "@socketsecurity/socket-patch-linux-arm-musl"], + "linux ia32": ["@socketsecurity/socket-patch-linux-ia32-gnu", "@socketsecurity/socket-patch-linux-ia32-musl"], + "win32 x64": ["@socketsecurity/socket-patch-win32-x64"], + "win32 ia32": ["@socketsecurity/socket-patch-win32-ia32"], + "win32 arm64": ["@socketsecurity/socket-patch-win32-arm64"], + "android arm64": ["@socketsecurity/socket-patch-android-arm64"], +}; + +const key = `${process.platform} ${process.arch}`; +const candidates = PLATFORMS[key]; +if (!candidates) { + console.error(`Unsupported platform: ${key}`); + process.exit(1); +} + +const exe = process.platform === "win32" ? "socket-patch.exe" : "socket-patch"; +let binPath; +for (const pkg of candidates) { + try { + const pkgDir = path.dirname(require.resolve(`${pkg}/package.json`)); + binPath = path.join(pkgDir, exe); + break; + } catch {} +} +if (!binPath) { + // Fallback: try local bin directory (for development or bundled installs) + const localBin = process.platform === "win32" + ? `socket-patch-${key.replace(" ", "-")}.exe` + : `socket-patch-${key.replace(" ", "-")}`; + binPath = path.join(__dirname, localBin); +} + +const result = spawnSync(binPath, process.argv.slice(2), { + stdio: "inherit", + env: process.env, +}); +process.exit(result.status ?? 1); diff --git a/npm/socket-patch/bin/socket-patch.test.mjs b/npm/socket-patch/bin/socket-patch.test.mjs new file mode 100644 index 0000000..1894f11 --- /dev/null +++ b/npm/socket-patch/bin/socket-patch.test.mjs @@ -0,0 +1,73 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; +import { readFileSync } from "node:fs"; +import { fileURLToPath } from "node:url"; +import { dirname, join } from "node:path"; + +const __dirname = dirname(fileURLToPath(import.meta.url)); +const src = readFileSync(join(__dirname, "socket-patch"), "utf8"); + +// Extract the PLATFORMS object from the source +const match = src.match(/const PLATFORMS = \{([\s\S]*?)\};/); +assert.ok(match, "PLATFORMS object not found in socket-patch"); + +// Parse keys and array values from the object literal +// Matches: "key": ["value1", "value2"] or "key": ["value1"] +const entries = []; +const entryRegex = /"([^"]+)":\s*\[([\s\S]*?)\]/g; +let m; +while ((m = entryRegex.exec(match[1])) !== null) { + const key = m[1]; + const values = [...m[2].matchAll(/"([^"]+)"/g)].map(([, v]) => v); + entries.push([key, values]); +} +const PLATFORMS = Object.fromEntries(entries); + +const EXPECTED_KEYS = [ + "darwin arm64", + "darwin x64", + "linux x64", + "linux arm64", + "linux arm", + "linux ia32", + "win32 x64", + "win32 ia32", + "win32 arm64", + "android arm64", +]; + +describe("npm platform dispatch", () => { + it("has all expected platform keys", () => { + for (const key of EXPECTED_KEYS) { + assert.ok(PLATFORMS[key], `missing platform key: ${key}`); + } + }); + + it("has no unexpected platform keys", () => { + for (const key of Object.keys(PLATFORMS)) { + assert.ok(EXPECTED_KEYS.includes(key), `unexpected platform key: ${key}`); + } + }); + + it("non-Linux package names follow @socketsecurity/socket-patch-- convention", () => { + for (const [key, candidates] of Object.entries(PLATFORMS)) { + if (key.startsWith("linux ")) continue; + const [platform, arch] = key.split(" "); + assert.equal(candidates.length, 1, `expected 1 candidate for ${key}`); + const expected = `@socketsecurity/socket-patch-${platform}-${arch}`; + assert.equal(candidates[0], expected, `package name mismatch for ${key}`); + } + }); + + it("Linux entries have both glibc and musl candidates", () => { + for (const [key, candidates] of Object.entries(PLATFORMS)) { + if (!key.startsWith("linux ")) continue; + const [, arch] = key.split(" "); + assert.equal(candidates.length, 2, `expected 2 candidates for ${key}`); + const gnuPkg = `@socketsecurity/socket-patch-linux-${arch}-gnu`; + const muslPkg = `@socketsecurity/socket-patch-linux-${arch}-musl`; + assert.equal(candidates[0], gnuPkg, `first candidate for ${key} should be gnu`); + assert.equal(candidates[1], muslPkg, `second candidate for ${key} should be musl`); + } + }); +}); diff --git a/npm/socket-patch/package.json b/npm/socket-patch/package.json new file mode 100644 index 0000000..1be5c82 --- /dev/null +++ b/npm/socket-patch/package.json @@ -0,0 +1,60 @@ +{ + "name": "@socketsecurity/socket-patch", + "version": "2.1.4", + "description": "CLI tool and schema library for applying security patches to dependencies", + "bin": { + "socket-patch": "bin/socket-patch" + }, + "exports": { + "./schema": { + "types": "./dist/schema/manifest-schema.d.ts", + "import": "./dist/schema/manifest-schema.js", + "require": "./dist/schema/manifest-schema.js" + } + }, + "publishConfig": { + "access": "public" + }, + "scripts": { + "build": "tsc", + "test": "pnpm run build && node --test dist/**/*.test.js" + }, + "keywords": [ + "security", + "patch", + "cli", + "dependencies" + ], + "author": "Socket Security", + "license": "MIT", + "repository": { + "type": "git", + "url": "https://github.com/SocketDev/socket-patch" + }, + "engines": { + "node": ">=18.0.0" + }, + "dependencies": { + "zod": "^3.24.4" + }, + "devDependencies": { + "typescript": "^5.3.0", + "@types/node": "^20.0.0" + }, + "optionalDependencies": { + "@socketsecurity/socket-patch-android-arm64": "2.1.4", + "@socketsecurity/socket-patch-darwin-arm64": "2.1.4", + "@socketsecurity/socket-patch-darwin-x64": "2.1.4", + "@socketsecurity/socket-patch-linux-arm-gnu": "2.1.4", + "@socketsecurity/socket-patch-linux-arm-musl": "2.1.4", + "@socketsecurity/socket-patch-linux-arm64-gnu": "2.1.4", + "@socketsecurity/socket-patch-linux-arm64-musl": "2.1.4", + "@socketsecurity/socket-patch-linux-ia32-gnu": "2.1.4", + "@socketsecurity/socket-patch-linux-ia32-musl": "2.1.4", + "@socketsecurity/socket-patch-linux-x64-gnu": "2.1.4", + "@socketsecurity/socket-patch-linux-x64-musl": "2.1.4", + "@socketsecurity/socket-patch-win32-arm64": "2.1.4", + "@socketsecurity/socket-patch-win32-ia32": "2.1.4", + "@socketsecurity/socket-patch-win32-x64": "2.1.4" + } +} diff --git a/npm/socket-patch/src/schema/manifest-schema.test.ts b/npm/socket-patch/src/schema/manifest-schema.test.ts new file mode 100644 index 0000000..d87a0f5 --- /dev/null +++ b/npm/socket-patch/src/schema/manifest-schema.test.ts @@ -0,0 +1,131 @@ +import { describe, it } from 'node:test' +import * as assert from 'node:assert/strict' +import { PatchManifestSchema, PatchRecordSchema } from './manifest-schema.js' + +describe('PatchManifestSchema', () => { + it('should validate a well-formed manifest', () => { + const manifest = { + patches: { + 'npm:simplehttpserver@0.0.6': { + uuid: '550e8400-e29b-41d4-a716-446655440000', + exportedAt: '2024-01-01T00:00:00Z', + files: { + 'node_modules/simplehttpserver/index.js': { + beforeHash: 'abc123', + afterHash: 'def456', + }, + }, + vulnerabilities: { + 'GHSA-jrhj-2j3q-xf3v': { + cves: ['CVE-2024-0001'], + summary: 'Path traversal vulnerability', + severity: 'high', + description: 'Allows reading arbitrary files', + }, + }, + description: 'Fix path traversal', + license: 'MIT', + tier: 'free', + }, + }, + } + + const result = PatchManifestSchema.safeParse(manifest) + assert.ok(result.success, 'Valid manifest should parse successfully') + assert.equal( + Object.keys(result.data.patches).length, + 1, + 'Should have one patch entry', + ) + }) + + it('should validate a manifest with multiple patches', () => { + const manifest = { + patches: { + 'npm:pkg-a@1.0.0': { + uuid: '550e8400-e29b-41d4-a716-446655440001', + exportedAt: '2024-01-01T00:00:00Z', + files: { + 'node_modules/pkg-a/lib/index.js': { + beforeHash: 'aaa', + afterHash: 'bbb', + }, + }, + vulnerabilities: {}, + description: 'Patch A', + license: 'MIT', + tier: 'free', + }, + 'npm:pkg-b@2.0.0': { + uuid: '550e8400-e29b-41d4-a716-446655440002', + exportedAt: '2024-02-01T00:00:00Z', + files: { + 'node_modules/pkg-b/src/main.js': { + beforeHash: 'ccc', + afterHash: 'ddd', + }, + }, + vulnerabilities: { + 'GHSA-xxxx-yyyy-zzzz': { + cves: [], + summary: 'Some vuln', + severity: 'medium', + description: 'A medium severity vulnerability', + }, + }, + description: 'Patch B', + license: 'Apache-2.0', + tier: 'paid', + }, + }, + } + + const result = PatchManifestSchema.safeParse(manifest) + assert.ok(result.success, 'Multi-patch manifest should parse successfully') + assert.equal(Object.keys(result.data.patches).length, 2) + }) + + it('should validate an empty manifest', () => { + const manifest = { patches: {} } + const result = PatchManifestSchema.safeParse(manifest) + assert.ok(result.success, 'Empty patches should be valid') + }) + + it('should reject a manifest missing the patches field', () => { + const result = PatchManifestSchema.safeParse({}) + assert.ok(!result.success, 'Missing patches should fail') + }) + + it('should reject a manifest with invalid patch record', () => { + const manifest = { + patches: { + 'npm:bad@1.0.0': { + // missing uuid, exportedAt, files, vulnerabilities, description, license, tier + }, + }, + } + const result = PatchManifestSchema.safeParse(manifest) + assert.ok(!result.success, 'Invalid patch record should fail') + }) + + it('should reject a patch with invalid uuid', () => { + const record = { + uuid: 'not-a-valid-uuid', + exportedAt: '2024-01-01T00:00:00Z', + files: {}, + vulnerabilities: {}, + description: 'Test', + license: 'MIT', + tier: 'free', + } + const result = PatchRecordSchema.safeParse(record) + assert.ok(!result.success, 'Invalid UUID should fail') + }) + + it('should reject non-object input', () => { + assert.ok(!PatchManifestSchema.safeParse(null).success) + assert.ok(!PatchManifestSchema.safeParse('string').success) + assert.ok(!PatchManifestSchema.safeParse(42).success) + assert.ok(!PatchManifestSchema.safeParse([]).success) + }) +}) diff --git a/src/schema/manifest-schema.ts b/npm/socket-patch/src/schema/manifest-schema.ts similarity index 100% rename from src/schema/manifest-schema.ts rename to npm/socket-patch/src/schema/manifest-schema.ts diff --git a/npm/socket-patch/tsconfig.json b/npm/socket-patch/tsconfig.json new file mode 100644 index 0000000..6998287 --- /dev/null +++ b/npm/socket-patch/tsconfig.json @@ -0,0 +1,16 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "Node16", + "moduleResolution": "Node16", + "declaration": true, + "composite": true, + "outDir": "dist", + "rootDir": "src", + "strict": true, + "skipLibCheck": true, + "esModuleInterop": true, + "types": ["node"] + }, + "include": ["src"] +} diff --git a/package.json b/package.json deleted file mode 100644 index c80779c..0000000 --- a/package.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "name": "@socketsecurity/socket-patch", - "version": "0.1.0", - "packageManager": "pnpm@10.16.1", - "description": "CLI tool for applying security patches to dependencies", - "main": "dist/index.js", - "types": "dist/index.d.ts", - "bin": { - "socket-patch": "dist/cli.js" - }, - "exports": { - ".": { - "types": "./dist/index.d.ts", - "require": "./dist/index.js", - "import": "./dist/index.js" - }, - "./schema": { - "types": "./dist/schema/manifest-schema.d.ts", - "require": "./dist/schema/manifest-schema.js", - "import": "./dist/schema/manifest-schema.js" - }, - "./hash": { - "types": "./dist/hash/git-sha256.d.ts", - "require": "./dist/hash/git-sha256.js", - "import": "./dist/hash/git-sha256.js" - }, - "./patch": { - "types": "./dist/patch/apply.d.ts", - "require": "./dist/patch/apply.js", - "import": "./dist/patch/apply.js" - }, - "./manifest/operations": { - "types": "./dist/manifest/operations.d.ts", - "require": "./dist/manifest/operations.js", - "import": "./dist/manifest/operations.js" - }, - "./manifest/recovery": { - "types": "./dist/manifest/recovery.d.ts", - "require": "./dist/manifest/recovery.js", - "import": "./dist/manifest/recovery.js" - }, - "./constants": { - "types": "./dist/constants.d.ts", - "require": "./dist/constants.js", - "import": "./dist/constants.js" - }, - "./package-json": { - "types": "./dist/package-json/index.d.ts", - "require": "./dist/package-json/index.js", - "import": "./dist/package-json/index.js" - } - }, - "scripts": { - "build": "tsc", - "dev": "tsc --watch", - "patch": "node dist/cli.js", - "lint": "oxlint -c ./.oxlintrc.json --tsconfig ./tsconfig.json --deny-warnings", - "lint:fix": "pnpm run lint --fix && pnpm run lint:fix:fast", - "lint:fix:fast": "biome format --write", - "publish:ci": "npm publish --provenance --access public" - }, - "publishConfig": { - "access": "public", - "registry": "https://registry.npmjs.org/" - }, - "keywords": [ - "security", - "patch", - "cli", - "dependencies" - ], - "author": "Socket Security", - "license": "MIT", - "dependencies": { - "yargs": "^17.7.2", - "zod": "^3.24.4" - }, - "devDependencies": { - "@biomejs/biome": "^2.1.2", - "@types/node": "^20.0.0", - "@types/yargs": "^17.0.32", - "oxlint": "^1.15.0", - "typescript": "^5.3.0" - }, - "engines": { - "node": ">=18.0.0" - } -} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml deleted file mode 100644 index 4108d26..0000000 --- a/pnpm-lock.yaml +++ /dev/null @@ -1,365 +0,0 @@ -lockfileVersion: '9.0' - -settings: - autoInstallPeers: true - excludeLinksFromLockfile: false - -importers: - - .: - dependencies: - yargs: - specifier: ^17.7.2 - version: 17.7.2 - zod: - specifier: ^3.24.4 - version: 3.25.76 - devDependencies: - '@biomejs/biome': - specifier: ^2.1.2 - version: 2.3.5 - '@types/node': - specifier: ^20.0.0 - version: 20.19.25 - '@types/yargs': - specifier: ^17.0.32 - version: 17.0.35 - oxlint: - specifier: ^1.15.0 - version: 1.28.0 - typescript: - specifier: ^5.3.0 - version: 5.9.3 - -packages: - - '@biomejs/biome@2.3.5': - resolution: {integrity: sha512-HvLhNlIlBIbAV77VysRIBEwp55oM/QAjQEin74QQX9Xb259/XP/D5AGGnZMOyF1el4zcvlNYYR3AyTMUV3ILhg==} - engines: {node: '>=14.21.3'} - hasBin: true - - '@biomejs/cli-darwin-arm64@2.3.5': - resolution: {integrity: sha512-fLdTur8cJU33HxHUUsii3GLx/TR0BsfQx8FkeqIiW33cGMtUD56fAtrh+2Fx1uhiCsVZlFh6iLKUU3pniZREQw==} - engines: {node: '>=14.21.3'} - cpu: [arm64] - os: [darwin] - - '@biomejs/cli-darwin-x64@2.3.5': - resolution: {integrity: sha512-qpT8XDqeUlzrOW8zb4k3tjhT7rmvVRumhi2657I2aGcY4B+Ft5fNwDdZGACzn8zj7/K1fdWjgwYE3i2mSZ+vOA==} - engines: {node: '>=14.21.3'} - cpu: [x64] - os: [darwin] - - '@biomejs/cli-linux-arm64-musl@2.3.5': - resolution: {integrity: sha512-eGUG7+hcLgGnMNl1KHVZUYxahYAhC462jF/wQolqu4qso2MSk32Q+QrpN7eN4jAHAg7FUMIo897muIhK4hXhqg==} - engines: {node: '>=14.21.3'} - cpu: [arm64] - os: [linux] - - '@biomejs/cli-linux-arm64@2.3.5': - resolution: {integrity: sha512-u/pybjTBPGBHB66ku4pK1gj+Dxgx7/+Z0jAriZISPX1ocTO8aHh8x8e7Kb1rB4Ms0nA/SzjtNOVJ4exVavQBCw==} - engines: {node: '>=14.21.3'} - cpu: [arm64] - os: [linux] - - '@biomejs/cli-linux-x64-musl@2.3.5': - resolution: {integrity: sha512-awVuycTPpVTH/+WDVnEEYSf6nbCBHf/4wB3lquwT7puhNg8R4XvonWNZzUsfHZrCkjkLhFH/vCZK5jHatD9FEg==} - engines: {node: '>=14.21.3'} - cpu: [x64] - os: [linux] - - '@biomejs/cli-linux-x64@2.3.5': - resolution: {integrity: sha512-XrIVi9YAW6ye0CGQ+yax0gLfx+BFOtKaNX74n+xHWla6Cl6huUmcKNO7HPx7BiKnJUzrxXY1qYlm7xMvi08X4g==} - engines: {node: '>=14.21.3'} - cpu: [x64] - os: [linux] - - '@biomejs/cli-win32-arm64@2.3.5': - resolution: {integrity: sha512-DlBiMlBZZ9eIq4H7RimDSGsYcOtfOIfZOaI5CqsWiSlbTfqbPVfWtCf92wNzx8GNMbu1s7/g3ZZESr6+GwM/SA==} - engines: {node: '>=14.21.3'} - cpu: [arm64] - os: [win32] - - '@biomejs/cli-win32-x64@2.3.5': - resolution: {integrity: sha512-nUmR8gb6yvrKhtRgzwo/gDimPwnO5a4sCydf8ZS2kHIJhEmSmk+STsusr1LHTuM//wXppBawvSQi2xFXJCdgKQ==} - engines: {node: '>=14.21.3'} - cpu: [x64] - os: [win32] - - '@oxlint/darwin-arm64@1.28.0': - resolution: {integrity: sha512-H7J41/iKbgm7tTpdSnA/AtjEAhxyzNzCMKWtKU5wDuP2v39jrc3fasQEJruk6hj1YXPbJY4N+1nK/jE27GMGDQ==} - cpu: [arm64] - os: [darwin] - - '@oxlint/darwin-x64@1.28.0': - resolution: {integrity: sha512-bGsSDEwpyYzNc6FIwhTmbhSK7piREUjMlmWBt7eoR3ract0+RfhZYYG4se1Ngs+4WOFC0B3gbv23fyF+cnbGGQ==} - cpu: [x64] - os: [darwin] - - '@oxlint/linux-arm64-gnu@1.28.0': - resolution: {integrity: sha512-eNH/evMpV3xAA4jIS8dMLcGkM/LK0WEHM0RO9bxrHPAwfS72jhyPJtd0R7nZhvhG6U1bhn5jhoXbk1dn27XIAQ==} - cpu: [arm64] - os: [linux] - - '@oxlint/linux-arm64-musl@1.28.0': - resolution: {integrity: sha512-ickvpcekNeRLND3llndiZOtJBb6LDZqNnZICIDkovURkOIWPGJGmAxsHUOI6yW6iny9gLmIEIGl/c1b5nFk6Ag==} - cpu: [arm64] - os: [linux] - - '@oxlint/linux-x64-gnu@1.28.0': - resolution: {integrity: sha512-DkgAh4LQ8NR3DwTT7/LGMhaMau0RtZkih91Ez5Usk7H7SOxo1GDi84beE7it2Q+22cAzgY4hbw3c6svonQTjxg==} - cpu: [x64] - os: [linux] - - '@oxlint/linux-x64-musl@1.28.0': - resolution: {integrity: sha512-VBnMi3AJ2w5p/kgeyrjcGOKNY8RzZWWvlGHjCJwzqPgob4MXu6T+5Yrdi7EVJyIlouL8E3LYPYjmzB9NBi9gZw==} - cpu: [x64] - os: [linux] - - '@oxlint/win32-arm64@1.28.0': - resolution: {integrity: sha512-tomhIks+4dKs8axB+s4GXHy+ZWXhUgptf1XnG5cZg8CzRfX4JFX9k8l2fPUgFwytWnyyvZaaXLRPWGzoZ6yoHQ==} - cpu: [arm64] - os: [win32] - - '@oxlint/win32-x64@1.28.0': - resolution: {integrity: sha512-4+VO5P/UJ2nq9sj6kQToJxFy5cKs7dGIN2DiUSQ7cqyUi7EKYNQKe+98HFcDOjtm33jQOQnc4kw8Igya5KPozg==} - cpu: [x64] - os: [win32] - - '@types/node@20.19.25': - resolution: {integrity: sha512-ZsJzA5thDQMSQO788d7IocwwQbI8B5OPzmqNvpf3NY/+MHDAS759Wo0gd2WQeXYt5AAAQjzcrTVC6SKCuYgoCQ==} - - '@types/yargs-parser@21.0.3': - resolution: {integrity: sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==} - - '@types/yargs@17.0.35': - resolution: {integrity: sha512-qUHkeCyQFxMXg79wQfTtfndEC+N9ZZg76HJftDJp+qH2tV7Gj4OJi7l+PiWwJ+pWtW8GwSmqsDj/oymhrTWXjg==} - - ansi-regex@5.0.1: - resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==} - engines: {node: '>=8'} - - ansi-styles@4.3.0: - resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} - engines: {node: '>=8'} - - cliui@8.0.1: - resolution: {integrity: sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==} - engines: {node: '>=12'} - - color-convert@2.0.1: - resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} - engines: {node: '>=7.0.0'} - - color-name@1.1.4: - resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} - - emoji-regex@8.0.0: - resolution: {integrity: sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==} - - escalade@3.2.0: - resolution: {integrity: sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==} - engines: {node: '>=6'} - - get-caller-file@2.0.5: - resolution: {integrity: sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==} - engines: {node: 6.* || 8.* || >= 10.*} - - is-fullwidth-code-point@3.0.0: - resolution: {integrity: sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==} - engines: {node: '>=8'} - - oxlint@1.28.0: - resolution: {integrity: sha512-gE97d0BcIlTTSJrim395B49mIbQ9VO8ZVoHdWai7Svl+lEeUAyCLTN4d7piw1kcB8VfgTp1JFVlAvMPD9GewMA==} - engines: {node: ^20.19.0 || >=22.12.0} - hasBin: true - peerDependencies: - oxlint-tsgolint: '>=0.4.0' - peerDependenciesMeta: - oxlint-tsgolint: - optional: true - - require-directory@2.1.1: - resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} - engines: {node: '>=0.10.0'} - - string-width@4.2.3: - resolution: {integrity: sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==} - engines: {node: '>=8'} - - strip-ansi@6.0.1: - resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==} - engines: {node: '>=8'} - - typescript@5.9.3: - resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} - engines: {node: '>=14.17'} - hasBin: true - - undici-types@6.21.0: - resolution: {integrity: sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==} - - wrap-ansi@7.0.0: - resolution: {integrity: sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==} - engines: {node: '>=10'} - - y18n@5.0.8: - resolution: {integrity: sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==} - engines: {node: '>=10'} - - yargs-parser@21.1.1: - resolution: {integrity: sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==} - engines: {node: '>=12'} - - yargs@17.7.2: - resolution: {integrity: sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==} - engines: {node: '>=12'} - - zod@3.25.76: - resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==} - -snapshots: - - '@biomejs/biome@2.3.5': - optionalDependencies: - '@biomejs/cli-darwin-arm64': 2.3.5 - '@biomejs/cli-darwin-x64': 2.3.5 - '@biomejs/cli-linux-arm64': 2.3.5 - '@biomejs/cli-linux-arm64-musl': 2.3.5 - '@biomejs/cli-linux-x64': 2.3.5 - '@biomejs/cli-linux-x64-musl': 2.3.5 - '@biomejs/cli-win32-arm64': 2.3.5 - '@biomejs/cli-win32-x64': 2.3.5 - - '@biomejs/cli-darwin-arm64@2.3.5': - optional: true - - '@biomejs/cli-darwin-x64@2.3.5': - optional: true - - '@biomejs/cli-linux-arm64-musl@2.3.5': - optional: true - - '@biomejs/cli-linux-arm64@2.3.5': - optional: true - - '@biomejs/cli-linux-x64-musl@2.3.5': - optional: true - - '@biomejs/cli-linux-x64@2.3.5': - optional: true - - '@biomejs/cli-win32-arm64@2.3.5': - optional: true - - '@biomejs/cli-win32-x64@2.3.5': - optional: true - - '@oxlint/darwin-arm64@1.28.0': - optional: true - - '@oxlint/darwin-x64@1.28.0': - optional: true - - '@oxlint/linux-arm64-gnu@1.28.0': - optional: true - - '@oxlint/linux-arm64-musl@1.28.0': - optional: true - - '@oxlint/linux-x64-gnu@1.28.0': - optional: true - - '@oxlint/linux-x64-musl@1.28.0': - optional: true - - '@oxlint/win32-arm64@1.28.0': - optional: true - - '@oxlint/win32-x64@1.28.0': - optional: true - - '@types/node@20.19.25': - dependencies: - undici-types: 6.21.0 - - '@types/yargs-parser@21.0.3': {} - - '@types/yargs@17.0.35': - dependencies: - '@types/yargs-parser': 21.0.3 - - ansi-regex@5.0.1: {} - - ansi-styles@4.3.0: - dependencies: - color-convert: 2.0.1 - - cliui@8.0.1: - dependencies: - string-width: 4.2.3 - strip-ansi: 6.0.1 - wrap-ansi: 7.0.0 - - color-convert@2.0.1: - dependencies: - color-name: 1.1.4 - - color-name@1.1.4: {} - - emoji-regex@8.0.0: {} - - escalade@3.2.0: {} - - get-caller-file@2.0.5: {} - - is-fullwidth-code-point@3.0.0: {} - - oxlint@1.28.0: - optionalDependencies: - '@oxlint/darwin-arm64': 1.28.0 - '@oxlint/darwin-x64': 1.28.0 - '@oxlint/linux-arm64-gnu': 1.28.0 - '@oxlint/linux-arm64-musl': 1.28.0 - '@oxlint/linux-x64-gnu': 1.28.0 - '@oxlint/linux-x64-musl': 1.28.0 - '@oxlint/win32-arm64': 1.28.0 - '@oxlint/win32-x64': 1.28.0 - - require-directory@2.1.1: {} - - string-width@4.2.3: - dependencies: - emoji-regex: 8.0.0 - is-fullwidth-code-point: 3.0.0 - strip-ansi: 6.0.1 - - strip-ansi@6.0.1: - dependencies: - ansi-regex: 5.0.1 - - typescript@5.9.3: {} - - undici-types@6.21.0: {} - - wrap-ansi@7.0.0: - dependencies: - ansi-styles: 4.3.0 - string-width: 4.2.3 - strip-ansi: 6.0.1 - - y18n@5.0.8: {} - - yargs-parser@21.1.1: {} - - yargs@17.7.2: - dependencies: - cliui: 8.0.1 - escalade: 3.2.0 - get-caller-file: 2.0.5 - require-directory: 2.1.1 - string-width: 4.2.3 - y18n: 5.0.8 - yargs-parser: 21.1.1 - - zod@3.25.76: {} diff --git a/pypi/socket-patch/pyproject.toml b/pypi/socket-patch/pyproject.toml new file mode 100644 index 0000000..8b2d70c --- /dev/null +++ b/pypi/socket-patch/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "socket-patch" +version = "2.1.4" +description = "CLI tool for applying security patches to dependencies" +readme = "README.md" +license = "MIT" +requires-python = ">=3.8" +authors = [ + { name = "Socket Security" } +] +keywords = ["security", "patch", "cli", "dependencies"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Topic :: Security", + "Topic :: Software Development :: Build Tools", +] + +[project.urls] +Homepage = "https://github.com/SocketDev/socket-patch" +Repository = "https://github.com/SocketDev/socket-patch" + +[project.scripts] +socket-patch = "socket_patch:main" + +[tool.setuptools.package-data] +socket_patch = ["bin/*"] diff --git a/pypi/socket-patch/socket_patch/__init__.py b/pypi/socket-patch/socket_patch/__init__.py new file mode 100644 index 0000000..bfcb9d2 --- /dev/null +++ b/pypi/socket-patch/socket_patch/__init__.py @@ -0,0 +1,22 @@ +import os +import sys +import subprocess + + +def main(): + bin_dir = os.path.join(os.path.dirname(__file__), "bin") + try: + entries = os.listdir(bin_dir) + except OSError: + entries = [] + bins = [e for e in entries if e.startswith("socket-patch")] + if len(bins) != 1: + print( + f"Expected exactly one socket-patch binary in {bin_dir}, found {len(bins)}", + file=sys.stderr, + ) + sys.exit(1) + bin_path = os.path.join(bin_dir, bins[0]) + if not os.access(bin_path, os.X_OK): + os.chmod(bin_path, os.stat(bin_path).st_mode | 0o111) + raise SystemExit(subprocess.call([bin_path] + sys.argv[1:])) diff --git a/pypi/socket-patch/socket_patch/bin/.gitkeep b/pypi/socket-patch/socket_patch/bin/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/pypi/socket-patch/test_dispatch.py b/pypi/socket-patch/test_dispatch.py new file mode 100644 index 0000000..07380f4 --- /dev/null +++ b/pypi/socket-patch/test_dispatch.py @@ -0,0 +1,128 @@ +import ast +import os +import stat +import sys +import tempfile +import textwrap +import unittest +from pathlib import Path +from unittest import mock + +# Import the module source for inspection +INIT_PATH = Path(__file__).parent / "socket_patch" / "__init__.py" +INIT_SRC = INIT_PATH.read_text() + + +class TestInitModule(unittest.TestCase): + """Test that __init__.py correctly finds and runs the single binary.""" + + def test_source_parses(self): + """Verify __init__.py is valid Python.""" + ast.parse(INIT_SRC) + + def test_main_defined(self): + """Verify main() function exists.""" + tree = ast.parse(INIT_SRC) + func_names = [ + node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) + ] + self.assertIn("main", func_names) + + def test_dispatches_single_binary(self): + """main() should find the single binary in bin/ and call it.""" + with tempfile.TemporaryDirectory() as tmpdir: + bin_dir = os.path.join(tmpdir, "bin") + os.makedirs(bin_dir) + fake_bin = os.path.join(bin_dir, "socket-patch-test") + Path(fake_bin).write_text("#!/bin/sh\nexit 0\n") + os.chmod(fake_bin, os.stat(fake_bin).st_mode | stat.S_IEXEC) + + with mock.patch("socket_patch.os.path.dirname", return_value=tmpdir): + with mock.patch("socket_patch.subprocess.call", return_value=42) as mock_call: + with self.assertRaises(SystemExit) as cm: + import socket_patch + + socket_patch.main() + self.assertEqual(cm.exception.code, 42) + mock_call.assert_called_once() + called_args = mock_call.call_args[0][0] + self.assertEqual(called_args[0], fake_bin) + + def test_errors_on_no_binary(self): + """main() should exit with error if no binary found.""" + with tempfile.TemporaryDirectory() as tmpdir: + bin_dir = os.path.join(tmpdir, "bin") + os.makedirs(bin_dir) + + with mock.patch("socket_patch.os.path.dirname", return_value=tmpdir): + with self.assertRaises(SystemExit) as cm: + import socket_patch + + socket_patch.main() + self.assertEqual(cm.exception.code, 1) + + def test_errors_on_multiple_binaries(self): + """main() should exit with error if multiple binaries found.""" + with tempfile.TemporaryDirectory() as tmpdir: + bin_dir = os.path.join(tmpdir, "bin") + os.makedirs(bin_dir) + Path(os.path.join(bin_dir, "socket-patch-a")).touch() + Path(os.path.join(bin_dir, "socket-patch-b")).touch() + + with mock.patch("socket_patch.os.path.dirname", return_value=tmpdir): + with self.assertRaises(SystemExit) as cm: + import socket_patch + + socket_patch.main() + self.assertEqual(cm.exception.code, 1) + + def test_errors_on_missing_bin_dir(self): + """main() should exit with error if bin/ dir doesn't exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Don't create bin_dir + with mock.patch("socket_patch.os.path.dirname", return_value=tmpdir): + with self.assertRaises(SystemExit) as cm: + import socket_patch + + socket_patch.main() + self.assertEqual(cm.exception.code, 1) + + +class TestWheelBuilder(unittest.TestCase): + """Test the wheel builder script configuration.""" + + def test_wheel_builder_exists(self): + """Verify the wheel builder script exists.""" + script_path = Path(__file__).parent.parent.parent / "scripts" / "build-pypi-wheels.py" + self.assertTrue(script_path.exists(), f"Wheel builder script not found at {script_path}") + + def test_wheel_builder_parses(self): + """Verify the wheel builder script is valid Python.""" + script_path = Path(__file__).parent.parent.parent / "scripts" / "build-pypi-wheels.py" + ast.parse(script_path.read_text()) + + def test_wheel_builder_targets(self): + """Verify the wheel builder covers all expected targets.""" + script_path = Path(__file__).parent.parent.parent / "scripts" / "build-pypi-wheels.py" + src = script_path.read_text() + + expected_targets = [ + "aarch64-apple-darwin", + "x86_64-apple-darwin", + "x86_64-unknown-linux-musl", + "aarch64-unknown-linux-gnu", + "arm-unknown-linux-gnueabihf", + "i686-unknown-linux-gnu", + "x86_64-pc-windows-msvc", + "i686-pc-windows-msvc", + "aarch64-pc-windows-msvc", + ] + for target in expected_targets: + self.assertIn(target, src, f"Target {target} missing from wheel builder") + + # Android should NOT be in the targets + self.assertNotIn('"aarch64-linux-android"', src) + + +if __name__ == "__main__": + unittest.main() diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..292fe49 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "stable" diff --git a/scripts/build-pypi-wheels.py b/scripts/build-pypi-wheels.py new file mode 100755 index 0000000..ab9f03b --- /dev/null +++ b/scripts/build-pypi-wheels.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +"""Build platform-tagged PyPI wheels for socket-patch. + +Each wheel contains only the binary for a single platform, so users download +only the ~4 MB they need instead of ~40 MB for all platforms. +""" + +import argparse +import csv +import hashlib +import io +import os +import re +import stat +import subprocess +import sys +import tempfile +import zipfile +from base64 import urlsafe_b64encode +from pathlib import Path + +# Mapping from Rust target triple to: +# (wheel platform tag(s), archive extension, binary name inside archive) +# Android is omitted — no standard PyPI platform tag exists for it. +TARGETS = { + "aarch64-apple-darwin": { + "platform_tag": "macosx_11_0_arm64", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "x86_64-apple-darwin": { + "platform_tag": "macosx_10_12_x86_64", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "x86_64-unknown-linux-gnu": { + "platform_tag": "manylinux_2_17_x86_64.manylinux2014_x86_64", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "x86_64-unknown-linux-musl": { + "platform_tag": "musllinux_1_1_x86_64", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "aarch64-unknown-linux-gnu": { + "platform_tag": "manylinux_2_17_aarch64.manylinux2014_aarch64", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "aarch64-unknown-linux-musl": { + "platform_tag": "musllinux_1_1_aarch64", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "arm-unknown-linux-gnueabihf": { + "platform_tag": "manylinux_2_17_armv7l.manylinux2014_armv7l", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "arm-unknown-linux-musleabihf": { + "platform_tag": "musllinux_1_1_armv7l", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "i686-unknown-linux-gnu": { + "platform_tag": "manylinux_2_17_i686.manylinux2014_i686", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "i686-unknown-linux-musl": { + "platform_tag": "musllinux_1_1_i686", + "archive_ext": "tar.gz", + "binary_name": "socket-patch", + }, + "x86_64-pc-windows-msvc": { + "platform_tag": "win_amd64", + "archive_ext": "zip", + "binary_name": "socket-patch.exe", + }, + "i686-pc-windows-msvc": { + "platform_tag": "win32", + "archive_ext": "zip", + "binary_name": "socket-patch.exe", + }, + "aarch64-pc-windows-msvc": { + "platform_tag": "win_arm64", + "archive_ext": "zip", + "binary_name": "socket-patch.exe", + }, +} + +DIST_NAME = "socket_patch" +PKG_NAME = "socket-patch" + + +def sha256_digest(data: bytes) -> str: + """Return the URL-safe base64 SHA-256 digest for RECORD.""" + h = hashlib.sha256(data) + return "sha256=" + urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=") + + +def extract_binary(artifacts_dir: Path, target: str, info: dict) -> bytes: + """Extract the binary from the artifact archive and return its contents.""" + ext = info["archive_ext"] + archive_path = artifacts_dir / f"socket-patch-{target}.{ext}" + if not archive_path.exists(): + raise FileNotFoundError(f"Artifact not found: {archive_path}") + + binary_name = info["binary_name"] + if ext == "tar.gz": + import tarfile + + with tarfile.open(archive_path, "r:gz") as tf: + member = tf.getmember(binary_name) + f = tf.extractfile(member) + if f is None: + raise ValueError(f"Could not extract {binary_name} from {archive_path}") + return f.read() + elif ext == "zip": + with zipfile.ZipFile(archive_path, "r") as zf: + return zf.read(binary_name) + else: + raise ValueError(f"Unknown archive extension: {ext}") + + +def read_pyproject_metadata(pyproject_dir: Path) -> dict: + """Read metadata fields from pyproject.toml (simple parser, no toml dep).""" + pyproject_path = pyproject_dir / "pyproject.toml" + text = pyproject_path.read_text() + + def extract_field(name: str) -> str: + m = re.search(rf'^{name}\s*=\s*"(.*?)"', text, re.MULTILINE) + if not m: + raise ValueError(f"Could not find {name} in {pyproject_path}") + return m.group(1) + + readme_path = pyproject_dir / "README.md" + readme = readme_path.read_text() if readme_path.exists() else "" + + return { + "name": extract_field("name"), + "version": extract_field("version"), + "description": extract_field("description"), + "license": extract_field("license"), + "requires_python": extract_field("requires-python"), + "readme": readme, + } + + +def read_init_py(pyproject_dir: Path) -> bytes: + """Read the __init__.py file for inclusion in wheels.""" + init_path = pyproject_dir / "socket_patch" / "__init__.py" + return init_path.read_bytes() + + +def build_wheel( + target: str, + info: dict, + version: str, + metadata: dict, + init_py: bytes, + binary_data: bytes, + dist_dir: Path, +) -> Path: + """Build a single platform-tagged wheel and return the path.""" + platform_tag = info["platform_tag"] + binary_name = info["binary_name"] + + # Wheel filename: {name}-{version}-{python tag}-{abi tag}-{platform tag}.whl + wheel_name = f"{DIST_NAME}-{version}-py3-none-{platform_tag}.whl" + wheel_path = dist_dir / wheel_name + + dist_info = f"{DIST_NAME}-{version}.dist-info" + + # Build file entries: (archive_name, data, is_executable) + files = [] + + # __init__.py + files.append((f"{DIST_NAME}/__init__.py", init_py, False)) + + # Binary + files.append((f"{DIST_NAME}/bin/{binary_name}", binary_data, True)) + + # METADATA + metadata_header = ( + f"Metadata-Version: 2.1\n" + f"Name: {metadata['name']}\n" + f"Version: {version}\n" + f"Summary: {metadata['description']}\n" + f"License: {metadata['license']}\n" + f"Requires-Python: {metadata['requires_python']}\n" + ) + if metadata.get("readme"): + metadata_header += "Description-Content-Type: text/markdown\n" + metadata_header += f"\n{metadata['readme']}" + metadata_content = metadata_header.encode() + files.append((f"{dist_info}/METADATA", metadata_content, False)) + + # WHEEL + wheel_content = ( + f"Wheel-Version: 1.0\n" + f"Generator: build-pypi-wheels.py\n" + f"Root-Is-Purelib: false\n" + f"Tag: py3-none-{platform_tag}\n" + ).encode() + files.append((f"{dist_info}/WHEEL", wheel_content, False)) + + # entry_points.txt + entry_points_content = ( + "[console_scripts]\n" "socket-patch = socket_patch:main\n" + ).encode() + files.append((f"{dist_info}/entry_points.txt", entry_points_content, False)) + + # Build RECORD (must be last, references all other files) + record_lines = [] + for name, data, _ in files: + record_lines.append(f"{name},{sha256_digest(data)},{len(data)}") + # RECORD itself has no hash + record_name = f"{dist_info}/RECORD" + record_lines.append(f"{record_name},,") + record_content = "\n".join(record_lines).encode() + files.append((record_name, record_content, False)) + + # Write the zip + with zipfile.ZipFile(wheel_path, "w", zipfile.ZIP_DEFLATED) as zf: + for name, data, is_exec in files: + info_obj = zipfile.ZipInfo(name) + # Set external_attr for executable files (unix permissions) + if is_exec: + info_obj.external_attr = (stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH) << 16 + else: + info_obj.external_attr = (stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) << 16 + info_obj.compress_type = zipfile.ZIP_DEFLATED + zf.writestr(info_obj, data) + + return wheel_path + + +def main(): + parser = argparse.ArgumentParser( + description="Build platform-tagged PyPI wheels for socket-patch" + ) + parser.add_argument( + "--version", + required=True, + help="Package version (e.g., 1.5.0)", + ) + parser.add_argument( + "--artifacts", + required=True, + help="Directory containing build artifacts", + ) + parser.add_argument( + "--dist", + default="dist", + help="Output directory for wheels (default: dist)", + ) + parser.add_argument( + "--pyproject-dir", + default=None, + help="Directory containing pyproject.toml (default: pypi/socket-patch relative to script)", + ) + args = parser.parse_args() + + artifacts_dir = Path(args.artifacts) + dist_dir = Path(args.dist) + dist_dir.mkdir(parents=True, exist_ok=True) + + if args.pyproject_dir: + pyproject_dir = Path(args.pyproject_dir) + else: + pyproject_dir = Path(__file__).resolve().parent.parent / "pypi" / "socket-patch" + + metadata = read_pyproject_metadata(pyproject_dir) + init_py = read_init_py(pyproject_dir) + + built = [] + skipped = [] + + for target, info in TARGETS.items(): + archive_ext = info["archive_ext"] + archive_path = artifacts_dir / f"socket-patch-{target}.{archive_ext}" + if not archive_path.exists(): + skipped.append(target) + continue + + print(f"Building wheel for {target} ({info['platform_tag']})...") + binary_data = extract_binary(artifacts_dir, target, info) + wheel_path = build_wheel( + target=target, + info=info, + version=args.version, + metadata=metadata, + init_py=init_py, + binary_data=binary_data, + dist_dir=dist_dir, + ) + size_mb = wheel_path.stat().st_size / (1024 * 1024) + print(f" -> {wheel_path.name} ({size_mb:.1f} MB)") + built.append(wheel_path) + + print(f"\nBuilt {len(built)} wheel(s) in {dist_dir}/") + if skipped: + print(f"Skipped {len(skipped)} target(s) (artifact not found): {', '.join(skipped)}") + + if not built: + print("ERROR: No wheels were built!", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/install.sh b/scripts/install.sh new file mode 100755 index 0000000..41a1510 --- /dev/null +++ b/scripts/install.sh @@ -0,0 +1,100 @@ +#!/bin/sh +set -eu + +# Socket Patch installer +# Usage: curl -fsSL https://raw.githubusercontent.com/SocketDev/socket-patch/main/scripts/install.sh | sh + +REPO="SocketDev/socket-patch" +BINARY="socket-patch" + +# Detect platform +OS="$(uname -s)" +ARCH="$(uname -m)" + +case "$OS" in + Darwin) + case "$ARCH" in + arm64) TARGET="aarch64-apple-darwin" ;; + x86_64) TARGET="x86_64-apple-darwin" ;; + *) echo "Error: unsupported architecture: $ARCH" >&2; exit 1 ;; + esac + ;; + Linux) + # Detect libc: musl or glibc + detect_libc() { + if ldd --version 2>&1 | grep -qi musl; then + echo "musl" + elif [ -e /lib/ld-musl-*.so.1 ] 2>/dev/null; then + echo "musl" + else + echo "gnu" + fi + } + LIBC="$(detect_libc)" + case "$ARCH" in + x86_64) + if [ "$LIBC" = "musl" ]; then TARGET="x86_64-unknown-linux-musl" + else TARGET="x86_64-unknown-linux-gnu"; fi ;; + aarch64) + if [ "$LIBC" = "musl" ]; then TARGET="aarch64-unknown-linux-musl" + else TARGET="aarch64-unknown-linux-gnu"; fi ;; + armv7l) + if [ "$LIBC" = "musl" ]; then TARGET="arm-unknown-linux-musleabihf" + else TARGET="arm-unknown-linux-gnueabihf"; fi ;; + i686) + if [ "$LIBC" = "musl" ]; then TARGET="i686-unknown-linux-musl" + else TARGET="i686-unknown-linux-gnu"; fi ;; + *) echo "Error: unsupported architecture: $ARCH" >&2; exit 1 ;; + esac + ;; + *) + echo "Error: unsupported OS: $OS" >&2 + exit 1 + ;; +esac + +# Detect downloader +if command -v curl >/dev/null 2>&1; then + download() { curl -fsSL -o "$1" "$2"; } +elif command -v wget >/dev/null 2>&1; then + download() { wget -qO "$1" "$2"; } +else + echo "Error: curl or wget is required" >&2 + exit 1 +fi + +# Pick install directory +if [ -w /usr/local/bin ]; then + INSTALL_DIR="/usr/local/bin" +else + INSTALL_DIR="${HOME}/.local/bin" + mkdir -p "$INSTALL_DIR" +fi + +# Create temp dir with cleanup +TMPDIR="$(mktemp -d)" +trap 'rm -rf "$TMPDIR"' EXIT + +# Download and extract +URL="https://github.com/${REPO}/releases/latest/download/${BINARY}-${TARGET}.tar.gz" +echo "Downloading ${BINARY} for ${TARGET}..." +download "$TMPDIR/${BINARY}.tar.gz" "$URL" +tar xzf "$TMPDIR/${BINARY}.tar.gz" -C "$TMPDIR" + +# Install +install -m 755 "$TMPDIR/${BINARY}" "${INSTALL_DIR}/${BINARY}" +echo "Installed ${BINARY} to ${INSTALL_DIR}/${BINARY}" + +# Print version +"${INSTALL_DIR}/${BINARY}" --version 2>/dev/null || true + +# Warn if not on PATH +case ":${PATH}:" in + *":${INSTALL_DIR}:"*) ;; + *) + echo "" + echo "Warning: ${INSTALL_DIR} is not on your PATH." + echo "Add it with:" + echo " export PATH=\"${INSTALL_DIR}:\$PATH\"" + ;; +esac diff --git a/scripts/version-sync.sh b/scripts/version-sync.sh new file mode 100755 index 0000000..f140612 --- /dev/null +++ b/scripts/version-sync.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +set -euo pipefail + +VERSION="${1:?Usage: version-sync.sh }" + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" + +# Update workspace Cargo.toml version +sed -i.bak "s/^version = \".*\"/version = \"$VERSION\"/" "$REPO_ROOT/Cargo.toml" +rm -f "$REPO_ROOT/Cargo.toml.bak" + +# Update socket-patch-core workspace dependency version (needed for cargo publish) +sed -i.bak "s/socket-patch-core = { path = \"crates\/socket-patch-core\", version = \".*\" }/socket-patch-core = { path = \"crates\/socket-patch-core\", version = \"$VERSION\" }/" "$REPO_ROOT/Cargo.toml" +rm -f "$REPO_ROOT/Cargo.toml.bak" + +# Update npm main package version and optionalDependencies versions +pkg_json="$REPO_ROOT/npm/socket-patch/package.json" +node -e " + const fs = require('fs'); + const pkg = JSON.parse(fs.readFileSync('$pkg_json', 'utf8')); + pkg.version = '$VERSION'; + if (pkg.optionalDependencies) { + for (const dep of Object.keys(pkg.optionalDependencies)) { + pkg.optionalDependencies[dep] = '$VERSION'; + } + } + fs.writeFileSync('$pkg_json', JSON.stringify(pkg, null, 2) + '\n'); +" + +# Update all per-platform npm package versions +for platform_dir in "$REPO_ROOT"/npm/socket-patch-*/; do + platform_pkg="$platform_dir/package.json" + if [ -f "$platform_pkg" ]; then + node -e " + const fs = require('fs'); + const pkg = JSON.parse(fs.readFileSync('$platform_pkg', 'utf8')); + pkg.version = '$VERSION'; + fs.writeFileSync('$platform_pkg', JSON.stringify(pkg, null, 2) + '\n'); + " + fi +done + +# Update PyPI package version +pyproject="$REPO_ROOT/pypi/socket-patch/pyproject.toml" +sed -i.bak "s/^version = \".*\"/version = \"$VERSION\"/" "$pyproject" +rm -f "$pyproject.bak" + +echo "Synced version to $VERSION" diff --git a/src/cli.ts b/src/cli.ts deleted file mode 100644 index e56181a..0000000 --- a/src/cli.ts +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env node - -import yargs from 'yargs' -import { hideBin } from 'yargs/helpers' -import { applyCommand } from './commands/apply.js' -import { downloadCommand } from './commands/download.js' -import { listCommand } from './commands/list.js' -import { removeCommand } from './commands/remove.js' -import { gcCommand } from './commands/gc.js' -import { setupCommand } from './commands/setup.js' - -async function main(): Promise { - await yargs(hideBin(process.argv)) - .scriptName('socket-patch') - .usage('$0 [options]') - .command(applyCommand) - .command(setupCommand) - .command(downloadCommand) - .command(listCommand) - .command(removeCommand) - .command(gcCommand) - .demandCommand(1, 'You must specify a command') - .help() - .alias('h', 'help') - .version() - .alias('v', 'version') - .strict() - .parse() -} - -main().catch((error: Error) => { - console.error('Error:', error.message) - process.exit(1) -}) diff --git a/src/commands/apply.ts b/src/commands/apply.ts deleted file mode 100644 index 6db7bbc..0000000 --- a/src/commands/apply.ts +++ /dev/null @@ -1,201 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import type { CommandModule } from 'yargs' -import { - PatchManifestSchema, - DEFAULT_PATCH_MANIFEST_PATH, -} from '../schema/manifest-schema.js' -import { - findNodeModules, - findPackagesForPatches, - applyPackagePatch, -} from '../patch/apply.js' -import type { ApplyResult } from '../patch/apply.js' -import { - cleanupUnusedBlobs, - formatCleanupResult, -} from '../utils/cleanup-blobs.js' - -interface ApplyArgs { - cwd: string - 'dry-run': boolean - silent: boolean - 'manifest-path': string -} - -async function applyPatches( - cwd: string, - manifestPath: string, - dryRun: boolean, - silent: boolean, -): Promise<{ success: boolean; results: ApplyResult[] }> { - // Read and parse manifest - const manifestContent = await fs.readFile(manifestPath, 'utf-8') - const manifestData = JSON.parse(manifestContent) - const manifest = PatchManifestSchema.parse(manifestData) - - // Find .socket directory (contains blobs) - const socketDir = path.dirname(manifestPath) - const blobsPath = path.join(socketDir, 'blobs') - - // Verify blobs directory exists - try { - await fs.access(blobsPath) - } catch { - throw new Error(`Blobs directory not found at ${blobsPath}`) - } - - // Find all node_modules directories - const nodeModulesPaths = await findNodeModules(cwd) - - if (nodeModulesPaths.length === 0) { - if (!silent) { - console.error('No node_modules directories found') - } - return { success: false, results: [] } - } - - // Find all packages that need patching - const allPackages = new Map() - for (const nmPath of nodeModulesPaths) { - const packages = await findPackagesForPatches(nmPath, manifest) - for (const [purl, location] of packages) { - if (!allPackages.has(purl)) { - allPackages.set(purl, location.path) - } - } - } - - if (allPackages.size === 0) { - if (!silent) { - console.log('No packages found that match available patches') - } - return { success: true, results: [] } - } - - // Apply patches to each package - const results: ApplyResult[] = [] - let hasErrors = false - - for (const [purl, pkgPath] of allPackages) { - const patch = manifest.patches[purl] - if (!patch) continue - - const result = await applyPackagePatch( - purl, - pkgPath, - patch.files, - blobsPath, - dryRun, - ) - - results.push(result) - - if (!result.success) { - hasErrors = true - if (!silent) { - console.error(`Failed to patch ${purl}: ${result.error}`) - } - } - } - - // Clean up unused blobs after applying patches - if (!silent) { - const cleanupResult = await cleanupUnusedBlobs(manifest, blobsPath, dryRun) - if (cleanupResult.blobsRemoved > 0) { - console.log(`\n${formatCleanupResult(cleanupResult, dryRun)}`) - } - } - - return { success: !hasErrors, results } -} - -export const applyCommand: CommandModule<{}, ApplyArgs> = { - command: 'apply', - describe: 'Apply security patches to dependencies', - builder: yargs => { - return yargs - .option('cwd', { - describe: 'Working directory', - type: 'string', - default: process.cwd(), - }) - .option('dry-run', { - alias: 'd', - describe: 'Verify patches can be applied without modifying files', - type: 'boolean', - default: false, - }) - .option('silent', { - alias: 's', - describe: 'Only output errors', - type: 'boolean', - default: false, - }) - .option('manifest-path', { - alias: 'm', - describe: 'Path to patch manifest file', - type: 'string', - default: DEFAULT_PATCH_MANIFEST_PATH, - }) - }, - handler: async argv => { - try { - const manifestPath = path.isAbsolute(argv['manifest-path']) - ? argv['manifest-path'] - : path.join(argv.cwd, argv['manifest-path']) - - // Check if manifest exists - try { - await fs.access(manifestPath) - } catch { - if (!argv.silent) { - console.error(`Manifest not found at ${manifestPath}`) - } - process.exit(1) - } - - const { success, results } = await applyPatches( - argv.cwd, - manifestPath, - argv['dry-run'], - argv.silent, - ) - - // Print results if not silent - if (!argv.silent && results.length > 0) { - const patched = results.filter(r => r.success) - const alreadyPatched = results.filter(r => - r.filesVerified.every(f => f.status === 'already-patched'), - ) - - if (argv['dry-run']) { - console.log(`\nPatch verification complete:`) - console.log(` ${patched.length} package(s) can be patched`) - if (alreadyPatched.length > 0) { - console.log(` ${alreadyPatched.length} package(s) already patched`) - } - } else { - console.log(`\nPatched packages:`) - for (const result of patched) { - if (result.filesPatched.length > 0) { - console.log(` ${result.packageKey}`) - } else if ( - result.filesVerified.every(f => f.status === 'already-patched') - ) { - console.log(` ${result.packageKey} (already patched)`) - } - } - } - } - - process.exit(success ? 0 : 1) - } catch (err) { - if (!argv.silent) { - const errorMessage = err instanceof Error ? err.message : String(err) - console.error(`Error: ${errorMessage}`) - } - process.exit(1) - } - }, -} diff --git a/src/commands/download.ts b/src/commands/download.ts deleted file mode 100644 index 62ff352..0000000 --- a/src/commands/download.ts +++ /dev/null @@ -1,165 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import type { CommandModule } from 'yargs' -import { PatchManifestSchema } from '../schema/manifest-schema.js' -import { getAPIClientFromEnv } from '../utils/api-client.js' -import { - cleanupUnusedBlobs, - formatCleanupResult, -} from '../utils/cleanup-blobs.js' - -interface DownloadArgs { - uuid: string - org: string - cwd: string - 'api-url'?: string - 'api-token'?: string -} - -async function downloadPatch( - uuid: string, - orgSlug: string, - cwd: string, - apiUrl?: string, - apiToken?: string, -): Promise { - // Override environment variables if CLI options are provided - if (apiUrl) { - process.env.SOCKET_API_URL = apiUrl - } - if (apiToken) { - process.env.SOCKET_API_TOKEN = apiToken - } - - // Get API client (will use env vars if not overridden) - const apiClient = getAPIClientFromEnv() - - console.log(`Fetching patch ${uuid} from ${orgSlug}...`) - - // Fetch patch from API - const patch = await apiClient.fetchPatch(orgSlug, uuid) - - if (!patch) { - throw new Error(`Patch with UUID ${uuid} not found`) - } - - console.log(`Downloaded patch for ${patch.purl}`) - - // Prepare .socket directory - const socketDir = path.join(cwd, '.socket') - const blobsDir = path.join(socketDir, 'blobs') - const manifestPath = path.join(socketDir, 'manifest.json') - - // Create directories - await fs.mkdir(socketDir, { recursive: true }) - await fs.mkdir(blobsDir, { recursive: true }) - - // Read existing manifest or create new one - let manifest: any - try { - const manifestContent = await fs.readFile(manifestPath, 'utf-8') - manifest = PatchManifestSchema.parse(JSON.parse(manifestContent)) - } catch { - // Create new manifest - manifest = { patches: {} } - } - - // Save blob contents - const files: Record = {} - for (const [filePath, fileInfo] of Object.entries(patch.files)) { - if (fileInfo.afterHash) { - files[filePath] = { - beforeHash: fileInfo.beforeHash, - afterHash: fileInfo.afterHash, - } - } - - // Save blob content if provided - if (fileInfo.blobContent && fileInfo.afterHash) { - const blobPath = path.join(blobsDir, fileInfo.afterHash) - const blobBuffer = Buffer.from(fileInfo.blobContent, 'base64') - await fs.writeFile(blobPath, blobBuffer) - console.log(` Saved blob: ${fileInfo.afterHash}`) - } - } - - // Add/update patch in manifest - manifest.patches[patch.purl] = { - uuid: patch.uuid, - exportedAt: patch.publishedAt, - files, - vulnerabilities: patch.vulnerabilities, - description: patch.description, - license: patch.license, - tier: patch.tier, - } - - // Write updated manifest - await fs.writeFile( - manifestPath, - JSON.stringify(manifest, null, 2) + '\n', - 'utf-8', - ) - - console.log(`\nPatch saved to ${manifestPath}`) - console.log(` PURL: ${patch.purl}`) - console.log(` UUID: ${patch.uuid}`) - console.log(` Files: ${Object.keys(files).length}`) - console.log(` Vulnerabilities: ${Object.keys(patch.vulnerabilities).length}`) - - // Clean up unused blobs - const cleanupResult = await cleanupUnusedBlobs(manifest, blobsDir, false) - if (cleanupResult.blobsRemoved > 0) { - console.log(`\n${formatCleanupResult(cleanupResult, false)}`) - } - - return true -} - -export const downloadCommand: CommandModule<{}, DownloadArgs> = { - command: 'download', - describe: 'Download a security patch from Socket API', - builder: yargs => { - return yargs - .option('uuid', { - describe: 'Patch UUID to download', - type: 'string', - demandOption: true, - }) - .option('org', { - describe: 'Organization slug', - type: 'string', - demandOption: true, - }) - .option('cwd', { - describe: 'Working directory', - type: 'string', - default: process.cwd(), - }) - .option('api-url', { - describe: 'Socket API URL (overrides SOCKET_API_URL env var)', - type: 'string', - }) - .option('api-token', { - describe: 'Socket API token (overrides SOCKET_API_TOKEN env var)', - type: 'string', - }) - }, - handler: async argv => { - try { - await downloadPatch( - argv.uuid, - argv.org, - argv.cwd, - argv['api-url'], - argv['api-token'], - ) - - process.exit(0) - } catch (err) { - const errorMessage = err instanceof Error ? err.message : String(err) - console.error(`Error: ${errorMessage}`) - process.exit(1) - } - }, -} diff --git a/src/commands/gc.ts b/src/commands/gc.ts deleted file mode 100644 index 2f56e7a..0000000 --- a/src/commands/gc.ts +++ /dev/null @@ -1,96 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import type { CommandModule } from 'yargs' -import { - PatchManifestSchema, - DEFAULT_PATCH_MANIFEST_PATH, -} from '../schema/manifest-schema.js' -import { - cleanupUnusedBlobs, - formatCleanupResult, -} from '../utils/cleanup-blobs.js' - -interface GCArgs { - cwd: string - 'manifest-path': string - 'dry-run': boolean -} - -async function garbageCollect( - manifestPath: string, - dryRun: boolean, -): Promise { - // Read and parse manifest - const manifestContent = await fs.readFile(manifestPath, 'utf-8') - const manifestData = JSON.parse(manifestContent) - const manifest = PatchManifestSchema.parse(manifestData) - - // Find .socket directory (contains blobs) - const socketDir = path.dirname(manifestPath) - const blobsPath = path.join(socketDir, 'blobs') - - // Run cleanup - const cleanupResult = await cleanupUnusedBlobs(manifest, blobsPath, dryRun) - - // Display results - if (cleanupResult.blobsChecked === 0) { - console.log('No blobs directory found, nothing to clean up.') - } else if (cleanupResult.blobsRemoved === 0) { - console.log( - `Checked ${cleanupResult.blobsChecked} blob(s), all are in use.`, - ) - } else { - console.log(formatCleanupResult(cleanupResult, dryRun)) - - if (!dryRun) { - console.log('\nGarbage collection complete.') - } - } -} - -export const gcCommand: CommandModule<{}, GCArgs> = { - command: 'gc', - describe: 'Clean up unused blob files from .socket/blobs directory', - builder: yargs => { - return yargs - .option('cwd', { - describe: 'Working directory', - type: 'string', - default: process.cwd(), - }) - .option('manifest-path', { - alias: 'm', - describe: 'Path to patch manifest file', - type: 'string', - default: DEFAULT_PATCH_MANIFEST_PATH, - }) - .option('dry-run', { - alias: 'd', - describe: 'Show what would be removed without actually removing', - type: 'boolean', - default: false, - }) - }, - handler: async argv => { - try { - const manifestPath = path.isAbsolute(argv['manifest-path']) - ? argv['manifest-path'] - : path.join(argv.cwd, argv['manifest-path']) - - // Check if manifest exists - try { - await fs.access(manifestPath) - } catch { - console.error(`Manifest not found at ${manifestPath}`) - process.exit(1) - } - - await garbageCollect(manifestPath, argv['dry-run']) - process.exit(0) - } catch (err) { - const errorMessage = err instanceof Error ? err.message : String(err) - console.error(`Error: ${errorMessage}`) - process.exit(1) - } - }, -} diff --git a/src/commands/list.ts b/src/commands/list.ts deleted file mode 100644 index 0f876bc..0000000 --- a/src/commands/list.ts +++ /dev/null @@ -1,151 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import type { CommandModule } from 'yargs' -import { - PatchManifestSchema, - DEFAULT_PATCH_MANIFEST_PATH, -} from '../schema/manifest-schema.js' - -interface ListArgs { - cwd: string - 'manifest-path': string - json: boolean -} - -async function listPatches( - manifestPath: string, - outputJson: boolean, -): Promise { - // Read and parse manifest - const manifestContent = await fs.readFile(manifestPath, 'utf-8') - const manifestData = JSON.parse(manifestContent) - const manifest = PatchManifestSchema.parse(manifestData) - - const patchEntries = Object.entries(manifest.patches) - - if (patchEntries.length === 0) { - if (outputJson) { - console.log(JSON.stringify({ patches: [] }, null, 2)) - } else { - console.log('No patches found in manifest.') - } - return - } - - if (outputJson) { - // Output as JSON for machine consumption - const jsonOutput = { - patches: patchEntries.map(([purl, patch]) => ({ - purl, - uuid: patch.uuid, - exportedAt: patch.exportedAt, - tier: patch.tier, - license: patch.license, - description: patch.description, - files: Object.keys(patch.files), - vulnerabilities: Object.entries(patch.vulnerabilities).map( - ([id, vuln]) => ({ - id, - cves: vuln.cves, - summary: vuln.summary, - severity: vuln.severity, - description: vuln.description, - }), - ), - })), - } - console.log(JSON.stringify(jsonOutput, null, 2)) - } else { - // Human-readable output - console.log(`Found ${patchEntries.length} patch(es):\n`) - - for (const [purl, patch] of patchEntries) { - console.log(`Package: ${purl}`) - console.log(` UUID: ${patch.uuid}`) - console.log(` Tier: ${patch.tier}`) - console.log(` License: ${patch.license}`) - console.log(` Exported: ${patch.exportedAt}`) - - if (patch.description) { - console.log(` Description: ${patch.description}`) - } - - // List vulnerabilities - const vulnEntries = Object.entries(patch.vulnerabilities) - if (vulnEntries.length > 0) { - console.log(` Vulnerabilities (${vulnEntries.length}):`) - for (const [id, vuln] of vulnEntries) { - const cveList = vuln.cves.length > 0 ? ` (${vuln.cves.join(', ')})` : '' - console.log(` - ${id}${cveList}`) - console.log(` Severity: ${vuln.severity}`) - console.log(` Summary: ${vuln.summary}`) - } - } - - // List files being patched - const fileList = Object.keys(patch.files) - if (fileList.length > 0) { - console.log(` Files patched (${fileList.length}):`) - for (const filePath of fileList) { - console.log(` - ${filePath}`) - } - } - - console.log('') // Empty line between patches - } - } -} - -export const listCommand: CommandModule<{}, ListArgs> = { - command: 'list', - describe: 'List all patches in the local manifest', - builder: yargs => { - return yargs - .option('cwd', { - describe: 'Working directory', - type: 'string', - default: process.cwd(), - }) - .option('manifest-path', { - alias: 'm', - describe: 'Path to patch manifest file', - type: 'string', - default: DEFAULT_PATCH_MANIFEST_PATH, - }) - .option('json', { - describe: 'Output as JSON', - type: 'boolean', - default: false, - }) - }, - handler: async argv => { - try { - const manifestPath = path.isAbsolute(argv['manifest-path']) - ? argv['manifest-path'] - : path.join(argv.cwd, argv['manifest-path']) - - // Check if manifest exists - try { - await fs.access(manifestPath) - } catch { - if (argv.json) { - console.log(JSON.stringify({ error: 'Manifest not found', path: manifestPath }, null, 2)) - } else { - console.error(`Manifest not found at ${manifestPath}`) - } - process.exit(1) - } - - await listPatches(manifestPath, argv.json) - process.exit(0) - } catch (err) { - const errorMessage = err instanceof Error ? err.message : String(err) - if (argv.json) { - console.log(JSON.stringify({ error: errorMessage }, null, 2)) - } else { - console.error(`Error: ${errorMessage}`) - } - process.exit(1) - } - }, -} diff --git a/src/commands/remove.ts b/src/commands/remove.ts deleted file mode 100644 index 4be9b7b..0000000 --- a/src/commands/remove.ts +++ /dev/null @@ -1,131 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import type { CommandModule } from 'yargs' -import { - PatchManifestSchema, - DEFAULT_PATCH_MANIFEST_PATH, - type PatchManifest, -} from '../schema/manifest-schema.js' -import { - cleanupUnusedBlobs, - formatCleanupResult, -} from '../utils/cleanup-blobs.js' - -interface RemoveArgs { - identifier: string - cwd: string - 'manifest-path': string -} - -async function removePatch( - identifier: string, - manifestPath: string, -): Promise<{ removed: string[]; notFound: boolean; manifest: PatchManifest }> { - // Read and parse manifest - const manifestContent = await fs.readFile(manifestPath, 'utf-8') - const manifestData = JSON.parse(manifestContent) - const manifest = PatchManifestSchema.parse(manifestData) - - const removed: string[] = [] - let foundMatch = false - - // Check if identifier is a PURL (contains "pkg:") - if (identifier.startsWith('pkg:')) { - // Remove by PURL - if (manifest.patches[identifier]) { - removed.push(identifier) - delete manifest.patches[identifier] - foundMatch = true - } - } else { - // Remove by UUID - search through all patches - for (const [purl, patch] of Object.entries(manifest.patches)) { - if (patch.uuid === identifier) { - removed.push(purl) - delete manifest.patches[purl] - foundMatch = true - } - } - } - - if (foundMatch) { - // Write updated manifest - await fs.writeFile( - manifestPath, - JSON.stringify(manifest, null, 2) + '\n', - 'utf-8', - ) - } - - return { removed, notFound: !foundMatch, manifest } -} - -export const removeCommand: CommandModule<{}, RemoveArgs> = { - command: 'remove ', - describe: 'Remove a patch from the manifest by PURL or UUID', - builder: yargs => { - return yargs - .positional('identifier', { - describe: 'Package PURL (e.g., pkg:npm/package@version) or patch UUID', - type: 'string', - demandOption: true, - }) - .option('cwd', { - describe: 'Working directory', - type: 'string', - default: process.cwd(), - }) - .option('manifest-path', { - alias: 'm', - describe: 'Path to patch manifest file', - type: 'string', - default: DEFAULT_PATCH_MANIFEST_PATH, - }) - }, - handler: async argv => { - try { - const manifestPath = path.isAbsolute(argv['manifest-path']) - ? argv['manifest-path'] - : path.join(argv.cwd, argv['manifest-path']) - - // Check if manifest exists - try { - await fs.access(manifestPath) - } catch { - console.error(`Manifest not found at ${manifestPath}`) - process.exit(1) - } - - const { removed, notFound, manifest } = await removePatch( - argv.identifier, - manifestPath, - ) - - if (notFound) { - console.error(`No patch found matching identifier: ${argv.identifier}`) - process.exit(1) - } - - console.log(`Removed ${removed.length} patch(es):`) - for (const purl of removed) { - console.log(` - ${purl}`) - } - - console.log(`\nManifest updated at ${manifestPath}`) - - // Clean up unused blobs after removing patches - const socketDir = path.dirname(manifestPath) - const blobsPath = path.join(socketDir, 'blobs') - const cleanupResult = await cleanupUnusedBlobs(manifest, blobsPath, false) - if (cleanupResult.blobsRemoved > 0) { - console.log(`\n${formatCleanupResult(cleanupResult, false)}`) - } - - process.exit(0) - } catch (err) { - const errorMessage = err instanceof Error ? err.message : String(err) - console.error(`Error: ${errorMessage}`) - process.exit(1) - } - }, -} diff --git a/src/commands/setup.ts b/src/commands/setup.ts deleted file mode 100644 index b47a4d1..0000000 --- a/src/commands/setup.ts +++ /dev/null @@ -1,185 +0,0 @@ -import * as path from 'path' -import * as readline from 'readline/promises' -import type { CommandModule } from 'yargs' -import { - findPackageJsonFiles, - updateMultiplePackageJsons, - type UpdateResult, -} from '../package-json/index.js' - -interface SetupArgs { - cwd: string - 'dry-run': boolean - yes: boolean -} - -/** - * Display a preview table of changes - */ -function displayPreview(results: UpdateResult[], cwd: string): void { - console.log('\nPackage.json files to be updated:\n') - - const toUpdate = results.filter(r => r.status === 'updated') - const alreadyConfigured = results.filter( - r => r.status === 'already-configured', - ) - const errors = results.filter(r => r.status === 'error') - - if (toUpdate.length > 0) { - console.log('Will update:') - for (const result of toUpdate) { - const relativePath = path.relative(cwd, result.path) - console.log(` ✓ ${relativePath}`) - if (result.oldScript) { - console.log(` Current: "${result.oldScript}"`) - } else { - console.log(` Current: (no postinstall script)`) - } - console.log(` New: "${result.newScript}"`) - } - console.log() - } - - if (alreadyConfigured.length > 0) { - console.log('Already configured (will skip):') - for (const result of alreadyConfigured) { - const relativePath = path.relative(cwd, result.path) - console.log(` ⊘ ${relativePath}`) - } - console.log() - } - - if (errors.length > 0) { - console.log('Errors:') - for (const result of errors) { - const relativePath = path.relative(cwd, result.path) - console.log(` ✗ ${relativePath}: ${result.error}`) - } - console.log() - } -} - -/** - * Display summary of changes made - */ -function displaySummary(results: UpdateResult[], dryRun: boolean): void { - const updated = results.filter(r => r.status === 'updated') - const alreadyConfigured = results.filter( - r => r.status === 'already-configured', - ) - const errors = results.filter(r => r.status === 'error') - - console.log('\nSummary:') - console.log( - ` ${updated.length} file(s) ${dryRun ? 'would be updated' : 'updated'}`, - ) - console.log(` ${alreadyConfigured.length} file(s) already configured`) - if (errors.length > 0) { - console.log(` ${errors.length} error(s)`) - } -} - -/** - * Prompt user for confirmation - */ -async function promptConfirmation(): Promise { - const rl = readline.createInterface({ - input: process.stdin, - output: process.stdout, - }) - - try { - const answer = await rl.question('Proceed with these changes? (y/N): ') - return answer.toLowerCase() === 'y' || answer.toLowerCase() === 'yes' - } finally { - rl.close() - } -} - -export const setupCommand: CommandModule<{}, SetupArgs> = { - command: 'setup', - describe: 'Configure package.json postinstall scripts to apply patches', - builder: yargs => { - return yargs - .option('cwd', { - describe: 'Working directory', - type: 'string', - default: process.cwd(), - }) - .option('dry-run', { - alias: 'd', - describe: 'Preview changes without modifying files', - type: 'boolean', - default: false, - }) - .option('yes', { - alias: 'y', - describe: 'Skip confirmation prompt', - type: 'boolean', - default: false, - }) - }, - handler: async argv => { - try { - // Find all package.json files - console.log('Searching for package.json files...') - const packageJsonFiles = await findPackageJsonFiles(argv.cwd) - - if (packageJsonFiles.length === 0) { - console.log('No package.json files found') - process.exit(0) - } - - console.log(`Found ${packageJsonFiles.length} package.json file(s)`) - - // Preview changes (dry run to see what would change) - const previewResults = await updateMultiplePackageJsons( - packageJsonFiles.map(p => p.path), - true, // Always preview first - ) - - // Display preview - displayPreview(previewResults, argv.cwd) - - const toUpdate = previewResults.filter(r => r.status === 'updated') - - if (toUpdate.length === 0) { - console.log( - 'All package.json files are already configured with socket-patch!', - ) - process.exit(0) - } - - // If not dry-run, ask for confirmation (unless --yes) - if (!argv['dry-run']) { - if (!argv.yes) { - const confirmed = await promptConfirmation() - if (!confirmed) { - console.log('Aborted') - process.exit(0) - } - } - - // Apply changes - console.log('\nApplying changes...') - const results = await updateMultiplePackageJsons( - packageJsonFiles.map(p => p.path), - false, - ) - - displaySummary(results, false) - - const errors = results.filter(r => r.status === 'error') - process.exit(errors.length > 0 ? 1 : 0) - } else { - // Dry run mode - displaySummary(previewResults, true) - process.exit(0) - } - } catch (err) { - const errorMessage = err instanceof Error ? err.message : String(err) - console.error(`Error: ${errorMessage}`) - process.exit(1) - } - }, -} diff --git a/src/constants.ts b/src/constants.ts deleted file mode 100644 index 0b5cc60..0000000 --- a/src/constants.ts +++ /dev/null @@ -1,16 +0,0 @@ -/** - * Standard paths and constants used throughout the socket-patch system - */ - -// Re-export from schema for convenience -export { DEFAULT_PATCH_MANIFEST_PATH } from './schema/manifest-schema.js' - -/** - * Default folder for storing patched file blobs - */ -export const DEFAULT_BLOB_FOLDER = '.socket/blob' - -/** - * Default Socket directory - */ -export const DEFAULT_SOCKET_DIR = '.socket' diff --git a/src/hash/git-sha256.ts b/src/hash/git-sha256.ts deleted file mode 100644 index 729734f..0000000 --- a/src/hash/git-sha256.ts +++ /dev/null @@ -1,37 +0,0 @@ -import * as crypto from 'crypto' - -/** - * Compute Git-compatible SHA256 hash for a buffer - * @param buffer - Buffer or Uint8Array to hash - * @returns Git-compatible SHA256 hash (hex string) - */ -export function computeGitSHA256FromBuffer( - buffer: Buffer | Uint8Array, -): string { - const gitHash = crypto.createHash('sha256') - const header = `blob ${buffer.length}\0` - gitHash.update(header) - gitHash.update(buffer) - return gitHash.digest('hex') -} - -/** - * Compute Git-compatible SHA256 hash from an async iterable of chunks - * @param size - Total size of the file in bytes - * @param chunks - Async iterable of Buffer or Uint8Array chunks - * @returns Git-compatible SHA256 hash (hex string) - */ -export async function computeGitSHA256FromChunks( - size: number, - chunks: AsyncIterable, -): Promise { - const gitHash = crypto.createHash('sha256') - const header = `blob ${size}\0` - gitHash.update(header) - - for await (const chunk of chunks) { - gitHash.update(chunk) - } - - return gitHash.digest('hex') -} diff --git a/src/index.ts b/src/index.ts deleted file mode 100644 index 3d25080..0000000 --- a/src/index.ts +++ /dev/null @@ -1,17 +0,0 @@ -export type { PatchInfo, ApplyOptions, PatchResult } from './types.js' -export { formatPatchResult, log, error } from './utils.js' - -// Re-export schema and hash modules -export * from './schema/manifest-schema.js' -export * from './hash/git-sha256.js' - -// Re-export patch application utilities -export * from './patch/file-hash.js' -export * from './patch/apply.js' - -// Re-export manifest utilities -export * from './manifest/operations.js' -export * from './manifest/recovery.js' - -// Re-export constants -export * from './constants.js' diff --git a/src/manifest/operations.ts b/src/manifest/operations.ts deleted file mode 100644 index a5dc0c9..0000000 --- a/src/manifest/operations.ts +++ /dev/null @@ -1,107 +0,0 @@ -import * as fs from 'fs/promises' -import type { PatchManifest, PatchRecord } from '../schema/manifest-schema.js' -import { PatchManifestSchema } from '../schema/manifest-schema.js' - -/** - * Get all blob hashes referenced by a manifest - * Used for garbage collection and validation - */ -export function getReferencedBlobs(manifest: PatchManifest): Set { - const blobs = new Set() - - for (const patchRecord of Object.values(manifest.patches)) { - const record = patchRecord as PatchRecord - for (const fileInfo of Object.values(record.files)) { - blobs.add(fileInfo.beforeHash) - blobs.add(fileInfo.afterHash) - } - } - - return blobs -} - -/** - * Calculate differences between two manifests - */ -export interface ManifestDiff { - added: Set // PURLs - removed: Set - modified: Set -} - -export function diffManifests( - oldManifest: PatchManifest, - newManifest: PatchManifest, -): ManifestDiff { - const oldPurls = new Set(Object.keys(oldManifest.patches)) - const newPurls = new Set(Object.keys(newManifest.patches)) - - const added = new Set() - const removed = new Set() - const modified = new Set() - - // Find added and modified - for (const purl of newPurls) { - if (!oldPurls.has(purl)) { - added.add(purl) - } else { - const oldPatch = oldManifest.patches[purl] as PatchRecord - const newPatch = newManifest.patches[purl] as PatchRecord - if (oldPatch.uuid !== newPatch.uuid) { - modified.add(purl) - } - } - } - - // Find removed - for (const purl of oldPurls) { - if (!newPurls.has(purl)) { - removed.add(purl) - } - } - - return { added, removed, modified } -} - -/** - * Validate a parsed manifest object - */ -export function validateManifest(parsed: unknown): { - success: boolean - manifest?: PatchManifest - error?: string -} { - const result = PatchManifestSchema.safeParse(parsed) - if (result.success) { - return { success: true, manifest: result.data } - } - return { - success: false, - error: result.error.message, - } -} - -/** - * Read and parse a manifest from the filesystem - */ -export async function readManifest(path: string): Promise { - try { - const content = await fs.readFile(path, 'utf-8') - const parsed = JSON.parse(content) - const result = validateManifest(parsed) - return result.success ? result.manifest! : null - } catch { - return null - } -} - -/** - * Write a manifest to the filesystem - */ -export async function writeManifest( - path: string, - manifest: PatchManifest, -): Promise { - const content = JSON.stringify(manifest, null, 2) - await fs.writeFile(path, content, 'utf-8') -} diff --git a/src/manifest/recovery.ts b/src/manifest/recovery.ts deleted file mode 100644 index af5a894..0000000 --- a/src/manifest/recovery.ts +++ /dev/null @@ -1,238 +0,0 @@ -import type { PatchManifest, PatchRecord } from '../schema/manifest-schema.js' -import { PatchManifestSchema, PatchRecordSchema } from '../schema/manifest-schema.js' - -/** - * Result of manifest recovery operation - */ -export interface RecoveryResult { - manifest: PatchManifest - repairNeeded: boolean - invalidPatches: string[] - recoveredPatches: string[] - discardedPatches: string[] -} - -/** - * Options for manifest recovery - */ -export interface RecoveryOptions { - /** - * Optional function to refetch patch data from external source (e.g., database) - * Should return patch data or null if not found - * @param uuid - The patch UUID - * @param purl - The package URL (for context/validation) - */ - refetchPatch?: (uuid: string, purl?: string) => Promise - - /** - * Optional callback for logging recovery events - */ - onRecoveryEvent?: (event: RecoveryEvent) => void -} - -/** - * Patch data returned from external source - */ -export interface PatchData { - uuid: string - purl: string - publishedAt: string - files: Record< - string, - { - beforeHash?: string - afterHash?: string - } - > - vulnerabilities: Record< - string, - { - cves: string[] - summary: string - severity: string - description: string - } - > - description: string - license: string - tier: string -} - -/** - * Events emitted during recovery - */ -export type RecoveryEvent = - | { type: 'corrupted_manifest' } - | { type: 'invalid_patch'; purl: string; uuid: string | null } - | { type: 'recovered_patch'; purl: string; uuid: string } - | { type: 'discarded_patch_not_found'; purl: string; uuid: string } - | { type: 'discarded_patch_purl_mismatch'; purl: string; uuid: string; dbPurl: string } - | { type: 'discarded_patch_no_uuid'; purl: string } - | { type: 'recovery_error'; purl: string; uuid: string; error: string } - -/** - * Recover and validate manifest with automatic repair of invalid patches - * - * This function attempts to parse and validate a manifest. If the manifest - * contains invalid patches, it will attempt to recover them using the provided - * refetch function. Patches that cannot be recovered are discarded. - * - * @param parsed - The parsed manifest object (may be invalid) - * @param options - Recovery options including refetch function and event callback - * @returns Recovery result with repaired manifest and statistics - */ -export async function recoverManifest( - parsed: unknown, - options: RecoveryOptions = {}, -): Promise { - const { refetchPatch, onRecoveryEvent } = options - - // Try strict parse first (fast path for valid manifests) - const strictResult = PatchManifestSchema.safeParse(parsed) - if (strictResult.success) { - return { - manifest: strictResult.data, - repairNeeded: false, - invalidPatches: [], - recoveredPatches: [], - discardedPatches: [], - } - } - - // Extract patches object with safety checks - const patchesObj = - parsed && - typeof parsed === 'object' && - 'patches' in parsed && - parsed.patches && - typeof parsed.patches === 'object' - ? (parsed.patches as Record) - : null - - if (!patchesObj) { - // Completely corrupted manifest - onRecoveryEvent?.({ type: 'corrupted_manifest' }) - return { - manifest: { patches: {} }, - repairNeeded: true, - invalidPatches: [], - recoveredPatches: [], - discardedPatches: [], - } - } - - // Try to recover individual patches - const recoveredPatchesMap: Record = {} - const invalidPatches: string[] = [] - const recoveredPatches: string[] = [] - const discardedPatches: string[] = [] - - for (const [purl, patchData] of Object.entries(patchesObj)) { - // Try to parse this individual patch - const patchResult = PatchRecordSchema.safeParse(patchData) - - if (patchResult.success) { - // Valid patch, keep it as-is - recoveredPatchesMap[purl] = patchResult.data - } else { - // Invalid patch, try to recover from external source - const uuid = - patchData && - typeof patchData === 'object' && - 'uuid' in patchData && - typeof patchData.uuid === 'string' - ? patchData.uuid - : null - - invalidPatches.push(purl) - onRecoveryEvent?.({ type: 'invalid_patch', purl, uuid }) - - if (uuid && refetchPatch) { - try { - // Try to refetch from external source - const patchFromSource = await refetchPatch(uuid, purl) - - if (patchFromSource && patchFromSource.purl === purl) { - // Successfully recovered, reconstruct patch record - const manifestFiles: Record< - string, - { beforeHash: string; afterHash: string } - > = {} - for (const [filePath, fileInfo] of Object.entries( - patchFromSource.files, - )) { - if (fileInfo.beforeHash && fileInfo.afterHash) { - manifestFiles[filePath] = { - beforeHash: fileInfo.beforeHash, - afterHash: fileInfo.afterHash, - } - } - } - - recoveredPatchesMap[purl] = { - uuid: patchFromSource.uuid, - exportedAt: patchFromSource.publishedAt, - files: manifestFiles, - vulnerabilities: patchFromSource.vulnerabilities, - description: patchFromSource.description, - license: patchFromSource.license, - tier: patchFromSource.tier, - } - - recoveredPatches.push(purl) - onRecoveryEvent?.({ type: 'recovered_patch', purl, uuid }) - } else if (patchFromSource && patchFromSource.purl !== purl) { - // PURL mismatch - wrong package! - discardedPatches.push(purl) - onRecoveryEvent?.({ - type: 'discarded_patch_purl_mismatch', - purl, - uuid, - dbPurl: patchFromSource.purl, - }) - } else { - // Not found in external source (might be unpublished) - discardedPatches.push(purl) - onRecoveryEvent?.({ - type: 'discarded_patch_not_found', - purl, - uuid, - }) - } - } catch (error: unknown) { - // Error during recovery - discardedPatches.push(purl) - const errorMessage = error instanceof Error ? error.message : String(error) - onRecoveryEvent?.({ - type: 'recovery_error', - purl, - uuid, - error: errorMessage, - }) - } - } else { - // No UUID or no refetch function, can't recover - discardedPatches.push(purl) - if (!uuid) { - onRecoveryEvent?.({ type: 'discarded_patch_no_uuid', purl }) - } else { - onRecoveryEvent?.({ - type: 'discarded_patch_not_found', - purl, - uuid, - }) - } - } - } - } - - const repairNeeded = invalidPatches.length > 0 - - return { - manifest: { patches: recoveredPatchesMap }, - repairNeeded, - invalidPatches, - recoveredPatches, - discardedPatches, - } -} diff --git a/src/package-json/detect.test.ts b/src/package-json/detect.test.ts deleted file mode 100644 index c6c5931..0000000 --- a/src/package-json/detect.test.ts +++ /dev/null @@ -1,557 +0,0 @@ -import { describe, it } from 'node:test' -import assert from 'node:assert/strict' -import { - isPostinstallConfigured, - generateUpdatedPostinstall, - updatePackageJsonContent, -} from './detect.js' - -describe('isPostinstallConfigured', () => { - describe('Edge Case 1: No scripts field at all', () => { - it('should detect as not configured when scripts field is missing', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, '') - }) - }) - - describe('Edge Case 2: Scripts field exists but no postinstall', () => { - it('should detect as not configured when postinstall is missing', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - test: 'jest', - build: 'tsc', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, '') - }) - - it('should detect as not configured when postinstall is null', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: null, - }, - } - - const result = isPostinstallConfigured(packageJson as any) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, '') - }) - - it('should detect as not configured when postinstall is undefined', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: undefined, - }, - } - - const result = isPostinstallConfigured(packageJson as any) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, '') - }) - - it('should detect as not configured when postinstall is empty string', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: '', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, '') - }) - - it('should detect as not configured when postinstall is whitespace only', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: ' \t\n ', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, ' \t\n ') - }) - }) - - describe('Edge Case 3: Postinstall exists but missing socket-patch setup', () => { - it('should detect as not configured when postinstall has different command', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'echo "Running postinstall"', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, 'echo "Running postinstall"') - }) - - it('should detect as not configured with complex existing script', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'npm run build && npm run prepare && echo done', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - }) - }) - - describe('Edge Case 4: Postinstall has socket-patch but not exact format', () => { - it('should detect socket-patch apply without npx as configured', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'socket-patch apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal( - result.configured, - true, - 'socket-patch apply should be recognized', - ) - assert.equal(result.needsUpdate, false) - }) - - it('should detect npx socket-patch apply (without @socketsecurity/) as configured', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'npx socket-patch apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal( - result.configured, - true, - 'npx socket-patch apply should be recognized', - ) - assert.equal(result.needsUpdate, false) - }) - - it('should detect canonical format npx @socketsecurity/socket-patch apply', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'npx @socketsecurity/socket-patch apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, true) - assert.equal(result.needsUpdate, false) - }) - - it('should detect pnpm socket-patch apply as configured', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'pnpm socket-patch apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal( - result.configured, - true, - 'pnpm socket-patch apply should be recognized', - ) - assert.equal(result.needsUpdate, false) - }) - - it('should detect yarn socket-patch apply as configured', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'yarn socket-patch apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal( - result.configured, - true, - 'yarn socket-patch apply should be recognized', - ) - assert.equal(result.needsUpdate, false) - }) - - it('should detect node_modules/.bin/socket-patch apply as configured', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'node_modules/.bin/socket-patch apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, true) - assert.equal(result.needsUpdate, false) - }) - - it('should NOT detect socket apply (main Socket CLI) as configured', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'socket apply', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal( - result.configured, - false, - 'socket apply (main CLI) should NOT be recognized', - ) - assert.equal(result.needsUpdate, true) - }) - - it('should NOT detect socket-patch without apply subcommand', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'socket-patch list', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal( - result.configured, - false, - 'socket-patch list should NOT be recognized', - ) - assert.equal(result.needsUpdate, true) - }) - - it('should detect socket-patch apply with additional flags', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'npx @socketsecurity/socket-patch apply --silent', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, true) - assert.equal(result.needsUpdate, false) - }) - - it('should detect socket-patch apply in complex script chain', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: - 'echo "Starting" && socket-patch apply && echo "Complete"', - }, - } - - const result = isPostinstallConfigured(packageJson) - - assert.equal(result.configured, true) - assert.equal(result.needsUpdate, false) - }) - }) - - describe('Edge Case 5: Invalid or malformed data', () => { - it('should handle malformed JSON gracefully', () => { - const malformedJson = '{ name: "test", invalid }' - - const result = isPostinstallConfigured(malformedJson) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - assert.equal(result.currentScript, '') - }) - - it('should handle non-string postinstall value', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 123, - }, - } - - const result = isPostinstallConfigured(packageJson as any) - - // Should coerce to string or handle gracefully - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - }) - - it('should handle array postinstall value', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: ['echo', 'hello'], - }, - } - - const result = isPostinstallConfigured(packageJson as any) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - }) - - it('should handle object postinstall value', () => { - const packageJson = { - name: 'test', - version: '1.0.0', - scripts: { - postinstall: { command: 'echo hello' }, - }, - } - - const result = isPostinstallConfigured(packageJson as any) - - assert.equal(result.configured, false) - assert.equal(result.needsUpdate, true) - }) - }) -}) - -describe('generateUpdatedPostinstall', () => { - it('should create command for empty string', () => { - const result = generateUpdatedPostinstall('') - assert.equal(result, 'npx @socketsecurity/socket-patch apply') - }) - - it('should create command for whitespace-only string', () => { - const result = generateUpdatedPostinstall(' \n\t ') - assert.equal(result, 'npx @socketsecurity/socket-patch apply') - }) - - it('should prepend to existing script', () => { - const result = generateUpdatedPostinstall('echo "Hello"') - assert.equal( - result, - 'npx @socketsecurity/socket-patch apply && echo "Hello"', - ) - }) - - it('should preserve existing script with socket-patch', () => { - const existing = 'socket-patch apply && echo "Done"' - const result = generateUpdatedPostinstall(existing) - assert.equal(result, existing, 'Should not modify if already present') - }) - - it('should preserve npx @socketsecurity/socket-patch apply', () => { - const existing = 'npx @socketsecurity/socket-patch apply' - const result = generateUpdatedPostinstall(existing) - assert.equal(result, existing) - }) - - it('should prepend to script with socket apply (main CLI)', () => { - const existing = 'socket apply' - const result = generateUpdatedPostinstall(existing) - assert.equal( - result, - 'npx @socketsecurity/socket-patch apply && socket apply', - 'Should add socket-patch even if socket apply is present', - ) - }) -}) - -describe('updatePackageJsonContent', () => { - it('should add scripts field when missing', () => { - const content = JSON.stringify({ - name: 'test', - version: '1.0.0', - }) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, true) - const updated = JSON.parse(result.content) - assert.ok(updated.scripts) - assert.equal( - updated.scripts.postinstall, - 'npx @socketsecurity/socket-patch apply', - ) - }) - - it('should add postinstall to existing scripts', () => { - const content = JSON.stringify({ - name: 'test', - version: '1.0.0', - scripts: { - test: 'jest', - build: 'tsc', - }, - }) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, true) - const updated = JSON.parse(result.content) - assert.equal( - updated.scripts.postinstall, - 'npx @socketsecurity/socket-patch apply', - ) - assert.equal(updated.scripts.test, 'jest', 'Should preserve other scripts') - assert.equal(updated.scripts.build, 'tsc', 'Should preserve other scripts') - }) - - it('should prepend to existing postinstall', () => { - const content = JSON.stringify({ - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'echo "Setup complete"', - }, - }) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, true) - assert.equal(result.oldScript, 'echo "Setup complete"') - assert.equal( - result.newScript, - 'npx @socketsecurity/socket-patch apply && echo "Setup complete"', - ) - }) - - it('should not modify when already configured', () => { - const content = JSON.stringify({ - name: 'test', - version: '1.0.0', - scripts: { - postinstall: 'npx @socketsecurity/socket-patch apply', - }, - }) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, false) - assert.equal(result.content, content) - }) - - it('should throw error for invalid JSON', () => { - const content = '{ invalid json }' - - assert.throws( - () => updatePackageJsonContent(content), - /Invalid package\.json/, - ) - }) - - it('should handle empty postinstall by replacing it', () => { - const content = JSON.stringify({ - name: 'test', - version: '1.0.0', - scripts: { - postinstall: '', - }, - }) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, true) - const updated = JSON.parse(result.content) - assert.equal( - updated.scripts.postinstall, - 'npx @socketsecurity/socket-patch apply', - ) - }) - - it('should handle whitespace-only postinstall', () => { - const content = JSON.stringify({ - name: 'test', - version: '1.0.0', - scripts: { - postinstall: ' \n\t ', - }, - }) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, true) - const updated = JSON.parse(result.content) - assert.equal( - updated.scripts.postinstall, - 'npx @socketsecurity/socket-patch apply', - ) - }) - - it('should preserve JSON formatting', () => { - const content = JSON.stringify( - { - name: 'test', - version: '1.0.0', - }, - null, - 2, - ) - - const result = updatePackageJsonContent(content) - - assert.equal(result.modified, true) - // Check that formatting is preserved (2 space indent) - assert.ok(result.content.includes(' "name"')) - assert.ok(result.content.includes(' "scripts"')) - }) -}) diff --git a/src/package-json/detect.ts b/src/package-json/detect.ts deleted file mode 100644 index 90548e7..0000000 --- a/src/package-json/detect.ts +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Shared logic for detecting and generating postinstall scripts - * Used by both CLI and GitHub bot - */ - -const SOCKET_PATCH_COMMAND = 'npx @socketsecurity/socket-patch apply' - -export interface PostinstallStatus { - configured: boolean - currentScript: string - needsUpdate: boolean -} - -/** - * Check if a postinstall script is properly configured for socket-patch - */ -export function isPostinstallConfigured( - packageJsonContent: string | Record, -): PostinstallStatus { - let packageJson: Record - - if (typeof packageJsonContent === 'string') { - try { - packageJson = JSON.parse(packageJsonContent) - } catch { - return { - configured: false, - currentScript: '', - needsUpdate: true, - } - } - } else { - packageJson = packageJsonContent - } - - const currentScript = packageJson.scripts?.postinstall || '' - - // Check if socket-patch apply is already present - const configured = currentScript.includes('socket-patch apply') - - return { - configured, - currentScript, - needsUpdate: !configured, - } -} - -/** - * Generate an updated postinstall script that includes socket-patch - */ -export function generateUpdatedPostinstall( - currentPostinstall: string, -): string { - const trimmed = currentPostinstall.trim() - - // If empty, just add the socket-patch command - if (!trimmed) { - return SOCKET_PATCH_COMMAND - } - - // If socket-patch is already present, return unchanged - if (trimmed.includes('socket-patch apply')) { - return trimmed - } - - // Prepend socket-patch command so it runs first, then existing script - // Using && ensures existing script only runs if patching succeeds - return `${SOCKET_PATCH_COMMAND} && ${trimmed}` -} - -/** - * Update a package.json object with the new postinstall script - * Returns the modified package.json and whether it was changed - */ -export function updatePackageJsonObject( - packageJson: Record, -): { modified: boolean; packageJson: Record } { - const status = isPostinstallConfigured(packageJson) - - if (!status.needsUpdate) { - return { modified: false, packageJson } - } - - // Ensure scripts object exists - if (!packageJson.scripts) { - packageJson.scripts = {} - } - - // Update postinstall script - const newPostinstall = generateUpdatedPostinstall(status.currentScript) - packageJson.scripts.postinstall = newPostinstall - - return { modified: true, packageJson } -} - -/** - * Parse package.json content and update it with socket-patch postinstall - * Returns the updated JSON string and metadata about the change - */ -export function updatePackageJsonContent( - content: string, -): { - modified: boolean - content: string - oldScript: string - newScript: string -} { - let packageJson: Record - - try { - packageJson = JSON.parse(content) - } catch { - throw new Error('Invalid package.json: failed to parse JSON') - } - - const status = isPostinstallConfigured(packageJson) - - if (!status.needsUpdate) { - return { - modified: false, - content, - oldScript: status.currentScript, - newScript: status.currentScript, - } - } - - // Update the package.json object - const { packageJson: updatedPackageJson } = - updatePackageJsonObject(packageJson) - - // Stringify with formatting - const newContent = JSON.stringify(updatedPackageJson, null, 2) + '\n' - const newScript = updatedPackageJson.scripts.postinstall - - return { - modified: true, - content: newContent, - oldScript: status.currentScript, - newScript, - } -} diff --git a/src/package-json/find.ts b/src/package-json/find.ts deleted file mode 100644 index 3b9a258..0000000 --- a/src/package-json/find.ts +++ /dev/null @@ -1,326 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' - -export interface WorkspaceConfig { - type: 'npm' | 'yarn' | 'pnpm' | 'none' - patterns: string[] -} - -export interface PackageJsonLocation { - path: string - isRoot: boolean - isWorkspace: boolean - workspacePattern?: string -} - -/** - * Find all package.json files recursively, respecting workspace configurations - */ -export async function findPackageJsonFiles( - startPath: string, -): Promise { - const results: PackageJsonLocation[] = [] - const rootPackageJsonPath = path.join(startPath, 'package.json') - - // Check if root package.json exists - let rootExists = false - let workspaceConfig: WorkspaceConfig = { type: 'none', patterns: [] } - - try { - await fs.access(rootPackageJsonPath) - rootExists = true - - // Detect workspace configuration - workspaceConfig = await detectWorkspaces(rootPackageJsonPath) - - // Add root package.json - results.push({ - path: rootPackageJsonPath, - isRoot: true, - isWorkspace: false, - }) - } catch { - // No root package.json - } - - // If workspaces are configured, find all workspace package.json files - if (workspaceConfig.type !== 'none') { - const workspacePackages = await findWorkspacePackages( - startPath, - workspaceConfig, - ) - results.push(...workspacePackages) - } else if (rootExists) { - // No workspaces, just search for nested package.json files - const nestedPackages = await findNestedPackageJsonFiles(startPath) - results.push(...nestedPackages) - } - - return results -} - -/** - * Detect workspace configuration from package.json - */ -export async function detectWorkspaces( - packageJsonPath: string, -): Promise { - try { - const content = await fs.readFile(packageJsonPath, 'utf-8') - const packageJson = JSON.parse(content) - - // Check for npm/yarn workspaces - if (packageJson.workspaces) { - const patterns = Array.isArray(packageJson.workspaces) - ? packageJson.workspaces - : packageJson.workspaces.packages || [] - - return { - type: 'npm', // npm and yarn use same format - patterns, - } - } - - // Check for pnpm workspaces (pnpm-workspace.yaml) - const dir = path.dirname(packageJsonPath) - const pnpmWorkspacePath = path.join(dir, 'pnpm-workspace.yaml') - - try { - await fs.access(pnpmWorkspacePath) - // Parse pnpm-workspace.yaml (simple YAML parsing for packages field) - const yamlContent = await fs.readFile(pnpmWorkspacePath, 'utf-8') - const patterns = parsePnpmWorkspacePatterns(yamlContent) - - return { - type: 'pnpm', - patterns, - } - } catch { - // No pnpm workspace file - } - - return { type: 'none', patterns: [] } - } catch { - return { type: 'none', patterns: [] } - } -} - -/** - * Simple parser for pnpm-workspace.yaml packages field - */ -function parsePnpmWorkspacePatterns(yamlContent: string): string[] { - const patterns: string[] = [] - const lines = yamlContent.split('\n') - let inPackages = false - - for (const line of lines) { - const trimmed = line.trim() - - if (trimmed === 'packages:') { - inPackages = true - continue - } - - if (inPackages) { - // Stop at next top-level key - if (trimmed && !trimmed.startsWith('-') && !trimmed.startsWith('#')) { - break - } - - // Parse list item - const match = trimmed.match(/^-\s*['"]?([^'"]+)['"]?/) - if (match) { - patterns.push(match[1]) - } - } - } - - return patterns -} - -/** - * Find workspace packages based on workspace patterns - */ -async function findWorkspacePackages( - rootPath: string, - workspaceConfig: WorkspaceConfig, -): Promise { - const results: PackageJsonLocation[] = [] - - for (const pattern of workspaceConfig.patterns) { - // Handle glob patterns like "packages/*" or "apps/**" - const packages = await findPackagesMatchingPattern(rootPath, pattern) - results.push( - ...packages.map(p => ({ - path: p, - isRoot: false, - isWorkspace: true, - workspacePattern: pattern, - })), - ) - } - - return results -} - -/** - * Find packages matching a workspace pattern - * Supports basic glob patterns: *, ** - */ -async function findPackagesMatchingPattern( - rootPath: string, - pattern: string, -): Promise { - const results: string[] = [] - - // Convert glob pattern to regex-like logic - const parts = pattern.split('/') - const searchPath = path.join(rootPath, parts[0]) - - // If pattern is like "packages/*", search one level deep - if (parts.length === 2 && parts[1] === '*') { - await searchOneLevel(searchPath, results) - } - // If pattern is like "packages/**", search recursively - else if (parts.length === 2 && parts[1] === '**') { - await searchRecursive(searchPath, results) - } - // If pattern is just a directory name, check if it has package.json - else { - const packageJsonPath = path.join(rootPath, pattern, 'package.json') - try { - await fs.access(packageJsonPath) - results.push(packageJsonPath) - } catch { - // Not a valid package - } - } - - return results -} - -/** - * Search one level deep for package.json files - */ -async function searchOneLevel( - dir: string, - results: string[], -): Promise { - try { - const entries = await fs.readdir(dir, { withFileTypes: true }) - - for (const entry of entries) { - if (!entry.isDirectory()) continue - - const packageJsonPath = path.join(dir, entry.name, 'package.json') - try { - await fs.access(packageJsonPath) - results.push(packageJsonPath) - } catch { - // No package.json in this directory - } - } - } catch { - // Ignore permission errors or missing directories - } -} - -/** - * Search recursively for package.json files - */ -async function searchRecursive( - dir: string, - results: string[], -): Promise { - try { - const entries = await fs.readdir(dir, { withFileTypes: true }) - - for (const entry of entries) { - if (!entry.isDirectory()) continue - - const fullPath = path.join(dir, entry.name) - - // Skip hidden directories, node_modules, dist, build - if ( - entry.name.startsWith('.') || - entry.name === 'node_modules' || - entry.name === 'dist' || - entry.name === 'build' - ) { - continue - } - - // Check for package.json at this level - const packageJsonPath = path.join(fullPath, 'package.json') - try { - await fs.access(packageJsonPath) - results.push(packageJsonPath) - } catch { - // No package.json at this level - } - - // Recurse into subdirectories - await searchRecursive(fullPath, results) - } - } catch { - // Ignore permission errors - } -} - -/** - * Find nested package.json files without workspace configuration - */ -async function findNestedPackageJsonFiles( - startPath: string, -): Promise { - const results: PackageJsonLocation[] = [] - - async function search(dir: string, depth: number): Promise { - // Limit depth to avoid searching too deep - if (depth > 5) return - - try { - const entries = await fs.readdir(dir, { withFileTypes: true }) - - for (const entry of entries) { - if (!entry.isDirectory()) continue - - const fullPath = path.join(dir, entry.name) - - // Skip hidden directories, node_modules, dist, build - if ( - entry.name.startsWith('.') || - entry.name === 'node_modules' || - entry.name === 'dist' || - entry.name === 'build' - ) { - continue - } - - // Check for package.json at this level - const packageJsonPath = path.join(fullPath, 'package.json') - try { - await fs.access(packageJsonPath) - // Don't include the root package.json (already added) - if (packageJsonPath !== path.join(startPath, 'package.json')) { - results.push({ - path: packageJsonPath, - isRoot: false, - isWorkspace: false, - }) - } - } catch { - // No package.json at this level - } - - // Recurse into subdirectories - await search(fullPath, depth + 1) - } - } catch { - // Ignore permission errors - } - } - - await search(startPath, 0) - return results -} diff --git a/src/package-json/index.ts b/src/package-json/index.ts deleted file mode 100644 index 34815b0..0000000 --- a/src/package-json/index.ts +++ /dev/null @@ -1,20 +0,0 @@ -export { - findPackageJsonFiles, - detectWorkspaces, - type WorkspaceConfig, - type PackageJsonLocation, -} from './find.js' - -export { - isPostinstallConfigured, - generateUpdatedPostinstall, - updatePackageJsonObject, - updatePackageJsonContent, - type PostinstallStatus, -} from './detect.js' - -export { - updatePackageJson, - updateMultiplePackageJsons, - type UpdateResult, -} from './update.js' diff --git a/src/package-json/update.ts b/src/package-json/update.ts deleted file mode 100644 index 033b7de..0000000 --- a/src/package-json/update.ts +++ /dev/null @@ -1,88 +0,0 @@ -import * as fs from 'fs/promises' -import { - isPostinstallConfigured, - updatePackageJsonContent, -} from './detect.js' - -export interface UpdateResult { - path: string - status: 'updated' | 'already-configured' | 'error' - oldScript: string - newScript: string - error?: string -} - -/** - * Update a single package.json file with socket-patch postinstall script - */ -export async function updatePackageJson( - packageJsonPath: string, - dryRun: boolean = false, -): Promise { - try { - // Read current package.json - const content = await fs.readFile(packageJsonPath, 'utf-8') - - // Check current status - const status = isPostinstallConfigured(content) - - if (!status.needsUpdate) { - return { - path: packageJsonPath, - status: 'already-configured', - oldScript: status.currentScript, - newScript: status.currentScript, - } - } - - // Generate updated content - const { modified, content: newContent, oldScript, newScript } = - updatePackageJsonContent(content) - - if (!modified) { - return { - path: packageJsonPath, - status: 'already-configured', - oldScript, - newScript, - } - } - - // Write updated content (unless dry run) - if (!dryRun) { - await fs.writeFile(packageJsonPath, newContent, 'utf-8') - } - - return { - path: packageJsonPath, - status: 'updated', - oldScript, - newScript, - } - } catch (error) { - return { - path: packageJsonPath, - status: 'error', - oldScript: '', - newScript: '', - error: error instanceof Error ? error.message : String(error), - } - } -} - -/** - * Update multiple package.json files - */ -export async function updateMultiplePackageJsons( - paths: string[], - dryRun: boolean = false, -): Promise { - const results: UpdateResult[] = [] - - for (const path of paths) { - const result = await updatePackageJson(path, dryRun) - results.push(result) - } - - return results -} diff --git a/src/patch/apply.ts b/src/patch/apply.ts deleted file mode 100644 index a749d9f..0000000 --- a/src/patch/apply.ts +++ /dev/null @@ -1,314 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import { computeFileGitSHA256 } from './file-hash.js' -import type { PatchManifest } from '../schema/manifest-schema.js' - -export interface PatchFileInfo { - beforeHash: string - afterHash: string -} - -export interface PackageLocation { - name: string - version: string - path: string -} - -export interface VerifyResult { - file: string - status: 'ready' | 'already-patched' | 'hash-mismatch' | 'not-found' - message?: string - currentHash?: string - expectedHash?: string - targetHash?: string -} - -export interface ApplyResult { - packageKey: string - packagePath: string - success: boolean - filesVerified: VerifyResult[] - filesPatched: string[] - error?: string -} - -/** - * Normalize file path by removing the 'package/' prefix if present - * Patch files come from the API with paths like 'package/lib/file.js' - * but we need relative paths like 'lib/file.js' for the actual package directory - */ -function normalizeFilePath(fileName: string): string { - const packagePrefix = 'package/' - if (fileName.startsWith(packagePrefix)) { - return fileName.slice(packagePrefix.length) - } - return fileName -} - -/** - * Verify a single file can be patched - */ -export async function verifyFilePatch( - packagePath: string, - fileName: string, - fileInfo: PatchFileInfo, -): Promise { - const normalizedFileName = normalizeFilePath(fileName) - const filepath = path.join(packagePath, normalizedFileName) - - // Check if file exists - try { - await fs.access(filepath) - } catch { - return { - file: fileName, - status: 'not-found', - message: 'File not found', - } - } - - // Compute current hash - const currentHash = await computeFileGitSHA256(filepath) - - // Check if already patched - if (currentHash === fileInfo.afterHash) { - return { - file: fileName, - status: 'already-patched', - currentHash, - } - } - - // Check if matches expected before hash - if (currentHash !== fileInfo.beforeHash) { - return { - file: fileName, - status: 'hash-mismatch', - message: 'File hash does not match expected value', - currentHash, - expectedHash: fileInfo.beforeHash, - targetHash: fileInfo.afterHash, - } - } - - return { - file: fileName, - status: 'ready', - currentHash, - targetHash: fileInfo.afterHash, - } -} - -/** - * Apply a patch to a single file - */ -export async function applyFilePatch( - packagePath: string, - fileName: string, - patchedContent: Buffer, - expectedHash: string, -): Promise { - const normalizedFileName = normalizeFilePath(fileName) - const filepath = path.join(packagePath, normalizedFileName) - - // Write the patched content - await fs.writeFile(filepath, patchedContent) - - // Verify the hash after writing - const verifyHash = await computeFileGitSHA256(filepath) - if (verifyHash !== expectedHash) { - throw new Error( - `Hash verification failed after patch. Expected: ${expectedHash}, Got: ${verifyHash}`, - ) - } -} - -/** - * Verify and apply patches for a single package - */ -export async function applyPackagePatch( - packageKey: string, - packagePath: string, - files: Record, - blobsPath: string, - dryRun: boolean = false, -): Promise { - const result: ApplyResult = { - packageKey, - packagePath, - success: false, - filesVerified: [], - filesPatched: [], - } - - try { - // First, verify all files - for (const [fileName, fileInfo] of Object.entries(files)) { - const verifyResult = await verifyFilePatch( - packagePath, - fileName, - fileInfo, - ) - result.filesVerified.push(verifyResult) - - // If any file is not ready or already patched, we can't proceed - if ( - verifyResult.status !== 'ready' && - verifyResult.status !== 'already-patched' - ) { - result.error = `Cannot apply patch: ${verifyResult.file} - ${verifyResult.message || verifyResult.status}` - return result - } - } - - // Check if all files are already patched - const allPatched = result.filesVerified.every( - v => v.status === 'already-patched', - ) - if (allPatched) { - result.success = true - return result - } - - // If dry run, stop here - if (dryRun) { - result.success = true - return result - } - - // Apply patches to files that need it - for (const [fileName, fileInfo] of Object.entries(files)) { - const verifyResult = result.filesVerified.find(v => v.file === fileName) - if (verifyResult?.status === 'already-patched') { - continue - } - - // Read patched content from blobs - const blobPath = path.join(blobsPath, fileInfo.afterHash) - const patchedContent = await fs.readFile(blobPath) - - // Apply the patch - await applyFilePatch( - packagePath, - fileName, - patchedContent, - fileInfo.afterHash, - ) - result.filesPatched.push(fileName) - } - - result.success = true - } catch (error) { - result.error = error instanceof Error ? error.message : String(error) - } - - return result -} - -/** - * Find all node_modules directories recursively - */ -export async function findNodeModules(startPath: string): Promise { - const results: string[] = [] - - async function search(dir: string): Promise { - try { - const entries = await fs.readdir(dir, { withFileTypes: true }) - - for (const entry of entries) { - if (!entry.isDirectory()) continue - - const fullPath = path.join(dir, entry.name) - - if (entry.name === 'node_modules') { - results.push(fullPath) - // Don't recurse into nested node_modules - continue - } - - // Skip hidden directories and common non-source directories - if ( - entry.name.startsWith('.') || - entry.name === 'dist' || - entry.name === 'build' - ) { - continue - } - - await search(fullPath) - } - } catch { - // Ignore permission errors - } - } - - await search(startPath) - return results -} - -/** - * Find packages in node_modules that match the manifest - */ -export async function findPackagesForPatches( - nodeModulesPath: string, - manifest: PatchManifest, -): Promise> { - const packages = new Map() - - try { - const entries = await fs.readdir(nodeModulesPath, { withFileTypes: true }) - - for (const entry of entries) { - // Allow both directories and symlinks (pnpm uses symlinks) - if (!entry.isDirectory() && !entry.isSymbolicLink()) continue - - const isScoped = entry.name.startsWith('@') - const dirPath = path.join(nodeModulesPath, entry.name) - - if (isScoped) { - // Handle scoped packages - const scopedEntries = await fs.readdir(dirPath, { withFileTypes: true }) - for (const scopedEntry of scopedEntries) { - // Allow both directories and symlinks (pnpm uses symlinks) - if (!scopedEntry.isDirectory() && !scopedEntry.isSymbolicLink()) continue - - const pkgPath = path.join(dirPath, scopedEntry.name) - await checkPackage(pkgPath, manifest, packages) - } - } else { - // Handle non-scoped packages - await checkPackage(dirPath, manifest, packages) - } - } - } catch { - // Ignore errors reading node_modules - } - - return packages -} - -async function checkPackage( - pkgPath: string, - manifest: PatchManifest, - packages: Map, -): Promise { - try { - const pkgJsonPath = path.join(pkgPath, 'package.json') - const pkgJsonContent = await fs.readFile(pkgJsonPath, 'utf-8') - const pkgJson = JSON.parse(pkgJsonContent) - - if (!pkgJson.name || !pkgJson.version) return - - // Check if this package has a patch - const purl = `pkg:npm/${pkgJson.name}@${pkgJson.version}` - if (manifest.patches[purl]) { - packages.set(purl, { - name: pkgJson.name, - version: pkgJson.version, - path: pkgPath, - }) - } - } catch { - // Ignore invalid package.json - } -} diff --git a/src/patch/file-hash.ts b/src/patch/file-hash.ts deleted file mode 100644 index bae6a85..0000000 --- a/src/patch/file-hash.ts +++ /dev/null @@ -1,22 +0,0 @@ -import * as fs from 'fs' -import * as fsp from 'fs/promises' -import { computeGitSHA256FromChunks } from '../hash/git-sha256.js' - -/** - * Compute Git-compatible SHA256 hash of file contents using streaming - */ -export async function computeFileGitSHA256(filepath: string): Promise { - // Get file size first - const stats = await fsp.stat(filepath) - const fileSize = stats.size - - // Create async iterable from read stream - async function* readFileChunks() { - const stream = fs.createReadStream(filepath) - for await (const chunk of stream) { - yield chunk as Buffer - } - } - - return computeGitSHA256FromChunks(fileSize, readFileChunks()) -} diff --git a/src/types.ts b/src/types.ts deleted file mode 100644 index f7ccfa8..0000000 --- a/src/types.ts +++ /dev/null @@ -1,20 +0,0 @@ -export interface PatchInfo { - packageName: string - version: string - patchPath: string - description?: string -} - -export interface ApplyOptions { - dryRun: boolean - verbose: boolean - force: boolean -} - -export interface PatchResult { - success: boolean - packageName: string - version: string - error?: string - filesModified?: string[] -} diff --git a/src/utils.ts b/src/utils.ts deleted file mode 100644 index 872a82e..0000000 --- a/src/utils.ts +++ /dev/null @@ -1,23 +0,0 @@ -import type { PatchResult } from './types.js' - -export function formatPatchResult(result: PatchResult): string { - if (result.success) { - let message = `✓ Successfully patched ${result.packageName}@${result.version}` - if (result.filesModified && result.filesModified.length > 0) { - message += `\n Modified files: ${result.filesModified.join(', ')}` - } - return message - } else { - return `✗ Failed to patch ${result.packageName}@${result.version}: ${result.error || 'Unknown error'}` - } -} - -export function log(message: string, verbose: boolean = false): void { - if (verbose) { - console.log(`[socket-patch] ${message}`) - } -} - -export function error(message: string): void { - console.error(`[socket-patch] ERROR: ${message}`) -} diff --git a/src/utils/api-client.ts b/src/utils/api-client.ts deleted file mode 100644 index 3639929..0000000 --- a/src/utils/api-client.ts +++ /dev/null @@ -1,120 +0,0 @@ -import * as https from 'node:https' -import * as http from 'node:http' - -export interface PatchResponse { - uuid: string - purl: string - publishedAt: string - files: Record< - string, - { - beforeHash?: string - afterHash?: string - socketBlob?: string - blobContent?: string - } - > - vulnerabilities: Record< - string, - { - cves: string[] - summary: string - severity: string - description: string - } - > - description: string - license: string - tier: 'free' | 'paid' -} - -export interface APIClientOptions { - apiUrl: string - apiToken: string -} - -export class APIClient { - private readonly apiUrl: string - private readonly apiToken: string - - constructor(options: APIClientOptions) { - this.apiUrl = options.apiUrl.replace(/\/$/, '') // Remove trailing slash - this.apiToken = options.apiToken - } - - async fetchPatch( - orgSlug: string, - uuid: string, - ): Promise { - const url = `${this.apiUrl}/v0/orgs/${orgSlug}/patches/view/${uuid}` - - return new Promise((resolve, reject) => { - const urlObj = new URL(url) - const isHttps = urlObj.protocol === 'https:' - const httpModule = isHttps ? https : http - - const options: https.RequestOptions = { - method: 'GET', - headers: { - Authorization: `Bearer ${this.apiToken}`, - Accept: 'application/json', - }, - } - - const req = httpModule.request(urlObj, options, res => { - let data = '' - - res.on('data', chunk => { - data += chunk - }) - - res.on('end', () => { - if (res.statusCode === 200) { - try { - const parsed = JSON.parse(data) - resolve(parsed) - } catch (err) { - reject(new Error(`Failed to parse response: ${err}`)) - } - } else if (res.statusCode === 404) { - resolve(null) - } else if (res.statusCode === 401) { - reject(new Error('Unauthorized: Invalid API token')) - } else if (res.statusCode === 403) { - reject( - new Error( - 'Forbidden: Access denied. This may be a paid patch or you may not have access to this organization.', - ), - ) - } else if (res.statusCode === 429) { - reject(new Error('Rate limit exceeded. Please try again later.')) - } else { - reject( - new Error(`API request failed with status ${res.statusCode}: ${data}`), - ) - } - }) - }) - - req.on('error', err => { - reject(new Error(`Network error: ${err.message}`)) - }) - - req.end() - }) - } -} - -export function getAPIClientFromEnv(): APIClient { - const apiUrl = - process.env.SOCKET_API_URL || 'https://api.socket.dev' - const apiToken = process.env.SOCKET_API_TOKEN - - if (!apiToken) { - throw new Error( - 'SOCKET_API_TOKEN environment variable is required. Please set it to your Socket API token.', - ) - } - - return new APIClient({ apiUrl, apiToken }) -} diff --git a/src/utils/cleanup-blobs.ts b/src/utils/cleanup-blobs.ts deleted file mode 100644 index 0998034..0000000 --- a/src/utils/cleanup-blobs.ts +++ /dev/null @@ -1,120 +0,0 @@ -import * as fs from 'fs/promises' -import * as path from 'path' -import type { PatchManifest } from '../schema/manifest-schema.js' -import { getReferencedBlobs } from '../manifest/operations.js' - -export interface CleanupResult { - blobsChecked: number - blobsRemoved: number - bytesFreed: number - removedBlobs: string[] -} - -/** - * Cleans up unused blob files from the .socket/blobs directory. - * Analyzes the manifest to determine which blobs are still in use, - * then removes any blob files that are not referenced. - * - * @param manifest - The patch manifest containing all active patches - * @param blobsDir - Path to the .socket/blobs directory - * @param dryRun - If true, only reports what would be deleted without actually deleting - * @returns Statistics about the cleanup operation - */ -export async function cleanupUnusedBlobs( - manifest: PatchManifest, - blobsDir: string, - dryRun: boolean = false, -): Promise { - // Collect all blob hashes that are currently in use - const usedBlobs = getReferencedBlobs(manifest) - - // Check if blobs directory exists - try { - await fs.access(blobsDir) - } catch { - // Blobs directory doesn't exist, nothing to clean up - return { - blobsChecked: 0, - blobsRemoved: 0, - bytesFreed: 0, - removedBlobs: [], - } - } - - // Read all files in the blobs directory - const blobFiles = await fs.readdir(blobsDir) - - const result: CleanupResult = { - blobsChecked: blobFiles.length, - blobsRemoved: 0, - bytesFreed: 0, - removedBlobs: [], - } - - // Check each blob file - for (const blobFile of blobFiles) { - // Skip hidden files and directories - if (blobFile.startsWith('.')) { - continue - } - - const blobPath = path.join(blobsDir, blobFile) - - // Check if it's a file (not a directory) - const stats = await fs.stat(blobPath) - if (!stats.isFile()) { - continue - } - - // If this blob is not in use, remove it - if (!usedBlobs.has(blobFile)) { - result.blobsRemoved++ - result.bytesFreed += stats.size - result.removedBlobs.push(blobFile) - - if (!dryRun) { - await fs.unlink(blobPath) - } - } - } - - return result -} - -/** - * Formats the cleanup result for human-readable output - */ -export function formatCleanupResult(result: CleanupResult, dryRun: boolean): string { - if (result.blobsChecked === 0) { - return 'No blobs directory found, nothing to clean up.' - } - - if (result.blobsRemoved === 0) { - return `Checked ${result.blobsChecked} blob(s), all are in use.` - } - - const action = dryRun ? 'Would remove' : 'Removed' - const bytesFormatted = formatBytes(result.bytesFreed) - - let output = `${action} ${result.blobsRemoved} unused blob(s) (${bytesFormatted} freed)` - - if (dryRun && result.removedBlobs.length > 0) { - output += '\nUnused blobs:' - for (const blob of result.removedBlobs) { - output += `\n - ${blob}` - } - } - - return output -} - -/** - * Formats bytes into a human-readable string - */ -function formatBytes(bytes: number): string { - if (bytes === 0) return '0 B' - if (bytes < 1024) return `${bytes} B` - if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(2)} KB` - if (bytes < 1024 * 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(2)} MB` - return `${(bytes / (1024 * 1024 * 1024)).toFixed(2)} GB` -} diff --git a/tsconfig.json b/tsconfig.json deleted file mode 100644 index 28d5490..0000000 --- a/tsconfig.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "compilerOptions": { - "target": "ES2022", - "module": "CommonJS", - "lib": ["ES2022"], - "moduleResolution": "node", - "outDir": "./dist", - "rootDir": "./src", - "strict": true, - "esModuleInterop": true, - "skipLibCheck": true, - "forceConsistentCasingInFileNames": true, - "resolveJsonModule": true, - "declaration": true, - "declarationMap": true, - "sourceMap": true, - "noUnusedLocals": true, - "noUnusedParameters": true, - "noImplicitReturns": true, - "noFallthroughCasesInSwitch": true, - "allowSyntheticDefaultImports": true, - "composite": true, - "tsBuildInfoFile": "./dist/tsconfig.tsbuildinfo" - }, - "include": ["src/**/*"], - "exclude": ["node_modules", "dist"] -}