diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml deleted file mode 100644 index 7a96def1..00000000 --- a/.github/workflows/go.yml +++ /dev/null @@ -1,513 +0,0 @@ -name: AAARRR - -on: - push: - branches: [ n2 ] - paths: - - '**/*.go' - pull_request: - branches: [ n2 ] - paths: - - '**/*.go' - workflow_dispatch: - -jobs: - - build: - name: ๐Ÿงฌ Build - runs-on: ubuntu-latest - permissions: - id-token: write - contents: write # sbom attestation - attestations: write - packages: write # sbom - security-events: write # grype - outputs: - artifact-subjects: ${{ steps.post-build.outputs.subjects }} - sbom-info: ${{ steps.post-build.outputs.sbom-info }} - vcs-ver: ${{ steps.post-build.outputs.vcs-ver }} - publish: ${{ steps.post-build.outputs.publish }} - env: - FOUT: firestack.aar - FOUTDBG: firestack-debug.aar - - steps: - - name: ๐Ÿฅ Checkout code - uses: actions/checkout@v6 - - - name: ๐Ÿผ Setup go1.26+ - uses: actions/setup-go@v6 - with: - go-version: '>=1.26' - check-latest: true - - - name: ๐Ÿ—๏ธ Make - id: make - run: | - # 10 chars of the commit sha - VCSVER="${GITHUB_SHA:0:10}" - # GET triggers a build; api: docs.jitpack.io/api/#builds - curl https://jitpack.io/api/builds/com.github.celzero/firestack/${VCSVER} | jq || true - - # outputs firestack.aar and firestack-debug.aar; also see: "Obj" below - ./make-aar nogo debug - shell: bash - - - name: ๐Ÿงช Test - id: test - if: success() - run: | - go env - # go test -v -race -bench=. -benchtime=100ms ./... - echo "::notice::success" - - # ref: make-aar for output file names - # for provenance, better to upload as soon as possible - # github.com/actions/upload-artifact - - name: ๐Ÿš€ Upload - if: ${{ steps.test.outcome == 'success' }} - uses: actions/upload-artifact@v6 - with: - name: firestack-aar-${{ github.sha }} # must be unique - path: | - firestack.aar - firestack-debug.aar - build/intra/tun2socks-sources.jar - retention-days: 52 # 0-90; 90 is max - if-no-files-found: error # error, warn (default), ignore - compression-level: 9 # 0-9; 9 is max - - # github.com/microsoft/sbom-tool - # github.com/microsoft/sbom-tool/blob/f5f65011f2/docs/sbom-tool-arguments.md - - name: ๐Ÿงพ SBOM - id: sbom-gen - if: always() - run: | - # docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/exporting-a-software-bill-of-materials-for-your-repository#generating-a-software-bill-of-materials-from-github-actions - # github.com/marketplace/actions/spdx-dependency-submission-action - curl -Lo $RUNNER_TEMP/sbom-tool https://github.com/microsoft/sbom-tool/releases/latest/download/sbom-tool-linux-x64 - chmod +x $RUNNER_TEMP/sbom-tool - $RUNNER_TEMP/sbom-tool generate -b . -bc . -pn ${{ github.repository }} -pv 1.0.0 -ps Celzero -nsb https://sbom.rethinkdns.com/app -V Verbose - - - name: ๐Ÿ“ค SBOM upload - if: ${{ steps.sbom-gen.outcome == 'success' }} - id: sbom-upload - continue-on-error: true - uses: actions/upload-artifact@v6 - with: - name: "firestack-sbom-${{ github.sha }}" - path: _manifest/spdx_2.2 - compression-level: 9 - retention-days: 3 - - - name: ๐Ÿ“ SBOM submission - if: success() - continue-on-error: true - uses: advanced-security/spdx-dependency-submission-action@v0.1.1 - with: - filePath: "_manifest/spdx_2.2/" - - - name: ๐Ÿ’ฟ Obj - if: ${{ steps.make.outcome == 'success' }} - run: | - wget --tries=2 --waitretry=3 --no-dns-cache https://github.com/Zxilly/go-size-analyzer/releases/download/v1.0.8/go-size-analyzer_1.0.8_linux_amd64.deb -O gsa.deb - sudo dpkg -i gsa.deb - # s/tun2socks*.aar/firestack*.aar; see: make-aar - # - # Archive: firestack-debug.aar - # inflating: AndroidManifest.xml - # inflating: proguard.txt - # inflating: classes.jar - # inflating: jni/armeabi-v7a/libgojni.so - # inflating: jni/arm64-v8a/libgojni.so - # inflating: jni/x86/libgojni.so - # inflating: jni/x86_64/libgojni.so - # inflating: R.txt - # creating: res/ - # /usr/bin/jar - unzip firestack-debug.aar - which jar && jar tf ./classes.jar - gsa jni/arm64-v8a/*.so -f text --verbose - - #pip install sqlelf - - #sqlelf jni/arm64-v8a/libgojni.so --sql \ - # "SELECT mnemonic, COUNT(*) from elf_instructions GROUP BY mnemonic ORDER BY 2 DESC LIMIT 20" - - #sqlelf jni/arm64-v8a/libgojni.so --sql \ - # "SELECT * from elf_headers" - - # determine NEEDED entries - #sqlelf jni/arm64-v8a/libgojni.so --sql \ - # "SELECT elf_strings.path, elf_strings.value - # FROM elf_dynamic_entries - # INNER JOIN elf_strings ON elf_dynamic_entries.value = elf_strings.offset - # WHERE elf_dynamic_entries.tag = 'NEEDED'" - - # determine the largest functions - #sqlelf jni/arm64-v8a/libgojni.so --sql \ - # "SELECT name AS function_name, (high_pc - low_pc) AS function_size - # FROM dwarf_dies - # WHERE tag = 'DW_TAG_subprogram' - # ORDER BY function_size DESC - # LIMIT 50;" - readelf -l jni/arm64-v8a/*.so - - # from: cs.android.com/android/platform/superproject/main/+/main:system/extras/tools/check_elf_alignment.sh;drc=97bcb31779;l=87 - RED="\e[31m" - GREEN="\e[32m" - ENDCOLOR="\e[0m" - - unaligned_libs=() - - echo - echo "=== ELF alignment ===" - - matches="$(find jni/ -type f)" - IFS=$'\n' - for match in $matches; do - [[ $(file "${match}") == *"ELF"* ]] || continue - - readelf -l "${match}" - - res="$(objdump -p "${match}" | grep LOAD | awk '{ print $NF }' | head -1)" - if [[ $res =~ 2\*\*(1[4-9]|[2-9][0-9]|[1-9][0-9]{2,}) ]]; then - echo -e "${match}: ${GREEN}ALIGNED${ENDCOLOR} ($res)" - else - echo -e "${match}: ${RED}UNALIGNED${ENDCOLOR} ($res)" - unaligned_libs+=("${match}") - fi - done - - if [ ${#unaligned_libs[@]} -gt 0 ]; then - echo -e "${RED}Found ${#unaligned_libs[@]} unaligned libs (only arm64-v8a/x86_64 libs need to be aligned).${ENDCOLOR}" - elif [ -n "${dir_filename}" ]; then - echo -e "ELF Verification Successful" - fi - echo "=====================" - shell: bash - - - name: ๐Ÿ”ฎ Vet - if: ${{ steps.make.outcome == 'success' }} - run: | - # github.com/actions/setup-go/issues/27 - export PATH=${PATH}:`go env GOPATH`/bin - - # ref: go.dev/blog/gofix - # vet: fails: archive.is/XcDl6 - go vet ./... - - # staticcheck - # go install honnef.co/go/tools/cmd/staticcheck@latest - # staticcheck ./... - - # nil checks - go install go.uber.org/nilaway/cmd/nilaway@latest - nilaway ./... - - # or: github.com/imjasonh/govulncheck-action - go install golang.org/x/vuln/cmd/govulncheck@latest - govulncheck -show verbose -test ./... - - # TODO: store output to compare with subsequent runs - # security.googleblog.com/2023/09/capslock-what-is-your-code-really.html - # github.com/google/capslock/tree/4bb7636/docs - go install github.com/google/capslock/cmd/capslock@latest - # github.com/hasansino/go42/blob/3be871dcfe/.github/workflows/130-security.yaml#L199 - capslock -packages ./... -output v - shell: bash - - # TODO: pin actions deps: github.com/stacklok/frizbee - # github.com/marketplace/actions/anchore-container-scan - - name: ๐ŸŽž Grype report - if: ${{ steps.sbom-gen.outcome == 'success' }} - uses: anchore/scan-action@v7 - id: gr - with: - sbom: "${{ github.workspace }}/_manifest/spdx_2.2/manifest.spdx.json" - fail-build: false - # severity-cutoff: critical - output-format: "sarif" # or "table", "json" etc - - - name: ๐Ÿ“ป Grype to code-scanning - if: success() - continue-on-error: true - uses: github/codeql-action/upload-sarif@v4 - with: - sarif_file: ${{ steps.gr.outputs.sarif }} - - - name: ๐Ÿ›ธ After Make - if: ${{ steps.make.outcome == 'success' && steps.sbom-upload.outcome == 'success' }} - id: post-build - run: | - # grype output - if [[ -n "${GRYPE_SARIF}" && -f "${GRYPE_SARIF}" ]]; then - ls -ltr "${GRYPE_SARIF}" - else - echo "::notice::Grype report missing" - fi - - # 10 chars of the commit sha - VCSVER="${GITHUB_SHA:0:10}" - - subjects='[]' - for key in FOUT FOUTDBG; do - artifact="${!key}" - if [ ! -f "$artifact" ]; then - echo "::error::missing artifact ${artifact}" - exit 12 - fi - digest="sha256:$(sha256sum "$artifact" | awk '{print $1}')" - subjects=$(jq -c --arg name "$artifact" --arg digest "$digest" '. + [{ "name": $name, "digest": $digest }]' <<<"$subjects") - done - - sbom_path="${SBOM_PATH}/${SBOM_FNAME}" - if [ ! -f "$sbom_path" ]; then - echo "::error::missing SBOM at ${sbom_path}" - exit 13 - fi - sbom_digest="sha256:$(sha256sum "$sbom_path" | awk '{print $1}')" - if [ -z "${SBOM_ARTIFACT_ID}" ]; then - echo "::error::missing SBOM artifact id" - exit 14 - fi - sbominfo=$( - jq -c -n \ - --argjson subjects "$subjects" \ - --arg artifactId "$SBOM_ARTIFACT_ID" \ - --arg artifactName "$SBOM_ARTIFACT_NAME" \ - --arg path "$SBOM_FNAME" \ - --arg digest "$sbom_digest" \ - '{subjects:$subjects,artifactId:$artifactId,artifactName:$artifactName,path:$path,digest:$digest}' - ) - - # api: docs.jitpack.io/api/#builds - curl https://jitpack.io/api/builds/com.github.celzero/firestack/${VCSVER} | jq || true - # print sbominfo and subjects - echo "== Subjects ==" - jq . <<<${subjects} - echo "== SBOM Info ==" - jq . <<<${sbominfo} - - printf 'subjects=%s\n' "$subjects" >> "$GITHUB_OUTPUT" - printf 'sbom-info=%s\n' "$sbominfo" >> "$GITHUB_OUTPUT" - printf 'vcs-ver=%s\n' "$VCSVER" >> "$GITHUB_OUTPUT" - printf 'publish=%s\n' "$PUBLISH" >> "$GITHUB_OUTPUT" - shell: bash - env: - SBOM_PATH: _manifest/spdx_2.2/ - SBOM_FNAME: manifest.spdx.json - SBOM_ARTIFACT_ID: ${{ steps.sbom-upload.outputs.artifact-id }} - SBOM_ARTIFACT_NAME: ${{ format('firestack-sbom-{0}', github.sha) }} - GRYPE_SARIF: ${{ steps.gr.outputs.sarif }} - PUBLISH: ${{ github.event_name == 'workflow_dispatch' }} - - attestation: - name: ๐Ÿชช Artifact attestations - needs: build - if: ${{ needs.build.result == 'success' && needs.build.outputs.artifact-subjects != '' && needs.build.outputs.publish == 'true' }} - uses: ./.github/workflows/provenance.yml - with: - subjects: ${{ needs.build.outputs.artifact-subjects }} - sbom-info: ${{ needs.build.outputs.sbom-info }} - - publish: - name: ๐Ÿšš Publish - needs: - - build - - attestation - if: ${{ needs.build.result == 'success' && needs.attestation.result == 'success' && needs.build.outputs.publish == 'true' }} - uses: ./.github/workflows/publish.yml - with: - run_id: ${{ github.run_id }} - vcsver: ${{ needs.build.outputs.vcs-ver }} - artifact_subjects: ${{ needs.build.outputs.artifact-subjects }} - sbom_info: ${{ needs.build.outputs.sbom-info }} - secrets: inherit - - osv: - name: ๐Ÿ›ก๏ธ OSV scanner - runs-on: ubuntu-latest - permissions: - contents: read - packages: read - actions: read - security-events: write - - # github.com/hasansino/go42/blob/3be871dcfe/.github/workflows/140-security-extra.yaml#L102 - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - - - name: ๐Ÿ” Scan - id: osv-scan - continue-on-error: true - uses: google/osv-scanner-action/osv-scanner-action@v2.2.2 - with: - scan-args: --output=osv-scanner-results.json --format=json --all-vulns --recursive ./ - - - name: ๐Ÿงพ Report - if: ${{ steps.osv-scan.outcome != 'skipped' }} - continue-on-error: true - uses: google/osv-scanner-action/osv-reporter-action@v2.2.2 - with: - scan-args: |- - --output=osv-scanner-results.sarif - --new=osv-scanner-results.json - --gh-annotations=true - --fail-on-vuln=true - --all-vulns - - - name: โš–๏ธ Licenses - continue-on-error: true - run: | - # installing osv-scanner is ... expensive - # github.com/google/osv-scanner/blob/main/README.md - if ! command -v osv-scanner >/dev/null 2>&1; then - go install github.com/google/osv-scanner/v2/cmd/osv-scanner@latest - fi - osv-scanner --licenses . - - # also: github.com/oss-review-toolkit/ort - - # github.com/google/go-licenses?tab=readme-ov-file#build-tags - # go install github.com/google/go-licenses/v2@latest - # github.com/google/licenseclassifier/blob/e6a9bb99b5/license_type.go#L28 - # go-licenses check ./... --include_tests --allowed_licenses=notice,permissive,reciprocal,unencumbered - - - name: ๐Ÿš€ Upload - if: ${{ steps.osv-scan.outcome != 'skipped' }} - uses: actions/upload-artifact@v6 - with: - name: "osv-scanner-results-${{ github.sha }}" - path: osv-scanner-results.sarif - retention-days: 72 - compression-level: 9 - - - name: ๐Ÿ“ก OSV to code-scanning - if: ${{ steps.osv-scan.outcome != 'skipped' }} - continue-on-error: true - uses: github/codeql-action/upload-sarif@v4 - with: - sarif_file: osv-scanner-results.sarif - - # github.com/oss-review-toolkit/ort-ci-github-action - # github.com/oss-review-toolkit/ort - ort: - name: ๐ŸŒˆ ORT - permissions: - contents: read - packages: read - actions: read - security-events: write - runs-on: ubuntu-latest - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - - name: ๐Ÿญ Run - uses: oss-review-toolkit/ort-ci-github-action@main - with: - run: > - cache-dependencies, - metadata-labels, - analyzer, - advisor, - reporter, - upload-results, - upload-evaluation-result - - checker: - name: ๐Ÿ—ณ๏ธ Security checker - runs-on: ubuntu-latest - permissions: - security-events: write - id-token: write - env: - GO111MODULE: on - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - - name: ๐Ÿ“€ Gosec Scanner - uses: securego/gosec@master - with: - # github.com/securego/gosec/issues/1219 - # we let the report trigger content trigger a failure using the GitHub Security features. - args: '-no-fail -fmt sarif -out results.sarif ./...' - - name: ๐Ÿ“ก Upload to code-scanning - uses: github/codeql-action/upload-sarif@v4 - with: - sarif_file: results.sarif - - # from: github.com/golangci/golangci-lint-action - golangci-lint: - name: ๐Ÿงญ Lint - runs-on: ubuntu-latest - permissions: - # Required: allow read access to the content for analysis. - contents: read - # Optional: allow read access to pull request. Use with `only-new-issues` option. - pull-requests: read - # Optional: Allow write access to checks to allow the action to annotate code in the PR. - checks: write - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - - name: ๐Ÿผ Set up Go - uses: actions/setup-go@v6 - with: - go-version: '>=1.26' - cache: false - # github.com/tailscale/tailscale/blob/93324cc7b/.github/workflows/depaware.yml - # consolidated in: github.com/tailscale/tailscale/commit/4022796484 - - name: ๐Ÿ“ฆ Depaware - run: | - # Error: ../../../go/pkg/mod/github.com/tailscale/depaware@v0.0.0-20251001183927-9c2ad255ef3f/depaware/depaware.go:238:17: undefined: imports.VendorlessPath - # Error: Process completed with exit code 1. - # go run github.com/tailscale/depaware github.com/celzero/firestack/intra - # go run github.com/tailscale/depaware github.com/celzero/firestack/tunnel - - name: ๐Ÿ… Lint - uses: golangci/golangci-lint-action@v9.2.0 - with: - args: --config=.golangci.yml --issues-exit-code=0 - - name: ๐Ÿ“ก Staticheck - uses: dominikh/staticcheck-action@v1.4.0 - with: - version: "latest" - install-go: false - - codeql: - name: ๐Ÿงฉ CodeQL & Zizmor - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - with: - persist-credentials: false - - - name: ๐Ÿ› ๏ธ CodeQL Init - uses: github/codeql-action/init@v4 - with: - # github.com/github/codeql-action/blob/ecec1f88769/init/action.yml#L137 - debug: true - languages: "go" - # docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs - queries: security-extended,security-and-quality - - # also uploads to code-scanning - # github.com/github/codeql-action/blob/ecec1f88769/analyze/action.yml#L92 - # sarif_file is at steps.cqla.outputs.sarif-output - - name: ๐Ÿงฉ CodeQL Analysis - uses: github/codeql-action/analyze@v4 - with: - category: "/language:go" - - - name: ๐Ÿ’ค Zizmor - if: always() - uses: zizmorcore/zizmor-action@e639db99335bc9038abc0e066dfcd72e23d26fb4 # v0.3.0 - with: - advanced-security: true diff --git a/.github/workflows/provenance.yml b/.github/workflows/provenance.yml deleted file mode 100644 index cb4c06e1..00000000 --- a/.github/workflows/provenance.yml +++ /dev/null @@ -1,91 +0,0 @@ -name: Build Provenance - -on: - workflow_call: - inputs: - subjects: - description: | - JSON array of {"name": string, "digest": "sha256:"} pairs that - represent artifacts requiring provenance attestations. See - https://github.blog/enterprise-software/devsecops/enhance-build-security-and-reach-slsa-level-3-with-github-artifact-attestations/ - for background on why reusable workflows are required at SLSA Level 3. - required: true - type: string - sbom-info: - description: | - Optional JSON blob containing SBOM attestation metadata in the form - {"subjects": [...], "artifact": string, "path": string, "digest": string}. - required: false - type: string - -jobs: - - provenance: - name: ๐Ÿ” Build attestation - if: ${{ inputs.subjects != '' && inputs.subjects != '[]' }} - runs-on: ubuntu-latest - permissions: - id-token: write - attestations: write - contents: read - strategy: - matrix: - subject: ${{ fromJson(inputs.subjects) }} - - steps: - # attest provenance for uploaded artifacts - # github.com/actions/attest-build-provenance?tab=readme-ov-file#integration-with-actionsupload-artifact - # docs.github.com/en/actions/how-tos/secure-your-work/use-artifact-attestations/use-artifact-attestations#generating-build-provenance-for-binaries - # also: buildsec.github.io/frsca/docs/slsa/frsca-slsa/ - - name: โœ… Attest provenance - uses: actions/attest-build-provenance@v3 - with: - subject-name: ${{ matrix.subject.name }} - subject-digest: ${{ matrix.subject.digest }} - - sbom: - name: ๐Ÿ“ฆ SBOM attestation - if: ${{ inputs.sbom-info != '' }} - runs-on: ubuntu-latest - permissions: - id-token: write - attestations: write - contents: read - strategy: - matrix: - subject: ${{ fromJson(inputs.sbom-info).subjects }} - steps: - - name: โฌ‡๏ธ Download SBOM - uses: actions/download-artifact@v4 - with: - artifact-ids: ${{ fromJson(inputs.sbom-info).artifactId }} - path: sbom - - - name: ๐Ÿ” Verify SBOM - id: oksbom - shell: bash - env: - SBOM_PATH: ${{ format('sbom/{0}/{1}', fromJson(inputs.sbom-info).artifactName, fromJson(inputs.sbom-info).path) }} - SBOM_DIGEST: ${{ fromJson(inputs.sbom-info).digest }} - run: | - set -euo pipefail - ls -ltr sbom - - file="${SBOM_PATH}" - if [ ! -f "$file" ]; then - echo "missing SBOM file: $file" >&2 - exit 1 - fi - - digest="${SBOM_DIGEST#sha256:}" - echo "${digest} ${file}" | sha256sum -c - - printf 'sbom-path=%s\n' "${SBOM_PATH}" >> "$GITHUB_OUTPUT" - - # github.com/actions/attest-sbom - # docs.github.com/en/actions/how-tos/secure-your-work/use-artifact-attestations/use-artifact-attestations#generating-an-sbom-attestation-for-binaries - - name: ๐Ÿงถ Attest SBOM - uses: actions/attest-sbom@v3 - with: - subject-name: ${{ matrix.subject.name }} - subject-digest: ${{ matrix.subject.digest }} - sbom-path: ${{ steps.oksbom.outputs.sbom-path }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index c2685cd8..00000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,338 +0,0 @@ -name: Publisher - -on: - workflow_dispatch: - inputs: - build_run_id: - description: "Workflow URL or Run ID that produced build artifacts" - required: true - type: string - verify_attestation: - description: "Verify build provenance attestations before publishing" - required: false - type: boolean - default: false - workflow_call: - inputs: - # docs.github.com/en/actions/reference/workflows-and-actions/contexts - run_id: - description: "Workflow run id that produced signed artifacts" - required: true - type: string - vcsver: - description: "Short git version" - required: false - type: string - artifact_subjects: - description: "JSON array of artifact subjects with sha256 digests" - required: false - type: string - sbom_info: - description: "SBOM info JSON blob with subjects and digest" - required: false - type: string - -jobs: - - publish: - name: ๐Ÿšš Publish - runs-on: ubuntu-latest - env: - # docs.github.com/en/actions/reference/workflows-and-actions/contexts#github-context - GROUP_GITHUB: ${{ format('com.github.{0}', github.repository_owner) }} - GROUP_OSSRH: com.celzero - # project artifactId; see: pom.xml - ARTIFACT: firestack - REPO_GITHUB: github - # central.sonatype.org/pages/ossrh-eol - # or "central" - REPO_OSSRH: ossrh - # artefact type - PACK: aar - # final out from make-aar - FOUT: firestack.aar - FOUTDBG: firestack-debug.aar - # artifact classifier; full unused - CLASSFULL: full - CLASSDBG: debug - # artifact bytecode sources - SOURCES: build/intra/tun2socks-sources.jar - # POM for Maven Central - POM_OSSRH: ossrhpom.xml - DIST_DIR: dist - RUN_ID: ${{ inputs.run_id || inputs.build_run_id }} - VCSVER_INPUT: ${{ inputs.vcsver }} - # workflow input constants - ARTIFACT_SUBJECTS: ${{ inputs.artifact_subjects }} - SBOM_INFO: ${{ inputs.sbom_info }} - ARTIFACT_PATTERN: "firestack-aar-*" - SBOM_PATTERN: "firestack-sbom-*" - ARTIFACT_PREFIX: "firestack-aar-" - SBOM_PREFIX: "firestack-sbom-" - SBOM_MANIFEST: "manifest.spdx.json" - SBOM_PREDICATE: "https://spdx.dev/Document/v2.2" - - permissions: - contents: read - actions: read - attestations: read - packages: write - - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - with: - persist-credentials: false - - - name: ๐Ÿ“€ Metadata - id: runmeta - env: - RUN_ID_OG: ${{ env.RUN_ID }} - RUN_ID: ${{ env.RUN_ID }} - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - # Allow RUN_ID to be passed as a full GitHub Actions run URL. - # Example: https://github.com/celzero/firestack/actions/runs/20923345284 - # Or: https://github.com/celzero/firestack/actions/runs/23057719555/job/66975665220 - RUN_ID="${RUN_ID%%\?*}" # strip query - RUN_ID="${RUN_ID%%\#*}" # strip fragment - RUN_ID="${RUN_ID%/}" # strip trailing slash - case "$RUN_ID" in - *github.com/*/actions/runs/*) - RUN_ID="${RUN_ID#*actions/runs/}" - RUN_ID="${RUN_ID%%/*}" - ;; - esac - echo "::notice::Using Run ID: $RUN_ID (in: $RUN_ID_OG)" - # Export normalized run id for later steps. - printf 'run_id=%s\n' "$RUN_ID" >> "$GITHUB_OUTPUT" - if [ -n "${VCSVER_INPUT:-}" ]; then - printf 'vcsver=%s\n' "${VCSVER_INPUT}" >> "$GITHUB_OUTPUT" - exit 0 - fi - info=$(gh run view "$RUN_ID" --json headSha,headBranch,event,displayTitle) - echo "$info" | jq - sha=$(echo "$info" | jq -r '.headSha') - if [ -z "$sha" ] || [ "$sha" = "null" ]; then - echo "::error::unable to resolve head sha for run $RUN_ID" >&2 - exit 11 - fi - # git version (short commit sha) - printf 'sha=%s\n' "$sha" >> "$GITHUB_OUTPUT" - printf 'vcsver=%s\n' "${sha:0:10}" >> "$GITHUB_OUTPUT" - - - name: โฌ‡๏ธ Download artifacts - id: dlaar - uses: actions/download-artifact@v7 - with: - pattern: ${{ env.ARTIFACT_PATTERN }} - run-id: ${{ steps.runmeta.outputs.run_id }} - github-token: ${{ github.token }} - path: ${{ env.DIST_DIR }} - - - name: โฌ‡๏ธ Download SBOM artifact - id: dlsbom - uses: actions/download-artifact@v7 - with: - pattern: ${{ env.SBOM_PATTERN }} - run-id: ${{ steps.runmeta.outputs.run_id }} - github-token: ${{ github.token }} - path: ${{ env.DIST_DIR }} - - - name: ๐Ÿ” Verify build provenance - if: ${{ inputs.verify_attestation == true || github.event.event_name == 'workflow_call' }} - env: - REPO: ${{ github.repository }} - ART_DIR: ${{ steps.dlaar.outputs.download-path }} - GH_TOKEN: ${{ github.token }} - SHA: ${{ steps.runmeta.outputs.sha }} - run: | - set -xeuo pipefail - ls -ltr "${ART_DIR}/" - # need to go one dir further for download-artifact v4 not v7 - # ART_DIR="${ART_DIR}/${ARTIFACT_PREFIX}${SHA}" - # ls -ltr "${ART_DIR}/" - for file in "$ART_DIR/${FOUT}" "$ART_DIR/${FOUTDBG}"; do - if [ ! -f "$file" ]; then - echo "::error::missing artifact $file" >&2 - exit 12 - fi - gh attestation verify "$file" -R "$REPO" - done - - if [ -n "${ARTIFACT_SUBJECTS:-}" ]; then - jq -c '.[]' <<<"${ARTIFACT_SUBJECTS}" | while read -r subject; do - name=$(jq -r '.name' <<<"$subject") - digest=$(jq -r '.digest' <<<"$subject") - file="${ART_DIR}/${name##*/}" - if [ ! -f "$file" ]; then - echo "::error::missing artifact $file for digest check" >&2 - exit 13 - fi - want=${digest#sha256:} - got=$(sha256sum "$file" | awk '{print $1}') - if [ "$got" != "$want" ]; then - echo "::error::digest mismatch for $file (got $got, want $want)" >&2 - exit 14 - fi - done - fi - - - name: ๐Ÿ” Verify SBOM attestation - if: ${{ inputs.verify_attestation == true || github.event.event_name == 'workflow_call' }} - env: - REPO: ${{ github.repository }} - ART_DIR: ${{ steps.dlaar.outputs.download-path }} - SBOM_DIR: ${{ steps.dlsbom.outputs.download-path }} - GH_TOKEN: ${{ github.token }} - SHA: ${{ steps.runmeta.outputs.sha }} - run: | - # andrewlock.net/creating-sbom-attestations-in-github-actions/ - set -xeuo pipefail - ls -ltr "${SBOM_DIR}/" - # need to go one dir further for download-artifact v4 not v7 - # SBOM_DIR="${SBOM_DIR}/${SBOM_PREFIX}${SHA}" - # ls -ltr "${SBOM_DIR}/" - if [ -n "${SBOM_INFO:-}" ]; then - name=$(jq -r '.path' <<<"${SBOM_INFO}") - sbom_file="$SBOM_DIR/$(jq -r '.artifactName' <<<"${SBOM_INFO}")/${name}" - digest=$(jq -r '.digest' <<<"${SBOM_INFO}") - - # github.com/celzero/firestack/blob/86af89da10abe/.github/workflows/go.yml#L398 - jq -c '.subjects[]' <<<"$SBOM_INFO" | while read -r subject; do - name=$(jq -r '.name' <<<"$subject") - file="${ART_DIR}/${name##*/}" - if [ ! -f "$file" ]; then - echo "::error::missing SBOM subject artifact $file" - exit 14 - fi - gh attestation verify "$file" -R "$REPO" --predicate-type "$predicate" - echo "Verified SBOM subject artifact $file" - done - else - sbom_file=$(find "${SBOM_DIR}" -name "${SBOM_MANIFEST}" -print -quit) - digest=$(cat < $(find "${SBOM_DIR}" -name "${SBOM_MANIFEST}.sha256" -print -quit)) - - # iterate in ART_DIR for artifacts then verify sbom attestations - for file in "$ART_DIR/${FOUT}" "$ART_DIR/${FOUTDBG}"; do - if [ ! -f "$file" ]; then - echo "::error::missing artifact $file for SBOM attestation check" >&2 - exit 13 - fi - gh attestation verify "$file" -R "$REPO" --predicate-type "${SBOM_PREDICATE}" - echo "Verified SBOM attestation for artifact $file" - done - fi - - if [ -z "$sbom_file" ]; then - echo "::error::SBOM file not found in ${SBOM_DIR}/" >&2 - exit 15 - fi - - if [ -n "$digest" ] && [ "$digest" != "null" ]; then - want=${digest#sha256:} - got=$(sha256sum "$sbom_file" | awk '{print $1}') - if [ "$got" != "$want" ]; then - echo "::error::SBOM digest mismatch (got $got, want $want)" >&2 - exit 16 - fi - else - echo "No SBOM digest; skipping digest verification" >&2 - fi - - - name: ๐Ÿท๏ธ Setup for GitHub Packages - uses: actions/setup-java@v5 - with: - java-version: '17' - distribution: 'temurin' - - # docs.github.com/en/actions/tutorials/build-and-test-code/java-with-maven - # docs.github.com/en/actions/tutorials/publish-packages/publish-java-packages-with-maven#publishing-packages-to-github-packages - - name: ๐Ÿ˜บ Publish to GitHub Packages - shell: bash - env: - REPOSITORY: ${{ github.repository }} - GITHUB_ACTOR: ${{ github.actor }} - GITHUB_TOKEN: ${{ github.token }} - VCSVER: ${{ steps.runmeta.outputs.vcsver }} - run: | - echo "::notice::Publishing version ${VCSVER} to GitHub Packages" - # uploaded at: - # maven.pkg.github.com/celzero/firestack/com/github/celzero/firestack//firestack-.aar - # github.com/deelaa-marketplace/commons-workflow/blob/637dc111/flows/publish-api.yml#L49 - # github.com/markocto/cf-octopub/blob/bba2de2c/github/script/action.yaml#L118 - # publish both stripped and debug - mvn deploy:deploy-file \ - -DgroupId="${GROUP_GITHUB}" \ - -DartifactId="${ARTIFACT}" \ - -Dversion="$VCSVER" \ - -Dpackaging="${PACK}" \ - -Dfile="${DIST_DIR}/${FOUT}" \ - -Dfiles="${DIST_DIR}/${FOUTDBG}" \ - -Dtypes="${PACK}" \ - -Dclassifiers=${CLASSDBG} \ - -DrepositoryId="${REPO_GITHUB}" \ - -Dsources="${DIST_DIR}/${SOURCES}" \ - -Durl="https://maven.pkg.github.com/${REPOSITORY}" - - # central.sonatype.org/publish/publish-portal-api/#authentication-authorization - # github.com/slsa-framework/slsa-github-generator/blob/4876e96b8268/actions/maven/publish/action.yml#L49 - # docs.github.com/en/actions/tutorials/publish-packages/publish-java-packages-with-maven#publishing-packages-to-the-maven-central-repository-and-github-packages - - name: ๐Ÿ›๏ธ Setup for Maven Central - uses: actions/setup-java@v4 - with: - java-version: '17' - distribution: 'temurin' - server-id: ossrh - server-username: MAVEN_USERNAME - server-password: MAVEN_PASSWORD - gpg-private-key: ${{ secrets.OSSRH_CELZERO_GPG_PRIVATE_KEY }} - gpg-passphrase: ${{ secrets.OSSRH_CELZERO_GPG_PASSPHRASE }} - - - name: ๐Ÿ“ฆ Publish to Maven Central - shell: bash - env: - MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} - MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} - MAVEN_NS: ${{ secrets.OSSRH_CELZERO_NAMESPACE }} - MAVEN_GPG_PASSPHRASE: ${{ secrets.OSSRH_CELZERO_GPG_PASSPHRASE }} - VCSVER: ${{ steps.runmeta.outputs.vcsver }} - run: | - echo "::notice::Publishing version ${VCSVER} to Maven Central" - mvn -f ${POM_OSSRH} versions:set -DnewVersion=${VCSVER} -DgenerateBackupPoms=false - # central.sonatype.org/publish/publish-portal-ossrh-staging-api/#getting-started-for-maven-api-like-plugins - # github.com/videolan/vlc-android/blob/c393dd0699/buildsystem/maven/deploy-to-mavencentral.sh#L119 - - mvn gpg:sign-and-deploy-file \ - -DgroupId="${GROUP_OSSRH}" \ - -DartifactId="${ARTIFACT}" \ - -Dversion="$VCSVER" \ - -Dpackaging="${PACK}" \ - -Dfile="${DIST_DIR}/${FOUT}" \ - -DrepositoryId="${REPO_OSSRH}" \ - -DpomFile=${POM_OSSRH} \ - -Dgpg.keyname=C3F3F4A160BB2CFFB5528699F19CE6642C40085C \ - -Dsources="${DIST_DIR}/${SOURCES}" \ - -Durl="https://ossrh-staging-api.central.sonatype.com/service/local/staging/deploy/maven2/" - - mvn gpg:sign-and-deploy-file \ - -DgroupId="${GROUP_OSSRH}" \ - -DartifactId="${ARTIFACT}" \ - -Dversion="$VCSVER" \ - -Dpackaging="${PACK}" \ - -Dfile="${DIST_DIR}/${FOUTDBG}" \ - -Dclassifier=${CLASSDBG} \ - -DrepositoryId="${REPO_OSSRH}" \ - -DgeneratePom=false \ - -Dgpg.keyname=C3F3F4A160BB2CFFB5528699F19CE6642C40085C \ - -Durl="https://ossrh-staging-api.central.sonatype.com/service/local/staging/deploy/maven2/" - - # central.sonatype.org/publish/publish-portal-api/#authentication-authorization - tok=$(printf "${MAVEN_USERNAME}:${MAVEN_PASSWORD}" | base64) - - # central.sonatype.org/publish/publish-portal-ossrh-staging-api/#1-modify-your-ci-script - # central.sonatype.org/publish/publish-portal-ossrh-staging-api/#post-to-manualuploaddefaultrepositorynamespace - # auth required for publishing_type=automatic - curl -D - -X POST -H "Authorization: Bearer ${tok}" \ - "https://ossrh-staging-api.central.sonatype.com/manual/upload/defaultRepository/${GROUP_OSSRH}?publishing_type=automatic" \ No newline at end of file diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml deleted file mode 100644 index 634cc4a8..00000000 --- a/.github/workflows/scorecard.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: Scorecard - -on: - # For Branch-Protection check. Only the default branch is supported. See - # github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection - branch_protection_rule: - # To guarantee Maintained check is occasionally updated. See - # github.com/ossf/scorecard/blob/main/docs/checks.md#maintained - schedule: - - cron: '53 21 * * 0' - push: - branches: [ "n2" ] - paths: - - '**/*.go' - -# Declare default permissions as read only. -permissions: read-all - -jobs: - analysis: - name: ๐ŸŽฒ Scorecard analysis - runs-on: ubuntu-latest - permissions: - # Needed to upload the results to code-scanning dashboard. - security-events: write - # Needed to publish results and get a badge (see publish_results below). - id-token: write - # Uncomment the permissions below if installing in a private repository. - # contents: read - # actions: read - - # ref: github.com/ossf/scorecard/blob/main/.github/workflows/scorecard-analysis.yml - steps: - - name: ๐Ÿฅ Checkout - uses: actions/checkout@v6 - with: - persist-credentials: false - - - name: ๐ŸŽ Run analysis - uses: ossf/scorecard-action@v2.4.3 - with: - results_file: results.sarif - results_format: sarif - # (Optional) "write" PAT token. Uncomment the `repo_token` line below if: - # - you want to enable the Branch-Protection check on a *public* repository, or - # - you are installing Scorecard on a *private* repository - # To create the PAT, follow the steps in github.com/ossf/scorecard-action#authentication-with-pat. - # repo_token: ${{ secrets.SCORECARD_TOKEN }} - - # Public repositories: - # - Publish results to OpenSSF REST API for easy access by consumers - # - Allows the repository to include the Scorecard badge. - # - See github.com/ossf/scorecard-action#publishing-results. - # For private repositories: - # - `publish_results` will always be set to `false`, regardless - # of the value entered here. - publish_results: true - - # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF - # format to the repository Actions tab. - - name: ๐ŸŒ๏ธโ€โ™‚๏ธ Upload artifact - uses: actions/upload-artifact@v4 - with: - name: SARIF file - path: results.sarif - retention-days: 21 - - # Upload the results to GitHub's code scanning dashboard. - - name: โ›ณ๏ธ Upload to code-scanning - uses: github/codeql-action/upload-sarif@v4 - with: - sarif_file: results.sarif diff --git a/.github/workflows/xbom.yaml b/.github/workflows/xbom.yaml deleted file mode 100644 index dd53b593..00000000 --- a/.github/workflows/xbom.yaml +++ /dev/null @@ -1,50 +0,0 @@ -name: Cryptography BoM - -on: - workflow_dispatch: - -jobs: - - # github.com/advanced-security/cbom-action - build-matrix: - name: ๐Ÿš Repo analysis - runs-on: ubuntu-latest - outputs: - repositories: ${{ steps.rm.outputs.repositories }} - steps: - - name: ๐Ÿš€ Build analysis matrix - uses: advanced-security/cbom-action/build-matrix@v1 - id: rm - with: - repositoryNameWithOwner: ${{ github.repository }} - analyzeDependencies: true - minimumLanguageBytes: 0 - - run-cbom-action: - name: ๐Ÿ“œ ${{ fromJson(matrix.repository).nameWithOwner }} - ${{ fromJson(matrix.repository).language }} - runs-on: ubuntu-latest - needs: build-matrix - continue-on-error: true - strategy: - fail-fast: false - matrix: - repository: ${{ fromJSON(needs.build-matrix.outputs.repositories) }} - - steps: - - name: ๐Ÿ’ˆ CBOM run - uses: advanced-security/cbom-action/analyze@d5f28cfce2a516c74cae4ebb296a427eb51f62ec # 11 Dec 25 - with: - repositoryNameWithOwner: ${{ fromJson(matrix.repository).nameWithOwner }} - language: ${{ fromJson(matrix.repository).language }} - createCodeQLDatabaseIfRequired: true - uploadToCodeScanning: false - requestGitHubAnalysis: false - queryTimeout: 500 - - cbom-summary: - name: ๐ŸŒ‹ CBOM results - runs-on: ubuntu-latest - needs: run-cbom-action - steps: - - name: ๐Ÿ”… Summarize - uses: advanced-security/cbom-action/workflow-summary@d5f28cfce2a516c74cae4ebb296a427eb51f62ec # 11 Dec 25 diff --git a/.gitignore b/.gitignore index dd8c5e32..66fd13c9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,30 +1,15 @@ -/build -/bin -intra/split/example/example +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib -# General -.DS_Store -.AppleDouble -.LSOverride +# Test binary, built with `go test -c` +*.test -# Icon must end with two \r -Icon +# Output of the go coverage tool, specifically when used with LiteIDE +*.out -# Thumbnails -._* - -# Files that might appear in the root of a volume -.DocumentRevisions-V100 -.fseventsd -.Spotlight-V100 -.TemporaryItems -.Trashes -.VolumeIcon.icns -.com.apple.timemachine.donotpresent - -# Directories potentially created on remote AFP share -.AppleDB -.AppleDesktop -Network Trash Folder -Temporary Items -.apdisk +# Dependency directories (remove the comment below to include it) +# vendor/ diff --git a/.golangci.yml b/.golangci.yml deleted file mode 100644 index 6a3587e3..00000000 --- a/.golangci.yml +++ /dev/null @@ -1,59 +0,0 @@ -version: "2" -# from: github.com/uber-go/guide/blob/621006f/.golangci.yml -run: - # timeout: 5m - modules-download-mode: readonly - issues-exit-code: 0 - tests: false - -output: - show-stats: true - # print-issued-lines: true - # uniq-by-line: true - # sort-results: true - # print-linter-name: true - -linters: - enable: - - errcheck - - errchkjson - - gosec - - gocyclo - - mnd - - loggercheck - - err113 - - revive - - misspell - - govet - - goconst - - staticcheck - settings: - cyclop: - max-complexity: 5 - package-average: 0.5 - exclusions: - presets: - - comments - - common-false-positives - - legacy - - std-error-handling - generated: disable - rules: - - path: '(.+)_test\.go' - linters: - - cyclop - - gocyclo - - misspell - - goconst - -formatters: - enable: - - gci - - gofmt - - gofumpt - - goimports - -issues: - max-issues-per-linter: 0 - max-same-issues: 0 - uniq-by-line: true diff --git a/Makefile b/Makefile deleted file mode 100644 index f4030c59..00000000 --- a/Makefile +++ /dev/null @@ -1,108 +0,0 @@ -BUILDDIR=$(CURDIR)/build -TOOLSDIR=$(CURDIR)/tools -GOBIN=$(CURDIR)/bin -GOMOBILE=$(GOBIN)/gomobile -GOPATCHOVERLAY=$(GOBIN)/go-patch-overlay -IMPORT_PATH=github.com/celzero/firestack -ELECTRON_PATH=$(IMPORT_PATH)/outline/electron -XGO=$(GOBIN)/xgo -COMMIT_ID=$(shell git rev-parse --short HEAD) -DATESTR=$(shell date -u +'%Y%m%d%H%M%S') -XGO_LDFLAGS='-s -w -X main.version=$(COMMIT_ID)' -# github.com/xjasonlyu/tun2socks/blob/bf745d0e0/Makefile#L14 -LDFLAGS_DEBUG='-checklinkname=0 -buildid= -X $(IMPORT_PATH)/intra/core.Date=$(DATESTR) -X $(IMPORT_PATH)/intra/core.Commit=$(COMMIT_ID)' -# checklinkname to override runtime.secureMode; see: core/overreach.go -# github.com/golang/go/issues/69868 -LDFLAGS='-checklinkname=0 -w -s -buildid= -X $(IMPORT_PATH)/intra/core.Date=$(DATESTR) -X $(IMPORT_PATH)/intra/core.Commit=$(COMMIT_ID)' -CGO_LDFLAGS="$(CGO_LDFLAGS) -s -w -Wl,-z,max-page-size=16384" -# build overlay json via recipe -BUILD_OVERLAY=$(BUILDDIR)/overlay.json - -# github.com/golang/mobile/blob/a1d90793fc/cmd/gomobile/bind.go#L36 -GOBIND=bind -trimpath -v -x -a -javapkg com.celzero.firestack -# -work: keep the temporary directory for debugging -ANDROID23=-androidapi 23 -target=android -tags='android' -work -# -tags debuglog to enable runtime crash logging output with "<< begin log" prefix -# example debug log: github.com/golang/go/issues/69629#issuecomment-2389297820 -# build-time tags may be required in somecases -# github.com/golang/go/blob/e2fef50def98/src/runtime/HACKING.md?plain=1#L524 -# github.com/golang/go/blob/e2fef50def98/src/runtime/debuglog.go#L63-L64 -ANDROID23_DEBUG=-androidapi 23 -target=android -tags='android,debuglog' -work - -WINDOWS_BUILDDIR=$(BUILDDIR)/windows -LINUX_BUILDDIR=$(BUILDDIR)/linux - -# stack traces are not affected by ldflags -s -w: github.com/golang/go/issues/25035#issuecomment-495004689 -# trimpath: github.com/skycoin/skycoin/issues/719 -ANDROID_BUILD_CMD=env PATH=$(GOBIN):$(PATH) $(GOMOBILE) $(GOBIND) $(ANDROID23) \ - -overlay=$(BUILD_OVERLAY) -ldflags $(LDFLAGS) -gcflags='-trimpath' -# built without stripping dwarf/symbols -ANDROID_DEBUG_BUILD_CMD=env PATH=$(GOBIN):$(PATH) $(GOMOBILE) $(GOBIND) $(ANDROID23_DEBUG) \ - -overlay=$(BUILD_OVERLAY) -ldflags $(LDFLAGS_DEBUG) -# exported pkgs -INTRA_BUILD_CMD=$(IMPORT_PATH)/intra $(IMPORT_PATH)/intra/backend $(IMPORT_PATH)/intra/settings - -$(BUILDDIR)/intra/tun2socks.aar: $(GOMOBILE) $(BUILD_OVERLAY) - mkdir -p $(BUILDDIR)/intra - $(ANDROID_BUILD_CMD) -o $@ $(INTRA_BUILD_CMD) - -$(BUILDDIR)/intra/tun2socks-debug.aar: $(GOMOBILE) $(BUILD_OVERLAY) - env NDK_DEBUG=1 - mkdir -p $(BUILDDIR)/intra - $(ANDROID_DEBUG_BUILD_CMD) -o $@ $(INTRA_BUILD_CMD) - -$(BUILDDIR)/android/tun2socks.aar: $(GOMOBILE) $(BUILD_OVERLAY) - env NDK_DEBUG=1 - mkdir -p $(BUILDDIR)/android - $(ANDROID_BUILD_CMD) -o $@ $(IMPORT_PATH)/outline/android $(IMPORT_PATH)/outline/shadowsocks - -$(BUILD_OVERLAY): $(TOOLSDIR)/runtime_write_err_android.patch - mkdir -p $(BUILDDIR) - env PATH=$(GOBIN):$(PATH) $(GOPATCHOVERLAY) -overlay $(BUILDDIR) $(TOOLSDIR)/runtime_write_err_android.patch - -$(LINUX_BUILDDIR)/tun2socks: $(XGO) - $(XGO) -ldflags $(XGO_LDFLAGS) --targets=linux/amd64 -dest $(LINUX_BUILDDIR) $(ELECTRON_PATH) - mv $(LINUX_BUILDDIR)/electron-linux-amd64 $@ - -$(WINDOWS_BUILDDIR)/tun2socks.exe: $(XGO) - $(XGO) -ldflags $(XGO_LDFLAGS) --targets=windows/386 -dest $(WINDOWS_BUILDDIR) $(ELECTRON_PATH) - mv $(WINDOWS_BUILDDIR)/electron-windows-4.0-386.exe $@ - -# MACOSX_DEPLOYMENT_TARGET and -iosversion should match what outline-client supports. -$(BUILDDIR)/apple/Tun2socks.xcframework: $(GOMOBILE) - export MACOSX_DEPLOYMENT_TARGET=10.14; $(GOMOBILE) $(GOBIND) -iosversion=9.0 -target=ios,iossimulator,macos -o $@ -ldflags '-s -w' -bundleid org.outline.tun2socks $(IMPORT_PATH)/outline/apple $(IMPORT_PATH)/outline/shadowsocks - -go.mod: tools/tools.go - go mod tidy - touch go.mod - -$(GOMOBILE): go.mod - env GOBIN=$(GOBIN) go install golang.org/x/mobile/cmd/gomobile@latest - env GOBIN=$(GOBIN) go install github.com/felixge/go-patch-overlay@latest - env PATH=$(GOBIN):$(PATH) $(GOMOBILE) init - -$(XGO): go.mod - env GOBIN=$(GOBIN) go install github.com/crazy-max/xgo - -.PHONY: android intra linux apple windows clean clean-all - -all: android intra linux apple windows - -android: $(BUILDDIR)/android/tun2socks.aar - -intra: $(BUILDDIR)/intra/tun2socks.aar - -intradebug: $(BUILDDIR)/intra/tun2socks-debug.aar - -apple: $(BUILDDIR)/apple/Tun2socks.xcframework - -linux: $(LINUX_BUILDDIR)/tun2socks - -windows: $(WINDOWS_BUILDDIR)/tun2socks.exe - -clean: - rm -rf $(BUILDDIR) - go clean - -clean-all: clean - rm -rf $(GOBIN) diff --git a/README.md b/README.md index 26d06e34..2d9b3ca0 100644 --- a/README.md +++ b/README.md @@ -1,118 +1,2 @@ -# Firestack - -Firestack is a userspace TCP/UDP connection monitor, firewall, DNS resolver, and multi-hop [WireGuard](https://github.com/wireguard/wireguard-go) client for Android. - -Firestack is built specifically for [Rethink DNS + Firewall + VPN](https://github.com/celzero/rethink-app). [gVisor/netstack](https://github.com/google/gvisor/tree/go/pkg/tcpip) provides a SOCKS-like interface (similar to [badvpn's tun2socks](https://github.com/ambrop72/badvpn)) for TCP/UDP over a TUN device. - -Firestack is a hard-fork of Google's [outline-go-tun2socks](https://github.com/Jigsaw-Code/outline-go-tun2socks) project. - -## DNS - -Firestack supports DNS over HTTPS, DNS over TLS, Oblivious DNS over HTTPS, DNS over WireGuard / SOCKS5 / Tor, DNSCrypt, and plain old DNS upstreams. - -## WireGuard - -Firestack runs WireGuard in userspace. When running *multiple* WireGuard tunnels at once, only ICMP, DNS, TCP and UDP are forwarded through them. ARP / IGMP / SCTP / RTP and other IP protocols are *not* forwarded to WireGuard tunnels. - -Firestack supports multi-hop / multi-relay WireGuard, where multiple tunnels can be chained together, provided that the outer tunnel (hop/relay) can route to the inner tunnel's (exit) endpoint. - -[FOSS United](https://fossunited.org/grants) FLOSS/fund badge - -WireGuard integration was sponsored by [FOSS United](https://fossunited.org/grants); and Multi-hop / Multi-relay WireGuard by [FLOSS/fund](https://floss.fund/). - -## Releases - -[![SLSA 3](https://slsa.dev/images/gh-badge-level3.svg)](https://slsa.dev/spec/v1.2/build-track-basics#build-l3) [![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/celzero/firestack/badge)](https://securityscorecards.dev/viewer/?uri=github.com/celzero/firestack) [![OpenSSF Best Practices](https://www.bestpractices.dev/projects/11568/badge)](https://www.bestpractices.dev/projects/11568) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/celzero/firestack) - -Firestack is released as an Android Library (`aar`) and can be integrated into -your Android builds via [Jitpack](https://jitpack.io/#celzero/firestack) ([ref](https://github.com/celzero/rethink-app/commit/a6e2abca7)) or [Maven Central (OSSRH)](https://central.sonatype.com/artifact/com.celzero/firestack/overview). - -```gradle - // add this to your project's build.gradle - allprojects { - repositories { - ... - // if consuming from maven central - // ref: central.sonatype.org/consume - mavenCentral() - ... - // if consuming from jitpack - // ref: docs.jitpack.io/android/#installing - maven { url 'https://jitpack.io' } - ... - } - } - - // add the dep to your app's build.gradle - dependencies { - ... - // maven central (stripped) - implementation 'com.celzero:firestack:Tag@aar' - ... - // jitpack (stripped) - implementation 'com.github.celzero:firestack:Tag@aar' - // jitpack (debug symbols) - implementation 'com.github.celzero:firestack:Tag:debug@aar' - ... - } -``` - -## API - -The APIs aren't stable and hence left undocumented, but you can look at -Rethink DNS + Firewall + VPN codebase: ([GoVpnAdapter](https://github.com/celzero/rethink-app/blob/0c931d23d7/app/src/main/java/com/celzero/bravedns/net/go/GoVpnAdapter.kt#L113-L137), [BraveVpnService](https://github.com/celzero/rethink-app/blob/0c931d23d7/app/src/main/java/com/celzero/bravedns/service/BraveVPNService.kt#L5306-L5324)) to see how to integrate with Firestack on Android. - -## Build - -Firestack only supports Android. Instructions for other platforms are left as-is, but they may or may not work. - -### Prerequisites - -- macOS host (iOS, macOS) -- make -- Go >= 1.25 -- A C compiler (e.g.: clang, gcc) - -Firestack APIs are available only on Android builds for now. iOS and Linux support planned but nothing concrete yet. - -### Android - -- [sdkmanager](https://developer.android.com/studio/command-line/sdkmanager) - 1. Download the command line tools from [developer.android.com](https://developer.android.com/studio). - 1. Unzip the pacakge as `~/Android/Sdk/cmdline-tools/latest/`. Make sure `sdkmanager` is located at `~/Android/Sdk/cmdline-tools/latest/bin/sdkmanager` -- Android NDK 28+ - ```bash - # Install the NDK (exact NDK version obtained from `sdkmanager --list`) - ~/Android/Sdk/cmdline-tools/latest/bin/sdkmanager "platforms;android-36" "ndk;28.2.13676358" - # Set up the environment variables: - export ANDROID_NDK_HOME=~/Android/Sdk/ndk/28.2.13676358 ANDROID_HOME=~/Android/Sdk - ``` -- [gomobile](https://pkg.go.dev/golang.org/x/mobile/cmd/gobind) (installed as needed by `make`) - -### Apple (iOS and macOS) - -- Xcode -- [gomobile](https://pkg.go.dev/golang.org/x/mobile/cmd/gobind) (installed as needed by `make`) - -### Linux and Windows - -We build binaries for Linux and Windows from source without any custom integrations. -`xgo` and Docker are required to support cross-compilation. - -- [Docker](https://docs.docker.com/get-docker/) (for XGO) -- [xgo](https://github.com/crazy-max/xgo) (installed as needed by `make`) -- [ghcr.io/crazy-max/xgo Docker image](https://github.com/crazy-max/xgo/pkgs/container/xgo) (~6.8GB pulled by `xgo`). - -## Make - -``` -# creates build/intra/{tun2socks.aar,tun2socks-sources.jar} -make clean && make intra - -``` -If needed, you can extract the jni files into `build/android/jni` with: -```bash -unzip build/android/tun2socks.aar 'jni/*' -d build/android -``` +# firestack +Userspace firewall in go diff --git a/SECURITY.md b/SECURITY.md deleted file mode 100644 index f32ec9c1..00000000 --- a/SECURITY.md +++ /dev/null @@ -1,5 +0,0 @@ -This project follows a 90 day disclosure timeline. - -To report a security issue, please follow instructions at https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability#privately-reporting-a-security-vulnerability - -We will respond within 30 days. If the issue is confirmed as a vulnerability, we will open a Security Advisory and acknowledge your contributions. diff --git a/go.mod b/go.mod deleted file mode 100644 index c16dc2c6..00000000 --- a/go.mod +++ /dev/null @@ -1,53 +0,0 @@ -module github.com/celzero/firestack - -go 1.26 - -require ( - github.com/celzero/gotrie v0.0.0-20250314130138-a2756ab2f6bd - github.com/jedisct1/go-dnsstamps v0.0.0-20200621175006-302248eecc94 - github.com/jedisct1/xsecretbox v0.0.0-20190909160646-b731c21297f9 - github.com/k-sone/critbitgo v1.4.0 - github.com/miekg/dns v1.1.66 - golang.org/x/crypto v0.47.0 - golang.org/x/sys v0.40.0 -) - -require ( - github.com/Snawoot/opera-proxy v1.5.0 - github.com/cloudflare/odoh-go v1.0.0 - github.com/coder/websocket v1.8.14 - github.com/crazy-max/xgo v0.31.0 - github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a - github.com/noql-net/certpool v0.0.0-20240719060413-a5ed62ecc62a - github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f - github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 - go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 - golang.org/x/mobile v0.0.0-20260120165949-40bd9ace6ce4 - golang.org/x/net v0.49.0 - golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb - gvisor.dev/gvisor v0.0.0-20260220231412-fe30adbe8e25 -) - -require ( - git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6 // indirect - github.com/Snawoot/go-http-digest-auth-client v1.1.3 // indirect - github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect - github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect - github.com/cisco/go-hpke v0.0.0-20210215210317-01c430f1f302 // indirect - github.com/cisco/go-tls-syntax v0.0.0-20200617162716-46b0cfb76b9b // indirect - github.com/cloudflare/circl v1.6.3 // indirect - github.com/google/btree v1.1.2 // indirect - github.com/patrickmn/go-cache v2.1.0+incompatible // indirect - github.com/pkg/diff v0.0.0-20200914180035-5b29258ca4f7 // indirect - github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf // indirect - golang.org/x/exp v0.0.0-20241004190924-225e2abe05e6 // indirect - golang.org/x/mod v0.32.0 // indirect - golang.org/x/sync v0.19.0 // indirect - golang.org/x/text v0.33.0 // indirect - golang.org/x/time v0.12.0 // indirect - golang.org/x/tools v0.41.0 // indirect - golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect -) - -// TODO: remove all replaces -// replace golang.org/x/mobile v0.0.0-20250506005352-78cd7a343bde => github.com/ignoramous/mobile v0.0.0-20260119111959-bc2c8adf6210 diff --git a/go.sum b/go.sum deleted file mode 100644 index 6606033c..00000000 --- a/go.sum +++ /dev/null @@ -1,233 +0,0 @@ -git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6 h1:w8IZgCntCe0RuBJp+dENSMwEBl/k8saTgJ5hPca5IWw= -git.schwanenlied.me/yawning/x448.git v0.0.0-20170617130356-01b048fb03d6/go.mod h1:wQaGCqEu44ykB17jZHCevrgSVl3KJnwQBObUtrKU4uU= -github.com/AdguardTeam/dnsproxy v0.73.2/go.mod h1:zD5WfTctbRvYYk8PS39h6/OT84NTu6QxKbAiBN5PUcI= -github.com/AdguardTeam/golibs v0.29.0/go.mod h1:vjw1OVZG6BYyoqGRY88U4LCJLOMfhBFhU0UJBdaSAuQ= -github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= -github.com/Microsoft/hcsshim v0.9.12/go.mod h1:qAiPvMgZoM0wpkVg6qMdSEu+1VtI6/qHOOPkTGt8ftQ= -github.com/Snawoot/go-http-digest-auth-client v1.1.3 h1:Xd/SNBuIUJqotzmxRpbXovBJxmlVZOT19IZZdMdrJ0Q= -github.com/Snawoot/go-http-digest-auth-client v1.1.3/go.mod h1:WiwNiPXTRGyjTGpBtSQJlM2wDPRRPpFGhMkMWpV4uqg= -github.com/Snawoot/opera-proxy v1.5.0 h1:lip90ChPbZF7PQvZvY20v99oy7WsyqtWggquudbdjaQ= -github.com/Snawoot/opera-proxy v1.5.0/go.mod h1:XdZYuhfFCgjS9Sufye+p+4llKsu0e5j3ImMUesEuzgc= -github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= -github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= -github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= -github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us= -github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE= -github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A= -github.com/bazelbuild/rules_go v0.44.2/go.mod h1:Dhcz716Kqg1RHNWos+N6MlXNkjNP2EwZQ0LukRKJfMs= -github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= -github.com/celzero/gotrie v0.0.0-20250314130138-a2756ab2f6bd h1:jPvF+c8M0ACi/TNPYigCzY39avz5qYI4hYE+ygrt4Xc= -github.com/celzero/gotrie v0.0.0-20250314130138-a2756ab2f6bd/go.mod h1:Yq0X9rVFGSEfmT/V95Z2blOpJGwhZsoboAdTAzlK6uo= -github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= -github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM= -github.com/cisco/go-hpke v0.0.0-20210215210317-01c430f1f302 h1:unAbn7dpE8eeUfWRaOPl1qTfffhIcCNuKQuECGNGWtk= -github.com/cisco/go-hpke v0.0.0-20210215210317-01c430f1f302/go.mod h1:RSsoIHRMBe69FbF/fIbmWYa3rrC6vuPyC0MbNUpel3Q= -github.com/cisco/go-tls-syntax v0.0.0-20200617162716-46b0cfb76b9b h1:Ves2turKTX7zruivAcUOQg155xggcbv3suVdbKCBQNM= -github.com/cisco/go-tls-syntax v0.0.0-20200617162716-46b0cfb76b9b/go.mod h1:0AZAV7lYvynZQ5ErHlGMKH+4QYMyNCFd+AiL9MlrCYA= -github.com/cloudflare/circl v1.0.0/go.mod h1:MhjB3NEEhJbTOdLLq964NIUisXDxaE1WkQPUxtgZXiY= -github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= -github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= -github.com/cloudflare/odoh-go v1.0.0 h1:4ZRBHNFC0wefDpWKuSXDuw6SsEulP3QrS/rqG9RVCgo= -github.com/cloudflare/odoh-go v1.0.0/go.mod h1:J3Doz827YDYvz4hEmJU6q45hRFOqxUBL6NRUuEfjMxA= -github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= -github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/containerd/cgroups v1.0.4/go.mod h1:nLNQtsF7Sl2HxNebu77i1R0oDlhiTG+kO4JTrUzo6IA= -github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= -github.com/containerd/containerd v1.6.36/go.mod h1:gSufNaPbqri6ifEQ3eihFSXoGwqTENkqB7j//aEgE0s= -github.com/containerd/errdefs v0.1.0/go.mod h1:YgWiiHtLmSeBrvpw+UfPijzbLaB77mEG1WwJTDETIV0= -github.com/containerd/fifo v1.0.0/go.mod h1:ocF/ME1SX5b1AOlWi9r677YJmCPSwwWnQ9O123vzpE4= -github.com/containerd/go-runc v1.0.0/go.mod h1:cNU0ZbCgCQVZK4lgG3P+9tn9/PaJNmoDXPpoJhDR+Ok= -github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= -github.com/containerd/ttrpc v1.1.2/go.mod h1:XX4ZTnoOId4HklF4edwc4DcqskFZuvXB1Evzy5KFQpQ= -github.com/containerd/typeurl v1.0.2/go.mod h1:9trJWW2sRlGub4wZJRTW83VtbOLS6hwcDZXTn6oPz9s= -github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= -github.com/crazy-max/xgo v0.31.0 h1:JE2PKXBQ6cYpcMSVBuo7WmYy3elKiNHaLv9YXCSsRug= -github.com/crazy-max/xgo v0.31.0/go.mod h1:m/aqfKaN/cYzfw+Pzk7Mk0tkmShg3/rCS4Zdhdugi4o= -github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= -github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= -github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= -github.com/elazarl/goproxy/ext v0.0.0-20190711103511-473e67f1d7d2 h1:dWB6v3RcOy03t/bUadywsbyrQwCqZeNIEX6M1OtSZOM= -github.com/elazarl/goproxy/ext v0.0.0-20190711103511-473e67f1d7d2/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8= -github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gofrs/flock v0.8.0/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= -github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20241001023024-f4c0cfd0cf1d/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= -github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gnostic v0.5.5/go.mod h1:7+EbHbldMins07ALC74bsA81Ovc97DwqyJO1AENw9kA= -github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/jedisct1/go-dnsstamps v0.0.0-20200621175006-302248eecc94 h1:O5X61fl3p/dl+7hLDwDamJxRY6z/LwuH1XD+OyNNlxE= -github.com/jedisct1/go-dnsstamps v0.0.0-20200621175006-302248eecc94/go.mod h1:128Ik0lG+DBYL6zaSgN3icmzDASeQgkSy3+Sp10trLc= -github.com/jedisct1/xsecretbox v0.0.0-20190909160646-b731c21297f9 h1:nGfB2s9K0GyHuNkJmXkIjP+m7je6Q6gjirr+weAEtDo= -github.com/jedisct1/xsecretbox v0.0.0-20190909160646-b731c21297f9/go.mod h1:MipBKo+gZlzpd1JXA1OliuwvtQizlFeu4aMAyTLh8bo= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/k-sone/critbitgo v1.4.0 h1:l71cTyBGeh6X5ATh6Fibgw3+rtNT80BA0uNNWgkPrbE= -github.com/k-sone/critbitgo v1.4.0/go.mod h1:7E6pyoyADnFxlUBEKcnfS49b7SUAQGMK+OAp/UQvo0s= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0= -github.com/miekg/dns v1.1.51/go.mod h1:2Z9d3CP1LQWihRZUf29mQ19yDThaI4DAYzte2CaQW5c= -github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE= -github.com/miekg/dns v1.1.66/go.mod h1:jGFzBsSNbJw6z1HYut1RKBKHA9PBdxeHrZG8J+gC2WE= -github.com/moby/sys/capability v0.4.0/go.mod h1:4g9IK291rVkms3LKCDOoYlnV8xKwoDTpIrNEE35Wq0I= -github.com/moby/sys/mountinfo v0.6.2/go.mod h1:IJb6JQeOklcdMU9F5xQ8ZALD+CUr5VlGpwtX+VE0rpI= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/noql-net/certpool v0.0.0-20240719060413-a5ed62ecc62a h1:ApBZhPJkiXHcx3EjFC1wHEk3+eKsGcrF3++2pSAI8Q8= -github.com/noql-net/certpool v0.0.0-20240719060413-a5ed62ecc62a/go.mod h1:NuAP3INCprX/lHBPlvCa67RpZ7bfwwgak06w2j2L00o= -github.com/onsi/ginkgo/v2 v2.20.2/go.mod h1:K9gyxPIlb+aIvnZ8bd9Ak+YP18w3APlR+5coaZoE2ag= -github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= -github.com/opencontainers/runtime-spec v1.1.0-rc.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= -github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pkg/diff v0.0.0-20200914180035-5b29258ca4f7 h1:+/+DxvQaYifJ+grD4klzrS5y+KJXldn/2YTl5JG+vZ8= -github.com/pkg/diff v0.0.0-20200914180035-5b29258ca4f7/go.mod h1:zO8QMzTeZd5cpnIkz/Gn6iK0jDfGicM1nynOkkPIl28= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= -github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= -github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f h1:PDPGJtm9PFBLNudHGwkfUGp/FWvP+kXXJ0D1pB35F40= -github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= -github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf h1:7PflaKRtU4np/epFxRXlFhlzLXZzKFrH5/I4so5Ove0= -github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf/go.mod h1:CLUSJbazqETbaR+i0YAhXBICV9TrKH93pziccMhmhpM= -github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 h1:d/Wr/Vl/wiJHc3AHYbYs5I3PucJvRuw3SvbmlIRf+oM= -github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301/go.mod h1:ntmMHL/xPq1WLeKiw8p/eRATaae6PiVRNipHFJxI8PM= -github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 h1:lGdhQUN/cnWdSH3291CUuxSEqc+AsGTiDxPP3r2J0l4= -go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190909091759-094676da4a83/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20241004190924-225e2abe05e6 h1:1wqE9dj9NpSm04INVsJhhEUzhuDVjbcyKH91sVyPATw= -golang.org/x/exp v0.0.0-20241004190924-225e2abe05e6/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= -golang.org/x/exp/shiny v0.0.0-20251219203646-944ab1f22d93/go.mod h1:QqbL1+y9e9D0Su+B9umI12TlEFXxVNGTpUai4t0pvgI= -golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= -golang.org/x/mobile v0.0.0-20260120165949-40bd9ace6ce4 h1:C3JuLOLhdaE75vk5m7u18NvZciRk+lnO34xcXl3NPTU= -golang.org/x/mobile v0.0.0-20260120165949-40bd9ace6ce4/go.mod h1:yHJY0EGzMJ0i5ONrrhdpDSSnoyres5LO7D2hSIbJJ5I= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190909082730-f460065e899a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= -golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= -golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= -golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= -golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= -golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= -golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= -google.golang.org/api v0.249.0/go.mod h1:dGk9qyI0UYPwO/cjt2q06LG/EhUpwZGdAbYF14wHHrQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= -google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20260220231412-fe30adbe8e25 h1:Agld46dvzXICOfiXcRxZX6+S9sZpTGEdJcOjYx+EW4I= -gvisor.dev/gvisor v0.0.0-20260220231412-fe30adbe8e25/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= -k8s.io/api v0.23.16/go.mod h1:Fk/eWEGf3ZYZTCVLbsgzlxekG6AtnT3QItT3eOSyFRE= -k8s.io/apimachinery v0.23.16/go.mod h1:RMMUoABRwnjoljQXKJ86jT5FkTZPPnZsNv70cMsKIP0= -k8s.io/client-go v0.23.16/go.mod h1:CUfIIQL+hpzxnD9nxiVGb99BNTp00mPFp3Pk26sTFys= -k8s.io/klog/v2 v2.30.0/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= -k8s.io/kube-openapi v0.0.0-20211115234752-e816edb12b65/go.mod h1:sX9MT8g7NVZM5lVL/j8QyCCJe8YSMW30QvGZWaCIDIk= -k8s.io/utils v0.0.0-20211116205334-6203023598ed/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= -sigs.k8s.io/json v0.0.0-20211020170558-c049b76a60c6/go.mod h1:p4QtZmO4uMYipTQNzagwnNoseA6OxSUutVw05NhYDRs= -sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= -sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= diff --git a/intra/backend/core_iptree.go b/intra/backend/core_iptree.go deleted file mode 100644 index 904c2c10..00000000 --- a/intra/backend/core_iptree.go +++ /dev/null @@ -1,522 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -import ( - "errors" - "net" - "strings" - "sync" - - "github.com/celzero/firestack/intra/core" - "github.com/k-sone/critbitgo" -) - -// todo: use github.com/gaissmai/bart ? - -// A IpTree is a thread-safe trie that supports insertion, deletion, and route matching IP CIDRs. -type IpTree interface { - // Adds value v to the cidr route. - Add(cidr, v string) error - // Sets cidr route to v, overwriting any previous value. - Set(cidr, v string) error - // Removes value v, if found. - Esc(cidr, v string) bool - // Deletes cidr route. Returns true if cidr was found. - Del(cidr string) bool - // Gets the value of cidr or "" if cidr is not found. - Get(cidr string) (string, error) - // Returns true if the cidr route is found. - Has(cidr string) (bool, error) - // Returns csv of all routes matching cidr or "". - Routes(cidr string) string - // Returns csv of values of all routes matching cidr or "". - Values(cidr string) string - // Returns the route@csv(value) of any route matching cidr or "". - GetAny(cidr string) (string, error) - // Returns true if any route matches cidr. - HasAny(cidr string) (bool, error) - // Removes values like v ("*v*") for cidr. - EscLike(cidr, likev string) int32 - // Returns csv of all routes with any value like v matching cidr. - RoutesLike(cidr, likev string) string - // Returns csv of all routes with values like v for cidr. - ValuesLike(cidr, likev string) string - // Returns csv of all values like v for cidr. - GetLike(cidr, likev string) string - // Returns the longest route for cidr as "r1@csv(v)|r2@csv(v2)" or "". - GetAll(cidr string) (string, error) - // Deletes all routes matching cidr. Returns the number of routes deleted. - DelAll(cidr string) int32 - // Clears the trie. - Clear() - // Returns the number of routes. - Len() int -} - -type iptree struct { - sync.RWMutex - t *critbitgo.Net -} - -const ( - // Vsep is a values separator (csv) - Vsep = "," - // Ksep is a key separator (csv) - Ksep = "," - // Kdelim is a key@csv(v) delimiter - Kdelim = "@" - // KVsep is a k1:v1|k2:v2 separator - KVsep = "|" -) - -var ( - errValNotString = errors.New("values must be string") -) - -// NewIpTree returns a new IpTree. -func NewIpTree() IpTree { - return &iptree{t: critbitgo.NewNet()} -} - -func (c *iptree) Add(cidr string, v string) error { - return c.add(cidr, v) -} - -func (c *iptree) add(cidr string, v string) error { - x, err := c.get(cidr) - - if err != nil { - return err - } else if len(v) == 0 || x == v { - return nil - } else if len(x) == 0 { - return c.set(cidr, v) - } else if strings.Contains(x, v) { // is ~v in x? - cur := strings.SplitSeq(x, Vsep) - for val := range cur { - if val == v { // v is definitely in x - return nil - } - } // v is not in x, but something resembling ~v was. - return c.set(cidr, x+Vsep+v) - } - return c.set(cidr, x+Vsep+v) -} - -func (c *iptree) Set(cidr string, v string) error { - c.del(cidr) // delete any previous value - return c.add(cidr, v) -} - -func (c *iptree) set(cidr string, v string) error { - r, err := ip2cidr(cidr) - if err != nil { - return err - } - - c.Lock() - defer c.Unlock() - - return c.t.Add(r, v) -} - -func (c *iptree) Del(cidr string) bool { - return c.del(cidr) -} - -func (c *iptree) del(cidr string) bool { - r, err := ip2cidr(cidr) - if err != nil { - return false - } - - c.Lock() - defer c.Unlock() - - _, ok, err := c.t.Delete(r) - return ok && err == nil -} - -func (c *iptree) Esc(cidr string, v string) bool { - return c.esc(cidr, v) -} - -func (c *iptree) esc(cidr string, v string) bool { - if x, err := c.get(cidr); err != nil { - return false - } else if len(x) == 0 || len(v) == 0 { - return false - } else if x == v { - return c.del(cidr) - } else if strings.Contains(x, v) { - // remove all occurrences of v in csv x - old := strings.Split(x, Vsep) - cur := make([]string, 0, len(old)) - for _, val := range old { - if val != v { - cur = append(cur, val) - } - } - if len(cur) == 0 { - return c.del(cidr) - } - return c.set(cidr, strings.Join(cur, Vsep)) == nil - } - return false -} - -func (c *iptree) Has(cidr string) (bool, error) { - return c.has(cidr) -} - -func (c *iptree) has(cidr string) (bool, error) { - r, err := ip2cidr(cidr) - if err != nil { - return false, err - } - - c.RLock() - defer c.RUnlock() - - _, ok, err := c.t.Get(r) - return ok, err -} - -func (c *iptree) DelAll(cidr string) (n int32) { - return c.delAll(cidr) -} - -func (c *iptree) delAll(cidr string) (n int32) { - r, err := ip2cidr(cidr) - if r == nil || err != nil { - return - } - - c.Lock() - defer c.Unlock() - - keys := make([]*net.IPNet, 0) - c.t.WalkMatch(r, func(k *net.IPNet, _ any) bool { - keys = append(keys, k) - return true - }) - - for _, k := range keys { - if _, ok, err := c.t.Delete(k); ok && err == nil { - n++ - } - } - return -} - -func (c *iptree) HasAny(cidr string) (bool, error) { - return c.hasAny(cidr) -} - -func (c *iptree) hasAny(cidr string) (bool, error) { - r, err := ip2cidr(cidr) - if err != nil { - return false, err - } - - c.RLock() - defer c.RUnlock() - - m, _, err := c.t.Match(r) - return m != nil, err -} - -func (c *iptree) Get(cidr string) (string, error) { - r, err := c.get(cidr) - if err != nil { - return "", err - } - return r, nil // r may be empty -} - -func (c *iptree) get(cidr string) (v string, err error) { - r, err := ip2cidr(cidr) - if err != nil { - return "", err - } - - c.RLock() - defer c.RUnlock() - - s, ok, err := c.t.Get(r) - if ok && err == nil { - if v, ok = s.(string); !ok { - return "", errValNotString - } - } else { - return "", err // may be nil - } - return -} - -func (c *iptree) GetAny(cidr string) (string, error) { - r, err := c.getAny(cidr) - if err != nil { - return "", err - } - return r, nil // r may be empty -} - -func (c *iptree) getAny(cidr string) (rv string, err error) { - r, err := ip2cidr(cidr) - if err != nil { - return "", err - } - - c.RLock() - defer c.RUnlock() - - m, v, err := c.t.Match(r) - if err != nil { - return "", err - } - if m != nil { - rv = m.String() - } - if v != nil { - if s, ok := v.(string); ok { - rv = rv + Kdelim + s - } - } - return -} - -func (c *iptree) GetAll(cidr string) (string, error) { - r, err := c.getAll(cidr) - if err != nil { - return "", err - } - return r, nil // r may be empty -} - -func (c *iptree) getAll(cidr string) (rv string, err error) { - r, err := ip2cidr(cidr) - if err != nil { - return "", err - } - - c.RLock() - defer c.RUnlock() - - c.t.WalkMatch(r, func(k *net.IPNet, v any) bool { - if k == nil { - return true // next - } - rv = rv + k.String() - if v != nil { - if s, ok := v.(string); ok && len(s) > 0 { - rv = rv + Kdelim + s - } - } - rv = rv + KVsep - return true // next - }) - return strings.TrimRight(rv, KVsep), nil -} - -func (c *iptree) Routes(cidr string) string { - return c.routes(cidr) -} - -func (c *iptree) routes(cidr string) string { - r, err := ip2cidr(cidr) - if err != nil { - return "" - } - - c.RLock() - defer c.RUnlock() - - rt := make([]string, 0) - c.t.WalkMatch(r, func(k *net.IPNet, _ any) bool { - if k != nil { - rt = append(rt, k.String()) - } - return true // next - }) - return strings.Join(rt, Ksep) -} - -func (c *iptree) Values(cidr string) string { - return c.values(cidr) -} - -func (c *iptree) values(cidr string) string { - r, err := ip2cidr(cidr) - if err != nil { - return "" - } - - c.RLock() - defer c.RUnlock() - - vt := make([]string, 0) - c.t.WalkMatch(r, func(_ *net.IPNet, v any) bool { - if v != nil { - if s, ok := v.(string); ok && len(s) > 0 { - vt = append(vt, s) - } - } - return true // next - }) - return strings.Join(vt, Vsep) -} - -func (c *iptree) EscLike(cidr, like string) int32 { - return c.escLike(cidr, like) -} - -func (c *iptree) escLike(cidr, like string) int32 { - if x, err := c.get(cidr); err != nil { - return -1 // error - } else if len(x) == 0 { - return 0 - } else if len(like) == 0 { - return c.delAll(cidr) - } else if x == like { - if rmv := c.del(cidr); rmv { - return 1 - } - return 0 - } else if strings.Contains(x, like) { - // remove all occurrences of v in csv x - old := strings.Split(x, Vsep) - cur := make([]string, 0, len(old)) - n := int32(0) - for _, val := range old { - if !strings.HasPrefix(val, like) { - cur = append(cur, val) - } else { - n++ - } - } - if len(cur) == 0 { // no values left - _ = c.del(cidr) - } else if len(cur) != len(old) { // no change; n == 0 - _ = c.set(cidr, strings.Join(cur, Vsep)) - } - return n - } - return 0 // not found -} - -func (c *iptree) GetLike(cidr, like string) string { - return c.getLike(cidr, like) -} - -func (c *iptree) getLike(cidr, like string) string { - if x, err := c.get(cidr); err != nil { - return "" // error - } else if len(x) == 0 { - return "" - } else if len(like) == 0 || x == like { - return x // match all - } else if strings.Contains(x, like) { - // grab all occurrences of v in csv x - all := strings.Split(x, Vsep) - grab := make([]string, 0, len(all)) - for _, val := range all { - if strings.HasPrefix(val, like) { - grab = append(grab, val) - } - } - return strings.Join(grab, Vsep) - } - return "" // not found -} - -func (c *iptree) RoutesLike(cidr, like string) string { - return c.routesLike(cidr, like) -} - -func (c *iptree) routesLike(cidr, like string) string { - r, err := ip2cidr(cidr) - if err != nil { - return "" - } - - c.RLock() - defer c.RUnlock() - - rt := make([]string, 0) - c.t.WalkMatch(r, func(k *net.IPNet, v any) bool { - if v == nil { - return true // next - } - if s, ok := v.(string); ok && len(s) > 0 { - if !strings.Contains(s, like) { - return true // next - } - // grab all occurrences of v in csv s - for val := range strings.SplitSeq(s, Vsep) { - if strings.HasPrefix(val, like) { - rt = append(rt, val) - } - } - } - return true // next - }) - return strings.Join(rt, Ksep) -} - -func (c *iptree) ValuesLike(cidr, like string) string { - return c.valuesLike(cidr, like) -} - -func (c *iptree) valuesLike(cidr, like string) string { - r, err := ip2cidr(cidr) - if err != nil { - return "" - } - - c.RLock() - defer c.RUnlock() - - vt := make([]string, 0) - c.t.WalkMatch(r, func(k *net.IPNet, v any) bool { - if v == nil { - return true // next - } - if s, ok := v.(string); ok && len(s) > 0 { - if !strings.Contains(s, like) { - return true // next - } - // grab all occurrences of v in csv s - for val := range strings.SplitSeq(s, Vsep) { - if strings.HasPrefix(val, like) { - vt = append(vt, val) - } - } - } - return true // next - }) - return strings.Join(vt, Vsep) -} - -func (c *iptree) Clear() { - c.Lock() - defer c.Unlock() - - c.t.Clear() -} - -func (c *iptree) Len() int { - c.RLock() - defer c.RUnlock() - - return c.t.Size() -} - -func ip2cidr(ippOrCidr string) (*net.IPNet, error) { - return core.IP2Cidr(ippOrCidr) -} diff --git a/intra/backend/core_iptree_test.go b/intra/backend/core_iptree_test.go deleted file mode 100644 index 19f946da..00000000 --- a/intra/backend/core_iptree_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package backend - -import ( - "testing" - - ll "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -func Test192(tst *testing.T) { - log := tst.Log - t := NewIpTree() - t.Add("192.0.0.0/8", "app192:443") - t.Add("192.0.0.0/8", "dup192:443") - t.Add("192.1.0.0/16", "app192:80") - t.Add("1.1.1.0/24", "*:80") - t.Add("192.1.0.0/16", "app1921:80") - t.Set("192.2.0.0/16", "app1922:0") - t.Set("192.2.0.0/16", "app1922:unset") - t.Add("192.1.1.1/32", "app192111:0") - t.Add("0.0.0.0/0", "test0000") - t.Add("192.0.0.0/8", "app192:443") - - g8, err := t.Get("192.0.0.0/8") - ko(tst, err) - log("g8", g8) // app192:443 dup192:443 - - g16, err := t.Get("192.1.0.0/16") - rmv := t.Esc("1.1.0.0/16", "test16.2") // false - g16any, err1 := t.GetAny("192.1.0.0/16") - ko(tst, err) - ko(tst, err1) - log("g16", g16, "g16any", g16any, "esc?", rmv) - - g32any, err := t.GetAny("192.1.1.2/32") - ko(tst, err) - log("g32any", g32any) - - gall, err := t.GetAll("192.1.1.1/32") - ko(tst, err) - log("gall", gall) - - route := t.Routes("192.1.0.0/16") - rlike := t.RoutesLike("192.1.0.0/16", ":80") - val := t.Values("192.1.0.0/16") - vlike := t.ValuesLike("192.1.0.0/16", ":80") - vlike2 := t.ValuesLike("192.1.0.0/16", "app192:80") - log("val", val) - log("route", route) - log("vlike", vlike, "vlike(1app):", vlike2) - log("rlike", rlike) -} - -func TestUn(tst *testing.T) { - ll.SetLevel(ll.VVERBOSE) - settings.Debug = true - - trie := NewRadixTree() - trie.Add("fritz.box") // exact domain - trie.Add(".lan") // subdomain ending with .lan - trie.Add(".sub.tld") // subdomain ending with .sub.tld - - noma1 := trie.HasAny("test.fritz.box") // no subdomain matches - yma1 := trie.HasAny("fritz.box") // exact match for fritz.box - yma2 := trie.HasAny("test.lan") // subdomain match for .lan - yma3 := trie.HasAny("mu.st.sub.tld") // subdomain match for sub.tld - - ll.V("no: %t, yes: [%t %t %t]", noma1, yma1, yma2, yma3) -} - -func ko(tst *testing.T, err error) { - if err != nil { - tst.Fatal(err) - } -} diff --git a/intra/backend/core_radixtree.go b/intra/backend/core_radixtree.go deleted file mode 100644 index 3e3733e3..00000000 --- a/intra/backend/core_radixtree.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -import ( - "sync" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - "github.com/k-sone/critbitgo" -) - -// A RadixTree is a thread-safe trie that supports insertion, deletion, and prefix matching. -type RadixTree interface { - // Adds k to the trie. Returns true if k was not already in the trie. - Add(k string) bool - // Sets k to v in the trie, overwriting any previous value. - Set(k, v string) - // Deletes k from the trie. Returns true if k was in the trie. - Del(k string) bool - // Gets the value of k from the trie or "" if k is not in the trie. - Get(k string) string - // Returns true if k is in the trie. - Has(k string) bool - // Returns the value of the longest prefix of k in the trie or "". - GetAny(prefix string) string - // Returns true if any key in the trie has the prefix. - HasAny(prefix string) bool - // Deletes all keys in the trie with the prefix. Returns the number of keys deleted. - DelAll(prefix string) int32 - // Clears the trie. - Clear() - // Returns the number of keys in the trie. - Len() int -} - -type radix struct { - sync.RWMutex - t *critbitgo.Trie -} - -func NewRadixTree() RadixTree { - return &radix{t: critbitgo.NewTrie()} -} - -func reversed(s string) (b []byte) { - return []byte(xdns.StringReverse(s)) -} - -func (c *radix) Add(k string) bool { - return c.add(k) -} - -func (c *radix) add(k string) bool { - c.Lock() - defer c.Unlock() - - return c.t.Insert(reversed(k), "") -} - -func (c *radix) Set(k, v string) { - c.set(k, v) -} - -func (c *radix) set(k, v string) { - c.Lock() - defer c.Unlock() - - c.t.Set(reversed(k), v) -} - -func (c *radix) Del(k string) bool { - return c.del(k) -} - -func (c *radix) del(k string) bool { - c.Lock() - defer c.Unlock() - - _, ok := c.t.Delete(reversed(k)) - return ok -} - -func (c *radix) Has(k string) bool { - return c.has(k) -} - -func (c *radix) has(k string) bool { - c.RLock() - defer c.RUnlock() - - return c.t.Contains(reversed(k)) -} - -func (c *radix) DelAll(prefix string) (n int32) { - return c.delAll(prefix) -} - -func (c *radix) delAll(prefix string) (n int32) { - c.Lock() - defer c.Unlock() - - keys := make([][]byte, 10) - c.t.Allprefixed(reversed(prefix), func(k []byte, v any) bool { - keys = append(keys, k) - return true - }) - - for _, k := range keys { - if _, ok := c.t.Delete(k); ok { - n++ - } - } - return -} - -func (c *radix) HasAny(prefix string) bool { - return c.hasAny(prefix) -} - -func (c *radix) hasAny(prefix string) bool { - return c.getMatch(prefix) != nil -} - -func (c *radix) Get(k string) (v string) { - return c.get(k) -} - -func (c *radix) get(k string) (v string) { - c.RLock() - defer c.RUnlock() - - s, ok := c.t.Get(reversed(k)) - if ok { - v, _ = s.(string) - } - return -} - -func (c *radix) GetAny(prefix string) (v string) { - return c.getAny(prefix) -} - -func (c *radix) getAny(prefix string) (v string) { - if s := c.getMatch(prefix); s != nil { - v = *s - } - return -} - -func (c *radix) getMatch(str string) *string { - c.RLock() - defer c.RUnlock() - - rev := reversed(str) - var v any // value - var s string // string(v) - var ok bool // rev(str) found? - var match []byte - - if match, v, ok = c.t.LongestPrefix(rev); ok { - // test: log.VV("radix: get: one: %s: %v %v %v", str, match, v, ok) - if ok = len(match) == len(rev); ok { - // full match (xyz.ipvonly.arpa); same as c.Get() - s, ok = v.(string) - // test: log.VV("radix: two: %s: %s %s %t", str, s, rev, ok) - } else if ok = len(match) < len(rev) && rev[len(match)-1] == '.'; ok { - // partial match upto a subdomain (.ipvonly.arpa); note the trailing dot - s, ok = v.(string) - // test: log.VV("radix: three: %s: %s %v %t", str, s, rev, ok) - } - // test: log.VV("radix: get: four: %s: %s [%d %d] %t", str, s, len(rev), len(match), ok) - // partial match (ipvonly.arpa) but not a subdomain/wildcard, discard - } else { // no match - // log.V("radix: get: no prefix match for %s", str) - return nil - } - - log.V("radix: get: partial or full %s => %s; rev %s; match %s; ok? %t", str, s, rev, match, ok) - - if !ok { - return nil - } - return &s -} - -func (c *radix) Clear() { - c.Lock() - defer c.Unlock() - - c.t.Clear() -} - -func (c *radix) Len() int { - c.RLock() - defer c.RUnlock() - - return c.t.Size() -} diff --git a/intra/backend/dnsx.go b/intra/backend/dnsx.go deleted file mode 100644 index b0db0031..00000000 --- a/intra/backend/dnsx.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -const ( // see dnsx/transport.go - // DNS transport types - DOH = "DNS-over-HTTPS" - DNSCrypt = "DNSCrypt" - DNS53 = "DNS" - DOT = "DNS-over-TLS" - ODOH = "Oblivious DNS-over-HTTPS" - - CT = "Cache" // cached transport prefix - - // special singleton DNS transports (IDs) - - // Multiple DoH, DoT, ODoH resolvers - Plus = "Plus" - // Go determined default resolver (built-in) - Goos = "Goos" - // network/os provided dns (init using intra.SetSystemDNS) - System = "System" - // mdns; never cached! (built-in) - Local = "mdns" - // default (fallback) dns, used in place of special transports when unavailable - // (init using intra.AddDefaultTransport) - Default = "Default" - // client preferred dns - Preferred = "Preferred" - // synthesizes answers from presets (built-in) - Preset = "Preset" - // synthesizes A/AAAA from a single fixed IP (built-in) - Fixed = "Fixed" - // a transport that bypasses local blocklists (dnsx.SetRdnsLocal); - // if not set, Default is used - BlockFree = "BlockFree" - // all queries are blocked, answers never cached (built-in) - BlockAll = "BlockAll" - // Bootstrap DNS (built-in); encapsulates Default, if set; or Goos, otherwise. - Bootstrap = "Bootstrap" - // Application-level gateway - Alg = "Alg" - // dnscrypt.Proxy as a DNS Transport - DcProxy = "DcProxy" - // dns resolver for dns resolvers and for firestack (built-in) - // delegates queries to Bootstrap. - IpMapper = "IpMapper" - - // dns request origin indicators - - // DNS request originated internally - OriginInternal = "self" - - // DNS request originated from tunnel read - OriginTunnel = "tunnel" -) - -const ( // from dnsx/queryerror.go - // Start: Transaction started - Start = iota - // Complete : Transaction completed successfully - Complete - // SendFailed : Failed to send query - SendFailed - // NoResponse : Got no response - NoResponse - // BadQuery : Malformed input - BadQuery - // BadResponse : Response was invalid - BadResponse - // InternalError : This should never happen - InternalError - // TransportError: Transport has issues - TransportError - // ClientError: Client has issues - ClientError - // Paused: Transport is paused - Paused - // DEnd: Transport stopped - DEnd -) - -const ( // from: dnsx/rethinkdns.go - EB32 = iota - EB64 -) - -// DNSTransport exports necessary methods from dnsx.Transport -type DNSTransport interface { - // uniquely identifies this transport - ID() string - // one of DNS53, DOH, DNSCrypt, System - Type() string - // Median round-trip time for this transport, in millis. - P50() int64 - // Return the server host address used to initialize this transport. - GetAddr() string - // Return the proxy (relay) always used by this transport. - // Returns nil if there isn't any. - GetRelay() Proxy - // State of the transport after previous query (see: queryerror.go) - Status() int -} - -type DNSTransportMult interface { - DNSTransportProvider - // Add adds a transport to this multi-transport. - Add(t DNSTransport) bool - // Remove removes a transport from this multi-transport. - Remove(id string) bool - // Refresh re-registers transports and returns a csv of active ones. - Refresh() (string, error) - // LiveTransports returns a csv of active transports. - LiveTransports() string -} - -type DNSTransportProvider interface { - // Get returns a transport from this multi-transport. - Get(id string) (DNSTransport, error) -} - -type RDNS interface { - // SetStamp sets the rethinkdns blockstamp. - SetStamp(string) error - // GetStamp returns the current rethinkdns blockstamp. - GetStamp() (string, error) - // StampToNames returns csv group:names of blocklists in the given stamp s. - StampToNames(s string) (string, error) - // FlagsToStamp returns a blockstamp for given csv blocklist-ids, if valid. - FlagsToStamp(csv string, enctyp int) (string, error) - // StampToFlags retruns csv blocklist-ids given a valid blockstamp s. - StampToFlags(s string) (string, error) -} - -type RDNSResolver interface { - // SetRdnsLocal sets the local rdns resolver. - SetRdnsLocal(trie, rank, conf, filetag string) error - // GetRdnsLocal returns the local rdns resolver. - GetRdnsLocal() (RDNS, error) - // SetRdnsRemote sets the remote rdns resolver. - SetRdnsRemote(filetag string) error - // GetRdnsRemote returns the remote rdns resolver. - GetRdnsRemote() (RDNS, error) - // Translate enables or disables ALG and fixed responses. - Translate(alg, fix bool) -} - -type DNSTransportMultProvider interface { - // GetMult returns a multi-transport by id. - GetMult(id string) (DNSTransportMult, error) -} - -type DNSResolver interface { - DNSTransportMult - DNSTransportMultProvider - RDNSResolver -} - -type ResolverListener interface { - // OnDNSAdded is called when a new DNS transport with id is added. - OnDNSAdded(id string) - // OnDNSRemoved is called when a DNS transport with id is removed, except - // when the transport is stopped, then OnDNSStopped is called instead. - OnDNSRemoved(id string) - // OnDNSStopped is called when all DNS transports are stopped. Note: - // OnDNSRemoved is not called for each transport even if they are - // being removed and not just stopped. - OnDNSStopped() -} diff --git a/intra/backend/dnsx_listener.go b/intra/backend/dnsx_listener.go deleted file mode 100644 index 8979ff00..00000000 --- a/intra/backend/dnsx_listener.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -import "fmt" - -// DNSSummary is a summary of a DNS transaction, reported when it is complete. -type DNSSummary struct { - // dnscrypt, dns53, doh, odoh, dot, preset, fixed, etc. - Type string - // DNS Transport ID - ID string - // owner uid that sent this request. May be empty. - UID string - // Response (or failure) latency in seconds - Latency float64 - // Queried domain name - QName string - // Query type: A, AAAA, SVCB, HTTPS, etc. May be 0. - QType int - // CSV of all DNS aliases/names in the answer section (ex: CNAMEs) - Targets string - // Was this response returned from cache? - Cached bool - // DNS Response data, ex: a csv of ips for A, AAAA. - RData string - // DNS Response code - RCode int - // DNS Response TTL - RTtl int - // DNS Server (ip, ip:port, host, host:port) - Server string - // Proxy or a relay server address - PID string - // Relay server PID hops over, if any - RPID string - // Transport status (Start, Complete, SendFailed, NoResponse, BadQuery, BadResponse, etc) - Status int - // CSV of Rethink DNS+ blocklists (local or remote) names (if used). - Blocklists string - // Actual target (domain name) that was blocked (could be a CNAME or HTTPS/SVCB alias) by Blocklists - BlockedTarget string - // True if any among upstream transports (primary or secondary) returned blocked ans. - // Only valid for A/AAAA queries. Unspecified IPs are considered as "blocked ans". - UpstreamBlocks bool - // True if DNSSEC OK bit is set. - DO bool - // True if DNSSEC validation was successful. - AD bool - // Diag message from Transport, if any. Typically, "no error" - Msg string - // Region of the Rethink DNS+ server (if used) - Region string -} - -type DNSOpts struct { - // uid of the app (or the stub resolver) that sent this query. - // May be ANDROID, DNS, MDNS etc instead of the actual app. - UID string - // csv of proxy ids to use for this query. Not all transports are proxied. - // For instance, dnsx.System, dnsx.Local, dnsx.Goos, dnsx.Preset, dnsx.Default - // are never proxied. - PIDCSV string - // csv of ips to answer for this query; incl unspecified ips, if any. - // applicable only for A/AAAA queries. - // if set, query bypasses on-device blocklists. - IPCSV string - // primary transport ids to use for this query. - // dictated by user preferences (dnsx.Preferred, dnsx.System etc) or - // or user set rules (dnsx.BlockAll, dnsx.BlockFree, dnsx.Fixed etc) - TIDCSV string - // secondary transport ids to use for this query. - // usually, user-set DNS (dnsx.Preferred or dnsx.System) when primary is - // dnsx.BlockFree or dnsx.Fixed. Mostly, left unset. - TIDSECCSV string - // If set, query bypasses on-device blocklists only, independent of whether TIDCSV - // has dnsx.BlockFree or not. The difference is, dnsx.BlockFree if pointing to a - // non-blocking resolver (like one.one.one.one or dns.google) - // will bypass both on-device & upstream blocklists. - NOBLOCK bool -} - -// String implements fmt.Stringer. -func (s *DNSSummary) String() string { - if s == nil { - return "" - } - return fmt.Sprintf("type: %s, id: %s, latency: %f, qname: %s, rdata: %s, rcode: %d, rttl: %d, server: %s, relay: %s, status: %d, blocklists: %s, msg: %s, loc: %s", - s.Type, s.ID, s.Latency, s.QName, s.RData, s.RCode, s.RTtl, s.Server, s.PID, s.Status, s.Blocklists, s.Msg, s.Region) -} - -// DNSListener receives Summaries. -type DNSListener interface { - ResolverListener - // OnQuery is called when a DNS query is received. The listener - // can return a DNSOpts to specify how the query should be handled. - OnQuery(who, uid, domain string, qtyp int) *DNSOpts - // OnUpstreamAnswer is called before an upstream DNS answer (not blocked by firestack) is sent to the OS. - // The listener may return DNSOpts to specify if another upstream should override that answer. - // Another round of OnQuery is NOT called in this case, and OnResponse is called once after processing - // DNSOpts returned by OnUpstreamAnswer if it has a non-empty TIDCSV (overriding the original TIDCSV). - OnUpstreamAnswer(smm *DNSSummary, unmodifiedipcsv string) *DNSOpts - // OnResponse is called when a DNS response is received. May be called twice for the same query, - // for instance, when different options are requested through OnUpstreamAnswer. - OnResponse(*DNSSummary) -} - -// args (string, string, int) is OK with cgo? -// github.com/golang/go/issues/46893#issuecomment-868749896 diff --git a/intra/backend/ipn_pipkeygen.go b/intra/backend/ipn_pipkeygen.go deleted file mode 100644 index c6f6c76a..00000000 --- a/intra/backend/ipn_pipkeygen.go +++ /dev/null @@ -1,641 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -import ( - "bytes" - "crypto" - "crypto/hmac" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "encoding/base64" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "math/big" - "strings" - "sync" - - "github.com/celzero/firestack/intra/core" - brsa "github.com/celzero/firestack/intra/core/brsa" - "github.com/celzero/firestack/intra/log" - // "github.com/cloudflare/circl/blindsign/blindrsa" -) - -const ( - delim = ":" - msgsize = 32 // min msg size in bytes; >= pipkey.c.prefixLen - cidsize = 32 // client identifier size in bytes - tokensize = 32 // token size in bytes - bidsize = 32 // blind id size in bytes; hmac(msg, pubkey) - blindsize = 256 // blinded message size in bytes - rsizemax = 256 // max blinding factor size in bytes - saltsize = 48 // salt size in bytes; see: hashfn - hashfn = crypto.SHA384 // 48 byte hash fn for RSA-PSS -) - -var ( - errEmptyPipKeyState = errors.New("pipkey: empty pip key state") - errTokenCreat = errors.New("pipkey: cannot create token") -) - -type PipKeyProvider interface { - // Msg returns the PipMsg that this PipKeyProvider is associated with. - // Never nil. - Msg() *PipMsg - // Bid uniquely identifies a blinded PipKeyProvider. - // PipKeyProviders created from same blinded PipKeyState have the same identity. - // If this PipKeyProvider is not yet blinded, it returns nil. - Bid() string - // Blind generates id:blindMsg:blindingFactor:salt:msg - // id is a 64 byte hmac tying blindMsg to the public key - // blindMsg is a 256 byte blinded message - // blindingFactor is upto 256 byte random blinding factor - // salt is 48 bytes random salt (see: hashfn) - // msg is a 32 byte random message (see: msgsize) - Blind() (*PipKeyState, error) - // Finalize calculates actual signature for given blingSig blind signature. - Finalize(blingSig string) (*PipKey, error) -} - -// Strx exists for gomobile type export support for PipMsg & PipToken -type Strx struct { - s string -} - -// PipToken is a 32 byte random token for bespoke auth. -type PipToken Strx - -// PipMsg is a 64 byte hex encoded string that contains: -// - first 32 bytes as message (random) -// - next 32 bytes as client identifier (random) -type PipMsg Strx - -func (s *Strx) S() string { - return s.s -} - -// gomobile does not generate funcs inherited from Strx -func (s *PipToken) S() string { - return s.s -} - -// gomobile does not generate funcs inherited from Strx -func (s *PipMsg) S() string { - return s.s -} - -// AsPipMsg typecast m to PipMsg. -// m must be a 64 bytes hex string -// (32b for msg + 32b for opaque-id). -// Returns nil if the string m is nil or not a valid PipMsg. -func AsPipMsg(m string) *PipMsg { - p := (PipMsg)(Strx{s: m}) - if !p.ok() { - return nil - } - return &p -} - -func NewPipMsgWith(tok *PipToken) *PipMsg { - if tok == nil { - return nil - } - msg := token() - if len(msg) != 2*msgsize { - log.E("pipkey: new: invalid msg size; want %d, got %d", 2*msgsize, len(msg)) - return nil - } - return pipmsgof(msg + tok.s) -} - -// go.dev/play/p/hPFgE9s9tMP -// go.dev/play/p/OTMIv7FLtVs -func pipmsgof(m string) *PipMsg { - // 2 chars per byte in hex - if sz := 2 * (msgsize + cidsize); len(m) < sz { - log.E("pipkey: fromv: invalid msg size; want %d, got %d", sz, len(m)) - return nil - } - // m is a 64 byte hex encoded string + tok is a 64 byte - return AsPipMsg(m) -} - -func (p *PipMsg) v() string { - if p == nil { - return "" - } - return p.s -} - -func (p *PipMsg) ok() bool { - return p != nil && len(p.s) >= 2*(msgsize+cidsize) -} - -func (p *PipMsg) msg() []byte { - if p == nil || !p.ok() { - log.E("pipkey: msg: invalid; got %d", len(p.s)) - return nil - } - // first 32 bytes are the message - return hex2byte(p.s[:2*msgsize]) -} - -func (p *PipMsg) cid() []byte { - if p == nil || !p.ok() { - log.E("pipkey: cid: invalid; got %d", len(p.s)) - return nil - } - // next 32 bytes are the client identifier - return hex2byte(p.s[2*msgsize : 2*(msgsize+cidsize)]) -} - -// Opaque returns the client id part of the PipMsg as hex string. -func (p *PipMsg) Opaque() *PipToken { - if p == nil || !p.ok() { - log.E("pipkey: opaque: invalid; got %d", len(p.s)) - return nil - } - tok, err := asPipToken(p.s[2*(msgsize) : 2*(msgsize+cidsize)]) - if err != nil { - log.E("pipkey: opaque conv: %v", err) - return nil - } - return tok -} - -// Rotate creates a new PipMsg with the same opaque identifier but a different msg. -func (p *PipMsg) Rotate() (new *PipMsg) { - return NewPipMsgWith(p.Opaque()) -} - -type PipKey struct { - // hex encoded 64 byte msg+cid (random) - Msg *PipMsg - // hex encoded 256 byte sig (unblinded signature) - Sig string - // hex encoded 32 byte sha256(sig) (msg signature hash) - SigHash string -} - -func (p *PipKey) V() string { - if p == nil { - return "" - } - - if !p.Msg.ok() { - return "" - } - - // msg+cid:sig:sigHash - return strings.Join([]string{ - p.Msg.v(), - p.Sig, - p.SigHash, - }, delim) -} - -func PipKeyFrom(state string) (*PipKey, error) { - if len(state) <= 0 { - return nil, errEmptyPipKeyState - } - - // msg:sig:sigHash - parts := strings.Split(state, delim) - if len(parts) != 3 { - log.E("pipkey: fromv: expected 3 parts, got %d", len(parts)) - return nil, brsa.ErrInvalidMessageLength - } - - msg := pipmsgof(parts[0]) - if msg == nil || !msg.ok() { - return nil, brsa.ErrInvalidMessageLength - } - - return &PipKey{ - Msg: msg, - Sig: parts[1], - SigHash: parts[2], - }, nil -} - -type PipKeyState struct { - // hex encoded 64 byte id that identifies BlindMsg. - Bid string - // hex encoded 256 byte blind(Msg) - BlindMsg string - // hex encoded blinding factor (up to 256 bytes) - R string - // hex encoded 48 byte salt (random) - Salt string - // hex encoded 32 byte (client) msg (usually, random) - // concatenated with 32 byte (client identifier) token - Msg *PipMsg -} - -func newPipKeyState(id, blindMsg, r, salt, msg string) *PipKeyState { - return &PipKeyState{ - Bid: id, - BlindMsg: blindMsg, - R: r, - Salt: salt, - Msg: pipmsgof(msg), - } -} - -func NewPipKeyStateFrom(state string) (*PipKeyState, error) { - if len(state) <= 0 { - return nil, errEmptyPipKeyState - } - - // id:blindMsg:r:salt:msg - parts := strings.Split(state, delim) - if len(parts) == 1 { - // if there's only one part, it's the message - return &PipKeyState{ - Msg: pipmsgof(parts[0]), - }, nil - } else if len(parts) == 5 { - return &PipKeyState{ - Bid: parts[0], - BlindMsg: parts[1], - R: parts[2], - Salt: parts[3], - Msg: pipmsgof(parts[4]), - }, nil - - } - - log.E("pipkey: fromv: expected either 1 or 5 parts, got %d", len(parts)) - return nil, brsa.ErrInvalidMessageLength -} - -func (p *PipKeyState) V() string { - if p == nil { - return "" - } - - return p.v() -} - -func (p *PipKeyState) v() string { - if p == nil { - return "" - } - - if len(p.BlindMsg) != blindsize { - return p.Msg.v() - } - - return strings.Join([]string{ - p.Bid, - p.BlindMsg, - p.R, - p.Salt, - p.Msg.v(), - }, - delim, - ) -} - -// { -// kty: "RSA", -// alg: "PS384", -// n: "lSFviqAqSHpPOtVgm7...", -// e: "AQAB", -// key_ops: [ "verify" ], -// ext: true -// } -// -// github.com/serverless-proxy/serverless-proxy/blob/5d209e85/src/webcrypto/blindrsa.js#L6-L15 -type pubKeyJwk struct { - Kty string `json:"kty"` // key type: RSA - Alg string `json:"alg,omitempty"` // algorithm: PS384 - N string `json:"n"` // modulus (2048 bits) - E string `json:"e"` // exponent - KeyOps []string `json:"key_ops"` // key operations: verify - Ext bool `json:"ext"` // extractable: true -} - -// pkgen is a struct that implements the PipKeyProvider interface. -type pkgen struct { - mu sync.Mutex // protects all fields below - - pubkey *rsa.PublicKey - - v *brsa.Verifier - state *brsa.State - c *brsa.Client - - bid []byte // 64 bytes id derived from hmac(m=blindMsg, k=pubkey) - cid []byte // 32 bytes client identifier (token); not used in this impl - msg []byte // min 32 bytes random msg specific to this key - blindMsg []byte // 256 bytes blindMsg derived from msg, r, salt -} - -var _ PipKeyProvider = (*pkgen)(nil) - -// NewPipKeyProvider creates a new PipKeyProvider instance. -// pubjwk: JWK string of the public key of the RSA-PSS signer (for which modulus must be 2048 bits, and hash-fn must be SHA384). -// msgOrExistingState: if empty, a new PipKeyProvider is created with a random message, if not empty, it's the state of an existing PipKey. -// Typically, msgOrExistingState is got from PipKeyState.V() -func NewPipKeyProvider(pubjwk []byte, msgOrExistingState string) (PipKeyProvider, error) { - return newPipKey(pubjwk, msgOrExistingState, false) -} - -// NewPipKeyProviderFromMsg creates a new PipKeyProvider instance from a JWK and a msg hex string. -// Generating Blind() for the same msg with the same JWK will NOT result in the same PipKeyState. -// To restore a previous state, use NewPipKeyProvider() with the PipKeyState.V() string. -func NewPipKeyProviderFromMsg(pubjwk []byte, msg string) (PipKeyProvider, error) { - return newPipKey(pubjwk, msg, true) -} - -func newPipKey(bjwk []byte, msgOrExistingState string, msgOnly bool) (PipKeyProvider, error) { - jwk := &pubKeyJwk{} - err := json.Unmarshal(bjwk, jwk) - if err != nil { - return nil, fmt.Errorf("cannot unmarshal public key: %v", err) - } - // base64 decode modulus and exponent into a big.Int - n, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return nil, fmt.Errorf("cannot decode key modulus: %v", err) - } - bn := big.NewInt(0) - bn.SetBytes(n) - // base64 decode exponent into an int - e, err := base64.RawURLEncoding.DecodeString(jwk.E) - if err != nil { - return nil, fmt.Errorf("cannot decode key exponent: %v", err) - } - be := big.NewInt(0) - be.SetBytes(e) - // create rsa.PublicKey - pub := &rsa.PublicKey{ - N: bn, - E: int(be.Int64()), // may overflow on 32-bit - } - // brsa.SHA384PSSDeterministic does not prepend random 32 bytes prefix to k.msg, - // whilst brsa.SHA384PSSRandomized does. ref: brsa.Prepare() which is unused here. - c, err1 := brsa.NewClient(brsa.SHA384PSSDeterministic, pub) - v, err2 := brsa.NewVerifier(brsa.SHA384PSSDeterministic, pub) - if err1 != nil || err2 != nil { - err := core.JoinErr(err1, err2) - log.E("pipkey: new: sha384-pss-det verifier err %v", err) - return nil, err - } - k := &pkgen{ - pubkey: pub, - v: &v, - c: &c, - } - if msgOrExistingState != "" { - // id : blindMsg : r : salt : msg+cid - parts := strings.Split(msgOrExistingState, delim) - if len(parts) == 1 { // if there's only one part, it's the message - pipmsg := pipmsgof(parts[0]) - if pipmsg == nil || !pipmsg.ok() { - log.E("pipkey: new: invalid msg; got %d", len(parts[0])) - return nil, brsa.ErrInvalidMessageLength - } - k.msg = pipmsg.msg() - k.cid = pipmsg.cid() - return k, nil - } - if msgOnly || len(parts) != 5 { - // if there's more than one part, it's the state - // and so we at least 4 parts - return nil, brsa.ErrInvalidMessageLength - } - k.bid = hex2byte(parts[0]) // unique id; hmac(msg, pubkey) - k.blindMsg = hex2byte(parts[1]) - r := hex2BigInt(parts[2]) // blinding factor - rInv, err := modInv(r, k.pubkey.N) - if err != nil { - log.E("pipkey: new: invalid r/rInv; %v", err) - return nil, err - } - salt := hex2byte(parts[3]) - pipmsg := pipmsgof(parts[4]) - if pipmsg == nil || !pipmsg.ok() { - log.E("pipkey: new: invalid msg; got %d", len(parts[0])) - return nil, brsa.ErrInvalidMessageLength - } - k.msg = pipmsg.msg() - k.cid = pipmsg.cid() - // no need to k.c.Prepare() if SHA384PSSDeterministic - if bmsg, state, err := k.c.FixedBlind(k.msg, salt, r, rInv); err != nil { - return nil, err - } else { - k.state = &state - if !bytes.Equal(k.blindMsg, bmsg) { // sanity check - log.E("pipkey: new: invalid blindMsg: got(%s) != want(%s)", - byte2hex(k.blindMsg), byte2hex(bmsg)) - return nil, brsa.ErrInvalidBlind - } - } - } else { - if k.msg, err = brand(msgsize); err == nil { - k.cid, err = brand(cidsize) - } - if err != nil { - log.E("pipkey: new: gen err, %v", err) - return nil, err - } - // k.c.Prepare() is unused for SHA384PSSDeterministic - } - return k, nil -} - -// Msg implements PipKeyProvider. -func (k *pkgen) Msg() *PipMsg { - if k == nil { - log.E("pipkey: msg: nil PipKeyProvider") - return nil - } - pipmsg := pipmsgof(byte2hex(k.msg) + byte2hex(k.cid)) - if pipmsg == nil || !pipmsg.ok() { - log.E("pipkey: msg: invalid PipMsg; got %d", len(k.msg)+len(k.cid)) - return nil - } - return pipmsg -} - -// Bid implements PipKeyProvider. -func (k *pkgen) Bid() string { - k.mu.Lock() - defer k.mu.Unlock() - - if k.bid == nil { - log.E("pipkey: who: not blinded") - return "" - } - - if len(k.bid) != bidsize { - log.E("pipkey: who: invalid size %d; expected: %d", - len(k.bid), bidsize) - return "" - } - - return byte2hex(k.bid) -} - -// Blind implements PipKeyProvider. -func (k *pkgen) Blind() (*PipKeyState, error) { - k.mu.Lock() - defer k.mu.Unlock() - - if k.state != nil { - log.E("pipkey: blind: already blinded") - return nil, brsa.ErrInvalidBlind - } - - blindMsg, verifierState, err := k.c.Blind(rand.Reader, k.msg) - if err != nil { - log.E("pipkey: blind: %v", err) - return nil, err - } - - r := verifierState.Factor() - salt := verifierState.Salt() // nil for SHA384PSSZeroDeterministic/SHA384PSSZeroRandomized - - if r == nil { - log.E("pipkey: blind: invalid r") - return nil, brsa.ErrInvalidBlind - } - - k.blindMsg = blindMsg - k.bid = hmac256(k.blindMsg, k.pubkey.N.Bytes()) // must match with server-side impl - k.state = &verifierState - - if len(k.bid) != bidsize || len(k.blindMsg) != blindsize || len(r.Bytes()) > rsizemax || len(salt) != saltsize || len(k.msg) != msgsize || len(k.cid) != cidsize { - log.E("pipkey: blind: invalid state; id %d, blindMsg %d, r %d, salt %d, msg %d+%d", - len(k.bid), len(k.blindMsg), len(r.Bytes()), len(salt), len(k.msg), len(k.cid)) - return nil, brsa.ErrUnexpectedSize - } - - // existing state; id : blindMsg : r : salt : msg+cid - return newPipKeyState( - byte2hex(k.bid), - byte2hex(blindMsg), - bigInt2hex(r), - byte2hex(salt), - k.Msg().v(), // msg + cid - ), nil -} - -// Finalize implements PipKeyProvider. -func (k *pkgen) Finalize(blindSig string) (*PipKey, error) { - return k.finalize(blindSig) -} - -func (k *pkgen) finalize(blindSig string) (*PipKey, error) { - k.mu.Lock() - defer k.mu.Unlock() - - if k.state == nil { - log.E("pipkey: finalize: not blinded") - return nil, brsa.ErrInvalidBlind - } - var sigbytes []byte - // unblind using r and salt - sigbytes, err := k.c.Finalize(*k.state, hex2byte(blindSig)) - if err != nil { - log.E("pipkey: finalize: %v", err) - return nil, err - } - // verify the unblinded sig using the public key - err = k.v.Verify(k.msg, sigbytes) - if err != nil { - log.E("pipkey: finalize: verify: %v", err) - return nil, err - } - hashedsigbytes := sha256sum(sigbytes) - - return &PipKey{ - Msg: k.Msg(), - Sig: byte2hex(sigbytes), - SigHash: byte2hex(hashedsigbytes), - }, nil -} - -// Token gnerates a 32 byte random as hex (auths dataplane ops) -func Token() (*PipToken, error) { - return asPipToken(token()) -} - -func asPipToken(tok string) (*PipToken, error) { - if len(tok) != 2*tokensize { - return nil, errTokenCreat - } - // StrOf interns the string - return (*PipToken)(&Strx{s: tok}), nil -} - -func token() string { - tok, err := brand(tokensize) - if err != nil { - log.W("pipkey: no token; err: %v", err) - return "" - } - return byte2hex(tok) -} - -func brand(sz int) ([]byte, error) { - r := make([]byte, sz) - _, err := rand.Read(r) - return r, err -} - -// hex2byte returns the byte slice represented by the hex string s. -func hex2byte(s string) []byte { - b, err := hex.DecodeString(s) - if err != nil { - log.E("piph2: hex2byte: err %v", err) - } - return b -} - -// byte2hex returns the hex representation of the byte slice b. -func byte2hex(b []byte) string { - return hex.EncodeToString(b) -} - -func hex2BigInt(s string) *big.Int { - b, err := hex.DecodeString(s) - if err != nil { - log.E("piph2: hex2BigInt: err %v", err) - } - return new(big.Int).SetBytes(b) -} - -func bigInt2hex(b *big.Int) (h string) { - return hex.EncodeToString(b.Bytes()) -} - -// sha256sum returns the SHA256 digest (32 byte) of the message m. -func sha256sum(m []byte) []byte { - digest := sha256.Sum256(m) - return digest[:] -} - -// hmac256 returns the HMAC-SHA256 (32 byte) of the message m using the key k. -func hmac256(m, k []byte) []byte { - mac := hmac.New(sha256.New, k) - mac.Write(m) - return mac.Sum(nil) -} - -func modInv(g *big.Int, n *big.Int) (z *big.Int, err error) { - z = new(big.Int).ModInverse(g, n) - if z == nil { - err = brsa.ErrInvalidBlind - } - return -} diff --git a/intra/backend/ipn_proxies.go b/intra/backend/ipn_proxies.go deleted file mode 100644 index ef17037c..00000000 --- a/intra/backend/ipn_proxies.go +++ /dev/null @@ -1,413 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -import "fmt" - -const ( // see ipn/proxies.go - // IDs for default proxies - - // blocks all traffic (built-in) - Block = "Block" - // may send traffic out via underlying network (built-in) - // see: tun2socks.Loopback; alias for dnsx.NetNoProxy - Base = "Base" - // always sends traffic out via underlying network (built-in) - // see: Controller.Protect; alias for dnsx.NetExitProxy - Exit = "Exit" - // proxies incoming connections (built-in) - Ingress = "Ingress" - // Auto uses ipn.Exit or any of the RPN proxies (built-in) - Auto = "Auto" - // RPN Win proxy (must be registered by Rpn.RegisterWin) - RpnWin = WG + "y" + RPN - // Alias for RPN Win - RpnPro = RpnWin - // RPN WebSockets (unused) - RpnWs = PIPWS + RPN - // RPN HTTP/2 (unused) - RpnH2 = PIPH2 + RPN - // RPN Exit hopping over NAT64 (built-in) - Rpn64 = NAT64 + RPN - // Orbot: Base Tor-as-a-SOCKS5 proxy - OrbotS5 = "OrbotSocks5" - // Orbot: Base Tor-as-a-HTTP/1.1 proxy - OrbotH1 = "OrbotHttp1" - // Global: HTTP/1.1 proxy if required by underlying network. - GlobalH1 = "GlobalHttp1" - - // type of proxies - - // SOCKS5 proxy type - SOCKS5 = "socks5" - // HTTP/1.1 proxy type - HTTP1 = "http1" - // WireGuard-as-a-proxy type and prefix - WG = "wg" - // No proxy (uses underlying network), ex: Base, Block, Ingress - NOOP = "noop" - // Egress, ex: Exit - INTERNET = "net" - // WireGuard-as-a-proxy w/ UDP GRO/GSO prefix (experimental) - WGFAST = "gsro" - // PIP: HTTP/2 proxy prefix (unused) - PIPH2 = "piph2" - // PIP: WebSockets proxy prefix (unused) - PIPWS = "pipws" - // A NAT64 router (prefix) - NAT64 = "nat64" - // Rethink Proxy Network (suffix) - RPN = "rpn" - - // status of proxies - - // proxy paused until resumed; will not dial - TPU = 3 - // proxy UP but not responding - TNT = 2 - // proxy idle - TZZ = 1 - // proxy UP but not yet OK - TUP = 0 - // proxy OK - TOK = -1 - // proxy OK but erroring out - TKO = -2 - // proxy stopped - END = -3 -) - -// RpnOps carries options that control the behaviour of Update and RegisterWin. -// Fields are unexported; use the Set* methods to configure. -type RpnOps struct { - rotateCreds bool // force a new WireGuard keypair on the next Update - permaCreds bool // use permanent WireGuard (local) addresses if available - forceFetchServers bool // force server-list refresh on the next Update - newPort uint16 // fixed WireGuard port; 0 = random per wsRandomPort() -} - -func NewRpnOps() *RpnOps { - return &RpnOps{} -} - -func (o *RpnOps) String() string { - return fmt.Sprintf("rotate: %t; perma: %t; forceFetchServers: %t; port: %d", - o.rotateCreds, o.permaCreds, o.forceFetchServers, o.newPort) -} - -// SetRotateCreds forces generation of a new WireGuard keypair on the next Update. -// Note: Rotate and Perma are mutually exclusive; setting one to true will set the other to false. -func (o *RpnOps) SetRotateCreds(v bool) { o.rotateCreds = v; o.permaCreds = false } - -// SetPermaCreds enables/disables using the permanent WG credential set in Conf(). -// Note: Rotate and Perma are mutually exclusive; setting one to true will set the other to false. -func (o *RpnOps) SetPermaCreds(v bool) { o.permaCreds = v; o.rotateCreds = false } - -// SetForceFetchServers forces the server-list refresh on the next Update. -func (o *RpnOps) SetForceFetchServers(v bool) { o.forceFetchServers = v } - -// SetPort pins a specific WireGuard port; 0 means random (default). -func (o *RpnOps) SetPort(port int32) { - if port >= 0 && port <= 65535 { - o.newPort = uint16(port) - } -} - -// Rotate reports whether a new WG keypair should be generated. -func (o RpnOps) Rotate() bool { return o.rotateCreds } - -// Perma reports whether permanent WG credentials should be used in Conf(). -func (o RpnOps) Perma() bool { return o.permaCreds } - -// FetchServers reports whether the server-list fetch should be forced. -func (o RpnOps) FetchServers() bool { return o.forceFetchServers } - -// Port returns the pinned WireGuard port, or 0 if none is set. -func (o RpnOps) Port() uint16 { - return o.newPort -} - -type Rpn interface { - // EntitlementFrom returns the RpnEntitlement represented by entitlementOrStateJson. - // `did` is the device identifier to use for this entitlement, if applicable; and `rpnProviderID` is the RPN provider for this entitlement, if applicable. - // `rpnProviderID` is the RPN provider to use with this entitlement (ex: RpnWin, etc). - EntitlementFrom(entitlementOrStateJson []byte, rpnProviderID, did string) (RpnEntitlement, error) - // RegisterWin registers (or re-registers) a Windscribe account. - // ops may be nil to use default behaviour. - RegisterWin(entitlementOrStateJson []byte, did string, ops *RpnOps) (json []byte, err error) - // UnregisterWin unregisters a Windscribe installation. - UnregisterWin() bool - // TestWin connects to the Windscribe gateway and returns its IP if reachable. - TestWin() (ips string, errs error) - // TestExit64 connects to public NAT64 endpoints and returns reachable ones. - TestExit64() (ips string, errs error) - // Win returns a Windscribe WireGuard proxy. - Win() (wg RpnProxy, err error) - // Pip returns a RpnWs proxy. - Pip() (ws RpnProxy, err error) - // Exit64 returns a Exit proxy hopping over preset publicly-available - // NAT64 proxies. - Exit64() (nat64 RpnProxy, err error) -} - -type Proxy interface { - // ID returns the ID of this proxy. - ID() string - // Type returns the type of this proxy. - Type() string - // Returns x.Router. - Router() Router - // Client returns a client that uses this proxy. - Client() Client - // GetAddr returns the address of this proxy. - GetAddr() string - // DNS returns the ip:port or doh/dot url or dnscrypt stamp for this proxy. - DNS() string - // Status returns the status of this proxy. - Status() int - // Ping pings this proxy. - Ping() bool - // Pause pauses this proxy. - Pause() bool - // Resume resumes this proxy. - Resume() bool - // Stop stops this proxy. - Stop() error - // Refresh re-registers this proxy, if necessary. - Refresh() error -} - -type RpnProxy interface { - Proxy - RpnAcc - // Fork adds proxy for country code, cc. - Fork(cc string) (Proxy, error) - // Redo re-forks the main proxy and all its kids. - Redo() (err error) - // PingAll pings the main proxy and all its kids. - PingAll() (csvpids string, err error) - // Purge removes proxy for country code, cc. - Purge(cc string) bool - // Get returns proxy for country code, cc. - Get(cc string) (Proxy, error) - // Kids returns csv of forked proxy PIDs, excluding this one. - Kids() (csvpids string) -} - -// RpnAcc represents an account with RPN provider. -type RpnAcc interface { - // Who returns identifier for this account; may be empty. - Who() string - // State returns the state (as json) of the account. - State() ([]byte, error) - // Ops returns the RpnOps that control the behaviour of Update. - Ops() *RpnOps - // Created returns the time (unix millis) currently active account was created. - Created() int64 - // Updated returns the time (unix millis) currently active account was updated. - Updated() int64 - // Expires returns the time (unix millis) currently active account expires. - Expires() int64 - // Locations returns RpnServers encapsulating this proxy's worldwide server presence. - Locations() (RpnServers, error) - // Update updates the account creating new state; ops may be nil to retain current RpnOps. - Update(ops *RpnOps) (newstate []byte, err error) -} - -// RpnEntitlement represents access to a proxy service. -type RpnEntitlement interface { - // ProviderID is RPN provider for this entitlement. - ProviderID() string - // Cid is the Client identifier. - CID() string - // DID is the Device identifier, if any. - DID() string - // Token is the entitlement token, if any. - Token() string - // Expiry is ISO 8601 string of the expiry time of this entitlement, if any. - Expiry() string - // "valid", "invalid", "banned", "expired", "unknown" - Status() string - // AllowRestore returns true if this entitlement can be transferred around for restores. - AllowRestore() bool - // Test is set if this entitlement is valid only in the test domain. - Test() bool - // Json returns entitlement (but not the state) as json. - Json() ([]byte, error) -} - -type Proxies interface { - // Underlay creates a [NOOP] proxy (that always connects over underlying network), - // but one that uses a custom Controller. - // This proxy is not tracked (APIs like GetProxy won't return these). - Underlay(id string, c Controller) Proxy - // Add adds a proxy to this multi-transport. - // "id" is a free-form unique identifier for this proxy, except: - // "id" for WireGuard proxies must be prefixed with [WG] - // "url" is WireGuard UAPI configuration. - // For HTTP1 and SOCKS5 proxies, "url" must be of the form: - // scheme://usr:pwd@domain.tld:port/p/a/t/h?q&u=e&r=y#f,r - // where scheme is "http" or "socks5", usr and/or pwd are optional - // port is the port number, and domain.tld could also be ip address. - AddProxy(id, url string) (Proxy, error) - // Remove removes a transport from this multi-transport. - RemoveProxy(id string) bool - // GetProxy returns a transport from this multi-transport. - GetProxy(id string) (Proxy, error) - // TestHop returns empty diag if origin can hop to via, - // otherwise returns a diagnosis of why it couldn't. - // Only WireGuard via & origin are supported, for now. - TestHop(via, origin string) (diag string) - // Hop chains two proxies in the order of origin dialing through via. - // Only WireGuard via & origin are supported, for now. - Hop(via, origin string) error - // Router returns a lowest common denomination router for this multi-transport. - Router() Router - // RPN returns the Rethink Proxy Network api. - Rpn() Rpn - // Refresh re-registers proxies and returns a csv of active ones. - RefreshProxies() string -} - -type Router interface { - // IP4 returns true if this router supports IPv4. - IP4() (y bool) - // IP6 returns true if this router supports IPv6. - IP6() (y bool) - // MTU returns the MTU of this router. - MTU() (mtu int, err error) - // Stats returns the stats of this router. - Stat() (s *RouterStats) - // Via returns the gateway for this router, if any. - Via() (gw Proxy, err error) - // Reaches returns true if any host:port or ip:port is dialable. - Reaches(hostportOrIPPortCsv string) (y bool) - // Contains returns true if this router can route ipprefix. - Contains(ipprefix string) (y bool) -} - -type Client interface { - // IP4 returns information about this client's remote IPv4. - IP4() (*IPMetadata, error) - // IP6 returns information about this client's remote IPv6. - IP6() (*IPMetadata, error) - // TODO: Move Reaches here? - // TODO: Fetch(method, url, headers, body) (status, headers, body, err) -} - -// ProxyListener is a listener for proxy events. -type ProxyListener interface { - // OnProxyAdded is called when a proxy is added. - OnProxyAdded(id string) - // OnProxyRemoved is called when a proxy is removed except when all - // proxies are stopped, in which case OnProxiesStopped is called. - OnProxyRemoved(id string) - // OnProxyStopped is called when a proxy is stopped instead of being - // removed (that is, this callback is not called in all proxy stop scenarios). - // A stopped proxy, if added again, is replaced/updated instead; and subsequently, - // the onProxyAdded callback is invoked. - OnProxyStopped(id string) - // OnProxiesStopped is called when all proxies are stopped. - // Note: OnProxyRemoved is not called for each proxy, even - // if they are removed instead of being merely "stopped". - OnProxiesStopped() -} - -// RouterStats lists interesting stats of a Router. -type RouterStats struct { - // addresses (csv) of the router - Addrs string - // bytes received - Rx int64 - // bytes transmitted - Tx int64 - // receive error count - ErrRx int64 - // transmit error count - ErrTx int64 - // last (most recent) receive in millis - LastTx int64 - // last (most recent) transmit in millis - LastRx int64 - // last non-wg (connect/dial) error - LastErr string - // last wg recv (read) error - LastRxErr string - // last wg send (write) error - LastTxErr string - // last successful receive in millis - LastGoodRx int64 - // last successful transmit in millis - LastGoodTx int64 - // last (most recent) handshake or ping or connect millis - LastOK int64 - // last refresh time in millis - LastRefresh int64 - // uptime in millis - Since int64 - // Current proxy status - Status string - // Current reason for Status - StatusReason string - // Extra is extra info about this router - Extra string -} - -type RpnServers interface { - // Get returns the RpnServer at index i; errors if i is out of bounds. - Get(i int) (*RpnServer, error) - // Len returns the number of RpnServers. - Len() int - // Json returns the RpnServer struct as JSON bytes. - Json() ([]byte, error) -} - -type RpnServer struct { - // Name of the server, if any. - Name string - // CSV of IP:Port and/or Domain:Port - Addrs string - // Country code of the location. - CC string - // City name of the location. - City string - // Key for RpnProxy.Fork() to get an RpnProxy instance for this RpnServer. - Key string - // Load score of this server (lower is better) - Load int32 - // Link speed in Mbps (higher is better). - Link int32 - // Number of active servers in this CC+City. - Count int32 - // Premium - Premium bool -} - -type IPMetadata struct { - // Proxy ID used to fetch this IP metadata. - ID string - // Provider that provided this IP metadata. - ProviderURL string - // IP address, never empty. - IP string - // ASN number, may be empty. - ASN string - // ASN organization name, may be empty. - ASNOrg string - // ASN domain, may be empty. - ASNDom string - // Country code, may be empty. - CC string - // City name, may be empty. - City string - // Address, may be empty. - Addr string - // Latitude, may be zero. - Lat float64 - // Longitude, may be zero. - Lon float64 -} diff --git a/intra/backend/ipn_wgkeygen.go b/intra/backend/ipn_wgkeygen.go deleted file mode 100644 index 9a5ec8b2..00000000 --- a/intra/backend/ipn_wgkeygen.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package backend - -import ( - "crypto/ed25519" - "crypto/rand" - "crypto/subtle" - "encoding/base64" - "encoding/hex" - "errors" - "fmt" - - "golang.org/x/crypto/curve25519" -) - -// from: github.com/WireGuard/wireguard-windows/blob/dcc0eb72a/conf/parser.go#L121 - -const klen = ed25519.SeedSize - -type ( - eckey [klen]byte -) - -var _ WgKey = (*eckey)(nil) - -type WgKey interface { - // IsZero returns true if the key is all zeros. - IsZero() bool - // Base64 returns the key as a base64-encoded string. - Base64() string - // Hex returns the key as a hex-encoded string. - Hex() string - // Mult returns the key multiplied by the basepoint (curve25519). - Mult() WgKey -} - -func (k *eckey) Hex() string { - return hex.EncodeToString(k[:]) -} - -func (k *eckey) Base64() string { - return base64.StdEncoding.EncodeToString(k[:]) -} - -func (k *eckey) IsZero() bool { - var zeros eckey - return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 -} - -func (k *eckey) Mult() WgKey { - var p [klen]byte - curve25519.ScalarBaseMult(&p, (*[klen]byte)(k)) - return (*eckey)(&p) -} - -func newPresharedKey() (*eckey, error) { - var k [klen]byte - _, err := rand.Read(k[:]) - if err != nil { - return nil, err - } - return (*eckey)(&k), nil -} - -func NewWgPrivateKey() (WgKey, error) { - k, err := newPresharedKey() - if err != nil { - return nil, err - } - k[0] &= 248 - k[31] = (k[31] & 127) | 64 - return k, nil -} - -func NewWgPrivateKeyFrom(k [klen]byte) WgKey { - k[0] &= 248 - k[31] = (k[31] & 127) | 64 - return (*eckey)(&k) -} - -func parseKeyBase64(s string) (*eckey, error) { - k, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, fmt.Errorf("invalid key: %v", err) - } - if len(k) != klen { - return nil, errors.New("keys must decode to exactly 32 bytes") - } - var key eckey - copy(key[:], k) - return &key, nil -} - -func NewWgPrivateKeyOf(b64 string) (WgKey, error) { - return parseKeyBase64(b64) -} diff --git a/intra/backend/ipn_wgkeygen_test.go b/intra/backend/ipn_wgkeygen_test.go deleted file mode 100644 index 6ef6c1f2..00000000 --- a/intra/backend/ipn_wgkeygen_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package backend - -import ( - "testing" - - "github.com/celzero/firestack/intra/log" -) - -// create a new private key and prints corres pubkey -func TestGenKeypair(t *testing.T) { - sk, err := NewWgPrivateKey() - if err != nil { - t.Error("failed to generate private key: ", err) - } else { - pk := sk.Mult() - t.Log("pub: ", pk.Base64(), "sk: ", sk.Base64()) - } -} - -func TestRadixSearch(t *testing.T) { - log.SetLevel(log.VERBOSE) - const goog = "google.com" - const wildgoog = ".google.com" - const mailgoog = "mail.google.com" - const dnsgoog = "dns.google.com" - const pgoog = "prefix" + goog - - r := NewRadixTree() - r.Set(goog, "goog") - r.Set(wildgoog, "wildgoog") - r.Set(mailgoog, "mailgoog") - - v0 := r.Get(goog) // goog if r.Set(goog, "goog") is uncommented; empty otherwise - v1 := r.Get(wildgoog) - v2 := r.Get(mailgoog) - v3 := r.Get(dnsgoog) // empty - v4 := r.Get(pgoog) // empty regardless of r.Set(goog, "goog") - - t.Log("v0?: ", v0, "\tv1: ", v1, "\tv2: ", v2, "\tv3?: ", v3, "\tv4?: ", v4) - - w0 := r.GetAny(goog) // goog if r.Set(goog, "goog") is uncommented; wildgoog otherwise - w1 := r.GetAny(wildgoog) // wildgoog - w2 := r.GetAny(mailgoog) // mailgoog - w3 := r.GetAny(dnsgoog) // wildgoog - w4 := r.GetAny(pgoog) // empty regardless of r.Set(goog, "goog") - - t.Log("w0: ", w0, "\tw1: ", w1, "\tw2: ", w2, "\tw3: ", w3, "\tw4: ", w4) -} diff --git a/intra/backend/netstat.go b/intra/backend/netstat.go deleted file mode 100644 index 8481459c..00000000 --- a/intra/backend/netstat.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -// NICStat is a collection of network interface statistics for the current tunnel. -type NICStat struct { - // bytes received - Rx string - // packets received - RxPkts int64 - // bytes sent - Tx string - // packets sent - TxPkts int64 - // invalid packets - Invalid int64 - // unknown l4 packets - L4Unknown int64 - // unknown l3 packets - L3Unknown int64 - // l4 drops - L4Drops int64 - // drops - Drops int64 -} - -type TUNStat struct { - Open bool - Up bool - Mtu int32 - Sid int - EpStats string - PcapMode string -} - -type NICInfo struct { - Name string - HwAddr string - Addrs string - Mtu int32 - Up bool - Running bool - Promisc bool - Lo bool - Arp int32 - Forwarding4 bool - Forwarding6 bool -} - -// IPFwdStat is a collection of IP forwarding statistics for the current tunnel. -type IPFwdStat struct { - // errors - Errs int64 - // unreachable - Unrch int64 - // no route - NoRoute int64 - // no endpoint - NoHop int64 - // packet too big - PTB int64 - // TTL timeouts - Timeouts int64 - // drops - Drops int64 -} - -// IPStat is a collection of IP statistics for the current tunnel. -type IPStat struct { - // invalid destination addresses - InvalidDst int64 - // invalid source addresses - InvalidSrc int64 - // invalid fragments - InvalidFrag int64 - // invalid packets - InvalidPkt int64 - // packet errors - Errs int64 - // packets received from l2 - Rcv int64 - // packets sent to l4 - Snd int64 - // packet receive errors from l2 - ErrRcv int64 - // packet send errors to l4 - ErrSnd int64 -} - -// ICMPStat is a collection of ICMP statistics for the current tunnel. -type ICMPStat struct { - Rcv4 int64 // ICMPv4 messages received - Rcv6 int64 // ICMPv6 messages received - Snd4 int64 // ICMPv4 messages sent - Snd6 int64 // ICMPv6 messages sent - UnrchRcv4 int64 // ICMPv4 unreachable received - UnrchRcv6 int64 // ICMPv6 unreachable received - UnrchSnd4 int64 // ICMPv4 unreachable sent - UnrchSnd6 int64 // ICMPv6 unreachable sent - Invalid4 int64 // ICMPv4 invalid messages - Invalid6 int64 // ICMPv6 invalid messages - TimeoutSnd4 int64 // ICMPv4 TTL timeouts sent - TimeoutSnd6 int64 // ICMPv6 TTL timeouts sent - TimeoutRcv4 int64 // ICMPv4 TTL timeouts received - TimeoutRcv6 int64 // ICMPv6 TTL timeouts received - Drops4 int64 // ICMPv4 messages dropped - Drops6 int64 // ICMPv6 messages dropped -} - -// TCPStat is a collection of TCP statistics for the current tunnel. -type TCPStat struct { - Active int64 // connecting - Passive int64 // listening - Est int64 // current established - EstClo int64 // established but closed - EstRst int64 // established but RST - EstTo int64 // established but timeout - Con int64 // current connected - ConFail int64 // failed connect attempts - PortFail int64 // failed port reservations - SynDrop int64 // syns dropped - AckDrop int64 // acks dropped - ErrChecksum int64 // bad checksums - ErrRcv int64 // invalid recv segments - ErrSnd int64 // segment send errors - Rcv int64 // segments received - Snd int64 // segments sent - Retrans int64 // retransmissions - Timeouts int64 // connection timeouts - Drops int64 // drops by max inflight threshold -} - -// UDPStat is a collection of UDP statistics for the current tunnel. -type UDPStat struct { - ErrChecksum int64 // bad checksums - ErrRcv int64 // recv errors - ErrSnd int64 // send errors - Snd int64 // packets sent - Rcv int64 // packets received - PortFail int64 // unknown port - Drops int64 // rcv buffer errors -} - -type RDNSInfo struct { - Open bool - Debug bool - Recording bool - - Looping bool - Slowdown bool - NewWireGuard string - Transparency bool - HappyEyeballs bool - PanicTest bool - FatalTest bool - SystemDNSForUndelegated bool - DefaultDNSAsFallback bool - SetUserAgent bool - OwnTunFd bool - PortForward bool - EIMEIF string - - Dialer4 bool - Dialer6 bool - DialerOpts string - TunMode string - - DNSPreferred string - DNSDefault string - DNSSystem string - DNS string - ALG string - - ProxiesHas4 bool - ProxiesHas6 bool - ProxyLastOK string - ProxySince string - ProxyStatus string - Proxies string - - AutoMode string - AutoDialsParallel bool - - LinkMTU string - - OpenConnsTCP string - OpenConnsUDP string - OpenConnsICMP string -} - -// ref: github.com/google/gops/blob/35c854fb84a/agent/agent.go -type GoStat struct { - Alloc string // bytes allocated and not yet freed - TotalAlloc string // total bytes allocated in aggregate - Sys string // bytes obtained from system - Lookups int64 // number of pointer lookups - Mallocs int64 // number of mallocs - Frees int64 // number of frees - - HeapAlloc string // bytes allocated on heap - HeapSys string // heap obtained from system - HeapIdle string // bytes in idle spans - HeapInuse string // bytes in non-idle span - HeapReleased string // bytes released to the OS - HeapObjects int64 // total number of allocated objects - - StackInuse string // bytes used by stack allocator - StackSys string // bytes obtained from system for stack allocator - MSpanInuse string // mspan allocs - MSpanSys string // bytes obtained from system for mspan structures - MCacheInuse string // mcache structures - MCacheSys string // bytes obtained from system for mcache structures - BuckHashSys string // bytes used by the profiling bucket hash table - - EnableGC bool // GC enabled - DebugGC bool // GC debug - GCSys string // bytes used for garbage collection system metadata - OtherSys string // bytes used for off-heap allocations - NextGC string // target heap size of the next GC - LastGC string // last run in heap - PauseSecs int64 // total STW pause time - NumGC int32 // number of GC runs - NumForcedGC int32 // number of forced GC runs - GCCPUFraction string // fraction of CPU time used by GC - - NumGoroutine int64 // number of goroutines - NumCgo int64 // number of cgo calls - NumCPU int64 // number of CPUs - - Trac string // gotraceback - Pers string // personality - Args string // command line arguments - Env string // environment variables -} - -type GoMetrics struct { - M string -} - -// NetStat is a collection of network engine statistics. -type NetStat struct { - NICSt NICStat - TUNSt TUNStat - NICIn NICInfo - IPSt IPStat - FWDSt IPFwdStat - ICMPSt ICMPStat - TCPSt TCPStat - UDPSt UDPStat - RDNSIn RDNSInfo - GOSt GoStat - GOMet GoMetrics -} - -// NIC returns the network interface statistics. -func (n *NetStat) NIC() *NICStat { return &n.NICSt } - -// NICI returns the network interface info. -func (n *NetStat) NICINFO() *NICInfo { return &n.NICIn } - -// TUN returns the internal tunnel statistics. -func (n *NetStat) TUN() *TUNStat { return &n.TUNSt } - -// IP returns the IP statistics. -func (n *NetStat) IP() *IPStat { return &n.IPSt } - -// FWD returns the IP forwarding statistics. -func (n *NetStat) FWD() *IPFwdStat { return &n.FWDSt } - -// ICMP returns the ICMP statistics. -func (n *NetStat) ICMP() *ICMPStat { return &n.ICMPSt } - -// TCP returns the TCP statistics. -func (n *NetStat) TCP() *TCPStat { return &n.TCPSt } - -// UDP returns the UDP statistics. -func (n *NetStat) UDP() *UDPStat { return &n.UDPSt } - -// RDNS returns the RDNS settings / info. -func (n *NetStat) RDNSINFO() *RDNSInfo { return &n.RDNSIn } - -// GO returns the Go runtime statistics. -func (n *NetStat) GO() *GoStat { return &n.GOSt } - -// GO2 returns the GO runtime metrics on g, m, & p. -func (n *NetStat) GO2() *GoMetrics { return &n.GOMet } diff --git a/intra/backend/protect.go b/intra/backend/protect.go deleted file mode 100644 index 75bf3d01..00000000 --- a/intra/backend/protect.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -const ( // see protect/protect.go - UidSelf = "rethink" - UidSystem = "system" - Localhost = "localhost" -) - -type Console interface { - Log(int32, string) - LogFD(readAfterDup int) bool - CrashFD(readUntilEOF int) bool -} - -// Controller provides a way to bind and protect socket file descriptors. -type Controller interface { - // Bind4 binds fd to any internet-capable IPv4 interface. - Bind4(who, addrport string, fd int) - // Bind6 binds fd to any internet-capable IPv6 interface. - // also: github.com/lwip-tcpip/lwip/blob/239918c/src/core/ipv6/ip6.c#L68 - Bind6(who, addrport string, fd int) - // Protect marks fd as protected. - Protect(who string, fd int) -} - -type Protector interface { - // UIP returns ip (network-order byte) to bind given a local/remote ipp (ip:port). - // (unused) - UIP(ipp string) []byte -} diff --git a/intra/backend/rnet_services.go b/intra/backend/rnet_services.go deleted file mode 100644 index 84fa555b..00000000 --- a/intra/backend/rnet_services.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package backend - -const ( - // type of services - - // SOCKS5 - SVCSOCKS5 = "svcsocks5" - // HTTP - SVCHTTP = "svchttp" - // SOCKS5 with forwarding proxy - PXSOCKS5 = "pxsocks5" - // HTTP with forwarding proxy - PXHTTP = "pxhttp" - - // status of proxies - - // svc UP - SUP = 0 - // svc OK - SOK = 1 - // svc not OK - SKO = -1 - // svc stopped - SOP = -2 -) - -type Server interface { - // Sets the proxy as the next hop. - Hop(p Proxy) error - // ID returns the ID of the server. - ID() string - // Start starts the server. - Start() error - // Type returns the type of the server. - Type() string - // Addr returns the address of the server. - GetAddr() string - // Status returns the status of the server. - Status() int - // Stop stops the server. - Stop() error - // Refresh re-registers the server. - Refresh() error -} - -type Services interface { - // Add adds a server. - AddServer(id, url string) (Server, error) - // Bridge bridges or unbridges server with proxy. - Bridge(serverid, proxyid string) error - // Remove removes a server. - RemoveServer(id string) (ok bool) - // RemoveAll removes all servers. - RemoveAll() - // Get returns a Server. - GetServer(id string) (Server, error) - // Refresh re-registers servces and returns a csv of active ones. - RefreshServers() (active string) -} - -type ServerSummary struct { - // http1, socks5, etc. - Type string - // Server ID. - SID string - // Proxy ID (hop) that handled egress, if any. - PID string - // Connection id - CID string - // Total uploaded (bytes). - Tx int64 - // Total downloaded (bytes). - Rx int64 - // Conn open duration (millis). - Duration int64 - // Error messages, if any. - Msg string -} - -// ServerListener receives Server events. -type ServerListener interface { - // SvcRoute decides how to forward an incoming connection over service (sid). - SvcRoute(sid, pid, network, sipport, dipport string) *Tab - // OnSvcComplete reports summary after a connection closes. - OnSvcComplete(*ServerSummary) -} - -type Tab struct { - // CID is the ID of this connection. - CID string - // Block is true if this connection should be blocked. - Block bool -} diff --git a/intra/bootstrap.go b/intra/bootstrap.go deleted file mode 100644 index 7798da9a..00000000 --- a/intra/bootstrap.go +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "context" - "errors" - "fmt" - "net/netip" - "net/url" - "strings" - "sync" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dns53" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/doh" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -const ( - // bootid marks the transports as "Protected" with internal - // hostname to ip cache (in ipmap.go via dnsx.RegisterAddrs) - bootid = dnsx.Bootstrap - // protected hostnames are only used by dnsx.DNS53 transport. - protectedHostname = protect.UidSelf // or protect.UidSystem - // special hostname is used only by dnsx.Goos transport. - builtinHostname = protect.Localhost -) - -var ( - errDefaultTransportType = errors.New("unknown default transport type") - errDefaultTransportNotReady = errors.New("default transport not ready") - errCannotStart = errors.New("missing proxies or controller") -) - -var ( - localip4 = "127.0.0.1" - localip6 = "::1" -) - -// DefaultDNS is the resolver used by all dialers. -type DefaultDNS interface { - x.DNSTransport - kickstart(px ipn.ProxyProvider) error - reinit(typ, ipOrUrl, ips string) error -} - -type bootstrap struct { - ctx context.Context - proxies ipn.ProxyProvider // never nil if underlying transport is set - - mu sync.RWMutex // protects following fields: - tr dnsx.Transport // the underlying transport - typ string // DOH or DNS53 - ipports string // never empty for DNS53 - url string // never empty for DOH - hostname string // never empty -} - -var _ DefaultDNS = (*bootstrap)(nil) -var _ dnsx.Transport = (*bootstrap)(nil) - -// NewDefaultDNS creates a new DefaultDNS resolver of type typ. For typ DOH, -// url scheme is http or https; for typ DNS53, url is ipport or csv(ipport). -// ips is a csv of ipports for typ DOH, and nil for typ DNS53. -func NewDefaultDNS(typ, url, ips string) (DefaultDNS, error) { - b := new(bootstrap) - b.ctx = context.TODO() - - if err := b.reinit(typ, url, ips); err != nil { - return nil, err - } - - // context.AfterFunc(b.ctx, func() { b.Stop() }) - log.I("dns: default: %s new %s %s %s", typ, url, b.hostname, ips) - - return b, nil -} - -// NewBuiltinDefaultDNS creates a new DefaultDNS resolver of type dnsx.DNS53. -// It may either use OS provided or network provided DNS resolver, & will not -// work if the tunnel is in "Loopback" mode; create w/ NewDefaultDNS instead. -func NewBuiltinDefaultDNS() (DefaultDNS, error) { - b := new(bootstrap) - b.ctx = context.TODO() - - if err := b.reinit("", "", ""); err != nil { - return nil, err - } - - // context.AfterFunc(b.ctx, func() { b.Stop() }) - log.I("dns: default: built-in") - - return b, nil -} - -func (b *bootstrap) newDefaultDohTransportLocked() (dnsx.Transport, error) { - ips := strings.Split(b.ipports, ",") - if len(b.url) > 0 && len(ips) > 0 { - return doh.NewTransport(b.ctx, bootid, b.url, ips, b.proxies) - } - return nil, errCannotStart -} - -func (b *bootstrap) newDefaultTransportLocked() (dnsx.Transport, error) { - if ipcsv := b.ipports; len(ipcsv) > 0 { - return dns53.NewTransportFromHostname(b.ctx, bootid, b.hostname, ipcsv, b.proxies) - } - return nil, errCannotStart -} - -func (b *bootstrap) reinit(trtype, ippOrUrl, ipcsv string) error { - b.mu.Lock() - defer b.mu.Unlock() - - if len(ippOrUrl) <= 0 && len(ipcsv) <= 0 { - b.url = "" - b.hostname = builtinHostname // use Goos - b.ipports = localip4 + "," + localip6 - b.typ = dnsx.DNS53 // ignore trtype - } else if trtype == dnsx.DOH { - if len(ippOrUrl) <= 0 { - log.E("dns: default: reinit: empty url! ips? %s", ipcsv) - return dnsx.ErrNotDefaultTransport - } - - // note: plain ip4 address is a valid url; ex: 1.2.3.4 - if parsed, err := url.Parse(ippOrUrl); err != nil { // ippOrUrl is a url? - log.E("dns: default: reinit: not %s url %s", trtype, ippOrUrl) - return dnsx.ErrNotDefaultTransport - } else if len(ipcsv) <= 0 { - // if ips are empty, bootstrap will be stuck in a catch-22 where - // dialers.New(...) in doh calls back into ipmapper to resolve the hostname - // in ippOrUrl, which calls into Default DNS (aka bootstrap) via resolver - // to resolve the hostname in ippOrUrl. - log.E("dns: default: reinit: doh: empty ips %s", ipcsv) - return dnsx.ErrNotDefaultTransport - } else { - b.url = ippOrUrl - b.hostname = parsed.Hostname() - b.ipports = ipcsv // should never be empty - b.typ = dnsx.DOH - } - } else { // ippOrUrl is an ipport? - if trtype != dnsx.DNS53 { - log.E("dns: default: reinit: ipport %s; %s != %s", ippOrUrl, trtype, dnsx.DNS53) - return dnsx.ErrNotDefaultTransport - } - if len(ippOrUrl) <= 0 { - log.I("dns: default: reinit: empty ipport %s; using: ", ippOrUrl, ipcsv) - ippOrUrl = ipcsv - } - if len(ippOrUrl) <= 0 { - log.E("dns: default: reinit: empty url! ips? %s", ipcsv) - return dnsx.ErrNotDefaultTransport - } - - // may be set to localhost (in which case it is equivalent to x.Goos) - // when no other system resolver could be determined - if strings.HasPrefix(ippOrUrl, builtinHostname) { - // note: this is not goos; for goos, trtype, ippOrUrl, ipcsv are all empty! - log.I("dns: default: reinit: loopback %s", ippOrUrl) - ippOrUrl = localip4 + "," + localip6 // see also dns53/ipmapper.go - } - ips := strings.Split(ippOrUrl, ",") - if len(ips) <= 0 { - log.E("dns: default: reinit: empty ipport %s", ippOrUrl) - return dnsx.ErrNotDefaultTransport - } - first := ips[0] - // todo: tests just the first ipport; test all? - if _, err := xdns.DnsIPPort(first); err != nil { - return err - } else { - b.url = "" - b.hostname = protectedHostname // override all incoming hostnames - b.ipports = ippOrUrl // always ipaddrs or csv(ipaddrs), never empty - b.typ = dnsx.DNS53 - } - } - - log.I("dns: default: %s reinit %s %s w/ %s", trtype, b.url, b.hostname, b.ipports) - - // if proxies is set, restart to create new transport - if b.proxies != nil { - return b.recreateLocked() - } - return nil -} - -func (b *bootstrap) recreateLocked() error { - return b.kickstartLocked(b.proxies) // restart with new proxies -} - -func (b *bootstrap) kickstart(px ipn.ProxyProvider) error { - b.mu.Lock() - defer b.mu.Unlock() - - return b.kickstartLocked(px) -} - -func (b *bootstrap) kickstartLocked(px ipn.ProxyProvider) error { - if px == nil { - return errCannotStart - } - - b.proxies = px - useGoos := b.hostname == builtinHostname - - var tr dnsx.Transport - var err error - switch b.typ { - case dnsx.DNS53: - if useGoos { - tr, err = dns53.NewGoosTransport(b.ctx, px) - } else { - tr, err = b.newDefaultTransportLocked() - } - case dnsx.DOH: - tr, err = b.newDefaultDohTransportLocked() - default: - err = errDefaultTransportType - } - - if prev := b.tr; prev != nil { - core.Gx1("dns.bootstrap.stop", stopTransport, prev) // stop after new transport is ready - log.I("dns: default: removing %s %s[%s]; using %s %s", - b.typ, b.hostname, b.IPPorts(), typstr(tr), ippstr(tr)) - } - - // always override previous transport with (new) tr; even if nil - b.tr = tr - - if err != nil { - log.E("dns: default: start; err %v", err) - return err - } else if b.tr == nil { - log.W("dns: default: start; nil transport %s %s", b.typ, b.hostname) - return errCannotStart - } - - log.I("dns: default: start; %s with %s[%s]; ok? %t", - b.typ, b.hostname, b.GetAddr(), len(b.ipports) > 0) - return nil -} - -func (*bootstrap) ID() string { - // never assume underlying transport's identity - return dnsx.Default -} - -func (b *bootstrap) Type() string { - return b.typ // DOH or DNS53 -} - -func (b *bootstrap) Query(network string, q *dns.Msg, smm *x.DNSSummary) (*dns.Msg, error) { - smm.ID = dnsx.Default - smm.Type = b.typ - smm.UID = protect.UidSelf - if tr := b.tr; tr != nil { - if settings.Debug { - log.V("dns: default: %s query? %t", network, q != nil) - } - return dnsx.Req(tr, network, q, smm) - } - smm.Status = dnsx.TransportError // InternalError? - smm.Msg = strings.Join([]string{smm.Msg, errDefaultTransportNotReady.Error()}, ";") - return nil, errDefaultTransportNotReady -} - -func (b *bootstrap) P50() int64 { - if tr := b.tr; tr != nil { - return tr.P50() - } - return 0 -} - -func (b *bootstrap) GetAddr() string { - if tr := b.tr; tr != nil { - return tr.GetAddr() - } - return dnsx.NoDNS -} - -func (b *bootstrap) GetRelay() x.Proxy { - return nil -} - -func (b *bootstrap) IPPorts() []netip.AddrPort { - if tr := b.tr; tr != nil { - return tr.IPPorts() - } - return dnsx.NoIPPort -} - -func (b *bootstrap) Status() int { - if tr := b.tr; tr != nil { - return tr.Status() - } - return dnsx.ClientError // see also: dnsx/plus.go -} - -func (b *bootstrap) Stop() error { - log.I("dns: default: stopping %s %s", b.typ, b.hostname) - if tr := b.tr; tr != nil { - return tr.Stop() - } - return nil -} - -func typstr(tr dnsx.Transport) string { - if tr == nil { - return "" - } - return tr.Type() -} - -func ippstr(tr dnsx.Transport) string { - if tr == nil { - return "" - } - return fmt.Sprintf("%v", tr.IPPorts()) -} - -func stopTransport(t dnsx.Transport) { - if t != nil { - _ = t.Stop() - } -} diff --git a/intra/common.go b/intra/common.go deleted file mode 100644 index e9715fcc..00000000 --- a/intra/common.go +++ /dev/null @@ -1,821 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "context" - "fmt" - "math/rand" - "net" - "net/netip" - "runtime/debug" - "strconv" - "strings" - "sync" - "time" - - "slices" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "github.com/celzero/firestack/intra/netstat" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" -) - -const ( - smmchSize = 256 // some comfortably high number - UNKNOWN_UID = core.UNKNOWN_UID - UNKNOWN_UID_STR = core.UNKNOWN_UID_STR - // ANDROID_UID = core.ANDROID_UID - ANDROID_UID_STR = core.ANDROID_UID_STR - SELF_UID = protect.UidSelf - UNSUPPORTED_NETWORK = core.UNSUPPORTED_NETWORK -) - -const ( - HDLOK = iota - HDLEND -) - -var ( - anyaddr4 = netip.IPv4Unspecified() - anyaddr6 = netip.IPv6Unspecified() -) - -// immediate is the wait time before sending a summary to the listener. -var immediate = time.Duration(0) - -// zeroListener is a no-op implementation of SocketListener. -type zeroListener struct{} - -var _ SocketListener = (*zeroListener)(nil) - -func (*zeroListener) Preflow(_, _ int32, _, _ string) *PreMark { return nil } -func (*zeroListener) Flow(_, _ int32, _, _, _, _, _, _ string) *Mark { return nil } -func (*zeroListener) Inflow(_, _ int32, _, _ string) *Mark { return nil } -func (*zeroListener) PostFlow(*Mark) {} -func (*zeroListener) OnSocketClosed(*SocketSummary) {} - -var nooplistener = new(zeroListener) - -type baseHandler struct { - proto string // tcp, udp, icmp - ctx context.Context - - resolver dnsx.Resolver // dns resolver to forward queries to - prox ipn.ProxyProvider // proxy provider - smmch chan *SocketSummary - listener SocketListener // listener for socket summaries - - conntracker core.ConnMapper // connid -> [local,remote] - fwtracker *core.ExpMap[string, string] // uid+dst(domainOrIP) -> blockSecs - - once sync.Once - - // fields below are mutable - - status *core.Volatile[int] // status of this handler -} - -var _ netstack.GBaseConnHandler = (*baseHandler)(nil) - -func newBaseHandler(pctx context.Context, proto string, r dnsx.Resolver, px ipn.ProxyProvider, l SocketListener) *baseHandler { - h := &baseHandler{ - ctx: pctx, - proto: proto, - resolver: r, - prox: px, - smmch: make(chan *SocketSummary, smmchSize), - listener: l, - fwtracker: core.NewExpiringMap[string, string](pctx), - conntracker: core.NewConnMap(), - status: core.NewVolatile(HDLOK), - } - context.AfterFunc(pctx, h.End) - return h -} - -// to is the local address, from is the remote address. -func (h *baseHandler) onInflow(to, from netip.AddrPort) (fm *Mark) { - blockmode := settings.BlockMode.Load() - fm = optionsBlock - // BlockModeNone returns false, BlockModeSink returns true - if blockmode == settings.BlockModeSink { - return // blocks everything - } // else: BlockModeNone|BlockModeFilter|BlockModeFilterProc - - uid := UNKNOWN_UID // todo: uid only known on egress? - nn := ntoa(h.proto) // -1 unsupported - - // inflow does not go through nat/alg/dns/proxy - fm, ok := core.Grx(h.proto+".inflow", func(_ context.Context) (*Mark, error) { - return h.listener.Inflow(nn, int32(uid), to.String(), from.String()), nil - }, onInFlowTimeout) - - if !ok || fm == nil { - fm = optionsBlock // fail-safe: block everything - log.E("com: %s: inFlow: timeout %v <= %v", h.proto, to, from) - return - } - // todo: assert fm.PID == ipn.Ingress or ipn.Block - return -} - -// onFlow calls listener.Flow to determine egress rules and routes; thread-safe. -func (h *baseHandler) onFlow(localaddr, target netip.AddrPort) (fm *Mark, undidAlg bool, ips, doms string) { - blockmode := settings.BlockMode.Load() - fm = optionsBlock - // BlockModeNone returns false, BlockModeSink returns true - if blockmode == settings.BlockModeSink { - return // blocks - } else if blockmode == settings.BlockModeNone { - fm = optionsExit - } // else: BlockModeFilter|BlockModeFilterProc - - // Implicit: BlockModeFilter or BlockModeFilterProc - uid := UNKNOWN_UID - preuid := UNKNOWN_UID_STR - if blockmode == settings.BlockModeFilterProc { - procEntry := netstat.FindProcNetEntry(h.proto, localaddr, target) - if procEntry != nil { - uid = procEntry.UserID - preuid = strconv.Itoa(uid) - } else { - // TODO: ipmapper wouldn't call into LookupFor (instead call into LocalLookup) - // when preuid (uid) is UNKNOWN_UID which is not what we want with ResolveFor - // call made below; that is, we want ResolveFor to call into ipmapper to then - // in to LookupFor so dnsx.Fixed would be the chosen transport? - // uid = ANDROID_UID - // preuid = ANDROID_UID_STR - } - } - - network := h.proto - if network == "icmp" && target.Addr().Is6() { - network = "icmp6" - } - var proto int32 = ntoa(network) // -1 unsupported - - src := localaddr.String() - dst := target.String() - - var pdoms, blocklists string - - pre, _ := core.Grx(h.proto+".preflow", func(_ context.Context) (*PreMark, error) { - return h.listener.Preflow(proto, int32(uid), src, dst), nil - }, onPreFlowTimeout) - - hasPre := pre != nil - if hasPre && pre != nil /*nilaway*/ && len(pre.UID) > 0 { - if c, cerr := strconv.Atoi(pre.UID); cerr == nil { - uid = c - // alg.go, for its per-uid cache, uses a special UID to denote - // our own process (protect.UidSelf) and so make sure to use it. - if pre.IsUidSelf { - preuid = SELF_UID - } else { - preuid = pre.UID - } - } else { - log.E("com: %s: onFlow: preflow: uid %s is not a number; use %s; err? %v", h.proto, pre.UID, uid, cerr) - } - } - - if settings.Debug { - log.VV("com: %s: onFlow: preflow: has? %t, preuid: %s for %s => %s", h.proto, hasPre, preuid, src, dst) - } - - // alg happens after nat64, and so, alg knows nat-ed ips and un-nats them; - // that is, real ips ("ips") are already un-nated where applicable. - undidAlg, ips, doms, pdoms, blocklists = h.undoAlg(target.Addr(), preuid) - hasOldIPs := len(ips) > 0 - if undidAlg && !hasOldIPs { - hasNewIPs := false - if hasPre { - for d := range strings.SplitSeq(doms, ",") { - nodomain := len(d) <= 0 - if nodomain && settings.Debug { - logwif(len(d) <= 0)("com: %s: onFlow: preflow: %v from %v => %v for %s; nodomain? %t", - h.proto, doms, src, target, preuid, nodomain) - continue - } - newips, err := dialers.ResolveFor(d, preuid) - hasNewIPs = err == nil && len(newips) > 0 - logwif(!hasNewIPs)("com: %s: onFlow: preflow: resolved alg domain %s? %t; new ips %v for %s => %s; preuid: %s", - h.proto, d, hasNewIPs, newips, src, dst, preuid) - if hasNewIPs { // already unalg'd by ipmapper - // _, ips, doms, pdoms, blocklists = h.undoAlg(target.Addr()) - ips = dnsx.Netip2Csv(newips) - break - } // else: either no known transport or preflow failed - } - } // else: either no known transport or preflow failed - - if !hasPre || !hasNewIPs { - log.E("com: %s: onFlow: alg, preflow? %t, ips? %t for %s; pre: %v; block!", - h.proto, hasPre, hasNewIPs, doms, pre) - // either optionsBase (BlockModeNone) or optionsBlock - return fm, undidAlg, "", "" - } // else: if we've got target and/or old ips, dial them - } else { - if settings.Debug { - log.D("com: %s: onFlow: noalg? %t or hasips? %t for %s => %s; preuid %s", - h.proto, !undidAlg, hasOldIPs, src, dst, preuid) - } - } - - if settings.Debug && (len(ips) <= 0 || len(doms) <= 0) { - log.D("com: %s: onFlow: no realips(%s) or domains(%s + %s), for src=%s dst=%s; preuid=%s; alg? %t", - h.proto, ips, doms, pdoms, localaddr, target, preuid, undidAlg) - } - - fm, ok := core.Grx(h.proto+".flow", func(_ context.Context) (*Mark, error) { - return h.listener.Flow(proto, int32(uid), src, dst, ips, doms, pdoms, blocklists), nil - }, onFlowTimeout) - - loopback := settings.Loopingback.Load() - - if fm == nil || !ok { // zeroListener returns nil - log.W("com: %s: onFlow: empty res or on flow timeout %t; block!", h.proto, ok) - fm = optionsBlock - } else if len(fm.PIDCSV) <= 0 { - fm.PIDCSV = ipn.Exit - if preuid == SELF_UID { - if loopback { - fm.PIDCSV = ipn.Base - } else if h.prox.AutoActive() { - fm.PIDCSV = ipn.Auto - } - } - log.E("com: %s: onFlow: missing proxyid for preuid %s (%s => %s) from kt (alg: %v + %v); %s!", - h.proto, preuid, src, dst, ips, doms, fm.PIDCSV) - } - // in loopback mode, user may have setup SELF_UID to be routed out via remote proxies. - // in other cases, routing Rethink via remote proxies is probably a bug. - if !loopback && preuid == SELF_UID && !ipn.IsAnyLocalProxy(strings.Split(fm.PIDCSV, ",")...) { - egress := ipn.Exit - if h.resolver.IsDnsAddr(target) { - egress = ipn.Base // see: udp.go:dnsOverride - } - log.W("com: %s: onFlow: preflow: preuid %s (%s => %s) is rethink (loopback? %t)! override %s to %s!", - h.proto, preuid, src, dst, loopback, fm.PIDCSV, egress) - fm.PIDCSV = egress - } - - return -} - -// forward copies data between local and remote, and tracks the connection. -// local, wired to TUN via netstack, is either gonet.TCPConn or gonet.UDPConn. -// remote, wired to egress, is wrapped in rwext; but the underlying conn may -// be *net.TCPConn, *net.UDPConn, *demuxconn, or dialers.retrier|splitter etc. -// It also sends a summary to the listener when done. Always called in a goroutine. -func (h *baseHandler) forward(local, remote net.Conn, smm *SocketSummary) { - cid := smm.ID - uid := smm.UID - pid := smm.PID - via := strings.Join([]string{smm.Proto, smm.PID, smm.RPID, smm.ID}, ":") - - tup := conn2str(local, remote) - - h.conntracker.Track(cid, uid, pid, local, remote) - defer h.conntracker.Untrack(cid) - - isrwext := false - didSet := false - timeoutsecs := 0 - // enable core.Pipe (sendfile/zero-copy) optimizations on TCP if - // read & write deadlines are not set (as in rwext is effectively - // a no-op) by unwrapping the underlying remote conn from rwext. - if r, ok := remote.(rwext); ok { - isrwext = true - if timeoutsecs, didSet = r.SetTimeout(); didSet || timeoutsecs <= 0 { - remote = r.Unwrap() // c may be *net.TCPConn or *demuxconn or *dialers.retrier|splitter - } - } - log.I("com: %s: forward: new conn %s rwext? %t (%T), optset? %t (%ds); %s for %s", - h.proto, via, isrwext, remote, didSet, timeoutsecs, tup, uid) - - uploadch := make(chan ioinfo) - - go upload(via, local, remote, uploadch) - dbytes, derr := download(via, local, remote) - - upload := <-uploadch - - // remote conn could be dialed in to some proxy; and so, - // its remote addr may not be the same as smm.Target - smm.Rx = dbytes - smm.Tx = upload.bytes - - h.queueSummary(smm.done(derr, upload.err)) -} - -func (h *baseHandler) queueSummary(s *SocketSummary) { - if s == nil { - return - } - - // go.dev/play/p/AXDdhcMu2w_k - // even though channel done is always closed before ch, we still - // see panic from the select statement writing to ch; and hence - // the need to have this nested select statement. - - // log.VV("com: %s: queueSummary: %x %x %s", h.proto, h.smmch, h.ctx, s.ID) - select { - case <-h.ctx.Done(): - if settings.Debug { - log.D("%s: queueSummary: end: %s", h.proto, s) - } - default: - select { - case <-h.ctx.Done(): - case h.smmch <- s: - default: - log.W("com: %s: sendSummary: dropped: %s", h.proto, s) - } - } -} - -// must be called from a goroutine; loops reading from ch until done is closed. -func (h *baseHandler) processSummaries() { - defer core.Recover(core.DontExit, "c.sendSummary") - - for { - select { - case <-h.ctx.Done(): - return - case s := <-h.smmch: - if s != nil && len(s.ID) > 0 { - h.sendSummary(s, immediate) - } - } - } -} - -func (h *baseHandler) sendSummary(s *SocketSummary, after time.Duration) { - defer core.Recover(core.DontExit, "c.sendNotif: "+s.ID) - - if after > 0 { - // sleep a bit to avoid scenario where kotlin-land - // hasn't yet had the chance to persist info about - // this conn (cid) to meaninfully process its summary - time.Sleep(after) - } - - if settings.Debug { - log.VV("com: %s: end? sendNotif: %s", h.proto, s) - } - h.listener.OnSocketClosed(s) // s.Duration may be uninitialized (zero) -} - -// OpenConns implements netstack.GBaseConnHandler -func (h *baseHandler) OpenConns() string { - return fmt.Sprintf("%d | %s", h.conntracker.Len()/2, h.conntracker.String()) -} - -// CloseConns implements netstack.GBaseConnHandler -func (h *baseHandler) CloseConns(cidsOrPidsOrUids []string) (closedCids []string) { - if len(cidsOrPidsOrUids) <= 0 { - closedCids = h.conntracker.Clear() - } else { - closedCids = h.conntracker.UntrackBatch(cidsOrPidsOrUids) - } - - log.I("com: %s: conns closed %d/%d", h.proto, len(closedCids), len(cidsOrPidsOrUids)) - return closedCids -} - -// aux is usually dst domains, ip, ip:port -func (h *baseHandler) flowID(uid string, aux ...string) (fid string) { - if len(uid) <= 0 { // uid may be empty - uid = UNKNOWN_UID_STR - } // or: uid may be unknown - fid = uid - for _, v := range aux { - if len(v) > 0 { // choose the first non-empty aux - return fid + v - } - } - return -} - -func (h *baseHandler) stall(flowid string) (secs uint32) { - if n := h.fwtracker.Get(flowid); n <= 0 { - secs = 0 // no stall - } else if n > 30 { - secs = 30 // max up to 30s - } else if n < 5 { - secs = (rand.Uint32() % 5) + 1 // up to 5s - } else { - secs = n - } - // track uid->target for n secs, or 30s if n is 0 - life30s := ((29 + secs) % 30) + 1 - newlife := time.Duration(life30s) * time.Second - h.fwtracker.Set(flowid, newlife) - if secs > 0 { - w := time.Duration(secs) * time.Second - time.Sleep(w) - } - return -} - -func (h *baseHandler) isDNS(addr netip.AddrPort) bool { - return addr.IsValid() && h.resolver.IsDnsAddr(addr) -} - -func (h *baseHandler) dnsOverride(conn net.Conn, uid string, smm *SocketSummary) bool { - // addr with zone information removed; see: netip.ParseAddrPort which h.resolver relies on - // addr2 := &net.TCPAddr{IP: addr.IP, Port: addr.Port} - // conn closed by the resolver - core.Gx(h.proto+".dns", func() { - // SocketSummary is not meant to be used by the listener; x.DNSSummary is - // but call into PostFlow & OnSocketClosed anyway, to avoid ambiguities - // on which sockets / sessions are still active. - rx, tx, errs := h.resolver.Serve(h.proto, conn, uid) - smm.Rx = rx - smm.Tx = tx - // smm.Rtt - // smm.Target = DNS resolver? - h.listener.OnSocketClosed(smm.done(errs...)) - }) - return true -} - -// End implements netstack.GBaseConnHandler -func (h *baseHandler) End() { - h.once.Do(func() { - h.CloseConns(nil) - h.status.Store(HDLEND) - close(h.smmch) // close listener chan - log.I("com: %s: handler end %x %x", h.proto, h.ctx, h.smmch) - }) -} - -// TODO: Propagate TCP RST using local.Abort(), on appropriate errors. -func upload(id string, local, remote net.Conn, ioch chan<- ioinfo) { - debug.SetPanicOnFault(true) - defer core.Recover(core.Exit11, "c.upload."+id) - defer core.CloseOp(local, core.CopR) - defer core.CloseOp(remote, core.CopW) - defer close(ioch) - - n, err := core.Pipe(remote, local) - - if settings.Debug { - log.D("com: %s upload(%d) done(%v) b/w %s", - id, n, err, conn2str(local, remote)) - } - ioch <- ioinfo{n, err} -} - -func download(id string, local, remote net.Conn) (n int64, err error) { - defer core.CloseOp(local, core.CopW) - defer core.CloseOp(remote, core.CopR) - - n, err = core.Pipe(local, remote) - - if settings.Debug { - log.D("com: %s download(%d) done(%v) b/w %s", - id, n, err, conn2str(local, remote)) - } - return -} - -func oneRealIPPort(realips []netip.Addr, origipp netip.AddrPort, maybeIncludeOrig bool) netip.AddrPort { - if len(realips) <= 0 { - return origipp - } - if first := makeIPPorts(realips, origipp, maybeIncludeOrig, 1); len(first) > 0 { - return first[0] - } - return origipp -} - -func makeAnyAddrPort(origipp netip.AddrPort) netip.AddrPort { - if !origipp.IsValid() { - return origipp - } - if origipp.Addr().Is4() { - return netip.AddrPortFrom(anyaddr4, origipp.Port()) - } - return netip.AddrPortFrom(anyaddr6, origipp.Port()) -} - -// makeIPPorts returns a slice of valid, non-zero at most cap AddrPorts. -// The first element may be origipp AddrPort, if realips is empty or contains only unspecified IPs. -// or maybeIncludeOrig is true and origipp's IP family is included in dialer's current config. -func makeIPPorts(ips []netip.Addr, origipp netip.AddrPort, maybeIncludeOrig bool, cap int) []netip.AddrPort { - use4 := dialers.Use4() - use6 := dialers.Use6() - orig4 := origipp.Addr().Is4() - orig6 := origipp.Addr().Is6() - - if use4 && use6 { - // happy-eyeballs from clients should take care of dialing both - // families when both v4 and v6 routes are available. - use4 = orig4 - use6 = orig6 - } - - if cap <= 0 || cap > len(ips) { - cap = len(ips) - } - - origip := origipp.Addr() - origport := origipp.Port() - willIncludeOrig := maybeIncludeOrig && ((use4 && orig4) || (use6 && orig6)) - r := make([]netip.AddrPort, 0, cap) - // override alg-ip with the first real-ip - for _, v := range ips { // may contain unspecifed ips - if len(r) >= cap { - break - } - if v == origip && willIncludeOrig { - // skip duplicate of origipp which will be included later - continue - } - if v.IsValid() && !v.IsUnspecified() { - r = append(r, netip.AddrPortFrom(v, origport)) - } // else: discard ip - } - - if settings.Debug { - log.VV("com: makeIPPorts(v4? %t, v6? %t) for %v; tot: %d; in: %v, out: %v", - use4, use6, origipp, len(ips), ips, r) - } - - if len(r) > 0 { - s := core.ShuffleInPlace(r) - if willIncludeOrig { - s = append([]netip.AddrPort{origipp}, s...) - } - return s - } - return []netip.AddrPort{origipp} -} - -// algip may or may not be an actual alg ip. -// returned realips may be incoming algip itself or translated from algip, -// depending on whether alg is enabled (ref: undidAlg). -func (h *baseHandler) undoAlg(algip netip.Addr, uid string) (undidAlg bool, realips, domains, probableDomains, blocklists string) { - const forcePTR = true // force PTR (realip => algans) translation? - anyTransport := dnsx.NoDNS - r := h.resolver - gw := r.Gateway() - - ipok := !algip.IsUnspecified() && algip.IsValid() - didForce := false - hasreal := false - if ipok && gw != nil { - domains, didForce = gw.PTR(algip, uid, anyTransport, !forcePTR) // does NAT (algip => algans) translation - if !didForce && len(domains) <= 0 { - probableDomains, _ = gw.PTR(algip, uid, anyTransport, forcePTR) - } - var ips []netip.Addr - // ips will contain the incoming "algip" arg, in cases where alg is NOT enabled. - ips, undidAlg = gw.X(algip, uid) - realips = dnsx.Netip2Csv(ips) - hasreal = len(realips) > 0 - blocklists = gw.RDNSBL(algip) - } - // pick up corresponding domains from dialer's ipmap cache if none from gw.PTR - if ipok && len(domains) <= 0 && len(probableDomains) <= 0 { - if hosts := dialers.Ptr(algip); len(hosts) > 0 { - probableDomains = strings.Join(hosts, ",") - } - if uid == SELF_UID { - domains = probableDomains - probableDomains = "" - } - } - - logwif(!hasreal)("com: %s: alg: undoAlg: for [%s] (gw? %t ok? %t, force? %t, withForce? %t) %s => %v (for %s + %s / block: %s)", - h.proto, uid, gw != nil, undidAlg, didForce, forcePTR, algip, realips, domains, probableDomains, blocklists) - return -} - -func filterFamilyForDialingWithFailSafe(ipcsv string) (included []netip.Addr, excluded []netip.Addr, excludedIsIncluded bool) { - included, excluded, excludedIsIncluded = filterFamilyForDialing(ipcsv) - if (!excludedIsIncluded || len(excluded) > 0) && len(included) > 0 { - // if not falling back, then include one excluded ip as a fail-safe - included = append(included, core.ChooseOne(excluded)) - } - return included, excluded, excludedIsIncluded -} - -// filterFamilyForDialing filters out invalid IPs and IPs that are not -// of the family that the dialer is configured to use. -func filterFamilyForDialing(ipcsv string) (included []netip.Addr, excluded []netip.Addr, excludedIsIncluded bool) { - if len(ipcsv) <= 0 { - return - } - ips := dnsx.Csv2Netip(ipcsv) - // assume ipv4 is available on ipv6-only network by the way of - // any of the 4to6 mechanisms like 464Xlat, DNS64/NAT64, Teredo etc. - use4 := dialers.Use4() - use6 := dialers.Use6() - invalids := 0 - var filtered, unfiltered []netip.Addr - for _, ip := range ips { - if !ip.IsValid() { - invalids++ - continue - } - // TODO: always include unspecified IPs as it is used by the client to make block/no-block decisions? - // The above is not true anymore? - if use4 && ip.Is4() || use6 && ip.Is6() { - filtered = append(filtered, ip) - } else if ip.IsValid() && !ip.IsUnspecified() { - unfiltered = append(unfiltered, ip) - } - } - logger := log.VV - // fail open: if no ipv4 then fallback to ipv6, and vice-versa. - if len(filtered) <= 0 { - excludedIsIncluded = true - filtered = unfiltered - unfiltered = nil - logger = log.W - } - logger("com: filterFamily(v4? %t, v6? %t, fallback? %t): filtered: %d/%d; in: %v, out: %v, ignored: %v + %d", - use4, use6, excludedIsIncluded, len(filtered), len(ips), ips, filtered, unfiltered, invalids) - return filtered, unfiltered, excludedIsIncluded -} - -// returns conn-id, user-id, flow-id. -// all four values may be empty. -func (h *baseHandler) judge(decision *Mark, aux ...string) (cid, uid, fid string, pids []string) { - if decision == nil { - return - } - - if len(decision.UID) > 0 { - uid = decision.UID - } else { - uid = UNKNOWN_UID_STR - } - cid = decision.CID - fid = h.flowID(uid, aux...) - - if len(decision.PIDCSV) > 0 { - for v := range strings.SplitSeq(decision.PIDCSV, ",") { - if v == ipn.Block { // block overrides all other pids - pids = []string{ipn.Block} - return - } - v = strings.TrimSpace(v) - if len(v) > 0 { - pids = append(pids, v) - } - } - } - - return -} - -func (h *baseHandler) maybeReplaceDest(res *Mark, target *netip.AddrPort) { - if len(res.IP) <= 0 { - return - } else if resip, err := netip.ParseAddr(res.IP); resip.IsValid() && err == nil { - // if res.IP is set, then use it as the target - if settings.Debug { - log.D("%s: proxy: %s %s target instead of %s", - h.proto, res.CID, resip, target) - } - *target = netip.AddrPortFrom(resip, target.Port()) - } -} - -func conn2str(a net.Conn, b net.Conn) string { - ar := a.RemoteAddr() - br := b.RemoteAddr() - al := a.LocalAddr() - bl := b.LocalAddr() - // may empty out? go.dev/blog/unique - // (a footnote about interning strings) - s := core.UniqStr(fmt.Sprintf("a(%v->%v) => b(%v<-%v)", al, ar, bl, br)) - return s -} - -func clos(c ...core.MinConn) { - core.CloseConn(c...) -} - -func ntoa(n string) int32 { - switch n { - case "udp", "udp6", "udp4": - return 17 - case "tcp", "tcp6", "tcp4": - return 6 - case "icmp", "icmp4": - return 1 - case "icmp6": - return 58 - } - return UNSUPPORTED_NETWORK -} - -func isAnyBlockPid(pids []string) bool { - return containsPid(pids, ipn.Block) -} - -func isAnyBasePid(pids []string) bool { - return containsPid(pids, ipn.Base) -} - -func containsPid(pids []string, pid string) bool { - return slices.Contains(pids, pid) -} - -func extendc(c net.Conn, r time.Duration, w time.Duration) { - if c != nil { - if r == w { - extend(c, r) - } else { - extendr(c, r) - extendw(c, w) - } - } -} - -func extend(c core.MinConn, t time.Duration) { - if c != nil && core.IsNotNil(c) { - if t.Milliseconds() <= 0 { - _ = c.SetDeadline(time.Time{}) - } else { - _ = c.SetDeadline(time.Now().Add(t)) - } - } -} - -func extendr(c core.MinConn, t time.Duration) { - if c != nil && core.IsNotNil(c) { - if t.Milliseconds() <= 0 { - _ = c.SetDeadline(time.Time{}) - } else { - _ = c.SetReadDeadline(time.Now().Add(t)) - } - } -} - -func extendw(c core.MinConn, t time.Duration) { - if c != nil && core.IsNotNil(c) { - if t.Milliseconds() <= 0 { - _ = c.SetDeadline(time.Time{}) - } else { - _ = c.SetWriteDeadline(time.Now().Add(t)) - } - } -} - -func anyaddrFor(ipp netip.AddrPort) (proto, anyaddr string) { - return ipn.AnyAddrForUDP(ipp) -} - -func logev(err error) log.LogFn { - f := log.E - if err == nil { - f = log.VV - } - return f -} - -func logei(err error) log.LogFn { - f := log.E - if err == nil { - f = log.I - } - return f -} - -func logwif(cond bool) log.LogFn { - if cond { - return log.W - } - return log.VV -} - -func logiif(cond bool) log.LogFn { - if cond { - return log.I - } - return log.VV -} - -func pidstr(p ipn.Proxy) string { - if p == nil { - return "" - } - return p.ID() -} diff --git a/intra/core/async.go b/intra/core/async.go deleted file mode 100644 index e8770f3b..00000000 --- a/intra/core/async.go +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "context" - "errors" - "runtime/debug" - "strconv" - "time" -) - -// Go runs f in a goroutine and recovers from any panics. -func Go(who string, f func()) { - go func() { - debug.SetPanicOnFault(true) - defer Recover(DontExit, who) - - f() - }() -} - -// Go1 runs f(arg) in a goroutine and recovers from any panics. -func Go1[T any](who string, f func(T), arg T) { - go func() { - debug.SetPanicOnFault(true) - defer Recover(DontExit, who) - - f(arg) - }() -} - -// Go2 runs f(arg0,arg1) in a goroutine and recovers from any panics. -func Go2[T0 any, T1 any](who string, f func(T0, T1), a0 T0, a1 T1) { - go func() { - debug.SetPanicOnFault(true) - defer Recover(DontExit, who) - - f(a0, a1) - }() -} - -// Gg runs f in a goroutine, recovers from any panics if any; -// then calls cb in a separate goroutine, and recovers from any panics. -func Gg(who string, f func(), cb func()) { - go func() { - debug.SetPanicOnFault(true) - defer RecoverFn(who, cb) - - f() - }() -} - -// Gx runs f in a goroutine and exits the process if f panics. -func Gx(who string, f func()) { - go func() { - debug.SetPanicOnFault(true) - defer Recover(Exit11, who) - - f() - }() -} - -// Gx1 runs f in a goroutine and exits the process if f panics. -func Gx1[T any](who string, f func(T), arg T) { - go func() { - debug.SetPanicOnFault(true) - defer Recover(Exit11, who) - - f(arg) - }() -} - -// Gif runs f in a goroutine if cond is true. -func Gif(cond bool, who string, f func()) { - if cond { - Go(who, f) - } -} - -// Grx runs work function f in a goroutine, blocking until it returns or timesout. -func Grx[T any](who string, f WorkCtx[T], d time.Duration) (zz T, completed bool) { - ch := make(chan T, 1) // non-blocking - - ctx, cancel := context.WithTimeout(context.Background(), d) - defer cancel() - - // go.dev/play/p/VtWYJrxhXz6 - go func() { - debug.SetPanicOnFault(true) - defer Recover(Exit11, who) - defer close(ch) - - out, _ := f(ctx) // TODO: log error? - ch <- out - }() - - select { - case out := <-ch: - return out, true - case <-ctx.Done(): // timeout - } - return zz, false -} - -// Gxe runs f in a goroutine, ignores returned error, and exits on panics. -func Gxe(who string, f func() error) { - go func() { - debug.SetPanicOnFault(true) - defer Recover(Exit11, who) - - _ = f() - }() -} - -// errPanic returns an error indicating that the function at index i panicked. -func errPanic(who string) error { - return errors.New(who + " fn panicked") -} - -// Race runs all the functions in fs concurrently and returns the first non-error result. -// Returned values are the result, the index of the function that returned the result, and any errors. -// If all functions return an error, the accumulation of it is returned. -// Panicking functions are considered as returning an error. -// If the timeout is reached, errTimeout is returned. -// Note that, zero value result could be returned if at least one function returns that without any error. -// go.dev/play/p/GVW-dXcZORr -func Race[T any](who string, timeout time.Duration, fs ...WorkCtx[T]) (zz T, fidx int, errs error) { - type res struct { - t T - err error - i int - } - - ch := make(chan *res, len(fs)) - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - for i, f := range fs { - i, f := i, f - fid := who + ".race." + strconv.Itoa(i) - Gg(fid, func() { - out, err := f(ctx) - select { - case <-ctx.Done(): // discard out, err - case ch <- &res{out, err, i}: - } - }, func() { - select { - case <-ctx.Done(): // discard out, err - case ch <- &res{zz, errPanic(fid), i}: - } - }) - } - -loop: - for range fs { - select { - case r := <-ch: - if r.err != nil { - errs = JoinErr(errs, r.err) - } else { - return r.t, r.i, r.err - } - case <-ctx.Done(): - // if one of WorkCtx functions times out, it - // means the rest have also lost the race. - // break out of the loop and return errTimeout. - errs = JoinErr(errs, errTimeout) - break loop - } - } - return // zz -} - -func First[T any](who string, overallTimeout time.Duration, fs ...WorkCtx[T]) (zz T, idx int) { - timeoutPerFn := overallTimeout / time.Duration(len(fs)) - for i, f := range fs { - // unneeded in go1.23+ i, f := i, f - fid := who + ".all." + strconv.Itoa(i) - if x, ok := Grx(fid, f, timeoutPerFn); ok { - return x, i - } - } - return zz, -1 -} - -func All[T any](who string, timeout time.Duration, fs ...WorkCtx[T]) ([]T, []error) { - type res struct { - fidx int // index of the function in fs - t T - err error - } - - ch := make(chan *res, len(fs)) - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - for i, f := range fs { - //unneeded in go1.23+ i, f := i, f - fid := who + ".all." + strconv.Itoa(i) - Gg(fid, func() { - out, err := f(ctx) - select { - case <-ctx.Done(): // timeout - ch <- &res{fidx: i, err: errTimeout} - case ch <- &res{i, out, err}: - } - }, func() { - select { - case <-ctx.Done(): // timeout - ch <- &res{fidx: i, err: errTimeout} - case ch <- &res{fidx: i, err: errPanic(fid)}: - } - }) - } - - results := make([]T, len(fs)) - errs := make([]error, len(fs)) - - for range len(fs) { - r := <-ch - results[r.fidx] = r.t - errs[r.fidx] = r.err - } - return results, errs -} - -func Periodic(id string, pctx context.Context, d time.Duration, f func()) context.Context { - ctx, done := context.WithCancel(pctx) - Go("periodic."+id, func() { - t := time.NewTicker(d) - defer t.Stop() - defer done() - - for { - select { - case <-pctx.Done(): - return - case <-t.C: - f() - } - } - }) - return ctx -} - -// SigFin runs f in a goroutine and returns a channel that is closed when f returns. -func SigFin(id string, f func()) <-chan struct{} { - done := make(chan struct{}) - Go("sigfin."+id, func() { - defer close(done) - f() - }) - return done -} - -func Await(f func(), until time.Duration) (awaited bool) { - done := make(chan struct{}) - Go("await", func() { - defer close(done) - f() - }) - - select { - case <-time.After(until): - return false - case <-done: - return true - } -} - -func Await1[T any](f func() T, until time.Duration) (v T, gotV bool) { - done := make(chan struct{}) - Go("await", func() { - defer close(done) - v = f() - }) - - select { - case <-time.After(until): - return v, false - case <-done: - return v, true - } -} - -func EitherOr(either <-chan struct{}, or Callback, until time.Duration) (esc bool) { - select { - case <-time.Tick(until): - if or != nil { - or() - } - return false - case <-either: - return true - } -} diff --git a/intra/core/barrier.go b/intra/core/barrier.go deleted file mode 100644 index c13909fd..00000000 --- a/intra/core/barrier.go +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package core - -import ( - "context" - "errors" - "fmt" - "sync" - "sync/atomic" - "time" -) - -const ( - // Anew is the value returned by Barrier.Do when the function was - // executed and its results are stored in the Barrier. - Anew = iota - // Shared is the value returned by Barrier.Do when the function's - // results are already stored in the Barrier. - Shared -) - -var ( - errTimeout = errors.New("core: timeout") - errNoFruitOfLabour = errors.New("core: Work did not yield results") -) - -// Work is the type of the function to memoize. -type Callback func() -type Work[T any] func() (T, error) -type Work1[T any] func(T) (T, error) -type WorkCtx[T any] func(context.Context) (T, error) - -// V is an in-flight or completed Barrier.Do V -type V[T any, K comparable] struct { - wg sync.WaitGroup - dob time.Time - Val T - Err error - N atomic.Uint32 -} - -func (v *V[t, k]) String() string { - if v == nil { - return "core.V: " - } - return fmt.Sprintf("v: %v // n: %d; exp: %s // err: %v", v.Val, v.N.Load(), v.dob, v.Err) -} - -func (v *V[t, k]) E() string { - if v == nil { - return "core.V: " - } else if ve := v.Err; ve == nil { - return "core.V: no error" - } else { - return ve.Error() - } -} - -func (v *V[t, k]) id() string { - return fmt.Sprintf("%p", v) -} - -// Barrier represents a class of work and forms a namespace in -// which units of work can be executed with duplicate suppression. -type Barrier[T any, K comparable] struct { - mu sync.Mutex // protects m - m map[K]*V[T, K] // caches in-flight and completed Vs - - ttl time.Duration // time-to-live for completed Vs in m - neg time.Duration // time-to-live for errored Vs in m - to time.Duration // timeout for Do(), Do1(), Go() - - lastscrub time.Time // last scrub time -} - -func NewKeyedBarrier[T any, K comparable](ttl time.Duration) *Barrier[T, K] { - return NewBarrier2[T, K](ttl, ttl/5) -} - -// NewBarrier returns a new Barrier with the given time-to-live for -// completed Vs. -func NewBarrier[T any](ttl time.Duration) *Barrier[T, string] { - return NewBarrier2[T, string](ttl, ttl/5) -} - -// NewBarrier2 returns a new Barrier with the time-to-lives for -// completed Vs (ttl) and errored Vs (neg). -func NewBarrier2[T any, K comparable](ttl, neg time.Duration) *Barrier[T, K] { - return &Barrier[T, K]{ - m: make(map[K]*V[T, K]), - ttl: ttl, - neg: max(1*time.Second /*min neg*/, neg), - to: ttl, - lastscrub: time.Now(), - } -} - -func (ba *Barrier[T, K]) maybeScrubLocked() { - now := time.Now() - if now.Sub(ba.lastscrub) < reapthreshold { - return - } - ba.lastscrub = now - - Go("ba.scrub", func() { - ba.mu.Lock() - defer ba.mu.Unlock() - - i := 0 - for k, v := range ba.m { - if i > maxreapiter { - break - } - ttl := ba.ttl - if v.Err != nil { - ttl = ba.neg - } - if time.Since(v.dob.Add(ttl)) > 0 { - delete(ba.m, k) - } - i++ - } - }) -} - -func (ba *Barrier[T, K]) getLocked(k K) (v *V[T, K], ok bool) { - defer ba.maybeScrubLocked() - - v, ok = ba.m[k] - if v != nil { - ttl := ba.ttl - if v.Err != nil { - ttl = ba.neg - } - if time.Since(v.dob.Add(ttl)) > 0 { - delete(ba.m, k) - return nil, false - } - } - return v, ok -} - -func (ba *Barrier[T, K]) addLocked(k K) *V[T, K] { - v := new(V[T, K]) - v.wg.Add(1) - v.dob = time.Now() - ba.m[k] = v - return v -} - -// DoIt is like Do but returns from once as-is. -func (ba *Barrier[T, K]) DoIt(k K, once Work[T]) (zz T, err error) { - v, _ := ba.Do(k, once) - if v == nil || v.Err != nil { - if v == nil { // unlikely - return zz, errNoFruitOfLabour - } - return v.Val, v.Err - } - return v.Val, nil -} - -// Do executes and returns the results of the given function, making -// sure that only one execution is in-flight for a given key at a -// time. If a duplicate comes in, the duplicate caller waits for the -// original to complete and receives the same results. -func (ba *Barrier[T, K]) Do(k K, once Work[T]) (*V[T, K], int) { - ba.mu.Lock() - c, _ := ba.getLocked(k) - if c != nil { - ba.mu.Unlock() - - c.N.Add(1) // register presence - c.wg.Wait() // wait for the in-flight req to complete - return c, Shared - } - c = ba.addLocked(k) - ba.mu.Unlock() - - if _, completed := Grx("ba.do."+c.id(), func(_ context.Context) (*V[T, K], error) { - c.Val, c.Err = once() - return c, c.Err - }, ba.to); !completed { - c.Err = JoinErr(c.Err, errTimeout) - } - - c.wg.Done() // unblock all waiters - return c, Anew -} - -// Do1 is like Do but for Work1 with one arg. -func (ba *Barrier[T, K]) Do1(k K, once Work1[T], arg T) (*V[T, K], int) { - ba.mu.Lock() - c, _ := ba.getLocked(k) - if c != nil { - ba.mu.Unlock() - - c.N.Add(1) // register presence - c.wg.Wait() // wait for the in-flight req to complete - return c, Shared - } - c = ba.addLocked(k) - ba.mu.Unlock() - - if _, completed := Grx("ba.do1."+c.id(), func(_ context.Context) (*V[T, K], error) { - c.Val, c.Err = once(arg) - return c, c.Err - }, ba.to); !completed { - c.Err = JoinErr(c.Err, errTimeout) - } - - c.wg.Done() // unblock all waiters - return c, Anew -} - -// untested -func (ba *Barrier[T, K]) Go(k K, once Work[T]) <-chan *V[T, K] { - ch := make(chan *V[T, K]) - - Go("ba.go", func() { - defer close(ch) - - ba.mu.Lock() - c, _ := ba.getLocked(k) - if c != nil { - ba.mu.Unlock() - - c.N.Add(1) // register presence - c.wg.Wait() // wait for the in-flight req to complete - ch <- c - return - } - c = ba.addLocked(k) - ba.mu.Unlock() - - c.Val, c.Err = once() - - c.wg.Done() // unblock all waiters - ch <- c - }) - - return ch -} diff --git a/intra/core/brsa/brsa.go b/intra/core/brsa/brsa.go deleted file mode 100644 index 0e68ca9c..00000000 --- a/intra/core/brsa/brsa.go +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// BSD-3-Clause License -// -// Copyright (c) 2019 Cloudflare. All rights reserved. -// Copyright (c) 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// from: https://github.com/cloudflare/circl/tree/v1.3.7/blindsign - -// Package blindrsa implements the RSA Blind Signature Protocol as defined in [RFC9474]. -// -// The RSA Blind Signature protocol, and its variant RSABSSA -// (RSA Blind Signature Scheme with Appendix) is a two-party protocol -// between a Client and Server where they interact to compute -// -// sig = Sign(sk, input_msg), -// -// where `input_msg = Prepare(msg)` is a prepared version of a private -// message `msg` provided by the Client, and `sk` is the private signing -// key provided by the server. -// -// # Supported Variants -// -// This package is compliant with the [RFC-9474] document -// and supports the following variants: -// - RSABSSA-SHA384-PSS-Deterministic -// - RSABSSA-SHA384-PSSZERO-Deterministic -// - RSABSSA-SHA384-PSS-Randomized -// - RSABSSA-SHA384-PSSZERO-Randomized -// -// [RFC-9474]: https://www.rfc-editor.org/info/rfc9474 -package blindrsa - -import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "errors" - "io" - "math/big" -) - -type Variant int - -const ( - SHA384PSSRandomized Variant = iota // RSABSSA-SHA384_PSS_Randomized - SHA384PSSDeterministic // RSABSSA-SHA384_PSS_Deterministic -) - -func (v Variant) String() string { - switch v { - case SHA384PSSRandomized, SHA384PSSDeterministic: - return "RSABSSA-SHA384-PSS-Randomized" - default: - return "invalid RSABSSA variant" - } -} - -// Client is a type that implements the client side of the blind RSA -// protocol, described in https://www.rfc-editor.org/rfc/rfc9474.html#name-rsabssa-variants -type Client struct { - v Verifier - prefixLen int -} - -func NewClient(v Variant, pk *rsa.PublicKey) (Client, error) { - verif, err := NewVerifier(v, pk) - if err != nil { - return Client{}, err - } - var prefixLen int - switch v { - case SHA384PSSDeterministic: - prefixLen = 0 - case SHA384PSSRandomized: - prefixLen = 32 - default: - return Client{}, ErrInvalidVariant - } - - return Client{verif, prefixLen}, nil -} - -type State struct { - // The hashed and encoded message being signed - encodedMsg []byte - // Blinding factor produced by the Verifier - r *big.Int - // Inverse of the blinding factor produced by the Verifier - rInv *big.Int - // Salt used in the encoding of the message - salt []byte -} - -func (s State) Salt() []byte { return s.salt } -func (s State) Factor() *big.Int { return s.r } - -// Prepare is the process by which the message to be signed and -// verified is prepared for input to the blind signing protocol. -func (c Client) Prepare(random io.Reader, message []byte) ([]byte, error) { - if random == nil { - return nil, ErrInvalidRandomness - } - - prefix := make([]byte, c.prefixLen) - _, err := io.ReadFull(random, prefix) - if err != nil { - return nil, err - } - - return append(append([]byte{}, prefix...), message...), nil -} - -// Blind initializes the blind RSA protocol using an input message and source of randomness. -// This function fails if randomness was not provided. -func (c Client) Blind(random io.Reader, preparedMessage []byte) (blindedMsg []byte, state State, err error) { - if random == nil { - return nil, State{}, ErrInvalidRandomness - } - - salt := make([]byte, c.v.SaltLength) - _, err = io.ReadFull(random, salt) - if err != nil { - return nil, State{}, err - } - - r, rInv, err := GenerateBlindingFactor(random, c.v.pk.N) - if err != nil { - return nil, State{}, err - } - - return c.FixedBlind(preparedMessage, salt, r, rInv) -} - -func (c Client) FixedBlind(message, salt []byte, r, rInv *big.Int) (blindedMsg []byte, state State, err error) { - encodedMsg, err := EncodeMessageEMSAPSS(message, c.v.pk.N, c.v.Hash.New(), salt) - if err != nil { - return nil, State{}, err - } - - m := new(big.Int).SetBytes(encodedMsg) - - bigE := big.NewInt(int64(c.v.pk.E)) - x := new(big.Int).Exp(r, bigE, c.v.pk.N) - z := new(big.Int).Set(m) - z.Mul(z, x) - z.Mod(z, c.v.pk.N) - - kLen := (c.v.pk.N.BitLen() + 7) / 8 - blindedMsg = make([]byte, kLen) - z.FillBytes(blindedMsg) - - return blindedMsg, State{encodedMsg, r, rInv, salt}, nil -} - -func (c Client) Finalize(state State, blindedSig []byte) ([]byte, error) { - kLen := (c.v.pk.N.BitLen() + 7) / 8 - if len(blindedSig) != kLen { - return nil, ErrUnexpectedSize - } - - z := new(big.Int).SetBytes(blindedSig) - s := new(big.Int).Set(state.rInv) - s.Mul(s, z) - s.Mod(s, c.v.pk.N) - - sig := make([]byte, kLen) - s.FillBytes(sig) - - err := VerifyBlindSignature(NewBigPublicKey(c.v.pk), state.encodedMsg, sig) - if err != nil { - return nil, err - } - - return sig, nil -} - -// Verify verifies the input (message, signature) pair and produces an error upon failure. -func (c Client) Verify(message, signature []byte) error { return c.v.Verify(message, signature) } - -type Verifier struct { - // Public key of the Signer - pk *rsa.PublicKey - rsa.PSSOptions -} - -func NewVerifier(v Variant, pk *rsa.PublicKey) (Verifier, error) { - switch v { - case SHA384PSSRandomized, SHA384PSSDeterministic: - return Verifier{pk, rsa.PSSOptions{Hash: crypto.SHA384, SaltLength: crypto.SHA384.Size()}}, nil - default: - return Verifier{}, ErrInvalidVariant - } -} - -// Verify verifies the input (message, signature) pair and produces an error upon failure. -func (v Verifier) Verify(message, signature []byte) error { - return VerifyMessageSignature(message, signature, v.SaltLength, NewBigPublicKey(v.pk), v.Hash) -} - -// Signer structure represents the signing server in the blind RSA protocol. -// It carries the raw RSA private key used for signing blinded messages. -type Signer struct { - // An RSA private key - sk *rsa.PrivateKey -} - -// NewSigner creates a new Signer for the blind RSA protocol using an RSA private key. -func NewSigner(sk *rsa.PrivateKey) Signer { - return Signer{ - sk: sk, - } -} - -// BlindSign blindly computes the RSA operation using the Signer's private key on the blinded -// message input, if it's of valid length, and returns an error should the function fail. -// -// See the specification for more details: -// https://www.rfc-editor.org/rfc/rfc9474.html#name-blindsign -func (signer Signer) BlindSign(data []byte) ([]byte, error) { - kLen := (signer.sk.N.BitLen() + 7) / 8 - if len(data) != kLen { - return nil, ErrUnexpectedSize - } - - m := new(big.Int).SetBytes(data) - if m.Cmp(signer.sk.N) > 0 { - return nil, ErrInvalidMessageLength - } - - s, err := DecryptAndCheck(rand.Reader, NewBigPrivateKey(signer.sk), m) - if err != nil { - return nil, err - } - - blindSig := make([]byte, kLen) - s.FillBytes(blindSig) - - return blindSig, nil -} - -var ( - // ErrInvalidVariant is the error used if the variant request does not exist. - ErrInvalidVariant = errors.New("blindsign/blindrsa: invalid variant requested") - - // ErrUnexpectedSize is the error used if the size of a parameter does not match its expected value. - ErrUnexpectedSize = errors.New("blindsign/blindrsa: unexpected input size") - - // ErrInvalidMessageLength is the error used if the size of a protocol message does not match its expected value. - ErrInvalidMessageLength = errors.New("blindsign/blindrsa: invalid message length") - - // ErrInvalidBlind is the error used if the blind generated by the Verifier fails. - ErrInvalidBlind = errors.New("blindsign/blindrsa: invalid blind") - - // ErrInvalidRandomness is the error used if caller did not provide randomness to the Blind() function. - ErrInvalidRandomness = errors.New("blindsign/blindrsa: invalid random parameter") - - // ErrUnsupportedHashFunction is the error used if the specified hash is not supported. - ErrUnsupportedHashFunction = errors.New("blindsign/blindrsa: unsupported hash function") -) diff --git a/intra/core/brsa/common.go b/intra/core/brsa/common.go deleted file mode 100644 index f6f46d22..00000000 --- a/intra/core/brsa/common.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// BSD-3-Clause License -// -// Copyright (c) 2019 Cloudflare. All rights reserved. -// Copyright (c) 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// from: https://github.com/cloudflare/circl/tree/v1.3.7/blindsign - -package blindrsa - -import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/sha512" - "crypto/subtle" - "errors" - "hash" - "io" - "math/big" -) - -// ConvertHashFunction converts a crypto.Hash function to an equivalent hash.Hash type. -func ConvertHashFunction(hash crypto.Hash) hash.Hash { - switch hash { - case crypto.SHA256: - return sha256.New() - case crypto.SHA384: - return sha512.New384() - case crypto.SHA512: - return sha512.New() - default: - panic(ErrUnsupportedHashFunction) - } -} - -// EncodeMessageEMSAPSS hashes the input message and then encodes it using PSS encoding. -func EncodeMessageEMSAPSS(message []byte, N *big.Int, hash hash.Hash, salt []byte) ([]byte, error) { - hash.Reset() // Ensure the hash state is cleared - hash.Write(message) - digest := hash.Sum(nil) - hash.Reset() - emBits := N.BitLen() - 1 - encodedMsg, err := emsaPSSEncode(digest[:], emBits, salt, hash) - return encodedMsg, err -} - -// GenerateBlindingFactor generates a blinding factor and its multiplicative inverse -// to use for RSA blinding. -func GenerateBlindingFactor(random io.Reader, N *big.Int) (*big.Int, *big.Int, error) { - randReader := random - if randReader == nil { - randReader = rand.Reader - } - r, err := rand.Int(randReader, N) - if err != nil { - return nil, nil, err - } - - if r.Sign() == 0 { - r.SetInt64(1) - } - rInv := new(big.Int).ModInverse(r, N) - if rInv == nil { - return nil, nil, ErrInvalidBlind - } - - return r, rInv, nil -} - -// VerifyMessageSignature verifies the input message signature against the expected public key -func VerifyMessageSignature(message, signature []byte, saltLength int, pk *BigPublicKey, hash crypto.Hash) error { - h := ConvertHashFunction(hash) - h.Write(message) - digest := h.Sum(nil) - - err := verifyPSS(pk, hash, digest, signature, &rsa.PSSOptions{ - Hash: hash, - SaltLength: saltLength, - }) - return err -} - -// DecryptAndCheck checks that the private key operation is consistent (fault attack detection). -func DecryptAndCheck(random io.Reader, priv *BigPrivateKey, c *big.Int) (m *big.Int, err error) { - m, err = decrypt(random, priv, c) - if err != nil { - return nil, err - } - - // In order to defend against errors in the CRT computation, m^e is - // calculated, which should match the original ciphertext. - check := encrypt(new(big.Int), priv.Pk.N, priv.Pk.E, m) - if c.Cmp(check) != 0 { - return nil, errors.New("rsa: internal error") - } - return m, nil -} - -// VerifyBlindSignature verifies the signature of the hashed and encoded message against the input public key. -func VerifyBlindSignature(pub *BigPublicKey, hashed, sig []byte) error { - m := new(big.Int).SetBytes(hashed) - bigSig := new(big.Int).SetBytes(sig) - - c := encrypt(new(big.Int), pub.N, pub.E, bigSig) - if subtle.ConstantTimeCompare(m.Bytes(), c.Bytes()) == 1 { - return nil - } else { - return rsa.ErrVerification - } -} - -func saltLength(opts *rsa.PSSOptions) int { - if opts == nil { - return rsa.PSSSaltLengthAuto - } - return opts.SaltLength -} - -func verifyPSS(pub *BigPublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *rsa.PSSOptions) error { - if len(sig) != pub.Size() { - return rsa.ErrVerification - } - s := new(big.Int).SetBytes(sig) - m := encrypt(new(big.Int), pub.N, pub.E, s) - emBits := pub.N.BitLen() - 1 - emLen := (emBits + 7) / 8 - if m.BitLen() > emLen*8 { - return rsa.ErrVerification - } - em := m.FillBytes(make([]byte, emLen)) - return emsaPSSVerify(digest, em, emBits, saltLength(opts), hash.New()) -} diff --git a/intra/core/brsa/keys.go b/intra/core/brsa/keys.go deleted file mode 100644 index d10b8620..00000000 --- a/intra/core/brsa/keys.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// BSD-3-Clause License -// -// Copyright (c) 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// from: https://github.com/cloudflare/circl/tree/v1.3.7/blindsign - -package blindrsa - -import ( - "crypto/rsa" - "math/big" -) - -// BigPublicKey is the same as an rsa.PublicKey struct, except the public -// key is represented as a big integer as opposed to an int. For the partially -// blind scheme, this is required since the public key will typically be -// any value in the RSA group. -type BigPublicKey struct { - N *big.Int - E *big.Int -} - -// Size returns the size of the public key. -func (pub *BigPublicKey) Size() int { - return (pub.N.BitLen() + 7) / 8 -} - -// Marshal encodes the public key exponent (e). -func (pub *BigPublicKey) Marshal() []byte { - buf := make([]byte, (pub.E.BitLen()+7)/8) - pub.E.FillBytes(buf) - return buf -} - -// NewBigPublicKey creates a BigPublicKey from a rsa.PublicKey. -func NewBigPublicKey(pk *rsa.PublicKey) *BigPublicKey { - return &BigPublicKey{ - N: pk.N, - E: new(big.Int).SetInt64(int64(pk.E)), - } -} - -// CustomPublicKey is similar to rsa.PrivateKey, containing information needed -// for a private key used in the partially blind signature protocol. -type BigPrivateKey struct { - Pk *BigPublicKey - D *big.Int - P *big.Int - Q *big.Int -} - -// NewBigPrivateKey creates a BigPrivateKey from a rsa.PrivateKey. -func NewBigPrivateKey(sk *rsa.PrivateKey) *BigPrivateKey { - return &BigPrivateKey{ - Pk: &BigPublicKey{ - N: sk.N, - E: new(big.Int).SetInt64(int64(sk.PublicKey.E)), - }, - D: sk.D, - P: sk.Primes[0], - Q: sk.Primes[1], - } -} diff --git a/intra/core/brsa/pss.go b/intra/core/brsa/pss.go deleted file mode 100644 index dfd89b9a..00000000 --- a/intra/core/brsa/pss.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// BSD-3-Clause License -// -// Copyright (c) 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// from: https://github.com/cloudflare/circl/tree/v1.3.7/blindsign - -package blindrsa - -// This file implements the RSASSA-PSS signature scheme according to RFC 8017. - -import ( - "bytes" - "crypto/rsa" - "errors" - "hash" -) - -// Per RFC 8017, Section 9.1 -// -// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc -// -// where -// -// DB = PS || 0x01 || salt -// -// and PS can be empty so -// -// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2 -// - -func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { - // See RFC 8017, Section 9.1.1. - - hLen := hash.Size() - sLen := len(salt) - emLen := (emBits + 7) / 8 - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - - if len(mHash) != hLen { - return nil, errors.New("crypto/rsa: input must be hashed with given hash") - } - - // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. - - if emLen < hLen+sLen+2 { - return nil, errors.New("crypto/rsa: key size too small for PSS signature") - } - - em := make([]byte, emLen) - psLen := emLen - sLen - hLen - 2 - db := em[:psLen+1+sLen] - h := em[psLen+1+sLen : emLen-1] - - // 4. Generate a random octet string salt of length sLen; if sLen = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; - // - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length hLen. - - var prefix [8]byte - - hash.Write(prefix[:]) - hash.Write(mHash) - hash.Write(salt) - - h = hash.Sum(h[:0]) - hash.Reset() - - // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - - db[psLen] = 0x01 - copy(db[psLen+1:], salt) - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - - mgf1XOR(db, hash, h) - - // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in - // maskedDB to zero. - - db[0] &= 0xff >> (8*emLen - emBits) - - // 12. Let EM = maskedDB || H || 0xbc. - em[emLen-1] = 0xbc - - // 13. Output EM. - return em, nil -} - -func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { - // See RFC 8017, Section 9.1.2. - - hLen := hash.Size() - if sLen == rsa.PSSSaltLengthEqualsHash { - sLen = hLen - } - emLen := (emBits + 7) / 8 - if emLen != len(em) { - return errors.New("rsa: internal error: inconsistent length") - } - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" - // and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if hLen != len(mHash) { - return rsa.ErrVerification - } - - // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. - if emLen < hLen+sLen+2 { - return rsa.ErrVerification - } - - // 4. If the rightmost octet of EM does not have hexadecimal value - // 0xbc, output "inconsistent" and stop. - if em[emLen-1] != 0xbc { - return rsa.ErrVerification - } - - // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and - // let H be the next hLen octets. - db := em[:emLen-hLen-1] - h := em[emLen-hLen-1 : emLen-1] - - // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in - // maskedDB are not all equal to zero, output "inconsistent" and - // stop. - var bitMask byte = 0xff >> (8*emLen - emBits) - if em[0] & ^bitMask != 0 { - return rsa.ErrVerification - } - - // 7. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 8. Let DB = maskedDB \xor dbMask. - mgf1XOR(db, hash, h) - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= bitMask - - // If we don't know the salt length, look for the 0x01 delimiter. - if sLen == rsa.PSSSaltLengthAuto { - psLen := bytes.IndexByte(db, 0x01) - if psLen < 0 { - return rsa.ErrVerification - } - sLen = len(db) - psLen - 1 - } - - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - psLen := emLen - hLen - sLen - 2 - for _, e := range db[:psLen] { - if e != 0x00 { - return rsa.ErrVerification - } - } - if db[psLen] != 0x01 { - return rsa.ErrVerification - } - - // 11. Let salt be the last sLen octets of DB. - salt := db[len(db)-sLen:] - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - var prefix [8]byte - hash.Write(prefix[:]) - hash.Write(mHash) - hash.Write(salt) - - h0 := hash.Sum(nil) - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if !bytes.Equal(h0, h) { // TODO: constant time? - return rsa.ErrVerification - } - return nil -} diff --git a/intra/core/brsa/rsa.go b/intra/core/brsa/rsa.go deleted file mode 100644 index 013329df..00000000 --- a/intra/core/brsa/rsa.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// BSD-3-Clause License -// -// Copyright (c) 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// from: https://github.com/cloudflare/circl/tree/v1.3.7/blindsign - -package blindrsa - -import ( - "crypto/rand" - "crypto/rsa" - "hash" - "io" - "math/big" -) - -var ( - bigZero = big.NewInt(0) - bigOne = big.NewInt(1) -) - -// incCounter increments a four byte, big-endian counter. -func incCounter(c *[4]byte) { - if c[3]++; c[3] != 0 { - return - } - if c[2]++; c[2] != 0 { - return - } - if c[1]++; c[1] != 0 { - return - } - c[0]++ -} - -// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function -// specified in PKCS #1 v2.1. -func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { - var counter [4]byte - var digest []byte - - done := 0 - for done < len(out) { - hash.Write(seed) - hash.Write(counter[0:4]) - digest = hash.Sum(digest[:0]) - hash.Reset() - - for i := 0; i < len(digest) && done < len(out); i++ { - out[done] ^= digest[i] - done++ - } - incCounter(&counter) - } -} - -func encrypt(c *big.Int, N *big.Int, e *big.Int, m *big.Int) *big.Int { - c.Exp(m, e, N) - return c -} - -// decrypt performs an RSA decryption, resulting in a plaintext integer. If a -// random source is given, RSA blinding is used. -func decrypt(random io.Reader, priv *BigPrivateKey, c *big.Int) (m *big.Int, err error) { - // TODO(agl): can we get away with reusing blinds? - if c.Cmp(priv.Pk.N) > 0 { - return nil, rsa.ErrDecryption - } - if priv.Pk.N.Sign() == 0 { - return nil, rsa.ErrDecryption - } - - var ir *big.Int - if random != nil { - // Blinding enabled. Blinding involves multiplying c by r^e. - // Then the decryption operation performs (m^e * r^e)^d mod n - // which equals mr mod n. The factor of r can then be removed - // by multiplying by the multiplicative inverse of r. - - var r *big.Int - ir = new(big.Int) - for { - r, err = rand.Int(random, priv.Pk.N) - if err != nil { - return nil, err - } - if r.Cmp(bigZero) == 0 { - r = bigOne - } - ok := ir.ModInverse(r, priv.Pk.N) - if ok != nil { - break - } - } - rpowe := new(big.Int).Exp(r, priv.Pk.E, priv.Pk.N) // N != 0 - cCopy := new(big.Int).Set(c) - cCopy.Mul(cCopy, rpowe) - cCopy.Mod(cCopy, priv.Pk.N) - c = cCopy - } - - m = new(big.Int).Exp(c, priv.D, priv.Pk.N) - - if ir != nil { - // Unblind. - m.Mul(m, ir) - m.Mod(m, priv.Pk.N) - } - - return m, nil -} diff --git a/intra/core/buf.go b/intra/core/buf.go deleted file mode 100644 index e8c7aa84..00000000 --- a/intra/core/buf.go +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -// from: github.com/eycorsican/go-tun2socks/blob/301549c435/core/buffer_pool.go - -import ( - "math/bits" - "sync" -) - -var slabs [totalSlabs]*sync.Pool // read-only after init - -const ( - totalSlabs = 7 // total slab types - minSlabExponent = 11 // 2^11 = 2048 bytes - - // B524288 is slab of size 512k - B524288 = 512 * 1024 - // B65536 is slab of size 64k - B65536 = 64 * 1024 - // B32768 is slab of size 32k - B32768 = 32 * 1024 - // B16384 is slab of size 16k - B16384 = 16 * 1024 - // B8192 is slab of size 8k - B8192 = 8 * 1024 - // B4096 is slab of size 4k - B4096 = 4 * 1024 - // B2048 is slab of size 2k; also the min - B2048 = 2 * 1024 - // BMAX is the largest pooled slab size - BMAX = B524288 -) - -// pointers to slices: archive.is/BhHuQ -// deal only in pointers to byte-array -// github.com/golang/example/blob/9fd7daa/slog-handler-guide/README.md#speed - -// AllocRegion returns a truncated byte slice at least size big -func AllocRegion(size int) *[]byte { - if slab := slabof(size); slab != nil { - if ptr, _ := slab.Get().(*[]byte); ptr != nil { - return ptr - } - } - b := make([]byte, 0, size) - return &b -} - -// Alloc returns a truncated byte slice of size 4096 -func Alloc() *[]byte { - return AllocRegion(B4096) -} - -// Alloc16 returns a truncated byte slice of size 16384 -func Alloc16() *[]byte { - return AllocRegion(B16384) -} - -// LOB returns a truncated byte slice of size 524288 -func LOB() *[]byte { - return AllocRegion(B524288) -} - -// Recycle returns the byte slices to the pool -func Recycle(b *[]byte) bool { - // some buffer pool impl extend len until cap (github.com/v2fly/v2ray-core/blob/0c5abc7e53a/common/bytespool/pool.go#L63) - // arr := *b.slice - // arr[:cap(arr)] - // ---- - // Other impls truncate the slice to 0 len (github.com/golang/example/blob/9fd7daa/slog-handler-guide/README.md#speed) - // (*b.slice) := (*b.slice)[:0] - - // ref: go.dev/play/p/ywM_j-IvVH6 - if slab := slabfor(b); slab != nil { - *b = (*b)[:0] - slab.Put(b) - return true - } - return false -} - -// github.com/v2fly/v2ray-core/blob/0c5abc7e53a/common/bytespool/pool.go#L63 -func init() { - slabs[k(B2048)] = newpool(B2048) - slabs[k(B4096)] = newpool(B4096) - slabs[k(B8192)] = newpool(B8192) - slabs[k(B16384)] = newpool(B16384) - slabs[k(B32768)] = newpool(B32768) - slabs[k(B65536)] = newpool(B65536) - slabs[k(B524288)] = newpool(B524288) -} - -// slabfor returns a sync.Pool that byte b can be recycled to. -func slabfor(b *[]byte) *sync.Pool { - sz := cap(*b) - return slabof(sz) -} - -// slabof returns the sync.Pool that vends byte slices of size sz. -func slabof(sz int) (p *sync.Pool) { - if sz > BMAX { - // do not store larger regions - } else if sz >= B524288 { // min 512k - p = slabs[k(B524288)] - } else if sz >= B65536 { // min 64k - p = slabs[k(B65536)] - } else if sz >= B32768 { // min 32k - p = slabs[k(B32768)] - } else if sz >= B16384 { // min 16k - p = slabs[k(B16384)] - } else if sz >= B8192 { // min 8k - p = slabs[k(B8192)] - } else if sz >= B4096 { // min 4k - p = slabs[k(B4096)] - } else { // min 2k - p = slabs[k(B2048)] - } - return -} - -// newpool returns a new sync.Pool of byte slices with minimum capacity, size. -func newpool(size int) *sync.Pool { - return &sync.Pool{ - New: func() any { - b := make([]byte, 0, size) - return &b - }, - } -} - -func k(i uint32) int { - slot := log2(i) - minSlabExponent - if slot < 0 { - return 0 - } - return min(slot, totalSlabs-1) -} - -func log2(powerof2 uint32) int { - return bits.TrailingZeros32(powerof2) -} diff --git a/intra/core/byt.go b/intra/core/byt.go deleted file mode 100644 index bdeaa24c..00000000 --- a/intra/core/byt.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import "io" - -type ByteWriter struct { - b *[]byte // pooled byte slice -} - -var _ io.WriteCloser = (*ByteWriter)(nil) - -func (w *ByteWriter) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - bptr := w.b - if bptr == nil { - bptr = AllocRegion(len(p)) - w.b = bptr - } - // TODO: copy when cap(*bptr) < len(*bptr)+len(p)? - // append may grow the slice beyond original capacity - // and so, it may get recycled to a higher slab on Close - *bptr = append(*bptr, p...) - // w.b and bptr point to same slice & contents; *w.b == *bptr - // go.dev/play/p/RJjoAXBsXy3 - return len(p), nil -} - -func (w *ByteWriter) Close() error { - if bptr := w.b; bptr != nil { - // may recycle to a higher slab (see: Write) - Recycle(bptr) - w.b = nil - } - return nil -} - -func (w *ByteWriter) Bytes() []byte { - if bptr := w.b; bptr != nil { - return *bptr - } - return nil -} - -func (w *ByteWriter) Copy() []byte { - if b := w.b; b != nil { - c := make([]byte, len(*b)) - copy(c, *b) - return c - } - return nil -} - -func (w *ByteWriter) Dup() ByteWriter { - if b := w.b; b != nil { - b2 := AllocRegion(len(*b)) - copy(*b2, *b) - return ByteWriter{b: b2} - } - return ByteWriter{} -} - -func (w *ByteWriter) Len() int { - if b := w.b; b != nil { - return len(*b) - } - return 0 -} - -func (w *ByteWriter) Reset() { - if b := w.b; b != nil { - *b = (*b)[:0] - } -} diff --git a/intra/core/closer.go b/intra/core/closer.go deleted file mode 100644 index a1c5c342..00000000 --- a/intra/core/closer.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "io" - "net" - "os" - "syscall" - - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" -) - -type CloserOp int - -const ( - CopR CloserOp = iota // close read - CopW // close write - CopRW // close read write - CopAny // close -) - -func CloseFile(f *os.File) { - if f != nil { - _ = f.Close() - } -} - -func CloseFD(fd int) { - _ = syscall.Close(fd) -} - -// CloseUDP closes c. -func CloseUDP(c *net.UDPConn) { - if c != nil { - _ = c.Close() - } -} - -// CloseTCP closes c. -func CloseTCP(c *net.TCPConn) { - if c != nil { - _ = c.Close() - } -} - -// CloseTCPRead closes the read end of r. -func CloseTCPRead(r TCPConn) { - if r != nil && IsNotNil(r) { - // avoid expensive reflection: - // groups.google.com/g/golang-nuts/c/wnH302gBa4I - switch x := r.(type) { - case *net.TCPConn: - if x != nil { - _ = x.CloseRead() - } - case *gonet.TCPConn: - if x != nil { - _ = x.CloseRead() - } - default: - if IsNotNil(r) { - _ = r.CloseRead() - } - } - } -} - -// CloseTCPWrite closes the write end of w. -func CloseTCPWrite(w TCPConn) { - if w != nil && IsNotNil(w) { - switch x := w.(type) { - case *net.TCPConn: - if x != nil { - _ = x.CloseWrite() - } - case *gonet.TCPConn: - if x != nil { - _ = x.CloseWrite() - } - default: - if IsNotNil(w) { - _ = w.CloseWrite() - } - } - } -} - -// CloseConn closes cs. -func CloseConn(cs ...MinConn) { - for _, c := range cs { - if c == nil || IsNil(c) { - continue - } - switch x := c.(type) { - case *net.TCPConn: - if x != nil { - _ = x.Close() - } - case *net.UDPConn: - if x != nil { - _ = x.Close() - } - case *gonet.TCPConn: - if x != nil { - _ = x.Close() - } - case *gonet.UDPConn: - if x != nil { - _ = x.Close() - } - default: - if IsNotNil(c) { - _ = c.Close() - } - } - } -} - -// Close closes cs. -func Close(cs ...io.Closer) { - for _, c := range cs { - CloseOp(c, CopAny) - } -} - -// CloseOp closes op on c. -func CloseOp(c io.Closer, op CloserOp) { - if c == nil || IsNil(c) { - return - } - switch x := c.(type) { - case TCPConn: - if op == CopR { - CloseTCPRead(x) - } else if op == CopW { - CloseTCPWrite(x) - } else { // == "rw" or "any" - CloseConn(x) - } - // some udp conns (ex: demuxconn) may conform to DuplexCloser - case DuplexCloser: - if op == CopR { - _ = x.CloseRead() - } else if op == CopW { - _ = x.CloseWrite() - } else { // == "rw" or "any" - _ = x.Close() - } - case *net.UDPConn: - CloseUDP(x) - case UDPConn: - CloseConn(x) - case *net.TCPListener: - if x != nil { - _ = x.Close() - } - case *io.PipeReader: - if x != nil { - _ = x.Close() - } - case *io.PipeWriter: - if x != nil { - _ = x.Close() - } - case *os.File: - CloseFile(x) - case io.Closer: // ex: net.PacketConn - if IsNotNil(c) { - _ = c.Close() - } - } -} diff --git a/intra/core/connmap.go b/intra/core/connmap.go deleted file mode 100644 index 8960d692..00000000 --- a/intra/core/connmap.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "crypto/tls" - "fmt" - "net" - "strings" - "sync" - "time" - "unique" - - "slices" - - "github.com/celzero/firestack/intra/log" -) - -var globalTLSSessionCache = tls.NewLRUClientSessionCache(1024 * 2) - -func TlsSessionCache() tls.ClientSessionCache { - return globalTLSSessionCache -} - -type ConnMapper interface { - // Clear untracks all conns. - Clear() []string - // Track maps x[] to cid. - Track(cid, uid, pid string, x ...MinConn) int - // Get returns a conn mapped to connection id, cid. - Get(cid string) []MinConn - // GetAll returns all conns mapped to pid or uid. - GetAll(uidOrPid string) []MinConn - // Untrack closes all conns with connection id, cid. - Untrack(cid string) int - // UntrackBatch untracks one cid at a time. - UntrackBatch(cidsOrUidsOrPids []string) []string - // Len returns the number of tracked conns. - Len() int - // String returns a string repr of all tracked conns. - String() string -} - -type connstat struct { - c []MinConn - t time.Time - - uid, pid string -} - -type cm struct { - sync.RWMutex - tracc map[string]connstat // cid -> conns - tracp map[string][]string // pid -> cid - tracu map[string][]string // uid -> cid - sz int -} - -var _ ConnMapper = (*cm)(nil) - -func NewConnMap() *cm { - return &cm{ - tracc: make(map[string]connstat), - tracp: make(map[string][]string), - tracu: make(map[string][]string), - } -} - -func (h *cm) Track(cid, uid, pid string, conns ...MinConn) (n int) { - h.Lock() - defer h.Unlock() - - n = h.addLocked(cid, uid, pid, conns) - - log.D("connmap: track: %d/%d conns for %s+%s+%s", n, h.sz, cid, uid, pid) - return -} - -func (h *cm) Untrack(cid string) (n int) { - h.Lock() - defer h.Unlock() - - n = h.delLocked(cid) - log.D("connmap: untrack: %d/%d conns for %s", n, h.sz, cid) - return -} - -func (h *cm) addLocked(cid, uid, pid string, conns []MinConn) (n int) { - if v, ok := h.tracc[cid]; !ok { - h.tracc[cid] = connstat{conns, time.Now(), uid, pid} - n = len(conns) - } else { // should not happen? - // TODO: append uid and pid? - v.c = append(v.c, conns...) - n = len(v.c) - h.tracc[cid] = v - } - h.addByPidLocked(pid, cid) - h.addByUidLocked(uid, cid) - - h.sz += len(conns) - return -} - -func (h *cm) addByPidLocked(pid, cid string) { - if len(pid) <= 0 || len(cid) <= 0 { - return - } - h.tracp[pid] = CopyUniq(h.tracp[pid], []string{cid}) -} - -func (h *cm) addByUidLocked(uid, cid string) { - if len(uid) <= 0 || len(cid) <= 0 { - return - } - h.tracu[uid] = CopyUniq(h.tracu[uid], []string{cid}) -} - -func (h *cm) getLocked(cid string) *connstat { - if v, ok := h.tracc[cid]; ok { - return &v - } - return nil -} - -func (h *cm) getByUidLocked(uid string) []string { - if v, ok := h.tracu[uid]; ok { - return v - } - return nil -} - -func (h *cm) getByPidLocked(pid string) []string { - if v, ok := h.tracp[pid]; ok { - return v - } - return nil -} - -func (h *cm) delLocked(id string) (n int) { - if v, ok := h.tracc[id]; ok { - defer delete(h.tracc, id) - defer h.delByPidLocked(v.pid, id) - defer h.delByUidLocked(v.uid, id) - - n = len(v.c) - CloseConn(v.c...) - - h.sz -= n - // id maybe pid or uid - } else if cidsByUid := h.getByUidLocked(id); len(cidsByUid) > 0 { - return len(h.untrackBatchLocked(cidsByUid)) - } else if cidsByPid := h.getByPidLocked(id); len(cidsByPid) > 0 { - return len(h.untrackBatchLocked(cidsByPid)) - } else { - log.VV("connmap: untrack: id not tracked %s", id) - } - - return -} - -func (h *cm) delByPidLocked(pid, cid string) (deleted []string) { - if len(pid) <= 0 { - return - } - cids := h.tracp[pid] - if len(cid) <= 0 { // delete all - deleted = cids - delete(h.tracp, pid) - return - } - for i, id := range cids { - if id == cid { - deleted = append(deleted, id) - if rem := slices.Delete(cids, i, i+1); len(rem) <= 0 { - delete(h.tracp, pid) - break - } else { - h.tracp[pid] = rem - } - } - } - return -} - -func (h *cm) delByUidLocked(uid, cid string) (deleted []string) { - if len(uid) <= 0 { - return - } - cids := h.tracu[uid] - if len(cid) <= 0 { // delete all - deleted = cids - delete(h.tracu, uid) - return - } - for i, id := range cids { - if id == cid { - deleted = append(deleted, id) - if rem := slices.Delete(cids, i, i+1); len(rem) <= 0 { - delete(h.tracp, uid) - break - } else { - h.tracp[uid] = rem - } - } - } - return -} - -func (h *cm) UntrackBatch(cidsOrUidsOrPids []string) (closedCids []string) { - h.Lock() - defer h.Unlock() - - return h.untrackBatchLocked(cidsOrUidsOrPids) -} - -func (h *cm) untrackBatchLocked(cidsOrUidsOrPids []string) (out []string) { - processed := 0 - n := 0 - out = make([]string, 0, len(cidsOrUidsOrPids)) - for _, id := range cidsOrUidsOrPids { - connsclosed := h.delLocked(id) - if connsclosed > 0 { - out = append(out, id) - n += connsclosed - } - processed++ - } - log.D("connmap: untrack: %d batches of %d/%d (conns/cids)", processed, n, len(out)) - return -} - -func (h *cm) Get(cid string) (conns []MinConn) { - h.RLock() - defer h.RUnlock() - - if cs := h.getLocked(cid); cs != nil { - return cs.c - } - return nil -} - -func (h *cm) GetAll(uidOrPid string) (conns []MinConn) { - h.RLock() - defer h.RUnlock() - - if len(uidOrPid) <= 0 { - return - } - cidsByPid := h.getByPidLocked(uidOrPid) - cidsByUid := h.getByUidLocked(uidOrPid) - for _, cid := range append(cidsByPid, cidsByUid...) { - if cs := h.getLocked(cid); cs != nil { - conns = append(conns, cs.c...) - } - } - return -} - -func (h *cm) String() string { - h.RLock() - defer h.RUnlock() - - var s strings.Builder - for id, cs := range h.tracc { - s.WriteString(id) - s.WriteString(cs.String()) - s.WriteString("\n") - } - return s.String() -} - -func (h *cm) Clear() (cids []string) { - h.Lock() - defer h.Unlock() - - cids = make([]string, 0, len(h.tracc)) - for k, v := range h.tracc { - CloseConn(v.c...) - cids = append(cids, k) - } - clear(h.tracc) - clear(h.tracu) - clear(h.tracp) - sz := h.sz - closed := len(cids) - h.sz = 0 - log.D("connmap: clear: closed %d/%d conns", closed, sz) - return -} - -func (h *cm) Len() int { - h.RLock() - defer h.RUnlock() - - return h.sz -} - -func (c *connstat) String() string { - if c == nil { - return "" - } - return fmt.Sprintf(":%s+%s:%s:%d[%s]", - c.pid, c.uid, FmtTimeAsPeriod(c.t), len(c.c), conn2str(c.c...)) -} - -func minconn2str(c ...MinConn) (csv string) { - if len(c) == 0 { - return "" - } - - s := make([]string, 0, len(c)) - for _, v := range c { - if v == nil || IsNil(v) { - continue - } - laddr := v.LocalAddr() - if cc, ok := v.(net.Conn); ok { - raddr := cc.RemoteAddr() - s = append(s, fmt.Sprintf("%s=>%s", laddr, raddr)) - } else if laddr != nil { // nilaway - s = append(s, laddr.String()) // net.PacketConn - } - } - return strings.Join(s, ",") -} - -// use unique.Handle to handle conn2str -func conn2str(c ...MinConn) string { - return UniqStr(minconn2str(c...)) -} - -func UniqStringer(s fmt.Stringer) string { - return UniqStr(s.String()) -} - -// not always optimal: go.dev/blog/unique -// (a footnote about interning strings) -func UniqStr(s string) string { - h := unique.Make(s) - return h.Value() -} diff --git a/intra/core/connpool.go b/intra/core/connpool.go deleted file mode 100644 index 117fb358..00000000 --- a/intra/core/connpool.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/celzero/firestack/intra/log" - "github.com/miekg/dns" - "golang.org/x/sys/unix" -) - -const pooluseread = false // never used; for documentation only -const poolcapacity = 8 // default capacity -const poolmaxattempts = poolcapacity / 2 // max attempts to retrieve a conn from pool -const Nobody = uintptr(0) // nobody -const poolmaxidle = 8 * time.Minute // close unused pooled conns after this period -const poolfreshttl = 1 * time.Minute // considered fresh if less than this period -const poolscrubinterval = poolmaxidle / 3 // interval between subsequent scrubs - -// go.dev/play/p/ig2Zpk-LTSv -var ( - kaidle = int(poolmaxidle / 5 / time.Second) // 8m / 5 => 96s - kainterval = int(poolmaxidle / 10 / time.Second) // 8m / 10 => 48s -) - -var ( - errUnexpectedRead = errors.New("pool: unexpected read") - errNotSyscallConn = errors.New("core: not a syscall.Conn") - errAttemptsExceeded = errors.New("pool: max attempts exceeded") -) - -type superpool[T comparable] struct { - quit context.CancelFunc - pool *ConnPool[T] -} - -type MultConnPool[T comparable] struct { - ctx context.Context - mu sync.RWMutex - m map[T]*superpool[T] - scrubtime time.Time -} - -// NewMultConnPool creates a new multi connection-pool. -func NewMultConnPool[T comparable](ctx context.Context) *MultConnPool[T] { - return &MultConnPool[T]{ - ctx: ctx, - m: make(map[T]*superpool[T]), - scrubtime: time.Now(), - } -} - -// scrub closes and removes old conns from all conn pools. -func (m *MultConnPool[T]) scrub() { - now := time.Now() - if now.Sub(m.scrubtime) <= poolscrubinterval { // too soon - return - } - m.scrubtime = now - - select { - case <-m.ctx.Done(): - return - default: - } - - Go("superpool.scrub", func() { - m.mu.Lock() - defer m.mu.Unlock() - - var n, nclosed, nquit, nscrubbed int - n = len(m.m) - for id, super := range m.m { - if super.pool.closed.Load() { - nclosed++ - delete(m.m, id) - } else if super.pool.empty() { - nquit++ - super.quit() - delete(m.m, id) - } else { - nscrubbed++ - Go("pool.scrub", super.pool.scrub) - } - } - - log.D("pool: scrubbed: %d, closed: %d, quit: %d, total: %d", - nscrubbed, nclosed, nquit, n) - }) -} - -// Get returns a conn from the pool[id], if available. -func (m *MultConnPool[T]) Get(id T) net.Conn { - if IsZero(id) { - return nil - } - - m.mu.RLock() - super := m.m[id] - m.mu.RUnlock() - - if super != nil { - return super.pool.Get() - } - return nil -} - -// Put puts conn back in the pool[id]. -func (m *MultConnPool[T]) Put(id T, conn net.Conn) (ok bool) { - if IsZero(id) || IsNil(conn) { - return false - } - - m.mu.RLock() // read lock - super := m.m[id] - m.mu.RUnlock() - - if super == nil { - m.mu.Lock() // double check with write lock - if super = m.m[id]; super == nil { - child, sigstop := context.WithCancel(m.ctx) - super = &superpool[T]{sigstop, NewConnPool(child, id)} - m.m[id] = super - } - m.mu.Unlock() - } - - m.scrub() - return super.pool.Put(conn) -} - -type agingconn struct { - c net.Conn // pooled conn - sc syscall.RawConn // raw conn; may be nil - dob time.Time // induction time - str string // local and remote addrs -} - -// newAgingConn creates a new agingconn. -// if c is a PoolableConn, it is used to check for readability. -// if not, c is checked for freshness. -func newAgingConn(c net.Conn) agingconn { - if IsNil(c) { - return agingconn{} - } - - var sc PoolableConn - - s := conn2str(c) - if sc, _ = c.(PoolableConn); sc != nil { - // ok - } else if dc, _ := c.(*dns.Conn); dc != nil { - if tc, _ := dc.Conn.(*tls.Conn); tc != nil { - if sc, _ = tc.NetConn().(PoolableConn); sc == nil { - log.VV("pool: dnsconn != sysconn: %T for %s", tc.NetConn(), s) - } // else: ok - } else if sc, _ = dc.Conn.(PoolableConn); sc == nil { - log.VV("pool: dnsconn != sysconn: %T for %s", dc.Conn, s) - } // else: ok - } else if tc, _ := c.(*tls.Conn); tc != nil { - if sc, _ = tc.NetConn().(PoolableConn); sc == nil { - log.VV("pool: tlsconn != sysconn: %T for %s", tc.NetConn(), s) - } // else: ok - } // sc is nil - - var raw syscall.RawConn - var err error - if sc != nil { // confirm syscall.Conn works - raw, err = sc.SyscallConn() - if err != nil { - log.VV("pool: sysconn %T for %s; err %v", c, s, err) - raw = nil - } - } - return agingconn{c, raw /* may be nil */, time.Now(), s} -} - -// github.com/redis/go-redis/blob/d9eeed13/internal/pool/pool.go -type ConnPool[T comparable] struct { - ctx context.Context - id T - p chan agingconn // never closed - closed atomic.Bool -} - -// NewConnPool creates a new conn pool with preset capacity and ttl. -func NewConnPool[T comparable](ctx context.Context, id T) *ConnPool[T] { - c := &ConnPool[T]{ - ctx: ctx, - id: id, - p: make(chan agingconn, poolcapacity), - } - - context.AfterFunc(ctx, c.clean) - return c -} - -// Get returns a conn from the pool, if available, within 3 seconds. -func (c *ConnPool[T]) Get() (zz net.Conn) { - if c.closed.Load() { - return - } - - if len(c.p) == 0 { - return - } - - pooled, complete := Grx("pool.get", func(ctx context.Context) (zz net.Conn, err error) { - i := 0 - for i < poolmaxattempts { - i++ - select { - case aconn := <-c.p: - // if readable/fresh, return conn regardless of its freshness - if aconn.ok() { - aconn.keepalive(false) - return aconn.c, nil - } - (&aconn).close() - case <-ctx.Done(): - return // signal stop - default: - return // empty - } - } - return nil, errAttemptsExceeded // maxattempts exceeded - }, timeout) - - empty := IsNil(pooled) // or maxattempts exceeded - timedout := !complete - logevif(timedout || empty)("pool: %v get: empty? %t, timedout? %t", - c.id, empty, timedout) - - return pooled -} - -// Put puts conn back in the pool. -// Put takes ownership of the conn regardless of the return value. -func (c *ConnPool[T]) Put(conn net.Conn) (ok bool) { - defer func() { - if !ok { - CloseConn(conn) - } - }() - - if c.closed.Load() { - return - } - if c.full() { - return - } - - aconn := newAgingConn(conn) - if !aconn.ok() { - return - } - - aconn.resetDeadline() - - select { - case c.p <- aconn: - aconn.keepalive(true) - return true - case <-c.ctx.Done(): // stop - return - default: // pool full - return - } -} - -// empty returns true if pool is empty. -func (c *ConnPool[T]) empty() bool { - return len(c.p) == 0 -} - -// full returns true if pool is full. -func (c *ConnPool[T]) full() bool { - return len(c.p) > poolcapacity -} - -// clean closes all conns in the pool. -func (c *ConnPool[T]) clean() { - // todo: defer close(c.p) - - ok := c.closed.CompareAndSwap(false, true) - log.I("pool: %v closed? %t", c.id, ok) - for { - select { - case aconn := <-c.p: - (&aconn).close() - default: - return - } - } -} - -// scrub closes and removes old conns from the pool. -func (c *ConnPool[T]) scrub() { - if c.closed.Load() { - return - } - - staged := make([]agingconn, 0) - defer func() { - for _, aconn := range staged { - kept := false - select { - case <-c.ctx.Done(): // close conn; fallthrough - default: - select { - case c.p <- aconn: // put it back in - kept = true - case <-c.ctx.Done(): // close conn; fallthrough - default: // pool full - } - } - if !kept { - (&aconn).close() - } - } - }() - - for { - select { - case aconn := <-c.p: - if aconn.old() || !aconn.ok() { - (&aconn).close() - } else { - staged = append(staged, aconn) - } // next - case <-c.ctx.Done(): // closed - return - default: // empty - return - } - } -} - -// old returns true if conn must be closed, -// or it might end up far longer than desired -// (ex: with long keepalives draining power). -func (a agingconn) old() bool { - return time.Since(a.dob) > poolmaxidle -} - -// ok returns true if a is readable or fresh. -func (a agingconn) ok() bool { - if a.sc != nil { // if sysconn, check readability - return a.readable() - } - return a.fresh() // else: check freshness -} - -// fresh returns true if conn is recent enough. -func (a agingconn) fresh() bool { - return time.Since(a.dob) < poolfreshttl -} - -// close closes the conn. -func (a *agingconn) close() { - a.dob = time.Time{} - CloseConn(a.c) -} - -// github.com/golang/go/issues/15735 -func (a agingconn) readable() bool { - err := a.canread() - - logev(err)("pool: %s sysconn? %T readable? %t; err? %v", - a.str, a.c, err == nil, err) - return err == nil -} - -// keepalive sets tcp keepalive, if y is true. -// If y is false, it disables keepalive. -func (a agingconn) keepalive(y bool) bool { - if y { - cleardeadline(a.c) // reset any previous timeout - return SetKeepAliveConfigSockOpt(a.c, kaidle, kainterval) - } else { - if c, ok := a.c.(KeepAliveConn); ok { - return c.SetKeepAlive(false) == nil - } - return false - } -} - -// github.com/go-sql-driver/mysql/blob/f20b28636/conncheck.go -// github.com/redis/go-redis/blob/cc9bcb0c0/internal/pool/conn_check.go -func (a agingconn) canread() error { - sc := a.sc - if sc == nil { - return errNotSyscallConn - } - - var checkErr error - var ctlErr error - - if pooluseread { // stackoverflow.com/q/12741386 - ctlErr = sc.Read(func(fd uintptr) bool { - // 0 byte reads do not work to detect readability: - // see: go-review.googlesource.com/c/go/+/23227 - // pitfalls: github.com/redis/go-redis/issues/3137 - var buf [1]byte - n, err := syscall.Read(int(fd), buf[:]) - switch { - case n == 0 && err == nil: - checkErr = io.EOF - case n > 0: - // conn is supposed to be idle - checkErr = errUnexpectedRead - case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: - checkErr = nil - default: - checkErr = err - } - return true - }) - } else { - ctlErr = sc.Control(func(fd uintptr) { - fds := []unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN | unix.POLLERR}, - } - n, err := unix.Poll(fds, 0) - if err != nil { - checkErr = fmt.Errorf("pool: poll: err: %v", err) - } - if n > 0 { - checkErr = fmt.Errorf("pool: poll: sz: %d (must be 0), errno: %v", - n, fds[0].Revents) - } - }) - } - return JoinErr(ctlErr, checkErr) // may return nil -} - -func (a agingconn) resetDeadline() { - a.c.SetDeadline(time.Time{}) -} - -func logev(err error) log.LogFn { - return logevif(err != nil) -} - -func logevif(e bool) log.LogFn { - if e { - return log.E - } - return log.VV -} diff --git a/intra/core/cp.go b/intra/core/cp.go deleted file mode 100644 index 777143c9..00000000 --- a/intra/core/cp.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "errors" - "io" -) - -var ( - errNoPipe = errors.New("pipe: src or dst nil") - errNoStream = errors.New("stream: reader or writer nil") - - // errInvalidWrite means that a write returned an impossible count. - errInvalidWrite = errors.New("invalid write result") -) - -// Pipe reads data from src to dst, and returns the number of bytes copied. -// Prefers src.WriteTo(dst) and dst.ReadFrom(src) if available. -// Otherwise, it uses core.Stream. -func Pipe(dst io.Writer, src io.Reader) (int64, error) { - if IsNil(src) || IsNil(dst) { - return 0, errNoPipe - } - - // when "uploading" src is wired to TUN via netstack as either - // gonet.TCPConn or gonet.UDPConn; reverse when "downloading". - - // Retrier conns have specific entry-points; make sure those - // get priority over regular copy. - if x, ok := src.(WriteRetrierConn); ok { - return x.WriteTo(dst) // may be called on "downloading" - } else if x, ok := dst.(ReadRetrierConn); ok { - return x.ReadFrom(src) // may be called on "uploading" - } - - // Prefer WriteTo/ReadFrom if available as they are zero-copy. - // also: github.com/acln0/zerocopy - if x, ok := src.(io.WriterTo); ok { - return x.WriteTo(dst) - } else if x, ok := dst.(io.ReaderFrom); ok { - return x.ReadFrom(src) - } - return Stream(dst, src) -} - -// Stream reads data from src in to dst until error, and returns the no. of bytes read. -// Internally, it bypasses io.ReaderFrom and io.WriterTo but uses io.CopyBuffer, -// recycling buffers from a global pool. -func Stream(dst io.Writer, src io.Reader) (written int64, err error) { - if IsNil(src) || IsNil(dst) { - return 0, errNoStream - } - - // TODO: writerNoReadFrom and readerNoWriteTo - bptr := Alloc16() // TLS record size? - buf := *bptr - buf = buf[:cap(buf)] - defer func() { - *bptr = buf - Recycle(bptr) - }() - // implementation from: io.CopyBuffer - // laid out here since "hiding" ReadFrom/WriteTo funcs - // did not work as expected and led to recursive calls. - for { - nr, er := src.Read(buf) - if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = errInvalidWrite - } - } - written += int64(nw) - if ew != nil { - err = ew - break - } - if nr != nw { - err = io.ErrShortWrite - break - } - } - if er != nil { - if er != io.EOF { - err = er - } - break - } - } - return written, err -} - -// ref: github.com/golang/go/issues/58808 -// from: go-review.googlesource.com/c/go/+/472475/20/src/net/net.go - -// noReadFrom can be embedded alongside another type to -// hide the ReadFrom method of that other type. -// type noReadFrom struct{} - -// ReadFrom hides another ReadFrom method. -// It should never be called. -// func (noReadFrom) ReadFrom(io.Reader) (int64, error) { -// panic("noReadFrom: hidden func; should not be called") -// } - -// noWriteTo can be embedded alongside another type to -// hide the WriterTo method of that other type. -// type noWriteTo struct{} - -// func (noWriteTo) WriteTo(io.Writer) (int64, error) { -// panic("noWriteTo: hidden func; should not be called") -// } - -// noReadFromWriter implements all the methods of io.Writer other -// than ReadFrom. This is used to permit ReadFrom to call io.Copy -// without leading to a recursive call to ReadFrom. -// type noReadFromWriter struct { -// noReadFrom -// io.Writer -// } - -// noWriteToReader implements all the methods of io.Reader other -// than WriteTo. This is used to permit WriteTo to call io.Copy -// without leading to a recursive call to WriteTo. -// type noWriteToReader struct { -// noWriteTo -// io.Reader -// } diff --git a/intra/core/dontpanic.go b/intra/core/dontpanic.go deleted file mode 100644 index 64514ebc..00000000 --- a/intra/core/dontpanic.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// Copyright (c) HashiCorp, Inc. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "bytes" - "fmt" - "os" - "runtime/trace" - "sync" - "time" - - "github.com/celzero/firestack/intra/log" -) - -// from: github.com/hashicorp/terraform/blob/325d18262e/internal/logging/panic.go#L36-L64 - -type Finally func() - -type ExitCode int - -func (e ExitCode) int() int { - return int(e) -} - -// An exit code of 11 keeps us out of the way of the detailed exitcodes -// from plan, and also happens to be the same code as SIGSEGV which is -// roughly the same type of condition that causes most panics. -const Exit11 ExitCode = 11 - -// DontExit is a special code that can be passed to Recover to indicate that -// the process should not exit after recovering from a panic. -const DontExit ExitCode = 0 - -// In case multiple goroutines panic concurrently, ensure only one of them -// is able to print the panic message and exit the process. -var _pmu sync.RWMutex - -// protects recorder.WriteTo. -var _rmu sync.Mutex - -var parentCallerDepthAt = 1 - -var recorder *trace.FlightRecorder = trace.NewFlightRecorder(trace.FlightRecorderConfig{ - MinAge: 10 * time.Second, -}) - -// fn is called in a separate goroutine, if a panic is recovered. -// RecoverFn must be called as a defered function, and must be the first -// defer called at the start of a new goroutine. -func RecoverFn(aux string, fn Finally) (didpanic bool) { - recovered := recover() - didpanic = recovered != nil - if !didpanic { // nothing to recover from - return false - } - - defer Gif(didpanic, "fin."+aux, fn) - - msg := fmt.Sprintf("%s [%d] %v", aux, DontExit, recovered) - log.E2(parentCallerDepthAt+1, msg) - - recorderToConsole() - applog(DontExit, msg) - return didpanic -} - -func Recording() bool { - return recorder.Enabled() -} - -func Record(start bool) (recording bool, err error) { - recording = recorder.Enabled() - if start { - if !recording { - err = recorder.Start() - recording = err == nil - } - } else { - if recording { - go recorder.Stop() - recording = false - } - } - return -} - -func recorderToConsole() (logged bool) { - logged, _ = DumpRecorder(true /* onConsole */) - return -} - -// Logs flight recorder to console if onConsole is true. -// The returned value b contains recorded bytes when got is true. -func DumpRecorder(onConsole bool) (got bool, b bytes.Buffer) { - if !recorder.Enabled() { - return - } - - _rmu.Lock() - defer _rmu.Unlock() - - n, _ := recorder.WriteTo(&b) - - if got = n > 0; got && onConsole { - log.R( /*console*/ true, b.String()) - } - - return got, b -} - -// Recover must be called as a defered function, and must be the first -// defer called at the start of a new goroutine. -func Recover(code ExitCode, aux any) (didpanic bool) { - recovered := recover() - didpanic = recovered != nil - if !didpanic { // nothing to recover from - return false - } - - msg := fmt.Sprintf("%s [%d] %v [%s]", aux, code, recovered, stamp()) - log.E2(parentCallerDepthAt, msg) - - recorderToConsole() - applog(code, msg) - return didpanic -} - -func applog(code ExitCode, msg string) { - // have all managed goroutines checkin here, and prevent them from exiting - // if there's a panic in progress. While this can't lock the entire runtime - // to block progress, we can prevent some cases where firestack may return - // early before the panic has been printed out. - if code == DontExit { - // many "dontexit" goroutines can safely run concurrently. - _pmu.RLock() - defer _pmu.RUnlock() - } else { - defer os.Exit(Exit11.int()) - // upto one goroutine panicking should exit the process. - _pmu.Lock() - defer _pmu.Unlock() - } - - bptr := LOB() - b := *bptr - b = b[:cap(b)] - defer func() { - *bptr = b - Recycle(bptr) - }() - log.C(msg, b) -} diff --git a/intra/core/err.go b/intra/core/err.go deleted file mode 100644 index 1e8e8809..00000000 --- a/intra/core/err.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "errors" - "strings" -) - -func OneErr(errs ...error) error { - // or: cmp.Or(errs...) - for _, err := range errs { - if err != nil { - return err - } - } - return nil -} - -func JoinErr(errs ...error) error { - return joinErr(false /*uniq*/, errs...) -} - -func UniqErr(errs ...error) error { - return joinErr(true /*uniq*/, errs...) -} - -// must always return the interface "error" and not -// a *errMult, as client code checks for err == nil -// for a nil *errMult returned from here is always false. -// ie, if *errMult was the return type, then the "error" -// interface returned by JoinErr and UniqErr is not "nil" -// even if *errMult returned by joinErr was "nil". -// see also: IsNil() and IsNotNil() -func joinErr(uniq bool, errs ...error) error { - if len(errs) <= 0 { - return nil - } - - var all []error - var m map[error]struct{} - - if false { // unjoin? - for _, err := range errs { - if err == nil { - continue - } - var merr *errMult - if errors.As(err, &merr) { - all = append(all, merr.Unwrap()...) - } - } - } - - if uniq { - m = make(map[error]struct{}, len(errs)) - } - for _, err := range errs { - if err == nil { - continue - } - haserr := false - if m != nil { // uniq - if _, haserr = m[err]; !haserr { - for k := range m { - if haserr = errors.Is(k, err); haserr { - break - } - } - } - m[err] = struct{}{} - } - if !haserr { - all = append(all, err) - } - } - if len(all) <= 0 { - return nil - } - - return &errMult{errs: all, sep: " | "} -} - -type errMult struct { - errs []error - sep string -} - -func (e *errMult) Error() string { - if e == nil { - return "{nil}" - } - if len(e.errs) <= 0 { - return "" - } else if len(e.errs) == 1 { - return e.errs[0].Error() - } - - b := strings.Builder{} - for i, err := range e.errs { - if i != 0 { // except for first entry, add separator - _, _ = b.WriteString(e.sep) - } - _, _ = b.WriteString(err.Error()) - } - return b.String() -} - -func (e *errMult) Unwrap() []error { - if e == nil { - return nil - } - return e.errs -} diff --git a/intra/core/expiringmap.go b/intra/core/expiringmap.go deleted file mode 100644 index e05d7845..00000000 --- a/intra/core/expiringmap.go +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "context" - "sync" - "time" -) - -var ( - reapthreshold = 5 * time.Minute - maxreapiter = 50 - sizethreshold = 100 - lifetime = 0 * time.Millisecond -) - -type val[V any] struct { - v V - expiry time.Time - hits uint32 -} - -// ExpMap holds expiring keys and read hits. -type ExpMap[P comparable, Q any] struct { - sync.Mutex // guards ExpMap. - ctx context.Context - m map[P]*val[Q] - sigreap chan struct{} - lastreap time.Time - minlife time.Duration -} - -// NewExpiringMap returns a new ExpMap with min lifetime of 0. -func NewExpiringMap[P comparable, Q any](ctx context.Context) *ExpMap[P, Q] { - return NewExpiringMapLifetime[P, Q](ctx, lifetime) -} - -func NewExpiringMapLifetime[P comparable, Q any](ctx context.Context, min time.Duration) *ExpMap[P, Q] { - m := &ExpMap[P, Q]{ - ctx: ctx, - m: make(map[P]*val[Q]), - sigreap: make(chan struct{}), - lastreap: time.Now(), - minlife: min, - } - Gx1("expm.reaper", m.reaper, ctx) - // test: go.dev/play/p/EYq_STKvugb - return m -} - -// Get returns the number of hits for the given key. -func (m *ExpMap[P, Q]) Get(key P) uint32 { - if done(m.ctx) { - return 0 - } - - n := time.Now() - - m.Lock() - defer m.Unlock() - - v, ok := m.m[key] - if !ok { - v = &val[Q]{ - expiry: n, - } - m.m[key] = v - } else if n.After(v.expiry) { - v.hits = 0 - } else { - v.hits++ - } - return v.hits -} - -// Set sets the expiry for the given key and returns the number of hits. -// expiry must be greater than the minimum lifetime. -func (m *ExpMap[P, Q]) Set(key P, expiry time.Duration) uint32 { - if done(m.ctx) { - return 0 - } - - if expiry < m.minlife { - expiry = m.minlife - } - - n := time.Now().Add(expiry) - - m.Lock() - defer m.Unlock() - - v, ok := m.m[key] - if v == nil || !ok { // add new val - v = &val[Q]{ - expiry: n, - } - m.m[key] = v - } else if n.After(v.expiry) { // update expiry - v.expiry = n - } // else: no change - - select { - case m.sigreap <- struct{}{}: - default: - } - - var zz Q - v.v = zz - return v.hits -} - -// Set sets the (value, expiry) for the given key and returns the number of hits. -// expiry must be greater than the minimum lifetime. -func (m *ExpMap[P, Q]) K(key P, value Q, expiry time.Duration) uint32 { - if done(m.ctx) { - return 0 - } - - if expiry < m.minlife { - expiry = m.minlife - } - - n := time.Now().Add(expiry) - - m.Lock() - defer m.Unlock() - - v, ok := m.m[key] - if v == nil || !ok { // add new val - v = &val[Q]{ - expiry: n, - } - m.m[key] = v - } else if n.After(v.expiry) { // update expiry - v.expiry = n - } // else: no change - - select { - case m.sigreap <- struct{}{}: - default: - } - - v.v = value - return v.hits -} - -func (m *ExpMap[P, Q]) V(key P) (zz Q, fresh bool) { - if done(m.ctx) { - return // zz, false - } - - m.Lock() - defer m.Unlock() - - now := time.Now() - if v, ok := m.m[key]; ok && v != nil { - return v.v, now.Before(v.expiry) - } - return // zz, false -} - -func (m *ExpMap[P, Q]) Alive(key P) bool { - if done(m.ctx) { - return false - } - - m.Lock() - defer m.Unlock() - - now := time.Now() - if v, ok := m.m[key]; ok && v != nil { - return now.Before(v.expiry) - } - return false -} - -// Delete deletes the given key. -func (m *ExpMap[P, Q]) Delete(key P) { - m.Lock() - defer m.Unlock() - - delete(m.m, key) -} - -// Len returns the number of keys, which may or may not have expired. -func (m *ExpMap[P, Q]) Len() int { - m.Lock() - defer m.Unlock() - - return len(m.m) -} - -// Clear deletes all keys and returns the number of keys deleted. -func (m *ExpMap[P, Q]) Clear() int { - m.Lock() - defer m.Unlock() - - l := len(m.m) - clear(m.m) - return l -} - -// reaper deletes expired keys. -// Must always be called from a goroutine. -func (m *ExpMap[P, Q]) reaper(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case <-m.sigreap: - } - - m.Lock() - - l := len(m.m) - if l < sizethreshold { - m.Unlock() - continue - } - - now := time.Now() - treap := m.lastreap.Add(reapthreshold) - // if last reap was reap-threshold minutes ago... - if now.Sub(treap) <= 0 { - m.Unlock() - continue - } - m.lastreap = now - // reap up to maxreapiter entries - i := 0 - for k, v := range m.m { - i += 1 - if now.Sub(v.expiry) > 0 { - delete(m.m, k) - } - if i > maxreapiter { - break - } - } - m.Unlock() - } -} - -// done returns true if the context is done. -func done(ctx context.Context) bool { - select { - case <-ctx.Done(): - return true - default: - } - return false -} diff --git a/intra/core/expiringsieve.go b/intra/core/expiringsieve.go deleted file mode 100644 index 6601a358..00000000 --- a/intra/core/expiringsieve.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "context" - "sync" - "time" -) - -// Sieve2K is a map of expiring maps. The outer map is keyed to K1, -// while the inner expiring maps are keyed to K2. -type Sieve2K[K1, K2 comparable, V any] struct { - ctx context.Context - mu sync.RWMutex // protects m, c - m map[K1]*Sieve[K2, V] - d map[K1]context.CancelFunc - life time.Duration -} - -// NewSieve2K returns a new Sieve2K with keys expiring after lifetime. -func NewSieve2K[K1, K2 comparable, V any](ctx context.Context, dur time.Duration) *Sieve2K[K1, K2, V] { - return &Sieve2K[K1, K2, V]{ - ctx: ctx, - m: make(map[K1]*Sieve[K2, V]), - d: make(map[K1]context.CancelFunc), - life: dur, - } -} - -// Sieve is a thread-safe map with expiring keys. -type Sieve[K comparable, V any] struct { - c *ExpMap[K, V] -} - -// NewSieve returns a new Sieve with keys expiring after lifetime. -func NewSieve[K comparable, V any](ctx context.Context, dur time.Duration) *Sieve[K, V] { - return &Sieve[K, V]{ - c: NewExpiringMapLifetime[K, V](ctx, dur), - } -} - -// Get returns the value associated with the given key, -// and a boolean indicating whether the key was found. -func (s *Sieve[K, V]) Get(k K) (V, bool) { - return s.c.V(k) -} - -// Put adds an element to the sieve with the given key and value. -func (s *Sieve[K, V]) Put(k K, v V) (replaced bool) { - return s.c.K(k, v, s.c.minlife) > 0 -} - -// Del removes the element with the given key from the sieve. -func (s *Sieve[K, V]) Del(k K) { - s.c.Delete(k) -} - -// Len returns the number of elements in the sieve. -func (s *Sieve[K, V]) Len() int { - if s == nil || s.c == nil { - return 0 - } - - return s.c.Len() -} - -// Clear removes all elements from the sieve. -func (s *Sieve[K, V]) Clear() int { - if s == nil || s.c == nil { - return 0 - } - return s.c.Clear() -} - -// Get returns the value associated with the given key, -// and a boolean indicating whether the key was found. -func (s *Sieve2K[K1, K2, V]) Get(k1 K1, k2 K2) (zz V, ok bool) { - s.mu.RLock() - inn := s.m[k1] - s.mu.RUnlock() - - if inn != nil { - return inn.Get(k2) - } - return -} - -// Put adds an element to the sieve with the given key and value. -func (s *Sieve2K[K1, K2, V]) Put(k1 K1, k2 K2, v V) (replaced bool) { - s.mu.RLock() - inn := s.m[k1] - s.mu.RUnlock() - - if inn == nil { - s.mu.Lock() - inn = s.m[k1] - if inn == nil { - ctx, done := context.WithCancel(s.ctx) - inn = NewSieve[K2, V](ctx, s.life) - s.m[k1] = inn - s.d[k1] = done - } - s.mu.Unlock() - } - - return inn.Put(k2, v) -} - -// Del removes the element with the given key from the sieve. -func (s *Sieve2K[K1, K2, V]) Del(k1 K1, k2 K2) { - s.mu.RLock() - inn := s.m[k1] - if inn != nil { - inn.Del(k2) - } - empty := inn.Len() == 0 // inn may be nil - s.mu.RUnlock() - - if empty { - s.mu.Lock() - inn = s.m[k1] // inn may be nil - done := s.d[k1] // done may be nil - if inn.Len() == 0 { - delete(s.m, k1) - delete(s.d, k1) - if done != nil { - done() - } - } - s.mu.Unlock() - } -} - -// Len returns the number of elements in the sieve. -func (s *Sieve2K[K1, K2, V]) Len() (n int) { - s.mu.RLock() - defer s.mu.RUnlock() - - for _, inn := range s.m { - n += inn.Len() - } - return -} - -// Clear removes all elements from the sieve. -func (s *Sieve2K[K1, K2, V]) Clear() (n int) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, inn := range s.m { - n += inn.Clear() - } - for _, done := range s.d { - done() - } - - clear(s.m) - clear(s.d) - return -} diff --git a/intra/core/fmtstr.go b/intra/core/fmtstr.go deleted file mode 100644 index fa66b613..00000000 --- a/intra/core/fmtstr.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "fmt" - "time" -) - -var units = []string{"b", "kb", "mb", "gb"} - -// from: github.com/google/gops/blob/35c854fb84/agent/agent.go -func FmtBytes(val uint64) string { - var i int - var target uint64 - for i = range units { - target = 1 << uint(10*(i+1)) - if val < target { - break - } - } - if i > 0 { - return fmt.Sprintf("%0.2f%s", float64(val)/(float64(target)/1024), units[i]) - } - return fmt.Sprintf("%d bytes", val) -} - -func FmtTimeNs(ns uint64) string { - return time.Now().Add(-time.Duration(ns)).Format(time.TimeOnly) -} - -func Nano2Sec(ns uint64) int64 { - return int64((time.Duration(ns) * time.Nanosecond).Seconds()) -} - -func FmtTimeAsPeriod(t time.Time) string { - return FmtPeriod(time.Since(t)) -} - -func FmtPeriod(d time.Duration) string { - p := "" - if d < 0 { - p = "-" - } - d = d.Abs() - if d < time.Microsecond { - return p + fmt.Sprintf("%dns", d.Nanoseconds()) - } else if d < time.Millisecond { - return p + fmt.Sprintf("%dยตs", d.Microseconds()) - } else if d < time.Second { - return p + fmt.Sprintf("%dms", d.Milliseconds()) - } else if d < time.Minute { - return p + fmt.Sprintf("%fs", d.Seconds()) - } else if d < time.Hour { - return p + fmt.Sprintf("%dm %ds", int64(d.Minutes()), int64(d.Seconds())%60) - } else if d < 24*time.Hour { - return p + fmt.Sprintf("%dh %dm %ds", int64(d.Hours()), int64(d.Minutes())%60, int64(d.Seconds())%60) - } else { - return p + fmt.Sprintf("%dd %dh %dm %ds", int64(d.Hours()/24), int64(d.Hours())%24, int64(d.Minutes())%60, int64(d.Seconds())%60) - } -} - -func FmtUnixMillisAsPeriod(ms int64) string { - return FmtTimeAsPeriod(time.UnixMilli(ms)) -} - -func FmtSecs(s int64) string { - return FmtPeriod(time.Duration(s) * time.Second) -} - -func FmtNanos(ns float64) string { - return FmtPeriod(time.Duration(ns) * time.Nanosecond) -} - -func FmtMillis(ms int64) string { - return FmtPeriod(time.Duration(ms) * time.Millisecond) -} - -func FmtUnixMillisAsTimestamp(ms int64) string { - return time.UnixMilli(ms).Format(time.Stamp) -} - -func FmtUnixEpochAsPeriod(secs int64) string { - return FmtTimeAsPeriod(time.Unix(secs, 0)) -} diff --git a/intra/core/hangover.go b/intra/core/hangover.go deleted file mode 100644 index a3cf0593..00000000 --- a/intra/core/hangover.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "time" -) - -var zerotime = time.Time{} - -type Hangover struct { - start *Volatile[time.Time] -} - -func NewHangover() *Hangover { - return &Hangover{start: NewVolatile(zerotime)} -} - -func (h *Hangover) Note() { - s := h.start.Load() - if s.IsZero() { - h.start.Cas(s, time.Now()) - } // else: already started -} - -func (h *Hangover) Break() { - s := h.start.Load() - if !s.IsZero() { - h.start.Cas(s, zerotime) - } // else: already stopped -} - -func (h *Hangover) Within(d time.Duration) bool { - s := h.start.Load() - if s.IsZero() { - return true - } - return time.Since(s) <= d -} - -func (h *Hangover) Exceeds(d time.Duration) bool { - return !h.Within(d) -} diff --git a/intra/core/ip.go b/intra/core/ip.go deleted file mode 100644 index 97cbd982..00000000 --- a/intra/core/ip.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT - -// from: https://github.com/bepass-org/warp-plus/blob/19ac233cc/iputils/iputils.go - -package core - -import ( - "errors" - "fmt" - "math/big" - "math/rand" - "net" - "net/netip" - "time" -) - -// RandomIPFromPrefix returns a random IP from the provided CIDR prefix. -// Supports IPv4 and IPv6. Does not support mapped inputs. -func RandomIPFromPrefix(cidr netip.Prefix) (netip.Addr, error) { - startingAddress := cidr.Masked().Addr() - if startingAddress.Is4In6() { - return netip.Addr{}, errors.New("mapped v4 addresses not supported") - } - - prefixLen := cidr.Bits() - if prefixLen == -1 { - return netip.Addr{}, fmt.Errorf("invalid cidr: %s", cidr) - } - - // Initialise rand number generator - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - - // Find the bit length of the Host portion of the provided CIDR - // prefix - hostLen := big.NewInt(int64(startingAddress.BitLen() - prefixLen)) - - // Find the max value for our random number - max := new(big.Int).Exp(big.NewInt(2), hostLen, nil) - - // Generate the random number - randInt := new(big.Int).Rand(rng, max) - - // Get the first address in the CIDR prefix in 16-bytes form - startingAddress16 := startingAddress.As16() - - // Convert the first address into a decimal number - startingAddressInt := new(big.Int).SetBytes(startingAddress16[:]) - - // Add the random number to the decimal form of the starting address - // to get a random address in the desired range - randomAddressInt := new(big.Int).Add(startingAddressInt, randInt) - - // Convert the random address from decimal form back into netip.Addr - randomAddress, ok := netip.AddrFromSlice(randomAddressInt.FillBytes(make([]byte, 16))) - if !ok { - return netip.Addr{}, fmt.Errorf("failed to generate random IP from CIDR: %s", cidr) - } - - // Unmap any mapped v4 addresses before return - return randomAddress.Unmap(), nil -} - -func IP2Cidr(ippOrCidr string) (*net.IPNet, error) { - var ipaddr netip.Addr - if _, ipnet, err := net.ParseCIDR(ippOrCidr); err == nil { - return ipnet, err - } else { - if ipp, err1 := netip.ParseAddrPort(ippOrCidr); err1 == nil { - ipaddr = ipp.Addr() - } else if ip, err2 := netip.ParseAddr(ippOrCidr); err2 == nil { - ipaddr = ip - } else { - return nil, fmt.Errorf("ip2cidr: errs: cidr %v / ipp %v / ip %v", err, err1, err2) - } - ip := ipaddr.AsSlice() - mask := net.CIDRMask(ipaddr.BitLen(), ipaddr.BitLen()) - return &net.IPNet{IP: ip, Mask: mask}, nil - } -} - -func IP2Cidr2(ippOrCidr string) (zz netip.Prefix, err error) { - var ipaddr netip.Addr - if prefix, err := netip.ParsePrefix(ippOrCidr); err == nil { - return prefix, err - } else { - if ipp, err1 := netip.ParseAddrPort(ippOrCidr); err1 == nil { - ipaddr = ipp.Addr() - } else if ip, err2 := netip.ParseAddr(ippOrCidr); err2 == nil { - ipaddr = ip - } else { - return zz, fmt.Errorf("ip2cidr2: errs: cidr %v / ipp %v / ip %v", err, err1, err2) - } - return netip.PrefixFrom(ipaddr, ipaddr.BitLen()), nil - } -} diff --git a/intra/core/overreach.go b/intra/core/overreach.go deleted file mode 100644 index ab2dd4a8..00000000 --- a/intra/core/overreach.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2026 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "os" - _ "unsafe" // for go:linkname -) - -// pushing / pulling symbols work provided -// -ldflags="checklinkname=0" - -//go:linkname secureMode runtime.secureMode -var secureMode bool - -func init() { - // github.com/golang/go/issues/69868 - // Unfortunately, Android apps have AT_SECURE set - // (read bytes in /proc/self/auxv on non-rooted Androids). - // This means, on Go runtime fatal / throws and a few kinds of panics, - // only one line is output to logcat (Android's stderr) which makes it - // hard to tell just what went wrong. Android, does use unwinder for - // native apps, and the Android RunTime has its own unwinder; - // both of which traceback seemingly oblivious to AT_SECURE. - // Perhaps, there's security benefits to the Go runtime being this rigid - // about GOTRACEBACK, but for goos.IsAndroid (and for apps with uid > 10000), - // using AT_SECURE to determine "setuid-like" protections appears pointless. - secureMode = false -} - -func SecureMode(new bool) (prev bool) { - prev = secureMode - secureMode = new - return prev -} - -// RuntimeSecureMode reports whether the Go runtime is in secure mode. -// github.com/golang/go/blob/e2fef50def98/src/runtime/os_linux.go#L296 -func RuntimeSecureMode() (them, us bool) { - return runtime_isSecureMode(), secureMode -} - -// RuntimeGotraceback returns the current GOTRACEBACK settings. -// github.com/golang/go/blob/e2fef50def98/src/runtime/runtime1.go#L38 -func RuntimeGotraceback() (l int32, all, crash bool) { - return runtime_gotraceback() -} - -// RuntimeFinishDebugVarsSetup resets internal runtime debug variables -// by re-reading GODEBUG & GOTRACEBACK env vars. -// github.com/golang/go/blob/e2fef50def98/src/runtime/runtime1.go#L462 -func RuntimeFinishDebugVarsSetup() { - runtime_finishDebugVarsSetup() -} - -// RuntimeEnviron returns the Go runtime's cached environment vars. -// github.com/golang/go/blob/e2fef50def98/src/runtime/runtime1.go#L98 -func RuntimeEnviron() []string { - return runtime_environ() -} - -// SetRuntimeEnviron sets / adds a key-value pair in the Go runtime's -// cached environment vars. -func SetRuntimeEnviron(key, val string) (didSet bool, err error) { - envs := runtime_environ() - kv := key + "=" - for i, e := range envs { - if len(e) >= len(kv) && e[:len(kv)] == kv { - envs[i] = kv + val - err = os.Setenv(key, val) - didSet = true - break - } - } - return -} - -// GetRuntimeEnviron gets a value from the Go runtime's cached -// environment vars. -func GetRuntimeEnviron(key string) (val string, found bool) { - envs := runtime_environ() - kv := key + "=" - for _, e := range envs { - if len(e) >= len(kv) && e[:len(kv)] == kv { - return e[len(kv):], true - } - } - return -} - -// RuntimeWtf uses runtime.writeHeader and emits s to logd. -// github.com/golang/go/blob/e2fef50def98/src/runtime/write_err_android.go#L39 -func RuntimeWtf(s string) { - runtime_wtf([]byte(s)) -} - -//go:linkname runtime_environ runtime.environ -func runtime_environ() []string - -//go:linkname runtime_finishDebugVarsSetup runtime.finishDebugVarsSetup -func runtime_finishDebugVarsSetup() - -//go:linkname runtime_isSecureMode runtime.isSecureMode -func runtime_isSecureMode() bool - -//go:linkname runtime_gotraceback runtime.gotraceback -func runtime_gotraceback() (int32, bool, bool) - -//go:linkname runtime_wtf runtime.writeErr -func runtime_wtf(b []byte) - -// ld.lld: error: relocation R_X86_64_PC32 cannot be used against symbol 'runtime.time_now'; recompile with -fPIC -// runtime.time_now => time.now -// ref: github.com/ulule/limiter/blob/f0ada6cb8fa4dc55a734de737c7b4a3f35c86ae1/internal/fasttime/fasttime.go#L12 diff --git a/intra/core/p2est.go b/intra/core/p2est.go deleted file mode 100644 index 3ae9c820..00000000 --- a/intra/core/p2est.go +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "context" - "math" - "slices" - "sync" -) - -// from: github.com/celzero/rethink-app/main/app/src/main/java/com/celzero/bravedns/util/P2QuantileEstimation.kt -// details: aakinshin.net/posts/p2-quantile-estimator/ -// orig impl: github.com/AndreyAkinshin/perfolizer p2.cs -type p2 struct { - mu sync.RWMutex - ctx context.Context - p float64 // percentile - u int64 // sample size - mid int64 // u / 2 - n []int64 // marker positions - ns []float64 // desired marker positions - q []float64 // marker heights - count int64 // total sampled so far - addc chan float64 // add sample -} - -// P2QuantileEstimator is an interface for the P2 quantile estimator. -type P2QuantileEstimator interface { - // Add a sample to the estimator. - Add(float64) - // Get the estimation for p. - Get() int64 - // Get the percentile, p. - P() float64 -} - -var _ P2QuantileEstimator = (*p2)(nil) - -// NewP50Estimator returns a new P50 (median) estimator. -func NewP50Estimator(ctx context.Context) *p2 { - // calibrate: go.dev/play/p/Ry1i61XqzgB - // 31 worked best amid wild latency fluctuations - // using 11 for lower overhead; 5 is the default - return NewP2QuantileEstimator(ctx, 11, 0.5) -} - -// NewP90Estimator returns a new estimator with percentile p. -func NewP2QuantileEstimator(ctx context.Context, samples int64, probability float64) *p2 { - // total samples, typically 5; higher sample size improves accuracy for - // lower percentiles (p50) at the expense of computational cost; - // for higher percentiles (p90+), even sample size as low as 5 works fine. - mid := int64(math.Floor(float64(samples) / 2.0)) - p := &p2{ - ctx: ctx, - p: probability, - u: samples, - mid: mid, - n: make([]int64, samples), - ns: make([]float64, samples), - q: make([]float64, samples), - count: 0, - addc: make(chan float64, samples), - } - Gx("p2.est", p.run) - return p -} - -// P returns the percentile, p. -func (est *p2) P() float64 { - return est.p -} - -// Add a sample to the estimator. -// www.cse.wustl.edu/~jain/papers/ftp/psqr.pdf (p. 1078) -func (est *p2) Add(x float64) { - select { - case est.addc <- x: - case <-est.ctx.Done(): - default: - } -} - -func (est *p2) run() { - for { - select { - case x := <-est.addc: - est.add(x) - case <-est.ctx.Done(): - return - } - } -} - -func (est *p2) add(x float64) { - est.mu.Lock() - defer est.mu.Unlock() - - defer func() { - est.count += 1 - }() - - if est.count < est.u { - est.q[est.count] = x - - if est.count+1 == est.u { - slices.Sort(est.q) - - t := est.u - 1 // 0 index - for i := int64(0); i <= t; i++ { - est.n[i] = i - } - - // divide p into mid no of equal segments - // p => 0.5, u = 11, t = 10, mid = 5; pmid => 0.1 - pmid := est.p / float64(est.mid) - for i := int64(0); i <= est.mid; i++ { - density := pmid * float64(i) - est.ns[i] = density * float64(t) - } - - rem := t - est.mid // the rest - s := 1.0 - est.p // left-over probability - // divide q into rem no of equal segments - // q => 0.5, u = 10, mid = 5, rem = 5; smid => 0.5 - smid := s / float64(rem) - for i := int64(1); i <= rem; i++ { - // assign i-th portion of smid to dns[mid+i] - // [mid+1] => .6, [mid+2] => .7, [mid+3] => .8, - // [mid+4] => .9, [mid+5] => 1 - density := (smid * float64(i)) + est.p - // assign t-th portion of dns[mid+i] to ns[mid+i] - // [mid+1] => 6, [mid+2] => 7, [mid+3] => 8, - // [mid+4] => 9, [mid+5] => 10 - est.ns[est.mid+i] = density * float64(t) - } - } - return - } - - var k int64 - if x < est.q[0] { - est.q[0] = x // update min - k = 0 - } else if x > est.q[est.u-1] { - est.q[est.u-1] = x // update max - k = est.u - 2 - } else { - k = est.u - 2 - for i := int64(1); i <= est.u-2; i++ { - if x < est.q[i] { - k = i - 1 - break - } - } - } - - for i := k + 1; i < est.u; i++ { - est.n[i]++ - } - - // go.dev/play/p/wL0hHYIB5DT - // for i := 0; i < est.u; i++ { - // est.ns[i] += est.dns[i] - // } - - // go.dev/play/p/yY23exf-KXh - factor := float64(est.count) / float64(est.count-1) - for i := int64(0); i < est.u; i++ { // update desired marker positions - est.ns[i] *= factor - } - - for i := int64(1); i < est.u-1; i++ { // update intermediatories - d := est.ns[i] - float64(est.n[i]) - - if (d >= 1 && est.n[i+1]-est.n[i] > 1) || (d <= -1 && est.n[i-1]-est.n[i] < -1) { - dInt := sign2int(d) - qs := est.parabolicLocked(i, float64(dInt)) - if est.q[i-1] < qs && qs < est.q[i+1] { - est.q[i] = qs - } else { - est.q[i] = est.linearLocked(i, dInt) - } - est.n[i] += dInt - } - } -} - -// parabolicLocked computes the parabolic estimate. -func (est *p2) parabolicLocked(i int64, d float64) float64 { - qi := est.q[i] - qij := est.q[i+1] - qih := est.q[i-1] - ni := float64(est.n[i]) - nij := float64(est.n[i+1]) - nih := float64(est.n[i-1]) - return qi + - (d/(nij-nih))* - (((ni-nih+d)*(qij-qi)/(nij-ni))+ - ((nij-ni-d)*(qi-qih)/(ni-nih))) -} - -// linearLocked computes the linear estimate. -func (est *p2) linearLocked(i int64, d int64) float64 { - df := float64(d) - qi := est.q[i] - qd := est.q[i+d] - ni := float64(est.n[i]) - nd := float64(est.n[i+d]) - return qi + (df*(qd-qi))/(nd-ni) -} - -// Get the estimation for p. -func (est *p2) Get() int64 { - est.mu.RLock() - defer est.mu.RUnlock() - - c := est.count - - if c == 0 { - return 0 - } - - if c > est.u { - ms := est.q[est.mid] * 1000 - return int64(ms) - } - - slices.Sort(est.q[:c]) // go.dev/play/p/sCIM4AB1t6n - index := int(float64(c-1) * est.p) - ms := est.q[index] * 1000 - return int64(ms) -} - -// sign2int returns the sign of the float64 as an int. -func sign2int(d float64) int64 { - if d < 0 { - return -1 - } else if d > 0 { - return 1 - } else { - return 0 - } -} diff --git a/intra/core/pcap.go b/intra/core/pcap.go deleted file mode 100644 index 3b5a0b65..00000000 --- a/intra/core/pcap.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "encoding" - "encoding/binary" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// from: github.com/google/gvisor/blob/596e8d22/pkg/tcpip/link/sniffer/pcap.go - -type PcapHeader struct { - MagicNumber uint32 - VersionMajor uint16 - VersionMinor uint16 - Thiszone int32 - Sigfigs uint32 - Snaplen uint32 - Network uint32 -} - -type PcapPacket struct { - Timestamp time.Time - Packet *stack.PacketBuffer - MaxCaptureLen int -} - -var _ encoding.BinaryMarshaler = (*PcapPacket)(nil) - -func (p *PcapPacket) MarshalBinary() ([]byte, error) { - pkt := TrimmedClone(p.Packet) - defer pkt.DecRef() - packetSize := pkt.Size() - captureLen := p.MaxCaptureLen - if packetSize < captureLen { - captureLen = packetSize - } - b := make([]byte, 16+captureLen) - binary.LittleEndian.PutUint32(b[0:4], uint32(p.Timestamp.Unix())) - binary.LittleEndian.PutUint32(b[4:8], uint32(p.Timestamp.Nanosecond()/1000)) - binary.LittleEndian.PutUint32(b[8:12], uint32(captureLen)) - binary.LittleEndian.PutUint32(b[12:16], uint32(packetSize)) - w := tcpip.SliceWriter(b[16:]) - for _, v := range pkt.AsSlices() { - if captureLen == 0 { - break - } - if len(v) > captureLen { - v = v[:captureLen] - } - n, err := w.Write(v) - if err != nil { - panic(err) - } - captureLen -= n - } - return b, nil -} - -// trimmedClone clones the packet buffer to not modify the original. It trims -// anything before the network header. -func TrimmedClone(pkt *stack.PacketBuffer) *stack.PacketBuffer { - // We don't clone the original packet buffer so that the new packet buffer - // does not have any of its headers set. - // - // We trim the link headers from the cloned buffer as the sniffer doesn't - // handle link headers. - buf := pkt.ToBuffer() - buf.TrimFront(int64(len(pkt.VirtioNetHeader().Slice()))) - buf.TrimFront(int64(len(pkt.LinkHeader().Slice()))) - return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) -} diff --git a/intra/core/ping.go b/intra/core/ping.go deleted file mode 100644 index 3634ca8a..00000000 --- a/intra/core/ping.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "bytes" - "crypto/rand" - "errors" - "fmt" - mrand "math/rand/v2" - "net" - "net/netip" - "time" - - "github.com/celzero/firestack/intra/log" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" -) - -var ( - errNotICMPEchoReply = errors.New("ping: expecting echo reply") - errPacketConnNotNetConn = errors.New("net.PacketConn is neither net.Conn nor syscall.Conn") -) - -const ( - payloadSize = 16 // bytes - padlen = 0 // bytes - ttl = 64 - timeout = 3 * time.Second - protocolICMP = 1 - protocolIPv6ICMP = 58 -) - -// from: github.com/go-ping/ping/blob/caaf2b72ea5/ping.go -func Ping(pc net.PacketConn, ipp netip.AddrPort) (ok bool, rtt time.Duration, err error) { - v4 := ipp.Addr().Is4() - seq := 1 // todo: seq? - var typ icmp.Type = ipv4.ICMPTypeEcho - if !v4 { - typ = ipv6.ICMPTypeEchoRequest - } - proto := protocolICMP - if !v4 { - proto = protocolIPv6ICMP - } - - var tslen int - var data []byte - data, tslen, err = payload() - if err != nil { - return - } - msgid := mrand.IntN(65535) - msg := &icmp.Message{ - Type: typ, - Code: 0, - Body: &icmp.Echo{ - ID: msgid, - Seq: seq, - Data: data, - }, - } - var pkt []byte - pkt, err = msg.Marshal(nil) - if err != nil { - return - } - - pkt, _, err = Echo(pc, pkt, net.UDPAddrFromAddrPort(ipp), v4) - - if err != nil { - return - } - - var m *icmp.Message - if m, err = icmp.ParseMessage(proto, pkt); err != nil { - return - } - - if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply { - err = errNotICMPEchoReply - return - } - - end := time.Now() - switch reply := m.Body.(type) { - case *icmp.Echo: - // IDs will never match for userspace icmp - // github.com/go-ping/ping/blob/caaf2b72e/utils_linux.go#L13 - // github.com/tailscale/tailscale/blob/43138c7a5c/cmd/stunstamp/stunstamp_linux.go#L77 - // if reply.ID != msgid { - // return fmt.Errorf("icmp: reply from [%v/%v] id %d; want %d", - // ipp, from, reply.ID, msgid) - // } - - if len(reply.Data) < len(data) { - err = fmt.Errorf("icmp: insufficient reply data; %d != %d", len(reply.Data), len(data)) - return - } - - start := bytesToTime(reply.Data[:tslen]) - // TODO: ref kernel timestamping - // github.com/tailscale/tailscale/blob/43138c7a5c/cmd/stunstamp/stunstamp_linux.go#L279 - rtt = end.Sub(start) - ok = true - default: - err = fmt.Errorf("icmp: err reply type: '%T' '%v'", pkt, pkt) - } - return -} - -func Echo(pc net.PacketConn, pkt []byte, dst net.Addr, v4 bool) (reply []byte, from net.Addr, err error) { - var n int - - if ttlerr := setttl(pc, v4); ttlerr != nil { - log.D("core: icmp: setttl failed: %v", ttlerr) - } - - n, err = pc.WriteTo(pkt, dst) - logev(err)("core: icmp: egress: write(=> %v) ping; done %d/%d; err? %v", - dst, n, len(pkt), err) - if err != nil { - // TODO: unreachable reply? - return - } - - extend(pc) - n, from, err = pc.ReadFrom(pkt) - reply = pkt[:n] // trunc - - logev(err)("core: icmp: ingress: read(<= %v / %v) ping done; done %d; err? %v", - dst, from, n, err) - // TODO: on err, unreachable reply? - return -} - -func timeToBytes(t time.Time) []byte { - nsec := t.UnixNano() - b := make([]byte, 8) - for i := range uint8(8) { - b[i] = byte((nsec >> ((7 - i) * 8)) & 0xff) - } - return b -} - -func bytesToTime(b []byte) time.Time { - var nsec int64 - maxiter := uint8(8) - for i := range maxiter { - nsec += int64(b[i]) << ((7 - i) * 8) - } - return time.Unix(nsec/1000000000, nsec%1000000000) -} - -func setttl(c MinConn, v4 bool) (err error) { - if c, ok := c.(ControlConn); ok { - raw, ctlErr := c.SyscallConn() - if ctlErr != nil || raw == nil { - return OneErr(ctlErr, errNotSyscallConn) - } - var ttlErr error - ctlErr = raw.Control(func(fd uintptr) { - if v4 { - // err1 := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTTL, 1) - ttlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TTL, ttl) - } else { - // err1 := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVHOPLIMIT, 1) - ttlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_UNICAST_HOPS, ttl) - } - }) - return JoinErr(ctlErr, ttlErr) - } - - var raw4 *ipv4.PacketConn - var raw6 *ipv6.PacketConn - switch x := c.(type) { - case *icmp.PacketConn: - if v4 { - raw4 = x.IPv4PacketConn() - } else { - raw6 = x.IPv6PacketConn() - } - case *ipv4.PacketConn: - raw4 = x - case *ipv6.PacketConn: - raw6 = x - case net.PacketConn: - if _, ok := x.(net.Conn); ok { - // ipv[4|6].NewPacketConn panics if the - // passed net.PacketConn is not net.Conn - if v4 { - raw4 = ipv4.NewPacketConn(x) - } else { - raw6 = ipv6.NewPacketConn(x) - } - } // eventually returns error errPacketConnNotNetConn - default: // eventually returns error errPacketConnNotNetConn - } - - if raw4 != nil { - err1 := raw4.SetControlMessage(ipv4.FlagTTL, true) - err2 := raw4.SetTTL(ttl) - return JoinErr(err1, err2) - } - if raw6 != nil { - err1 := raw6.SetControlMessage(ipv6.FlagHopLimit, true) - err2 := raw6.SetHopLimit(ttl) - return JoinErr(err1, err2) - } - return errPacketConnNotNetConn -} - -func extend(c MinConn) { - if c != nil { - _ = c.SetDeadline(time.Now().Add(timeout)) - } -} - -func cleardeadline(c MinConn) { - if c != nil { - _ = c.SetDeadline(time.Time{}) - } -} - -func payload() (t []byte, tslen int, err error) { - randomPayload := make([]byte, payloadSize) - _, err = rand.Read(randomPayload[:]) - if err != nil { - return - } - ts := timeToBytes(time.Now()) - tslen = len(ts) - t = append(ts, randomPayload...) - if padlen > 0 { - t = append(t, bytes.Repeat([]byte{1}, padlen)...) - } - return -} diff --git a/intra/core/proto.go b/intra/core/proto.go deleted file mode 100644 index 919d0725..00000000 --- a/intra/core/proto.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "io" - "net" - "syscall" - "time" -) - -// from: github.com/eycorsican/go-tun2socks/blob/301549c435/core/conn.go#LL3C9-L3C9 - -// ref: cs.android.com/android/platform/superproject/+/android-latest-release:system/core/libcutils/include/private/android_filesystem_config.h;drc=e999f05f34e91a3a313ba7dd77bcf52b58a0841e -const ( - UNKNOWN_UID = -1 - UNKNOWN_UID_STR = "-1" - // ANDROID_UID = 0 - ANDROID_UID_STR = "0" - DNS_UID_STR = "1051" - UNSUPPORTED_NETWORK = -1 -) - -// TCPConn abstracts a TCP connection coming from TUN. This connection -// should be handled by a registered TCP proxy handler. -type TCPConn interface { - DuplexCloser - - // RemoteAddr returns the destination network address. - RemoteAddr() net.Addr - // LocalAddr returns the local client network address. - LocalAddr() net.Addr - - // confirms to protect.Conn - Write([]byte) (int, error) - Read([]byte) (int, error) - - // Implements MinConn and net.Conn - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} - -// UDPConn abstracts a UDP connection coming from TUN. This connection -// should be handled by a registered UDP proxy handler. -type UDPConn interface { - io.Closer - - // LocalAddr returns the local client network address. - LocalAddr() net.Addr - // RemoteAddr returns the destination network address. - RemoteAddr() net.Addr - - // confirms to protect.Conn - Write([]byte) (int, error) - Read([]byte) (int, error) - - // confirms to net.PacketConn - WriteTo([]byte, net.Addr) (int, error) - ReadFrom([]byte) (int, net.Addr, error) - - // Implements MinConn, net.Conn, and net.PacketConn - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} - -type DuplexCloser interface { - io.Closer - CloseRead() error - CloseWrite() error -} - -// DuplexConn represents a bidirectional stream socket. -type DuplexConn interface { - TCPConn - DuplexCloser - PoolableConn - KeepAliveConn -} - -// so it can be used by dialers/retrier.go -type ReadRetrierConn io.ReaderFrom - -type WriteRetrierConn io.WriterTo - -type RetrierConn interface { - ReadRetrierConn - WriteRetrierConn -} - -// so it can be pooled by ConnPool. -type PoolableConn syscall.Conn -type ControlConn = PoolableConn - -// KeepAliveConn supports keep-alive probes. -type KeepAliveConn interface { - SetKeepAlive(bool) error -} - -type ICMPConn interface { - ControlConn // see: ping.go:setttl - net.PacketConn -} - -// MinConn is a minimal connection interface that is -// a subset of both net.Conn and net.PacketConn. -type MinConn interface { - io.Closer - - LocalAddr() net.Addr - - // Doc copied from net.Conn: - // SetDeadline sets the read and write deadlines associated - // with the connection. It is equivalent to calling both - // SetReadDeadline and SetWriteDeadline. - // - // A deadline is an absolute time after which I/O operations - // fail instead of blocking. The deadline applies to all future - // and pending I/O, not just the immediately following call to - // Read or Write. After a deadline has been exceeded, the - // connection can be refreshed by setting a deadline in the future. - // - // If the deadline is exceeded a call to Read or Write or to other - // I/O methods will return an error that wraps os.ErrDeadlineExceeded. - // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). - // The error's Timeout method will return true, but note that there - // are other possible errors for which the Timeout method will - // return true even if the deadline has not been exceeded. - // - // An idle timeout can be implemented by repeatedly extending - // the deadline after successful Read or Write calls. - // - // A zero value for t means I/O operations will not time out. - SetDeadline(t time.Time) error - - // Doc copied from net.Conn: - // SetReadDeadline sets the deadline for future Read calls - // and any currently-blocked Read call. - // A zero value for t means Read will not time out. - SetReadDeadline(t time.Time) error - - // Doc copied from net.Conn: - // SetWriteDeadline sets the deadline for future Write calls - // and any currently-blocked Write call. - // Even if write times out, it may return n > 0, indicating that - // some of the data was successfully written. - // A zero value for t means Write will not time out. - SetWriteDeadline(t time.Time) error -} diff --git a/intra/core/rollwg.go b/intra/core/rollwg.go deleted file mode 100644 index 58b79026..00000000 --- a/intra/core/rollwg.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import "sync" - -// RollingWaitGroup is like sync.WaitGroup but rolls over to a new internal -// WaitGroup after each Wait() call. This allows reuse of the same RollingWaitGroup -// for multiple wait cycles, without having to create a new instance each time. -// Each Add() call returns a generation number that must be passed to the -// corresponding Done() call. This ensures that Done() calls from a previous -// generation do not affect the current generation. -type RollingWaitGroup struct { - mu sync.Mutex - generation uint - deltas [2]int - wgs [2]*sync.WaitGroup -} - -// Add adds non-negative delta to the current generation WaitGroup counter. -func (m *RollingWaitGroup) Add(delta uint16) bool { - d := int(delta) - if d < 0 { - return false - } - - m.mu.Lock() - defer m.mu.Unlock() - g := m.generation % 2 - if wg := m.wgs[g]; wg == nil { - m.wgs[g] = &sync.WaitGroup{} - } - m.wgs[g].Add(d) - m.deltas[g] += d - return true -} - -// Done decrements the WaitGroup counter for the current generation by one. -func (m *RollingWaitGroup) Done() { - m.mu.Lock() - g := m.generation % 2 - if m.deltas[g] <= 0 { - m.wgs[g] = nil // should already be nil, but just in case - m.mu.Unlock() - return - } - wg := m.wgs[g] - if wg == nil { - m.deltas[g] = 0 // should already be 0, but just in case - m.mu.Unlock() - return - } - - m.deltas[g] = m.deltas[g] - 1 - if m.deltas[g] == 0 { - m.wgs[g] = nil - m.generation++ - } - m.mu.Unlock() - - wg.Done() -} - -// Wait blocks until the WaitGroup counter for the current generation is zero. -// It then rolls over to a new generation. It is safe to call Wait concurrently -// with Add and Done. -func (m *RollingWaitGroup) Wait() { - m.mu.Lock() - g := m.generation % 2 - wg := m.wgs[g] - d := m.deltas[g] - if wg == nil || d <= 0 { - m.wgs[g] = nil // should already be nil, but just in case - m.deltas[g] = 0 // should already be 0, but just in case - m.mu.Unlock() - return - } - m.mu.Unlock() - - wg.Wait() -} - -// WouldWait returns true if the WaitGroup counter for the current generation -// is non-zero. -func (m *RollingWaitGroup) WouldWait() bool { - m.mu.Lock() - defer m.mu.Unlock() - g := m.generation % 2 - delta := m.deltas[g] - return delta > 0 -} diff --git a/intra/core/runmet.go b/intra/core/runmet.go deleted file mode 100644 index 06d4eea7..00000000 --- a/intra/core/runmet.go +++ /dev/null @@ -1,425 +0,0 @@ -// Copyright (c) 2026 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "fmt" - "math" - "runtime/metrics" - "strings" - "sync" - "time" -) - -// from pkg.go.dev/runtime/metrics#Read - -const ( - // ex: /cgo/go-to-c-calls:calls - MetCgo = "/cgo" - // ex: /cpu/classes/gc/mark/assist:cpu-seconds - MetCPU = "/cpu" - // ex: /gc/scan/globals:bytes - MetGC = "/gc" - // ex: /godebug/non-default-behavior/zipinsecurepath:events - MetDbg = "/godebug" - // ex: /memory/classes/total:bytes - MetMem = "/memory" - // ex: /sched/threads/total:threads - MetSched = "/sched" - // ex: /sync/mutex/wait/total:seconds - MetSync = "/sync" - - MetGCPauses = MetGC + "/pauses" - // prefix - MetGCHeap = MetGC + "/heap" - // histogram of scheduler latencies (e.g. time to schedule a goroutine) - MetSchedLatencies = MetSched + "/latencies" - MetSchedPausesStoppingGc = MetSched + "/pauses/stopping/gc" - MetSchedPausesStoppingOther = MetSched + "/pauses/stopping/other" - MetSchedPausesTotalGc = MetSched + "/pauses/total/gc" - MetSchedPausesTotalOther = MetSched + "/pauses/total/other" - - memoizationThreshold = 10 * time.Second -) - -type metricUnit = int - -const ( - unitUnknown = metricUnit(iota) - unitTemporal - unitCount - unitPercent - unitBytes - unitGcTime -) - -var ( - descs = metrics.All() - allsamples = make([]metrics.Sample, len(descs)) - - sb strings.Builder - lastcall time.Time - mu sync.Mutex // protects allsamples, last, sb -) - -func init() { - for i := range allsamples { - allsamples[i].Name = descs[i].Name - } - sb.Grow(len(allsamples) * 100) -} - -func Metrics() string { - mu.Lock() - defer mu.Unlock() - - if !lastcall.IsZero() && time.Since(lastcall) < memoizationThreshold { - return sb.String() - } - - lastcall = time.Now() - - sb.Reset() - sb.WriteString("\n") - - metrics.Read(allsamples) - - for _, sample := range allsamples { - name, value := sample.Name, sample.Value - - // skip debug - if strings.HasPrefix(name, MetCgo) || strings.HasPrefix(name, MetDbg) { - continue - } - - unit := "" - u := unitUnknown - namesplit := strings.Split(name, ":") - if len(namesplit) >= 2 { - name = namesplit[0] - unit = namesplit[1] - } - - if unit == "cpu-seconds" || unit == "seconds" { - u = unitTemporal - } else if unit == "count" || - unit == "cleanups" || - unit == "calls" || - unit == "gc-cycles" || - unit == "finalizers" || - unit == "objects" || - unit == "events" || - unit == "threads" || - unit == "goroutines" { - u = unitCount - } else if unit == "percent" { - u = unitPercent - } else if unit == "bytes" { - u = unitBytes - } else if unit == "gc-cycle" { - u = unitGcTime - } - - switch value.Kind() { - case metrics.KindUint64: - s := fmt.Sprintf("%s: %s\n", name, unit4int(value.Uint64(), u)) - sb.WriteString(s) - case metrics.KindFloat64: - s := fmt.Sprintf("%s: %s\n", name, unit4float(value.Float64(), u)) - sb.WriteString(s) - case metrics.KindFloat64Histogram: - sb.WriteString("-----------\n") - if strings.HasPrefix(name, MetGC) { - s := fmt.Sprintf("%s: hist(%s)", name, histo2str(value.Float64Histogram(), u, '\n')) - sb.WriteString(s) - if strings.HasPrefix(name, MetGCHeap) { - s := fmt.Sprintf("%s: percentiles(%s)", name, histo2Ps(value.Float64Histogram(), u, '\n')) - sb.WriteString(s) - s = fmt.Sprintf("\n%s: dist(%s)", name, histo2Ms(value.Float64Histogram(), u, '\n')) - sb.WriteString(s) - } - } else if strings.HasPrefix(name, MetSched) { - if strings.HasPrefix(name, MetSchedLatencies) || - strings.HasPrefix(name, MetSchedPausesStoppingGc) || - strings.HasPrefix(name, MetSchedPausesStoppingOther) || - strings.HasPrefix(name, MetSchedPausesTotalGc) || - strings.HasPrefix(name, MetSchedPausesTotalOther) { - s := fmt.Sprintf("%s: percentiles(%s)", name, histo2Ps(value.Float64Histogram(), u, '\n')) - sb.WriteString(s) - s = fmt.Sprintf("\n%s: dist(%s)", name, histo2Ms(value.Float64Histogram(), u, '\n')) - sb.WriteString(s) - } - } else { - s := fmt.Sprintf("%s: hist(%s)", name, histo2str(value.Float64Histogram(), u, '\n')) - sb.WriteString(s) - } - sb.WriteString("-----------\n") - case metrics.KindBad: - fallthrough - default: - // This may happen as new metrics get added. - s := fmt.Sprintf("%s: Unknown(%v)\n", name, value.Kind()) - sb.WriteString(s) - } - } - return sb.String() -} - -func unit4int(v uint64, u metricUnit) string { - switch u { - case unitTemporal: - return FmtSecs(int64(v)) // may wrap? - case unitPercent: - return fmt.Sprintf("%d%%", v) - case unitBytes: - return FmtBytes(v) - case unitGcTime: - return fmt.Sprintf("%d", v) - case unitCount: - return FmtWithSep(v, '_') - default: - return fmt.Sprintf("%d", v) - } -} - -// FmtWithSep formats a uint64 with one byte separators every 3 digits (e.g. 1234567 -> "1,234,567"). -func FmtWithSep(v uint64, sep byte) string { - s := fmt.Sprintf("%d", v) - n := len(s) - if n <= 3 { - return s - } - // pre-allocate exact size: n digits + (n-1)/3 separators - b := make([]byte, n+(n-1)/3) - for i, j, k := n-1, len(b)-1, 0; i >= 0; i, k = i-1, k+1 { - if k > 0 && k%3 == 0 { - b[j] = sep - j-- - } - b[j] = s[i] - j-- - } - return string(b) -} - -func histo2str(h *metrics.Float64Histogram, u metricUnit, sep byte) string { - var sb strings.Builder - sb.Grow(20 * len(h.Buckets)) - for i, b := range h.Buckets { - if i >= len(h.Buckets)-1 { - break - } - if h.Counts[i] == 0 { - continue - } - sb.WriteByte(sep) - s := fmt.Sprintf("%s:%s", unit4float(b, u), unit4int(h.Counts[i], unitCount)) - sb.WriteString(s) - } - return sb.String() -} - -// histo2Ps returns a string of percentiles (p10 ... p99.99) for the histogram, -// considering only buckets with non-zero counts. -func histo2Ps(h *metrics.Float64Histogram, u metricUnit, sep byte) string { - type bkt struct { - mid float64 - count uint64 - } - - var buckets []bkt - for i := 0; i < len(h.Counts); i++ { - if h.Counts[i] == 0 { - continue - } - lo, hi := h.Buckets[i], h.Buckets[i+1] - var mid float64 - switch { - case math.IsInf(lo, -1): - mid = hi - case math.IsInf(hi, 1): - mid = lo - default: - mid = (lo + hi) / 2 - } - buckets = append(buckets, bkt{mid, h.Counts[i]}) - } - if len(buckets) == 0 { - return "" - } - - var total uint64 - for _, b := range buckets { - total += b.count - } - - percentiles := []float64{10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 99, 99.9, 99.99} - - var sb strings.Builder - sb.Grow(len(percentiles) * 24) - - pi := 0 - var cumulative uint64 - for _, b := range buckets { - cumulative += b.count - for pi < len(percentiles) { - threshold := uint64(math.Ceil(percentiles[pi] / 100.0 * float64(total))) - if cumulative >= threshold { - sb.WriteByte(sep) - sb.WriteString(fmt.Sprintf("p%g=%s", percentiles[pi], unit4float(b.mid, u))) - pi++ - } else { - break - } - } - if pi >= len(percentiles) { - break - } - } - // fill any trailing percentiles with the last bucket's midpoint - last := buckets[len(buckets)-1] - for ; pi < len(percentiles); pi++ { - sb.WriteByte(sep) - sb.WriteString(fmt.Sprintf("p%g=%s", percentiles[pi], unit4float(last.mid, u))) - } - return sb.String() -} - -// histo2Ms returns a string with mean, median, mode, avg, min, max, variance, -// and standard deviation for the histogram, considering only buckets with -// non-zero counts. -// - mean / avg: weighted arithmetic mean (mid-point weighted by count) -// - median: mid-point of the bucket that crosses the 50th percentile -// - mode: mid-point of the bucket with the highest count -// - min / max: lower / upper boundary of the first / last non-zero bucket -// - variance / std: population variance and standard deviation (weighted) -func histo2Ms(h *metrics.Float64Histogram, u metricUnit, sepb byte) string { - type bkt struct { - lo, hi, mid float64 - count uint64 - } - - var buckets []bkt - for i := 0; i < len(h.Counts); i++ { - if h.Counts[i] == 0 { - continue - } - lo, hi := h.Buckets[i], h.Buckets[i+1] - var mid float64 - switch { - case math.IsInf(lo, -1): - mid = hi - case math.IsInf(hi, 1): - mid = lo - default: - mid = (lo + hi) / 2 - } - buckets = append(buckets, bkt{lo, hi, mid, h.Counts[i]}) - } - if len(buckets) == 0 { - return "" - } - - var total uint64 - for _, b := range buckets { - total += b.count - } - - // weighted mean (accounts for observation frequency per bucket) - var weightedSum float64 - for _, b := range buckets { - weightedSum += b.mid * float64(b.count) - } - mean := weightedSum / float64(total) - - // avg: unweighted mean of non-zero bucket midpoints - var midSum float64 - for _, b := range buckets { - midSum += b.mid - } - avg := midSum / float64(len(buckets)) - - // median: midpoint of the bucket crossing the 50th percentile - median := buckets[len(buckets)-1].mid - var cumulative uint64 - for _, b := range buckets { - cumulative += b.count - if cumulative >= uint64(math.Ceil(0.5*float64(total))) { - median = b.mid - break - } - } - - // mode: midpoint of the bucket with the highest count - mode := buckets[0].mid - maxCount := buckets[0].count - for _, b := range buckets[1:] { - if b.count > maxCount { - maxCount = b.count - mode = b.mid - } - } - - // min: lower bound of first non-zero bucket (use mid if -Inf) - minVal := buckets[0].lo - if math.IsInf(minVal, -1) { - minVal = buckets[0].mid - } - // max: upper bound of last non-zero bucket (use mid if +Inf) - maxVal := buckets[len(buckets)-1].hi - if math.IsInf(maxVal, 1) { - maxVal = buckets[len(buckets)-1].mid - } - - // population variance and standard deviation (weighted) - var varianceSum float64 - for _, b := range buckets { - diff := b.mid - mean - varianceSum += float64(b.count) * diff * diff - } - variance := varianceSum / float64(total) - stddev := math.Sqrt(variance) - sep := string(sepb) - - return fmt.Sprintf( - "%smean=%s%smedian=%s%smode=%s%savg=%s%smin=%s%smax=%s%svar=%s%sstd=%s%s", - sep, - unit4float(mean, u), - sep, - unit4float(median, u), - sep, - unit4float(mode, u), - sep, - unit4float(avg, u), - sep, - unit4float(minVal, u), - sep, - unit4float(maxVal, u), - sep, - unit4float(variance, u), - sep, - unit4float(stddev, u), - sep, - ) -} - -func unit4float(v float64, u metricUnit) string { - switch u { - case unitTemporal: - return FmtNanos(v) - case unitPercent: - return fmt.Sprintf("%.2f%%", v) - case unitBytes: - return FmtBytes(uint64(v)) - case unitGcTime: - return fmt.Sprintf("%.2f", v) - case unitCount: - return FmtWithSep(uint64(v), '_') - default: - return fmt.Sprintf("%.3g", v) - } -} diff --git a/intra/core/sched.go b/intra/core/sched.go deleted file mode 100644 index 662a2b4f..00000000 --- a/intra/core/sched.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "context" - "errors" - "sync" - "time" -) - -var errNewJob = errors.New("sched: replaced by newer job") -var errClearJob = errors.New("sched: job cleared") - -type CheckIn func(t time.Time) -type Job func() error -type JobShift func(next CheckIn) error - -type ctl struct { - who context.Context - cancel context.CancelCauseFunc -} - -type Scheduler struct { - ctx context.Context - - mu sync.Mutex - jobctl map[string]*ctl -} - -func NewScheduler(ctx context.Context) *Scheduler { - return &Scheduler{ - ctx: ctx, - jobctl: make(map[string]*ctl), - } -} - -// Retry runs f at time t, retries times on errs, with retryCount*period between retries. -func (s *Scheduler) Retry(id string, t time.Time, f Job, retries uint16, multiplier time.Duration) context.Context { - retrierCtx, retrierDone := context.WithCancelCause(s.ctx) - - Go("retrier."+id, func() { - var errs error - - defer func() { - retrierDone(errs) - }() - - for i := uint16(0); i < retries; i++ { - select { - case <-retrierCtx.Done(): - errs = JoinErr(errs, context.Cause(retrierCtx)) - return // cancelled - default: - } - - ctx := s.At(id, t, f) // do f at t - - <-ctx.Done() // await f - - next := time.Duration(i+1) * multiplier - t = time.Now().Add(next) - - if err := context.Cause(ctx); err == nil { - errs = nil - return // ok - } else if errors.Is(err, errNewJob) { - errs = JoinErr(errs, err) - return // new job replaced this one - } else if errors.Is(err, errClearJob) { - errs = JoinErr(errs, err) - return // job cleared - } else { - errs = JoinErr(errs, err) - continue // retry - } - } - }) - - return retrierCtx -} - -// Clear cancels the job with id. -func (s *Scheduler) Clear(ids ...string) int { - s.mu.Lock() - defer s.mu.Unlock() - - n := 0 - if len(ids) <= 0 { // clear all - for _, c := range s.jobctl { - if c != nil { - n++ - c.cancel(errClearJob) - } - } - clear(s.jobctl) - } else { - for _, id := range ids { - if c := s.jobctl[id]; c != nil { - n++ - c.cancel(errClearJob) - delete(s.jobctl, id) - } - } - } - return n -} - -func (s *Scheduler) Shift(id string, t time.Time, f JobShift) context.Context { - return s.At(id, t, func() error { - return f(func(at time.Time) { - s.Shift(id, at, f) - }) - }) -} - -// At runs f at time t; accepts a Context to cancel it. -func (s *Scheduler) At(id string, t time.Time, f Job) context.Context { - s.mu.Lock() - ctx, done := context.WithCancelCause(s.ctx) - if c := s.jobctl[id]; c != nil { - c.cancel(errNewJob) // dispose existing job - } - s.jobctl[id] = &ctl{who: ctx, cancel: done} - s.mu.Unlock() - - Go("at."+id, func() { - var cause error - - defer func() { - s.mu.Lock() - if c := s.jobctl[id]; c != nil && c.who == ctx { - delete(s.jobctl, id) - cause = JoinErr(cause, errNewJob) - } - s.mu.Unlock() - done(cause) - }() - - select { - case <-s.ctx.Done(): - cause = context.Cause(s.ctx) - return - case <-time.After(time.Until(t)): - cause = f() - return - } - }) - return ctx -} diff --git a/intra/core/sigcond.go b/intra/core/sigcond.go deleted file mode 100644 index edb2f629..00000000 --- a/intra/core/sigcond.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "sync/atomic" - "time" -) - -// A "signalable boolean". This is like sync.Cond but light-weight and -// only allows a single state transition: false -> true. Once signalled, -// it stays signalled forever. Waiters wait until the condition is true. -// This is useful for one-time events like "the connection is closed". -// -// It is safe for multiple goroutines to call Signal concurrently; only -// one will succeed and the others will be no-ops. -// -// It is safe for multiple goroutines to call Wait concurrently; all -// will be woken when the condition is signalled. -// -// It is safe for one goroutine to call Signal while other goroutines -// are calling Wait. -// -// It is not safe to reuse a SigCond after signalling it. -type SigCond struct { - c chan struct{} // always unbuffered - b atomic.Bool -} - -func NewSigCond() *SigCond { - return &SigCond{ - c: make(chan struct{}), - } -} - -// Cond returns true if the signal has been fired. -func (sc *SigCond) Cond() bool { - return sc.b.Load() -} - -// Wait waits until the condition is true. -// If the condition is already true, it returns immediately. -func (sc *SigCond) Wait() { - if sc.b.Load() { - return - } - <-sc.c -} - -// TryWait waits until signalled, or until max time has elapsed. -// It returns true if signal fires, false if timeout elapsed. -// If already signalled, it returns immediately. -func (sc *SigCond) TryWait(timeout time.Duration) (fired bool) { - if sc.b.Load() { - return true - } - select { - case <-sc.c: - return true - case <-time.After(timeout): - return false - } -} - -// Signal sets the condition to true and wakes all waiters, if any. -func (sc *SigCond) Signal() (fired bool) { - if sc.b.Swap(true) { // already true - return false - } - close(sc.c) - return true -} diff --git a/intra/core/slices.go b/intra/core/slices.go deleted file mode 100644 index ff1efd31..00000000 --- a/intra/core/slices.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "math/rand" - "slices" -) - -// flattens and returns a (stable) copy with dups removed, if any. -// go.dev/play/p/OzJs4s6XvQe -func CopyUniq[T comparable](a ...[]T) (out []T) { - out = make([]T, 0) - if len(a) <= 0 { - return - } - acc := make(map[T]struct{}, 0) - for _, x := range a { - for _, xx := range x { - if _, ok := acc[xx]; ok { - continue - } - // maintain incoming order - out = append(out, xx) - acc[xx] = struct{}{} - } - } - return -} - -type TestFn[T any] func(T) bool - -func FilterLeft[T any](arr []T, test TestFn[T]) (out []T) { - out = make([]T, 0) - for _, x := range arr { - if test(x) { - out = append(out, x) - } - } - return out -} - -func ShuffleInPlace[T any](c []T) []T { - if len(c) <= 1 { - return c - } - rand.Shuffle(len(c), func(i, j int) { - c[i], c[j] = c[j], c[i] - }) - return c -} - -func ChooseOne[T any](c []T) (zz T) { - if len(c) <= 0 { - return zz - } - return c[rand.Intn(len(c))] -} - -func FirstOf[T any](c []T) (zz T) { - if len(c) <= 0 { - return zz - } - return c[0] -} - -// sorts arr x in ascending order. less(a, b) < 0 when a < b, -// less(a, b) > 0 when a > b, and less(a, b) == 0 when a == b. -func Sort[T any](arr []T, less func(a, b T) int) []T { - slices.SortStableFunc(arr, less) - return arr -} - -func Map[T, U any](arr []T, transform func(T) U) (out []U) { - out = make([]U, 0) - for _, x := range arr { - out = append(out, transform(x)) - } - return out -} - -// WithoutElem returns arr (may be a copy) removing all occurrences of elem. -func WithoutElem[T comparable](arr []T, elem T) (out []T) { - if !slices.Contains(arr, elem) { - return arr - } - - out = make([]T, 0) - for _, x := range arr { - if x == elem { - continue - } - out = append(out, x) - } - return out -} - -// WithElem returns arr with elem added to it. -func WithElem[T comparable](s []T, add T) []T { - if len(s) <= 0 { - return []T{add} - } - if slices.Contains(s, add) { - return s - } - return append(s, add) -} - -func WithoutNils[T any](arr []T) (out []T) { - for _, x := range arr { - if IsNil(x) { - continue - } - out = append(out, x) - } - return out -} - -func IsAny[T any](arr []T, test TestFn[T]) bool { - return slices.ContainsFunc(arr, test) -} - -func IsAll[T any](arr []T, test TestFn[T]) bool { - for _, x := range arr { - if !test(x) { - return false - } - } - return true -} diff --git a/intra/core/sockopt.go b/intra/core/sockopt.go deleted file mode 100644 index 24a3075c..00000000 --- a/intra/core/sockopt.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "net" - "syscall" - - "github.com/celzero/firestack/intra/log" - "golang.org/x/sys/unix" -) - -// github.com/tailscale/tailscale/blob/65fe0ba7b5/cmd/derper/derper.go#L75-L78 -// blog.cloudflare.com/when-tcp-sockets-refuse-to-die/ -// shorter count / interval for faster drops -const ( - defaultIdle = 600 // in seconds - defaultCount = 4 // unacknowledged probes - defaultInterval = 5 // in seconds - usrTimeoutMillis = 1000*defaultIdle + (defaultInterval * defaultCount) -) - -var ( - kacfg = net.KeepAliveConfig{ - Enable: true, - Idle: defaultIdle, - Count: defaultCount, - Interval: defaultInterval, - } -) - -func SetKeepAliveConfig(c MinConn) bool { - if tc, ok := c.(*net.TCPConn); ok { - return tc.SetKeepAliveConfig(kacfg) == nil - } - return false -} - -func SetTimeoutSockOpt(c MinConn, timeoutms int) bool { - if tc, ok := c.(PoolableConn); ok { - id := conn2str(c) - rawConn, err := tc.SyscallConn() - if err != nil || rawConn == nil { - return false - } - ok := true - err = rawConn.Control(func(fd uintptr) { - sock := int(fd) - // code.googlesource.com/google-api-go-client/+/master/transport/grpc/dial_socketopt.go#30 - if err := unix.SetsockoptInt(sock, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, timeoutms); err != nil { - log.D("core: sockopt: set TCP_USER_TIMEOUT %s (%d) failed: %dms, %v", id, sock, timeoutms, err) - ok = false - } - }) - if err != nil { - log.E("core: sockopt: %s RawConn.Control() err: %v", id, err) - ok = false - } - return ok - } - return false -} - -func DisableKeepAlive(c MinConn) (done bool) { - if sc, ok := c.(PoolableConn); ok { - raw, err := sc.SyscallConn() - if raw == nil || err != nil { - return - } - err = raw.Control(func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, 0) - }) - return err == nil - } - return -} - -// SetKeepAliveConfigSockOpt sets for a TCP connection, SO_KEEPALIVE, -// TCP_KEEPIDLE, TCP_KEEPINTVL, TCP_KEEPCNT, TCP_USER_TIMEOUT. -// args is optional, and should be in the order of idle, interval, count. -func SetKeepAliveConfigSockOpt(c MinConn, args ...int) (ok bool) { - switch pc := c.(type) { - case *net.UDPConn: - return - case PoolableConn: - id := conn2str(c) - - rawConn, err := pc.SyscallConn() - if err != nil || rawConn == nil { - ok = false - return ok - } - - idle := defaultIdle // secs - interval := defaultInterval // secs - count := defaultCount - if len(args) >= 1 && args[0] > 0 { - idle = args[0] - } - if len(args) >= 2 && args[1] > 0 { - interval = args[1] - } - if len(args) >= 3 && args[2] > 0 { - count = args[2] - } - usertimeoutms := idle*1000 + (interval * count) // millis - - ok = true - err = rawConn.Control(func(fd uintptr) { - sock := int(fd) - if err := syscall.SetsockoptInt(sock, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(true)); err != nil { - log.V("core: sockopt: set SO_KEEPALIVE %s (%d) failed: %v", id, sock, err) - ok = false - } - if err := syscall.SetsockoptInt(sock, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, idle); err != nil { - log.V("core: sockopt: set TCP_KEEPIDLE %s (%d) failed: %ds, %v", id, sock, idle, err) - ok = false - } - if err := syscall.SetsockoptInt(sock, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, interval); err != nil { - log.V("core: sockopt: set TCP_KEEPINTVL %s (%d) failed: %ds, %v", id, sock, interval, err) - ok = false - } - if err := syscall.SetsockoptInt(sock, syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT, count); err != nil { - log.V("core: sockopt: set TCP_KEEPCNT %s (%d) failed: #%d, %v", id, sock, count, err) - ok = false - } - // code.googlesource.com/google-api-go-client/+/master/transport/grpc/dial_socketopt.go#30 - if err := unix.SetsockoptInt(sock, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, usertimeoutms); err != nil { - log.V("core: sockopt: set TCP_USER_TIMEOUT %s (%d) failed: %dms, %v", id, sock, usertimeoutms, err) - ok = false - } - }) - if err != nil { - log.E("core: sockopt: %s RawConn.Control() err: %v", id, err) - ok = false - } - } - return ok -} - -func boolint(b bool) int { - if b { - return 1 - } - return 0 -} diff --git a/intra/core/typ.go b/intra/core/typ.go deleted file mode 100644 index ec09afe5..00000000 --- a/intra/core/typ.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "fmt" - "reflect" -) - -func Loc(x any) uintptr { - if x == nil { - return 0 - } - v := reflect.ValueOf(x) - k := v.Kind() - switch k { - // [Chan], [Func], [Map], [Pointer], [Slice], [String] or [UnsafePointer] - case reflect.Pointer, reflect.UnsafePointer, reflect.String, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice: - return v.Pointer() - } - return 0 -} - -func LocStr(x any) string { - return fmt.Sprintf("%x", Loc(x)) -} - -// may panic or return false if x is not addressable -func IsNotNil(x any) bool { - return !IsNil(x) -} - -// IsNil reports whether x is nil if its Chan, Func, Map, -// Pointer, UnsafePointer, Interface, and Slice; -// may panic or return false if x is not addressable -func IsNil(x any) bool { - // from: stackoverflow.com/a/76595928 - if x == nil { - return true - } - v := reflect.ValueOf(x) - k := v.Kind() - switch k { - case reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice: - return v.IsNil() - } - return false -} - -func TypeEq(a, b any) bool { - if IsNil(a) { - return false - } else if IsNil(b) { - return false - } - return reflect.TypeOf(a) == reflect.TypeOf(b) -} - -func LocEq(a, b any) bool { - loca := Loc(a) - locb := Loc(b) - return loca > 0 && locb > 0 && loca == locb -} - -func IsZero(x any) bool { - if IsNil(x) { - return true - } - v := reflect.ValueOf(x) - // panics if x == nil: go.dev/play/p/jcJzdHF0JCq - return v.IsZero() -} diff --git a/intra/core/undelegated.go b/intra/core/undelegated.go deleted file mode 100644 index ace01f5b..00000000 --- a/intra/core/undelegated.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -var UndelegatedDomains = []string{ - "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - "0.in-addr.arpa", - "1", - "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - "10.in-addr.arpa", - "100.100.in-addr.arpa", - "100.51.198.in-addr.arpa", - "101.100.in-addr.arpa", - "102.100.in-addr.arpa", - "103.100.in-addr.arpa", - "104.100.in-addr.arpa", - "105.100.in-addr.arpa", - "106.100.in-addr.arpa", - "107.100.in-addr.arpa", - "108.100.in-addr.arpa", - "109.100.in-addr.arpa", - "110.100.in-addr.arpa", - "111.100.in-addr.arpa", - "112.100.in-addr.arpa", - "113.0.203.in-addr.arpa", - "113.100.in-addr.arpa", - "114.100.in-addr.arpa", - "115.100.in-addr.arpa", - "116.100.in-addr.arpa", - "117.100.in-addr.arpa", - "118.100.in-addr.arpa", - "119.100.in-addr.arpa", - "120.100.in-addr.arpa", - "121.100.in-addr.arpa", - "122.100.in-addr.arpa", - "123.100.in-addr.arpa", - "124.100.in-addr.arpa", - "125.100.in-addr.arpa", - "126.100.in-addr.arpa", - "127.100.in-addr.arpa", - "127.in-addr.arpa", - "16.172.in-addr.arpa", - "168.192.in-addr.arpa", - "17.172.in-addr.arpa", - "18.172.in-addr.arpa", - "19.172.in-addr.arpa", - "2.0.192.in-addr.arpa", - "20.172.in-addr.arpa", - "21.172.in-addr.arpa", - "22.172.in-addr.arpa", - "23.172.in-addr.arpa", - "24.172.in-addr.arpa", - "25.172.in-addr.arpa", - "254.169.in-addr.arpa", - "255.255.255.255.in-addr.arpa", - "26.172.in-addr.arpa", - "27.172.in-addr.arpa", - "28.172.in-addr.arpa", - "29.172.in-addr.arpa", - "30.172.in-addr.arpa", - "31.172.in-addr.arpa", - "64.100.in-addr.arpa", - "65.100.in-addr.arpa", - "66.100.in-addr.arpa", - "67.100.in-addr.arpa", - "68.100.in-addr.arpa", - "69.100.in-addr.arpa", - "70.100.in-addr.arpa", - "71.100.in-addr.arpa", - "72.100.in-addr.arpa", - "73.100.in-addr.arpa", - "74.100.in-addr.arpa", - "75.100.in-addr.arpa", - "76.100.in-addr.arpa", - "77.100.in-addr.arpa", - "78.100.in-addr.arpa", - "79.100.in-addr.arpa", - "8.b.d.0.1.0.0.2.ip6.arpa", - "8.e.f.ip6.arpa", - "80.100.in-addr.arpa", - "81.100.in-addr.arpa", - "82.100.in-addr.arpa", - "83.100.in-addr.arpa", - "84.100.in-addr.arpa", - "85.100.in-addr.arpa", - "86.100.in-addr.arpa", - "87.100.in-addr.arpa", - "88.100.in-addr.arpa", - "89.100.in-addr.arpa", - "9.e.f.ip6.arpa", - "90.100.in-addr.arpa", - "91.100.in-addr.arpa", - "92.100.in-addr.arpa", - "93.100.in-addr.arpa", - "94.100.in-addr.arpa", - "95.100.in-addr.arpa", - "96.100.in-addr.arpa", - "97.100.in-addr.arpa", - "98.100.in-addr.arpa", - "99.100.in-addr.arpa", - "a.e.f.ip6.arpa", - ".airdream", - ".api", - "b.e.f.ip6.arpa", - ".bbrouter", - ".belkin", - ".bind", - ".blinkap", - ".corp", - "d.f.ip6.arpa", - ".davolink", - ".dearmyrouter", - ".dhcp", - ".dlink", - ".domain", - ".envoy", - ".example", - "fritz.box", // github.com/celzero/rethink-app/issues/1298 - "f.f.ip6.arpa", - ".grp", - ".gw==", - ".home", - ".hub", - ".internal", - ".intra", - ".intranet", - ".invalid", - ".ksyun", - ".lan", - ".loc", - ".local", - ".localdomain", - ".localhost", - ".localnet", - ".modem", - ".mynet", - ".myrouter", - ".novalocal", - // "onion", github.com/celzero/rethink-app/issues/1259 - ".openstacklocal", - ".priv", - ".private", - ".prv", - ".router", - ".telus", - ".test", - ".totolink", - ".wlan_ap", - ".workgroup", - ".zghjccbob3n0", -} diff --git a/intra/core/version.go b/intra/core/version.go deleted file mode 100644 index 9317f386..00000000 --- a/intra/core/version.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -// ref: github.com/xjasonlyu/tun2socks/blob/bf745d0e0e5d/internal/version/version.go#L1 -import ( - "fmt" - "runtime" - "runtime/debug" - "strings" -) - -var buildinfo, _ = debug.ReadBuildInfo() - -var ( - // Commit set at link time by git rev-parse --short HEAD - Commit string - // Date set at link time by date -u +'%Y%m%d%H%M%S' - Date string -) - -func Version() string { - return fmt.Sprintf("%s (%s/%s@%s)", stamp(), runtime.GOOS, runtime.GOARCH, runtime.Version()) -} - -func stamp() string { - path := "" - v := "v" + Date + "-" + Commit - if buildinfo != nil { // github.com/golang/go/issues/50603 - path = buildinfo.Main.Path + "@" - if len(buildinfo.Main.Version) > 0 && !strings.Contains(buildinfo.Main.Version, "devel") { - v = "v" + buildinfo.Main.Version - } - } - return path + v -} - -func BuildInfo() string { - if buildinfo == nil { - return "unknown" - } - return buildinfo.String() -} diff --git a/intra/core/volatile.go b/intra/core/volatile.go deleted file mode 100644 index d045156d..00000000 --- a/intra/core/volatile.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -package core - -import "sync/atomic" - -// go.dev/play/p/bdaGAB_xsLN - -// Volatile is a non-panicking, non-atomic atomic.Value. -type Volatile[T any] atomic.Value - -// NewVolatile returns a new Volatile with the value t. -// Panics if t is nil. -func NewVolatile[T any](t T) *Volatile[T] { - v := NewZeroVolatile[T]() - // safe to call Store but not any other func on v from this ctor. - v.Store(t) - return v -} - -// NewVolatile returns a new uninitialized Volatile. -func NewZeroVolatile[T any]() *Volatile[T] { - // do not call into any func on the returned instance from this ctor. - return new(Volatile[T]) -} - -func (a *Volatile[T]) safeLoad() (t T) { - if a == nil { - return // zz - } - aa := (*atomic.Value)(a) - if x := aa.Load(); x != nil && !IsNil(x) { - t, _ = x.(T) - } - return -} - -// Load returns the value of a. May return zero value. -// This func is atomic. -func (a *Volatile[T]) Load() (t T) { - if a == nil { - return // zz - } - return a.safeLoad() -} - -// Store stores the value t; creates a new Volatile[T] if t is nil. -// If a is nil, does nothing. This func is not atomic. -func (a *Volatile[T]) Store(t T) { - if a == nil { - return // zz - } - a.safeStore(a.Load(), t) -} - -// safeStore stores new in a, iff old & new are of the same concrete type. -// If old & new are not of the same concrete type, it creates a Volatile with new. -// If new is nil, sets a to NewZeroVolatile[T]. -// If a is nil, does nothing. This func is not atomic. -func (a *Volatile[T]) safeStore(old, new T) { - if a == nil { - return - } - if IsNil(new) { // nothing to store - *a = *NewZeroVolatile[T]() - return - } - - // old may be a diff concrete type than new - if IsNil(old) || !TypeEq(old, new) { - *a = *NewZeroVolatile[T]() - } else if LocEq(old, new) { - return // old is same as new; no-op - } - aa := (*atomic.Value)(a) - aa.Store(new) // new is a not nil -} - -// Cas compares and swaps the value of a with new, returns true if the value was swapped. -// If new is nil, returns true; and sets a to NewZeroVolatile[T] non-atomically. -// If a is nil or old & new are not of same concrete type, returns false. -func (a *Volatile[T]) Cas(old, new T) (ok bool) { - if a == nil { - return - } - if IsNil(new) { - *a = *NewZeroVolatile[T]() - return true - } - if !TypeEq(old, new) || LocEq(old, new) { - return - } - - aa := (*atomic.Value)(a) - return aa.CompareAndSwap(old, new) -} - -// Swap assigns new and returns the old value, atomically. -// If a is nil, it returns zero value. -// If new is nil, returns old value; and sets a to NewZeroVolatile[T]. -// If old & new are not of the same concrete type, it panics. -func (a *Volatile[T]) Swap(new T) (old T) { - if a == nil { - return // zz - } - if IsNil(new) { - old = a.safeLoad() - - *a = *NewZeroVolatile[T]() - return - } - if LocEq(old, new) { - return old - } - - aa := (*atomic.Value)(a) - old, _ = aa.Swap(new).(T) - return old -} - -// Tango retrieves old value and loads in new non-atomically. -// If a is nil, returns zero value. -// If new is nil, returns zero value; and sets a to NewZeroVolatile[T]. -// old & new need not be the same concrete type. This func is not atomic. -func (a *Volatile[T]) Tango(new T) (old T) { - if a == nil { - return // zz - } - - defer a.safeStore(old, new) - - return a.safeLoad() -} diff --git a/intra/core/volatileflow.go b/intra/core/volatileflow.go deleted file mode 100644 index 0606033d..00000000 --- a/intra/core/volatileflow.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -package core - -import ( - "context" - "slices" - "sync" - "time" -) - -type FlowFunc[T any] func(v T) - -type FlowOn[T any] struct { - ctx context.Context - f *FlowFunc[T] -} - -func (f FlowOn[T]) flow(v T) (flowed bool) { - on := f.f - if on == nil { - return false - } - select { - case <-f.ctx.Done(): - default: - (*on)(v) - return true - } - return false -} - -func (f FlowOn[T]) obsolete() bool { - select { - case <-f.ctx.Done(): - return true - default: - } - return false -} - -type Flow[T any] struct { - ctx context.Context - - v *Volatile[T] - c chan T - - fmu sync.RWMutex - o []FlowOn[T] -} - -func NewForeverFlow[T any](v T) *Flow[T] { - return NewFlowFor(context.Background(), NewVolatile(v)) -} - -func NewFlowFor[T any](ctx context.Context, v *Volatile[T]) *Flow[T] { - if v == nil || ctx == nil { - return nil - } - - f := &Flow[T]{ - v: v, - o: make([]FlowOn[T], 0), - c: make(chan T), - ctx: ctx, - } - Gx("core.flow", f.stream) - return f -} - -func (f *Flow[T]) stream() { - // TODO: defer close(f.c); see pub() - for { - select { - case <-f.ctx.Done(): - return - case v := <-f.c: - Gx("flow.stream", func() { - notflowing := make(map[FlowOn[T]]struct{}, 0) - for _, o := range f.observers() { - if ok := o.flow(v); !ok { - notflowing[o] = struct{}{} - } - } - f.removeFinallys(notflowing) - }) - case <-time.Tick(3 * time.Hour): - Gx("flow.stream.tick", func() { - notflowing := make(map[FlowOn[T]]struct{}) - for _, o := range f.observers() { - if o.obsolete() { - notflowing[o] = struct{}{} - } - } - f.removeFinallys(notflowing) - }) - } - } -} - -func (f *Flow[T]) removeFinallys(obsolete map[FlowOn[T]]struct{}) { - obssz := len(obsolete) - if obssz <= 0 { - return - } - - f.fmu.Lock() - defer f.fmu.Unlock() - - cursz := len(f.o) - if cursz <= 0 { - return - } - - flowing := make([]FlowOn[T], 0, cursz) - for _, o := range f.o { - if _, ok := obsolete[o]; ok { - continue - } - flowing = append(flowing, o) - } - f.o = flowing -} - -func (f *Flow[T]) observers() []FlowOn[T] { - f.fmu.RLock() - defer f.fmu.RUnlock() - return slices.Clone(f.o) -} - -func (f *Flow[T]) pub(v T) { - select { - case <-f.ctx.Done(): - return - default: - select { - case <-f.ctx.Done(): - return - case f.c <- v: // f.c never closed - } - } -} - -// On (is a hot flow) which immediately calls o (in a separate goroutine) -// and later calls o on changes to the underlying Volatile variable. -func (f *Flow[T]) On(until context.Context, o FlowFunc[T]) { - f.fmu.Lock() - defer f.fmu.Unlock() - on := FlowOn[T]{until, &o} - f.o = append(f.o, on) - Gx("flow.on", func() { on.flow(f.v.Load()) }) -} - -func (f *Flow[T]) Store(v T) { - defer f.pub(v) - f.v.Store(v) -} - -func (f *Flow[T]) Load() T { - return f.v.Load() -} - -func (f *Flow[T]) Swap(new T) (old T) { - defer f.pub(new) - return f.v.Swap(new) -} - -func (f *Flow[T]) Tango(new T) (old T) { - defer func() { - if !LocEq(new, old) { - f.pub(new) - } - }() - - return f.v.Tango(new) -} - -func (f *Flow[T]) CompareAndSwap(old, new T) bool { - return f.Cas(old, new) -} - -func (f *Flow[T]) Cas(old, new T) (success bool) { - defer func() { - if success { - f.pub(new) - } - }() - - return f.v.Cas(old, new) -} diff --git a/intra/core/weakref.go b/intra/core/weakref.go deleted file mode 100644 index 579f3142..00000000 --- a/intra/core/weakref.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package core - -import ( - "errors" - "sync" - "weak" -) - -var errNoCreat = errors.New("weak: create fn nil") - -type reffactory[V any] func() *V -type reftest[V any] func(*V) bool - -func refpass[V any](_ *V) bool { return true } - -type WeakRef[V any] struct { - mu sync.RWMutex - weak weak.Pointer[V] - creat reffactory[V] - test reftest[V] -} - -func NewWeakRef[V any](creat reffactory[V], test reftest[V]) (*WeakRef[V], error) { - if creat == nil { - return nil, errNoCreat - } - if test == nil { - test = refpass - } - - return &WeakRef[V]{ - creat: creat, - test: test, - }, nil -} - -func (w *WeakRef[V]) load() (v *V, valid bool) { - defer func() { // test without lock held - valid = v != nil && w.test(v) - }() - - w.mu.RLock() - defer w.mu.RUnlock() - v = w.weak.Value() - return -} - -func (w *WeakRef[V]) storeLocked() (v *V) { - v = w.creat() - w.weak = weak.Make(v) - return -} - -func (w *WeakRef[V]) loadOrStore() (v *V, valid bool) { - if v, valid = w.load(); valid { - return - } - - defer func() { // test without lock held - valid = v != nil && w.test(v) - }() - - w.mu.Lock() - defer w.mu.Unlock() - if v = w.weak.Value(); v == nil { // gc won - v = w.storeLocked() // new v - } // else: use existing v - return -} - -func (w *WeakRef[V]) Ref() (v *V, valid bool) { - return w.loadOrStore() -} - -func (w *WeakRef[V]) Get() (zz V, valid bool) { - v, valid := w.loadOrStore() - if v == nil || IsNil(v) { - return zz, false - } - return *v, valid -} - -func (w *WeakRef[V]) Load() (v V) { - v, _ = w.Get() - return -} diff --git a/intra/core/wire/header.go b/intra/core/wire/header.go deleted file mode 100644 index f0182811..00000000 --- a/intra/core/wire/header.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "errors" - "math" -) - -const igmpHeaderLength = 8 -const tcpHeaderLength = 20 -const sctpHeaderLength = 12 - -// maxPacketLength is the largest length that all headers support. -// IPv4 headers using uint16 for this forces an upper bound of 64KB. -const maxPacketLength = math.MaxUint16 - -var ( - // errSmallBuffer is returned when Marshal receives a buffer - // too small to contain the header to marshal. - errSmallBuffer = errors.New("buffer too small") - // errLargePacket is returned when Marshal receives a payload - // larger than the maximum representable size in header - // fields. - errLargePacket = errors.New("packet too large") -) - -// Header is a packet header capable of marshaling itself into a byte -// buffer. -type Header interface { - // Len returns the length of the marshaled packet. - Len() int - // Marshal serializes the header into buf, which must be at - // least Len() bytes long. Implementations of Marshal assume - // that bytes after the first Len() are payload bytes for the - // purpose of computing length and checksum fields. Marshal - // implementations must not allocate memory. - Marshal(buf []byte) error -} - -// HeaderChecksummer is implemented by Header implementations that -// need to do a checksum over their payloads. -type HeaderChecksummer interface { - Header - - // WriteCheck writes the correct checksum into buf, which should - // be be the already-marshalled header and payload. - WriteChecksum(buf []byte) -} - -// Generate generates a new packet with the given Header and -// payload. This function allocates memory, see Header.Marshal for an -// allocation-free option. -func Generate(h Header, payload []byte) []byte { - hlen := h.Len() - buf := make([]byte, hlen+len(payload)) - - copy(buf[hlen:], payload) - h.Marshal(buf) - - if hc, ok := h.(HeaderChecksummer); ok { - hc.WriteChecksum(buf) - } - - return buf -} diff --git a/intra/core/wire/icmp4.go b/intra/core/wire/icmp4.go deleted file mode 100644 index de0ce08f..00000000 --- a/intra/core/wire/icmp4.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" -) - -// icmp4HeaderLength is the size of the ICMPv4 packet header, not -// including the outer IP layer or the variable "response data" -// trailer. -const icmp4HeaderLength = 4 - -// ICMP4Type is an ICMPv4 type, as specified in -// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml -type ICMP4Type uint8 - -const ( - ICMP4EchoReply ICMP4Type = 0x00 - ICMP4EchoRequest ICMP4Type = 0x08 - ICMP4Unreachable ICMP4Type = 0x03 - ICMP4TimeExceeded ICMP4Type = 0x0b - ICMP4ParamProblem ICMP4Type = 0x12 -) - -func (t ICMP4Type) String() string { - switch t { - case ICMP4EchoReply: - return "EchoReply" - case ICMP4EchoRequest: - return "EchoRequest" - case ICMP4Unreachable: - return "Unreachable" - case ICMP4TimeExceeded: - return "TimeExceeded" - case ICMP4ParamProblem: - return "ParamProblem" - default: - return "Unknown" - } -} - -// ICMP4Code is an ICMPv4 code, as specified in -// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml -type ICMP4Code uint8 - -const ( - ICMP4NoCode ICMP4Code = 0 - ICMP4HostUnreachable ICMP4Code = 1 -) - -// ICMP4Header is an IPv4+ICMPv4 header. -type ICMP4Header struct { - IP4Header - Type ICMP4Type - Code ICMP4Code -} - -// Len implements Header. -func (h ICMP4Header) Len() int { - return h.IP4Header.Len() + icmp4HeaderLength -} - -// Marshal implements Header. -func (h ICMP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ICMPv4 - - buf[20] = uint8(h.Type) - buf[21] = uint8(h.Code) - - h.IP4Header.Marshal(buf) - - binary.BigEndian.PutUint16(buf[22:24], ip4Checksum(buf)) - - return nil -} - -// ToResponse implements Header. TODO: it doesn't implement it -// correctly, instead it statically generates an ICMP Echo Reply -// packet. -func (h *ICMP4Header) ToResponse() { - // TODO: this doesn't implement ToResponse correctly, as it - // assumes the ICMP request type. - h.Type = ICMP4EchoReply - h.Code = ICMP4NoCode - h.IP4Header.ToResponse() -} diff --git a/intra/core/wire/icmp6.go b/intra/core/wire/icmp6.go deleted file mode 100644 index bbbb18ef..00000000 --- a/intra/core/wire/icmp6.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" -) - -// icmp6HeaderLength is the size of the ICMPv6 packet header, not -// including the outer IP layer or the variable "response data" -// trailer. -const icmp6HeaderLength = 4 - -// ICMP6Type is an ICMPv6 type, as specified in -// https://www.iana.org/assignments/icmpv6-parameters/icmpv6-parameters.xhtml -type ICMP6Type uint8 - -const ( - ICMP6Unreachable ICMP6Type = 1 - ICMP6PacketTooBig ICMP6Type = 2 - ICMP6TimeExceeded ICMP6Type = 3 - ICMP6ParamProblem ICMP6Type = 4 - ICMP6EchoRequest ICMP6Type = 128 - ICMP6EchoReply ICMP6Type = 129 -) - -func (t ICMP6Type) String() string { - switch t { - case ICMP6Unreachable: - return "Unreachable" - case ICMP6PacketTooBig: - return "PacketTooBig" - case ICMP6TimeExceeded: - return "TimeExceeded" - case ICMP6ParamProblem: - return "ParamProblem" - case ICMP6EchoRequest: - return "EchoRequest" - case ICMP6EchoReply: - return "EchoReply" - default: - return "Unknown " + string(t) - } -} - -// ICMP6Code is an ICMPv6 code, as specified in -// https://www.iana.org/assignments/icmpv6-parameters/icmpv6-parameters.xhtml -type ICMP6Code uint8 - -const ( - ICMP6NoCode ICMP6Code = 0 - ICMP6NoRoute ICMP6Code = 0 // code 0: no route to destination -) - -// ICMP6Header is an IPv4+ICMPv4 header. -type ICMP6Header struct { - IP6Header - Type ICMP6Type - Code ICMP6Code -} - -// Len implements Header. -func (h ICMP6Header) Len() int { - return h.IP6Header.Len() + icmp6HeaderLength -} - -// Marshal implements Header. -func (h ICMP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ICMPv6 - - h.IP6Header.Marshal(buf) - - const o = IP6HeaderLength // start offset of ICMPv6 header - buf[o+0] = uint8(h.Type) - buf[o+1] = uint8(h.Code) - buf[o+2] = 0 // checksum, to be filled in later - buf[o+3] = 0 // checksum, to be filled in later - return nil -} - -// ToResponse implements Header. TODO: it doesn't implement it -// correctly, instead it statically generates an ICMP Echo Reply -// packet. -func (h *ICMP6Header) ToResponse() { - // TODO: this doesn't implement ToResponse correctly, as it - // assumes the ICMP request type. - h.Type = ICMP6EchoReply - h.Code = ICMP6NoCode - h.IP6Header.ToResponse() -} - -// WriteChecksum implements HeaderChecksummer, writing just the checksum bytes -// into the otherwise fully marshaled ICMP6 packet p (which should include the -// IPv6 header, ICMPv6 header, and payload). -func (h ICMP6Header) WriteChecksum(p []byte) { - const payOff = IP6HeaderLength + icmp6HeaderLength - xsum := icmp6Checksum(p[IP6HeaderLength:payOff], h.Src.As16(), h.Dst.As16(), p[payOff:]) - binary.BigEndian.PutUint16(p[IP6HeaderLength+2:], xsum) -} - -// Adapted from gVisor: - -// icmp6Checksum calculates the ICMP checksum over the provided ICMPv6 -// header (without the IPv6 header), IPv6 src/dst addresses and the -// payload. -// -// The header's existing checksum must be zeroed. -func icmp6Checksum(header []byte, src, dst [16]byte, payload []byte) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := checksumBytes(src[:], 0) - xsum = checksumBytes(dst[:], xsum) - - var scratch [4]byte - binary.BigEndian.PutUint32(scratch[:], uint32(len(header)+len(payload))) - xsum = checksumBytes(scratch[:], xsum) - xsum = checksumBytes(append(scratch[:0], 0, 0, 0, uint8(ICMPv6)), xsum) - xsum = checksumBytes(payload, xsum) - - var hdrz [icmp6HeaderLength]byte - copy(hdrz[:], header) - // Zero out the header. - hdrz[2] = 0 - hdrz[3] = 0 - xsum = ^checksumBytes(hdrz[:], xsum) - return xsum -} - -// checksumCombine combines the two uint16 to form their -// checksum. This is done by adding them and the carry. -// -// Note that checksum a must have been computed on an even number of -// bytes. -func checksumCombine(a, b uint16) uint16 { - v := uint32(a) + uint32(b) - return uint16(v + v>>16) -} - -// checksumBytes calculates the checksum (as defined in RFC 1071) of -// the bytes in buf. -// -// The initial checksum must have been computed on an even number of bytes. -func checksumBytes(buf []byte, initial uint16) uint16 { - if len(buf) <= 0 { - return initial - } - - v := uint32(initial) - - odd := len(buf)%2 == 1 - if odd { - v += uint32(buf[0]) - buf = buf[1:] - } - - n := len(buf) - odd = n&1 != 0 - if odd { - n-- - v += uint32(buf[n]) << 8 - } - - for i := 0; i < n; i += 2 { - v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) - } - - return checksumCombine(uint16(v), uint16(v>>16)) -} diff --git a/intra/core/wire/ip4.go b/intra/core/wire/ip4.go deleted file mode 100644 index d0c7cb0d..00000000 --- a/intra/core/wire/ip4.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" - "errors" - "net/netip" -) - -// IP4HeaderLength is the length of an IPv4 header with no IP options. -const IP4HeaderLength = 20 - -const IP4SrcAddrOffset = 12 - -// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP -// pseudo-header is smaller than the real IPv4 header. -const ip4PseudoHeaderOffset = 8 - -// IP4Header represents an IPv4 packet header. -type IP4Header struct { - IPProto Proto - IPID uint16 - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP4Header) Len() int { - return IP4HeaderLength -} - -var errWrongFamily = errors.New("wrong address family for src/dst IP") - -// Marshal implements Header. -func (h IP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - if !h.Src.Is4() || !h.Dst.Is4() { - return errWrongFamily - } - - buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL - buf[1] = 0x00 // DSCP + ECN - binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length - binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID - binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset - buf[8] = 64 // TTL - buf[9] = uint8(h.IPProto) // Inner protocol - // Blank checksum. This is necessary even though we overwrite - // it later, because the checksum computation runs over these - // bytes and expects them to be zero. - binary.BigEndian.PutUint16(buf[10:12], 0) - src := h.Src.As4() - dst := h.Dst.As4() - copy(buf[12:16], src[:]) - copy(buf[16:20], dst[:]) - - binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum - - return nil -} - -// ToResponse implements Header. -func (h *IP4Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = ^h.IPID -} - -// ip4Checksum computes an IPv4 checksum, as specified in -// https://tools.ietf.org/html/rfc1071 -func ip4Checksum(b []byte) uint16 { - var ac uint32 - i := 0 - n := len(b) - for n >= 2 { - ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint32(b[i]) << 8 - } - for (ac >> 16) > 0 { - ac = (ac >> 16) + (ac & 0xffff) - } - return uint16(^ac) -} - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. The pseudo-header starts -// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP -// header, while leaving enough space in buf for a full IPv4 header. -func (h IP4Header) marshalPseudo(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - length := len(buf) - h.Len() - src, dst := h.Src.As4(), h.Dst.As4() - copy(buf[8:12], src[:]) - copy(buf[12:16], dst[:]) - buf[16] = 0x0 - buf[17] = uint8(h.IPProto) - binary.BigEndian.PutUint16(buf[18:20], uint16(length)) - return nil -} diff --git a/intra/core/wire/ip6.go b/intra/core/wire/ip6.go deleted file mode 100644 index 85b7ddca..00000000 --- a/intra/core/wire/ip6.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" - "net/netip" -) - -// IP6HeaderLength is the length of an IPv6 header with no IP options. -const IP6HeaderLength = 40 - -const IP6SrcAddrOffset = 9 - -// IP6Header represents an IPv6 packet header. -type IP6Header struct { - IPProto Proto - IPID uint32 // only lower 20 bits used - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP6Header) Len() int { - return IP6HeaderLength -} - -// Marshal implements Header. -func (h IP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) - buf[0] = 0x60 - binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-IP6HeaderLength)) // Total length - buf[6] = uint8(h.IPProto) // Inner protocol - buf[7] = 64 // TTL - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[8:24], src[:]) - copy(buf[24:40], dst[:]) - - return nil -} - -// ToResponse implements Header. -func (h *IP6Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = (^h.IPID) & 0x000FFFFF -} - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. -func (h IP6Header) marshalPseudo(buf []byte, proto Proto) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[:16], src[:]) - copy(buf[16:32], dst[:]) - binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) - buf[36] = 0 - buf[37] = 0 - buf[38] = 0 - buf[39] = byte(proto) // NextProto - return nil -} diff --git a/intra/core/wire/parsed.go b/intra/core/wire/parsed.go deleted file mode 100644 index 17fe3d17..00000000 --- a/intra/core/wire/parsed.go +++ /dev/null @@ -1,587 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" - "fmt" - "net" - "net/netip" - "strings" - "sync" -) - -const unknown = UnknownProto - -const MinTCPHeaderSize = 20 - -// RFC1858: prevent overlapping fragment attacks. -const minFragBlks = (60 + MinTCPHeaderSize) / 8 // max IPv4 header + basic TCP header in fragment blocks (8 bytes each) - -type TCPFlag uint8 - -// CaptureMeta contains metadata that is used when debugging. -type CaptureMeta struct { - DidSNAT bool // SNAT was performed & the address was updated. - OriginalSrc netip.AddrPort // The source address before SNAT was performed. - DidDNAT bool // DNAT was performed & the address was updated. - OriginalDst netip.AddrPort // The destination address before DNAT was performed. -} - -const ( - TCPFlagsOffset = 13 - - TCPFin TCPFlag = 0x01 - TCPSyn TCPFlag = 0x02 - TCPRst TCPFlag = 0x04 - TCPPsh TCPFlag = 0x08 - TCPAck TCPFlag = 0x10 - TCPUrg TCPFlag = 0x20 - TCPECNEcho TCPFlag = 0x40 - TCPCWR TCPFlag = 0x80 - TCPSynAck TCPFlag = TCPSyn | TCPAck - TCPECNBits TCPFlag = TCPECNEcho | TCPCWR -) - -type ParsedPool sync.Pool - -// Pool holds a pool of Parsed structs for use in filtering. -var Pool = ParsedPool{New: func() any { return new(Parsed) }} - -func (p *ParsedPool) Get() *Parsed { - pp := (*sync.Pool)(p) - return pp.Get().(*Parsed) -} - -func (p *ParsedPool) Put(parsed *Parsed) { - pp := (*sync.Pool)(p) - pp.Put(parsed) -} - -// Parsed is a minimal decoding of a packet suitable for use in filters. -type Parsed struct { - // b is the byte buffer that this decodes. - b []byte - // subofs is the offset of IP subprotocol. - subofs int - // dataofs is the offset of IP subprotocol payload. - dataofs int - // length is the total length of the packet. - // This is not the same as len(b) because b can have trailing zeros. - length int - // truncated indicates if the packet was truncated. - trunc bool - - // IPVersion is the IP protocol version of the packet (4 or - // 6), or 0 if the packet doesn't look like IPv4 or IPv6. - IPVersion uint8 - // IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0. - IPProto Proto - // Src is the source address. Family matches IPVersion. Port is - // valid iff IPProto == TCP || IPProto == UDP || IPProto == SCTP. - Src netip.AddrPort - // Dst is the destination address. Family matches IPVersion. Port is - // valid iff IPProto == TCP || IPProto == UDP || IPProto == SCTP. - Dst netip.AddrPort - // TCPFlags is the packet's TCP flag bits. Valid iff IPProto == TCP. - TCPFlags TCPFlag - - // CaptureMeta contains metadata that is used when debugging. - CaptureMeta CaptureMeta -} - -func (p *Parsed) String() string { - if p.IPVersion != 4 && p.IPVersion != 6 { - return "Unknown{???}" - } - - // max is the maximum reasonable length of the string we are constructing. - // It's OK to overshoot, as the temp buffer is allocated on the stack. - const max = len("ICMPv6{[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535 > [ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535}") - b := make([]byte, 0, max) - b = append(b, p.IPProto.String()...) - b = append(b, '{') - b = p.Src.AppendTo(b) - b = append(b, ' ', '>', ' ') - b = p.Dst.AppendTo(b) - b = append(b, '}') - return string(b) -} - -func (q *Parsed) Decode(b []byte) { - q.decode(b, false /*truncated*/) -} - -func (q *Parsed) DecodeTrunc(b []byte, trunc bool) { - q.decode(b, trunc /*truncated*/) -} - -// Decode extracts data from the packet in b into q. -// It performs extremely simple packet decoding for basic IPv4 and IPv6 packet types. -// It extracts only the subprotocol id, IP addresses, and (if any) ports, -// and shouldn't need any memory allocation. -func (q *Parsed) decode(b []byte, trunc bool) { - q.b = b - q.trunc = trunc - q.CaptureMeta = CaptureMeta{} // Clear any capture metadata if it exists. - - if len(b) < 1 { - q.IPVersion = 0 - q.IPProto = unknown - return - } - - q.IPVersion = b[0] >> 4 - switch q.IPVersion { - case 4: - q.decode4(b) - case 6: - q.decode6(b) - default: - q.IPVersion = 0 - q.IPProto = unknown - } -} - -func (q *Parsed) decode4(b []byte) { - if len(b) < IP4HeaderLength { - q.IPVersion = 0 - q.IPProto = unknown - return - } - - // Check that it's IPv4. - q.IPProto = Proto(b[9]) - q.length = int(binary.BigEndian.Uint16(b[2:4])) - if !q.trunc && len(b) < q.length { - // Packet was cut off before full IPv4 length. - q.IPProto = unknown - return - } - - // If it's valid IPv4, then the IP addresses are valid - q.Src = withIP(q.Src, netip.AddrFrom4([4]byte{b[12], b[13], b[14], b[15]})) - q.Dst = withIP(q.Dst, netip.AddrFrom4([4]byte{b[16], b[17], b[18], b[19]})) - - q.subofs = int((b[0] & 0x0F) << 2) - if q.subofs > q.length { - // next-proto starts beyond end of packet. - q.IPProto = unknown - return - } - sub := b[q.subofs:] - sub = sub[:len(sub):len(sub)] // help the compiler do bounds check elimination - - // We don't care much about IP fragmentation, except insofar as it's - // used for firewall bypass attacks. The trick is make the first - // fragment of a TCP or UDP packet so short that it doesn't fit - // the TCP or UDP header, so we can't read the port, in hope that - // it'll sneak past. Then subsequent fragments fill it in, but we're - // missing the first part of the header, so we can't read that either. - // - // A "perfectly correct" implementation would have to reassemble - // fragments before deciding what to do. But the truth is there's - // zero reason to send such a short first fragment, so we can treat - // it as Unknown. We can also treat any subsequent fragment that starts - // at such a low offset as Unknown. - fragFlags := binary.BigEndian.Uint16(b[6:8]) - moreFrags := (fragFlags & 0x2000) != 0 - fragOfs := fragFlags & 0x1FFF - - if fragOfs == 0 { - // This is the first fragment - // Every protocol below MUST check that it has at least one entire - // transport header in order to protect against fragment confusion. - switch q.IPProto { - case ICMPv4: - if len(sub) < icmp4HeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, 0) - q.Dst = withPort(q.Dst, 0) - q.dataofs = q.subofs + icmp4HeaderLength - return - case IGMP: - if len(sub) < igmpHeaderLength { - q.IPProto = unknown - return - } - // Keep IPProto, but don't parse anything else - // out. - return - case TCP: - if len(sub) < tcpHeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, binary.BigEndian.Uint16(sub[0:2])) - q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) - q.TCPFlags = TCPFlag(sub[TCPFlagsOffset]) - headerLength := (sub[12] & 0xF0) >> 2 - q.dataofs = q.subofs + int(headerLength) - return - case UDP: - if len(sub) < udpHeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, binary.BigEndian.Uint16(sub[0:2])) - q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) - q.dataofs = q.subofs + udpHeaderLength - return - case SCTP: - if len(sub) < sctpHeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, binary.BigEndian.Uint16(sub[0:2])) - q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) - return - case TSMP: - // Strictly disallow fragmented TSMP - if moreFrags { - q.IPProto = unknown - return - } - // Inter-tailscale messages. - q.dataofs = q.subofs - return - case Fragment: - // An IPProto value of 0xff (our Fragment constant for internal use) - // should never actually be used in the wild; if we see it, - // something's suspicious and we map it back to zero (unknown). - q.IPProto = unknown - } - } else { - // This is a fragment other than the first one. - if fragOfs < minFragBlks { - // disallow fragment offsets that are potentially inside of a - // transport header. This is notably asymmetric with the - // first-packet limit, that may allow a first-packet that requires a - // shorter offset than this limit, but without state to tie this - // to the first fragment we can not allow shorter packets. - q.IPProto = unknown - return - } - // otherwise, we have to permit the fragment to slide through. - // Second and later fragments don't have sub-headers. - // Ideally, we would drop fragments that we can't identify, - // but that would require statefulness. Anyway, receivers' - // kernels know to drop fragments where the initial fragment - // doesn't arrive. - q.IPProto = Fragment - return - } -} - -func (q *Parsed) decode6(b []byte) { - if len(b) < IP6HeaderLength { - q.IPVersion = 0 - q.IPProto = unknown - return - } - - q.IPProto = Proto(b[6]) - q.length = int(binary.BigEndian.Uint16(b[4:6])) + IP6HeaderLength - if !q.trunc && len(b) < q.length { - // Packet was cut off before the full IPv6 length. - q.IPProto = unknown - return - } - - // okay to ignore `ok` here, because IPs pulled from packets are - // always well-formed stdlib IPs. - srcIP, _ := netip.AddrFromSlice(net.IP(b[8:24])) - dstIP, _ := netip.AddrFromSlice(net.IP(b[24:40])) - q.Src = withIP(q.Src, srcIP) - q.Dst = withIP(q.Dst, dstIP) - - // We don't support any IPv6 extension headers. Don't try to - // be clever. Therefore, the IP subprotocol always starts at - // byte 40. - // - // Note that this means we don't support fragmentation in - // IPv6. This is fine, because IPv6 strongly mandates that you - // should not fragment, which makes fragmentation on the open - // internet extremely uncommon. - // - // This also means we don't support IPSec headers (AH/ESP), or - // IPv6 jumbo frames. Those will get marked Unknown and - // dropped. - q.subofs = IP6HeaderLength - sub := b[q.subofs:] - sub = sub[:len(sub):len(sub)] // help the compiler do bounds check elimination - - switch q.IPProto { - case ICMPv6: - if len(sub) < icmp6HeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, 0) - q.Dst = withPort(q.Dst, 0) - q.dataofs = q.subofs + icmp6HeaderLength - case TCP: - if len(sub) < tcpHeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, binary.BigEndian.Uint16(sub[0:2])) - q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) - q.TCPFlags = TCPFlag(sub[13]) - headerLength := (sub[12] & 0xF0) >> 2 - q.dataofs = q.subofs + int(headerLength) - return - case UDP: - if len(sub) < udpHeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, binary.BigEndian.Uint16(sub[0:2])) - q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) - q.dataofs = q.subofs + udpHeaderLength - case SCTP: - if len(sub) < sctpHeaderLength { - q.IPProto = unknown - return - } - q.Src = withPort(q.Src, binary.BigEndian.Uint16(sub[0:2])) - q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) - return - case TSMP: - // Inter-tailscale messages. - q.dataofs = q.subofs - return - case Fragment: - // An IPProto value of 0xff (our Fragment constant for internal use) - // should never actually be used in the wild; if we see it, - // something's suspicious and we map it back to zero (unknown). - q.IPProto = unknown - return - } -} - -func (q *Parsed) IP4Header() IP4Header { - if q.IPVersion != 4 { - return IP4Header{} - } - ipid := binary.BigEndian.Uint16(q.b[4:6]) - return IP4Header{ - IPID: ipid, - IPProto: q.IPProto, - Src: q.Src.Addr(), - Dst: q.Dst.Addr(), - } -} - -func (q *Parsed) IP6Header() IP6Header { - if q.IPVersion != 6 { - return IP6Header{} - } - ipid := (binary.BigEndian.Uint32(q.b[:4]) << 12) >> 12 - return IP6Header{ - IPID: ipid, - IPProto: q.IPProto, - Src: q.Src.Addr(), - Dst: q.Dst.Addr(), - } -} - -func (q *Parsed) ICMPHeaderString() string { - switch q.IPProto { - case ICMPv4: - return q.ICMP4Header().Stringer() - case ICMPv6: - return q.ICMP6Header().Stringer() - } - return "ICMP" + string(q.IPVersion) + "{???}" -} - -func (q *Parsed) ICMP4Header() ICMP4Header { - return ICMP4Header{ - IP4Header: q.IP4Header(), - Type: ICMP4Type(q.b[q.subofs+0]), - Code: ICMP4Code(q.b[q.subofs+1]), - } -} - -func (h ICMP4Header) Stringer() string { - return fmt.Sprintf("%v", h) -} - -func (h ICMP6Header) Stringer() string { - return fmt.Sprintf("%v", h) -} - -func (q *Parsed) ICMP6Header() ICMP6Header { - return ICMP6Header{ - IP6Header: q.IP6Header(), - Type: ICMP6Type(q.b[q.subofs+0]), - Code: ICMP6Code(q.b[q.subofs+1]), - } -} - -func (q *Parsed) UDP4Header() UDP4Header { - return UDP4Header{ - IP4Header: q.IP4Header(), - SrcPort: q.Src.Port(), - DstPort: q.Dst.Port(), - } -} - -// Buffer returns the entire packet buffer. -// This is a read-only view; that is, q retains the ownership of the buffer. -func (q *Parsed) Buffer() []byte { - return q.b -} - -// Payload returns the payload of the IP subprotocol section. -// This is a read-only view; that is, q retains the ownership of the buffer. -func (q *Parsed) Payload() ([]byte, bool) { - // If the packet is truncated, return nothing instead of crashing. - if q.dataofs > len(q.b) { - return nil, q.trunc - } - if q.length > len(q.b) { - if q.trunc { - return q.b[q.dataofs:], true - } - return nil, q.trunc - } - - return q.b[q.dataofs:q.length], false -} - -// Transport returns the transport header and payload (IP subprotocol, such as TCP or UDP). -// This is a read-only view; that is, p retains the ownership of the buffer. -func (p *Parsed) Transport() []byte { - return p.b[p.subofs:] -} - -func (p *Parsed) HasTransportData() bool { - return p.subofs < len(p.b) -} - -// IsTCPSyn reports whether q is a TCP SYN packet, -// without ACK set. (i.e. the first packet in a new connection) -func (q *Parsed) IsTCPSyn() bool { - return (q.TCPFlags & TCPSynAck) == TCPSyn -} - -// IsError reports whether q is an ICMP "Error" packet. -func (q *Parsed) IsError() bool { - switch q.IPProto { - case ICMPv4: - if len(q.b) < q.subofs+8 { - return false - } - t := ICMP4Type(q.b[q.subofs]) - return t == ICMP4Unreachable || t == ICMP4TimeExceeded || t == ICMP4ParamProblem - case ICMPv6: - if len(q.b) < q.subofs+8 { - return false - } - t := ICMP6Type(q.b[q.subofs]) - return t == ICMP6Unreachable || t == ICMP6PacketTooBig || t == ICMP6TimeExceeded || t == ICMP6ParamProblem - default: - return false - } -} - -// IsEchoRequest reports whether q is an ICMP Echo Request. -func (q *Parsed) IsEchoRequest() bool { - switch q.IPProto { - case ICMPv4: - return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode - case ICMPv6: - return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoRequest && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode - default: - return false - } -} - -// IsEchoResponse reports whether q is an IPv4 ICMP Echo Response. -func (q *Parsed) IsEchoResponse() bool { - switch q.IPProto { - case ICMPv4: - return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode - case ICMPv6: - return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoReply && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode - default: - return false - } -} - -// EchoIDSeq extracts the identifier/sequence bytes from an ICMP Echo response, -// and returns them as a uint32, used to lookup internally routed ICMP echo -// responses. This function is intentionally lightweight as it is called on -// every incoming ICMP packet. -func (q *Parsed) EchoIDSeq() uint32 { - switch q.IPProto { - case ICMPv4: - offset := IP4HeaderLength + icmp4HeaderLength - if len(q.b) < offset+4 { - return 0 - } - return binary.LittleEndian.Uint32(q.b[offset:]) - case ICMPv6: - offset := IP6HeaderLength + icmp6HeaderLength - if len(q.b) < offset+4 { - return 0 - } - return binary.LittleEndian.Uint32(q.b[offset:]) - default: - return 0 - } -} - -func Hexdump(b []byte) string { - out := new(strings.Builder) - for i := 0; i < len(b); i += 16 { - if i > 0 { - fmt.Fprintf(out, "\n") - } - fmt.Fprintf(out, " %04x ", i) - j := 0 - for ; j < 16 && i+j < len(b); j++ { - if j == 8 { - fmt.Fprintf(out, " ") - } - fmt.Fprintf(out, "%02x ", b[i+j]) - } - for ; j < 16; j++ { - if j == 8 { - fmt.Fprintf(out, " ") - } - fmt.Fprintf(out, " ") - } - fmt.Fprintf(out, " ") - for j = 0; j < 16 && i+j < len(b); j++ { - if b[i+j] >= 32 && b[i+j] < 128 { - fmt.Fprintf(out, "%c", b[i+j]) - } else { - fmt.Fprintf(out, ".") - } - } - } - return out.String() -} - -func withIP(ap netip.AddrPort, ip netip.Addr) netip.AddrPort { - return netip.AddrPortFrom(ip, ap.Port()) -} - -func withPort(ap netip.AddrPort, port uint16) netip.AddrPort { - return netip.AddrPortFrom(ap.Addr(), port) -} diff --git a/intra/core/wire/parsed_test.go b/intra/core/wire/parsed_test.go deleted file mode 100644 index 3a51b28c..00000000 --- a/intra/core/wire/parsed_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" - "encoding/hex" - "net/netip" - "testing" -) - -func TestParsedDecodeICMPv6EchoRequest(t *testing.T) { - // IPv6 ICMP echo request packet (no fragmentation, truncated) - hexpkt := "600eb68300403a40fd66f83ac65000000000000000000001260647000000000000000000681084e58000c964116c0002832a586900000000f15b030000000000101112131415161718191a1b1c1d1e1f" - - pkt, err := hex.DecodeString(hexpkt) - if err != nil { - t.Fatalf("hex decode failed: %v", err) - } - - expectedLen := int(binary.BigEndian.Uint16(pkt[4:6])) + IP6HeaderLength - if len(pkt) < expectedLen { - pkt = append(pkt, make([]byte, expectedLen-len(pkt))...) - } - - var p Parsed - p.DecodeTrunc(pkt, true) - - if p.IPVersion != Version6 { - t.Fatalf("IPVersion got %d, want %d", p.IPVersion, Version6) - } - if p.IPProto != ICMPv6 { - t.Fatalf("IPProto got %v, want %v", p.IPProto, ICMPv6) - } - - wantSrc := netip.MustParseAddr("fd66:f83a:c650::1") - wantDst := netip.MustParseAddr("2606:4700::6810:84e5") - - if p.Src.Addr() != wantSrc { - t.Fatalf("Src got %v, want %v", p.Src.Addr(), wantSrc) - } - if p.Dst.Addr() != wantDst { - t.Fatalf("Dst got %v, want %v", p.Dst.Addr(), wantDst) - } - if p.Src.Port() != 0 || p.Dst.Port() != 0 { - t.Fatalf("expected zero ports for ICMPv6, got src=%d dst=%d", p.Src.Port(), p.Dst.Port()) - } - - if !p.IsEchoRequest() { - t.Fatalf("expected packet to be ICMPv6 echo request") - } - t.Log(p.ICMPHeaderString()) -} diff --git a/intra/core/wire/proto.go b/intra/core/wire/proto.go deleted file mode 100644 index 8df1ad71..00000000 --- a/intra/core/wire/proto.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "fmt" -) - -// Version describes the IP address version. -type Version uint8 - -// Valid Version values. -const ( - Version4 = 4 - Version6 = 6 -) - -func (p Version) String() string { - switch p { - case Version4: - return "IPv4" - case Version6: - return "IPv6" - default: - return fmt.Sprintf("Version-%d", int(p)) - } -} - -// Proto is an IP subprotocol as defined by the IANA protocol -// numbers list -// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), -// or the special values Unknown or Fragment. -type Proto uint8 - -const ( - // Unknown represents an unknown or unsupported protocol; it's - // deliberately the zero value. Strictly speaking the zero - // value is IPv6 hop-by-hop extensions, but we don't support - // those, so this is still technically correct. - UnknownProto Proto = 0x00 - - // Values from the IANA registry. - ICMPv4 Proto = 0x01 - IGMP Proto = 0x02 - ICMPv6 Proto = 0x3a - TCP Proto = 0x06 - UDP Proto = 0x11 - DCCP Proto = 0x21 - GRE Proto = 0x2f - SCTP Proto = 0x84 - - // TSMP is the Tailscale Message Protocol (our ICMP-ish - // thing), an IP protocol used only between Tailscale nodes - // (still encrypted by WireGuard) that communicates why things - // failed, etc. - // - // Proto number 99 is reserved for "any private encryption - // scheme". We never accept these from the host OS stack nor - // send them to the host network stack. It's only used between - // nodes. - TSMP Proto = 99 - - // Fragment represents any non-first IP fragment, for which we - // don't have the sub-protocol header (and therefore can't - // figure out what the sub-protocol is). - // - // 0xFF is reserved in the IANA registry, so we steal it for - // internal use. - Fragment Proto = 0xFF -) - -// Deprecated: use MarshalText instead. -func (p Proto) String() string { - switch p { - case UnknownProto: - return "Unknown" - case Fragment: - return "Frag" - case ICMPv4: - return "ICMPv4" - case IGMP: - return "IGMP" - case ICMPv6: - return "ICMPv6" - case UDP: - return "UDP" - case TCP: - return "TCP" - case SCTP: - return "SCTP" - case TSMP: - return "TSMP" - case GRE: - return "GRE" - case DCCP: - return "DCCP" - default: - return fmt.Sprintf("IPProto-%d", int(p)) - } -} - -// Prefer names from -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// unless otherwise noted. -var ( - // PreferredNames is the set of protocol names that re produced by - // MarshalText, and are the preferred representation. - PreferredNames = map[Proto]string{ - 51: "ah", - DCCP: "dccp", - 8: "egp", - 50: "esp", - 47: "gre", - ICMPv4: "icmp", - IGMP: "igmp", - 9: "igp", - 4: "ipv4", - ICMPv6: "ipv6-icmp", - SCTP: "sctp", - TCP: "tcp", - UDP: "udp", - } - - // AcceptedNames is the set of protocol names that are accepted by - // UnmarshalText. - AcceptedNames = map[string]Proto{ - "ah": 51, - "dccp": DCCP, - "egp": 8, - "esp": 50, - "gre": 47, - "icmp": ICMPv4, - "icmpv4": ICMPv4, - "icmpv6": ICMPv6, - "igmp": IGMP, - "igp": 9, - "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" - "ipv4": 4, - "ipv6-icmp": ICMPv6, - "sctp": SCTP, - "tcp": TCP, - "tsmp": TSMP, - "udp": UDP, - } -) diff --git a/intra/core/wire/udp.go b/intra/core/wire/udp.go deleted file mode 100644 index e2259b92..00000000 --- a/intra/core/wire/udp.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" -) - -// udpHeaderLength is the size of the UDP packet header, not including -// the outer IP header. -const udpHeaderLength = 8 - -// UDP4Header is an IPv4+UDP header. -type UDP4Header struct { - IP4Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP4Header) Len() int { - return h.IP4Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = UDP - - length := len(buf) - h.IP4Header.Len() - binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) - binary.BigEndian.PutUint16(buf[22:24], h.DstPort) - binary.BigEndian.PutUint16(buf[24:26], uint16(length)) - binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP4Header.marshalPseudo(buf) - binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) - - h.IP4Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP4Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP4Header.ToResponse() -} - -// UDP6Header is an IPv6+UDP header. -type UDP6Header struct { - IP6Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP6Header) Len() int { - return h.IP6Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = UDP - - length := len(buf) - h.IP6Header.Len() - binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) - binary.BigEndian.PutUint16(buf[42:44], h.DstPort) - binary.BigEndian.PutUint16(buf[44:46], uint16(length)) - binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP6Header.marshalPseudo(buf, UDP) - binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) - - h.IP6Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP6Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP6Header.ToResponse() -} diff --git a/intra/core/wire/xsum.go b/intra/core/wire/xsum.go deleted file mode 100644 index 6911430b..00000000 --- a/intra/core/wire/xsum.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: BSD-3-Clause -// Copyright (c) Tailscale Inc & AUTHORS - -package wire - -import ( - "encoding/binary" - "net/netip" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// UpdateSrcAddr updates the source address in the packet buffer (e.g. during -// SNAT). It also updates the checksum. Currently (2023-09-22) only TCP/UDP/ICMP -// is supported. It panics if provided with an address in a different -// family to the parsed packet. -func UpdateSrcAddr(q *Parsed, src netip.Addr) { - if src.Is6() && q.IPVersion != 6 { - panic("UpdateSrcAddr: cannot write IPv6 address to v4 packet") - } else if src.Is4() && q.IPVersion != 4 { - panic("UpdateSrcAddr: cannot write IPv4 address to v6 packet") - } - q.CaptureMeta.DidSNAT = true - q.CaptureMeta.OriginalSrc = q.Src - - old := q.Src.Addr() - q.Src = netip.AddrPortFrom(src, q.Src.Port()) - - b := q.Buffer() - if src.Is6() { - v6 := src.As16() - copy(b[8:24], v6[:]) - updateV6PacketChecksums(q, old, src) - } else { - v4 := src.As4() - copy(b[12:16], v4[:]) - updateV4PacketChecksums(q, old, src) - } -} - -// UpdateDstAddr updates the destination address in the packet buffer (e.g. during -// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP -// is supported. It panics if provided with an address in a different -// family to the parsed packet. -func UpdateDstAddr(q *Parsed, dst netip.Addr) { - if dst.Is6() && q.IPVersion != 6 { - panic("UpdateDstAddr: cannot write IPv6 address to v4 packet") - } else if dst.Is4() && q.IPVersion != 4 { - panic("UpdateDstAddr: cannot write IPv4 address to v6 packet") - } - q.CaptureMeta.DidDNAT = true - q.CaptureMeta.OriginalDst = q.Dst - - old := q.Dst.Addr() - q.Dst = netip.AddrPortFrom(dst, q.Dst.Port()) - - b := q.Buffer() - if dst.Is6() { - v6 := dst.As16() - copy(b[24:40], v6[:]) - updateV6PacketChecksums(q, old, dst) - } else { - v4 := dst.As4() - copy(b[16:20], v4[:]) - updateV4PacketChecksums(q, old, dst) - } -} - -// updateV4PacketChecksums updates the checksums in the packet buffer. -// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported. -// p is modified in place. -// If p.IPProto is unknown, only the IP header checksum is updated. -func updateV4PacketChecksums(p *Parsed, old, new netip.Addr) { - if len(p.Buffer()) < 12 { - // Not enough space for an IPv4 header. - return - } - o4, n4 := old.As4(), new.As4() - - // First update the checksum in the IP header. - updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:]) - - // Now update the transport layer checksums, where applicable. - tr := p.Transport() - switch p.IPProto { - case UDP, DCCP: - if len(tr) < header.UDPMinimumSize { - // Not enough space for a UDP header. - return - } - updateV4Checksum(tr[6:8], o4[:], n4[:]) - case TCP: - if len(tr) < header.TCPMinimumSize { - // Not enough space for a TCP header. - return - } - updateV4Checksum(tr[16:18], o4[:], n4[:]) - case GRE: - if len(tr) < 6 { - // Not enough space for a GRE header. - return - } - if tr[0] == 1 { // checksum present - updateV4Checksum(tr[4:6], o4[:], n4[:]) - } - case SCTP, ICMPv4: - // No transport layer update required. - } -} - -// updateV6PacketChecksums updates the checksums in the packet buffer. -// p is modified in place. -// If p.IPProto is unknown, no checksums are updated. -func updateV6PacketChecksums(p *Parsed, old, new netip.Addr) { - if len(p.Buffer()) < 40 { - // Not enough space for an IPv6 header. - return - } - o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice()) - - // Now update the transport layer checksums, where applicable. - tr := p.Transport() - switch p.IPProto { - case ICMPv6: - if len(tr) < header.ICMPv6MinimumSize { - return - } - header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6) - case UDP, DCCP: - if len(tr) < header.UDPMinimumSize { - return - } - header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) - case TCP: - if len(tr) < header.TCPMinimumSize { - return - } - header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) - case SCTP: - // No transport layer update required. - } -} - -// updateV4Checksum calculates and updates the checksum in the packet buffer for -// a change between old and new. The oldSum must point to the 16-bit checksum -// field in the packet buffer that holds the old checksum value, it will be -// updated in place. -// -// The old and new must be the same length, and must be an even number of bytes. -func updateV4Checksum(oldSum, old, new []byte) { - if len(old) != len(new) { - panic("old and new must be the same length") - } - if len(old)%2 != 0 { - panic("old and new must be of even length") - } - /* - RFC 1624 - Given the following notation: - - HC - old checksum in header - C - one's complement sum of old header - HC' - new checksum in header - C' - one's complement sum of new header - m - old value of a 16-bit field - m' - new value of a 16-bit field - - HC' = ~(C + (-m) + m') -- [Eqn. 3] - HC' = ~(~HC + ~m + m') - - This can be simplified to: - HC' = ~(C + ~m + m') -- [Eqn. 3] - HC' = ~C' - C' = C + ~m + m' - */ - - c := uint32(^binary.BigEndian.Uint16(oldSum)) - - cPrime := c - for len(new) > 0 { - mNot := uint32(^binary.BigEndian.Uint16(old[:2])) - mPrime := uint32(binary.BigEndian.Uint16(new[:2])) - cPrime += mPrime + mNot - new, old = new[2:], old[2:] - } - - // Account for overflows by adding the carry bits back into the sum. - for (cPrime >> 16) > 0 { - cPrime = cPrime&0xFFFF + cPrime>>16 - } - hcPrime := ^uint16(cPrime) - binary.BigEndian.PutUint16(oldSum, hcPrime) -} diff --git a/intra/depaware.txt b/intra/depaware.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/intra/dialers/cdial.go b/intra/dialers/cdial.go deleted file mode 100644 index fcf52db3..00000000 --- a/intra/dialers/cdial.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "errors" - "net" - "net/netip" - "strconv" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect/ipmap" - "github.com/celzero/firestack/intra/settings" -) - -const dialRetryTimeout = 35 * time.Second - -var errRetryTimeout = errors.New("dialers: retry timeout") - -func reorderIPs(ips []netip.Addr, alwaysExclude netip.Addr) ([]netip.Addr, bool) { - failingopen := true - use4 := Use4() - use6 := Use6() - - front := make([]netip.Addr, 0, len(ips)) - back := make([]netip.Addr, 0, len(ips)) - for _, ip := range ips { - if ip.Compare(alwaysExclude) == 0 || !ip.IsValid() { - continue - } else if use4 && ip.Is4() { - front = append(front, ip) - } else if use6 && ip.Is6() { - front = append(front, ip) - } else { - back = append(back, ip) - } - } - if len(front) <= 0 { - // if all ips are filtered out, fail open and return unfiltered - return back, failingopen - } - if len(back) > 0 { - // sample one unfiltered ip in an ironic case that it works - // but the filtered out ones don't. this can happen in scenarios - // where tunnel's ipProto is IP4 but the underlying network is IP6: - // that is, IP6 is filtered out even though it might have worked. - front = append(front, back...) - } - return front, !failingopen -} - -func commondial[D rdials, C rconns](d D, network, addr string, connect dialFn[D, C]) (C, error) { - return commondial2(d, network, "", addr, connect) -} - -func commondial2[D rdials, C rconns](d D, network, laddr, raddr string, connect dialFn[D, C]) (C, error) { - start := time.Now() - - local, lerr := netip.ParseAddrPort(laddr) // okay if local is invalid - domain, portstr, err := net.SplitHostPort(raddr) - - if settings.Debug { - log.D("commondial: dialing (host:port) %s=>%s; errs? %v %v", - laddr, raddr, lerr, err) - } - - if err != nil { - return nil, err - } - - // cannot dial into a wildcard address - // while, listen is unsupported - if len(domain) == 0 { - return nil, net.InvalidAddrError(raddr) - } - port, err := strconv.Atoi(portstr) - if err != nil { - return nil, err - } - - var conn C - var errs error - ips := ipm.Get(domain) - dontretry := ips.OneIPOnly() // just one IP, no retries possible - confirmed := ips.Confirmed() // may be zeroaddr - confirmedIPOK := ipok(confirmed) - - defer func() { - dur := time.Since(start) - if settings.Debug { - log.D("commondial: duration: %s; addr %s; confirmed? %s, sz: %d", - core.FmtPeriod(dur), raddr, confirmed, ips.Size()) - } - }() - - // One the TODO is fixed, change ipn/proxy.go:Reaches to rely on this behaviour - // TODO: confirmedIPOK must be used depending on network type "tcp4", "udp4", "tcp6", "udp6" etc - if confirmedIPOK { - remote := netip.AddrPortFrom(confirmed, uint16(port)) - if settings.Debug { - log.V("commondial: dialing confirmed ip %s for %s", confirmed, remote) - } - conn, err = connect(d, network, local, remote) - // nilaway: tx.socks5 returns nil conn even if err == nil - if conn == nil { - err = core.OneErr(err, errNoConn) - } - if err == nil { - if settings.Debug { - log.V("commondial: ip %s works for %s", confirmed, remote) - } - return conn, nil - } - errs = core.JoinErr(errs, err) - ips.Disconfirm(confirmed) - logwd(err)("rdial: commondial: confirmed %s for %s failed; err %v", - confirmed, remote, err) - } - - if dontretry { - if !confirmedIPOK { - log.E("commondial: ip %s not ok for %s", confirmed, raddr) - errs = core.JoinErr(errs, errNoIps) - } - return nil, errs - } - - ipset := ips.Addrs() - // One the TODO is fixed, change ipn/proxy.go:Reaches to rely on this behaviour - // TODO: maybeFilter should consider incoming network types "tcp4", "udp4", "tcp6", "udp6" etc - ordered, failingopen := reorderIPs(ipset, confirmed) - if len(ordered) <= 0 || failingopen { - var renewed bool - if ips, renewed = renew(domain, ips); renewed { - ipset = ips.Addrs() - ordered, failingopen = reorderIPs(ipset, confirmed) - } - log.D("commondial: renew ips for %s; renewed? %t, failingopen? %t", raddr, renewed, failingopen) - } - log.D("commondial: trying all ips %d/%d %v for %s, failingopen? %t", - len(ordered), len(ipset), ordered, raddr, failingopen) - for _, ip := range ordered { - end := time.Since(start) - if end > dialRetryTimeout { - errs = core.JoinErr(errs, errRetryTimeout) - log.D("commondial: timeout %s for %s", end, raddr) - break - } - if ipok(ip) { - remote := netip.AddrPortFrom(ip, uint16(port)) - conn, err = connect(d, network, local, remote) - // nilaway: tx.socks5 returns nil conn even if err == nil - if conn == nil { - err = core.OneErr(err, errNoConn) - } - if err == nil { - confirm(ips, ip) - log.I("commondial: ip %s works for %s", ip, remote) - return conn, nil - } - errs = core.JoinErr(errs, err) - logwd(err)("commondial: ip %s for %s failed; err %v", ip, remote, err) - } else { - log.W("commondial: ip %s not ok for %s", ip, raddr) - } - } - - if len(ipset) <= 0 { - errs = errNoIps - } - - return nil, errs -} - -func clos(c ...core.MinConn) { - core.CloseConn(c...) -} - -func confirm(ips *ipmap.IPSet, ip netip.Addr) { - if ips != nil && ipok(ip) { - ips.Confirm(ip) - } -} - -func ipok(ip netip.Addr) bool { - return ip.IsValid() && !ip.IsUnspecified() -} - -func logwd(err error) log.LogFn { - if err != nil { - return log.W - } - return log.D -} diff --git a/intra/dialers/direct_split.go b/intra/dialers/direct_split.go deleted file mode 100644 index 702124e5..00000000 --- a/intra/dialers/direct_split.go +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dialers - -import ( - "io" - "net" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -type splitter struct { - conn *net.TCPConn - strat int32 // settings.Split* constant - - used *core.SigCond // Signalled after the first write. -} - -var _ core.DuplexConn = (*splitter)(nil) -var _ core.RetrierConn = (*splitter)(nil) - -// Write implements core.DuplexConn. -func (s *splitter) Write(b []byte) (n int, err error) { - if s.used.Cond() { - // after the first write, there is no special write behavior. - return s.conn.Write(b) - } - if s.used.Signal() { // first writer splits - n, err = s.writeSplit(b) - return n, err - } - // if `used` is already swapped or set, then the split has already been done. - return s.conn.Write(b) -} - -func (s *splitter) writeSplit(b []byte) (n int, err error) { - w := s.conn - switch s.strat { - case settings.SplitTCP: - n, err = writeTCPSplit(w, b) - case settings.SplitTCPOrTLS: - n, err = writeTCPOrTLSSplit(w, b) - default: - log.W("split: unknown dial strategy: %d", s.strat) - n, err = w.Write(b) - } - return -} - -// ReadFrom reads from reader and writes to s. -// Usually, ReadFrom is executing the "upload" phase of egressing conn (r). -func (s *splitter) ReadFrom(reader io.Reader) (bytes int64, err error) { - start := time.Now() - if !s.used.Cond() { - // This is the first write on this socket. - // Use copyOnce(), which calls Write(), to get Write's splitting behavior for - // the first segment. - if bytes, err = copyOnce(s, reader); err != nil { - return - } - } - // wait for first write (split) to complete if concurrent - s.used.Wait() - elapsed := time.Since(start) - - var b int64 - b, err = s.conn.ReadFrom(reader) - bytes += b - - logeif(err)("split: readfrom: done %s<=%s; sz: %d; dur: %s, wait: %s; err: %v", - laddr(s.conn), raddr(s.conn), bytes, core.FmtTimeAsPeriod(start), core.FmtPeriod(elapsed), err) - return -} - -// WriteTo reads from s and writes to w. -// Usually, WriteTo is executing the "download" phase of egressing conn (w). -func (s *splitter) WriteTo(w io.Writer) (bytes int64, err error) { - start := time.Now() - waited := s.used.TryWait(uploadTimeoutForDownload) - elapsed := time.Since(start) - - bytes, err = s.conn.WriteTo(w) - - logeif(err)("split: writeto: done %s=>%s; sz: %d; dur: %s, wait: %s (%t); err: %v", - laddr(s.conn), raddr(s.conn), bytes, core.FmtTimeAsPeriod(start), core.FmtPeriod(elapsed), waited, err) - return -} - -// Read implements core.DuplexConn. -func (s *splitter) Read(b []byte) (int, error) { return s.conn.Read(b) } - -// LocalAddr implements core.DuplexConn. -func (s *splitter) LocalAddr() net.Addr { return laddr(s.conn) } - -// RemoteAddr implements core.DuplexConn. -func (s *splitter) RemoteAddr() net.Addr { return raddr(s.conn) } - -func (s *splitter) SetDeadline(t time.Time) error { - if c := s.conn; c != nil { - return c.SetDeadline(t) - } - return nil // no-op -} - -// SetReadDeadline implements core.DuplexConn. -func (s *splitter) SetReadDeadline(t time.Time) error { - if c := s.conn; c != nil { - return c.SetReadDeadline(t) - } - return nil // no-op -} - -// SetWriteDeadline implements core.DuplexConn. -func (s *splitter) SetWriteDeadline(t time.Time) error { - if c := s.conn; c != nil { - return c.SetWriteDeadline(t) - } - return nil // no-op -} - -// Close implements core.DuplexConn. -func (s *splitter) Close() error { core.CloseTCP(s.conn); return nil } - -// CloseRead implements core.DuplexConn. -func (s *splitter) CloseRead() error { core.CloseTCPRead(s.conn); return nil } - -// CloseWrite implements core.DuplexConn. -func (s *splitter) CloseWrite() error { core.CloseTCPWrite(s.conn); return nil } - -// SyscallConn implements core.DuplexConn. -func (s *splitter) SyscallConn() (syscall.RawConn, error) { - if c := s.conn; c != nil { - return c.SyscallConn() - } - return nil, syscall.EINVAL -} - -// SetKeepAlive implements core.DuplexConn. -func (s *splitter) SetKeepAlive(y bool) error { - if c := s.conn; c != nil { - return c.SetKeepAlive(y) - } - return nil // no-op -} diff --git a/intra/dialers/dns.go b/intra/dialers/dns.go deleted file mode 100644 index 6bdd6e02..00000000 --- a/intra/dialers/dns.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "context" - "net/netip" - "net/url" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -// ResolveFor resolves nom to IPs using transport designated for given uid. -func ResolveFor(nom string, uid string) ([]netip.Addr, error) { - // ipm.LookupNetIP itself has a short-term cache (ipmapper.go:battl) - // and since TIDs are specified, the ipmap cache is not used. - return ipm.LookupNetIPFor(context.Background(), "ip", nom, uid) -} - -// Resolve resolves hostname to IP addresses, bypassing cache. -// If resolution fails, entries from the cache are returned, if any. -func Resolve(hostname string, tids ...string) (addrs []netip.Addr, err error) { - ctx := context.Background() - // both lookups may return addrs = nil, err = nil - // (see: ipmapper.go:queryIP2 and protect.NeverResolve) - // ipm.LookupNetIPxxx itself has a short-term cache (ipmapper.go:battl) - addrs, err = ipm.LookupNetIPOn(ctx, "ip", hostname, tids...) - - if len(addrs) <= 0 { // check cache - if addrs = CachedAddrs(hostname); len(addrs) > 0 { - return addrs, nil - } // else: no cached addrs - // even if ipmapper lookups return no addrs, raw ipset - // may have seed addrs; which when empty, error out. - err = core.OneErr(err, errNoIps) - } - return addrs, err -} - -func ResolveForUrl(s string) []netip.Addr { - u, err := url.Parse(s) // works if s is mere hostname; ex: example.com - if err != nil { - return For(s) // fallback on hostOrIP - } - return For(u.Hostname()) -} - -// SampleHosts returns a slice of random hosts, of size n, for the given ipver. -// ipver is one of "v4", "v6", or "" (for both). -func SampleHosts(n uint8, ipver string) []string { - return ipm.ReverseGetMany(n, ipver) -} - -// SampleIPs returns a slice of random IPs, of size n, for the given ipver. -// ipver is one of "v4", "v6", or "" (for both). -func SampleIPs(n uint8, ipver string) []netip.Addr { - return ipm.GetMany(n, ipver) -} - -// ECH returns the ECH config, if any, for the given hostname. -// The query is resolved using IPMapper's default resolver. -func ECH(hostname string) ([]byte, error) { - q, err := xdns.Question(hostname, dns.TypeHTTPS) - if err != nil { - return nil, err - } - res, err := ipm.LocalLookup(q) - if err != nil { - return nil, err - } - ans := &dns.Msg{} - if err = ans.Unpack(res); err != nil { - return nil, err - } - for _, a := range ans.Answer { - if rr, ok := a.(*dns.HTTPS); ok { - for i, kv := range rr.Value { - if kv.Key() == dns.SVCB_ECHCONFIG { - if v, ok := rr.Value[i].(*dns.SVCBECHConfig); ok { - return v.ECH, nil - } // else: unlikely - } // else: not ech config - } // done iter https rr - } // else: not https rr - } // done iter answers - return nil, errNoEch -} - -// Query sends a DNS query to the Default DNS and -// returns the answer. -func Query(msg *dns.Msg, tids ...string) (*dns.Msg, error) { - q, err := msg.Pack() - if err != nil { - return nil, err - } - - r, err := ipm.Lookup(q, protect.UidSelf, tids...) - if err != nil { - return nil, err - } - - ans := &dns.Msg{} - if err = ans.Unpack(r); err != nil { - return nil, err - } - return ans, nil -} - -// QueryFor forward a DNS request for uid (if set) -// or to chosen transport, tid (if uid is not set). -func QueryFor(msg *dns.Msg, uid, tid string) (*dns.Msg, error) { - q, qerr := msg.Pack() - if qerr != nil { - return nil, qerr - } - - // uid may be core.UNKNOWN_UID_STR - r, rerr := ipm.Lookup(q, uid, tid) - - if rerr != nil { - return nil, rerr - } - - ans := &dns.Msg{} - if aerr := ans.Unpack(r); aerr != nil { - return nil, aerr - } - return ans, nil -} diff --git a/intra/dialers/example/main.go b/intra/dialers/example/main.go deleted file mode 100644 index 762e2159..00000000 --- a/intra/dialers/example/main.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2020 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "crypto/tls" - "flag" - "fmt" - "log" - "net" - "os" - - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/protect" -) - -func main() { - flag.Usage = func() { - _, _ = fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [-sni=SNI] destination\n", os.Args[0]) - _, _ = fmt.Fprintln(flag.CommandLine.Output(), "This tool attempts a TLS connection to the "+ - "destination (port 443), with and without splitting. If the SNI is specified, it "+ - "overrides the destination, which can be an IP address.") - flag.PrintDefaults() - } - - sni := flag.String("sni", "", "Server name override") - flag.Parse() - destination := flag.Arg(0) - if destination == "" { - flag.Usage() - return - } - - addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(destination, "443")) - if err != nil { - log.Fatalf("Couldn't resolve destination: %v", err) - } - - if *sni == "" { - *sni = destination - } - tlsConfig := &tls.Config{ServerName: *sni, MinVersion: tls.VersionTLS12} - - log.Println("Trying direct connection") - conn, err := net.DialTCP(addr.Network(), nil, addr) - if err != nil { - log.Fatalf("Could not establish a TCP connection: %v", err) - } - tlsConn := tls.Client(conn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - log.Printf("Direct TLS handshake failed: %v", err) - } else { - log.Printf("Direct TLS succeeded") - } - - log.Println("Trying split connection") - d := protect.MakeNsRDial("test", context.TODO(), nil) - splitConn, err := dialers.SplitDial(d, "tcp", addr.String()) - if err != nil { - log.Fatalf("Could not establish a splitting socket: %v", err) - } - tlsConn2 := tls.Client(splitConn, tlsConfig) - err = tlsConn2.Handshake() - if err != nil { - log.Printf("Split TLS handshake failed: %v", err) - } else { - log.Printf("Split TLS succeeded") - } -} diff --git a/intra/dialers/ips.go b/intra/dialers/ips.go deleted file mode 100644 index bf82e21f..00000000 --- a/intra/dialers/ips.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "net" - "net/netip" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/protect/ipmap" -) - -const ( - errNilConn = net.UnknownNetworkError("nil connection") - errNoConn = net.UnknownNetworkError("no connection") - errNoSysConn = net.UnknownNetworkError("no sys connection") - errNoDesyncConn = net.UnknownNetworkError("no desync connection") - errTLSHandshake = net.UnknownNetworkError("tls handshake may be failed") - errNoIps = net.UnknownNetworkError("no ips") - errNoEch = net.UnknownNetworkError("no ech") - errNoDialer = net.UnknownNetworkError("no dialer") - errNoRetrier = net.UnknownNetworkError("no retrier") - errNoListener = net.UnknownNetworkError("no listener") -) - -var ipm ipmap.IPMap = ipmap.NewIPMap() - -// Resolves hostOrIP, and re-seeds it if existing is non-empty. -// hostOrIP may be host:port, or ip:port, or host, or ip. -func renew(hostOrIP string, existing *ipmap.IPSet) (cur *ipmap.IPSet, ok bool) { - // will never be able to resolve protected hosts (UidSelf, UidRethink), - // and so, keep existing as-is (we do not want to use NewProtected and - // race against dnsx.RegisterAddrs or other clients updating UidSelf or - // UidRethink as changes come in from kotlinland intra.Bridge) - if protect.NeverResolve(hostOrIP) { - cur = existing.Reset() - } else if existing.Protected() { - // if protected, preserve seed addrs; then resolve hostOrIP - NewProtected(hostOrIP, existing.Seed()) - cur = ipm.Add(hostOrIP) - // fallthrough - } else if existing.Empty() { - // if empty, discard seed, re-resolve hostOrIP; oft times, ipset is - // empty when its ips have been disconfirmed beyond some threshold - cur = ipm.Add(hostOrIP) - if cur.Empty() { - // if still empty, fallback on seed addrs - cur, _ = New(hostOrIP, existing.Seed()) - } // else: fallthrough - } else { - // if non-empty, renew hostOrIP with seed addrs - // existing may be of typ IPAddr, in which case - // existing.Seed() would be empty, and hostOrIP - // should be a valid IP or IP:Port. - New(hostOrIP, existing.Seed()) - cur = ipm.Add(hostOrIP) - } - if cur == nil { // can never happen as Add/New/NewProtected return a non-nil ipset - return nil, false - } - return cur, !cur.Empty() -} - -// New re-seeds hostOrIP with a new set of ips. -// hostOrIP may be host:port, or ip:port, or host, or ip. -// ipps may be ip or ip:port. -func New(hostOrIP string, ipps []string) (*ipmap.IPSet, bool) { - ips := ipm.MakeIPSet(hostOrIP, ipps, ipmap.AutoType) - return ips, !ips.Empty() -} - -// hostOrIP may be host:port, or ip:port, or host, or ip. -func NewProtected(hostOrIP string, ipps []string) (*ipmap.IPSet, bool) { - ips := ipm.MakeIPSet(hostOrIP, ipps, ipmap.Protected) - return ips, !ips.Empty() -} - -// For returns addresses for hostOrIP from cache, resolving them if missing. -// Underlying cache relies on Disconfirm() to remove unreachable IP addrs; -// if not called, these entries may go stale. Use Resolve() to bypass cache. -// hostOrIP may be host:port, or ip:port, or host, or ip. -func For(hostOrIP string) []netip.Addr { - ipset := ipm.Get(hostOrIP) - if ipset != nil { - return ipset.Addrs() - } - return nil -} - -// Ptr returns hostnames from the ipmap cache, given an IP address. -func Ptr(ip netip.Addr) []string { - return ipm.ReverseGet(ip) -} - -func Confirmed(hostOrIP string) (zz netip.Addr) { - if ipset := ipm.GetAny(hostOrIP); ipset != nil { - return ipset.Confirmed() - } - return -} - -// CachedAddrs returns addresses for hostOrIP from cache. Use Resolve() to bypass cache. -func CachedAddrs(hostOrIP string) []netip.Addr { - ipset := ipm.GetAny(hostOrIP) - if ipset != nil || !ipset.Empty() { - return ipset.Addrs() - } - return nil -} - -// Mapper is a hostname to IP (a/aaaa) resolver for the network engine; may be nil. -func Mapper(m ipmap.IPMapper) { - log.I("dialers: ips: mapper ok? %t", m != nil) - // usually set once per tunnel disconnect/reconnect - ipm.With(m) -} - -func Clear() { - // do not need to handle panics w/ core.Recover - ipm.Clear() // does not clear UidSelf, UidSystem (protected) -} - -// Confirm3 marks addr as preferred for hostOrIP -func Confirm3(hostOrIP string, addr net.Addr) bool { - return Confirm2(hostOrIP, addr.String()) -} - -func Confirm(hostOrIP string, addr netip.Addr) bool { - if ipok(addr) { // confirms ONLY valid ips - ips := ipm.GetAny(hostOrIP) - ips.Confirm(addr) - return ips != nil - } - return false -} - -func Confirm2(hostOrIP string, addr string) bool { - return Confirm(hostOrIP, ipof(addr)) -} - -// Disconfirm3 unmarks addr as preferred for hostOrIP -func Disconfirm3(hostOrIP string, addr net.Addr) bool { - return Disconfirm2(hostOrIP, addr.String()) -} - -// Disconfirm unmarks addr as preferred for hostOrIP -func Disconfirm(hostOrIP string, addr netip.Addr) bool { - ips := ipm.GetAny(hostOrIP) - if ips != nil { - return ips.Disconfirm(addr) // disconfirms ANY ip (invalid/unspecified) - } // not ok - return false -} - -// Disconfirm2 unmarks addr as preferred for hostOrIP -func Disconfirm2(hostOrIP string, addr string) bool { - return Disconfirm(hostOrIP, ipof(addr)) -} - -func ipof(addr string) (zz netip.Addr) { - if ipp, err := netip.ParseAddrPort(addr); err == nil { - return ipp.Addr() - } else if ip, err := netip.ParseAddr(addr); err == nil { - return ip - } - return -} diff --git a/intra/dialers/link.go b/intra/dialers/link.go deleted file mode 100644 index 04af2fb3..00000000 --- a/intra/dialers/link.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -var ipProto *core.Volatile[string] = core.NewVolatile(settings.IP46) - -func Use4() bool { - d := true // by default, use4 - switch x := ipProto.Load(); x { - case settings.IP6: - return false - case settings.IP4: - fallthrough - case settings.IP46: - return true - default: - log.W("dialers: use4: invalid protos %s; default: %s", x, d) - return d - } -} - -func Use6() bool { - d := false // by default, use4 instead - switch x := ipProto.Load(); x { - case settings.IP4: - return false - case settings.IP6: - fallthrough - case settings.IP46: - return true - default: - log.W("dialers: use6: invalid protos %s; default: %s", x, d) - return d - } -} - -// p must be one of settings.IP4, settings.IP6, or settings.IP46 -func IPProtos(ippro string) (diff bool) { - switch ippro { - case settings.IP4: - fallthrough - case settings.IP6: - fallthrough - case settings.IP46: - diff = ipProto.Swap(ippro) != ippro - default: - log.D("dialers: ips: invalid protos %s; use existing: %s", ippro, ipProto.Load()) - return - } - log.I("dialers: ips: protos set to %s; diff? %t", ippro, diff) - return -} diff --git a/intra/dialers/op.go b/intra/dialers/op.go deleted file mode 100644 index f1005e50..00000000 --- a/intra/dialers/op.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "encoding/binary" - "io" - "math/rand" - "net" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -// Copy one buffer from src to dst, using dst.Write. -func copyOnce(dst io.Writer, src io.Reader) (int64, error) { - // A buffer large enough to hold any ordinary first write - // without introducing extra splitting. - bptr := core.Alloc16() - buf := *bptr - buf = buf[:cap(buf)] - defer func() { - *bptr = buf - core.Recycle(bptr) - }() - - n, err := src.Read(buf) // src: netstack; downstream conn - - if err != nil { - log.W("op: copyOnce: read %d/%d; err %v", n, len(buf), err) - return 0, err - } - - wn, err := dst.Write(buf[:n]) // dst: retrier; upstream conn - - logeif(err)("op: copyOnce: rw %d/%d; err %v", n, wn, err) - return int64(n), err -} - -func getTLSClientHelloRecordLen(h []byte) (uint16, bool) { - if len(h) < 5 { - return 0, false - } - - const ( - TYPE_HANDSHAKE byte = 22 - VERSION_TLS10 uint16 = 0x0301 - VERSION_TLS11 uint16 = 0x0302 - VERSION_TLS12 uint16 = 0x0303 - VERSION_TLS13 uint16 = 0x0304 - ) - - if h[0] != TYPE_HANDSHAKE { - return 0, false - } - - ver := binary.BigEndian.Uint16(h[1:3]) - if ver != VERSION_TLS10 && ver != VERSION_TLS11 && - ver != VERSION_TLS12 && ver != VERSION_TLS13 { - return 0, false - } - - return binary.BigEndian.Uint16(h[3:5]), true -} - -func writeTCPSplit(w net.Conn, hello []byte) (n int, err error) { - var p, q int - to := raddr(w) - from := laddr(w) - - first, second := splitHello(hello) - - if p, err = w.Write(first); err != nil { - log.E("op: splits: TCP1 %s (%d): err %v", to, len(first), err) - return p, err - } else if q, err = w.Write(second); err != nil { - log.E("op: splits: TCP2 %s (%d): err %v", to, len(second), err) - return p + q, err - } - if settings.Debug { - log.D("op: splits: %s=>%s; TCP: %d/%d,%d/%d", from, to, p, len(first), q, len(second)) - } - - return p + q, nil -} - -// upb-syssec.github.io/blog/2023/record-fragmentation/ -// from: github.com/Jigsaw-Code/Intra/blob/27637e0ed497/Android/app/src/go/intra/split/retrier.go#L245 -func writeTCPOrTLSSplit(w net.Conn, hello []byte) (n int, err error) { - to := raddr(w) - from := laddr(w) - - if len(hello) <= 1 { - n, err = w.Write(hello) - if settings.Debug { - log.D("op: splits: %s=>%s; len(hello) <= 1; n: %d; err: %v", from, to, n, err) - } - return - } - - const ( - MIN_SPLIT int = 6 - MAX_SPLIT int = 64 - ) - - // random number in the range [MIN_SPLIT, MAX_SPLIT] - // splitLen includes 5 bytes of TLS header - splitLen := MIN_SPLIT + rand.Intn(MAX_SPLIT+1-MIN_SPLIT) - limit := len(hello) / 2 - if splitLen > limit { - splitLen = limit - } - - recordLen, ok := getTLSClientHelloRecordLen(hello) - recordSplit1Len := splitLen - 5 - recordSplit2Len := recordLen - uint16(recordSplit1Len) - if !ok || recordSplit1Len <= 0 || recordSplit1Len >= int(recordLen) { - // TCP split if hello is not a valid TLS Client Hello, or cannot be fragmented - return writeTCPSplit(w, hello) - } - - bptr := core.AllocRegion(len(hello)) - parcel := *bptr - parcel = parcel[:cap(parcel)] - defer func() { - *bptr = parcel - core.Recycle(bptr) - }() - // TLS record layout: - // +-------------+ 0 - // | RecordType | - // +-------------+ 1 - // | Protocol | - // | Version | - // +-------------+ 3 - // | Record | - // | Length | - // +-------------+ 5 - // | Message | - // | Data | - // +-------------+ Message Length + 5 - // - // RecordType := invalid(0) | handshake(22) | application_data(23) | ... - // LegacyRecordVersion := 0x0301 ("TLS 1.0") | 0x0302 ("TLS 1.1") | 0x0303 ("TLS 1.2") - // 0 < Message Length (of handshake) โ‰ค 2^14 - // 0 โ‰ค Message Length (of application_data) โ‰ค 2^14 - // - // datatracker.ietf.org/doc/html/rfc8446#section-5.1 - // see: github.com/Jigsaw-Code/outline-sdk/blob/19f51846/transport/tlsfrag/tls.go#L24 - - // do not modify hello in-place as it "updates" the underlying buffer - // (go.dev/play/p/CffJ3XziU5u) which breaks the io.Writer.Write contract. - - // 1. copy the split which includes the record header - // 2. write len(message data) of this split from [3:5] - p := copy(parcel, hello[:splitLen]) - binary.BigEndian.PutUint16(parcel[3:5], uint16(recordSplit1Len)) - n, err = w.Write(parcel[:p]) - if err != nil { - log.E("op: Splits: %s=>%s; TLS1 %d/%d; n: %d; err: %v", from, to, splitLen, len(hello), n, err) - return - } - - // 3. copy the rest of the message data + trailing space for the 5 byte record header - // 4. write the original record header from [0:5] - // 5. write len(message data) of this split from [3:5] - q := copy(parcel, hello[splitLen-5:]) - aux := copy(parcel, hello[:5]) // repeated - binary.BigEndian.PutUint16(parcel[3:5], recordSplit2Len) - m, err := w.Write(parcel[:q]) - // discount repeated 5-byte header from total bytes - n += max(m-aux, 0) - - logeif(err)("op: splits: %s=>%s; TLS2 %d/%d; n: %d, m: %d; err: %v", - from, to, splitLen, len(hello), n, m, err) - // if n > len(hello); return len(hello) to avoid confusion with the callers - // that expect bytes written to be equal to the length of the input buffer. - // splits: [:f29]:55476=>[:f5e]:443; TLS2 51/2048; n: 2053; err: - // F c.upload: [11] runtime error: slice bounds out of range [:2053] with capacity 2048 - // from: dialers.(*retrier).sendCopyHello - return min(n, len(hello)), err -} - -// splitHello splits the TLS client hello message into two. -func splitHello(hello []byte) ([]byte, []byte) { - if len(hello) == 0 { - return hello, hello - } - const ( - min int = 32 - max int = 64 - ) - - // Random number in the range [MIN_SPLIT, MAX_SPLIT] - s := min + rand.Intn(max+1-min) - limit := len(hello) / 2 - if s > limit { - s = limit - } - return hello[:s], hello[s:] -} - -// laddr returns the local address of the connection. -func laddr(c net.Conn) net.Addr { - if c != nil && core.IsNotNil(c) { - return c.LocalAddr() - } - return zeroNetAddr{} -} - -func raddr(c net.Conn) net.Addr { - if c != nil && core.IsNotNil(c) { - return c.RemoteAddr() - } - return zeroNetAddr{} -} - -func logeif(e error) log.LogFn { - if e != nil { - return log.E - } else { - return log.D - } -} - -func logeor(e error, d log.LogFn) log.LogFn { - if e != nil { - return log.E - } - return d -} - -func logedcond(x bool) log.LogFn { - if x { - return log.E - } else { - return log.D - } -} diff --git a/intra/dialers/pdial.go b/intra/dialers/pdial.go deleted file mode 100644 index 59820d84..00000000 --- a/intra/dialers/pdial.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "net" - "net/netip" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "golang.org/x/net/proxy" -) - -// todo: dial bound to the local address if specified -func proxyConnect(d *proxy.Dialer, proto string, local, remote netip.AddrPort) (net.Conn, error) { - if d == nil { // unlikely - log.E("pdial: proxyConnect: nil dialer") - return nil, errNoDialer - } else if !ipok(remote.Addr()) { - log.E("pdial: proxyConnect: invalid ip", remote) - return nil, errNoIps - } - - return (*d).Dial(proto, remote.String()) -} - -// ProxyDial tries to connect to addr using d -func ProxyDial(d proxy.Dialer, network, addr string) (net.Conn, error) { - if d == nil || core.IsNil(d) { - log.E("pdial: ProxyDial: nil dialer") - return nil, errNoDialer - } - return unPtr(commondial(&d, network, addr, adaptProxyDial(proxyConnect))) -} - -// ProxyDials tries to connect to addr using each dialer in dd -func ProxyDials(dd []proxy.Dialer, network, addr string) (c net.Conn, errs error) { - start := time.Now() - tot := len(dd) - for i, d := range dd { - if time.Since(start) > dialRetryTimeout { - errs = core.JoinErr(errs, errRetryTimeout) - break - } - conn, err := ProxyDial(d, network, addr) - if conn == nil && err == nil { - errs = core.JoinErr(errs, errNoConn) - } else if err != nil { - clos(conn) - log.W("pdial: trying %s dialer of %d / %d to %s", network, i, tot, addr) - errs = core.JoinErr(errs, err) - } else if conn != nil { - c = conn - errs = nil - return - } - } - log.W("pdial: no dialer (sz: %d) could connect to %s", tot, addr) - return nil, core.OneErr(errs, errNoDialer) -} diff --git a/intra/dialers/rdial.go b/intra/dialers/rdial.go deleted file mode 100644 index 242fd66b..00000000 --- a/intra/dialers/rdial.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "crypto/tls" - "errors" - "net" - "net/netip" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" -) - -func netConnect2(d *protect.RDialer, proto string, laddr, raddr netip.AddrPort) (net.Conn, error) { - if d == nil { - log.E("rdial: netConnect: nil dialer") - return nil, errNoDialer - } else if !ipok(raddr.Addr()) { - log.E("rdial: netConnect: invalid ip", raddr) - return nil, errNoIps - } - - if laddr.IsValid() { - return (*d).DialBind(proto, laddr.String(), raddr.String()) - } else { - return (*d).Dial(proto, raddr.String()) - } -} - -// ipConnect dials into ip:port using the provided dialer and returns a net.Conn -// net.Conn is guaranteed to be either net.UDPConn or net.TCPConn -func ipConnect(d *protect.RDial, proto string, laddr, raddr netip.AddrPort) (net.Conn, error) { - if d == nil { - log.E("rdial: ipConnect: nil dialer") - return nil, errNoDialer - } else if !ipok(raddr.Addr()) { - log.E("rdial: ipConnect: invalid ip", raddr) - return nil, errNoIps - } - - if laddr.IsValid() { - switch proto { - case "tcp", "tcp4", "tcp6": - return d.DialTCP(proto, net.TCPAddrFromAddrPort(laddr), net.TCPAddrFromAddrPort(raddr)) - case "udp", "udp4", "udp6": - return d.DialUDP(proto, net.UDPAddrFromAddrPort(laddr), net.UDPAddrFromAddrPort(raddr)) - default: - return d.DialBind(proto, laddr.String(), raddr.String()) - } - } else { - switch proto { - case "tcp", "tcp4", "tcp6": - return d.DialTCP(proto, nil, net.TCPAddrFromAddrPort(raddr)) - case "udp", "udp4", "udp6": - return d.DialUDP(proto, nil, net.UDPAddrFromAddrPort(raddr)) - default: - return d.Dial(proto, raddr.String()) - } - } -} - -func doSplit(ipp netip.AddrPort) bool { - ip := ipp.Addr() - port := ipp.Port() - // HTTPS or DoT - return !ip.IsPrivate() && (port == 443 || port == 853) -} - -func splitIpConnect(d *protect.RDial, proto string, laddr, raddr netip.AddrPort) (net.Conn, error) { - if d == nil { - log.E("rdial: splitIpConnect: nil dialer") - return nil, errNoDialer - } else if !ipok(raddr.Addr()) { - log.E("rdial: splitIpConnect: invalid ip", raddr) - return nil, errNoIps - } - - if laddr.IsValid() { - switch proto { - case "tcp", "tcp4", "tcp6": - remote := net.TCPAddrFromAddrPort(raddr) - local := net.TCPAddrFromAddrPort(laddr) - if doSplit(raddr) { - return DialWithSplitRetry(d, local, remote) - } - return d.DialTCP(proto, local, remote) - case "udp", "udp4", "udp6": - remote := net.UDPAddrFromAddrPort(raddr) - local := net.UDPAddrFromAddrPort(laddr) - return d.DialUDP(proto, local, remote) - default: - return d.DialBind(proto, laddr.String(), raddr.String()) - } - } else { - switch proto { - case "tcp", "tcp4", "tcp6": - tcpaddr := net.TCPAddrFromAddrPort(raddr) - if doSplit(raddr) { - return DialWithSplitRetry(d, nil, tcpaddr) - } - return d.DialTCP(proto, nil, tcpaddr) - case "udp", "udp4", "udp6": - return d.DialUDP(proto, nil, net.UDPAddrFromAddrPort(raddr)) - default: - return d.Dial(proto, raddr.String()) - } - } -} - -// ListenPacket listens on for UDP connections on the local address using d. -// Returned net.Conn is guaranteed to be a *net.UDPConn. -func ListenPacket(d *protect.RDial, network, local string) (net.PacketConn, error) { - if d == nil { - log.E("rdial: ListenPacket: nil dialer") - return nil, errNoListener - } - return d.AnnounceUDP(network, local) -} - -// Listen listens on for TCP connections on the local address using d. -func Listen(d *protect.RDial, network, local string) (net.Listener, error) { - if d == nil { - log.E("rdial: Listen: nil dialer") - return nil, errNoListener - } - return d.AcceptTCP(network, local) -} - -// Probe sends and accepts ICMP packets on local addr using d over a net.PacketConn. -func Probe(d *protect.RDial, network, local string) (net.PacketConn, error) { - // commondial does not handle unspecified ips well; see: ipmap.go & ipok() - // return unPtr(commondial(d, network, addr, adaptp(icmpListen))) - return d.ProbeICMP(network, local) -} - -// Dial dials into addr using the provided dialer and returns a net.Conn, -// which is guaranteed to be either net.UDPConn or net.TCPConn -func Dial(d *protect.RDial, network, addr string) (net.Conn, error) { - return unPtr(commondial(d, network, addr, adaptRDial(ipConnect))) -} - -// SplitDial dials into addr splitting the first segment to two if the -// first connection is unsuccessful, using settings.DialStrategy. -// Returns a net.Conn, which may not be net.UDPConn or net.TCPConn. -func SplitDial(d *protect.RDial, network, addr string) (net.Conn, error) { - return unPtr(commondial(d, network, addr, adaptRDial(splitIpConnect))) -} - -func DialBind(d *protect.RDial, network, local, remote string) (net.Conn, error) { - return unPtr(commondial2(d, network, local, remote, adaptRDial(ipConnect))) -} - -func SplitDialBind(d *protect.RDial, network, local, remote string) (net.Conn, error) { - return unPtr(commondial2(d, network, local, remote, adaptRDial(splitIpConnect))) -} - -// DialWithTls dials into addr using the provided dialer and returns a tls.Conn -func DialWithTls(d protect.RDialer, cfg *tls.Config, network, addr string) (net.Conn, error) { - return dialtls(&d, cfg, network, "", addr, adaptRDialer(netConnect2)) -} - -// DialWithTls dials into addr using the provided dialer and returns a tls.Conn -func DialBindWithTls(d protect.RDialer, cfg *tls.Config, network, local, remote string) (net.Conn, error) { - return dialtls(&d, cfg, network, local, remote, adaptRDialer(netConnect2)) -} - -func dialtls[D rdials](d D, cfg *tls.Config, network, local, remote string, how dialFn[D, *net.Conn]) (net.Conn, error) { - c, err := unPtr(commondial2(d, network, local, remote, how)) - if err != nil { - clos(c) - return nil, err - } - - tlsconn, err := tlsHello(c, cfg, remote) - - if eerr := new(tls.ECHRejectionError); errors.As(err, &eerr) { - clos(tlsconn) - - manual := false - ech := eerr.RetryConfigList - if len(ech) <= 0 { - ech, _ = ECH(cfg.ServerName) - manual = true - } - log.I("rdial: tls: ech rejected; new? %d / manual? %t, err: %v", - len(ech), manual, eerr) - if len(ech) > 0 { // retry with new ech - cfg.EncryptedClientHelloConfigList = ech - c, err = unPtr(commondial2(d, network, local, remote, how)) - if err != nil { - clos(c) - return nil, err - } - tlsconn, err = tlsHello(c, cfg, remote) - } - } - if err != nil { - clos(tlsconn) - tlsconn = nil - } - return tlsconn, err -} - -func tlsHello(c net.Conn, cfg *tls.Config, addr string) (*tls.Conn, error) { - if c == nil || core.IsNil(c) { - return nil, errNilConn - } - switch c := c.(type) { - case *tls.Conn: - return c, nil - } - - tlsconn := tls.Client(c, ensureSni(cfg, addr)) - err := tlsconn.Handshake() - - if err != nil { - clos(tlsconn) - } - return tlsconn, err -} - -func ensureSni(cfg *tls.Config, addr string) *tls.Config { - if cfg == nil { - cfg = &tls.Config{ - ServerName: sni(addr), - MinVersion: tls.VersionTLS12, - } - } else if len(cfg.ServerName) <= 0 { - cfg.ServerName = sni(addr) - } - return cfg -} - -func sni(addr string) string { - host, _, err := net.SplitHostPort(addr) - if err != nil { - log.W("rdial: sni %s, err: %v", addr, err) - host = addr // may be ip - } - return host -} diff --git a/intra/dialers/retrier.go b/intra/dialers/retrier.go deleted file mode 100644 index 7ae3061c..00000000 --- a/intra/dialers/retrier.go +++ /dev/null @@ -1,928 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dialers - -import ( - "context" - "fmt" - "io" - "math" - "net" - "net/netip" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" -) - -type zeroNetAddr struct{} - -func (zeroNetAddr) Network() string { return "no" } -func (zeroNetAddr) String() string { return "none" } - -const ( - maxRetryCount = 3 - maxEmptyReads = 3 - - minExpectedTLSRead = 16 // minimum bytes expected in a TLS read - tlsPort = 443 - - // maximum timeout in seconds for reads to complete before retrying - cielRetryReadTimeoutSec = 9 - // minimum timeout in milliseconds for reads to complete before retrying - floorRetryReadTimeoutMillis = 1000 - // max timeout (for ReadFrom, before WriteTo) for desync to complete - uploadTimeoutForDownload = 3 * time.Second -) - -// ippPins maintains a limited-time mapping between ip:port addresses and dialer IDs. -// TODO: invalidate cache on network changes. -// TODO: with context.TODO, expmap's reaper goroutine will leak. -var ippPins = core.NewSieve[netip.AddrPort, string](context.TODO(), desync_cache_ttl) - -// retrier implements the DuplexConn interface and must -// be typecastable to *net.TCPConn (see: xdial.DialTCP) -// inheritance: go.dev/play/p/mMiQgXsPM7Y -type retrier struct { - dialers []protect.RDialer - dialerOpts settings.DialerOpts - nextDialerIdx int - currentStrat int32 - multidial bool - - rport uint16 // raddr port - raddr net.Addr - laddr net.Addr // laddr may be nil; TCPAddr.IP may be nil. - - // Flags indicating whether the caller has called CloseRead and CloseWrite. - readDone atomic.Bool - writeDone atomic.Bool - - // mu is a lock that guards conn, retryCount, tee, timeout, - // retryErr, retryDone, readDeadline, and writeDeadline. - // After retryDoneCh is closed, these values will not be - // modified again so locking is no longer required for reads. - mu sync.Mutex - - // the current underlying connection. It is only modified by the reader - // thread, so the reader functions may access it without acquiring a lock. - // nb: if embedding TCPConn; override its WriteTo instead of just ReadFrom - // as io.Copy prefers WriteTo over ReadFrom; or use core.Pipe - conn protect.Conn - - // External read and write deadlines. These need to be stored here so that - // they can be re-applied in the event of a retry. - readDeadline time.Time - writeDeadline time.Time - // Time to wait between the 1st write & the 1st read before triggering a retry. - timeout time.Duration - // tee is the contents written before the first read. It is initially empty, - // and is cleared when the first byte is received. - tee []byte - // retryWriteErr is set to the error from the last retry, if any. - retryWriteErr error - // tracks the number of retries attempted. - retryCount uint8 - // must be set to 1 or more, never 0. - maxRetries uint8 - // Flag indicating when retry is finished or unnecessary. - retryDone *core.SigCond -} - -var _ core.DuplexConn = (*retrier)(nil) -var _ core.RetrierConn = (*retrier)(nil) - -var _ core.DuplexConn = (*net.TCPConn)(nil) - -// retryCompleted returns true if the retry is complete or unnecessary. -func (r *retrier) retryCompleted() bool { return r.retryDone.Cond() } - -func (r *retrier) canRetry() bool { - return r.dialerOpts.Retry != settings.RetryNever -} - -// TODO: make sure "Auto" works as intended for netdev.vger.kernel.narkive.com and norskkalender.no -// Given rtt of a successful socket connection (SYN sent - SYNACK received), -// returns a timeout for replies to the first segment sent on this socket. -func calcTimeout(rtt time.Duration, spread uint16) time.Duration { - spread = max(1, spread) // avoid div by zero - ciel := time.Duration(max(1, (cielRetryReadTimeoutSec/spread))) * time.Second // ciel at least 1secs - floor := time.Duration(min(300, (floorRetryReadTimeoutMillis/spread))) * time.Millisecond // floor is at most 1secs - - // Lower values trigger an unnecessary retry that make connections slower or fail (like nytimes.com) - // However, overly long timeouts make retry slower. - return max(rtt*2, ciel) + min(2*rtt, floor) -} - -// DialWithSplitRetry returns a TCP connection that transparently retries by -// splitting the initial upstream segment if the socket closes without receiving a -// reply. Like net.Conn, it is intended for two-threaded use, with one thread calling -// Read and CloseRead, and another calling Write, ReadFrom, and CloseWrite. -// `dialer` will be used to establish the connection. -// `addr` is the destination. -func DialWithSplitRetry(d *protect.RDial, laddr, raddr *net.TCPAddr) (*retrier, error) { - r := &retrier{ - dialers: []protect.RDialer{d}, - dialerOpts: settings.GetDialerOpts(), - laddr: laddr, // may be nil - raddr: raddr, // must not be nil - maxRetries: maxRetryCount, - retryDone: core.NewSigCond(), - } - - r.mu.Lock() - defer r.mu.Unlock() - - if err := r.dialLocked(); err != nil { - return nil, err - } - return r, nil -} - -func dialerOptsForMultiDialers() settings.DialerOpts { - // see: dialStratLocked - return settings.DialerOpts{ - Strat: settings.SplitNever, - Retry: settings.RetryWithSplit, - } -} - -func reprioritize(ds []protect.RDialer, ipp netip.AddrPort) []protect.RDialer { - // reprioritize the dialers based on the IP:port pair - if !ipp.IsValid() { - return ds - } - if len(ds) <= 1 { - return ds - } - id, ok := ippPins.Get(ipp) - if !ok || len(id) <= 0 { - return ds - } - for i, d := range ds { - if d.ID() == id { - ds[i], ds[0] = ds[0], ds[i] - break - } - } - return ds -} - -func DialAny(ds []protect.RDialer, laddr, raddr net.Addr) (*retrier, error) { - if len(ds) <= 0 { - return nil, errNoDialer - } - - remote := asAddrPort(raddr) - r := &retrier{ - dialers: reprioritize(ds, remote), - dialerOpts: dialerOptsForMultiDialers(), - multidial: true, - maxRetries: uint8(len(ds)), - laddr: laddr, // may be nil - raddr: raddr, // must not be nil - rport: remote.Port(), - retryDone: core.NewSigCond(), - } - - r.mu.Lock() - defer r.mu.Unlock() - - if err := r.dialLocked(); err != nil { - return nil, err - } - return r, nil -} - -// SycallConn implements core.DuplexConn. -func (r *retrier) SyscallConn() (syscall.RawConn, error) { - r.mu.Lock() - c := r.conn - r.mu.Unlock() - if sc, ok := c.(syscall.Conn); ok { - return sc.SyscallConn() - } - log.W("retrier: not a syscall.Conn: %T", c) - return nil, syscall.EINVAL -} - -// SetKeepAlive implements core.DuplexConn. -func (r *retrier) SetKeepAlive(y bool) error { - r.mu.Lock() - c := r.conn - r.mu.Unlock() - if c, ok := c.(core.KeepAliveConn); ok { - return c.SetKeepAlive(y) - } - log.W("retrier: not a net.Conn: %T", c) - return syscall.EINVAL -} - -func (r *retrier) dialStratLocked() (strat int32, err error) { - if r.multidial { // multidialing retrier does not follow strategies - // see: dialerOptsForMultiDialers - return settings.SplitNever, nil - } - - auto := r.dialerOpts.Strat == settings.SplitAuto - retryStrat := r.dialerOpts.Retry - split := r.dialerOpts.Strat != settings.SplitNever - - switch retryStrat { - case settings.RetryNever: - if r.retryCount >= 1 { - err = errNoRetrier // retry not allowed - return - } - split = split && r.retryCount == 0 // split at 1st attempt - case settings.RetryWithSplit: - split = split && r.retryCount >= 1 // split after 1st attempt - case settings.RetryAfterSplit: - split = split && r.retryCount == 0 // split at 1st attempt - if auto { - // split at all attempts except the last - // (also see "auto" cond block below) - split = split && r.retryCount < r.maxRetries - } - } - - if !split { - strat = settings.SplitNever - } else if auto { - // auto always attempts TCP split first as TLS splits - // as not all TLS servers play nice with split TLS records. - attemptCycle := r.retryCount % r.maxRetries - switch retryStrat { - case settings.RetryNever: - // only one attempt allowed; neither retried nor split - strat = settings.SplitTCP - case settings.RetryWithSplit: // "auto" retry - // if retrying (retryCount > 0), always split - if attemptCycle == 1 { - strat = settings.SplitTCP - } else if attemptCycle == 2 { - strat = settings.SplitTCPOrTLS - } else { // split is either true or false - strat = settings.SplitDesync - } - case settings.RetryAfterSplit: - // split for the first two attempts - if attemptCycle == 0 { - strat = settings.SplitTCP - } else if attemptCycle == 1 { - strat = settings.SplitTCPOrTLS - } else { - // split is false when retryCount >= r.maxRetries, - // and so, the strat here does not matter - strat = settings.SplitTCP - } - } - } else { - strat = r.dialerOpts.Strat - } - - return -} - -func (r *retrier) dialerID() string { - di := 0 - if r.multidial { - di = min(max(di, r.nextDialerIdx-1), len(r.dialers)-1) - } - return r.dialers[di].ID() -} - -// dialLocked establishes a new connection to r.raddr and closes existing, if any. -// Sets r.conn on non-errors and timeout as calculated from round-trip time. -func (r *retrier) dialLocked() error { - clos(r.conn) // close existing connection, if any - - strat, err := r.dialStratLocked() - if err != nil { - return err - } - r.currentStrat = strat - - spreadTimeoutOver := int(r.maxRetries) - int(r.retryCount) - if nosplit := strat == settings.SplitNever; nosplit { - spreadTimeoutOver = 0 // no spread if no split - } - - begin := time.Now() - c, err := r.doDialLocked(strat) - rtt := time.Since(begin) - - if c != nil && core.IsNotNil(c) { // c may be deep nil - r.conn = c - } else { - r.conn = nil - } - - if r.canRetry() { - // final retry gets maximum possible timeout - r.timeout = calcTimeout(rtt, uint16(spreadTimeoutOver)) - } else { - // if retries are disabled, then do not aggressively timeout - // as there's nothing else for the retrier to do. - r.timeout = 0 - } - - logeif(err)("retrier: dial(%s) %s=>%s; strat: %d+%d (mult? %d %T), rtt: %s / to: %s; err? %v", - r.dialerID(), laddr(c), r.raddr, strat, r.dialerOpts.Retry, len(r.dialers), c, core.FmtPeriod(rtt), core.FmtPeriod(r.timeout), err) - - return err -} - -func (r *retrier) isSplitStrat() bool { - if r.multidial { // multidial do not follow strats - return false - } - return r.currentStrat == settings.SplitTCP || r.currentStrat == settings.SplitTCPOrTLS -} - -// dialStrat returns a core.DuplexConn to r.raddr using a specified strategy, strat, -// which is one of the settings.Split* constants. -func (r *retrier) doDialLocked(dialStrat int32) (protect.Conn, error) { - if r.multidial { - var errs error - if r.nextDialerIdx >= len(r.dialers) && r.retryCount < r.maxRetries { - r.nextDialerIdx = 0 - log.D("retrier: mult: %s: reset dialer index; retry # %d / %d", - r.dialerID(), r.retryCount, r.maxRetries) - } - for r.nextDialerIdx < len(r.dialers) { - d := r.dialers[r.nextDialerIdx] - c, err := protect.Dial(d, r.laddr, r.raddr) - logeif(err)("retrier: mult: #%d/%d dial(%s: %s) %s=>%s; err? %v", - r.nextDialerIdx, len(r.dialers), d.ID(), r.dialerOpts, laddr(c), r.raddr, err) - - r.nextDialerIdx++ // incr regardless of err - - if err == nil { - return c, nil - } - clos(c) - errs = core.JoinErr(errs, err) - } - return nil, core.OneErr(errs, errNoDialer) - } - - di := r.dialers[0] // always use the first dialer when not multidialing - - network := r.raddr.Network() - if isTCP := strings.HasPrefix(network, "tcp"); !isTCP { - return protect.Dial(di, r.laddr, r.raddr) - } - - // r.laddr may be nil or laddr.IP may be nil. - switch dialStrat { - case settings.SplitNever: - return protect.Dial(di, r.laddr, r.raddr) - case settings.SplitDesync: - return dialWithSplitAndDesync(di, r.laddr, r.raddr) - case settings.SplitTCP, settings.SplitTCPOrTLS: - fallthrough - default: - } - tc, terr := protect.DialTCP(di, network, r.laddr, r.raddr) - if terr != nil || tc == nil { - return nil, core.JoinErr(terr, errNilConn) - } - // todo: assert strat must be tcp or tls? - return &splitter{conn: tc, strat: dialStrat, used: core.NewSigCond()}, nil -} - -// retryWriteReadLocked closes the current connection, dials a new one, and writes -// the first segment after splitting according to specified dial strategy. -// Returns an error if the dial fails or if the splits could not be written. -func (r *retrier) retryWriteReadLocked(buf []byte) (int, error) { - // r.dialLocked also closes provisional socket - err := r.dialLocked() // errs on dial strat = no retries, too - newConn := r.conn - if err != nil || newConn == nil { - return 0, core.OneErr(err, errNoConn) - } - - var nw int - nw, r.retryWriteErr = newConn.Write(r.tee) - logeif(r.retryWriteErr)("retrier: retryLocked: strat(%s, mult? %d %T) %s=>%s; write? %d/%d; err? %v", - r.dialerID(), len(r.dialers), newConn, laddr(newConn), r.raddr, nw, len(r.tee), r.retryWriteErr) - if r.retryWriteErr != nil { - return 0, r.retryWriteErr - } - - // while we were creating the new socket, the caller might have called CloseRead - // or CloseWrite on the old socket. Copy that state to the new socket. - // CloseRead and CloseWrite are idempotent, so this is safe even if the user's - // action actually affected the new socket. - readdone := r.readDone.Load() - writedone := r.writeDone.Load() - if readdone { - core.CloseOp(newConn, core.CopR) - } - if writedone { - core.CloseOp(newConn, core.CopW) - } - - logedcond(readdone || writedone)("retrier: retryLocked: done! strat(%s; mult? %d %T) %s=>%s; write? %d/%d; closed r/w? %t/%t; rtt: %s, deadline r/w: %v/%v", - r.dialerID(), len(r.dialers), newConn, laddr(newConn), r.raddr, nw, len(r.tee), readdone, writedone, core.FmtPeriod(r.timeout), core.FmtTimeAsPeriod(r.readDeadline), core.FmtTimeAsPeriod(r.writeDeadline)) - - // all of buf was written to c - // require a response within a short timeout on r.conn (same as newConn) - newConn.SetReadDeadline(time.Now().Add(r.readTimeoutLocked())) - return newConn.Read(buf) -} - -// CloseRead closes r.conn for reads, and the read flag. -func (r *retrier) CloseRead() error { - r.readDone.Store(true) - r.mu.Lock() - defer r.mu.Unlock() - core.CloseOp(r.conn, core.CopR) - return nil -} - -// Read data from r.conn into buf ("download" from remote to local). -func (r *retrier) Read(buf []byte) (n int, err error) { - note := log.VV - redoForTls := false - - r.mu.Lock() - c := r.conn // r.conn may be provisional or final connection - r.mu.Unlock() - - if c != nil { - for reads := range maxEmptyReads { - n, err = c.Read(buf) - if n == 0 && err == nil { // no data and no error - note("retrier: read: %s: no data #%d; retrying [%s<=%s], b: 0/%d", - r.dialerID(), reads, laddr(c), r.raddr, len(buf)) - continue // nothing yet to retry; on to next read - } // else: check if retry is needed (c == nil or err != nil) - break - } - if n == 0 && err == nil { - err = io.ErrNoProgress - } - redoForTls = r.isSplitStrat() && r.rport == tlsPort && n < minExpectedTLSRead - if err == nil && redoForTls { - err = errTLSHandshake - } - logeor(err, note)("retrier: read: %s: [%s<=%s]; (rtt: %s / read: %s / redo? %t); b: %d/%d (tee: %d); err: %v", - r.dialerID(), laddr(c), r.raddr, core.FmtPeriod(r.timeout), core.FmtTimeAsPeriod(r.readDeadline), redoForTls, n, len(buf), len(r.tee), err) - } // else: needs retry as c == nil - - note = log.D - - // must enter this block at least once (even if c != nil) - // as it resets read timeout and teed write buffer - if !r.retryCompleted() { - r.mu.Lock() - defer r.mu.Unlock() - - if !r.retryCompleted() { - note = log.I - defer r.retryDone.Signal() // signal completion (success or not) - - retryReadErr := err - // retry on errs like timeouts or connection resets - for (c == nil || redoForTls || retryReadErr != nil) && (r.canRetry() && r.retryCount < r.maxRetries) { - r.retryCount++ - - n, retryReadErr = r.retryWriteReadLocked(buf) - - redoForTls = r.isSplitStrat() && r.rport == tlsPort && n < minExpectedTLSRead - if retryReadErr == nil && redoForTls { - err = errTLSHandshake - } - - c = r.conn // re-assign c to newConn, if any; may be nil - if c == nil || retryReadErr != nil { - retryReadErr = core.OneErr(retryReadErr, errNoConn) - err = core.JoinErr(err, retryReadErr) - } else { - retryReadErr = nil // break - err = nil // return no error - redoForTls = false - } - logeor(retryReadErr, note)("retrier: read: %s: #%d + (mult? %d %T / c: %d): [%s<=%s]; t: %s; redo? %t; b:%d/%d; err? %v", - r.dialerID(), r.retryCount, len(r.dialers), c, r.nextDialerIdx, laddr(c), r.raddr, core.FmtPeriod(r.timeout), redoForTls, n, len(buf), retryReadErr) - } - if c != nil { - // caller might have set read or write deadlines before the retry; - // if not, clear any deadlines set by the retrier - _ = c.SetReadDeadline(r.readDeadline) - _ = c.SetWriteDeadline(r.writeDeadline) - } - logeor(err, note)("retrier: read: %s: #%d + (mult? %d / %d) [%s<=%s]; rshortt: %s / rfullt: %s; b: %d/%d; err? %v", - r.dialerID(), r.retryCount, len(r.dialers), r.nextDialerIdx, laddr(c), r.raddr, core.FmtPeriod(r.timeout), core.FmtTimeAsPeriod(r.readDeadline), n, len(buf), err) - r.tee = nil // discard teed data - return - } - logeor(err, note)("retrier: read: %s already retried! (conn? %t) [%s<=%s]; t: %s; b: %d/%d; err? %v", - r.dialerID(), c != nil, laddr(c), r.raddr, core.FmtPeriod(r.timeout), n, len(buf), err) - } // else: just one read is enough; no retry needed - if c == nil { // retry completed but no conn - cerr := log.EE("retrier: read: %s: no conn! [<=%s]; t: %s; b: %d/%d", - r.dialerID(), r.raddr, core.FmtPeriod(r.timeout), n, len(buf)) - err = core.JoinErr(err, cerr, errNilConn) - } - return -} - -func (r *retrier) teedFirstWrite(b []byte) (n int, firstWrite, didAttemptWrite bool, readWait time.Duration, src net.Addr, err error) { - r.mu.Lock() - defer r.mu.Unlock() - - firstWrite = len(r.tee) <= 0 - - c := r.conn - if c == nil { - err = errNilConn - log.E("retrier: send: %s: tee [] => %s, no conn; sz(%d)", - r.dialerID(), r.raddr, len(b)) - return - } - - src = laddr(c) - if !r.retryCompleted() { // may be first write - _ = c.SetWriteDeadline(r.writeDeadline) - - n, err = c.Write(b) - - // capture first write, aka "hello" - r.tee = append(r.tee, b...) - didAttemptWrite = true - readWait = r.readTimeoutLocked() - // all of b was written to r.tee if not to c - // require a response or another write within a short timeout. - c.SetReadDeadline(time.Now().Add(readWait)) - } - - return -} - -func (r *retrier) readTimeoutLocked() time.Duration { - if r.timeout > 0 { - return r.timeout - } - if r.readDeadline.IsZero() { - // 2501h 59m 59s 25ms: a comfortably high duration in nanos - return math.MaxInt64 >> 10 - } - return time.Until(r.readDeadline) -} - -// Write data in b to retrier's underlying conn, r.conn ("upload" from local to remote). -func (r *retrier) Write(b []byte) (int, error) { - start := time.Now() - // Double-checked locking pattern. This avoids lock acquisition on - // every packet after retry completes, while also ensuring that r.tee is - // empty at steady-state. - if !r.retryCompleted() { - // todo: what if sentAndCopied is false and err != nil? - n, first, sentAndCopied, waitForRead, src, err := r.teedFirstWrite(b) - - note := log.D - if sentAndCopied { - note = log.I - } - - logeor(err, note)("retrier: write: %s: (first? %t, sent? %t) [%v=>%s]; t: %s; b: %d/%d (tee: %d); after: %s; write-err? %v", - r.dialerID(), first, sentAndCopied, src, r.raddr, core.FmtPeriod(r.timeout), n, len(b), len(r.tee), core.FmtTimeAsPeriod(start), err) - - if sentAndCopied { - // if Write() does not wait for <-retryDoneCh in absence of errors, - // it is possible that ReadFrom() => copyOnce() is called before retryDoneCh - // is closed, resulting in two Write() calls, and r.tee containing buffers - // the size of two Writes() - if err == nil { - return n, nil - } // write failed, wait for retry to complete - - start := time.Now() - // write error on the provisional socket should be handled - // by the retry procedure. Block until we have a final socket (which will - // already have replayed r.tee), and retry. - // ie, wait until first write is done on the final socket. - maxUntil := max(waitForRead, waitForRead*time.Duration(r.maxRetries)) - if r.multidial { - maxUntil = max(maxUntil, maxUntil*time.Duration(len(r.dialers))) - } - if !r.retryDone.TryWait(maxUntil) { // timed out waiting for retry completion - rerr := log.EE("retrier: write: %s: 1st write timed-out waiting for %s [calc-rtt: %s] 1st read b/w [%s=>%s], mult: %d, b: %d/%d, err: %v", - r.dialerID(), core.FmtPeriod(maxUntil), core.FmtPeriod(r.timeout), src, r.raddr, len(r.dialers), n, len(b), err) - return n, core.JoinErr(err, rerr, errRetryTimeout) - } - - r.mu.Lock() - defer r.mu.Unlock() - - // r.conn may be nil or closed by the time we get here - finalConn := r.conn - noconn := finalConn == nil || core.IsNil(finalConn) - if r.retryWriteErr != nil || noconn { // check if retried writes also failed - if noconn { - err = core.JoinErr(err, errNilConn) - } - werr := log.EE("retrier: write: %s: retry failed (conn? %t) [%s=>%s] b: %d/%d (tee: %d) in %s; old => new: %v => %v", - r.dialerID(), !noconn, laddr(r.conn), r.raddr, n, len(b), len(r.tee), core.FmtTimeAsPeriod(start), err, r.retryWriteErr) - return n, core.JoinErr(err, r.retryWriteErr, werr) // pass on the og error, too - } - - // if len(leftover) > 0 { - // m, err = newConn.Write(leftover) - // return n + m, err - // } - - // retry write succeeded, nil error - // ie, all of b was written to r.tee which was replayed - return len(b), nil - } // not sent by teedFirstWrite; do a normal write - } - - r.mu.Lock() - c := r.conn // retry has completed, so r.conn is final and may not need locking? - r.mu.Unlock() - if c == nil { - cerr := log.EE("retrier: write: %s: [] => %s (b: %d, tee: %d), not retrying, but no conn; after: %s", - r.dialerID(), r.raddr, len(b), len(r.tee), core.FmtTimeAsPeriod(start)) - return 0, core.JoinErr(cerr, errNilConn) - } - - n, err := c.Write(b) - if err != nil { - err = log.EE("retrier: write: %s: [%s=>%s]; b: %d/%d (retried? %t); after: %s; err? %v", - r.dialerID(), laddr(c), r.raddr, n, len(b), r.retryCompleted(), core.FmtTimeAsPeriod(start), err) - } - return n, err -} - -// WriteTo writes data to writer via r.conn.WriteTo, after (as needed) -// retries are done; before which reads are delegated to copyOnce. -// Usually, WriteTo is executing the "download" phase of egressing conn (w). -func (r *retrier) WriteTo(w io.Writer) (bytes int64, err error) { - start := time.Now() - copies := 0 - // TODO: skip copyOnce if r.multidial set or if strat is SplitNever? - for !r.retryCompleted() { // should iter only once - b, e := copyOnce(w, r) - copies++ - bytes += b - err = e - done := r.retryCompleted() - logeif(e)("retrier: writerto: %s: copyOnce #%d (done? %t) %s<=%s; sz: %d/%d; err: %v", - r.dialerID(), copies, done, laddr(r.conn), r.raddr, b, bytes, e) - if e != nil { - return bytes, e - } // TODO: return after first copyOnce if strat is RetryNever? - } - - wait := r.timeout - // _ = r.retryDone.TryWait(wait) // 0 when no need to wait for retry - - r.mu.Lock() // may block if retry in progress - c := r.conn - tee := len(r.tee) - r.mu.Unlock() - - logedcond(c == nil)("retrier: writerto: %s: (conn? %t) [%s <= %s], tee(%d), waited %s/%s", - r.dialerID(), c != nil, laddr(c), r.raddr, tee, core.FmtTimeAsPeriod(start), core.FmtPeriod(wait)) - - if c == nil { - return bytes, io.ErrUnexpectedEOF - } - - // check if writer is a Unix Domain Socket (splice) or tcp.WriteTo - // won't be optimized (as of go1.25). - _, canOptimizeWriteTo := w.(*net.UnixConn) - - optimizedWriteTo := canOptimizeWriteTo - var b int64 - if canOptimizeWriteTo { - switch x := c.(type) { - case *net.TCPConn: - r.SetDeadline(time.Time{}) - b, err = x.WriteTo(w) - bytes += b - case *splitter: - r.SetDeadline(time.Time{}) - b, err = x.WriteTo(w) - bytes += b - case io.WriterTo: - r.SetDeadline(time.Time{}) - b, err = x.WriteTo(w) - bytes += b - default: // net.UDPConn, net.PacketConn etc? - optimizedWriteTo = false - } - } - - if !optimizedWriteTo { - // write to w from c until EOF - b, err = core.Stream(w, c) - bytes += b - } - - msg := fmt.Sprintf("retrier: writerto: %s: (can? %t / optimized? %t for %T) %s<=%s done; sz: %d; after: %s; err: %v", - r.dialerID(), canOptimizeWriteTo, optimizedWriteTo, c, laddr(c), r.raddr, bytes, core.FmtTimeAsPeriod(start), err) - - if err != nil { - err = log.EE(msg) - } else { - log.V(msg) - } - - return bytes, err -} - -// ReadFrom reads data from reader via r.conn.ReadFrom, after (as needed) -// retries are done; before which reads are delegated to copyOnce. -// Usually, ReadFrom is executing the "upload" phase of egressing conn (r). -func (r *retrier) ReadFrom(reader io.Reader) (bytes int64, err error) { - start := time.Now() - copies := 0 - // TODO: skip copyOnce if r.multidial set or if strat is SplitNever? - for !r.retryCompleted() { // should iter only once - b, e := copyOnce(r, reader) - copies++ - bytes += b - err = e - done := r.retryCompleted() - logeif(e)("retrier: readfrom: %s: copyOnce #%d (done? %t) %s<=%s; sz: %d/%d; err: %v", - r.dialerID(), copies, done, laddr(r.conn), r.raddr, b, bytes, e) - if e != nil { - return bytes, e - } - // TODO: return after first copyOnce if strat is RetryNever? - } - - r.mu.Lock() - c := r.conn // reader thread does not need the mutex? - tee := len(r.tee) - r.mu.Unlock() - - if c == nil { - log.E("retrier: readfrom: %s: [] <= %s, no conn; after# %d: sz(%d) tee(%d); after %s", - r.dialerID(), r.raddr, copies, bytes, tee, core.FmtTimeAsPeriod(start)) - return bytes, io.ErrUnexpectedEOF - } - - pinned := false - pinnedID := "" - if r.multidial { - if ipp := asAddrPort(r.raddr); ipp.IsValid() { - // cache the dialer ID for the IP:port pair - pinnedID = r.dialerID() - ippPins.Put(ipp, pinnedID) - pinned = true - } - } - - // check if reader is SyscallConn (sendfile) or *net.TCPConn (splice) - // as tcp.ReadFrom otherwise won't be optimized but rather use built-in - // io.Copy which doesn't use a buffer pool unlike core.Stream (go1.25). - _, canOptimizeReadFrom := reader.(*net.TCPConn) - if !canOptimizeReadFrom { - if sc, ok := reader.(syscall.Conn); ok { - _, err := sc.SyscallConn() - canOptimizeReadFrom = err == nil - } - } - - // disable read and write deadlines as io.ReaderFrom does not - // rely on io.Read and io.Write semantics for "r.conn" from which - // deadlines are extended to avoid timeouts (see also: rwconn.go) - optimizedReadFrom := canOptimizeReadFrom - var b int64 - if canOptimizeReadFrom { - switch x := c.(type) { - case *net.TCPConn: - r.SetDeadline(time.Time{}) - b, err = x.ReadFrom(reader) - bytes += b - case *splitter: - r.SetDeadline(time.Time{}) - b, err = x.ReadFrom(reader) - bytes += b - case io.ReaderFrom: - r.SetDeadline(time.Time{}) - b, err = x.ReadFrom(reader) - bytes += b - default: // net.UDPConn, net.PacketConn etc? - optimizedReadFrom = false - } - } - - if !optimizedReadFrom { - // read from reader into c until EOF - b, err = core.Stream(c, reader) - bytes += b - } - - msg := fmt.Sprintf("retrier: readfrom: %s: (can? %t / optimized? %t for %T) done (id:%s/pinned?%t) %s<=%s; sz: %d; after: %s; err: %v", - r.dialerID(), canOptimizeReadFrom, optimizedReadFrom, c, pinnedID, pinned, laddr(c), r.raddr, bytes, core.FmtTimeAsPeriod(start), err) - - if err != nil { - err = log.EE(msg) - } else { - log.V(msg) - } - return -} - -// CloseWrite closes r.conn for writes, the write flag. -func (r *retrier) CloseWrite() error { - r.writeDone.Store(true) - r.mu.Lock() - defer r.mu.Unlock() - core.CloseOp(r.conn, core.CopW) - return nil -} - -// Close closes the connection and the read and write flags. -func (r *retrier) Close() error { - // also close the read and write flags - return core.JoinErr(r.CloseRead(), r.CloseWrite()) -} - -// LocalAddr behaves slightly strangely: its value may change as a -// result of a retry. However, LocalAddr is largely useless for -// TCP client sockets anyway, so nothing should be relying on this. -func (r *retrier) LocalAddr() net.Addr { - r.mu.Lock() - defer r.mu.Unlock() - if c := r.conn; c != nil { - return c.LocalAddr() - } - return zeroNetAddr{} -} - -// RemoteAddr returns the remote address of the connection. -func (r *retrier) RemoteAddr() net.Addr { - return r.raddr -} - -// SetReadDeadline sets the read deadline for the connection -// if the retry is complete, otherwise it does so after the retry. -func (r *retrier) SetReadDeadline(t time.Time) error { - r.mu.Lock() - defer r.mu.Unlock() - r.readDeadline = t - // Don't enforce read deadlines until after the retry - // is complete. Retry relies on setting its own read - // deadline, and we don't want this to interfere. - if r.retryCompleted() { - if c := r.conn; c != nil { - return c.SetReadDeadline(t) - } - return errNoConn - } - return nil -} - -// SetWriteDeadline sets the write deadline for the connection. -func (r *retrier) SetWriteDeadline(t time.Time) error { - r.mu.Lock() - defer r.mu.Unlock() - r.writeDeadline = t - if c := r.conn; c != nil { - return c.SetWriteDeadline(t) - } - return errNoConn -} - -// SetDeadline sets the read and write deadlines for the connection. -// Read deadlines are set eventually depending on the status of retries. -func (r *retrier) SetDeadline(t time.Time) error { - e1 := r.SetReadDeadline(t) - e2 := r.SetWriteDeadline(t) - return core.JoinErr(e1, e2) -} diff --git a/intra/dialers/retrier_test.go b/intra/dialers/retrier_test.go deleted file mode 100644 index a699336d..00000000 --- a/intra/dialers/retrier_test.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build ignore - -// -package dialers - -import ( - "bytes" - "io" - "net" - "testing" - "time" - - "github.com/celzero/firestack/intra/protect" -) - -type setup struct { - t *testing.T - server *net.TCPListener - clientSide DuplexConn - serverSide *net.TCPConn - serverReceived []byte -} - -func makeSetup(t *testing.T) *setup { - addr, err := net.ResolveTCPAddr("tcp", ":0") - if err != nil { - t.Error(err) - } - server, err := net.ListenTCP("tcp", addr) - if err != nil { - t.Error(err) - } - - serverAddr, ok := server.Addr().(*net.TCPAddr) - if !ok { - t.Error("Server isn't TCP?") - } - d := protect.MakeNsRDial("rtest", nil) - clientSide, err := DialWithSplitRetry(d, serverAddr) - if err != nil { - t.Error(err) - } - serverSide, err := server.AcceptTCP() - if err != nil { - t.Error(err) - } - return &setup{t, server, clientSide, serverSide, nil} -} - -const BUFSIZE = 256 - -func makeBuffer() []byte { - buffer := make([]byte, BUFSIZE) - for i := 0; i < BUFSIZE; i++ { - buffer[i] = byte(i) - } - return buffer -} - -func send(src io.Writer, dest io.Reader, t *testing.T) []byte { - buffer := makeBuffer() - n, err := src.Write(buffer) - if err != nil { - t.Error(err) - } - if n != len(buffer) { - t.Errorf("Failed to write whole buffer: %d", n) - } - - buf := make([]byte, len(buffer)) - n, err = dest.Read(buf) - if err != nil { - t.Error(nil) - } - if n != len(buf) { - t.Errorf("Not enough bytes: %d", n) - } - if !bytes.Equal(buf, buffer) { - t.Errorf("Wrong contents") - } - return buf -} - -func (s *setup) sendUp() { - buf := send(s.clientSide, s.serverSide, s.t) - s.serverReceived = append(s.serverReceived, buf...) -} - -func (s *setup) sendDown() { - send(s.serverSide, s.clientSide, s.t) -} - -func closeRead(closed, blocked DuplexConn, t *testing.T) { - closed.CloseRead() - // TODO: Figure out if this is detectable on the opposite side. -} - -func closeWrite(closed, blocked DuplexConn, t *testing.T) { - closed.CloseWrite() - n, err := blocked.Read(make([]byte, 1)) - if err != io.EOF || n > 0 { - t.Errorf("Read should have failed with EOF") - } -} - -func (s *setup) closeReadUp() { - closeRead(s.clientSide, s.serverSide, s.t) -} - -func (s *setup) closeWriteUp() { - closeWrite(s.clientSide, s.serverSide, s.t) -} - -func (s *setup) closeReadDown() { - closeRead(s.serverSide, s.clientSide, s.t) -} - -func (s *setup) closeWriteDown() { - closeWrite(s.serverSide, s.clientSide, s.t) -} - -func (s *setup) close() { - s.server.Close() -} - -func (s *setup) confirmRetry() { - done := make(chan struct{}) - go func() { - buf := make([]byte, len(s.serverReceived)) - n, err := s.clientSide.Read(buf) - if err != nil { - s.t.Error(err) - } - if n != len(buf) { - s.t.Error("Unexpected echo length") - } - close(done) - }() - - var err error - s.serverSide, err = s.server.AcceptTCP() - if err != nil { - s.t.Errorf("Second socket failed") - } - buf := make([]byte, len(s.serverReceived)) - var n int - for n < len(buf) { - var m int - m, err = s.serverSide.Read(buf[n:]) - n += m - if err != nil { - s.t.Error(err) - } - } - if !bytes.Equal(buf, s.serverReceived) { - s.t.Errorf("Replay was corrupted") - } - - n, err = s.serverSide.Write(buf) - if err != nil { - s.t.Error(err) - } - if n != len(buf) { - s.t.Errorf("Couldn't echo all bytes: %d", n) - } - <-done -} - -func (s *setup) checkNoSplit() { - // no-op -} - -func TestNormalConnection(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.sendDown() - s.closeReadUp() - s.closeWriteUp() - s.close() - s.checkNoSplit() -} - -func TestFinRetry(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.serverSide.Close() - s.confirmRetry() - s.sendDown() - s.closeReadUp() - s.closeWriteUp() - s.close() -} - -func TestTimeoutRetry(t *testing.T) { - s := makeSetup(t) - s.sendUp() - // Client should time out and retry after about 1.2 seconds - time.Sleep(2 * time.Second) - s.confirmRetry() - s.sendDown() - s.closeReadUp() - s.closeWriteUp() - s.close() -} - -func TestTwoWriteRetry(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.sendUp() - s.serverSide.Close() - s.confirmRetry() - s.sendDown() - s.closeReadUp() - s.closeWriteUp() - s.close() -} - -func TestFailedRetry(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.serverSide.Close() - s.confirmRetry() - s.closeReadDown() - s.closeWriteDown() - s.close() -} - -func TestDisappearingServer(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.close() - s.serverSide.Close() - // Try to read 1 byte to trigger the retry. - n, err := s.clientSide.Read(make([]byte, 1)) - if n > 0 || err == nil { - t.Error("Expected read to fail") - } - s.clientSide.CloseRead() - s.clientSide.CloseWrite() - s.checkNoSplit() -} - -func TestSequentialClose(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.closeWriteUp() - s.sendDown() - s.closeWriteDown() - s.close() - s.checkNoSplit() -} - -func TestBackwardsUse(t *testing.T) { - s := makeSetup(t) - s.sendDown() - s.closeWriteDown() - s.sendUp() - s.closeWriteUp() - s.close() - s.checkNoSplit() -} - -// Regression test for an issue in which the initial handshake timeout -// continued to apply after the handshake completed. -func TestIdle(t *testing.T) { - s := makeSetup(t) - s.sendUp() - s.sendDown() - // Wait for longer than the 1.2-second response timeout - time.Sleep(2 * time.Second) - // Try to send down some more data. - s.sendDown() - s.close() - s.checkNoSplit() -} diff --git a/intra/dialers/split_and_desync.go b/intra/dialers/split_and_desync.go deleted file mode 100644 index 51a87f7f..00000000 --- a/intra/dialers/split_and_desync.go +++ /dev/null @@ -1,534 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "context" - csprng "crypto/rand" - "io" - "math/rand" - "net" - "net/netip" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "golang.org/x/sys/unix" -) - -const ( - probeSize = 8 - default_ttl = 64 - - desync_http1_1str = "POST / HTTP/1.1\r\nHost: 10.0.0.1\r\nContent-Type: application/octet-stream\r\nContent-Length: 9999999\r\n\r\n" - // from: github.com/bol-van/zapret/blob/c369f11638/nfq/darkmagic.h#L214-L216 - desync_max_ttl = 20 - desync_noop_ttl = 3 - desync_delta_ttl = 1 - - desync_cache_ttl = 30 * time.Second -) - -// ttlcache stores the TTL for a given IP address for a limited time. -// TODO: invalidate cache on network changes. -// TODO: with context.TODO, expmap's reaper goroutine will leak. -var ttlcache = core.NewSieve[netip.Addr, int](context.TODO(), desync_cache_ttl) - -// Combines direct split with TCB Desynchronization Attack -// Inspired by byedpi: github.com/hufrea/byedpi/blob/82e5229df00/desync.c#L69-L123 -type overwriteSplitter struct { - conn *net.TCPConn // underlying connection - - used *core.SigCond // signalled to stop desync writer - await chan struct{} // closed when used is set to false - - ttl int // desync TTL - ip6 bool // IPv6 - payload []byte // must be smaller than 1st written packet - // note: Normal ClientHello generated by browsers is 517 bytes. If kyber is enabled, the ClientHello can be larger. -} - -var _ core.DuplexConn = (*overwriteSplitter)(nil) -var _ core.RetrierConn = (*overwriteSplitter)(nil) - -// exceedsHopLimit checks if cmsgs contains an ICMPv6 hop limit exceeded SockExtendedErr -// -// type SockExtendedErr struct { -// Errno uint32 -// Origin uint8 -// Type uint8 -// Code uint8 -// Pad uint8 -// Info uint32 -// Data uint32 -// } -// -// https://www.rfc-editor.org/rfc/rfc4443.html#section-3.3 -func exceedsHopLimit(cmsgs []unix.SocketControlMessage) bool { - for _, cmsg := range cmsgs { - if cmsg.Header.Level == unix.IPPROTO_IPV6 && cmsg.Header.Type == unix.IPV6_RECVERR { - eeOrigin := cmsg.Data[4] - if eeOrigin == unix.SO_EE_ORIGIN_ICMP6 { - eeType := cmsg.Data[5] - eeCode := cmsg.Data[6] - if eeType == 3 && eeCode == 0 { - return true - } - } - } - } - return false -} - -// exceedsTTL checks if cmsgs contains an ICMPv4 time to live exceeded SockExtendedErr. -// https://www.rfc-editor.org/rfc/rfc792.html#page-6 -func exceedsTTL(cmsgs []unix.SocketControlMessage) bool { - for _, cmsg := range cmsgs { - if cmsg.Header.Level == unix.IPPROTO_IP && cmsg.Header.Type == unix.IP_RECVERR { - eeOrigin := cmsg.Data[4] - if eeOrigin == unix.SO_EE_ORIGIN_ICMP { - eeType := cmsg.Data[5] - eeCode := cmsg.Data[6] - if eeType == 11 && eeCode == 0 { - return true - } - } - } - } - return false -} - -// tracert dials a UDP conn to the target address over a port range basePort to basePort+DESYNC_MAX_TTL, with TTL -// set to 2, 3, ..., DESYNC_MAX_TTL. It does not take ownership of the conn (which must be closed by the caller). -func tracert(d protect.RDialer, ipp netip.AddrPort, basePort int) (*net.UDPConn, int, error) { - udpAddr := net.UDPAddrFromAddrPort(ipp) - udpAddr.Port = 1 // unset port - - isIPv6 := ipp.Addr().Is6() - - // explicitly prefer udp4 for IPv4 to prevent OS from giving cmsg(s) which mix IPPROTO_IPV6 cmsg level - // & IPv4-related cmsg data, because exceedsTTL() returns false when cmsg.Header.Level == IPPROTO_IPV6. - // that is: "udp" dials a dual-stack connection, which we don't want. - proto := "udp4" - if isIPv6 { - proto = "udp6" - } - - var udpFD int - uc, err := protect.AnnounceUDP(d, proto, ":0") - if err != nil { - return uc, udpFD, log.EE("desync: err announcing udp: %v", err) - } - if uc == nil { - return uc, udpFD, log.EE("desync: nil udp conn") - } - - rawConn, err := uc.SyscallConn() - if err != nil { - return uc, udpFD, err - } - if rawConn == nil { - return uc, udpFD, errNoSysConn - } - err = rawConn.Control(func(fd uintptr) { - udpFD = int(fd) - }) - if err != nil { - return uc, udpFD, err - } - - if isIPv6 { - err = unix.SetsockoptInt(udpFD, unix.IPPROTO_IPV6, unix.IPV6_RECVERR, 1) - } else { - err = unix.SetsockoptInt(udpFD, unix.IPPROTO_IP, unix.IP_RECVERR, 1) - } - if err != nil { - return uc, udpFD, log.EE("desync: %s => %s sockopt(recverr) err: %v", laddr(uc), ipp, err) - } - - var msgBuf [probeSize]byte - for ttl := 2; ttl <= desync_max_ttl; ttl += desync_delta_ttl { - _, err = csprng.Read(msgBuf[:]) - if err != nil { - return uc, udpFD, err - } - if isIPv6 { - err = unix.SetsockoptInt(udpFD, unix.IPPROTO_IPV6, unix.IPV6_UNICAST_HOPS, ttl) - } else { - err = unix.SetsockoptInt(udpFD, unix.IPPROTO_IP, unix.IP_TTL, ttl) - } - if err != nil { - return uc, udpFD, err - } - udpAddr.Port = basePort + ttl - _, err = uc.WriteToUDP(msgBuf[:], udpAddr) - if err != nil { - return uc, udpFD, err - } - } - return uc, udpFD, nil -} - -// desyncWithTraceroute estimates the TTL with UDP traceroute, -// then returns a TCP connection that may launch TCB Desynchronization Attack and split the initial upstream segment -// If `payload` is smaller than the initial upstream segment, it launches the attack and splits. -// This traceroute is not accurate, because of time limit (TCP handshake). -// Note: The path the UDP packet took to reach the destination may differ from the path the TCP packet took. -func desyncWithTraceroute(d protect.RDialer, local, remote netip.AddrPort) (*overwriteSplitter, error) { - const maxport = 65535 - measureTTL := true - isIPv6 := remote.Addr().Is6() - basePort := 1 + rand.Intn(maxport-desync_max_ttl) //#nosec G404 - - uc, udpFD, err := tracert(d, remote, basePort) - defer core.Close(uc) - - logeif(err)("desync: dialUDP %s => %s %d: err? %v", local, remote, udpFD, err) - if err != nil { - measureTTL = false - } - - oc, err := desyncWithFixedTtl(d, local, remote, desync_noop_ttl) - if err != nil { - return nil, err - } - if oc == nil { // nilaway - return nil, errNoDesyncConn - } - - var msgBuf [probeSize]byte - - bptr := core.Alloc16() - cmsgBuf := *bptr - cmsgBuf = cmsgBuf[:cap(cmsgBuf)] - defer func() { - *bptr = cmsgBuf - core.Recycle(bptr) - }() - - // after TCP handshake, check received ICMP messages, if measureTTL is true. - for i := 0; i < desync_max_ttl-1 && measureTTL; i += desync_delta_ttl { - _, cmsgN, _, from, err := unix.Recvmsg(udpFD, msgBuf[:], cmsgBuf[:], unix.MSG_ERRQUEUE) - if err != nil { - log.VV("desync: recvmsg %v, err: %v", remote, err) - break - } - - cmsgs, err := unix.ParseSocketControlMessage(cmsgBuf[:cmsgN]) - if err != nil { - log.W("desync: parseSocketControlMessage %v failed: %v", remote, err) - continue - } - - if isIPv6 { - if exceedsHopLimit(cmsgs) { - fromPort := from.(*unix.SockaddrInet6).Port - ttl := fromPort - basePort - if ttl <= desync_max_ttl { - oc.ttl = max(oc.ttl, ttl) - } // else: corrupted packet? - } - } else { - if exceedsTTL(cmsgs) { - fromPort := from.(*unix.SockaddrInet4).Port - ttl := fromPort - basePort - if ttl <= desync_max_ttl { - oc.ttl = max(oc.ttl, ttl) - } // else: corrupted packet? - } - } - } - - // skip or apply desync depending on whether - // the measurement is successful. - avoidDesync := oc.ttl <= desync_noop_ttl - if avoidDesync { - oc.used.Signal() - } - - log.D("desync: done: %v, do desync? %t, ttl: %d", remote, !avoidDesync, oc.ttl) - - return oc, nil -} - -func desyncWithFixedTtl(d protect.RDialer, local, remote netip.AddrPort, initialTTL int) (*overwriteSplitter, error) { - var raddr *net.TCPAddr = net.TCPAddrFromAddrPort(remote) - var laddr *net.TCPAddr // nil is valid - if local.IsValid() { - laddr = net.TCPAddrFromAddrPort(local) - } - - isIPv6 := remote.Addr().Is6() - // skip desync if no measurement is done - avoidDesync := initialTTL <= desync_noop_ttl - - proto := "tcp4" - if isIPv6 { - proto = "tcp6" - } - - tcpConn, err := protect.DialTCP(d, proto, laddr, raddr) - - logeif(err)("desync: dialTCP: %s => %s, do desync? %t, ttl: %d", - laddr, raddr, !avoidDesync, initialTTL) - - if err != nil { - return nil, err - } - if tcpConn == nil { - return nil, errNoConn - } - - s := &overwriteSplitter{ - conn: tcpConn, - used: core.NewSigCond(), - ttl: initialTTL, - payload: []byte(desync_http1_1str), - ip6: isIPv6, - } - if avoidDesync { - s.used.Signal() - } - - return s, nil -} - -// DialWithSplitAndDesync estimates the TTL with UDP traceroute, -// then returns a TCP connection that may launch TCB Desynchronization -// and split the initial upstream segment. -// ref: github.com/bol-van/zapret/blob/c369f11638/docs/readme.eng.md#dpi-desync-attack -func dialWithSplitAndDesync(d protect.RDialer, laddr, raddr net.Addr) (*overwriteSplitter, error) { - remote := asAddrPort(raddr) // must not be invalid - local := asAddrPort(laddr) // can be invalid - - if !remote.IsValid() { - log.E("desync: invalid raddr: conv %s to %s", raddr, remote) - return nil, errNoIps - } - - ttl, ok := ttlcache.Get(remote.Addr()) - if ok { - return desyncWithFixedTtl(d, local, remote, ttl) - } - conn, err := desyncWithTraceroute(d, local, remote) - if err == nil && conn != nil { // go vet (incorrectly) complains conn being nil when err is nil - ttlcache.Put(remote.Addr(), conn.ttl) - } - return conn, err -} - -// Close implements core.DuplexConn. -func (s *overwriteSplitter) Close() error { core.CloseTCP(s.conn); return nil } - -// CloseRead implements core.DuplexConn. -func (s *overwriteSplitter) CloseRead() error { core.CloseTCPRead(s.conn); return nil } - -// CloseWrite implements core.DuplexConn. -func (s *overwriteSplitter) CloseWrite() error { core.CloseTCPWrite(s.conn); return nil } - -// LocalAddr implements core.DuplexConn. -func (s *overwriteSplitter) LocalAddr() net.Addr { return laddr(s.conn) } - -// RemoteAddr implements core.DuplexConn. -func (s *overwriteSplitter) RemoteAddr() net.Addr { return raddr(s.conn) } - -// SetDeadline implements core.DuplexConn. -func (s *overwriteSplitter) SetDeadline(t time.Time) error { - if c := s.conn; c != nil { - return c.SetDeadline(t) - } - return nil // no-op -} - -// SyscallConn implements core.DuplexConn. -func (s *overwriteSplitter) SyscallConn() (syscall.RawConn, error) { - if c := s.conn; c != nil { - return c.SyscallConn() - } - return nil, syscall.EINVAL -} - -// SetKeepAlive implements core.DuplexConn. -func (s *overwriteSplitter) SetKeepAlive(y bool) error { - if c := s.conn; c != nil { - return c.SetKeepAlive(y) - } - return nil // no-op -} - -// SetReadDeadline implements core.DuplexConn. -func (s *overwriteSplitter) SetReadDeadline(t time.Time) error { - if c := s.conn; c != nil { - return c.SetReadDeadline(t) - } - return nil // no-op -} - -// SetWriteDeadline implements core.DuplexConn. -func (s *overwriteSplitter) SetWriteDeadline(t time.Time) error { - if c := s.conn; c != nil { - return c.SetWriteDeadline(t) - } - return nil // no-op -} - -// Read implements core.DuplexConn. -func (s *overwriteSplitter) Read(b []byte) (int, error) { return s.conn.Read(b) } - -// Write implements core.DuplexConn. -// ref: github.com/hufrea/byedpi/blob/82e5229df00/desync.c#L69-L123 -func (s *overwriteSplitter) Write(b []byte) (n int, err error) { - conn := s.conn - laddr := laddr(s.conn) - raddr := raddr(s.conn) - - noop := len(b) == 0 // go vet has us handle this case - short := len(b) < len(s.payload) - signalled := false - used := s.used.Cond() // also true when s.ttl <= desync_noop_ttl - if noop { - n, err = 0, nil - } else if used { - // after the first write, there is no special write behavior. - // used may also be set to true to avoid desync. - n, err = conn.Write(b) - } else if signalled = s.used.Signal(); !signalled { - close(s.await) // wake up any ReadFrom waiting for desync to complete - // set `used` to ensure this code only runs once per conn; - // if !swapped, some other goroutine has already swapped it. - n, err = conn.Write(b) - } else if short { - n, err = conn.Write(b) - } - if used || short || !signalled || noop { - logeif(err)("desync: write: %s => %s; desync done %d; (noop? %t, used? %t, short? %t, race? %t); err? %v", - laddr, raddr, n, noop, used, short, !signalled, err) - return n, err - } - - if len(b) <= len(s.payload) { // same as "short == true" - return n, err // redundant check for nilaway - } - - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - if rawConn == nil { - return 0, errNoSysConn - } - - var sockFD int - err = rawConn.Control(func(fd uintptr) { - sockFD = int(fd) - }) - if err != nil { - return 0, log.EE("desync: %s => %s get sock fd failed; %v", laddr, raddr, err) - } - - fileFD, err := unix.MemfdCreate("haar", unix.O_RDWR) - if err != nil { - return 0, err - } - - defer core.CloseFD(fileFD) - - err = unix.Ftruncate(fileFD, int64(len(s.payload))) - if err != nil { - return 0, err - } - firstSegment, err := unix.Mmap(fileFD, 0, len(s.payload), unix.PROT_WRITE, unix.MAP_SHARED) - if err != nil { - return 0, err - } - defer func() { - _ = unix.Munmap(firstSegment) - }() - - // restrict TTL to ensure s.Payload is seen by censors, but not by the server. - copy(firstSegment, s.payload) - if s.ip6 { - err = unix.SetsockoptInt(sockFD, unix.IPPROTO_IPV6, unix.IPV6_UNICAST_HOPS, s.ttl) - } else { - err = unix.SetsockoptInt(sockFD, unix.IPPROTO_IP, unix.IP_TTL, s.ttl) - } - if err != nil { - return 0, log.EE("desync: %s => %s setsockopt(ttl) err: %v", laddr, raddr, err) - } - var offset int64 = 0 - n1, err := unix.Sendfile(sockFD, fileFD, &offset, len(s.payload)) - if err != nil { - return n1, log.EE("desync: %s => %s sendfile() %d err: %v", laddr, raddr, n1, err) - } - - // also: github.com/hufrea/byedpi/blob/bbe95222/desync.c#L115 - time.Sleep(3 * time.Microsecond) - // restore the first-half of the payload so that it gets picked up on retranmission. - copy(firstSegment, b[:len(s.payload)]) - - // restore default TTL - if s.ip6 { - err = unix.SetsockoptInt(sockFD, unix.IPPROTO_IPV6, unix.IPV6_UNICAST_HOPS, default_ttl) - } else { - err = unix.SetsockoptInt(sockFD, unix.IPPROTO_IP, unix.IP_TTL, default_ttl) - } - if err != nil { - return n1, log.EE("desync: %s => %s setsockopt(ttl) err: %v", laddr, raddr, err) - } - - // write the second segment - n2, err := conn.Write(b[len(s.payload):]) - logeif(err)("desync: write: n1: %d, n2: %d, err: %v", n1, n2, err) - return n1 + n2, err -} - -// ReadFrom reads from the reader and writes to s. -func (s *overwriteSplitter) ReadFrom(reader io.Reader) (bytes int64, err error) { - start := time.Now() - if !s.used.Cond() { - bytes, err = copyOnce(s, reader) - logeif(err)("desync: readfrom: copyOnce; sz: %d %s=>%s; err: %v", - bytes, s.LocalAddr(), s.RemoteAddr(), err) - if err != nil { - return - } - } - // wait for desync to complete - s.used.Wait() - elapsed := time.Since(start) - - b, err := s.conn.ReadFrom(reader) - bytes += b - log.V("desync: readfrom: done; sz: %d %s=>%s; dur: %s, wait: %s; err: %v", - bytes, s.LocalAddr(), s.RemoteAddr(), core.FmtTimeAsPeriod(start), core.FmtPeriod(elapsed), err) - - return -} - -// WriteTo reads from s and writes to w. -// Usually, WriteTo is executing the "download" phase of egressing conn (w). -func (s *overwriteSplitter) WriteTo(w io.Writer) (bytes int64, err error) { - start := time.Now() - waited := s.used.TryWait(uploadTimeoutForDownload) - elapsed := time.Since(start) - - b, err := s.conn.WriteTo(w) - bytes += b - log.V("desync: writeto: done; sz: %d %s<=%s; dur: %s, wait: %s (%t); err: %v", - bytes, s.LocalAddr(), s.RemoteAddr(), core.FmtTimeAsPeriod(start), core.FmtPeriod(elapsed), waited, err) - return -} - -func asAddrPort(a net.Addr) (n netip.AddrPort) { - if a == nil { - return - } - n, _ = netip.ParseAddrPort(a.String()) - return -} diff --git a/intra/dialers/tlsdial.go b/intra/dialers/tlsdial.go deleted file mode 100644 index 134495a0..00000000 --- a/intra/dialers/tlsdial.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "crypto/tls" - "net" - "net/netip" - - "github.com/celzero/firestack/intra/log" -) - -func tlsConnect(d *tls.Dialer, proto string, local, remote netip.AddrPort) (net.Conn, error) { - if d == nil { - log.E("tlsdial: tlsConnect: nil dialer") - return nil, errNoDialer - } else if !ipok(remote.Addr()) { - log.E("tlsdial: tlsConnect: invalid ip", remote) - return nil, errNoIps - } - if local.IsValid() { - cd := new(net.Dialer) - *cd = *d.NetDialer // shallow copy - cd.LocalAddr = net.TCPAddrFromAddrPort(local) - return cd.Dial(proto, remote.String()) - } else { - return d.Dial(proto, remote.String()) - } -} - -func TlsDial(d *tls.Dialer, network, addr string) (net.Conn, error) { - d.Config = ensureSni(d.Config, addr) - return dialtls(d, d.Config, network, "", addr, adaptTlsDial(tlsConnect)) -} diff --git a/intra/dialers/types.go b/intra/dialers/types.go deleted file mode 100644 index 8aefeb79..00000000 --- a/intra/dialers/types.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dialers - -import ( - "crypto/tls" - "net" - "net/netip" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/protect" - "golang.org/x/net/proxy" -) - -// rdials is a union type for protect.RDial, net.Dialer, tls.Dialer -type rdials interface { - *protect.RDial | *protect.RDialer | *tls.Dialer | *proxy.Dialer -} - -// rconns is a union type for net.UDPConn, net.TCPConn, icmp.PacketConn, net.TCPListener -type rconns interface { - *net.Conn | *net.PacketConn | *net.UDPConn | *net.TCPConn | *net.TCPListener -} - -type dialFn[D rdials, C rconns] func(dialer D, network string, local, remote netip.AddrPort) (C, error) -type connFn[D rdials] func(dialer D, network string, local, remote netip.AddrPort) (net.Conn, error) - -// adaptRDial adapts a connectFn[protect.RDial] to a dialFn -func adaptRDial[D *protect.RDial, C *net.Conn](f connFn[D]) dialFn[D, C] { - return func(d D, network string, laddr, raddr netip.AddrPort) (cc C, err error) { - c, err := f(d, network, laddr, raddr) - if err != nil { - clos(c) - return nil, err - } - if c == nil || core.IsNil(c) { // go.dev/play/p/SsmqM00d2oH - return nil, errNilConn - } - return &c, nil - } -} - -// adaptRDialer adapts a connectFn[protect.RDialer] to a dialFn -func adaptRDialer[D *protect.RDialer, C *net.Conn](f connFn[D]) dialFn[D, C] { - return func(d D, network string, laddr, raddr netip.AddrPort) (cc C, err error) { - c, err := f(d, network, laddr, raddr) - if err != nil { - clos(c) - return nil, err - } - if c == nil || core.IsNil(c) { - return nil, errNilConn - } - return &c, nil - } -} - -// adaptTlsDial adapts a connectFn[tls.Dialer] to a dialFn -func adaptTlsDial[D *tls.Dialer, C *net.Conn](f connFn[D]) dialFn[D, C] { - return func(d D, network string, laddr, raddr netip.AddrPort) (cc C, err error) { - c, err := f(d, network, laddr, raddr) - if err != nil { - clos(c) - return nil, err - } - if c == nil || core.IsNil(c) { - return nil, errNilConn - } - return &c, nil - } -} - -func adaptProxyDial[D *proxy.Dialer, C *net.Conn](f connFn[D]) dialFn[D, C] { - return func(d D, network string, laddr, raddr netip.AddrPort) (cc C, err error) { - c, err := f(d, network, laddr, raddr) - if err != nil { - clos(c) - return nil, err - } - if c == nil || core.IsNil(c) { // go.dev/play/p/SsmqM00d2oH - return nil, errNilConn - } - return &c, nil - } -} - -func unPtr[P any, Q any](p *P, q Q) (P, Q) { - // go.dev/play/p/XRrCepATeIi - if p == nil || core.IsNil(p) { - var zz P - return zz, q - } - return *p, q -} diff --git a/intra/dns.go b/intra/dns.go deleted file mode 100644 index 8a32701d..00000000 --- a/intra/dns.go +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "context" - "strconv" - "strings" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dns53" - "github.com/celzero/firestack/intra/dnscrypt" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/doh" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" -) - -func addIPMapper(ctx context.Context, r dnsx.Resolver, protos string) { - dns53.AddIPMapper(r, protos, false /*clear cache*/) - context.AfterFunc(ctx, func() { - dns53.AddIPMapper(nil, "", true /*clear cache*/) - }) -} - -// AddDNSProxy creates and adds a DNS53 transport to the tunnel's resolver. -func AddDNSProxy(t Tunnel, id, ippcsv string) error { - p, perr := t.internalProxies() - r, rerr := t.internalResolver() - if rerr != nil || perr != nil { - return core.JoinErr(rerr, perr) - } - ctx := t.internalCtx() - specialHostname := protect.HostlessPrefix + id - if dns, err := dns53.NewTransportFromHostname(ctx, id, specialHostname, ippcsv, p); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } -} - -func newSystemDNSProxy(ctx context.Context, p ipn.ProxyProvider, ipcsv string) (d dnsx.Transport, err error) { - specialHostname := protect.UidSystem // never resolved by ipmap:LookupNetIP - return dns53.NewTransportFromHostname(ctx, dnsx.System, specialHostname, ipcsv, p) -} - -// SetSystemDNS creates and adds a DNS53 transport of the specified IP addresses. -func SetSystemDNS(t Tunnel, ipcsvx string) error { - r, rerr := t.internalResolver() - p, perr := t.internalProxies() - ctx := t.internalCtx() - ipcsv := ipcsvx - n := len(ipcsv) - if r == nil || p == nil { - log.W("dns: cannot set system dns; n: %d, errs: %v %v", n, rerr, perr) - return core.JoinErr(dnsx.ErrAddFailed, rerr, perr) - } - - if n <= 0 { - log.W("dns: no system dns IPs to set; fallback to Goos") - r.Remove(dnsx.System) - return nil - } - - // if the ipcsv is localhost, use loopback addresses. - // this is the case if kotlin-land is unable to determine - // DNS servers. This is equivalent to using x.Goos Transport. - if strings.HasPrefix(ipcsv, "localhost") { - if settings.Debug { - log.D("dns: system dns is localhost, using loopback") - } - ipcsv = localip4 + "," + localip6 - } - - var ok bool - if sdns, err := newSystemDNSProxy(ctx, p, ipcsv); err == nil { - ok = r.Add(sdns) - } else { - return err - } - - log.I("dns: new system dns from %s; ok? %t", ipcsv, ok) - return nil -} - -func newGoosTransport(ctx context.Context, px ipn.ProxyProvider) (d dnsx.Transport) { - d, _ = dns53.NewGoosTransport(ctx, px) - return -} - -func newBlockAllTransport() dnsx.Transport { - return dns53.NewGroundedTransport(dnsx.BlockAll) -} - -func newFixedTransport() dnsx.Transport { - return dns53.NewErrorerTransport(dnsx.Fixed) -} - -func newPlusTransport(ctx context.Context, r dnsx.Resolver) dnsx.Transport { - return dnsx.NewPlusTransport(ctx, r /*and zero transports*/) -} - -func newDNSCryptTransport(ctx context.Context, px ipn.ProxyProvider, bdg Bridge) (p dnsx.TransportMult) { - p = dnscrypt.NewDcMult(ctx, px, bdg) - return -} - -func newMDNSTransport(ctx context.Context, protos string, px ipn.ProxyProvider) (d dnsx.MDNSTransport) { - return dns53.NewMDNSTransport(ctx, protos, px) -} - -// AddDefaultTransport adds a special default transport to the tunnel's resolver -// It may be either a DoH or a DNS53 transport. -func AddDefaultTransport(t Tunnel, typ, ippOrUrl, ips string) error { - r, rerr := t.GetResolver() - if rerr != nil { - return rerr - } - tr, err := r.Get(dnsx.Default) - if err != nil { - return err - } - defaultransport, ok := tr.(DefaultDNS) - if !ok { - return dnsx.ErrNotDefaultTransport - } - // on error, default transport remains unchanged - return defaultransport.reinit(typ, ippOrUrl, ips) -} - -// AddProxyDNS creates and adds a DNS53 transport as defined in Proxy's configuration. -func AddProxyDNS(t Tunnel, p x.Proxy) error { - pxr, perr := t.internalProxies() - r, rerr := t.internalResolver() - if rerr != nil || perr != nil { - return core.JoinErr(rerr, perr) - } - pid := p.ID() - ctx := t.internalCtx() - ipOrHostCsv := p.DNS() // may return csv(host:port), csv(ip:port), csv(ips), csv(host) - if len(ipOrHostCsv) == 0 { - log.W("dns: no proxy dns for %s @ %s", pid, p.GetAddr()) - return dnsx.ErrNoProxyDNS - } - ipsOrHost := strings.Split(ipOrHostCsv, ",") - if len(ipsOrHost) == 0 { - log.W("dns: no dns for %s @ %s", pid, p.GetAddr()) - return dnsx.ErrNoProxyDNS - } - first := ipsOrHost[0] - ipport, err := xdns.DnsIPPort(first) - hostOrHostport := first // could be multiple hostnames or host:ports, but choose the first - if err != nil { // use hostname - if dns, err := dns53.NewTransportFromHostname(ctx, pid, hostOrHostport, "" /*ip or ip:port csv*/, pxr); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } - // use ipports; register with same id as the proxy p - } else if dns, err := dns53.NewTransportFrom(ctx, pid, ipport, pxr); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } -} - -// AddDoHTransport creates and adds a Transport that connects to the specified DoH server. -// `url` is the URL of a DoH server (no template, POST-only). -func AddDoHTransport(t Tunnel, id, url, ipcsv string) error { - pxr, perr := t.internalProxies() - r, rerr := t.internalResolver() - if rerr != nil || perr != nil { - return core.JoinErr(rerr, perr) - } - ips := ipcsv - ctx := t.internalCtx() - split := []string{} - if len(ips) > 0 { - split = strings.Split(ips, ",") - } - if dns, err := doh.NewTransport(ctx, id, url, split, pxr); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } -} - -// AddODoHTransport creates and adds a Transport that connects to the specified ODoH server. -// `endpoint` is the entry / proxy for the ODoH server, `resolver` is the URL of the target ODoH server. -func AddODoHTransport(t Tunnel, id, endpoint, resolver, epipcsv string) error { - pxr, perr := t.internalProxies() - r, rerr := t.internalResolver() - if rerr != nil || perr != nil { - return core.JoinErr(rerr, perr) - } - epips := epipcsv - ctx := t.internalCtx() - split := []string{} - if len(epips) > 0 { - split = strings.Split(epips, ",") - } - if dns, err := doh.NewOdohTransport(ctx, id, endpoint, resolver, split, pxr); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } -} - -// AddDoTTransport creates and adds a Transport that connects to the specified DoT server. -func AddDoTTransport(t Tunnel, id, url, ipcsv string) error { - pxr, perr := t.internalProxies() - r, rerr := t.internalResolver() - if rerr != nil || perr != nil { - return core.JoinErr(rerr, perr) - } - ctx := t.internalCtx() - split := []string{} - ips := ipcsv - if len(ips) > 0 { - split = strings.Split(ips, ",") - } - if dns, err := dns53.NewTLSTransport(ctx, id, url, split, pxr); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } -} - -// AddDNSCryptTransport creates and adds a DNSCrypt transport to the tunnel's resolver. -func AddDNSCryptTransport(t Tunnel, id, stamp string) (err error) { - r, rerr := t.internalResolver() - if rerr != nil { - return rerr - } - - var tm dnsx.TransportMult - if tm, err = r.GetMultInternal(dnsx.DcProxy); err != nil { - return err - } - // todo: unexpose DcMulti, cast to TransportMult - if p, ok := tm.(*dnscrypt.DcMulti); ok { - if dns, err := dnscrypt.AddTransport(p, id, stamp); err != nil { - return err - } else { - return addDNSTransport(r, dns) - } - } else { - return dnsx.ErrNoDcProxy - } -} - -// AddDNSCryptRelay adds a DNSCrypt relay transport to the tunnel's resolver. -func AddDNSCryptRelay(t Tunnel, stamp string) error { - var tm dnsx.TransportMult - var err error - r, rerr := t.internalResolver() - if rerr != nil { - return rerr - } - if tm, err = r.GetMultInternal(dnsx.DcProxy); err != nil { - return err - } - if p, ok := tm.(*dnscrypt.DcMulti); ok { - // relay transports are not added to the resolver - return dnscrypt.AddRelayTransport(p, stamp) - } else { - return dnsx.ErrNoDcProxy - } -} - -func addDNSTransport(r dnsx.Resolver, t dnsx.Transport) error { - if !r.Add(t) { - return dnsx.ErrAddFailed - } - return nil -} - -func csv2ssv(csv string) string { - return strings.ReplaceAll(csv, ",", ";") -} - -func fetchDNSInfo(r dnsx.Resolver, id string) string { - if tr, rerr := r.GetInternal(id); rerr == nil { - var sb strings.Builder - sb.WriteString(tr.GetAddr()) - sb.WriteString("[") - sb.WriteString(tr.Type()) - sb.WriteString("/") - sb.WriteString(dnsx.Status2Str(tr.Status())) - sb.WriteString("/") - sb.WriteString(strconv.FormatInt(tr.P50(), 10)) - sb.WriteString("ms] ") - for _, ipp := range tr.IPPorts() { - if ipp.IsValid() { - sb.WriteString(ipp.Addr().String()) - sb.WriteString(";") - } - } - return sb.String() - } else { - return rerr.Error() - } -} diff --git a/intra/dns53/dot.go b/intra/dns53/dot.go deleted file mode 100644 index 8421e48f..00000000 --- a/intra/dns53/dot.go +++ /dev/null @@ -1,475 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "net/netip" - "net/url" - "strconv" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" - _ "go4.org/unsafe/assume-no-moving-gc" -) - -const usepool = true - -const echRetryPeriod = 8 * time.Hour - -type dot struct { - ctx context.Context - done context.CancelFunc - - id string // id of the transport - - url string // full url - addrport string // ip:port or hostname:port - port uint16 // port number - host string // hostname from the url - - c *dns.Client - proxies ipn.ProxyProvider // may be nil - relay string // may be empty - skipTLSVerify bool - - pool *core.MultConnPool[uintptr] - usepool bool - - echconfig *core.Volatile[*tls.Config] // echconfig for the endpoint; may be nil - echlastattempt *core.Volatile[time.Time] // last attempt fetching ech cfg - - est core.P2QuantileEstimator - status *core.Volatile[int] -} - -var _ dnsx.Transport = (*dot)(nil) - -// NewTLSTransport returns a DNS over TLS transport, ready for use. -func NewTLSTransport(ctx context.Context, id, rawurl string, addrs []string, px ipn.ProxyProvider) (t *dot, err error) { - tlscfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - SessionTicketsDisabled: false, - } - - // rawurl is either tls:host[:port] or tls://host[:port] or host[:port] - parsedurl, err := url.Parse(rawurl) - if err != nil { - return - } - skipTLSVerify := false - if parsedurl.Scheme != "tls" { - log.I("dot: disabling tls verification for %s", rawurl) - tlscfg.InsecureSkipVerify = true - skipTLSVerify = true - } - var relay string - if px != nil { - if p, _ := px.ProxyFor(id); p != nil { - relay = p.ID() - } - } - ctx, done := context.WithCancel(ctx) - hostname := parsedurl.Hostname() - if len(hostname) <= 0 { - hostname = rawurl - } - // addrs are pre-determined ip addresses for url / hostname - ok := dnsx.RegisterAddrs(id, hostname, addrs) - // add sni to tls config - tlscfg.ServerName = hostname - tlscfg.ClientSessionCache = core.TlsSessionCache() - addrport, port := url2addrport(rawurl) - t = &dot{ - ctx: ctx, - done: done, - id: id, - url: rawurl, - host: hostname, - skipTLSVerify: skipTLSVerify, - addrport: addrport, // may or may not be ipaddr - port: port, - status: core.NewVolatile(x.Start), - proxies: px, - relay: relay, - pool: core.NewMultConnPool[uintptr](ctx), - usepool: usepool, - est: core.NewP50Estimator(ctx), - echconfig: core.NewZeroVolatile[*tls.Config](), - echlastattempt: core.NewZeroVolatile[time.Time](), - } - echcfg := t.getOrCreateEchConfigIfNeeded() - // local dialer: protect.MakeNsDialer(id, ctl) - t.c = dnsclient(tlscfg) - log.I("dot: (%s) setup: %s; relay? %t; resolved? %t, ech? %t", - id, rawurl, len(relay) > 0, ok, echcfg != nil) - return t, nil -} - -func dnsclient(c *tls.Config) *dns.Client { - return &dns.Client{ - Net: "tcp-tls", - Dialer: nil, // unused; dialers from px take precedence - Timeout: dottimeout, - SingleInflight: true, // coalsece queries - TLSConfig: c.Clone(), // may be left unused - } -} - -// todo: ech over user specified dns+proxy -func (t *dot) ech() []byte { - if v, err := dialers.ECH(t.host); err == nil { - log.V("dot: ech(%s): %d", t.host, len(v)) - return v - } - log.W("dot: ech(%s): not found", t.host) - return nil -} - -func (t *dot) echVerifyFn() func(tls.ConnectionState) error { - if t.skipTLSVerify { - return func(info tls.ConnectionState) error { - log.V("doh: skip ech verify for %s via %s", t.addrport, info.ServerName) - return nil // never reject - } - } - return nil // delegate to stdlib -} - -func (t *dot) doQuery(pid string, q *dns.Msg) (response *dns.Msg, rpid string, ech bool, elapsed time.Duration, qerr *dnsx.QueryError) { - if q == nil || !xdns.HasAnyQuestion(q) { - qerr = dnsx.NewBadQueryError(fmt.Errorf("err len(query) %d", xdns.Len(q))) - return - } - - if qerr = dnsx.WillErr(t); qerr != nil { - return - } - - response, rpid, ech, elapsed, qerr = t.sendRequest(pid, q) - - if qerr != nil { // only on send-request errors - response = xdns.Servfail(q) - } - return -} - -func (t *dot) tlsdial(p ipn.Proxy) (dc *dns.Conn, who uintptr, usingech bool, err error) { - who = p.Handle() - - defer func() { - if dc != nil { - // todo: higher timeout for if using proxy dialer - // _ = c.SetDeadline(time.Now().Add(dottimeout * 2)) - if c := dc.Conn; c != nil { - _ = c.SetDeadline(time.Now().Add(dottimeout)) - } - } - }() - - if dc = t.fromPool(who); dc != nil { - return dc, who, false, nil // pooled connections don't track ECH state - } - - var c net.Conn = nil // dot is always tcp - addr := t.addrport // t.addr may be ip or hostname - - // Try ECH first if available - if echcfg := t.getOrCreateEchConfigIfNeeded(); echcfg != nil { - // update ech config which may have been changed by DialWithTls - defer t.echconfig.Store(echcfg) - c, err = dialers.DialWithTls(p.Dialer(), echcfg, "tcp", addr) - } - - if c == nil && core.IsNil(c) { // no ech or ech failed - cfg := t.c.TLSConfig - c, err = dialers.DialWithTls(p.Dialer(), cfg, "tcp", addr) - usingech = false - } - if c != nil && core.IsNotNil(c) { - if tlsConn, ok := c.(*tls.Conn); ok && usingech { - usingech = tlsConn.ConnectionState().ECHAccepted - } - return &dns.Conn{Conn: c}, who, usingech, err - } else { - err = core.OneErr(err, errNoNet) - log.W("dot: tlsdial: (%s) nil conn/err for %s, ech? %t; err? %v", - t.id, addr, usingech, err) - } - return nil, who, false, err -} - -func (t *dot) pxdial(pid string) (*dns.Conn, string, uintptr, bool, error) { - var px ipn.Proxy - if len(t.relay) > 0 { // relay takes precedence - pid = t.relay - } - if t.proxies != nil { // err if t.proxies is nil - var err error - if px, err = t.proxies.ProxyFor(pid); err != nil { - return nil, "", core.Nobody, false, err - } - } - if px == nil { - return nil, "", core.Nobody, false, dnsx.ErrNoProxyProvider - } - pid = px.ID() - rpid := ipn.ViaID(px) - if settings.Debug { - log.V("dot: pxdial: (%s) using relay/proxy %s (via: %s) at %s", - t.id, pid, rpid, px.GetAddr()) - } - - c, who, ech, err := t.tlsdial(px) - return c, rpid, who, ech, err -} - -// toPool takes ownership of c. -func (t *dot) toPool(id uintptr, c *dns.Conn) { - if !t.usepool || id == core.Nobody { - clos(c) - return - } - ok := t.pool.Put(id, c) - logwif(!ok)("dot: pool: (%s) put for %v; ok? %t", t.id, id, ok) -} - -// fromPool returns a conn from the pool, if available. -func (t *dot) fromPool(id uintptr) (c *dns.Conn) { - if !t.usepool || id == core.Nobody { - return - } - - pooled := t.pool.Get(id) - if pooled == nil || core.IsNil(pooled) { - return - } - var ok bool - if c, ok = pooled.(*dns.Conn); !ok { // unlikely - return &dns.Conn{Conn: pooled} - } - if settings.Debug { - log.V("dot: pool: (%s) got conn from %v", t.id, id) - } - return -} - -func clos(c net.Conn) { - core.CloseConn(c) -} - -func (t *dot) sendRequest(pid string, q *dns.Msg) (ans *dns.Msg, rpid string, ech bool, elapsed time.Duration, qerr *dnsx.QueryError) { - var err error - - if q == nil || !xdns.HasAnyQuestion(q) { - qerr = dnsx.NewBadQueryError(errQueryParse) - return - } - - var conn *dns.Conn - var who uintptr - userelay := len(t.relay) > 0 - useproxy := len(pid) != 0 // pid == dnsx.NetNoProxy => ipn.Block - if useproxy || userelay { // ref dns.Client.Dial - conn, rpid, who, ech, err = t.pxdial(pid) - } else { - err = dnsx.ErrNoProxyProvider - } - - if err == nil { - // tls config is not used with this exchange as conn is pre-supplied - ans, elapsed, err = t.c.ExchangeWithConnContext(t.ctx, q, conn) - } // fallthrough - - raddr := remoteAddrIfAny(conn) - if err != nil { - clos(conn) - ok := dialers.Disconfirm2(t.host, raddr) - log.V("dot: sendRequest: (%s) sz: %d, pad: %d, err: %v; disconfirm? %t %s => %s", - t.id, xdns.Size(q), xdns.EDNS0PadLen(q), err, ok, t.host, raddr) - qerr = dnsx.NewSendFailedQueryError(err) - } else if ans == nil { - t.toPool(who, conn) // or close - qerr = dnsx.NewBadResponseQueryError(errNoAns) - } else { - t.toPool(who, conn) // or close - dialers.Confirm2(t.host, raddr) - } - return -} - -func (t *dot) chooseProxy(pids ...string) string { - return dnsx.ChooseHealthyProxyHostPort("dot: "+t.id, t.addrport, t.port, pids, t.proxies) -} - -func (t *dot) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg, err error) { - var qerr *dnsx.QueryError - var elapsed time.Duration - var pid, rpid string - var ech bool - - if r := t.relay; len(r) > 0 { - pid = t.chooseProxy(r) - } else { - _, pids := xdns.Net2ProxyID(network) - pid = t.chooseProxy(pids...) - } - - ans, rpid, ech, elapsed, qerr = t.doQuery(pid, q) - - status := dnsx.Complete - if qerr != nil { - err = qerr.Unwrap() - status = qerr.Status() - log.W("dot: ans? %v err(%v) / ans(%d)", ans, err, xdns.Len(ans)) - } - t.status.Store(status) - - smm.Latency = elapsed.Seconds() - smm.RData = xdns.GetInterestingRData(ans) - smm.RCode = xdns.Rcode(ans) - smm.RTtl = xdns.RTtl(ans) - smm.Server = t.getAddr() - if ech { - smm.Server = dnsx.EchPrefix + smm.Server - } - smm.PID = pid // may be local dnsx.IsLocalProxy - smm.RPID = rpid // may be empty - if err != nil { - smm.Msg = err.Error() - } - smm.Status = status - t.est.Add(smm.Latency) - - if settings.Debug { - log.V("dot: %s ech? %t; len(res): fro %s:%d a:%d/sz:%d/pad:%d, data: %s / status: %d, via: %s, err? %v", - t.id, ech, smm.QName, smm.QType, xdns.Len(ans), xdns.Size(ans), xdns.EDNS0PadLen(ans), smm.RData, smm.Status, smm.PID, err) - } - - return -} - -func (t *dot) ID() string { - return t.id -} - -func (t *dot) Type() string { - return dnsx.DOT -} - -func (t *dot) P50() int64 { - return t.est.Get() -} - -func (t *dot) GetAddr() string { - return t.getAddr() -} - -func (t *dot) GetRelay() x.Proxy { - if r := t.relay; len(r) > 0 { - px, _ := t.proxies.ProxyFor(r) - return px - } - return nil -} - -func (t *dot) getAddr() (addr string) { - if t.echconfig.Load() != nil { - addr = dnsx.EchPrefix + t.addrport - } else if t.skipTLSVerify { - addr = dnsx.NoPkiPrefix + t.addrport - } else { - addr = t.addrport - } - return addr -} - -func (t *dot) IPPorts() (ipps []netip.AddrPort) { - for _, ip := range dialers.For(t.addrport) { - ipps = append(ipps, netip.AddrPortFrom(ip, t.port)) - } - return -} - -func (t *dot) Status() int { - if px := t.GetRelay(); px != nil { - if px.Status() == ipn.TPU { // relay paused => transport paused - return dnsx.Paused - } - } - return t.status.Load() -} - -func (t *dot) Stop() error { - t.status.Store(dnsx.DEnd) - t.done() - return nil -} - -func url2addrport(url string) (string, uint16) { - // url is of type "tls://host:port" or "tls:host:port" or "host:port" or "host" - if len(url) > 6 && url[:6] == "tls://" { - url = url[6:] - } - if len(url) > 4 && url[:4] == "tls:" { - url = url[4:] - } - port := DotPortU16 - // add port 853 if not present - if _, p, err := net.SplitHostPort(url); err != nil { - url = net.JoinHostPort(url, DotPort) - } else { - v, err := strconv.Atoi(p) - if err != nil && v > 0 { - port = uint16(v) - } - } - return url, port -} - -func (t *dot) getOrCreateEchConfigIfNeeded() *tls.Config { - echcfg := t.echconfig.Load() - if echcfg != nil { - return echcfg - } - - prev := t.echlastattempt.Load() - if time.Since(prev) < echRetryPeriod { - return nil - } - refetch := t.echlastattempt.Cas(prev, time.Now()) - if !refetch { - return nil - } - - if ech := t.ech(); len(ech) > 0 { - echcfg = &tls.Config{ - InsecureSkipVerify: t.skipTLSVerify, - MinVersion: tls.VersionTLS13, // must be 1.3 - EncryptedClientHelloConfigList: ech, - SessionTicketsDisabled: false, - ClientSessionCache: core.TlsSessionCache(), - EncryptedClientHelloRejectionVerify: t.echVerifyFn(), - } - t.echconfig.Store(echcfg) - } - - ok := echcfg != nil - logwif(!ok)("dot: %s fetch ech... ok? %t", t.id, ok) - return echcfg -} diff --git a/intra/dns53/dot_test.go b/intra/dns53/dot_test.go deleted file mode 100644 index 076206d0..00000000 --- a/intra/dns53/dot_test.go +++ /dev/null @@ -1,583 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "context" - "encoding/json" - "errors" - "log" - "net" - "net/netip" - "os" - "testing" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/doh" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/ipn/rpn" - ilog "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/x64" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -type fakeResolver struct { - *net.Resolver -} - -func (r fakeResolver) LocalLookup(q []byte) ([]byte, error) { - return r.Lookup(q, protect.UidSelf) -} - -func (r fakeResolver) Lookup(q []byte, _ string, _ ...string) ([]byte, error) { - // return nil, errors.New("lookup: not implemented") - msg := xdns.AsMsg(q) - if msg == nil { - return nil, errors.New("fakeresolver: nil dns msg") - } - if !xdns.HasAQuadAQuestion(msg) { - return nil, errors.New("fakeresolver: A/AAAA only") - } - qname := xdns.QName(msg) - network := "ip4" - if xdns.HasAAAAQuestion(msg) { - network = "ip6" - } - addrs, err := r.Resolver.LookupNetIP(context.TODO(), network, qname) - if err != nil { - return nil, err - } - // make a dns answer for addrs - ans := xdns.EmptyResponseFromMessage(msg) - if ans == nil { - return nil, errors.New("fakeresolver: nil pkt") - } - rrs := make([]dns.RR, 0) - for _, a := range addrs { - if network == "ip4" { - rr := xdns.MakeARecord(qname, a.String(), 30) - rrs = append(rrs, rr) - } else { - rr := xdns.MakeAAAARecord(qname, a.String(), 30) - rrs = append(rrs, rr) - } - } - ans.Answer = rrs - - return ans.Pack() -} - -func (r fakeResolver) LookupFor(q []byte, _ string) ([]byte, error) { - return r.LocalLookup(q) -} - -func (r fakeResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { - // return nil, errors.New("lookup net ip: not implemented") - return r.Resolver.LookupNetIP(ctx, network, host) -} - -func (r fakeResolver) LookupNetIPFor(ctx context.Context, network, host, uid string) ([]netip.Addr, error) { - // return nil, errors.New("lookup net ip for: not implemented") - return r.Resolver.LookupNetIP(ctx, network, host) -} - -func (r fakeResolver) LookupNetIPOn(ctx context.Context, network, host string, tid ...string) ([]netip.Addr, error) { - return nil, errors.New("fakeResolver: lookup net ip on not implemented") -} - -type fakeCtl struct { - protect.Controller -} - -func (*fakeCtl) Bind4(_, _ string, _ int) {} -func (*fakeCtl) Bind6(_, _ string, _ int) {} -func (*fakeCtl) Protect(_ string, _ int) {} - -type fakeObs struct { - x.ProxyListener -} - -func (*fakeObs) OnProxyAdded(string) {} -func (*fakeObs) OnProxyRemoved(string) {} -func (*fakeObs) OnProxiesStopped() {} - -type fakeBdg struct { - protect.Controller - x.DNSListener -} - -var ( - // baseNsOpts = &x.DNSOpts{PIDCSV: dnsx.NetBaseProxy, IPCSV: "", TIDCSV: x.CT + "test0"} - baseTab = &x.Tab{CID: "testcid", Block: false} - autoNsOpts = &x.DNSOpts{PIDCSV: x.RpnWin, IPCSV: "", TIDCSV: x.CT + "test0"} -) - -func (*fakeBdg) OnQuery(_, _, _ string, _ int) *x.DNSOpts { return autoNsOpts } -func (*fakeBdg) OnUpstreamAnswer(_ *x.DNSSummary, _ string) *x.DNSOpts { return nil } -func (*fakeBdg) OnResponse(*x.DNSSummary) {} -func (*fakeBdg) OnDNSAdded(string) {} -func (*fakeBdg) OnDNSRemoved(string) {} -func (*fakeBdg) OnDNSStopped() {} - -func (*fakeBdg) Route(a, b, c, d, e string) *x.Tab { return baseTab } -func (*fakeBdg) OnComplete(*x.ServerSummary) {} - -const minmtu = 1280 -const dualstack = settings.IP46 - -func TestDot(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("nil proxifier") - } - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - q := aquery("skysports.com") - q6 := aaaaquery("skysports.com") - q2 := aquery("yahoo.com") - q26 := aaaaquery("yahoo.com") - - b4, _ := q.Pack() - b6, _ := q6.Pack() - b24, _ := q2.Pack() - b26, _ := q26.Pack() - // smm := &x.DNSSummary{} - // smm6 := &x.DNSSummary{} - _ = xdns.NetAndProxyID("tcp", dnsx.NetBaseProxy) - // tr, _ := NewTLSTransport(ctx, "test0", "max.rethinkdns.com", []string{"213.188.216.9"}, pxr, ctl) - dtr, _ := NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - tr, _ := NewTransport(ctx, "test0", "1.0.0.2", "53", pxr) - if tr == nil || dtr == nil { - t.Fatal("nil dns transports") - } - - natpt := x64.NewNatPt() - resolv := dnsx.NewResolver(ctx, "10.111.222.3:53", dtr, bdg, natpt) - resolv.Add(tr) - r4, _, err := resolv.LocalLookup(b4) - r6, _, err6 := resolv.LocalLookup(b6) - _, _, _ = resolv.LocalLookup(b24) - _, _, _ = resolv.LocalLookup(b26) - time.Sleep(1 * time.Second) - _, _, _ = resolv.LocalLookup(b6) - if err != nil { - // log.Output(2, smm.Str()) - t.Fatal(err) - } - if err6 != nil { - // log.Output(2, smm6.Str()) - t.Fatal(err6) - } - ans := xdns.AsMsg(r4) - ans6 := xdns.AsMsg(r6) - if xdns.Len(ans) == 0 && xdns.Len(ans6) == 0 { - t.Fatal("no ans") - } - log.Output(10, xdns.Ans(ans)) - log.Output(10, xdns.Ans(ans6)) -} - -func TestProxyReaches(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("nil proxifier") - } - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - _ = xdns.NetAndProxyID("tcp", dnsx.NetBaseProxy) - tr, _ := NewTLSTransport(ctx, "test0", "1.1.1.1", nil, pxr) - dtr, _ := NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - if tr == nil || dtr == nil { - t.Fatal("nil dns transports") - } - - natpt := x64.NewNatPt() - resolv := dnsx.NewResolver(ctx, "10.111.222.3", dtr, bdg, natpt) - resolv.Add(tr) - - exit, _ := pxr.ProxyFor(ipn.Exit) - if exit == nil { - t.Fatal("proxy: exit proxy nil") - } - - c1, _ := exit.Dial("tcp", "google.com:443") - c2, _ := exit.Dial("tcp", "cloudflare.com:443") - c3, _ := exit.Dial("tcp", "microsoft.com:443") - core.Close(c1, c2, c3) - if ok := ipn.Reaches(exit, "auto:https"); !ok { - t.Fatal("does not reach auto:https (google/cloudflare/microsoft)") - } - if ok, err := ipn.IcmpReaches(exit, netip.MustParseAddrPort("34.245.245.138:153")); !ok { - t.Fatal(err) // always fails - } - t.Log("proxy reaches") -} - -func TestSEProxy(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("nil proxifier") - } - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - _ = xdns.NetAndProxyID("tcp", dnsx.NetBaseProxy) - - tr, _ := doh.NewTransport(ctx, "test0", "http://zero.rethinkdns.com/dns-query/", []string{"104.21.83.62"}, pxr) - dtr, _ := doh.NewTransport(ctx, x.Default, "http://zero.rethinkdns.com/dns-query/", []string{"172.67.214.246"}, pxr) - if tr == nil || dtr == nil { - t.Fatal("nil dns transports") - } - - natpt := x64.NewNatPt() - resolv := dnsx.NewResolver(ctx, "10.111.222.3:53", dtr, bdg, natpt) - resolv.Add(tr) - - if err := pxr.RegisterSE(); err != nil { - t.Fatal(err) - } - /*if ips, err := pxr.TestSE(); err != nil { - t.Fatal(err) - } else { - ilog.D("se: %v", ips) - }*/ - - autoNsOpts.PIDCSV = ipn.RpnSE - se, _ := pxr.ProxyFor(ipn.RpnSE) - if se == nil { - t.Fatal("proxy: se proxy nil") - } - - if ok := ipn.Reaches(se, "google.com", "tcp"); !ok { - t.Fail() - } - t.Log("proxy reaches") - - q := aquery("skysports.com") - q6 := aaaaquery("skysports.com") - - b4, _ := q.Pack() - b6, _ := q6.Pack() - - r4, _, err := resolv.LocalLookup(b4) - r6, _, err6 := resolv.LocalLookup(b6) - if err != nil { - // log.Output(2, smm.Str()) - t.Fatal(err) - } - if err6 != nil { - // log.Output(2, smm6.Str()) - t.Fatal(err6) - } - ans := xdns.AsMsg(r4) - ans6 := xdns.AsMsg(r6) - if xdns.Len(ans) == 0 && xdns.Len(ans6) == 0 { - t.Fatal("no ans") - } - log.Output(10, xdns.Ans(ans)) - log.Output(10, xdns.Ans(ans6)) -} - -func TestWgReaches(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("testwg: nil proxifier") - } - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - wgid := x.WG + "1111" - autoNsOpts.PIDCSV = wgid - - _ = xdns.NetAndProxyID("tcp", wgid) - - tr, _ := NewTLSTransport(ctx, "test0", "8.8.8.8", nil, pxr) - dtr, _ := NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - if tr == nil || dtr == nil { - t.Fatal("nil dns transports") - } - - natpt := x64.NewNatPt() - resolv := dnsx.NewResolver(ctx, "10.111.222.3:53", dtr, bdg, natpt) - resolv.Add(tr) - - wgconf, err := os.ReadFile("wg.conf") - if err != nil { - t.Fatal(err) - } - - // read wgconf json into regionalwgconf - - rwg := &rpn.RegionalWgConf{} - if err := json.Unmarshal(wgconf, rwg); err != nil { - t.Fatal(err) - } - - ilog.D("testwg: read wg: %s: %d", rwg.Name, len(wgconf)) - - confok := rwg.GenUapiConfig() - if !confok { - t.Fatal("testwg: gen uapi conf failed") - } - - win, err := pxr.AddProxy(wgid, rwg.UapiWgConf) - ko(t, err) - - ilog.D("testwg: setup %s: %d", rwg.Name, len(rwg.UapiWgConf)) - - if win == nil { - t.Fatal("testwg: nil main ws proxy") - } - - settings.SetAutoDialsParallel(false) - settings.SetAutoMode(settings.AutoModeRemote) - - propx, _ := pxr.ProxyFor(wgid) - if propx == nil { - t.Fatal("testwg: nil proxies") - } - - /*ilog.VV("-----------------------MAIN--------------------------") - ilog.I("proxies 1: %t; 2: %t, 3: %t", propx != nil, propx2 != nil, auto != nil) - if ok := ipn.Reaches(propx, "google.com:443", "tcp"); !ok { - t.Fail() - } - ilog.VV("-----------------------MXCO--------------------------") - if ok := ipn.Reaches(propx2, "cloudflare.com:443", "tcp"); !ok { - t.Fail() - } - ilog.VV("-----------------------AUTO--------------------------") - if ok := ipn.Reaches(auto, "x.com:443", "tcp"); !ok { - t.Fail() - }*/ - ilog.VV("-----------------------DNSX--------------------------") - b4, _ := aquery("skysports.com").Pack() - r4, _, err := resolv.LocalLookup(b4) // must use "test0" - - ilog.D("testwg: %v", win.Router().Stat()) - time.Sleep(2 * time.Second) - - ko(t, err) - - ans := xdns.AsMsg(r4) - if xdns.Len(ans) <= 0 { - t.Fatal("testwg: no ans") - } - ilog.D("dns %s", xdns.Ans(ans)) - ilog.VV("-----------------------END0--------------------------") - - t.Log("testwg: proxy reaches") -} - -func TestWinReaches(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("nil proxifier") - } - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - _ = xdns.NetAndProxyID("tcp", ipn.Auto) - - tr, _ := NewTLSTransport(ctx, "test0", "8.8.8.8", nil, pxr) - dtr, _ := NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - if tr == nil || dtr == nil { - t.Fatal("nil dns transports") - } - - natpt := x64.NewNatPt() - resolv := dnsx.NewResolver(ctx, "10.111.222.3:53", dtr, bdg, natpt) - resolv.Add(tr) - - readWinJson := true - entjson, err := os.ReadFile("win.json") - if err != nil { - readWinJson = false - entjson, err = os.ReadFile("ent.json") - } - ko(t, err) - - const did = "deadbeefdeadbeefdeadbeefdeadbeef" // some device id - ilog.D("ws: read ent (sess? %t): %d", readWinJson, len(entjson)) - if wreg, err := pxr.RegisterWin(entjson, did, nil); err != nil { - t.Fatal(err) - } else { - entjson = wreg - _ = os.WriteFile("win.json", entjson, 0644) // same as sess.json - ilog.D("ws: setup %d", len(entjson)) - } - - win, err := pxr.Win() - ko(t, err) - if win == nil { - t.Fatal("nil main ws proxy") - } - - const maxVisited = 10 - visited := make(map[string]struct{}, 0) - locs, err := win.Locations() - ko(t, err) - if locs == nil { - t.Fatalf("expected locations for %s", win.Who()) - } - for i := 0; i < locs.Len(); i++ { - c, err := locs.Get(i) - if err != nil { - continue - } - if _, ok := visited[c.CC]; !ok { - // _, _ = pxr.AddProxy(ipn.RpnPro+c.CC, c.UapiConfig()) - visited[c.CC] = struct{}{} - } - if len(visited) >= maxVisited { - break - } - } - ilog.I("available proxy CCs (limited to 10): %v", visited) - - _, err = win.Fork("US") - ko(t, err) - _, err = win.Fork("GT") - ko(t, err) - - settings.SetAutoDialsParallel(false) - settings.SetAutoMode(settings.AutoModeRemote) - - propx, _ := pxr.ProxyFor(ipn.RpnWin) - propx2, _ := pxr.ProxyFor(ipn.RpnWin + "GT") - auto, _ := pxr.ProxyFor(ipn.Auto) - if propx == nil || propx2 == nil || auto == nil { - t.Fatal("nil US/GT/Auto proxies") - } - - sess, err := win.State() - ko(t, err) - err = os.WriteFile("sess.json", sess, 0644) // same as win.json - ko(t, err) - - autoNsOpts.PIDCSV = ipn.RpnWin - /*ilog.VV("-----------------------MAIN--------------------------") - ilog.I("proxies 1: %t; 2: %t, 3: %t", propx != nil, propx2 != nil, auto != nil) - if ok := ipn.Reaches(propx, "google.com:443", "tcp"); !ok { - t.Fail() - } - ilog.VV("-----------------------MXCO--------------------------") - if ok := ipn.Reaches(propx2, "cloudflare.com:443", "tcp"); !ok { - t.Fail() - } - ilog.VV("-----------------------AUTO--------------------------") - if ok := ipn.Reaches(auto, "x.com:443", "tcp"); !ok { - t.Fail() - }*/ - ilog.VV("-----------------------DNSX--------------------------") - b4, _ := aquery("skysports.com").Pack() - r4, _, err := resolv.LocalLookup(b4) // must use "test0" - - ilog.D("%v", propx2.Router().Stat()) - time.Sleep(2 * time.Second) - - if err != nil { - t.Fatal(err) - } - - ans := xdns.AsMsg(r4) - if xdns.Len(ans) <= 0 { - t.Fatal("no ans") - } - ilog.D("dns", xdns.Ans(ans)) - ilog.VV("-----------------------END0--------------------------") - - t.Log("proxy reaches") -} - -func TestPinger(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - _ = &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("nil proxifier") - } - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - p, err := pxr.ProxyFor(ipn.Exit) - if err != nil || p == nil { - t.Fatal(err) - } - pc, err := p.Probe("udp", "0.0.0.0:0") - if err != nil || pc == nil { - t.Fatal(err) - } - ok, rtt, err := core.Ping(pc, netip.MustParseAddrPort("1.1.1.1:53")) - if !ok { - t.Fatalf("ping failed %v", err) - } - t.Log("ping rtt", rtt) -} - -func aquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeA) - msg.Id = 1234 - return msg -} - -func aaaaquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeAAAA) - msg.Id = 3456 - return msg -} - -func ko(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} diff --git a/intra/dns53/errorer.go b/intra/dns53/errorer.go deleted file mode 100644 index cad3ffea..00000000 --- a/intra/dns53/errorer.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "errors" - "net/netip" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -var errStubTransport = errors.New("dns: stub transport") - -// TODO: Keep a context here so that queries can be canceled. -type errorer struct { - id string - ipport string -} - -var _ dnsx.Transport = (*errorer)(nil) - -// NewErrorerTransport returns a DNS transport that always errors out on queries. -func NewErrorerTransport(id string) *errorer { - t := &errorer{ - id: id, // typically, dnsx.Fixed - ipport: "127.3.3.3:33", - } - log.I("errorer(%s) setup: %s", t.ID(), t.GetAddr()) - return t -} - -func (t *errorer) Query(_ string, q *dns.Msg, smm *x.DNSSummary) (*dns.Msg, error) { - smm.Latency = 0 - smm.RData = xdns.GetInterestingRData(nil) - smm.RCode = xdns.Rcode(nil) - smm.RTtl = xdns.RTtl(nil) - smm.Server = t.GetAddr() - smm.Status = t.Status() - smm.Msg = errStubTransport.Error() - - return nil, errStubTransport -} - -func (t *errorer) ID() string { - return t.id -} - -func (*errorer) Type() string { - return dnsx.DNS53 -} - -func (*errorer) P50() int64 { - return 0 -} - -func (t *errorer) GetAddr() string { - return t.ipport -} - -func (t *errorer) GetRelay() x.Proxy { - return nil -} - -func (t *errorer) IPPorts() []netip.AddrPort { - return dnsx.NoIPPort -} - -func (*errorer) Status() int { - return x.ClientError -} - -func (*errorer) Stop() error { - return nil -} diff --git a/intra/dns53/goos.go b/intra/dns53/goos.go deleted file mode 100644 index 7bc6140e..00000000 --- a/intra/dns53/goos.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "context" - "net" - "net/netip" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -type goosr struct { - ctx context.Context - done context.CancelFunc - r *net.Resolver - rcgo *net.Resolver - // dialer *protect.RDial - exit ipn.Proxy // the only supported proxy is ipn.Exit - - status *core.Volatile[int] -} - -var _ dnsx.Transport = (*goosr)(nil) - -const ttl30s = 30 // in secs - -// NewGoosTransport returns the default Go DNS resolver -func NewGoosTransport(pctx context.Context, pxs ipn.ProxyProvider) (t *goosr, err error) { - // cannot be nil, see: ipn.Exit which the only proxy guaranteed to be connected to the internet; - // ex: ipn.Base routed back within the tunnel (rethink's traffic routed back into rethink) - // but it doesn't work for goos because the traffic to localhost:53 is routed back in as if - // the destination is vpn's own "fake" dns (typically, at 10.111.222.3) - if pxs == nil { - return nil, dnsx.ErrNoProxyProvider - } - // d := protect.MakeNsRDial(dnsx.Goos, ctl) - px, err := pxs.ProxyFor(ipn.Exit) - if err != nil { - log.E("dns53: goosr: no exit proxy: %v", err) - return nil, err - } - ctx, cancel := context.WithCancel(pctx) - tx := &goosr{ - ctx: ctx, - done: cancel, - status: core.NewVolatile(x.Start), - // dialer: d, - exit: px, - } - tx.r = &net.Resolver{ - PreferGo: true, - Dial: tx.pxdial, // dials in to ipn.Exit, always - } - tx.rcgo = &net.Resolver{ // loopbacks into the tunnel in rinr mode - PreferGo: false, - } - log.I("dns53: goosr: setup done") - return tx, nil -} - -func (t *goosr) pxdial(ctx context.Context, network, addr string) (conn net.Conn, err error) { - // addr must be ip:port - log.VV("dns53: goosr: pxdial: using %s proxy for %s:%s => %s", - t.exit.ID(), network, t.exit.GetAddr(), addr) - return t.exit.Dialer().Dial(network, addr) -} - -func (t *goosr) send(msg *dns.Msg) (ans *dns.Msg, elapsed time.Duration, qerr *dnsx.QueryError) { - var err error - var ip netip.Addr - if msg == nil { - qerr = dnsx.NewBadQueryError(errQueryParse) - return - } - if qerr = dnsx.WillErr(t); qerr != nil { - return - } - - start := time.Now() - - host := xdns.QName(msg) - // TODO: zero length host must return NS records for the root zone - if len(host) <= 0 || host == "." { - qerr = dnsx.NewBadQueryError(errNoHost) - elapsed = time.Since(start) - ans = xdns.Servfail(msg) - return - } - - if ip, err = str2ip(host); err == nil { - log.V("dns53: goosr: no-op; host %s is ipaddr", host) - ans, err = xdns.AQuadAForQuery(msg, ip) - } else { - aquadaq := xdns.HasAQuadAQuestion(msg) - - if !aquadaq { // TODO: support queries other than A/AAAA - log.E("dns53: goosr: not A/AAAA query type for %d:%s", xdns.QType(msg), host) - ans = xdns.Servfail(msg) - err = errQueryParse - } else { - proto := "ip4" - if xdns.HasAAAAQuestion(msg) { - proto = "ip6" - } - - if settings.Loopingback.Load() { - if ips, errl := t.r.LookupNetIP(t.ctx, proto, host); errl == nil && xdns.HasAnyAnswer(msg) { - log.D("dns53: goosr: go resolver (why? %v) for %s => %s", errl, host, ips) - ans, err = xdns.AQuadAForQueryTTL(msg, ttl30s, ips...) - } else { - err = errl - } - } else { - if ips, errc := t.rcgo.LookupNetIP(t.ctx, proto, host); errc == nil { - log.D("dns53: goosr: cgo resolver for %s => %s", host, ips) - ans, err = xdns.AQuadAForQueryTTL(msg, ttl30s, ips...) - } else { - err = errc - } - } - // TODO: if len(ips) <= 0 synthesize a NXDOMAIN? - } - } - - elapsed = time.Since(start) - if err != nil { - qerr = dnsx.NewSendFailedQueryError(err) - return - } - - return -} - -func (t *goosr) Query(_ string, q *dns.Msg, smm *x.DNSSummary) (r *dns.Msg, err error) { - r, elapsed, qerr := t.send(q) - if qerr != nil { // only on send-request errors - r = xdns.Servfail(q) - } - - status := dnsx.Complete - if qerr != nil { - err = qerr.Unwrap() - status = qerr.Status() - log.W("dns53: goosr: err(%v) / size(%d)", qerr, xdns.Len(r)) - } - t.status.Store(status) - - smm.Latency = elapsed.Seconds() - smm.RData = xdns.GetInterestingRData(r) - smm.RCode = xdns.Rcode(r) - smm.RTtl = xdns.RTtl(r) - smm.Server = t.getAddr() - smm.Status = status - smm.PID = t.exit.ID() - if err != nil { - smm.Msg = err.Error() - } - - log.V("dns53: goosr: len(res): %d, data: %s, err? %v", - xdns.Len(r), smm.RData, err) - - return r, err -} - -func (t *goosr) ID() string { - return dnsx.Goos -} - -func (t *goosr) Type() string { - return dnsx.DNS53 -} - -func (t *goosr) P50() int64 { - return 1 // always fast -} - -func (t *goosr) GetAddr() string { - return t.getAddr() -} - -func (t *goosr) getAddr() string { - return protect.Localhost + ":53" // dummy -} - -func (t *goosr) GetRelay() x.Proxy { - return nil -} - -func (t *goosr) IPPorts() []netip.AddrPort { - return []netip.AddrPort{netip.AddrPortFrom(netip.IPv6Loopback(), uint16(53))} -} - -func (t *goosr) Status() int { - return t.status.Load() -} - -func (t *goosr) Stop() error { - t.status.Store(dnsx.DEnd) - t.done() - return nil -} diff --git a/intra/dns53/grounded.go b/intra/dns53/grounded.go deleted file mode 100644 index 75129dd5..00000000 --- a/intra/dns53/grounded.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "net/netip" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -// TODO: Keep a context here so that queries can be canceled. -type grounded struct { - id string - ipport string - status int -} - -var _ dnsx.Transport = (*grounded)(nil) - -// NewGroundedTransport returns a DNS transport that blocks all DNS queries. -func NewGroundedTransport(id string) (t dnsx.Transport) { - t = &grounded{ - id: id, // typically, dnsx.BlockAll - ipport: "127.0.0.3:53", - status: dnsx.Start, - } - log.I("grounded(%s) setup: %s", t.ID(), t.GetAddr()) - return -} - -func (t *grounded) Query(_ string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg, err error) { - ans, err = xdns.RefusedResponseFromMessage(q) - if err != nil { - t.status = x.BadResponse - } else { - t.status = x.Complete - } - elapsed := 0 * time.Second - smm.Latency = elapsed.Seconds() - smm.RData = xdns.GetInterestingRData(ans) - smm.RCode = xdns.Rcode(ans) - smm.RTtl = xdns.RTtl(ans) - smm.Server = t.ipport - smm.Status = t.Status() - if err != nil { - smm.Msg = err.Error() - } - - return ans, err -} - -func (t *grounded) ID() string { - return t.id -} - -func (t *grounded) Type() string { - return dnsx.DNS53 -} - -func (t *grounded) P50() int64 { - return 0 -} - -func (t *grounded) GetAddr() string { - return t.ipport -} - -func (t *grounded) GetRelay() x.Proxy { - return nil -} - -func (t *grounded) IPPorts() []netip.AddrPort { - return dnsx.NoIPPort -} - -func (t *grounded) Status() int { - return t.status -} - -func (*grounded) Stop() error { - return nil -} diff --git a/intra/dns53/ipmapper.go b/intra/dns53/ipmapper.go deleted file mode 100644 index d79c0ee1..00000000 --- a/intra/dns53/ipmapper.go +++ /dev/null @@ -1,431 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "context" - "errors" - "net/netip" - "strconv" - "strings" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/protect/ipmap" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -const battl = 10 * time.Second - -var ( - errNoHost = errors.New("no hostname") - errNoAns = errors.New("no answer") - errNoNet = errors.New("unknown network") - - loopback4 = netip.AddrFrom4([4]byte{127, 0, 0, 1}) - loopback6 = netip.IPv6Loopback() -) - -type answer struct { - a []byte - tid, uid string -} - -type ipmapper struct { - id string - r dnsx.ResolverSelf - g dnsx.Gateway - ba *core.Barrier[answer, string] -} - -var _ ipmap.IPMapper = (*ipmapper)(nil) - -// AddIPMapper adds or removes the IPMapper. -func AddIPMapper(r dnsx.Resolver, protos string, clear bool) { - var m ipmap.IPMapper // nil - ok := r != nil - if ok { - m = &ipmapper{ - id: dnsx.IpMapper, - r: r, - g: r.Gateway(), - ba: core.NewBarrier[answer](battl), - } - } // else remove; m is nil - if clear { - dialers.Clear() // note: clears ipset async - } - dialers.Mapper(m) - dialers.IPProtos(protos) -} - -func str2ip(host string) (netip.Addr, error) { - return netip.ParseAddr(host) -} - -// Implements IPMapper. -func (m *ipmapper) LocalLookup(q []byte) ([]byte, error) { - return m.Lookup(q, protect.UidSelf) -} - -// Implements IPMapper. -func (m *ipmapper) Lookup(q []byte, uid string, tids ...string) ([]byte, error) { - return m.queryAny2(q, uid, tids...) -} - -// Implements IPMapper. -func (m *ipmapper) LookupFor(q []byte, uid string) ([]byte, error) { - return m.queryAny2(q, uid) -} - -// Implements IPMapper. -func (m *ipmapper) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { - return m.queryIP(ctx, network, host, core.UNKNOWN_UID_STR) -} - -// Implements IPMapper. -func (m *ipmapper) LookupNetIPFor(ctx context.Context, network, host, uid string) ([]netip.Addr, error) { - return m.queryIP(ctx, network, host, uid) -} - -// Implements IPMapper. -func (m *ipmapper) LookupNetIPOn(ctx context.Context, network, host string, tid ...string) ([]netip.Addr, error) { - return m.queryIP2(ctx, network, host, protect.UidSelf, tid...) -} - -func (m *ipmapper) queryIP(ctx context.Context, network, host string, uid string) ([]netip.Addr, error) { - return m.queryIP2(ctx, network, host, uid) -} - -// todo: use context -func (m *ipmapper) queryIP2(_ context.Context, network, host, uid string, tid ...string) ([]netip.Addr, error) { - if len(host) <= 0 { - return nil, errNoHost - } - if protect.NeverResolve(host) { - return nil, nil - } - if host == protect.Localhost || host == "localhost." { - return []netip.Addr{loopback4, loopback6}, nil - } - // no lookups when host is already an IP - if ip, err := str2ip(host); err == nil { - log.V("ipmapper: lookup: no-op; host %s is ipaddr", host) - return []netip.Addr{ip}, nil - } - - if settings.Debug { - log.V("ipmapper: lookup: host %s:%s for %s on %v", network, host, uid, tid) - } - - var q4, q6 []byte - var err4, err6 error - switch network { - case "ip": - q4, err4 = dnsmsg(host, dns.TypeA) - q6, err6 = dnsmsg(host, dns.TypeAAAA) - case "ip4": - q4, err4 = dnsmsg(host, dns.TypeA) - case "ip6": - q6, err6 = dnsmsg(host, dns.TypeAAAA) - default: - log.E("ipmapper: lookup: unknown net %s query %s", network, host) - return nil, errNoNet - } - - if err4 != nil || err6 != nil { - errs := core.JoinErr(err4, err6) - log.E("ipmapper: lookup: query %s err %v", host, errs) - return nil, errs - } - - var val4, val6 *core.V[answer, string] - if len(tid) > 0 { // always choose one among these tids - val4, _ = m.ba.Do(key4(host, tid...), m.lookupon(q4, uid, tid...)) - val6, _ = m.ba.Do(key6(host, tid...), m.lookupon(q6, uid, tid...)) - } else if uid != core.UNKNOWN_UID_STR { // client code chooses a tid depending on uid & "origin" - val4, _ = m.ba.Do(key4(host, uid), m.lookupfor(q4, uid)) - val6, _ = m.ba.Do(key6(host, uid), m.lookupfor(q6, uid)) - } else { // either Default or System/Goos - val4, _ = m.ba.Do(key4(host, dnsx.Default), m.locallookup(q4)) - val6, _ = m.ba.Do(key6(host, dnsx.Default), m.locallookup(q6)) - } - - var noval4, noval6 bool - var r4, r6 []byte - var tid4, tid6 string - var lerr4, lerr6 error - if val4 == nil { - noval4 = true - } else { - noval4 = len(val4.Val.a) <= 0 - r4 = val4.Val.a // may be nil - lerr4 = val4.Err // may be nil - tid4 = val4.Val.tid // may be empty - } - if val6 == nil { - noval6 = true - } else { - noval6 = len(val6.Val.a) <= 0 - r6 = val6.Val.a // may be nil - lerr6 = val6.Err // may be nil - tid6 = val6.Val.tid // may be empty - } - - if lerr4 != nil && lerr6 != nil { // all errors - errs := core.JoinErr(lerr4, lerr6) - log.E("ipmapper: lookup: %s: err %v", host, errs) - return nil, errs - } else if noval4 && noval6 { // typecast failed or no answer - log.E("ipmapper: lookup: no answers for %s; len(4)? %d len(6)? %d", host, len(r4), len(r6)) - return nil, errNoAns - } else if len(r4) <= 0 && len(r6) <= 0 { // empty answer - errs := core.JoinErr(errNoAns, lerr4, lerr6) - log.E("ipmapper: lookup: no answers for %s (by: %s+%s), err %v", host, tid4, tid6, errs) - return nil, errs - } - - _, ip4 := addrs(r4) - _, ip6 := addrs(r6) - ip4 = m.undoAlgAndOrNat64(ip4, tid4, uid) - ip6 = m.undoAlgAndOrNat64(ip6, tid6, uid) // nat64 cannot really be "undone" for ip6! - ips := append(ip4, ip6...) - - if settings.Debug { - log.D("ipmapper: host %s => ips (out: %v / in: %d+%d); uid: %s, tids: %s+%s; err4: %v, err6: %v", - host, ips, len(r4), len(r6), uid, tid4, tid6, lerr4, lerr6) - } - return ips, nil -} - -func (m *ipmapper) queryAny2(q []byte, uid string, tids ...string) ([]byte, error) { - msg := xdns.AsMsg(q) - if msg == nil { - log.W("ipmapper: not a dns query sz(%d)", len(q)) - return nil, errQueryParse - } - qname := xdns.QName(msg) - if len(qname) <= 0 { - log.W("ipmapper: query: no qname") - return nil, errNoHost - } - qtype := int(xdns.QType(msg)) - qtypestr := strconv.Itoa(qtype) - - if settings.Debug { - log.V("ipmapper: lookup: host %s, uid: %v", qname, uid) - } - - var v *core.V[answer, string] - if len(tids) > 0 { - v, _ = m.ba.Do(key(qname, qtypestr, tids...), m.lookupon(q, uid, tids...)) - } else if uid != core.UNKNOWN_UID_STR { - v, _ = m.ba.Do(key(qname, qtypestr, uid), m.lookupfor(q, uid)) - } else { - v, _ = m.ba.Do(key(qname, qtypestr, dnsx.Default), m.locallookup(q)) - } - - if v == nil || len(v.Val.a) <= 0 || v.Err != nil { - log.W("ipmapper: query: noans? %t [err %v] for %s / typ %d; for: %s [on %v]", - v == nil, v.Err, qname, qtype, uid, tids) - return nil, core.OneErr(v.Err, errNoAns) - } - - return m.undoAlg(v.Val.a, v.Val.tid, uid) -} - -// lookupfor resolves q given a uid. If uid is protect.SelfUid, the client -// code (via DNSListener.OnQuery) may or may not choose dnsx.Default. If uid -// is any other "integer" including "-1" (core.UNKNOWN_UID_STR), the client -// code is free to choose a transport as it sees fit. -func (m *ipmapper) lookupfor(q []byte, uid string) func() (answer, error) { - return func() (answer, error) { - a, tid, err := m.r.LookupFor(q, uid) - return answer{a, tid, uid}, err - } -} - -// lookupon always resolves on one of the chosen tids -// (if empty, it may or may not use dnsx.Default; -// see: dnsx.transport.go:determineTransport) -// uid may be protect.UidSelf or unknown -func (m *ipmapper) lookupon(q []byte, uid string, tids ...string) func() (answer, error) { - return func() (answer, error) { - a, tid, err := m.r.LookupFor2(q, uid, tids...) - return answer{a, tid, uid}, err - } -} - -// locallookup resolves on dnsx.Default and then on dnsx.System or dnsx.Goos -// if dnsx.Default fails. -func (m *ipmapper) locallookup(q []byte) func() (answer, error) { - return func() (answer, error) { - a, tid, err := m.r.LocalLookup(q) - return answer{a, tid, protect.UidSelf}, err - } -} - -func (m *ipmapper) undoAlg(ans []byte, tid, uid string) ([]byte, error) { - gw := m.g - if gw == nil { - if settings.Debug { - log.V("ipmapper: undoAlg: no-op for %s[%s]; no gateway", tid, uid) - } - return ans, nil - } - - msg := &dns.Msg{} - if err := msg.Unpack(ans); err != nil { - log.W("ipmapper: undoAlg: unpack err %v", err) - return ans, nil - } - - qname, possiblyalgips := addrs(ans) // usually only 1 if alg'd - - noips := len(possiblyalgips) <= 0 - is4 := xdns.HasAAnswer(msg) - is6 := !is4 && xdns.HasAAAAQuestion(msg) - - if !is4 && !is6 || noips { - if settings.Debug { - log.VV("ipmapper: undoAlg: no a? (%t), aaaa? (%t), ans? (%t); no-op", - !is4, !is6, noips) - } - return ans, nil - } - - var realips []netip.Addr - var undidAlg bool - for _, maybealgip := range possiblyalgips { - if ips, undid := gw.X(maybealgip, uid, tid); undid { - // expecting homogeneous addr family; ie, all realips - // to be either v4 or v6 - realips = append(realips, ips...) - undidAlg = true - } - } - - if len(realips) <= 0 { - logwif(undidAlg)("ipmapper: undoAlg: no algip => realip; return orig (qname: %s / ips: %d / undidAlg? %t); tid? %s[%s]", - qname, len(possiblyalgips), undidAlg, tid, uid) - // TODO: return error if undidAlg == true? - return ans, nil - } - - var msgout *dns.Msg - var didTranslate bool - if is4 { - msgout, didTranslate = xdns.TranslateRecords(msg, dns.TypeA, func(r dns.RR) (rx []dns.RR, done bool) { - for _, ip4 := range realips { - if x := xdns.CloneA(r, ip4); x != nil { - rx = append(rx, x) - } - } - return rx, len(rx) > 0 // a single translated rrs is enough - }) - } else if is6 { - msgout, didTranslate = xdns.TranslateRecords(msg, dns.TypeAAAA, func(r dns.RR) (rx []dns.RR, done bool) { - for _, ip6 := range realips { - if x := xdns.CloneAAAA(r, ip6); x != nil { - rx = append(rx, x) - } - } - return rx, len(rx) > 0 // a single translated rrs is enough - }) - } // else: msgout is nil - - logwif(!didTranslate || msgout == nil)("ipmapper: undoAlg: %s => ips (out: %v / in: %d); tids: %s[%s]", - qname, realips, xdns.Len(msgout), tid, uid) - - if msgout != nil { - return msgout.Pack() - } - return ans, nil -} - -func (m *ipmapper) undoAlgAndOrNat64(ip64 []netip.Addr, tid, uid string) []netip.Addr { - // unlike common.go:undoAlg, we do not filter out ipaddrs - // based on dialers.Use4/Use6. This is because the ipmapper - // is used for DNS queries, and the dialers are used for - // actual connections. The dialers will filter out ipaddrs - // based on the dialers.Use4/Use6 settings. - gw := m.g - if gw == nil { - if settings.Debug { - log.V("ipmapper: undoAlg: no-op for %v on %s[%s]; no gateway", ip64, tid, uid) - } - return ip64 - } - realips := make([]netip.Addr, 0, len(ip64)) - for _, addr := range ip64 { - if xips, undidAlg := gw.X(addr, uid, tid); len(xips) > 0 { - // may contain duplicates due to how alg maps domains and ips - realips = append(realips, xips...) - continue // skip log.W below - } else { - log.W("ipmapper: undoAlg: no algip => realip? (%s => %v); undidAlg? %t; tid? %s[%s]", - addr, xips, undidAlg, tid, uid) - } - } - if len(realips) <= 0 { - log.W("ipmapper: undoAlg: no algip => realip; return orig (%v); tid? %s[%s]", - ip64, tid, uid) - return core.CopyUniq(ip64) - } - return realips // no dups -} - -func key(name string, typ string, oth ...string) string { - if len(oth) <= 0 { - return name - } - return name + ":" + typ + ":" + strings.Join(oth, ":") -} - -func key4(name string, oth ...string) string { - return key(name, "ip4", oth...) -} - -func key6(name string, oth ...string) string { - return key(name, "ip6", oth...) -} - -// TODO: handle HTTPS/SVCB -func addrs(a []byte) (qname string, ips []netip.Addr) { - msg := xdns.AsMsg(a) - if msg == nil { - return - } - ips = make([]netip.Addr, 0, len(msg.Answer)) - for _, a := range msg.Answer { - switch rr := a.(type) { - case *dns.A: - if ip4, ok := netip.AddrFromSlice(rr.A); ok { - ips = append(ips, ip4.Unmap()) - } - case *dns.AAAA: - if ip6, ok := netip.AddrFromSlice(rr.AAAA); ok { - ips = append(ips, ip6) - } - case *dns.CNAME: - log.V("ipmapper: cname %s => %s", rr.Hdr.Name, rr.Target) - default: - log.V("ipmapper: unexpected ans type: %v... skip", rr) - } - } - return xdns.QName(msg), ips -} - -func dnsmsg(host string, qtype uint16) ([]byte, error) { - return xdns.Question(host, qtype) -} diff --git a/intra/dns53/mdns.go b/intra/dns53/mdns.go deleted file mode 100644 index 5e4c6722..00000000 --- a/intra/dns53/mdns.go +++ /dev/null @@ -1,655 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (c) HashiCorp, Inc. - -package dns53 - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - "runtime/debug" - "strings" - "sync" - "sync/atomic" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -var ( - errNoProtos = errors.New("enable at least one of IPv4 and IPv6 querying") - errBindFail = errors.New("failed to bind to udp port") - errNoMdnsQuery = errors.New("no mdns query") - errNoMdnsAnswer = errors.New("no mdns answer") -) - -type dnssd struct { - ctx context.Context - done context.CancelFunc - dialer ipn.Proxy - id string // ID of this transport - ipport string // IP:Port queries are sent to (v4) - use4 atomic.Bool // Use IPv4 - use6 atomic.Bool // Use IPv6 - status *core.Volatile[int] // Status of this transport - est core.P2QuantileEstimator -} - -var _ dnsx.MDNSTransport = (*dnssd)(nil) - -// NewMDNSTransport returns a DNS transport that sends all DNS queries to mDNS endpoint. -func NewMDNSTransport(pctx context.Context, protos string, pxr ipn.ProxyProvider) *dnssd { - ctx, done := context.WithCancel(pctx) - - // mdns always dials exit (and never loops back) - exit, err := pxr.ProxyFor(ipn.Exit) - if err != nil { - log.E("mdns: no exit proxy: %v", err) - done() - return nil - } - - t := &dnssd{ - ctx: ctx, - done: done, - id: dnsx.Local, - dialer: exit, - ipport: xdns.MDNSAddr4.String(), // ip6: ff02::fb:5353 - status: core.NewVolatile(dnsx.Start), - est: core.NewP50Estimator(ctx), - } - t.use4.Store(use4(protos)) - t.use6.Store(use6(protos)) - log.I("mdns: setup: %s", protos) - return t -} - -func use4(l3 string) bool { - switch l3 { - case settings.IP4, settings.IP46: - return true - default: - return false - } -} - -func use6(l3 string) bool { - switch l3 { - case settings.IP6, settings.IP46: - return true - default: - return false - } -} - -func (t *dnssd) RefreshProto(protos string) { - n4 := use4(protos) - n6 := use6(protos) - o4 := t.use4.Swap(n4) - o6 := t.use6.Swap(n6) - log.I("mdns: proto change: %s; 4(%s => %s) 6(%s => %s)", protos, o4, n4, o6, n6) -} - -func (t *dnssd) oneshotQuery(msg *dns.Msg) (*dns.Msg, *dnsx.QueryError) { - if qerr := dnsx.WillErr(t); qerr != nil { - return nil, qerr - } - service, tld := xdns.ExtractMDNSDomain(msg) - // always buffered; otherwise c.listen may block on writes into ansch / resch. - // go.dev/play/p/gzwnGAFlTDV - resch := make(chan *dnssdanswer, 32) - qctx := &qcontext{ - msg: msg, - svc: service, - tld: tld, - ansch: resch, - // sec 5 rfc6762 for oneshot queries - timeout: time.Second * 3, - // sec 5.4 rfc6762 multicast flooding - unicastonly: true, - } - - var c *client - var err error - qname := fmt.Sprintf("%s.%s", service, tld) - log.D("mdns: oquery: %s", qname) - - if c, err = t.newClient(true); err != nil { - log.E("mdns: oquery: underlying transport: %s", err) - return nil, dnsx.NewTransportQueryError(err) - } - defer core.Close(c) - if qerr := c.query(qctx); qerr != nil { - log.E("mdns: oquery(%s): %v", qname, qerr) - return nil, qerr - } - log.D("mdns: oquery: awaiting response %s", qname) - // return the first response from channel qctx.ansch (same as resch) - for res := range resch { - if res != nil && res.ans != nil { - log.I("mdns: oquery: q(%s) ans(%s) 4(%s) 6(%s)", qname, res.name, res.ip4, res.ip6) - // todo: multiple answers? - return res.ans, nil - } else { - log.D("mdns: oquery: q(%s); ans missing for %v", qname, res) - } - } - log.I("mdns: oquery: no response for %s", qname) - return nil, dnsx.NewNoResponseQueryError(errNoMdnsAnswer) -} - -func (t *dnssd) Query(_ string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg, err error) { - smm.ID = t.id - smm.Type = dnsx.DNS53 - smm.Server = t.ipport - - defer func() { - log.D("mdns: err: %v; summary: %s", err, smm) - }() - - start := time.Now() - - if q == nil || !xdns.HasAnyQuestion(q) { - smm.Status = dnsx.BadQuery - t.status.Store(dnsx.BadQuery) - return - } - - ans, qerr := t.oneshotQuery(q) - if qerr != nil { - err = qerr.Unwrap() - t.status.Store(qerr.Status()) - } else { - t.status.Store(dnsx.Complete) - } - - elapsed := time.Since(start) - smm.Latency = elapsed.Seconds() - smm.RData = xdns.GetInterestingRData(ans) - smm.RCode = xdns.Rcode(ans) - smm.RTtl = xdns.RTtl(ans) - smm.Status = t.Status() - if err != nil { - smm.Msg = err.Error() - } - t.est.Add(smm.Latency) - - return ans, err -} - -func (t *dnssd) ID() string { - return t.id -} - -func (t *dnssd) Type() string { - return dnsx.DNS53 -} - -func (t *dnssd) P50() int64 { - return t.est.Get() -} - -func (t *dnssd) GetAddr() string { - return t.ipport -} - -func (t *dnssd) GetRelay() x.Proxy { - return nil -} - -func (t *dnssd) IPPorts() []netip.AddrPort { - return []netip.AddrPort{ - xdns.MDNSAddr4.AddrPort(), - xdns.MDNSAddr6.AddrPort(), - } -} - -func (t *dnssd) Status() int { - return t.status.Load() -} - -func (t *dnssd) Stop() error { - t.status.Store(dnsx.DEnd) - t.done() - return nil -} - -// from: github.com/hashicorp/mdns/blob/5b0ab6d61/client.go - -// dnssdanswer is returned after dnssd / mdns query -type dnssdanswer struct { - ans *dns.Msg - name string - target string - ip4 net.IP - ip6 net.IP - port int - txt []string - captured bool -} - -// hasip checks if we have all the ip recs we need -func (s *dnssdanswer) hasip() bool { - return (s.ip4 != nil || s.ip6 != nil) -} - -// hassvc checks if we have all the srv recs we need -func (s *dnssdanswer) hassvc() bool { - return s.port != 0 && len(s.txt) > 0 -} - -// qcontext customizes how a mdns lookup is performed -type qcontext struct { - svc string // Service to query for, ex: _foobar._tcp, normalized to lower case - tld string // If blank, assumes "local" - msg *dns.Msg // If not nil, use this message instead of building one - timeout time.Duration // Lookup timeout - ansch chan<- *dnssdanswer // answers acc, must be non-blocking (buffered) - unicastonly bool // Unicast response desired, as per 5.4 in RFC -} - -// Client provides a query interface that can be used to -// search for service providers using mDNS -type client struct { - use4 bool - use6 bool - - unicast4, unicast6 net.PacketConn - multicast4, multicast6 net.PacketConn - - tmu sync.RWMutex // protects tracker map - tracker map[string]*dnssdanswer - msgCh chan *dns.Msg // never closed - - oneshot bool - - once sync.Once - - // mutable fields - - closed atomic.Bool // 0: open, 1: closed -} - -// String implements fmt.Stringer -func (c *client) String() string { - if c == nil { - return "" - } - return fmt.Sprintf("use4/6? %t/%t; oneshot? %t; tracked %d; closed %t", - c.use4, c.use6, c.oneshot, len(c.tracker), c.closed.Load()) -} - -// newClient creates a new mdns unicast and multicast client -func (t *dnssd) newClient(oneshot bool) (*client, error) { - use4 := t.use4.Load() - use6 := t.use6.Load() - if !use4 && !use6 { - return nil, errNoProtos - } - - var uconn4, uconn6 net.PacketConn // bind to higher port for unicast - var mconn4, mconn6 net.PacketConn // bind to port 5353 for multicast - var err error - - if use4 { - uconn4, err = t.dialer.Announce("udp4", "0.0.0.0:0") - if err != nil { - log.E("mdns: new-client: unicast4 bind fail: %v", err) - } - if !oneshot { - // won't work when in rethink-within-rethink (loopback) mode - // TODO: add support for multicast in protect.RDialer / ipn.Proxy - mconn4, err = net.ListenMulticastUDP("udp4", nil, xdns.MDNSAddr4) - if err != nil { - log.E("mdns: new-client: multicast4 bind fail: %v", err) - } - } - } - - if use6 { - uconn6, err = t.dialer.Announce("udp6", "[::]:0") - if err != nil { - log.E("mdns: new-client: unicast6 bind fail: %v", err) - } - if !oneshot { - // TODO: add support for multicast in protect.RDialer / ipn.Proxy - mconn6, err = net.ListenMulticastUDP("udp6", nil, xdns.MDNSAddr6) - if err != nil { - log.E("mdns: new-client: multicast6 bind fail: %v", err) - } - } - } - - has4 := use4 && uconn4 != nil && (oneshot || mconn4 != nil) - has6 := use6 && uconn6 != nil && (oneshot || mconn6 != nil) - if !has4 && !has6 { - log.E("mdns: new-client: oneshot? %t with no4? %t / no6? %t", oneshot, has4, has6) - return nil, errBindFail - } - - c := &client{ - use4: use4, - use6: use6, - multicast4: mconn4, // nil if oneshot - multicast6: mconn6, // nil if oneshot - unicast4: uconn4, - unicast6: uconn6, - tracker: make(map[string]*dnssdanswer), - msgCh: make(chan *dns.Msg, 32), - oneshot: oneshot, - } - return c, nil -} - -// Close cleanups the client -func (c *client) Close() error { - if c.closed.Load() { - return nil // already closed - } - c.once.Do(func() { - c.closed.Store(true) - log.I("mdns: closing client %s", c) - - core.CloseConn(c.unicast4) - core.CloseConn(c.unicast6) - core.CloseConn(c.multicast4) - core.CloseConn(c.multicast6) - }) - - return nil -} - -// query is used to perform a lookup and stream results -func (c *client) query(qctx *qcontext) *dnsx.QueryError { - if !xdns.HasAnyQuestion(qctx.msg) { - return dnsx.NewBadQueryError(errNoMdnsQuery) - } - - if c.use4 { - go c.recv(c.unicast4) - go c.recv(c.multicast4) - } - if c.use6 { - go c.recv(c.unicast6) - go c.recv(c.multicast6) - } - - q := qctx.msg - q.RecursionDesired = false - // RFC 6762, section 18.12. - // - // In the Question Section of a Multicast DNS query, the top bit of the qclass - // field is used to indicate that unicast responses are preferred for this - // particular question. (See Section 5.4.) - if !c.oneshot && qctx.unicastonly && len(q.Question) > 0 { - q.Question[0].Qclass |= 1 << 15 - } - if err := c.send(q); err != nil { - log.E("mdns: query: send query(%s) fail: err(%v)", qctx.svc, err) - return err - } - - core.Go("mdns.listen", func() { c.listen(qctx) }) - - log.D("mdns: query: waiting for ans to %s", qctx.svc) - return nil -} - -// listen listens for answers to the MDNS query, and sends them to qctx.ansch, -// and stops listening after qctx.timeout or the client is closed. -// Must be called from a goroutine. -func (c *client) listen(qctx *qcontext) { - timesup := time.After(qctx.timeout) - qname := fmt.Sprintf("%s.%s.", qctx.svc, qctx.tld) - total := 0 - defer close(qctx.ansch) -loop: - for { - select { - case msg, ok := <-c.msgCh: - if !ok { - // stackoverflow.com/a/13666733 - log.W("mdns: listen: msg channel for %s closed", qname) - break loop - } - var disco *dnssdanswer - xxlans := append(msg.Answer, msg.Extra...) - for _, ans := range xxlans { - ansname, aerr := xdns.AName(ans) - tracked := c.isTracked(ansname) - // expect answers only for the service name client queried for, or - // an already tracked alias (ex: cname targets) - if (aerr != nil) || (c.oneshot && !strings.Contains(ansname, qctx.svc) && !tracked) { - log.V("mdns: listen: ignoring %s ans for %s svc; tracked? %t; err? %v", ansname, qctx.svc, tracked, aerr) - continue - } - log.D("mdns: listen: processing %s ans for %s", ansname, qname) - switch rr := ans.(type) { - case *dns.PTR: - // create new entry for this - disco = c.track(rr.Ptr) - case *dns.SRV: - // check for a target mismatch - if rr.Target != rr.Hdr.Name { - c.alias(rr.Hdr.Name, rr.Target) - } - disco = c.track(rr.Hdr.Name) - disco.target = rr.Target - disco.port = int(rr.Port) - case *dns.TXT: - disco = c.track(rr.Hdr.Name) - disco.txt = rr.Txt - // todo: r.ans = ans ? - case *dns.CNAME: - disco = c.track(rr.Hdr.Name) - disco.target = rr.Target - c.alias(rr.Hdr.Name, rr.Target) - case *dns.A: - disco = c.track(rr.Hdr.Name) - // todo: append to ip4? - disco.ip4 = rr.A - disco.ans = msg - case *dns.AAAA: - disco = c.track(rr.Hdr.Name) - // todo: append to ip6? - disco.ip6 = rr.AAAA - disco.ans = msg - default: - who := qname - if disco != nil { - who += " -> (disco name) " + disco.name - } - log.I("mdns: listen: ignoring ans %s to %s", rr, who) - } - } - - if disco == nil { // no valid answers - log.D("mdns: listen: no valid answers for %s; len? %d", qname, len(xxlans)) - continue - } else if (c.oneshot && disco.hasip()) || // oneshot + received v4 / v6 ips - (!c.oneshot && disco.hasip() && disco.hassvc()) { // v4 / v6 ips and srv - if !disco.captured { - disco.captured = true - log.D("mdns: listen: q: %s; sent ans %s", qname, disco) - qctx.ansch <- disco - c.untrack(disco.name) - total++ - } else { // discard duplicates - log.D("mdns: listen: q: %s; duplicate ans %s", qname, disco) - continue - } - } else if !c.oneshot { // fire off a node specific query - m := new(dns.Msg) - m.SetQuestion(disco.name, dns.TypePTR) - m.RecursionDesired = false - if err := c.send(m); err != nil { - log.E("mdns: listen: failed to ptr query %s: %v", disco.name, err) - } else { - log.D("mdns: listen: sent ptr query for %s", disco.name) - } - } else { - log.D("mdns: listen: waiting for ip / port for %s", disco.name) - } - case <-timesup: - log.W("mdns: listen: timeout for %s", qname) - break loop - } - } - log.D("mdns: listen: done; got answers %d for %s", total, qname) -} - -// send writes q to approp unicast mdns address -func (c *client) send(q *dns.Msg) *dnsx.QueryError { - if buf, err := q.Pack(); err != nil { - log.W("mdns: send: failed to pack query: %v", err) - return dnsx.NewBadQueryError(err) - } else { - qname := xdns.QName(q) - if c.unicast4 != nil { - extend(c.unicast4, mdnstimeout) - if _, err = c.unicast4.WriteTo(buf, xdns.MDNSAddr4); err != nil { - return dnsx.NewSendFailedQueryError(err) - } - log.D("mdns: send: sent query4 %s", qname) - } - if c.unicast6 != nil { - extend(c.unicast6, mdnstimeout) - if _, err = c.unicast6.WriteTo(buf, xdns.MDNSAddr6); err != nil { - return dnsx.NewSendFailedQueryError(err) - } - log.D("mdns: send: sent query6 %s", qname) - } - } - return nil -} - -// recv forwards bytes to msgCh read from conn until error or shutdown. -// Must be called from a goroutine. -func (c *client) recv(conn net.PacketConn) { - if conn == nil { - return - } - - debug.SetPanicOnFault(true) - defer core.Recover(core.DontExit, "mdns.recv") - - bptr := core.Alloc() - buf := *bptr - buf = buf[:cap(buf)] - // buf must be recycled from a deferred fn since exec continues - // on panics and deferred fns are guaranteed to run. - defer func() { - *bptr = buf - core.Recycle(bptr) - }() - - for !c.closed.Load() { - extend(conn, mdnstimeout) - n, raddr, err := conn.ReadFrom(buf) - - if c.closed.Load() { - log.W("mdns: recv: from(%v); closed; bytes(%d), err(%v)", raddr, n, err) - return - } - - if err != nil { - log.E("mdns: recv: read failed: %v", err) - continue - } - msg := new(dns.Msg) - if err := msg.Unpack(buf[:n]); err != nil { - log.E("mdns: recv: unpack failed: %v", err) - continue - } - - timesup := time.After(mdnstimeout) - // ideally, the writer would close the channel, but in this - // case there are potentially 4 writers (2 unicast, 2 multicast) - // also see: go.dev/play/p/gzwnGAFlTDV - select { - case c.msgCh <- msg: - log.V("mdns: recv: from(%v); sent; bytes(%d)", raddr, n) - case <-timesup: - log.V("mdns: recv: from(%v); timeout ch; bytes(%d)", raddr, n) - return - } - } -} - -// untrack removes a name from the tracker; -// name is NOT normalized. -func (c *client) untrack(name string) { - c.tmu.Lock() - defer c.tmu.Unlock() - log.V("mdns: tracker: rmv %s", name) - delete(c.tracker, name) -} - -// track marks a name as being tracked by this client; -// name is NOT normalized. -func (c *client) track(name string) *dnssdanswer { - c.tmu.Lock() - defer c.tmu.Unlock() - - return c.trackLocked(name) -} - -func (c *client) isTracked(name string) bool { - c.tmu.RLock() - defer c.tmu.RUnlock() - _, ok := c.tracker[name] - return ok -} - -// alias sets up mapping between two tracked entries; -// src and dst are NOT normalized. -func (c *client) alias(src, dst string) { - c.tmu.Lock() - defer c.tmu.Unlock() - - if se, ok := c.tracker[dst]; ok { - log.VV("mdns: tracker: discard %v for %s; aliased to %s", se, dst, src) - } - se := c.trackLocked(src) - log.V("mdns: tracker: alias %s <-> %s with %v", src, dst, se) - c.tracker[dst] = se -} - -// trackLocked is the non-locking version of track, called when lock is already held -func (c *client) trackLocked(name string) *dnssdanswer { - if tse, ok := c.tracker[name]; ok { - log.VV("mdns: tracker: exists %s with %v", name, tse) - return tse - } - se := &dnssdanswer{ - name: name, - } - c.tracker[name] = se - log.V("mdns: tracker: start %s with %v", name, se) - return se -} - -func extend(c net.PacketConn, t time.Duration) { - if c != nil && core.IsNotNil(c) { - _ = c.SetDeadline(time.Now().Add(t)) - } -} diff --git a/intra/dns53/upstream.go b/intra/dns53/upstream.go deleted file mode 100644 index c9931f46..00000000 --- a/intra/dns53/upstream.go +++ /dev/null @@ -1,412 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dns53 - -import ( - "context" - "errors" - "net" - "net/netip" - "strings" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" - - _ "go4.org/unsafe/assume-no-moving-gc" -) - -const ( - Port = "53" // default DNS port - PortU16 = uint16(53) // default DNS port as uint16 - DotPort = "853" // default DNS over TLS port - DotPortU16 = uint16(853) // default DNS over TLS port as uint16 - - // timeouts slightly higher than stalls in ipn.Proxies.ProxyTo - timeout = 15 * time.Second // default timeout for DNS53 - dottimeout = 20 * time.Second // default timeout for DNS over TLS - mdnstimeout = 5 * time.Second // default timeout for MDNS -) - -var errQueryParse = errors.New("dns53: err parse query") - -// TODO: Keep a context here so that queries can be canceled. -type transport struct { - ctx context.Context - done context.CancelFunc - - id string - - addrport string // hostname, ip:port, protect.UidSelf:53, protect.System:53, protect.HostlessXYZ:53 - port uint16 - - client *dns.Client - proxies ipn.ProxyProvider // should never be nil - relay string // may be empty - - pool *core.MultConnPool[uintptr] - usepool bool - - est core.P2QuantileEstimator - lastaddr *core.Volatile[string] // last resolved addr - status *core.Volatile[int] // status of the transport -} - -var _ dnsx.Transport = (*transport)(nil) - -// NewTransportFromHostname returns a DNS53 transport serving from hostname, ready for use. -func NewTransportFromHostname(ctx context.Context, id, hostOrHostport string, ipcsv string, px ipn.ProxyProvider) (t *transport, err error) { - // ipcsv may contain port, eg: 10.1.1.3:53 - do, err := settings.NewDNSOptionsFromHostname(hostOrHostport, ipcsv) - if err != nil { - return - } - return newTransport(ctx, id, do, px) -} - -// NewTransport returns a DNS53 transport serving from ip & port, ready for use. -func NewTransport(ctx context.Context, id, ip, port string, px ipn.ProxyProvider) (t *transport, err error) { - ipport := net.JoinHostPort(ip, port) - do, err := settings.NewDNSOptions(ipport) - if err != nil { - return - } - - return newTransport(ctx, id, do, px) -} - -func newTransport(pctx context.Context, id string, do *settings.DNSOptions, px ipn.ProxyProvider) (*transport, error) { - // cannot be nil, see: ipn.Exit which the only proxy guaranteed to be connected to the internet; - // ex: ipn.Base routed back within the tunnel (rethink's traffic routed back into rethink). - if px == nil { - return nil, dnsx.ErrNoProxyProvider - } - ctx, done := context.WithCancel(pctx) - var relay string - if dnsx.CanUseProxy(id) { - if p, _ := px.ProxyFor(id); p != nil { - relay = p.ID() - } - } - tx := &transport{ - ctx: ctx, - done: done, - id: id, - addrport: do.AddrPort(), // may be hostname:port or ip:port - port: do.Port(), - status: core.NewVolatile(dnsx.Start), - lastaddr: core.NewZeroVolatile[string](), - pool: core.NewMultConnPool[uintptr](ctx), - // todo: renable once we know why pooled wireguard dns conns are troublesome - usepool: false, - proxies: px, // never nil; see above - relay: relay, // may be empty - est: core.NewP50Estimator(ctx), - } - ipcsv := do.ResolvedAddrs() - hasips := len(ipcsv) > 0 - ips := strings.Split(ipcsv, ",") // may be nil or empty or ip:port - ok := dnsx.RegisterAddrs(id, tx.addrport, ips) // addrport may be protect.UidSelf or protect.System - log.I("dns53: (%s) pre-resolved %s to %s; ok? %t", id, tx.addrport, ipcsv, ok) - tx.client = &dns.Client{ - Net: "udp", // default transport type - Timeout: timeout, - // instead using custom dialer rdial - // Dialer: d, - // TODO: set it to MTU? or no more than 512 bytes? - // ref: github.com/miekg/dns/blob/b3dfea071/server.go#L207 - // UDPSize: dns.DefaultMsgSize, - } - log.I("dns53: (%s) setup: %s; pre-ips? %t; relay? %t", id, tx.addrport, hasips, len(relay) > 0) - return tx, nil -} - -// NewTransportFrom returns a DNS53 transport serving from ipp, ready for use. -func NewTransportFrom(ctx context.Context, id string, ipp netip.AddrPort, px ipn.Proxies) (t dnsx.Transport, err error) { - do, err := settings.NewDNSOptionsFromNetIp(ipp) - if err != nil { - return - } - - return newTransport(ctx, id, do, px) -} - -func (t *transport) pxdial(network, pid string) (*dns.Conn, string, uintptr, error) { - if t.id == dnsx.Bootstrap || t.id == dnsx.System { // bootstrap/default never be proxied - // never proxy dns53 transport with "bootstrap" id is a clone of dnsx.System - pid = dnsx.NetBaseProxy - } else if len(t.relay) > 0 { // relay takes precedence - pid = t.relay - } - px, err := t.proxies.ProxyFor(pid) - if err != nil { - return nil, "", core.Nobody, err - } else if px == nil { - return nil, "", core.Nobody, dnsx.ErrNoProxyProvider - } - - rpid := ipn.ViaID(px) - who := px.Handle() - if c := t.fromPool(who); c != nil { - return c, rpid, who, nil - } - - if settings.Debug { - log.V("dns53: pxdial: (%s) using %s relay/proxy %s at %s", - t.id, network, px.ID(), px.GetAddr()) - } - - // t.addrport may be hostless / system / self but we expect the - // proxies to be able to handle these from ipmapper, as expected. - pxconn, err := px.Dialer().Dial(network, t.addrport) - if err != nil { - clos(pxconn) - return nil, rpid, core.Nobody, err - } else if pxconn == nil { - log.E("dns53: pxdial: (%s) no %s conn for relay/proxy %s at %s", - t.id, network, px.ID(), px.GetAddr()) - err = errNoNet - return nil, rpid, core.Nobody, err - } - return &dns.Conn{Conn: pxconn}, rpid, who, nil -} - -// toPool takes ownership of c. -func (t *transport) toPool(id uintptr, c *dns.Conn) { - if !t.usepool || id == core.Nobody { - clos(c) - return - } - ok := t.pool.Put(id, c) - logwif(!ok)("dns53: pool: (%s) put for %v; ok? %t", t.id, id, ok) -} - -// fromPool returns a conn from the pool, if available. -func (t *transport) fromPool(id uintptr) (c *dns.Conn) { - if !t.usepool || id == core.Nobody { - return - } - - pooled := t.pool.Get(id) - if pooled == nil || core.IsNil(pooled) { - return - } - var ok bool - if c, ok = pooled.(*dns.Conn); !ok { // unlikely - log.W("dns53: pool: (%s) not a dns.Conn for %d!", t.id, id) - return &dns.Conn{Conn: pooled} - } - if settings.Debug { - log.V("dns53: pool: (%s) got conn for %d", t.id, id) - } - return -} - -func (t *transport) connect(network, pid string) (conn *dns.Conn, rpid string, who uintptr, err error) { - useudp := network == dnsx.NetTypeUDP - userelay := len(t.relay) > 0 - useproxy := len(pid) != 0 // pid == dnsx.NetNoProxy => ipn.Block - - // if udp is unreachable, try tcp: github.com/celzero/rethink-app/issues/839 - // note that some proxies do not support udp (eg pipws, piph2) - if userelay || useproxy { - conn, rpid, who, err = t.pxdial(network, pid) - if err != nil && useudp { - clos(conn) - network = dnsx.NetTypeTCP - conn, rpid, who, err = t.pxdial(network, pid) - } - } else { - err = dnsx.ErrNoProxyProvider - } - return -} - -// ref: github.com/celzero/midway/blob/77ede02c/midway/server.go#L179 -func (t *transport) send(network, pid string, q *dns.Msg) (ans *dns.Msg, rpid string, elapsed time.Duration, qerr *dnsx.QueryError) { - var err error - if q == nil || !xdns.HasAnyQuestion(q) { - qerr = dnsx.NewBadQueryError(errQueryParse) - return - } - if qerr = dnsx.WillErr(t); qerr != nil { - return - } - - qname := xdns.QName(q) - - conn, rpid, who, err := t.connect(network, pid) - - logev(err)("dns53: send: (%s / %s) to %s for %s; px? %s / hop? %s; err? %v", - network, t.id, t.addrport, qname, pid, rpid, err) - - if err != nil { - qerr = dnsx.NewSendFailedQueryError(err) - return - } // else: send query - - lastaddr := remoteAddrIfAny(conn) // may return empty string - ans, elapsed, err = t.client.ExchangeWithConnContext(t.ctx, q, conn) - - if err != nil { - clos(conn) - ok := dialers.Disconfirm2(t.addrport, lastaddr) - log.E("dns53: sendRequest: (%s) for %s (elapsed: %s); err: %v; disconfirm? %t %s => %s", - t.id, qname, core.FmtPeriod(elapsed), err, ok, t.addrport, lastaddr) - qerr = dnsx.NewSendFailedQueryError(err) - } else if ans == nil { - t.toPool(who, conn) // or close - qerr = dnsx.NewBadResponseQueryError(errNoAns) - } else { - t.toPool(who, conn) // or close - dialers.Confirm2(t.addrport, lastaddr) - } - - t.lastaddr.Store(lastaddr) - - return -} - -func (t *transport) chooseProxy(pids ...string) string { - return dnsx.ChooseHealthyProxyHostPort("dns53: "+t.id, t.addrport, t.port, pids, t.proxies) -} - -func (t *transport) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg, err error) { - var pid string - proto, pids := xdns.Net2ProxyID(network) - - if r := t.relay; len(r) > 0 { - pid = t.chooseProxy(r) - } else { - pid = t.chooseProxy(pids...) - } - - ans, rpid, elapsed, qerr := t.send(proto, pid, q) - if qerr != nil { // only on send-request errors - ans = xdns.Servfail(q) - } - - status := dnsx.Complete - if qerr != nil { - err = qerr.Unwrap() - status = qerr.Status() - log.W("dns53: (%s) err(%v) / size(%d)", t.id, err, xdns.Len(ans)) - } - t.status.Store(status) - - smm.Latency = elapsed.Seconds() - smm.RData = xdns.GetInterestingRData(ans) - smm.RCode = xdns.Rcode(ans) - smm.RTtl = xdns.RTtl(ans) - smm.Server = t.getAddr() - smm.PID = pid - smm.RPID = rpid - if err != nil { - smm.Msg = err.Error() - } - smm.Status = status - t.est.Add(smm.Latency) - - if settings.Debug { - log.V("dns53: (%s) len(res): %d, data: %s, via: %s, err? %v", - t.id, xdns.Len(ans), smm.RData, smm.PID, err) - } - - return ans, err -} - -func (t *transport) ID() string { - return t.id -} - -func (t *transport) Type() string { - return dnsx.DNS53 -} - -func (t *transport) P50() int64 { - return t.est.Get() -} - -func (t *transport) GetAddr() string { - return t.getAddr() -} - -func (t *transport) getAddr() string { - addr := t.lastaddr.Load() - if len(addr) == 0 { - // may be protect.UidSelf (for bootstrap/default) or protect.System - addr = t.addrport - } - - prefix := dnsx.PrefixFor(t.id) - if len(prefix) > 0 { - addr = prefix + addr - } - - return addr -} - -func (t *transport) GetRelay() x.Proxy { - if r := t.relay; len(r) > 0 { - px, _ := t.proxies.ProxyFor(r) - return px - } - return nil -} - -func (t *transport) IPPorts() (ipps []netip.AddrPort) { - for _, ip := range dialers.For(t.addrport) { - ipps = append(ipps, netip.AddrPortFrom(ip, t.port)) - } - return -} - -func (t *transport) Status() int { - if px := t.GetRelay(); px != nil { - if px.Status() == ipn.TPU { - return dnsx.Paused - } - } - return t.status.Load() -} - -func (t *transport) Stop() error { - t.status.Store(dnsx.DEnd) - t.done() - return nil -} - -func remoteAddrIfAny(conn *dns.Conn) string { - if conn == nil || conn.Conn == nil { - return "" - } else if addr := conn.RemoteAddr(); addr == nil { - return "" - } else { - return addr.String() - } -} - -func logev(err error) log.LogFn { - if err != nil { - return log.E - } - return log.V -} - -func logwif(cond bool) log.LogFn { - if cond { - return log.W - } - return log.V -} diff --git a/intra/dnscrypt/certs.go b/intra/dnscrypt/certs.go deleted file mode 100644 index cc1ba0d6..00000000 --- a/intra/dnscrypt/certs.go +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package dnscrypt - -import ( - "bytes" - "encoding/binary" - "errors" - "net" - "strings" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/miekg/dns" - "golang.org/x/crypto/ed25519" - - "github.com/celzero/firestack/intra/xdns" -) - -type certinfo struct { - ServerPk [32]byte - SharedKey [32]byte - MagicQuery [xdns.ClientMagicLen]byte - CryptoConstruction xdns.CryptoConstruction - ForwardSecurity bool -} - -type dnsExchangeResponse struct { - response *dns.Msg - rtt time.Duration - priority int - err error -} - -var ( - errCancelled = errors.New("dnscrypt: cancelled") - errFetchingCerts = errors.New("dnscrypt: cannot reach server to fetch certs") -) - -func fetchCurrentDNSCryptCert(proxy *DcMulti, serverName *string, pk ed25519.PublicKey, serverAddress string, providerName string) (certinfo, error) { - exit := proxy.exit - - if exit == nil { - return certinfo{}, log.EE("%v for %s", dnsx.ErrNoProxyProvider, serverAddress) - } - if len(pk) != ed25519.PublicKeySize { - return certinfo{}, log.EE("invalid public key length for %s", serverAddress) - } - if !strings.HasSuffix(providerName, ".") { - providerName = providerName + "." - } - if serverName == nil { - serverName = &providerName - } - - query := dns.Msg{} - query.SetQuestion(providerName, dns.TypeTXT) - if !strings.HasPrefix(providerName, "2.dnscrypt-cert.") { - log.W("dnscrypt: [%v] is not v2, ('%v' doesn't start with '2.dnscrypt-cert.')", *serverName, providerName) - } - log.I("dnscrypt: [%v] Fetching DNSCrypt certificate for [%s] at [%v]", *serverName, providerName, serverAddress) - in, rtt, err := dnsExchange(exit.Dialer(), &query, serverAddress, serverName) - if err != nil || in == nil { - log.W("dnscrypt: [%s] TIMEOUT %v", *serverName, err) - return certinfo{}, log.EE("%v for %s: %v", errFetchingCerts, serverAddress, err) - } - now := uint32(time.Now().Unix()) - certInfo := certinfo{CryptoConstruction: xdns.UndefinedConstruction} - highestSerial := uint32(0) - var certCountStr string - for _, answerRr := range in.Answer { - var txt string - if t, ok := answerRr.(*dns.TXT); !ok { - log.I("dnscrypt: [%v] Extra record of type [%v] found in certificate", *serverName, answerRr.Header().Rrtype) - continue - } else { - txt = strings.Join(t.Txt, "") - } - binCert := packTxtString(txt) - if len(binCert) < 124 { - log.W("dnscrypt: [%v] Certificate too short", *serverName) - continue - } - if !bytes.Equal(binCert[:4], xdns.CertMagic[:4]) { - log.W("dnscrypt: [%v] Invalid cert magic", *serverName) - continue - } - cryptoConstruction := xdns.CryptoConstruction(0) - switch esVersion := binary.BigEndian.Uint16(binCert[4:6]); esVersion { - case 0x0001: - cryptoConstruction = xdns.XSalsa20Poly1305 - case 0x0002: - cryptoConstruction = xdns.XChacha20Poly1305 - default: - log.W("dnscrypt: [%v] Unsupported crypto construction", *serverName) - continue - } - signature := binCert[8:72] - signed := binCert[72:] - if !ed25519.Verify(pk, signed, signature) { - log.W("dnscrypt: [%v] Incorrect signature for provider name: [%v]", *serverName, providerName) - continue - } - serial := binary.BigEndian.Uint32(binCert[112:116]) - tsBegin := binary.BigEndian.Uint32(binCert[116:120]) - tsEnd := binary.BigEndian.Uint32(binCert[120:124]) - if tsBegin >= tsEnd { - log.W("dnscrypt: [%v] certificate ends before it starts (%v >= %v)", *serverName, tsBegin, tsEnd) - continue - } - ttl := tsEnd - tsBegin - if ttl > 86400*7 { - log.I("dnscrypt: [%v] the key validity period for this server is excessively long (%d days), significantly reducing reliability and forward security.", *serverName, ttl/86400) - daysLeft := (tsEnd - now) / 86400 - if daysLeft < 1 { - log.W("dnscrypt: [%v] certificate will expire today -- Switch to a different resolver as soon as possible", *serverName) - } else if daysLeft <= 7 { - log.W("dnscrypt: [%v] certificate is about to expire -- if you don't manage this server, tell the server operator about it", *serverName) - } else if daysLeft <= 30 { - log.I("dnscrypt: [%v] certificate will expire in %d days", *serverName, daysLeft) - } - certInfo.ForwardSecurity = false - } else { - certInfo.ForwardSecurity = true - } - if !proxy.certIgnoreTimestamp { - if now > tsEnd || now < tsBegin { - log.W("dnscrypt: [%v] Certificate not valid at the current date (now: %v is not in [%v..%v])", *serverName, now, tsBegin, tsEnd) - continue - } - } - if serial < highestSerial { - log.W("dnscrypt: [%v] Superseded by a previous certificate", *serverName) - continue - } - if serial == highestSerial { - if cryptoConstruction < certInfo.CryptoConstruction { - log.W("dnscrypt: [%v] Keeping the previous, preferred crypto construction", *serverName) - continue - } else { - log.W("dnscrypt: [%v] Upgrading the construction from %v to %v", *serverName, certInfo.CryptoConstruction, cryptoConstruction) - } - } - if cryptoConstruction != xdns.XChacha20Poly1305 && cryptoConstruction != xdns.XSalsa20Poly1305 { - log.W("dnscrypt: [%v] Cryptographic construction %v not supported", *serverName, cryptoConstruction) - continue - } - var serverPk [32]byte - copy(serverPk[:], binCert[72:104]) - sharedKey := computeSharedKey(cryptoConstruction, &proxy.proxySecretKey, &serverPk, &providerName) - certInfo.SharedKey = sharedKey - highestSerial = serial - certInfo.CryptoConstruction = cryptoConstruction - copy(certInfo.ServerPk[:], serverPk[:]) - copy(certInfo.MagicQuery[:], binCert[104:112]) - log.I("dnscrypt: [%s] OK (DNSCrypt) - rtt: %dms%s", *serverName, rtt.Nanoseconds()/1000000, certCountStr) - certCountStr = " - additional certificate" - } - if certInfo.CryptoConstruction == xdns.UndefinedConstruction { - return certInfo, log.EE("no useable cert found for %s", serverAddress) - } - return certInfo, nil -} - -func isDigit(b byte) bool { return b >= '0' && b <= '9' } - -func dddToByte(s []byte) byte { - return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) -} - -func packTxtString(s string) []byte { - bs := make([]byte, len(s)) - msg := make([]byte, 0) - copy(bs, s) - for i := 0; i < len(bs); i++ { - if bs[i] == '\\' { - i++ - if i == len(bs) { - break - } - if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { - msg = append(msg, dddToByte(bs[i:])) - i += 2 - } else if bs[i] == 't' { - msg = append(msg, '\t') - } else if bs[i] == 'r' { - msg = append(msg, '\r') - } else if bs[i] == 'n' { - msg = append(msg, '\n') - } else { - msg = append(msg, bs[i]) - } - } else { - msg = append(msg, bs[i]) - } - } - return msg -} - -func dnsExchange(dialer protect.RDialer, query *dns.Msg, serverAddress string, serverName *string) (*dns.Msg, time.Duration, error) { - // always use udp to fetch certs since most servers like adguard, cleanbrowsing - // don't support fetching certs over tcp - proto := "udp" - - // add padding to ensure that the cert txt response is large enough - minsz := 480 - cancelChannel := make(chan struct{}) - channel := make(chan dnsExchangeResponse) - var err error - options := 0 - - for tries := range 4 { - queryCopy := query.Copy() - queryCopy.Id += uint16(options) - timeout := time.Duration(200*tries) * time.Millisecond - core.Go2("cert.dnsExchange", func(query *dns.Msg, delay time.Duration) { - - if proto == "udp" { - proto = "tcp" - } else { - proto = "udp" - } - option := dnsExchangeResponse{err: errCancelled} - time.Sleep(delay) - select { - case <-cancelChannel: - return - default: - option = _dnsExchange(dialer, proto, query, serverAddress, minsz) - } - option.priority = 0 - channel <- option - }, queryCopy, timeout) - options++ - } - deadline := time.NewTimer(30 * time.Second) - var bestOption *dnsExchangeResponse - for i := 0; i < options; i++ { - select { - case res := <-channel: - if res.err == nil { - if bestOption == nil { - bestOption = &res - } else if res.rtt < bestOption.rtt { - bestOption = &res - close(cancelChannel) - i = options // break - } - } else { - err = res.err - } - case <-deadline.C: - i = options // break - } - } - if bestOption != nil { - log.D("dnscrypt: cert retrieval for [%v] succeeded via relay?", *serverName) - return bestOption.response, bestOption.rtt, nil - } - - log.I("dnscrypt: no cert, ignoring server: [%v] proto: [%v]", *serverName, proto) - - err = core.OneErr(err, errFetchingCerts) - - return nil, 0, err -} - -// _dnsExchange sends query and returns an answer from serverAddress using dialer. -// It can be called from multiple goroutines. -func _dnsExchange(dialer protect.RDialer, proto string, query *dns.Msg, serverAddress string, paddedLen int) dnsExchangeResponse { - var packet []byte - var rtt time.Duration - - qname := xdns.QName(query) - // FIXME: udp relays do not support fetching certs over relays, and - // doing so leaks client's identity to the actual dns-crypt server! - log.V("dnscrypt: [%s] relay is not used when fetching certs %s", proto, qname) - if proto == "udp" { - qNameLen, padding := len(qname), 0 - if qNameLen < paddedLen { - padding = paddedLen - qNameLen - } - if padding > 0 { - opt := new(dns.OPT) - opt.Hdr.Name = "." - ext := new(dns.EDNS0_PADDING) - ext.Padding = make([]byte, padding) - opt.Option = append(opt.Option, ext) - query.Extra = []dns.RR{opt} - } - binQuery, err := query.Pack() - if err != nil { - return dnsExchangeResponse{err: err} - } - - now := time.Now() - pc, err := dialer.Dial("udp", serverAddress) - if err != nil { - return dnsExchangeResponse{err: err} - } else if pc == nil || core.IsNil(pc) { - return dnsExchangeResponse{err: errNoConn} - } - - defer clos(pc) - if derr := pc.SetDeadline(time.Now().Add(timeout8s)); derr != nil { - return dnsExchangeResponse{err: derr} - } - if _, werr := pc.Write(binQuery); werr != nil { - return dnsExchangeResponse{err: werr} - } - packet = make([]byte, xdns.MaxDNSPacketSize) - length, err := pc.Read(packet) - if err != nil { - return dnsExchangeResponse{err: err} - } - rtt = time.Since(now) - packet = packet[:length] - } else { - binQuery, err := query.Pack() - if err != nil { - return dnsExchangeResponse{err: err} - } - // FIXME: for time-being, tcp validation is used only - // when relay addresses are nil. Uncomment the code - // below when udp transport for dnscrypt-proxy is ready. - /* - if relayTCPAddr != nil && relayForCerts { - proxy.prepareForRelay(tcpAddr.IP, tcpAddr.Port, &binQuery) - upstreamAddr = relayTCPAddr - } - */ - now := time.Now() - var pc net.Conn - pc, err = dialer.Dial("tcp", serverAddress) - if err != nil { - return dnsExchangeResponse{err: err} - } else if pc == nil || core.IsNil(pc) { - return dnsExchangeResponse{err: errNoConn} - } - - defer clos(pc) - if derr := pc.SetDeadline(time.Now().Add(timeout8s)); derr != nil { - return dnsExchangeResponse{err: derr} - } - binQuery, err = xdns.PrefixWithSize(binQuery) - if err != nil { - return dnsExchangeResponse{err: err} - } - if _, werr := pc.Write(binQuery); werr != nil { - return dnsExchangeResponse{err: werr} - } - packet, err = xdns.ReadPrefixed(&pc) - if err != nil { - return dnsExchangeResponse{err: err} - } - rtt = time.Since(now) - } - msg := dns.Msg{} - if err := msg.Unpack(packet); err != nil { - return dnsExchangeResponse{err: err} - } - return dnsExchangeResponse{response: &msg, rtt: rtt, err: nil} -} - -func clos(c net.Conn) { - core.CloseConn(c) -} diff --git a/intra/dnscrypt/crypto.go b/intra/dnscrypt/crypto.go deleted file mode 100644 index 56b53117..00000000 --- a/intra/dnscrypt/crypto.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package dnscrypt - -import ( - "bytes" - crypto_rand "crypto/rand" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - - "github.com/jedisct1/xsecretbox" - - "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/secretbox" -) - -const ( - // NonceSize is what the name suggests - NonceSize = 24 - // TagSize is what the name suggests - TagSize = 16 - // HalfNonceSize is half of NonceSize - HalfNonceSize = NonceSize / 2 - // PublicKeySize is the size of a public key - PublicKeySize = 32 - // QueryOverhead is the amount of request overhead due to client-magic, public key, nonce, and tag - QueryOverhead = xdns.ClientMagicLen + PublicKeySize + HalfNonceSize + TagSize - // ResponseOverhead is the amount of answer overhead due to server-magic, nonce, and tag - ResponseOverhead = len(xdns.ServerMagic) + NonceSize + TagSize -) - -func pad(packet []byte, minSize int) []byte { - packet = append(packet, 0x80) - for len(packet) < minSize { - packet = append(packet, 0) - } - return packet -} - -func unpad(packet []byte) ([]byte, error) { - for i := len(packet); ; { - if i == 0 { // short packet - return nil, errIncorrectPad - } - i-- - if packet[i] == 0x80 { - return packet[:i], nil - } else if packet[i] != 0x00 { // delimiter not found - return nil, errIncorrectPad - } - } -} - -func computeSharedKey(cryptoConstruction xdns.CryptoConstruction, secretKey *[32]byte, serverPk *[32]byte, providerName *string) (sharedKey [32]byte) { - if cryptoConstruction == xdns.XChacha20Poly1305 { - var err error - sharedKey, err = xsecretbox.SharedKey(*secretKey, *serverPk) - if err != nil { - log.W("dnscrypt: [%v] Weak public key", providerName) - } - } else { - box.Precompute(&sharedKey, serverPk, secretKey) - } - return -} - -func encrypt( - serverInfo *serverinfo, - packet []byte, - useudp, userelay bool, -) (sharedKey *[32]byte, encrypted []byte, clientNonce []byte, err error) { - nonce := make([]byte, NonceSize) - clientNonce = make([]byte, HalfNonceSize) - if _, err = crypto_rand.Read(clientNonce); err != nil { - return - } - copy(nonce, clientNonce) - - var publicKey *[PublicKeySize]byte - - sharedKey = &serverInfo.SharedKey - publicKey = serverInfo.ClientPubKey - - var paddedLength int - if useudp { // using udp - paddedLength = xdns.MaxDNSUDPSafePacketSize - } else if userelay { // tcp, with relay - paddedLength = xdns.MaxDNSPacketSize - } else { // tcp, without relay - minQuestionSize := QueryOverhead + len(packet) - // random pad if tcp without relay - var xpad [1]byte - if _, err = crypto_rand.Read(xpad[:]); err != nil { - return - } - minQuestionSize += int(xpad[0]) - paddedLength = xdns.Min(xdns.MaxDNSUDPPacketSize, (xdns.Max(minQuestionSize, QueryOverhead)+1+63) & ^63) - } - - if QueryOverhead+len(packet)+1 > paddedLength { - err = errQueryTooLarge - return - } - - encrypted = append(serverInfo.MagicQuery[:], publicKey[:]...) - encrypted = append(encrypted, nonce[:HalfNonceSize]...) - padded := pad(packet, paddedLength-QueryOverhead) - - if serverInfo.CryptoConstruction == xdns.XChacha20Poly1305 { - encrypted = xsecretbox.Seal(encrypted, nonce, padded, sharedKey[:]) - } else { - var xsalsaNonce [24]byte - copy(xsalsaNonce[:], nonce) - encrypted = secretbox.Seal(encrypted, padded, &xsalsaNonce, sharedKey) - } - return -} - -func decrypt(serverInfo *serverinfo, sharedKey *[32]byte, encrypted []byte, nonce []byte) ([]byte, error) { - serverMagicLen := len(xdns.ServerMagic) - responseHeaderLen := serverMagicLen + NonceSize - if len(encrypted) < responseHeaderLen+TagSize+int(xdns.MinDNSPacketSize) || - len(encrypted) > responseHeaderLen+TagSize+int(xdns.MaxDNSPacketSize) || - !bytes.Equal(encrypted[:serverMagicLen], xdns.ServerMagic[:]) { - return encrypted, errInvalidResponse - } - serverNonce := encrypted[serverMagicLen:responseHeaderLen] - if !bytes.Equal(nonce[:HalfNonceSize], serverNonce[:HalfNonceSize]) { - return encrypted, errNonceUnexpected - } - - var packet []byte - var err error - if serverInfo.CryptoConstruction == xdns.XChacha20Poly1305 { - packet, err = xsecretbox.Open(nil, serverNonce, encrypted[responseHeaderLen:], sharedKey[:]) - } else { - var xsalsaServerNonce [24]byte - copy(xsalsaServerNonce[:], serverNonce) - var ok bool - packet, ok = secretbox.Open(nil, encrypted[responseHeaderLen:], &xsalsaServerNonce, sharedKey) - if !ok { - err = errIncorrectTag - } - } - - if err != nil { - return encrypted, err - } - if len(packet) <= 0 { - return encrypted, errInvalidResponse - } - - packet, err = unpad(packet) - if err != nil || len(packet) < xdns.MinDNSPacketSize { - return encrypted, errIncorrectPad - } - return packet, nil -} diff --git a/intra/dnscrypt/intercept.go b/intra/dnscrypt/intercept.go deleted file mode 100644 index 9c8d6043..00000000 --- a/intra/dnscrypt/intercept.go +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package dnscrypt - -import ( - "errors" - "strings" - - "github.com/celzero/firestack/intra/xdns" - - "github.com/celzero/firestack/intra/log" - "github.com/miekg/dns" -) - -const ( - ActionNone = iota // No action has been taken - ActionContinue // Continue with the request - ActionDrop // Drop the request - ActionSynth // Use synthesized response -) - -const ( - ReturnCodePass = iota - ReturnCodeSynth -) - -type intercept struct { - state *interceptstate -} - -type interceptstate struct { - originalMaxPayloadSize int - maxUnencryptedUDPSafePayloadSize int - maxPayloadSize int - question *dns.Msg - qName string - response *dns.Msg - action int - returnCode int - dnssec bool - blocklists string -} - -// HandleRequest changes the incoming DNS question either to add padding to it or synthesize a pre-determined answer. -func (ic *intercept) handleRequest(msg *dns.Msg) (*dns.Msg, error) { - state := ic.state - if len(msg.Question) != 1 { - return msg, errors.New("unexpected number of questions") - } - qName, err := xdns.NormalizeQName(msg.Question[0].Name) - if err != nil { - return msg, err - } - log.D("dnscrypt: query for [%v]", qName) - state.qName = qName - state.question = msg - - // TODO: Recheck: None of these methods return err - if berr := ic.blockUnqualified(msg); berr != nil { - state.action = ActionDrop - return msg, berr - } - if serr := ic.getSetPayloadSize(msg); serr != nil { - state.action = ActionDrop - return msg, serr - } - return msg, nil -} - -// handleResponse -func (ic *intercept) handleResponse(packet []byte, truncate bool) ([]byte, error) { - state := ic.state - msg := dns.Msg{Compress: true} - if err := msg.Unpack(packet); err != nil { - // HasTCFlag is always false because currently transport is TCP only - if len(packet) >= xdns.MinDNSPacketSize && xdns.HasTCFlag2(packet) { - log.W("dnscrypt: has-tc-flag, retry with tcp, ignore err: %w", err) - err = nil - } - log.E("dnscrypt: has-tc-flag not set, intercept-handle-response err: %w", err) - return packet, err - } - - xdns.RemoveEDNS0Options(&msg) - - packet2, err := msg.PackBuffer(packet) - if err != nil { - log.E("dnscrypt: intercept-handle-response err for pack-buffer: %w", err) - return packet, err - } - - if truncate && len(packet2) > state.maxUnencryptedUDPSafePayloadSize { - return xdns.TruncatedResponse(packet2) - } - - return packet2, nil -} - -// GetSetPayloadSize adjusts the maximum payload size advertised in queries sent to upstream servers. -func (ic *intercept) getSetPayloadSize(msg *dns.Msg) error { - state := ic.state - - if state.action != ActionContinue { - // nothing to do. - return nil - } - - state.originalMaxPayloadSize = 512 - ResponseOverhead - edns0 := msg.IsEdns0() - dnssec := false - if edns0 != nil { - state.maxUnencryptedUDPSafePayloadSize = int(edns0.UDPSize()) - state.originalMaxPayloadSize = xdns.Max(state.maxUnencryptedUDPSafePayloadSize-ResponseOverhead, state.originalMaxPayloadSize) - dnssec = edns0.Do() - } - var options *[]dns.EDNS0 - state.dnssec = dnssec - state.maxPayloadSize = xdns.Min(xdns.MaxDNSUDPPacketSize-ResponseOverhead, xdns.Max(state.originalMaxPayloadSize, state.maxPayloadSize)) - if state.maxPayloadSize > 512 { - extra2 := []dns.RR{} - for _, extra := range msg.Extra { - if extra.Header().Rrtype != dns.TypeOPT { - extra2 = append(extra2, extra) - } else if xoptions := &extra.(*dns.OPT).Option; len(*xoptions) > 0 && options == nil { - options = xoptions - } - } - msg.Extra = extra2 - msg.SetEdns0(uint16(state.maxPayloadSize), dnssec) - if options != nil { - for _, extra := range msg.Extra { - if extra.Header().Rrtype == dns.TypeOPT { - extra.(*dns.OPT).Option = *options - break - } - } - } - } - return nil -} - -// BlockUnqualified blocks unqualified DNS names. -func (ic *intercept) blockUnqualified(msg *dns.Msg) error { - state := ic.state - - if state.action != ActionContinue { - // nothing to do. - return nil - } - - if len(msg.Question) <= 0 { - return nil - } - - question := msg.Question[0] - if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) { - return nil - } - if strings.IndexByte(state.qName, '.') >= 0 { - return nil - } - synth := xdns.EmptyResponseFromMessage(msg) // may be nil - if synth == nil { - return nil - } - synth.Rcode = dns.RcodeNameError - state.response = synth - state.action = ActionSynth - state.returnCode = ReturnCodeSynth - - return nil -} - -func newIntercept() *intercept { - return &intercept{ - state: newInterceptState(), - } -} - -func newInterceptState() *interceptstate { - return &interceptstate{ - action: ActionContinue, - returnCode: ReturnCodePass, - maxPayloadSize: xdns.MaxDNSUDPPacketSize - ResponseOverhead, - question: nil, - qName: "", - maxUnencryptedUDPSafePayloadSize: xdns.MaxDNSUDPSafePacketSize, - dnssec: false, - response: nil, - blocklists: "", - } -} diff --git a/intra/dnscrypt/multiserver.go b/intra/dnscrypt/multiserver.go deleted file mode 100644 index 6e024d3a..00000000 --- a/intra/dnscrypt/multiserver.go +++ /dev/null @@ -1,762 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package dnscrypt - -import ( - "context" - crypto_rand "crypto/rand" - "encoding/binary" - "errors" - "fmt" - "net" - "net/netip" - "runtime/debug" - "strings" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" - "golang.org/x/crypto/curve25519" - - stamps "github.com/jedisct1/go-dnsstamps" -) - -// DcMulti is a dnsx.TransportMult supporting dnscrypt servers and relays -type DcMulti struct { - sync.RWMutex - proxyPublicKey [32]byte - proxySecretKey [32]byte - serversInfo *ServersInfo - certIgnoreTimestamp bool - registeredServers map[string]registeredserver - routes []string - liveServers []string - proxies ipn.ProxyProvider - ctx context.Context - sigterm context.CancelFunc - lastStatus *core.Volatile[int] - lastAddr string - ctl protect.Controller - exit ipn.Proxy // may be nil - est core.P2QuantileEstimator -} - -var ( - certRefreshDelay = 240 * time.Minute - certRefreshDelayAfterFailure = 10 * time.Second -) - -var _ dnsx.TransportMult = (*DcMulti)(nil) -var timeout8s = 8000 * time.Millisecond - -var ( - errNoCert = errors.New("dnscrypt: error refreshing cert") - errQueryTooShort = errors.New("dnscrypt: query size too short") - errQueryTooLarge = errors.New("dnscrypt: query size too large") - errNoServers = errors.New("dnscrypt: server not found") - errNothing = errors.New("dnscrypt: specify at least one resolver endpoint") - errNoDoh = errors.New("dnscrypt: dns-over-https not supported") - errNoRoute = errors.New("dnscrypt: specify atleast one route") - errUnknownProto = errors.New("dnscrypt: unknown protocol") - errInvalidResponse = errors.New("dnscrypt: response too large or too small") - errNonceUnexpected = errors.New("dnscrypt: unexpected nonce") - errIncorrectTag = errors.New("dnscrypt: incorrect tag") - errIncorrectPad = errors.New("dnscrypt: incorrect padding") - errNoConn = errors.New("dnscrypt: no connection") -) - -func chooseAny[T any](s []T) (zz T) { - return core.ChooseOne(s) -} - -func udpExchange(pid string, serverInfo *serverinfo, relayAddrs []*net.UDPAddr, sharedKey *[32]byte, encryptedQuery []byte, clientNonce []byte) (res []byte, relay net.Addr, err error) { - upstreamAddr := serverInfo.UDPAddr - userelay := false - if len(relayAddrs) > 0 { - oneaddr := chooseAny(relayAddrs) - if oneaddr != nil && oneaddr.AddrPort().IsValid() { - upstreamAddr = oneaddr - relay = upstreamAddr - userelay = true - } - } - - pc, err := serverInfo.dialudp(pid, upstreamAddr) - pcnil := pc == nil || core.IsNil(pc) - if err != nil || pcnil { // nilaway: tx.socks5 returns nil conn even if err == nil - err = core.OneErr(err, errNoConn) - log.E("dnscrypt: udp: dialing %s; hasConn? %s(%t); err: %v", serverInfo, pid, pcnil, err) - return - } - - defer clos(pc) - if err = pc.SetDeadline(time.Now().Add(timeout8s)); err != nil { - return - } - if userelay { - prepareForRelay(serverInfo.UDPAddr.IP, serverInfo.UDPAddr.Port, &encryptedQuery) - } - // TODO: use a pool - bptr := core.AllocRegion(xdns.MaxDNSUDPPacketSize) - encryptedResponse := (*bptr)[:xdns.MaxDNSUDPPacketSize] - defer func() { - *bptr = encryptedResponse - core.Recycle(bptr) - }() - for tries := 2; tries > 0; tries-- { - if _, err = pc.Write(encryptedQuery); err != nil { - log.E("dnscrypt: udp: [%s] write err; [%v]", serverInfo.Name, err) - return - } - var length int - length, err = pc.Read(encryptedResponse) - if err == nil { - encryptedResponse = encryptedResponse[:length] - break - } else if tries <= 0 { - log.E("dnscrypt: udp: [%s] read err; quit [%v]", serverInfo.Name, err) - return - } - log.D("dnscrypt: udp: [%s] read err; retry [%v]", serverInfo.Name, err) - } - res, err = decrypt(serverInfo, sharedKey, encryptedResponse, clientNonce) - return -} - -func tcpExchange(pid string, serverInfo *serverinfo, relayAddrs []*net.TCPAddr, sharedKey *[32]byte, encryptedQuery []byte, clientNonce []byte) (res []byte, relay net.Addr, err error) { - upstreamAddr := serverInfo.TCPAddr - userelay := false - if len(relayAddrs) > 0 { - oneaddr := chooseAny(relayAddrs) - if oneaddr != nil && oneaddr.AddrPort().IsValid() { - upstreamAddr = oneaddr - relay = upstreamAddr - userelay = true - } - } - - pc, err := serverInfo.dialtcp(pid, upstreamAddr) - pcnil := pc == nil || core.IsNil(pc) - if err != nil || pcnil { // nilaway: tx.socks5 returns nil conn even if err == nil - err = core.OneErr(err, errNoConn) - log.E("dnscrypt: tcp: dialing %s; hasConn? %s(%t); err: %v", serverInfo, pid, pcnil, err) - return - } - defer clos(pc) - if err = pc.SetDeadline(time.Now().Add(timeout8s)); err != nil { - log.E("dnscrypt: tcp: err deadline: %v", err) - return - } - if userelay { - prepareForRelay(serverInfo.TCPAddr.IP, serverInfo.TCPAddr.Port, &encryptedQuery) - } - encryptedQuery, err = xdns.PrefixWithSize(encryptedQuery) - if err != nil { - log.E("dnscrypt: tcp: prefix(q) %s err: %v", serverInfo, err) - return - } - if _, err = pc.Write(encryptedQuery); err != nil { - log.E("dnscrypt: tcp: err write: %v", serverInfo, err) - return - } - encryptedResponse, err := xdns.ReadPrefixed(&pc) - if err != nil { - log.E("dnscrypt: tcp: read(enc) %s err %v", serverInfo, err) - return - } - res, err = decrypt(serverInfo, sharedKey, encryptedResponse, clientNonce) - return -} - -func prepareForRelay(ip net.IP, port int, eq *[]byte) { - anonymizedDNSHeader := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00} - relayedQuery := append(anonymizedDNSHeader, ip.To16()...) - var tmp [2]byte - binary.BigEndian.PutUint16(tmp[0:2], uint16(port)) - relayedQuery = append(relayedQuery, tmp[:]...) - relayedQuery = append(relayedQuery, *eq...) - *eq = relayedQuery -} - -func query(pid string, packet *dns.Msg, serverInfo *serverinfo, useudp bool) (ans *dns.Msg, relay net.Addr, qerr *dnsx.QueryError) { - var response []byte - if packet == nil || !xdns.HasAnyQuestion(packet) { - qerr = dnsx.NewBadQueryError(errQueryTooShort) - return // nil ans - } - - intercept := newIntercept() - intercepted, err := intercept.handleRequest(packet) - state := intercept.state - - saction := state.action - sr := state.response - if err != nil || saction == ActionDrop { - log.E("dnscrypt: ActionDrop %v", err) - qerr = dnsx.NewBadQueryError(err) - return // nil ans - } - if saction == ActionSynth { - if sr != nil { - ans = sr - log.D("dnscrypt: send synth response") - return // synth ans - } - log.D("dnscrypt: no synth; forward query [udp? %t]...", useudp) - } - - if serverInfo == nil { - qerr = dnsx.NewTransportQueryError(errNoServers) - return // nil ans - } - - if qerr = dnsx.WillErr(serverInfo); qerr != nil { - return // nil ans - } - - query, err := intercepted.Pack() - if err != nil { - log.E("dnscrypt: pack query err: %v", err) - qerr = dnsx.NewBadQueryError(err) - return // nil ans - } - - tcprelays := serverInfo.RelayTCPAddrs.Load() // may return nil - udprelays := serverInfo.RelayUDPAddrs.Load() // may return nil - usetcprelay := len(tcprelays) > 0 - if serverInfo.Proto == stamps.StampProtoTypeDNSCrypt { - sharedKey, encryptedQuery, clientNonce, cerr := encrypt(serverInfo, query, useudp, usetcprelay) - - if cerr != nil { - log.W("dnscrypt: enc fail forwarding to %s; udp? %t, relay? %t", serverInfo, useudp, usetcprelay) - qerr = dnsx.NewInternalQueryError(cerr) - return // nil ans - } - - if useudp { - response, relay, err = udpExchange(pid, serverInfo, udprelays, sharedKey, encryptedQuery, clientNonce) - } - tcpfallback := useudp && err != nil - if tcpfallback { - log.D("dnscrypt: udp failed, trying tcp; relay? %t", usetcprelay) - } - // if udp errored out, try over tcp; or use tcp if udp is disabled - if tcpfallback || !useudp { - useudp = false // switched to tcp - response, relay, err = tcpExchange(pid, serverInfo, tcprelays, sharedKey, encryptedQuery, clientNonce) - } - - if err != nil { - log.W("dnscrypt: querying [udp? %t; tcpfallback?: %t; relay? %t] failed: %v", useudp, tcpfallback, usetcprelay, err) - qerr = dnsx.NewSendFailedQueryError(err) - return // nil ans - } - } else if serverInfo.Proto == stamps.StampProtoTypeDoH { - // FIXME: implement - qerr = dnsx.NewSendFailedQueryError(errNoDoh) - return // nil ans - } else { - qerr = dnsx.NewTransportQueryError(errUnknownProto) - return // nil ans - } - - if len(response) < xdns.MinDNSPacketSize || len(response) > xdns.MaxDNSPacketSize { - log.E("dnscrypt: response from %s too small or too large", serverInfo) - qerr = dnsx.NewBadResponseQueryError(errInvalidResponse) - return // nil ans - } - - response, err = intercept.handleResponse(response, useudp) - - if err != nil { - log.E("dnscrypt: err intercept response for %s: %w", serverInfo, err) - qerr = dnsx.NewBadResponseQueryError(err) - return // nil ans - } - - ans = new(dns.Msg) - if err = ans.Unpack(response); err != nil { - log.E("dnscrypt: err unpack response for %s: %w", serverInfo, err) - qerr = dnsx.NewBadResponseQueryError(err) - return // nil ans? - } - - return // ans "response" -} - -// resolve resolves incoming DNS query, data -func resolve(network string, data *dns.Msg, si *serverinfo, smm *x.DNSSummary) (ans *dns.Msg, err error) { - var qerr *dnsx.QueryError - var anonrelayaddr net.Addr - - before := time.Now() - - proto, pids := xdns.Net2ProxyID(network) - useudp := proto == dnsx.NetTypeUDP - pid := dnsx.NetNoProxy - if si != nil { - if r := si.relay; len(r) > 0 { - pid = si.chooseProxy([]string{r}) - } else { - pid = si.chooseProxy(pids) - } - } - - // ans, si may be nil - ans, anonrelayaddr, qerr = query(pid, data, si, useudp) - - after := time.Now() - - latency := after.Sub(before) - status := dnsx.Complete - - var resolver string - var anonrelay string - if si != nil { - resolver = si.HostName - if anonrelayaddr != nil { // may be nil - anonrelay = anonrelayaddr.String() - } - } - - if qerr != nil { - status = qerr.Status() - err = qerr.Unwrap() - } - - smm.Latency = latency.Seconds() - smm.RData = xdns.GetInterestingRData(ans) - smm.RCode = xdns.Rcode(ans) - smm.RTtl = xdns.RTtl(ans) - smm.Server = resolver - smm.PID = anonrelay // may be empty - smm.Status = status - if err != nil { - smm.Msg = err.Error() - } - if len(anonrelay) <= 0 { - smm.PID = pid - } - - if settings.Debug { - log.V("dnscrypt: len(res): %d, data: %s, via: %s, err? %v", - xdns.Len(ans), smm.RData, smm.PID, err) - } - - return // ans, err -} - -// LiveTransports returns csv of dnscrypt server-names currently in-use -func (proxy *DcMulti) LiveTransports() string { - if len(proxy.liveServers) <= 0 { - return "" - } - return strings.Join(proxy.liveServers[:], ",") -} - -func (proxy *DcMulti) refreshOne(uid string) (bool, error) { - proxy.RLock() - r, ok := proxy.registeredServers[uid] - proxy.RUnlock() - - if !ok { - return false, errNoServers - } - if err := proxy.serversInfo.refreshServer(proxy, r.name, r.stamp); err != nil { - log.E("dnscrypt: refresh failed %s: %s; err: %v", r.name, stamp2str(r.stamp), err) - return false, err - } - if settings.Debug { - log.D("dnscrypt: refresh success %s: %s", r.name, stamp2str(r.stamp)) - } - return true, nil -} - -// Refresh re-registers servers -func (proxy *DcMulti) Refresh() (string, error) { - var servers []*registeredserver - proxy.RLock() - for _, s := range proxy.registeredServers { - sp := &s // stackoverflow.com/a/68247837 - servers = append(servers, sp) - } - proxy.RUnlock() - - for _, registeredServer := range servers { - proxy.serversInfo.registerServer(registeredServer.name, registeredServer.stamp) - } - var err error - proxy.liveServers, err = proxy.serversInfo.refresh(proxy) - if len(proxy.liveServers) > 0 { - proxy.certIgnoreTimestamp = false - } else if err != nil { - // ignore error if live-servers are around - return "", err - } - go proxy.refreshRoutes() - - return proxy.ID(), nil -} - -// start starts this dnscrypt proxy -func (proxy *DcMulti) start() error { - if _, err := crypto_rand.Read(proxy.proxySecretKey[:]); err != nil { - return err - } - curve25519.ScalarBaseMult(&proxy.proxyPublicKey, &proxy.proxySecretKey) - - _, err := proxy.Refresh() - _ = core.Periodic("dcmulti.start", proxy.ctx, certRefreshDelay, func() { - maxtries := 10 - i := 0 - for { - i++ - if i > maxtries { - log.E("dnscrypt: cert refresh failed after %d tries", maxtries) - return - } - select { - case <-proxy.ctx.Done(): - log.I("dnscrypt: cert refresh stopped") - return - default: - } - - hasServers := proxy.serversInfo.len() > 0 - if !hasServers { - log.D("dnscrypt: no servers; next check after %v", certRefreshDelayAfterFailure) - return - } - proxy.liveServers, _ = proxy.serversInfo.refresh(proxy) - if someAlive := len(proxy.liveServers) > 0; someAlive { - log.I("dnscrypt: some servers alive; retry #%d; next check after", - i, certRefreshDelayAfterFailure) - proxy.certIgnoreTimestamp = false - return - } - proxy.certIgnoreTimestamp = true - backoff := time.Duration(i) * time.Second - wait := certRefreshDelayAfterFailure * backoff - log.W("dnscrypt: all servers dead; retry #%d in %v", i, wait) - time.Sleep(wait) - continue - - } - }) - // todo: on error: context.AfterFunc(refreshCtx, proxy.notifyRestart) - return err -} - -// func (proxy *DcMulti) notifyRestart() { -// defer proxy.Stop() -// log.U("DNSCrypt stopped; restart the app") -// } - -// Stop stops this dnscrypt proxy -func (proxy *DcMulti) Stop() error { - proxy.lastStatus.Store(dnsx.DEnd) - proxy.sigterm() - return nil -} - -// refreshRoutes re-adds relay routes to all live/tracked servers. -// Must be called from a goroutine. -func (proxy *DcMulti) refreshRoutes() { - debug.SetPanicOnFault(true) - defer core.Recover(core.Exit11, "dcmulti.refreshRoutes") - - udp, tcp := route(proxy) - if len(udp) <= 0 || len(tcp) <= 0 { - log.I("dnscrypt: refreshRoutes: remove all relays") - } - n := 0 - for _, x := range proxy.serversInfo.getAll() { - if x == nil { - continue - } - // udp, tcp may be empty or nil; which means no relay - x.RelayUDPAddrs.Store(udp) - x.RelayTCPAddrs.Store(tcp) - n++ - } - log.I("dnscrypt: refreshRoutes: %d/%d for %d servers", len(udp), len(tcp), n) -} - -// AddGateways adds relay servers -func (proxy *DcMulti) AddGateways(routescsv string) (int, error) { - if len(routescsv) <= 0 { - return 0, errNoRoute - } - - proxy.Lock() - defer proxy.Unlock() - r := strings.Split(routescsv, ",") - cat := xdns.FindUnique(proxy.routes, r) - proxy.routes = append(proxy.routes, cat...) - - log.I("dnscrypt: added %d/%d; relay? %s", len(cat), len(r), cat) - if len(cat) > 0 { - go proxy.refreshRoutes() - } - return len(cat), nil -} - -// RemoveGateways removes relay servers -func (proxy *DcMulti) RemoveGateways(routescsv string) (int, error) { - if len(routescsv) <= 0 { - return 0, errNoRoute - } - - proxy.Lock() - defer proxy.Unlock() - rm := strings.Split(routescsv, ",") - l := len(proxy.routes) - proxy.routes = xdns.FindUnique(rm, proxy.routes) - n := len(proxy.routes) - - if l != n { // routes changed - go proxy.refreshRoutes() - } - if settings.Debug { - log.V("dnscrypt: removed %d/%d; relays: %s", l-n, l, routescsv) - } - return l - n, nil -} - -func (proxy *DcMulti) removeOne(uid string) int { - proxy.Lock() - delete(proxy.registeredServers, uid) - proxy.Unlock() - - // TODO: handle err - n, err := proxy.serversInfo.unregisterServer(uid) - if settings.Debug { - log.D("dnscrypt: removed %s; %d servers (err? %v)", uid, n, err) - } - return n -} - -// Remove removes a dnscrypt server / relay, if any -func (proxy *DcMulti) Remove(uid string) bool { - // may be a gateway / relay or a dnscrypt server - n := proxy.removeOne(uid) - nr, nerr := proxy.RemoveGateways(uid) - if settings.Debug { - log.D("dnscrypt: removed %s; %d servers; %d relays [err %v]", uid, n, nr, nerr) - } - return true -} - -// RemoveAll removes all dnscrypt servers in the csv -func (proxy *DcMulti) RemoveAll(servernamescsv string) (int, error) { - if len(servernamescsv) <= 0 { - return 0, errNothing - } - - servernames := strings.Split(servernamescsv, ",") - c := 0 - for _, name := range servernames { - if len(name) == 0 { - continue - } - c = proxy.removeOne(name) - } - - log.I("dnscrypt: removed %d servers %s", c, servernamescsv) - return c, nil -} - -func (proxy *DcMulti) addOne(uid, rawstamp string) (string, error) { - stamp, err := stamps.NewServerStampFromString(rawstamp) - if err != nil { - return uid, fmt.Errorf("dnscrypt: stamp error for [%s] def: [%v]", rawstamp, err) - } - if stamp.Proto == stamps.StampProtoTypeDoH { - // TODO: Implement doh - return uid, fmt.Errorf("dnscrypt: doh not supported %s", rawstamp) - } - - proxy.Lock() - proxy.registeredServers[uid] = registeredserver{name: uid, stamp: stamp} - proxy.Unlock() - - if settings.Debug { - log.D("dnscrypt: added [%s] %s", uid, stamp2str(stamp)) - } - return uid, nil -} - -// Add implements dnsx.TransportMult -func (proxy *DcMulti) Add(t x.DNSTransport) bool { - // no-op - return false -} - -// Get implements dnsx.TransportMult -func (proxy *DcMulti) Get(id string) (x.DNSTransport, error) { - if t := proxy.serversInfo.get(id); t != nil { - return t, nil - } - return nil, errNoServers -} - -// AddAll registers additional dnscrypt servers -func (proxy *DcMulti) AddAll(serverscsv string) (int, error) { - if len(serverscsv) <= 0 { - return 0, errNothing - } - - servers := strings.Split(serverscsv, ",") - for i, serverStampPair := range servers { - if len(serverStampPair) == 0 { - return i, fmt.Errorf("dnscrypt: missing stamp for [%s]", serverStampPair) - } - serverStamp := strings.Split(serverStampPair, "#") - if len(serverStamp) < 2 { - return i, fmt.Errorf("dnscrypt: invalid stamp for [%s]", serverStampPair) - } - uid := serverStamp[0] - if _, err := proxy.addOne(uid, serverStamp[1]); err != nil { - return i, fmt.Errorf("dnscrypt: error adding [%s]: %v", uid, err) - } - } - return len(servers), nil -} - -// P50 implements dnsx.TransportMult -func (p *DcMulti) P50() int64 { - return p.est.Get() -} - -// ID implements dnsx.TransportMult -func (p *DcMulti) ID() string { - return dnsx.DcProxy -} - -// Type implements dnsx.TransportMult -func (p *DcMulti) Type() string { - return dnsx.DNSCrypt -} - -// Query implements dnsx.TransportMult -func (p *DcMulti) Query(network string, q *dns.Msg, smm *x.DNSSummary) (r *dns.Msg, err error) { - // TODO: check if server is active (status != DEnd) - r, err = resolve(network, q, p.serversInfo.getOne(), smm) - p.lastStatus.Store(smm.Status) - p.lastAddr = smm.Server - p.est.Add(smm.Latency) - return -} - -// GetAddr returns the last server address -func (p *DcMulti) GetAddr() string { - return p.getAddr() -} - -func (p *DcMulti) getAddr() string { - return p.lastAddr -} - -func (p *DcMulti) GetRelay() x.Proxy { - return nil -} - -func (p *DcMulti) IPPorts() []netip.AddrPort { - // TODO: get ipports from all servers - return dnsx.NoIPPort -} - -// Status implements dnsx.TransportMult -func (p *DcMulti) Status() int { - return p.lastStatus.Load() -} - -func stamp2str(s stamps.ServerStamp) string { - return core.UniqStr(fmt.Sprintf("name:%s, addr:%s, path:%s", s.ProviderName, s.ServerAddrStr, s.Path)) -} - -// NewDcMult creates a dnscrypt proxy -func NewDcMult(pctx context.Context, px ipn.ProxyProvider, ctl protect.Controller) *DcMulti { - ctx, cancel := context.WithCancel(pctx) - exit, err := px.ProxyFor(ipn.Exit) - if err != nil { - log.W("dnscrypt: no exit proxy: %v", err) - } - dc := &DcMulti{ - ctx: ctx, - sigterm: cancel, - routes: nil, - registeredServers: make(map[string]registeredserver), - certIgnoreTimestamp: false, - serversInfo: newServersInfo(), - liveServers: nil, - lastStatus: core.NewVolatile(dnsx.Start), - proxies: px, - lastAddr: "", - ctl: ctl, - exit: exit, // may be nil - est: core.NewP50Estimator(ctx), - } - err = dc.start() - if err != nil { - log.E("dnscrypt: start failed: %v", err) - } - return dc -} - -// AddTransport creates and adds a dnscrypt transport to p -func AddTransport(p *DcMulti, id, serverstamp string) (*serverinfo, error) { - if p == nil { - return nil, dnsx.ErrNoDcProxy - } - if _, err := p.addOne(id, serverstamp); err == nil { - if ok, err := p.refreshOne(id); ok { - log.I("dnscrypt: added %s; %s", id, serverstamp) - if tr := p.serversInfo.get(id); tr != nil { - go p.refreshRoutes() - return tr, nil - } - log.W("dnscrypt: failed to add1 %s; %s", id, serverstamp) - return nil, dnsx.ErrAddFailed - } else { - log.W("dnscrypt: failed to add2 %s; %s", id, serverstamp) - p.removeOne(id) - return nil, core.OneErr(err, errNoCert) - } - } else { - return nil, err - } -} - -// AddRelayTransport creates and adds a relay server to p -func AddRelayTransport(p *DcMulti, relaystamp string) error { - if p == nil { - return dnsx.ErrNoDcProxy - } - if _, err := p.AddGateways(relaystamp); err == nil { - log.I("dnscrypt: added relay %s", relaystamp) - return nil - } else { - return err - } -} diff --git a/intra/dnscrypt/servers.go b/intra/dnscrypt/servers.go deleted file mode 100644 index 4ac84858..00000000 --- a/intra/dnscrypt/servers.go +++ /dev/null @@ -1,500 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package dnscrypt - -import ( - "context" - "encoding/hex" - "errors" - "fmt" - "maps" - "math/rand" - "net" - "net/netip" - "strconv" - "strings" - "sync" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" - - stamps "github.com/jedisct1/go-dnsstamps" - "golang.org/x/crypto/ed25519" -) - -type registeredserver struct { - name string - stamp stamps.ServerStamp -} - -type serverinfo struct { - ctx context.Context - done context.CancelFunc - Proto stamps.StampProtoType - MagicQuery [8]byte - ClientPubKey *[32]byte - ServerPk [32]byte - SharedKey [32]byte - CryptoConstruction xdns.CryptoConstruction - Name string // id of the server - HostName string - UDPAddr *net.UDPAddr - TCPAddr *net.TCPAddr - proxies ipn.ProxyProvider // proxy-provider, may be nil - relay string // proxy relay to use, may be nil - est core.P2QuantileEstimator - - // fields below are mutable - - RelayUDPAddrs *core.Volatile[[]*net.UDPAddr] // anonymous relays, if any - RelayTCPAddrs *core.Volatile[[]*net.TCPAddr] // anonymous relays, if any - status *core.Volatile[int] // status of the last query -} - -var _ dnsx.Transport = (*serverinfo)(nil) - -type ServersInfo struct { - sync.RWMutex - inner map[string]*serverinfo - registeredServers map[string]registeredserver -} - -// newServersInfo returns a new servers-info object -func newServersInfo() *ServersInfo { - return &ServersInfo{ - registeredServers: make(map[string]registeredserver), - inner: make(map[string]*serverinfo), - } -} - -func (serversInfo *ServersInfo) len() int { - serversInfo.RLock() - defer serversInfo.RUnlock() - - return len(serversInfo.registeredServers) -} - -func (serversInfo *ServersInfo) getAll() []*serverinfo { - serversInfo.RLock() - defer serversInfo.RUnlock() - - servers := make([]*serverinfo, 0) - for _, si := range serversInfo.inner { - if si != nil { - servers = append(servers, si) - } - } - if settings.Debug { - log.V("dnscrypt: getAll: servers [%d/%d]", len(servers), len(serversInfo.inner)) - } - return servers -} - -func (serversInfo *ServersInfo) getOne() (serverInfo *serverinfo) { - serversInfo.RLock() - defer serversInfo.RUnlock() - - serversCount := len(serversInfo.inner) - if serversCount <= 0 { - return nil - } - selectAny := false - candidate := rand.Intn(serversCount) -retry: - i := 0 - for _, si := range serversInfo.inner { - if i == candidate || selectAny { - if si != nil && dnsx.WillErr(si) == nil { - if settings.Debug { - log.V("dnscrypt: candidate [%v]", si) // may be nil? - } - serverInfo = si - break - } - } - i++ - } - - if serverInfo == nil && !selectAny { - selectAny = true - goto retry - } - - return serverInfo -} - -func (serversInfo *ServersInfo) get(name string) *serverinfo { - serversInfo.RLock() - defer serversInfo.RUnlock() - serversCount := len(name) - if serversCount <= 0 { - return nil - } - return serversInfo.inner[name] // may be nil -} - -func (serversInfo *ServersInfo) unregisterServer(name string) (int, error) { - serversInfo.Lock() - defer serversInfo.Unlock() - - if si, ok := serversInfo.inner[name]; ok { - si.Stop() - } - - delete(serversInfo.registeredServers, name) - delete(serversInfo.inner, name) - - return len(serversInfo.registeredServers), nil -} - -func (serversInfo *ServersInfo) registerServer(name string, stamp stamps.ServerStamp) { - serversInfo.Lock() - defer serversInfo.Unlock() - - serversInfo.registeredServers[name] = registeredserver{name: name, stamp: stamp} -} - -func (serversInfo *ServersInfo) refresh(proxy *DcMulti) ([]string, error) { - if settings.Debug { - log.D("dnscrypt: refreshing certificates") - } - var liveServers []string - var err error - - // Get a snapshot of registered servers under lock to prevent race conditions - serversInfo.RLock() - copied := make(map[string]registeredserver) - maps.Copy(copied, serversInfo.registeredServers) - serversInfo.RUnlock() - - for _, registeredServer := range copied { - if err = serversInfo.refreshServer(proxy, registeredServer.name, registeredServer.stamp); err == nil { - liveServers = append(liveServers, registeredServer.name) - } else { - log.E("dnscrypt: %s not a live server? %w", registeredServer.stamp, err) - } - } - return liveServers, err -} - -func (serversInfo *ServersInfo) refreshServer(proxy *DcMulti, name string, stamp stamps.ServerStamp) error { - newServer, err := fetchServerInfo(proxy, name, stamp) - if err != nil { - return err - } - if name != newServer.Name { - return fmt.Errorf("[%s] != [%s]", name, newServer.Name) - } - - serversInfo.Lock() - defer serversInfo.Unlock() - if si, ok := serversInfo.inner[name]; ok { - si.Stop() - } - serversInfo.inner[name] = &newServer - serversInfo.registeredServers[name] = registeredserver{name: name, stamp: stamp} - return nil -} - -func fetchServerInfo(proxy *DcMulti, name string, stamp stamps.ServerStamp) (serverinfo, error) { - switch stamp.Proto { - case stamps.StampProtoTypeDNSCrypt: - return fetchDNSCryptServerInfo(proxy, name, stamp) - case stamps.StampProtoTypeDoH: - return fetchDoHServerInfo(proxy, name, stamp) - } - return serverinfo{}, log.EE("unsupported protocol for %s", stamp.ServerAddrStr) -} - -func fetchDNSCryptServerInfo(proxy *DcMulti, name string, stamp stamps.ServerStamp) (serverinfo, error) { - if len(stamp.ServerPk) != ed25519.PublicKeySize { - serverPk, err := hex.DecodeString(strings.ReplaceAll(string(stamp.ServerPk), ":", "")) - if err != nil || len(serverPk) != ed25519.PublicKeySize { - return serverinfo{}, fmt.Errorf("unsupported public key for [%s]: [%s]", name, stamp.ServerPk) - } - log.W("dnscrypt: public key [%s] shouldn't be hex-encoded any more", string(stamp.ServerPk)) - stamp.ServerPk = serverPk - } - - // note: relays are not used to fetch certs due to multiple issues reported by users - certInfo, err := fetchCurrentDNSCryptCert(proxy, &name, stamp.ServerPk, stamp.ServerAddrStr, stamp.ProviderName) - if err != nil { - return serverinfo{}, err - } - var tcpaddr *net.TCPAddr - var udpaddr *net.UDPAddr - s, p := hostport(&stamp) - if ips, err := dialers.Resolve(s); err == nil && len(ips) > 0 { - ipp := netip.AddrPortFrom(ips[0], p) - tcpaddr = net.TCPAddrFromAddrPort(ipp) - udpaddr = net.UDPAddrFromAddrPort(ipp) - } else { - return serverinfo{}, fmt.Errorf("dnscrypt: no ips for [%s]: %v", s, err) - } - if udpaddr == nil || tcpaddr == nil { - return serverinfo{}, log.EE("%v for %s", errNoServers, stamp.ServerAddrStr) - } - px := proxy.proxies - var relay string - if px != nil { - x, _ := px.ProxyFor(name) - if x != nil { - relay = x.ID() - } - } - - ctx, done := context.WithCancel(proxy.ctx) - si := serverinfo{ - ctx: ctx, - done: done, - Proto: stamps.StampProtoTypeDNSCrypt, - MagicQuery: certInfo.MagicQuery, - ClientPubKey: &proxy.proxyPublicKey, - ServerPk: certInfo.ServerPk, - SharedKey: certInfo.SharedKey, - CryptoConstruction: certInfo.CryptoConstruction, - HostName: stamp.ProviderName, - Name: name, - UDPAddr: udpaddr, - TCPAddr: tcpaddr, - RelayTCPAddrs: core.NewZeroVolatile[[]*net.TCPAddr](), // populated later; see proxy.refreshRoutes() - RelayUDPAddrs: core.NewZeroVolatile[[]*net.UDPAddr](), // populated later; see proxy.refreshRoutes() - proxies: px, - relay: relay, - est: core.NewP50Estimator(ctx), - status: core.NewVolatile(dnsx.Start), - } - log.I("dnscrypt: (%s) setup: %s; anonrelay? %t, proxy? %t", name, si.HostName, len(relay) > 0) - return si, nil -} - -func fetchDoHServerInfo(_ *DcMulti, _ string, _ stamps.ServerStamp) (serverinfo, error) { - // FIXME: custom ip-address, user-certs, and cert-pinning not supported - return serverinfo{}, errors.New("unsupported protocol") -} - -func route(proxy *DcMulti) (udpaddrs []*net.UDPAddr, tcpaddrs []*net.TCPAddr) { - proxy.Lock() - relays := proxy.routes - proxy.Unlock() - - udpaddrs = make([]*net.UDPAddr, 0) - tcpaddrs = make([]*net.TCPAddr, 0) - - if len(relays) <= 0 { // no err, no relays - return - } - - for _, rr := range relays { - var rrstamp *stamps.ServerStamp - if len(rr) == 0 { - log.W("dnscrypt: route: skip empty relay") - continue - } else if relayStamp, serr := stamps.NewServerStampFromString(rr); serr == nil { - rrstamp = &relayStamp - } - - if rrstamp == nil { - rrstamp = &stamps.ServerStamp{ - ServerAddrStr: rr, // may be a hostname or ip-address - Proto: stamps.StampProtoTypeDNSCryptRelay, - } - } - - host, port := hostport(rrstamp) - if rrstamp != nil && (rrstamp.Proto == stamps.StampProtoTypeDNSCrypt || - rrstamp.Proto == stamps.StampProtoTypeDNSCryptRelay) { - if ips, err := dialers.Resolve(host); err == nil && len(ips) > 0 { - ipp := netip.AddrPortFrom(ips[0], port) // TODO: randomize? - tcpaddrs = append(tcpaddrs, net.TCPAddrFromAddrPort(ipp)) - udpaddrs = append(udpaddrs, net.UDPAddrFromAddrPort(ipp)) - } else { - log.W("dnscrypt: route: zero ips for relay [%s] for server [%s]; err [%v]", rr, host, err) - } - } else { - log.W("dnscrypt: route: invalid relay [%s]", rr) - } - } - return -} - -func hostport(stamp *stamps.ServerStamp) (string, uint16) { - if stamp == nil { - return "", 0 - } - x := stamp.ServerAddrStr - s, port, err := net.SplitHostPort(x) - if err != nil || len(port) <= 0 { - log.W("dnscrypt: host-port og(%s); err? %v", x, err) - s = x - port = "443" // use default port - } - p, err := strconv.Atoi(port) - if err != nil { - p = 443 // use default port - } - return s, uint16(p) -} - -func (s *serverinfo) String() string { - if s == nil { - return "" - } - - serverid := s.ID() - servername := s.getAddr() - serveraddr := "notcp" - relayaddr := "norelay" - if s.TCPAddr != nil { - serveraddr = s.TCPAddr.String() - } - if a := s.RelayTCPAddrs.Load(); len(a) > 0 { - relayaddr = chooseAny(a).String() - } - - return serverid + ":" + servername + "/" + serveraddr + "<=>" + relayaddr -} - -func (s *serverinfo) ID() string { - return s.Name -} - -func (s *serverinfo) Type() string { - return dnsx.DNSCrypt -} - -func (s *serverinfo) Query(network string, q *dns.Msg, smm *x.DNSSummary) (r *dns.Msg, err error) { - r, err = resolve(network, q, s, smm) - s.status.Store(smm.Status) - - if s.est != nil { - s.est.Add(smm.Latency) - } - if err != nil { - smm.Msg = err.Error() - } - - return -} - -func (s *serverinfo) P50() int64 { - if s.est != nil { - return s.est.Get() - } else { - return 0 - } -} - -func (s *serverinfo) GetAddr() string { - return s.getAddr() -} - -func (s *serverinfo) getAddr() string { - return s.HostName -} - -func (s *serverinfo) GetRelay() x.Proxy { - return s.getRelay() -} - -func (s *serverinfo) getRelay() ipn.Proxy { - if r := s.relay; len(r) > 0 { - px, _ := s.proxies.ProxyFor(r) - return px - } - return nil -} - -func (s *serverinfo) IPPorts() []netip.AddrPort { - if relay := s.RelayUDPAddrs.Load(); len(relay) > 0 { - return addr2ipp(relay...) - } - return addr2ipp(s.UDPAddr) -} - -func (s *serverinfo) Status() int { - if px := s.getRelay(); px != nil { - if px.Status() == ipn.TPU { - return dnsx.Paused - } - } - return s.status.Load() -} - -func (s *serverinfo) Stop() error { - if s != nil { - s.status.Store(dnsx.DEnd) - s.done() - } - return nil -} - -func (s *serverinfo) dialudp(pid string, addr *net.UDPAddr) (net.Conn, error) { - userelay := s.GetRelay() != nil - useproxy := len(pid) != 0 // pid == dnsx.NetNoProxy => ipn.Base - if userelay || useproxy { - return s.dialpx(pid, "udp", addr.String()) - } - return nil, dnsx.ErrNoProxyProvider -} - -func (s *serverinfo) dialtcp(pid string, addr *net.TCPAddr) (net.Conn, error) { - userelay := s.GetRelay() != nil - useproxy := len(pid) != 0 // pid == dnsx.NetNoProxy => ipn.Base - if userelay || useproxy { - return s.dialpx(pid, "tcp", addr.String()) - } - return nil, dnsx.ErrNoProxyProvider -} - -func (s *serverinfo) dialpx(pid, proto string, addr string) (net.Conn, error) { - relay := s.getRelay() - if relay != nil { - // addr is always ip:port; hence protect.dialers are not needed - return relay.Dialer().Dial(proto, addr) - } - pxs := s.proxies - if pxs == nil { - return nil, dnsx.ErrNoProxyProvider - } - px, err := pxs.ProxyFor(pid) - if err == nil { - return px.Dialer().Dial(proto, addr) // ref comment above - } - return nil, err -} - -func (s *serverinfo) chooseProxy(pids []string) string { - return dnsx.ChooseHealthyProxy("dnscrypt: "+s.ID(), s.IPPorts(), pids, s.proxies) -} - -func addr2ipp(u ...*net.UDPAddr) (ipps []netip.AddrPort) { - if len(u) <= 0 { - return dnsx.NoIPPort - } - for _, x := range u { - if x != nil { - ipps = append(ipps, x.AddrPort()) - } - } - return // may be nil -} diff --git a/intra/dnscrypt/servers_test.go b/intra/dnscrypt/servers_test.go deleted file mode 100644 index cf828ddd..00000000 --- a/intra/dnscrypt/servers_test.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnscrypt - -import ( - "context" - "errors" - "log" - "net" - "net/netip" - "testing" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - ilog "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -type fakeCtl struct { - protect.Controller -} - -func (*fakeCtl) Bind4(_, _ string, _ int) {} -func (*fakeCtl) Bind6(_, _ string, _ int) {} -func (*fakeCtl) Protect(_ string, _ int) {} - -type fakeObs struct { - x.ProxyListener -} - -func (*fakeObs) OnProxyAdded(string) {} -func (*fakeObs) OnProxyRemoved(string) {} -func (*fakeObs) OnProxiesStopped() {} - -/* -type fakeBdg struct { - protect.Controller - intra.Bridge -} - -var ( - baseNsOpts = &dnsx.NsOpts{PID: ipn.Base, IPCSV: "", TIDCSV: ""} - baseMark = &intra.Mark{PID: ipn.Base, CID: "testcid", UID: protect.UidSelf} - baseTab = &rnet.Tab{CID: "testcid", Block: false} -) - -func (*fakeBdg) Flow(_ int32, _ int, a, b, c, d, e, f string) *intra.Mark { return baseMark } -func (*fakeBdg) OnSocketClosed(*intra.SocketSummary) {} - -func (*fakeBdg) OnQuery(_ string, _ int) *dnsx.NsOpts { return baseNsOpts } -func (*fakeBdg) OnResponse(*dnsx.Summary) {} - -func (*fakeBdg) Route(a, b, c, d, e string) *rnet.Tab { return baseTab } -func (*fakeBdg) OnComplete(*rnet.ServerSummary) {} -*/ - -type fakeResolver struct { - *net.Resolver -} - -func (r fakeResolver) LocalLookup([]byte) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) Lookup([]byte, string, ...string) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupFor([]byte, string) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIP(_ context.Context, _, _ string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIPFor(_ context.Context, _, _, _ string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIPOn(_ context.Context, _, _ string, _ ...string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -const minmtu = 1280 -const dualstack = settings.IP46 - -func TestOne(t *testing.T) { - ctx := context.TODO() - r := &net.Resolver{} - // create a struct that implements protect.Controller interface - ctl := &fakeCtl{} - obs := &fakeObs{} - // bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctx, dualstack, minmtu, ctl, obs) - if pxr == nil { - t.Fatal("nil proxifier") - } - ilog.SetLevel(0) - resolver := fakeResolver{r} - dialers.Mapper(resolver) - settings.Debug = true - p := NewDcMult(ctx, pxr, ctl) - // csromania fetches certs, but not answers - // csromania := "sdns://AQIAAAAAAAAADTE0Ni43MC42Ni4yMjcgMTNyrVlWMsJBa4cvCY-FG925ZShMbL6aTxkJZDDbqVoeMi5kbnNjcnlwdC1jZXJ0LmNyeXB0b3N0b3JtLmlz" - // dctnl does not fetch certs - // dctnl := "sdns://AQcAAAAAAAAAEzIzLjEzNy4yNDkuMTE2Ojg0NDMgEWD0g0vsKFqwslGBKql8eTiu1RvK2dzZIxLfR7ctlAwXMi5kbnNjcnlwdC1jZXJ0LmRjdC1ubDE" - // pl fetches certs, but not answers - // pl := "sdns://AQcAAAAAAAAAFDE3OC4yMTYuMjAxLjEyODoyMDUzIH9hfLgepVPSNMSbwnnHT3tUmAUNHb8RGv7mmWPGR6FpGzIuZG5zY3J5cHQtY2VydC5kbnNjcnlwdC5wbA" - // swfr := "sdns://AQcAAAAAAAAADjIxMi40Ny4yMjguMTM2IOgBuE6mBr-wusDOQ0RbsV66ZLAvo8SqMa4QY2oHkDJNHzIuZG5zY3J5cHQtY2VydC5mci5kbnNjcnlwdC5vcmc" - // swams := "sdns://AQcAAAAAAAAADTUxLjE1LjEyMi4yNTAg6Q3ZfapcbHgiHKLF7QFoli0Ty1Vsz3RXs1RUbxUrwZAcMi5kbnNjcnlwdC1jZXJ0LnNjYWxld2F5LWFtcw" - // dnsbe is down - // dnsbe := "sdns://AQcAAAAAAAAADzE5My4xOTEuMTg3LjEwNyAzWmXOT_I8k2BKJzxIJ_iYoXRQRWcR0Q1FFyrJWtvogxsyLmRuc2NyeXB0LWNlcnQuZG5zY3J5cHQuYmU" - // adguard family does not fetch certs - // agfam := "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20" - // adguard does not fetch certs - // adguard := "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20" - // cb := "sdns://AQMAAAAAAAAAEzE4NS4yMjguMTY4LjEwOjg0NDMgvKwy-tVDaRcfCDLWB1AnwyCM7vDo6Z-UGNx3YGXUjykRY2xlYW5icm93c2luZy5vcmc" - q912 := "sdns://AQYAAAAAAAAAEzE0OS4xMTIuMTEyLjEyOjg0NDMgZ8hHuMh1jNEgJFVDvnVnRt803x2EwAuMRwNo34Idhj4ZMi5kbnNjcnlwdC1jZXJ0LnF1YWQ5Lm5ldA" - tr, err := AddTransport(p, "test", q912) - if err != nil || tr == nil { - t.Fatal(errors.Join(dnsx.ErrAddFailed, err)) - } - q := aquery("google.com") - smm := &x.DNSSummary{} - netw := xdns.NetAndProxyID("udp", ipn.Base) - // FIXME: querying always fails with EOF - ans, err := tr.Query(netw, q, smm) - if err != nil { - log.Output(2, smm.String()) - t.Fatal(err) - } - if xdns.Len(ans) == 0 { - t.Fatal("empty response") - } - log.Output(10, strDNSAns(ans)) -} - -func aquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeA) - msg.Id = 1234 - return msg -} - -func strDNSAns(a *dns.Msg) string { - if a == nil || len(a.Answer) < 1 { - return "no answer" - } - return a.Answer[0].String() -} diff --git a/intra/dnsx/alg.go b/intra/dnsx/alg.go deleted file mode 100644 index 34788ff3..00000000 --- a/intra/dnsx/alg.go +++ /dev/null @@ -1,2408 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "hash/fnv" - "math" - "net" - "net/netip" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -const ( - timeout = 15 * time.Second - // 2m max ttl for alg/nat ip - ttl2m = 2 * time.Minute - // 8s min ttl for alg/nat ip; chosen to be closer to transport timeouts - ttl8s = 8 * time.Second - // 3s timeout to undo dns64/nat64 - ttl3s = 3 * time.Second - - key4 = ":a" - key6 = ":aaaa" - - notransport = "NoTransport" - - maxiter = 100 // max number alg/nat evict iterations -) - -type iptype int - -const ( - typalg iptype = iota - typreal - typsecondary -) - -var ( - // 100.64.x.x - rfc6598 = []uint8{100, 64, 0, 1} - rfc8215a = []uint16{0x64, 0xff9b, 0x1, 0xda19, 0x100, 0x0, 0x0, 0x0} - rfc3068ip4 = netip.AddrFrom4([4]uint8{192, 88, 99, 114}) - rfc3068ip6 = netip.AddrFrom16([16]byte{0x20, 0x02, 0xF1, 0x3D, 0x00, 0x01, 0xDA, 0x19, 0x01, 0x92, 0x00, 0x88, 0x00, 0x99, 0x01, 0x14}) - // [192.88.99.114, 2002:f13d:1:da19:192:88:99:114] go.dev/play/p/-0MJenRF5pm - fixedRealIPs = []netip.Addr{rfc3068ip4, rfc3068ip6} - - zeroaddr = netip.Addr{} - anyaddr4 = netip.IPv4Unspecified() - anyaddr6 = netip.IPv6Unspecified() - - errAlgNoTransport = errors.New("no alg transport") - errAlgNotAvail = errors.New("no valid alg ips") - errAlgCannotRegister = errors.New("cannot register alg ip") - errAlgCannotSubst = errors.New("cannot substitute alg ip") -) - -func isAlgErr(err error) bool { - return (err == errAlgCannotRegister || err == errAlgNotAvail || err == errAlgCannotSubst) -} - -type Gateway interface { - // given an alg or real ip, retrieves assoc real ips as csv, if any - X(maybeAlg netip.Addr, uid string, tids ...string) (realips []netip.Addr, undidAlg bool) - // given an alg or real ip, retrieves assoc dns names as csv, if any - PTR(maybeAlg netip.Addr, uid, tid string, force bool) (domaincsv string, didForce bool) - // given domain, retrieve assoc alg ips or real ips as csv, if any - RESOLV(domain, uid, tid string) []netip.Addr - // given an alg or real ip, retrieve assoc blocklists as csv, if any - RDNSBL(maybeAlg netip.Addr) (blocklistcsv string) - // translate overwrites ip answers to alg ip & fixed ip answers - translate(tr, fix bool) - // splitTunnel sets per-app split tunneling on/off - splitTunnel() - // fixedTransport returns true if split tunneling is enforced via dnsx.Fixed - fixedTransport() bool - // Query using t1 as primary transport and t2 as secondary and preset as pre-determined ip answers - q(t1, t2 Transport, preset []netip.Addr, network, uid string, q *dns.Msg, s *x.DNSSummary) (og *dns.Msg, algd *dns.Msg, err error) - // onStopped is called when a transport tid is stopped. Gateway invalidates its local caches, if any. - onStopped(tid string) - // S reveals internal state for debugging - S() string -} - -type secans struct { - ips []netip.Addr - smm *x.DNSSummary - pri bool // is secans got from primary transport? -} - -func (sec *secans) initIfNeeded() { - // ptr reciever updates all secans in-place: - // go.dev/play/p/wjBd1TC59zN - if sec.ips == nil { - sec.ips = []netip.Addr{} - } - if sec.smm == nil { - sec.smm = new(x.DNSSummary) - } -} - -type expaddr struct { - ips []netip.Addr - ttl time.Time - dob time.Time -} - -func (a expaddr) String() string { - return fmt.Sprintf("addrs(%v / ttl: %s / dob: %s)", a.ips, core.FmtTimeAsPeriod(a.ttl), core.FmtTimeAsPeriod(a.dob)) -} - -func (a expaddr) sizes() (alive, tot int) { - if a.ips == nil { - return - } - return len(a.get(xalive)), len(a.ips) -} - -func (a expaddr) all() []netip.Addr { - return a.get(xall) -} - -func (a expaddr) alive() []netip.Addr { - return a.get(xalive) -} - -// fresh returns false if a has expired or if a is zero-value. -func (a expaddr) fresh() (rem time.Duration, y bool) { - if a.ttl.IsZero() { - return - } - rem = time.Until(a.ttl) - y = rem > 0 - return -} - -func (a expaddr) after(b expaddr) bool { - if a.dob.IsZero() { - return b.dob.IsZero() - } - return a.dob.After(b.dob) -} - -func (a expaddr) fresherThan(b expaddr) bool { - if a.ttl.IsZero() { - return b.ttl.IsZero() - } - return a.ttl.After(b.ttl) -} - -func (a expaddr) get(s xaddrstatus) (out []netip.Addr) { - if a.ips == nil { - return - } - if s == xall { - return a.ips - } else if s == xalive { - if _, y := a.fresh(); y { - return a.ips - } - } - return -} - -// go.dev/play/p/t24GYIQERsp -func expaddrDobSorter(a, b expaddr) int { - if a.after(b) { - return -1 - } else if b.after(a) { - return 1 - } - return 0 -} - -type xips struct { - // protects pri, aux - pmu sync.RWMutex - // resolved expaddr(v6 or v4) by tid; may be nil - pri map[string]expaddr - // resolved expaddr(v6 or v4) by secondary tid for a uid; may be nil - aux map[string]expaddr -} - -type xaddrstatus bool - -const ( - xalive xaddrstatus = true - xall xaddrstatus = false -) - -type xaddrtyp uint8 - -const ( - xpri xaddrtyp = 0 - xsec xaddrtyp = 1 -) - -func (xat xaddrtyp) String() string { - switch xat { - case xpri: - return "pri" - case xsec: - return "sec" - default: - return "???" - } -} - -func (xst xaddrstatus) String() string { - if xst { - return "alive" - } - return "all" -} - -func makeXipStatus(alive bool) xaddrstatus { - return xaddrstatus(alive) -} - -// NewXips returns a new xips object with primary and secondary ips. -// May return nil if tid is empty. -func NewXips(tid, uid string, pri, sec []netip.Addr, ttl time.Time) *xips { - if len(tid) <= 0 { - log.W("alg: xips: tid cannot be empty") - return nil - } - - x := &xips{ - pri: make(map[string]expaddr), - aux: make(map[string]expaddr), // sec may be nil - } - now := time.Now() - // id == "" for dnsx.NoDNS, in which case pri should be empty - if len(pri) > 0 { // pri may be nil - x.pri[tid] = expaddr{pri, ttl, now} - } - if len(uid) <= 0 { - uid = core.UNKNOWN_UID_STR - } - if len(sec) > 0 { // sec may be same as pri - x.aux[tid+uid] = expaddr{sec, ttl, now} - } - return x -} - -func vals[K comparable, V any](m map[K]V, upto uint) []V { - outv := make([]V, 0, upto) - for _, v := range m { - if upto != 0 && uint(len(outv)) >= upto { - break - } - outv = append(outv, v) - } - return outv -} - -func (p *xips) String() string { - if p == nil { - return "" - } - p.pmu.RLock() - defer p.pmu.RUnlock() - return fmt.Sprintf("xips: pri(%v) sec(%v)", p.pri, p.aux) -} - -// secondary ips for all tids+uids -func (p *xips) secips(s xaddrstatus) (out []netip.Addr) { - // as far as secondary is concerned, only live ips matter; - // expired ips should not be used for allow/deny decisions. - // that is, s should almost always be set to xalive. - return p.allips(xsec, s) -} - -// secondary ip by a given tid+uid pair -func (p *xips) secipsFor(tid, uid string) (out []netip.Addr) { - // as far as secondary is concerned, only live ips matter; - // expired ips should not be used for allow/deny decisions. - return p.ipsFor(tid, uid, xsec, xalive) -} - -func (p *xips) allips(t xaddrtyp, s xaddrstatus) (out []netip.Addr) { - p.pmu.RLock() - defer p.pmu.RUnlock() - - g := p.pri - if t == xsec { - g = p.aux - } - const all = 0 - addrs := append([]expaddr{}, vals(g, all)...) - - for _, v := range core.Sort(addrs, expaddrDobSorter) { - out = append(out, v.get(s)...) - } - return -} - -// realips returns all live primary ips for a given uid. -// returns all tracked live primary ips. -// and returns unspecified ips, if uid had a "block" response. -func (p *xips) realips(uid string, s xaddrstatus) (pri []netip.Addr) { - if p == nil { - return nil - } - - pri = p.allips(xpri, s) - - zz := p.zz(uid) // check for "block" responses for just this uid - if zz { - // if aux had unspecified ips, then append those to - // primary ips. if aux has proper ips or no ips, - // those are made redundant by primary. - // append unspecified ips (zz) to pri regardless of xaddrstatus, s, - // which is set to xall by xLocked if alg is disabled (mod == false), - // but that does not mean the values from aux (secondary) are not - // needed; ie, the "domain block" decision (indicated by aux) must - // be commmunicated to the client at all costs. However, aux itself - // must be xalive (not expired). - pri = append(pri, anyaddr4, anyaddr6) - } - loged(len(pri) == 0)("alg: xips: x(%s): zz(%s)? %t; %v", s, uid, zz, pri) - return pri // may be nil / empty -} - -func (p *xips) zz(uid string) bool { - if p == nil { - return false - } - return anyUnspecified(p.secipsFor(notransport, uid)) -} - -// realipsFor returns all live primary ips for a given tid+uid pair -func (p *xips) realipsFor(tid, uid string, s xaddrstatus) (pri []netip.Addr) { - return p.ipsFor(tid, uid, xpri, s) -} - -func (p *xips) ipsFor(tid, uid string, t xaddrtyp, s xaddrstatus) (out []netip.Addr) { - if p == nil { - return nil - } - - if len(uid) <= 0 { - uid = core.UNKNOWN_UID_STR - } - - if t == xpri && (tid == notransport || tid == NoDNS || len(tid) <= 0) { - out = p.allips(xpri, s) - if settings.Debug { - log.VV("alg: xips: xof(%s,%s): no tid? %s[%s]; returning all %v", t, s, tid, uid, out) - } - return - } - - p.pmu.RLock() - defer p.pmu.RUnlock() - - if t == xsec { // ignore alive for secondary - if tid == notransport || tid == NoDNS || len(tid) <= 0 { - for k, v := range p.aux { - if strings.HasSuffix(k, uid) { - out = append(out, v.get(s)...) - } - } - } else if v, ok := p.aux[tid+uid]; ok { - out = v.get(s) - } - } else if s == xalive { - out = p.pri[tid].alive() - } else { - out = p.pri[tid].all() - } - - if settings.Debug { - log.VV("alg: xips: xof(%s,%s): tid %s + uid %s; %v", t, s, tid, uid, out) - } - return -} - -func (p *xips) rmv(tid string) (done bool) { - if p == nil { - return - } - if len(tid) <= 0 || tid == NoDNS || tid == notransport { - return - } - - i, j := 0, 0 - p.pmu.Lock() - defer p.pmu.Unlock() - xaddr := p.pri[tid] - if _, y := xaddr.fresh(); y { - i++ - xaddr.ttl = time.Now() // mark as expired - p.pri[tid] = xaddr - } - for k, v := range p.aux { - if _, y := v.fresh(); !y { - continue - } - if strings.HasPrefix(k, tid) { - j++ - done = done || true - v.ttl = time.Now() // mark as expired - p.aux[k] = v - } - } - if done = i > 0 || j > 0; done { - if settings.Debug { - log.D("alg: xips: rmv(%s): pri(%d), sec(%d); ok? %t", tid, i, j, done) - } - } - return -} - -// block returns true if any secondary ip is unspecified -func anyUnspecified(ips []netip.Addr) bool { - // unspecified ips expected to be in aux - // primary must never have unspecified ips - for _, ip := range ips { - if ip.IsUnspecified() { - return true - } - } - return false -} - -func anyAddrEqual(ipps []netip.AddrPort, ip netip.Addr) bool { - for _, ipp := range ipps { - if ipp.Addr().Compare(ip) == 0 { - return true - } - } - return false -} - -// each iterates over each primary ip -func (p *xips) each(f func(ip netip.Addr)) { - for _, ip := range p.allips(xpri, xall) { - f(ip) - } - for _, ip := range p.allips(xsec, xall) { - f(ip) - } -} - -func (p *xips) merge(q *xips) (szprialiv, szpri, szsecaliv, szsec int) { - szprialiv, szpri, szsecaliv, szsec = -1, -1, -1, -1 // -1 means no change - - if p == nil || q == nil { - return - } - - p.pmu.Lock() // write lock on p - defer p.pmu.Unlock() - q.pmu.RLock() // read lock on q - defer q.pmu.RUnlock() - - if len(q.pri) > 0 { - szprialiv, szpri = 0, 0 - } - if len(q.aux) > 0 { - szsecaliv, szsec = 0, 0 - } - - for qk, qv := range q.pri { - pv := p.pri[qk] - if _, y := pv.fresh(); !y { - p.pri[qk] = qv // copy v from q into p - szqaliv, szqpri := qv.sizes() - szprialiv += szqaliv - szpri += szqpri - } else { - var ips []netip.Addr - ttl := pv.ttl - dob := pv.dob - if !pv.after(qv) { - // qv is younger, so qv's ips should come first (youngest first ordering) - ips = copyUniq(qv.alive(), pv.alive()) - dob = qv.dob - } else { - // pv is younger, so pv's ips should come first (youngest first ordering) - ips = copyUniq(pv.alive(), qv.alive()) - } - if !pv.fresherThan(qv) { - // ips from aa & v both get assigned the latest ttl - // which is strictly incorrect, but for accounting - // purposes (that is, algip->realip translations), - // it is preferable to keep as many realips around - // as possible (even at the slight cost of correctness). - ttl = qv.ttl - } - v := expaddr{ips, ttl, dob} - p.pri[qk] = v - szqaliv, szqpri := v.sizes() - szprialiv += szqaliv - szpri += szqpri - } - } - - for qk, qv := range q.aux { - pv := p.aux[qk] - if _, y := pv.fresh(); !y { - p.aux[qk] = qv // copy v from q into p - szqaliv, szqsec := qv.sizes() - szsecaliv += szqaliv - szsec += szqsec - } else { - var ips []netip.Addr - ttl := pv.ttl - dob := pv.dob - if !pv.after(qv) { - // qv is younger, so qv's ips should come first (youngest first ordering) - ips = copyUniq(qv.alive(), pv.alive()) - dob = qv.dob - } else { - // pv is younger, so pv's ips should come first (youngest first ordering) - ips = copyUniq(pv.alive(), qv.alive()) - } - if !pv.fresherThan(qv) { - ttl = qv.ttl - } - v := expaddr{ips, ttl, dob} - p.aux[qk] = v - szqaliv, szqsec := v.sizes() - szsecaliv += szqaliv - szsec += szqsec - } - } - - return -} - -// expdomains tracks domains with ttl and birth for ordering; unlike expaddr, -// there is no secondary domains map. -type expdomains struct { - domains []string - ttl time.Time - dob time.Time -} - -func (a expdomains) String() string { - return fmt.Sprintf("domains(%v / ttl: %s / dob: %s)", a.domains, core.FmtTimeAsPeriod(a.ttl), core.FmtTimeAsPeriod(a.dob)) -} - -func (a expdomains) sizes() (alive, tot int) { - if a.domains == nil { - return - } - return len(a.get(xalive)), len(a.domains) // == len(a.get(xall)) -} - -func (a expdomains) get(s xaddrstatus) (out []string) { - if a.domains == nil { - return - } - if s == xall { - return a.domains - } - if s == xalive { - if _, y := a.fresh(); y { - return a.domains - } - } - return -} - -func (a expdomains) fresh() (rem time.Duration, y bool) { - if a.ttl.IsZero() { - return - } - rem = time.Until(a.ttl) - y = rem > 0 - return -} - -func (a expdomains) after(b expdomains) bool { - if a.dob.IsZero() { - return b.dob.IsZero() - } - return a.dob.After(b.dob) -} - -func (a expdomains) fresherThan(b expdomains) bool { - if a.ttl.IsZero() { - return b.ttl.IsZero() - } - return a.ttl.After(b.ttl) -} - -func expdomainsDobSorter(a, b expdomains) int { - if a.after(b) { // a is born after b (ie, a is younger) - return -1 - } else if b.after(a) { // b is younger - return 1 - } - return 0 -} - -// xdomains tracks domains per tid+uid; there is no secondary (aux) store. -type xdomains struct { - pmu *sync.RWMutex - pri map[string]expdomains // tid+uid -> domains -} - -// NewXdomains returns a new xdomains object keyed by tid+uid. -func NewXdomains(tid, uid string, pri []string, ttl time.Time) *xdomains { - if len(tid) <= 0 { - log.W("alg: xdomains: tid cannot be empty") - return nil - } - if len(uid) <= 0 { - uid = core.UNKNOWN_UID_STR - } - - x := &xdomains{ - pmu: new(sync.RWMutex), - pri: make(map[string]expdomains), - } - - if len(pri) > 0 { // pri may be nil - x.pri[tid+uid] = expdomains{domains: pri, ttl: ttl, dob: time.Now()} - } - return x -} - -func (p *xdomains) String() string { - if p == nil { - return "" - } - p.pmu.RLock() - defer p.pmu.RUnlock() - return fmt.Sprintf("xdomains: pri(%v)", p.pri) -} - -func (p *xdomains) domainsFor(tid, uid string, forIP netip.Addr, s xaddrstatus) (out []string) { - if p == nil { - return nil - } - if len(uid) <= 0 { - uid = core.UNKNOWN_UID_STR - } - - doms := make([]expdomains, 0) - ttls := []time.Time{} - key := tid + uid - - p.pmu.RLock() - defer p.pmu.RUnlock() - - if tid == notransport || tid == NoDNS || len(tid) <= 0 { - for k, v := range p.pri { - if !strings.HasSuffix(k, uid) { - continue - } - doms = append(doms, v) - } - for _, v := range core.Sort(doms, expdomainsDobSorter) { - if settings.Debug { - // for debugging: may have expired ttls - ttls = append(ttls, v.ttl) - } - out = append(out, v.get(s)...) - } - } else if v, ok := p.pri[key]; ok { - if settings.Debug { - ttls = append(ttls, v.ttl) - } - out = v.get(s) - } - - if settings.Debug { - log.VV("alg: xdomains: xof(%s/%s): %s => %v [%v]", s, forIP, key, out, core.Map(ttls, core.FmtTimeAsPeriod)) - } - return -} - -func (p *xdomains) rmv(tid string) (done bool) { - if p == nil { - return - } - if len(tid) <= 0 || tid == NoDNS || tid == notransport { - return - } - - keyPrefix := tid - i := 0 - p.pmu.Lock() - defer p.pmu.Unlock() - for k, v := range p.pri { - if !strings.HasPrefix(k, keyPrefix) { - continue - } - if _, y := v.fresh(); y { - i++ - v.ttl = time.Now() // mark as expired - p.pri[k] = v - } - } - done = i > 0 - return -} - -func (p *xdomains) merge(q *xdomains) (szprialiv, szpri int) { - szprialiv, szpri = -1, -1 // -1 means no change - - if p == nil || q == nil { - return - } - - p.pmu.Lock() - defer p.pmu.Unlock() - q.pmu.RLock() - defer q.pmu.RUnlock() - - if len(q.pri) > 0 { - szprialiv, szpri = 0, 0 - } - - for qk, qv := range q.pri { - pv := p.pri[qk] - if _, y := pv.fresh(); !y { - p.pri[qk] = qv // copy v from q into p - szqaliv, szqpri := qv.sizes() - szprialiv += szqaliv - szpri += szqpri - } else { - var doms []string - ttl := pv.ttl - dob := pv.dob - if !pv.after(qv) { - // qv is younger, so qv's domains should come first (youngest first ordering) - doms = copyUniq(qv.get(xalive), pv.get(xalive)) - dob = qv.dob - } else { - // pv is younger, so pv's domains should come first (youngest first ordering) - doms = copyUniq(pv.get(xalive), qv.get(xalive)) - } - if !pv.fresherThan(qv) { - pv.ttl = qv.ttl - } - v := expdomains{domains: doms, ttl: ttl, dob: dob} - p.pri[qk] = v - szqaliv, szqpri := v.sizes() - szprialiv += szqaliv - szpri += szqpri - } - } - - return -} - -type baseans struct { - ips *xips // all ip answers, v6+v4; may be nil - domains *xdomains // all domain names in an answer (incl qname); per tid+uid - blocklists string // csv blocklists containing qname per active config at the time - ttl time.Time // ttl for this alg translation -} - -func (a *baseans) String() string { - if a == nil { - return "" - } - return fmt.Sprintf("xips(%v) domains(%v) blocklists(%s) ttl(%s)", - a.ips, a.domains, a.blocklists, core.FmtTimeAsPeriod(a.ttl)) -} - -// fresh returns false if a has expired or if a is nil -func (a *baseans) fresh() (rem time.Duration, y bool) { - if a == nil || a.ttl.IsZero() { - return - } - rem = time.Until(a.ttl) - y = rem > 0 - return -} - -type algans struct { - algip netip.Addr // generated answer, v6 or v4 - *baseans -} - -func (a *algans) String() string { - if a == nil { - return "" - } - return fmt.Sprintf("algans: algip(%s) base(%s)", a.algip, a.baseans) -} - -func (a *algans) extend(by time.Duration) { - a.ttl = time.Now().Add(by) -} - -func (a *algans) after(b *algans) bool { - if a.ttl.IsZero() { - return b.ttl.IsZero() - } - return a.ttl.After(b.ttl) -} - -// merge merges b into a. -func (a *algans) merge(b *algans) { - if a == nil || b == nil { - log.W("alg: merge: nil algans; a? %s; b? %s", a, b) - return - } - - log.VV("alg: merge: b(%s) into a(%s)", b, a) - - if a.ips == nil { - a.ips = b.ips - } else { - prialiv, totpri, secaliv, totsec := a.ips.merge(b.ips) - logeif(totpri < 0 && totsec < 0)("alg: merge: err ips merge; pri(%d/%d) sec(%d/%d), out(%s)", - prialiv, totpri, secaliv, totsec, a) - } - if a.domains == nil { - a.domains = b.domains - } else { - prialiv, totpri := a.domains.merge(b.domains) - logeif(totpri < 0)("alg: merge: err domains merge; pri(%d/%d), out(%s)", - prialiv, totpri, a) - } - a.blocklists = b.blocklists // TODO: merge? - if b.after(a) { - a.ttl = b.ttl - } -} - -func domainsFor(base *baseans, tid, uid string, forIP netip.Addr, s xaddrstatus) []string { - if base == nil || base.domains == nil { - return nil - } - return base.domains.domainsFor(tid, uid, forIP, s) -} - -type dnsgateway struct { - sync.RWMutex // protects alg, nat, octets, hexes - alg map[string]*algans // domain+type => ans - nat map[netip.Addr]*baseans // algip => baseans - ptr map[netip.Addr]*baseans // primaryip => baseans - octets []uint8 // ip4 octets, 100.x.y.z - hexes []uint16 // ip6 hex, 64:ff9b:1:da19:0100.x.y.z - - // fields below are never reassigned - - fake []netip.AddrPort // fake DNS addrs to ignore for "undoAlg" - rdns RdnsResolver // local and remote rdns blocks - dns64 NatPt // dns64/nat64 - chash bool // use consistent hashing to generate alg ips - - // fields below are mutable - - mod atomic.Bool // modify realip to algip - fix atomic.Bool // enforce split tunneling via dnsx.Fixed - split atomic.Bool // per-app split tunneling -} - -var _ Gateway = (*dnsgateway)(nil) - -// NewDNSGateway returns a DNS ALG, ready for use. -func NewDNSGateway(pctx context.Context, fakeaddrs []netip.AddrPort, outer RdnsResolver, dns64 NatPt) (t *dnsgateway) { - alg := make(map[string]*algans) - nat := make(map[netip.Addr]*baseans) - ptr := make(map[netip.Addr]*baseans) - - t = &dnsgateway{ - alg: alg, - nat: nat, - ptr: ptr, - fake: fakeaddrs, - rdns: outer, - dns64: dns64, - octets: rfc6598, - hexes: rfc8215a, - chash: true, - } - - context.AfterFunc(pctx, t.stop) - log.I("alg: setup done") - return -} - -func (t *dnsgateway) translate(tr, fix bool) { - // fixed transport can only be used if translation is on - fix = tr && fix - prevtr := t.mod.Swap(tr) - prevfix := t.fix.Swap(fix) - log.I("alg: translate? prevtr(%t) > nowtr(%t); prevfix(%t) > nowfix(%t)", prevtr, tr, prevfix, fix) -} - -func (t *dnsgateway) splitTunnel() { - if t.split.Load() { - return - } - t.split.Store(true) - log.I("alg: splitTunnel turned on") -} - -func (t *dnsgateway) fixedTransport() bool { - return t.mod.Load() && t.split.Load() && t.fix.Load() -} - -func (t *dnsgateway) onStopped(tid string) { - if len(tid) <= 0 { - return - } - - t.RLock() - algEntries := make([]*xips, 0, len(t.alg)) - domEntries := make([]*xdomains, 0, len(t.alg)) - for _, algans := range t.alg { - if algans != nil { - if algans.ips != nil { - algEntries = append(algEntries, algans.ips) - } - if algans.domains != nil { - domEntries = append(domEntries, algans.domains) - } - } - } - t.RUnlock() - - // Process outside the lock to avoid holding it too long - n := 0 - for _, xip := range algEntries { - if xip.rmv(tid) { - n++ - } - } - m := 0 - for _, xdom := range domEntries { - if xdom.rmv(tid) { - m++ - } - } - log.I("alg: onStopped(%s): removed %d alg<>realip translations; %d domains", tid, n, m) -} - -// clears alg states -func (t *dnsgateway) stop() { - t.Lock() - defer t.Unlock() - - clear(t.alg) - clear(t.nat) - clear(t.ptr) - t.octets = rfc6598 - t.hexes = rfc8215a -} - -func (t *dnsgateway) fromInternalCache(tid, uid string, q *dns.Msg, typ iptype) (ans *dns.Msg, err error) { - if skipInternalCache(tid) { - return nil, errSkipInternalCache - } - // Skip answering from internal cache when DNSSEC is requested - if typ == typreal && xdns.IsDNSSECRequested(q) { - return nil, errSkipInternalCache - } - a, aaaa := xdns.HasAQuestion(q), xdns.HasAAAAQuestion(q) - if !a && !aaaa { - return nil, errNilCacheResponse - } - - domain := qname(q) - - t.RLock() - cached4s, cached6s, stale, until, _ := t.resolvLocked(domain, typ, tid, uid) - t.RUnlock() - - var cachedips []netip.Addr - if a && len(cached4s) > 0 { - cachedips = cached4s - } else if aaaa && len(cached6s) > 0 { - cachedips = cached6s - } - cachehit := len(cachedips) > 0 - ttlnegative := false - ttl := int64(until / time.Second) - if ttl < 0 { - ttl = int64(ttl8s / time.Second) - ttlnegative = true - } - - logeif(ttlnegative && cachehit)("alg: response for %s by %s[%s] (q4? %t / q6? %t) realip; in cache? %v [ttl: %s / -ve? %t / hit? %t / until: %s] (or stale? %v)", - domain, tid, uid, a, aaaa, cachedips, core.FmtSecs(ttl), ttlnegative, cachehit, core.FmtPeriod(until), stale) - - if !cachehit { - return nil, errNilCacheResponse - } - return xdns.AQuadAForQueryTTL(q, uint32(ttl), cachedips...) -} - -func (t *dnsgateway) qp(t1 Transport, uid, network string, q *dns.Msg, innersummary *x.DNSSummary) (ans *dns.Msg, err error) { - // For A/AAAA queries, check if xips has an answer for the qname. - if ans, err := t.fromInternalCache(idstr(t1), uid, q, typreal); err == nil { - innersummary.ID = idstr(t1) - innersummary.Server = getaddrstr(t1) - innersummary.RData = xdns.GetInterestingRData(ans) - innersummary.RCode = xdns.Rcode(ans) - innersummary.RTtl = xdns.RTtl(ans) - innersummary.Status = Complete - innersummary.Cached = true - return ans, nil - } - return Req(t1, network, q, innersummary) -} - -func (t *dnsgateway) qs(t2 Transport, uid, network string, msg *dns.Msg, t1res <-chan *dns.Msg) <-chan secans { - t2res := make(chan secans, 1) - msg = msg.Copy() // to avoid racing against changes made by caller - - core.Gx("alg.qs."+xdns.QName(msg), func() { - defer close(t2res) - - qname := xdns.QName(msg) - - r, completed := core.Grx("alg.qs."+qname, func(_ context.Context) (secans, error) { - return t.querySecondary(t2, uid, network, msg, t1res), nil - }, timeout) - - if !completed { - log.W("alg: skip; qs timeout; tr2: %s, qname: %s", idstr(t2), qname) - } - - r.initIfNeeded() // r may be nil value on Grx:timeout - - t2res <- r // may be zero secans - }) - return t2res -} - -func (t *dnsgateway) querySecondary(t2 Transport, uid, network string, msg *dns.Msg, t1res <-chan *dns.Msg) (result secans) { - var r *dns.Msg - var err error - - result.initIfNeeded() // result must not be reassigned - - // check if the question is blocked - if msg == nil || !xdns.HasAnyQuestion(msg) { - result.smm.Msg = errNoQuestion.Error() - return // not a valid dns message - } else if ok := xdns.HasAQuadAQuestion(msg) || xdns.HasHTTPQuestion(msg) || xdns.HasSVCBQuestion(msg); !ok { - result.smm.Msg = errNotEnoughAnswers.Error() - return // not a dns question we care about - } else if ans1, blocklists, err2 := t.rdns.blockQ( /*maybe nil*/ t2, nil, msg); err2 == nil { - // if err !is nil, then the question is blocked - if ans1 != nil && len(ans1.Answer) > 0 { - result.ips = append(result.ips, xdns.AAnswer(ans1)...) - result.ips = append(result.ips, xdns.AAAAAnswer(ans1)...) - } // noop: for HTTP/SVCB, the answer is always empty - result.smm.Blocklists = blocklists - result.smm.Status = Complete - return - } - - // no secondary transport; check if there's already an answer to work with - if t2 == nil || core.IsNil(t2) { - r = <-t1res // from primary transport, t1; r may be nil - result.pri = true // secans not from secondary - } else { - // check if there's already a cached answer to work with - // note: secondary ips are not cached per-transport (see xips.sec()) - if r, err = t.fromInternalCache(idstr(t2), uid, msg, typsecondary); err != nil { - // else: query secondary to get answer for q - r, err = Req(t2, network, msg, result.smm) - } else { - result.smm.ID = idstr(t2) - result.smm.Server = getaddrstr(t2) - result.smm.RData = xdns.GetInterestingRData(r) - result.smm.RCode = xdns.Rcode(r) - result.smm.RTtl = xdns.RTtl(r) - result.smm.Status = Complete - result.smm.Cached = true - } - } - - // check if answer r is blocked; r is either from t2 or from <-in - if err != nil || r == nil || !xdns.HasAnyAnswer(r) { // not a valid dns answer - log.V("alg: querySecondary: skip; sec transport %s; nores? %t, err? %v", idstr(t2), r == nil, err) - result.smm.Msg = errNotEnoughAnswers.Error() - return - } else if a, blockedtarget, blocklistnames := t.rdns.blockA( /*may be nil*/ t2, nil, msg, r, result.smm.Blocklists); a != nil { - // if "a" is not nil, then the r is blocked - if len(blocklistnames) > 0 { - result.smm.Blocklists = blocklistnames - result.smm.BlockedTarget = blockedtarget - } - // when rdns.blockA blocks, A/AAAA must be 0.0.0.0/:: - // and HTTPS/SVCB is an empty answer section - // see: xdns.RefusedResponseFromMessage - // if len(a.Answer) > 0 { - // result.ips = append(result.ips, xdns.AAnswer(a)...) - // result.ips = append(result.ips, xdns.AAAAAnswer(a)...) - // } - result.ips = append(result.ips, anyaddr4, anyaddr6) - return - } else { - if len(blocklistnames) > 0 { - result.smm.Blocklists = blocklistnames - result.smm.BlockedTarget = blockedtarget - } - if xdns.AQuadAUnspecified(r) { - // A/AAAA must be 0.0.0.0/::, set UpstreamBlocks to true - result.smm.UpstreamBlocks = true - // discard all other answers - result.ips = append(result.ips, anyaddr4, anyaddr6) - } else if xdns.HasAnyAnswer(r) { - ip4hints := xdns.IPHints(r, dns.SVCB_IPV4HINT) - ip6hints := xdns.IPHints(r, dns.SVCB_IPV6HINT) - result.ips = append(result.ips, xdns.IPs(r)...) - result.ips = append(result.ips, ip4hints...) - result.ips = append(result.ips, ip6hints...) - // TODO: result.targets? - } - return - } -} - -// Implements Gateway -// preset may be nil -func (t *dnsgateway) q(t1, t2 Transport, preset []netip.Addr, network, uid string, q *dns.Msg, smm *x.DNSSummary) (ogmsg *dns.Msg, outmsg *dns.Msg, outerr error) { - var ansin *dns.Msg // answer got from transports - var err error - - usepreset := len(preset) > 0 // preset may be nil - mod := t.mod.Load() // allow alg? - discarduid := !t.split.Load() // do not split tunnel? - uidself := uid == protect.UidSelf // us? - hasblock := isAnyBlockAll(idstr(t2), idstr(t2)) - hasfixed := isAnyFixed(idstr(t1)) // fixed transport? - usefixed := !usepreset && mod && hasfixed // use preset fixed realips? - skipcache := skipInternalCache(idstr(t1), idstr(t2)) - // do not perform alg or use its "xip" caches by the way of ptr, nat, alg - // maps for synthesized response and block-all transport. That's because - // BlockAll and synth responses are "ephemeral" in that, they may be based - // on a UID and not just a tid + qname + qtype (cf: DNSListener.OnQuery) - // but ptr, nat, alg maps are only tied to (tid, qname, qtype) tuple. Also, - // it isn't necessary that if a qname is blocked right now, it will be blocked - // by subsequent OnQuery()s as the rules are dynamic (for instance, a uid+qname - // may be blocked because device is in keyguard/locked state but allowed - // when the user is present (device is unlocked). In the cases where the - // uid is protect.UidSelf (that is, requests sent by dns64.go or ipmapper.go) - // should not be alg'd as the alg'd ips will end up as "realips" in xips caches. - // nb: setting mod = false will achieve the same effect but it goes through - // the effort of setting up alg/ptr/nat caches which is wasteful in this case. - // TODO: handle Loopback scenario for uidself (which probably should be alg'd?) - dontalg := usepreset || skipcache || uidself || hasblock - // usefixed generates fake but static answers for A/AAAA queries and no - // actual resolution request is sent (not even to the cache). It is expected - // that during PreFlow, the proxy layer will again attempt to resolve when - // it sees that an algip has been mapped to this fake (fixed) ip. - synthAns := usepreset || usefixed - hasdnssec := xdns.IsDNSSECRequested(q) - - smm.DO = hasdnssec - - if discarduid { - uid = core.UNKNOWN_UID_STR - } - if hasfixed && !usefixed { - log.W("alg: dnsx.Fixed must be used with mod & without preset, instead using... %s", idstr(t2)) - t1 = t2 // assert t2 != nil? - } else if usefixed { - // fixed ip responses must always be alg'd unlike preset / blockall - // even when uid == protect.UidSelf as dnsx.Fixed overrides all other - // settings & must respond with modded (alg'd) ips. It is another thing - // that protect.UidSelf requests should never need dnsx.Fixed. - dontalg = false - preset = fixedRealIPs - mod = true // assert mod == true? - t1 = t2 // assert t2 != nil? - } - if t1 == nil || core.IsNil(t1) { - log.W("alg: no primary transport; t1 %s, t2 %s, uid %s, self? %t preset? %t fixed? %t synth? %t nouid? %t", - idstr(t1), idstr(t2), uid, uidself, usepreset, usefixed, synthAns, discarduid) - return nil, nil, errAlgNoTransport - } - - // when synthesizing answers, override both t1 and t2: - // discard t2 as with preset we don't care about additional ips and blocklists; - // t1 is not discarded entirely as it is needed to subst ips in https/svcb responses - if synthAns { - t2 = nil // assert t2 == nil? - } - t1res := make(chan *dns.Msg, 1) - innersummary := copySummary(smm) - // todo: use context? - secch := t.qs(t2, uid, network, q, t1res) // t2 may be nil - - if synthAns { - ansin, err = synthesizeOrQuery(preset, t1, q, network, innersummary, usefixed) - } else { - ansin, err = t.qp(t1, uid, network, q, innersummary) - } - t1res <- ansin // ansin may be nil; but that's ok - - // override relevant values in smm - fillSummary(innersummary, smm) - - if err != nil || ansin == nil { - if ansin == nil { - log.I("alg: abort no ans on %s+%s[%s]; self? %t synth? %t; qerr %v", - idstr(t1), idstr(t2), uid, uidself, synthAns, err) - return nil, nil, core.JoinErr(err, errNoAnswer) - } - if !xdns.HasRcodeSuccess(ansin) { - return ansin, nil, err - } - if settings.Debug { - log.D("alg: for %s:%s err but ans ok: %d; do? %t, self? %t synth? %t; qerr %v", - qname(q), qtype(q), xdns.Len(ansin), hasdnssec, uidself, synthAns, err) - } - } - - hasauth64 := false - hasauth := xdns.IsDNSSECAnswerAuthenticated(ansin) - - qname := qname(ansin) - qtyp := qtype(ansin) - smm.QName = qname - smm.QType = qtyp - - // if usefixed is true, then d64 is no-op, as preset fixed ip does have ipv6 - ans64 := t.dns64.D64(network, t1.ID(), uid, ansin) // ans64 may be nil if no D64 or error - if ans64 != nil { - if settings.Debug { - log.D("alg: %s<>%s:%s[%s] %d dns64; dnssec? %t; s/ans(%d)/ans64(%d)", - qname, smm.ID, idstr(t1), uid, qtyp, hasdnssec, xdns.Len(ansin), xdns.Len(ans64)) - } - withDNS64Summary(ans64, smm) - // todo: for uidself, skip dns64? see: ipmapper.go:undoAlgAndOrNat64 - // todo: skip for for undelegated domains like ipv4only.arpa? - ansin = ans64 - // false if ans64 is synthesized or nil - // or its value matches hasauth - hasauth64 = xdns.IsDNSSECAnswerAuthenticated(ans64) - } // else: no dns64, or error; continue with ansin - - smm.AD = hasauth && hasauth64 // true if both ansin and ans64 are authenticated - - hasq := xdns.HasAAAAQuestion(ansin) || xdns.HasAQuestion(ansin) || - xdns.HasSVCBQuestion(ansin) || xdns.HasHTTPQuestion(ansin) - hasans := xdns.HasAnyAnswer(ansin) - rgood := xdns.HasRcodeSuccess(ansin) - // for t1, ansin's already evaluated for ans0000 in querySecondary - // (secans.pri is set to true). ansin may be from t2 (if t2 != nil), - // ans64, which is a modified ansin, depending on settings.PtMode - ans0000 := xdns.AQuadAUnspecified(ansin) // ansin is not nil; ans64 may be nil - - if ans0000 { - smm.UpstreamBlocks = true - } - - // todo: skip alg for undelegated domains like ipv4only.arpa? - if !hasq || !hasans || !rgood || ans0000 || dontalg { - if settings.Debug { - log.D("alg: skip; query %s<>%s[%s]:%s:%d / a:%d + rdata: %s + status: %d, dnssec(do? %t /ad? %t) self(%t) dontalg(%t) hasq(%t) hasans(%t) rgood(%t), ans0000(%t)", - smm.ID, idstr(t1), uid, qname, qtyp, xdns.Len(ansin), smm.RData, smm.Status, smm.DO, smm.AD, uidself, dontalg, hasq, hasans, rgood, ans0000) - } - return ansin, nil, nil - } - - a6 := xdns.AAAAAnswer(ansin) - a4 := xdns.AAnswer(ansin) - ip4hints := xdns.IPHints(ansin, dns.SVCB_IPV4HINT) - ip6hints := xdns.IPHints(ansin, dns.SVCB_IPV6HINT) - // TODO: generate one alg ip per target, synth one rec per target - targets := xdns.Targets(ansin) - realip := make([]netip.Addr, 0) - var algip4, algip6 netip.Addr - - // fetch secondary ips before locks - // these may be from primary when secans.pri is true - secres := <-secch - - // inform kt of secondary blocklists, if any - smm.Blocklists = secres.smm.Blocklists - smm.BlockedTarget = secres.smm.BlockedTarget - smm.UpstreamBlocks = secres.smm.UpstreamBlocks || smm.UpstreamBlocks - - if smm.UpstreamBlocks || len(secres.smm.Msg) > 0 { - smsg := secres.smm.Msg - spri := secres.pri - log.V("alg: %s<>%s[%s]:%s:%d upstream blocks: primary? %t / sec? %t; secres: pri? %t, msg: %s", - smm.ID, idstr(t1), uid, qname, qtyp, secres.smm.UpstreamBlocks, smm.UpstreamBlocks, spri, smsg) - } - - defer func() { - // answers in outmsg may not have been cached at all by xips - // since register may not have happened at all - xdns.BustAndroidCacheIfNeeded(outmsg) - - if isAlgErr(outerr) && !mod { - if settings.Debug { - log.D("alg: %s<>%s[%s]:%s:%d no mod; suppress err %v", - smm.ID, idstr(t1), uid, qname, qtyp, outerr) - } - outerr = nil // ignore alg errors if no modification is desired - } - }() - - ansttl := time.Duration(xdns.RTtl(ansin)) * time.Second - - t.Lock() - defer t.Unlock() - - var algip4hints, algip6hints, algip4s, algip6s netip.Addr - if len(a6) > 0 { - realip = append(realip, a6...) - // choose the first alg ip6; may've been generated by a6 - algip, ipok := t.take6Locked(qname, 0) - if !ipok { - return ansin, nil, errAlgNotAvail - } - algip6s = algip - } - if len(a4) > 0 { - realip = append(realip, a4...) - algip, ipok := t.take4Locked(qname, 0) - if !ipok { - return ansin, nil, errAlgNotAvail - } - algip4s = algip - } - if len(ip4hints) > 0 { - realip = append(realip, ip4hints...) - // choose the first alg ip4; may've been generated by a4 - if !algip4s.IsValid() { - // 0th algip is reserved for A records - algip, ipok := t.take4Locked(qname, 0) - if !ipok { - return ansin, nil, errAlgNotAvail - } - algip4hints = algip - } else { - algip4hints = algip4s - } - } - if len(ip6hints) > 0 { - realip = append(realip, ip6hints...) - if !algip6s.IsValid() { - // 0th algip is reserved for AAAA records - algip, ipok := t.take6Locked(qname, 0) - if !ipok { - return ansin, nil, errAlgNotAvail - } - algip6hints = algip - } else { - algip6hints = algip6s - } - } - - algXlatTtl := xdns.ZeroTTL - substok4 := false - substok6 := false - // substitutions needn't happen when no alg ips to begin with - // but must happen if (real) ips are fixed - mustsubst := false || usefixed - ansmod := xdns.CopyAns(ansin) - // TODO: substitute ips in additional section - if algip4hints.IsValid() { - substok4 = xdns.SubstSVCBRecordIPs( /*out*/ ansmod, dns.SVCB_IPV4HINT, algip4hints, algXlatTtl) || substok4 - mustsubst = true - } - if algip6hints.IsValid() { - substok6 = xdns.SubstSVCBRecordIPs( /*out*/ ansmod, dns.SVCB_IPV6HINT, algip6hints, algXlatTtl) || substok6 - mustsubst = true - } - if algip4s.IsValid() { - substok4 = xdns.SubstARecords( /*out*/ ansmod, algip4s, algXlatTtl) || substok4 - mustsubst = true - } - if algip6s.IsValid() { - substok6 = xdns.SubstAAAARecords( /*out*/ ansmod, algip6s, algXlatTtl) || substok6 - mustsubst = true - } - - if settings.Debug { - log.D("alg: %s<>%s[%s]; %s:%d (split? %t, do? %t / ad? %t) a6(a %d / h %d / s %t) : a4(a %d / h %d / s %t); ttl: %s", - smm.ID, idstr(t1), uid, qname, qtyp, !discarduid, smm.DO, smm.AD, len(a6), len(ip6hints), substok6, len(a4), len(ip4hints), substok4, ansttl) - } - if !substok4 && !substok6 { - if mustsubst { // always true when usefixed is true - err = errAlgCannotSubst - } else { // no algips - err = nil - } - logeif(err != nil)("alg: %s<>%s[%s]: skip; err(%v); ips subst %s:%d; fixed? %t, split? %t", - smm.ID, idstr(t1), uid, err, qname, qtyp, usefixed, !discarduid) - return ansin, nil, err // ansin is nil if no alg ips - } - - var fixedips []netip.Addr - if usefixed { - // if usefixed, then realips are in fact fixedips - fixedips = realip - // empty out realip and secres.ips got from answers - // secres.ips is fixedips anyway since t2 is nil - realip = nil - secres.ips = nil - } - - if algip4s.IsValid() { - algip4 = algip4s - } else { - algip4 = algip4hints - } - if algip6s.IsValid() { - algip6 = algip6s - } else { - algip6 = algip6hints - } - - // always use the ID as set in the summary; which may or may not match - // the primary t1.ID(). For instance, a caching dnsx.Transport (cacher) - // may set DNSSummary.ID of the underlying dnsx.Transport that fetched - // the answer, whose ID != to t1 (cacher) itself. OTOH, dnsx.Resolver - // uses DNSSummary.ID when returning ans to the caller (ex: ipmapper) - tidToReg := smm.ID - if settings.Debug { - log.D("alg: ok; for %s<>%s[%s]:%s:%d (do? %t / ad? %t), domains %s real: %s / fix: %s => subst %s | %s; (mod? %t / fix? %t / synth? %t / split? %t); sec %s; ttl %s", - tidToReg, idstr(t1), uid, qname, qtyp, smm.DO, smm.AD, targets, realip, fixedips, algip4, algip6, mod, usefixed, synthAns, !discarduid, secres.ips, ansttl) - } - - // always register algips, even if mod is false, to maintain a hot cache. - // if client enables mod later, algips will be instantly available. - algok := t.registerLocked(qname, tidToReg, uid, algip4, algip6, realip, ansttl, targets, secres) - - if mod { // if mod is set, send modified answer and summary - withAlgSummary(smm, algip4, algip6) - if algok { - return ansin, ansmod, nil - } // else: err out if alg is requested but algips could not be registered - return ansin, nil, errAlgCannotRegister - } - return ansin, nil, nil -} - -func Netip2Csv(ips []netip.Addr) (csv string) { - if len(ips) <= 0 { - return "" - } - out := make([]string, 0, len(ips)) - for _, ip := range ips { - if ip.IsValid() { - out = append(out, ip.String()) - } - } - return strings.Join(out, ",") -} - -func Csv2Netip(csv string) (ips []netip.Addr) { - out := make([]netip.Addr, 0) - for ip := range strings.SplitSeq(csv, ",") { - if ipaddr, err := netip.ParseAddr(ip); ipaddr.IsValid() && err == nil { - out = append(out, ipaddr) - } - } - return out -} - -func withDNS64Summary(ans64 *dns.Msg, s *x.DNSSummary) { - s.RCode = xdns.Rcode(ans64) - s.RData = xdns.GetInterestingRData(ans64) - s.RTtl = xdns.RTtl(ans64) - if settings.Debug { - prefix := PrefixFor(AlgDNS64) - s.Server = prefix + s.Server - } -} - -func withAlgSummary(s *x.DNSSummary, algips ...netip.Addr) { - if settings.Debug { - // convert algips to ipcsv; any algips may be invalid - ipcsv := Netip2Csv(algips) - - if len(s.RData) > 0 { - s.RData = s.RData + "," + ipcsv - } else { - s.RData = ipcsv - } - prefix := PrefixFor(Alg) - if len(s.Server) > 0 { - s.Server = prefix + s.Server - } else { - s.Server = prefix + notransport - } - } - // if modified alg ips are being returned, then these are not authentic - s.AD = len(algips) > 0 -} - -func (t *dnsgateway) registerLocked(q, tid, uid string, algip4, algip6 netip.Addr, realips []netip.Addr, ttl time.Duration, targets []string, secres secans) bool { - if tid == notransport || tid == NoDNS || len(tid) <= 0 { - log.E("alg: no tid for %s@%s[%s]; real? %d, sec? %d", - q, tid, uid, len(realips), len(secres.ips)) - return false - } - if !algip4.IsValid() && !algip6.IsValid() { // defensive; should not happen - log.E("alg: no algips for %s@%s[%s]; real? %d, sec? %d", - q, tid, uid, len(realips), len(secres.ips)) - return false - } - - // some domain set very low ttl (ex: 1s for news.ycombinator.com) which - // is too short for translations; use a minimum of 15s to account - // for just-in-time re-resolution of the same domain by common.go via - // dialers.ResolverFor(uid) which may be called on new tcp / udp conn. - ttl = max(ttl8s, ttl) - - now := time.Now() - // ttl is used for algans and xips, but the alg'fied dns answer - // has a lower ttl as defined by const algttl (currently, 15s). - ansttl := now.Add(max(ttl2m, ttl)) - xipsttl := now.Add(ttl) - // secres.ips may be empty on timeout errors, or - // or same as realips if t2 is nil; realips can be nil - // if fixedips is being used. - am4 := &baseans{ - ips: NewXips(tid, uid, v4only(realips), v4only(secres.ips), xipsttl), - domains: NewXdomains(tid, uid, targets, xipsttl), // targets may be nil - blocklists: secres.smm.Blocklists, - ttl: ansttl, // extended by 2m on every use - } - am6 := &baseans{ - ips: NewXips(tid, uid, v6only(realips), v6only(secres.ips), xipsttl), - domains: NewXdomains(tid, uid, targets, xipsttl), // targets may be nil - blocklists: secres.smm.Blocklists, - ttl: ansttl, // extended by 2m on every use - } - - // Check if NewXips failed to create valid xips objects - if am4.ips == nil || am6.ips == nil || am4.domains == nil || am6.domains == nil { - log.E("alg: failed to create xips/xdomains for %s@%s[%s]; am4.ips: %v, am6.ips: %v", - q, tid, uid, am4.ips, am6.ips) - return false - } - - newEntry := false - didRegister := false - // register mapping from qname -> algip+realip (alg) and algip -> qname+realip (nat) - for _, ip := range []netip.Addr{algip4, algip6} { // algips may be nil? - var k string - var x *algans - if ip.IsValid() && ip.Is4() { - k = q + key4 + strconv.Itoa(0) - x = &algans{ - algip: ip, - baseans: am4, - } - } else if ip.IsValid() && ip.Is6() { - k = q + key6 + strconv.Itoa(0) - x = &algans{ - algip: ip, - baseans: am6, - } - } // else: ip invalid - if x == nil { // no valid algans - continue - } - if prevans := t.alg[k]; prevans != nil { - // merge x into prevans - prevans.merge(x) - x = prevans - } else { - t.alg[k] = x - t.nat[ip] = x.baseans - newEntry = true - } - // am.ips.realips() may return nil; ex: when preset fixed ips are used - x.ips.each(func(ip netip.Addr) { - // existing am is merged into am4/am6 by t.alg above - // register mapping from realip -> algip+qname (ptr) - t.ptr[ip] = x.baseans - }) - didRegister = true - } - logeif(!didRegister)("alg: algips (reg? %t / new? %t) (alg: %s+%s => real: %s) for %s@%s[%s]; real? %d, sec? %d; until (ans: %s / xips: %s)", - didRegister, newEntry, algip4, algip6, realips, q, tid, uid, len(realips), len(secres.ips), time.Until(ansttl), time.Until(xipsttl)) - - return didRegister -} - -func (t *dnsgateway) take4Locked(q string, idx int) (netip.Addr, bool) { - k := q + key4 + strconv.Itoa(idx) - if ans, ok := t.alg[k]; ok { - ip := ans.algip - if ip.Is4() { - ans.extend(ttl2m) - return ip, true - } else { - // shouldn't happen; if it does, rm erroneous entry - delete(t.alg, k) - delete(t.nat, ip) - ans.ips.each(func(ip netip.Addr) { - if pans := t.ptr[ip]; pans == ans.baseans { - delete(t.ptr, ip) - } - }) - } - } - - if t.chash { - for i := range maxiter { - genip := gen4Locked(k, i) - if !genip.IsGlobalUnicast() { - continue - } - if _, taken := t.nat[genip]; !taken { - return genip, genip.IsValid() - } - } - log.W("alg: gen: no more IP4s (%v)", q) - return zeroaddr, false - } - - gen := true - // 100.x.y.z: 4m+ ip4s - if z := t.octets[3]; z < 254 { - t.octets[3] += 1 // z - } else if y := t.octets[2]; y < 254 { - t.octets[2] += 1 // y - t.octets[3] = 1 // z - } else if x := t.octets[1]; x < 128 { - t.octets[1] += 1 // x - t.octets[2] = 0 // y - t.octets[3] = 1 // z - } else { - i := 0 - for kx, ent := range t.alg { - if i > maxiter { - break - } - if d := time.Since(ent.ttl); d > 0 { - log.I("alg: reuse stale alg %s for %s", kx, k) - delete(t.alg, kx) - delete(t.nat, ent.algip) - ent.ips.each(func(ip netip.Addr) { - if pans := t.ptr[ip]; pans == ent.baseans { - delete(t.ptr, ip) - } - }) - return ent.algip, true - } - i += 1 - } - gen = false - } - if gen { - // 100.x.y.z: big endian is network-order, which netip expects - b4 := [4]byte{t.octets[0], t.octets[1], t.octets[2], t.octets[3]} - genip := netip.AddrFrom4(b4).Unmap() - return genip, genip.IsValid() - } else { - log.W("alg: no more IP4s (%v)", t.octets) - } - return zeroaddr, false -} - -func gen4Locked(k string, hop int) netip.Addr { - s := strconv.Itoa(hop) + k - v22 := hash22(s) - // 100.64.y.z/15 2m+ ip4s - b4 := [4]byte{ - rfc6598[0], // 100 - rfc6598[1] + uint8(v22>>16), // 64 + int(6bits) - uint8((v22 >> 8) & 0xff), // extract next 8 bits - uint8(v22 & 0xff), // extract last 8 bits - } - - // why unmap? github.com/golang/go/issues/53607 - return netip.AddrFrom4(b4).Unmap() -} - -func (t *dnsgateway) take6Locked(q string, idx int) (netip.Addr, bool) { - k := q + key6 + strconv.Itoa(idx) - if ans, ok := t.alg[k]; ok { - ip := ans.algip - if ip.Is6() { - ans.extend(ttl2m) - return ip, true - } else { - // shouldn't happen; if it does, rm erroneous entry - delete(t.alg, k) - delete(t.nat, ip) - ans.ips.each(func(ip netip.Addr) { - if pans := t.ptr[ip]; pans == ans.baseans { - delete(t.ptr, ip) - } - }) - } - } - - if t.chash { - for i := range maxiter { - genip := gen6Locked(k, i) - if _, taken := t.nat[genip]; !taken { - return genip, genip.IsValid() - } - } - log.W("alg: gen: no more IP6s (%v)", q) - return zeroaddr, false - } - - gen := true - // 64:ff9b:1:da19:0100.x.y.z: 281 trillion ip6s - if z := t.hexes[7]; z < 65534 { - t.hexes[7] += 1 // z - } else if y := t.hexes[6]; y < 65534 { - t.hexes[6] += 1 // y - t.hexes[7] = 1 // z - } else if x := t.hexes[5]; x < 65534 { - t.hexes[5] += 1 // x - t.hexes[6] = 0 // y - t.hexes[7] = 1 // z - } else { - // possible that we run out of 200 trillion ips...? - gen = false - } - if gen { - // 64:ff9b:1:da19:0100.x.y.z: big endian is network-order, which netip expects - b16 := [16]byte{} - for i, hx := range t.hexes { - i = i * 2 - binary.BigEndian.PutUint16(b16[i:i+2], hx) - } - genip := netip.AddrFrom16(b16) - return genip, genip.IsValid() - } else { - log.W("alg: no more IP6s (%x)", t.hexes) - } - return zeroaddr, false -} - -func gen6Locked(k string, hop int) netip.Addr { - s := strconv.Itoa(hop) + k - v48 := hash48(s) - // 64:ff9b:1:da19:0100.x.y.z: 281 trillion ip6s - a16 := [8]uint16{ - rfc8215a[0], // 64 - rfc8215a[1], // ff9b - rfc8215a[2], // 1 - rfc8215a[3], // da19 - rfc8215a[4], // 0100 - uint16((v48 >> 32) & 0xffff), // extract the top 16 bits - uint16((v48 >> 16) & 0xffff), // extract the mid 16 bits - uint16(v48 & 0xffff), // extract the last 16 bits - } - b16 := [16]byte{} - for i, hx := range a16 { - i = i * 2 - binary.BigEndian.PutUint16(b16[i:i+2], hx) - } - return netip.AddrFrom16(b16) -} - -func (t *dnsgateway) S() string { - t.RLock() - defer t.RUnlock() - - var sb strings.Builder - sb.WriteString("dnsgateway state:\n") - sb.WriteString(" mod: ") - sb.WriteString(strconv.FormatBool(t.mod.Load())) - sb.WriteString(" / cansplit: ") - sb.WriteString(strconv.FormatBool(t.split.Load())) - sb.WriteString(" / wantsplit: ") - sb.WriteString(strconv.FormatBool(t.fixedTransport())) - sb.WriteString(" / chash: ") - sb.WriteString(strconv.FormatBool(t.chash)) - sb.WriteString(" / adv: ") - sb.WriteString(strconv.Itoa(len(t.alg))) - sb.WriteString(" / nat: ") - sb.WriteString(strconv.Itoa(len(t.nat))) - sb.WriteString(" / ptr: ") - sb.WriteString(strconv.Itoa(len(t.ptr))) - return sb.String() -} - -func (t *dnsgateway) X(maybeAlg netip.Addr, uid string, tids ...string) (ips []netip.Addr, undidAlg bool) { - t.RLock() - defer t.RUnlock() - - if !t.split.Load() { - uid = core.UNKNOWN_UID_STR - } - // stale IPs are okay iff !mod; as then maybeAlg itself is a realip - usestale := !t.mod.Load() - return t.xLocked(maybeAlg, usestale, uid, tids...) // ips may be 0 len -} - -func (t *dnsgateway) PTR(maybeAlg netip.Addr, uid, tid string, force bool) (domains string, didForce bool) { - t.RLock() - defer t.RUnlock() - - if !t.split.Load() { - uid = core.UNKNOWN_UID_STR - } - - // do not use t.ptr (realip -> ans) in mod (alg) mode, unless forced; - // as t.nat (algip -> ans) is a more accurate translation. - useptr := !t.mod.Load() || force - d := t.ptrLocked(maybeAlg, uid, tid, useptr) - if len(d) > 0 { - domains = strings.Join(d, ",") - } // else: algip isn't really an alg ip, nothing to do - return domains, useptr -} - -func (t *dnsgateway) RESOLV(domain, uid, tid string) []netip.Addr { - t.RLock() - defer t.RUnlock() - - typ := typalg - if !t.mod.Load() { - typ = typreal - } - if !t.split.Load() { - uid = core.UNKNOWN_UID_STR - } - // TODO: handle Preset IPs which aren't alg'd - // TODO: for some skipInternalCache(tid) and uid == protect.UidSelf - // alg caches (nat/ptr) won't have any entries - // See: dontalg var in dnsgateway.q() and dnsgateway.xLocked() - if uid == protect.UidSelf { - uid = core.UNKNOWN_UID_STR // wildcard, so xips searches across all UIDs - tid = notransport // wildcard, so xips searches across all TIDs - } - if len(tid) <= 0 { - tid = notransport - } - ip4s, ip6s, _, _, _ := t.resolvLocked(domain, typ, tid, uid) - return append(ip4s, ip6s...) -} - -func (t *dnsgateway) RDNSBL(algip netip.Addr) (blocklists string) { - t.RLock() - defer t.RUnlock() - - return t.rdnsblLocked(algip, !t.mod.Load()) -} - -func (t *dnsgateway) xLocked(maybeAlg netip.Addr, usestale bool, uid string, tids ...string) (realips []netip.Addr, _ bool) { - var until time.Duration - var undidAlg, undidPtr, fresh bool - - xst := makeXipStatus(!usestale) - // alg ips are always unmappped; see take4Locked - unmapped := maybeAlg.Unmap() // aligip may also be origip / realip - - // ignore & return the fake dns address as-is (ex: 10.111.222.3:53) - if anyAddrEqual(t.fake, unmapped) { - return []netip.Addr{maybeAlg}, false - } - - // see: dontalg var in dnsgateway.q() - // TODO: handle preset IPs that won't be in the ptr/nat caches - uidself := uid == protect.UidSelf - skippedcache := skipInternalCache(tids...) - didnotAlg := skippedcache || uidself - - if !didnotAlg { - var ans *baseans - // undidAlg is really "hasAnyAlgEntry"; set it to true - // regardless of len(realips) or freshness of ans.ips - if ans, undidAlg = t.nat[unmapped]; undidAlg { - if len(tids) <= 0 { - realips = ans.ips.realips(uid, xst) - } else { - for _, tid := range tids { - realips = append(realips, ans.ips.realipsFor(tid, uid, xst)...) - } - } - until, fresh = ans.fresh() - } else if ans, undidPtr = t.ptr[unmapped]; undidPtr { - // for IPs (unlike domains), it is okay to fallback on ptr as the - // maybeAlg may be an algip OR realip (latter in the case where an - // app is connecting to a cached IP addr from before t.mod was set) - // nb: both realips & secondaryips may be nil, but that's okay: - // go.dev/play/p/fSjRjMSAS2m - if len(tids) <= 0 { - realips = ans.ips.realips(uid, xst) - } else { - for _, tid := range tids { - realips = append(realips, ans.ips.realipsFor(tid, uid, xst)...) - } - } - until, fresh = ans.fresh() - } - } - - hasrealips := len(realips) > 0 - var unnated []netip.Addr - if !hasrealips { // algip is probably origip / realip - // unnat origip as it itself may have been synthesized from - // our DNS responses by apps doing funky things; like FreeFire - unnated = t.maybeUndoNat64Locked(unmapped) - } else { - unnated = t.maybeUndoNat64Locked(realips...) - } // else: send realips as is - - logeif(!hasrealips && (!usestale && (!undidAlg || !undidPtr)))("alg: dns64: for %v[%s] (didnotAlg? %t / fresh? %t / undidAlg? %t / undidPtr? %t / staleok? %t) algip(%v) => realips(%v) => unnated(%v); until: %s", - tids, uid, didnotAlg, fresh, undidAlg, undidPtr, usestale, unmapped, realips, unnated, until) - - if len(unnated) > 0 { // unnated is already de-duplicated - return unnated, undidAlg - } - - if !hasrealips { - // when realips is empty but one of undidAlg / undidPtr is not false, - // it means the client code may retry re-resolving the corresponding - // domain to freshen up alg mapping; which is to say, sending empty - // realips instead of unmapped as-is is a way to signal that - // the alg mapping is stale. - if undidAlg || undidPtr { - return realips, undidAlg // realips is empty here - } - // no algip, no realips, no unnated; - // ptr + nat alg mapping do not exist / apply; - // return unmapped as-is to the client code. - return []netip.Addr{unmapped}, undidAlg - } - - return copyUniq(realips), undidAlg -} - -func (t *dnsgateway) maybeUndoNat64Locked(realips ...netip.Addr) (unnateds []netip.Addr) { - for _, nip := range realips { - unmapped := nip.Unmap() - if !unmapped.Is6() { - continue - } - // the actual ID of the DNS64 for this whoever responded with "realips" for some unknown - // DNS query is not available. But, we needn't worry about UN-NAT64'ing other resolvers - // except the one we "force" onto the clients (aka dnsx.Local464Resolver). - // whether the active network has ipv4 connectivity is checked by dialers.filter() - ipx4, completed := core.Grx("undoNat64."+unmapped.String(), func(ctx context.Context) (netip.Addr, error) { - // with async+timeout to avoid blocking on mutex - return t.dns64.X64(Local464Resolver, unmapped), nil // ipx4 may be zero addr - }, ttl3s) - - logeif(!completed)("alg: dns64: maybeUndoNat64: nat64 to ip4(%v) for ip6(%v); timedout? %t", - ipx4, nip, !completed) - if !completed || !ipok(ipx4) { // no nat? - continue - } - unmapped4 := ipx4.Unmap() - unnateds = append(unnateds, unmapped4) - } - return copyUniq(unnateds) -} - -func (t *dnsgateway) ptrLocked(maybeAlg netip.Addr, uid, tid string, useptr bool) (domains []string) { - // alg ips are always unmappped; see take4Locked - unmapped := maybeAlg.Unmap() - if len(uid) <= 0 { - uid = core.UNKNOWN_UID_STR - } - if len(tid) <= 0 { - tid = notransport - } - if ans, ok := t.nat[unmapped]; ok { - domains = domainsFor(ans, tid, uid, unmapped, xalive) - } else if ans, ok := t.ptr[unmapped]; useptr && ok { - // translate from realip only if not in mod mode - // for useptr, s/xalive/xall/ - domains = domainsFor(ans, tid, uid, unmapped, xalive /*prefer fresh mapping */) - if len(domains) <= 0 { - domains = domainsFor(ans, tid, uid, unmapped, xall /*useptr == true */) - } - } - return copyUniq(domains) -} - -// resolvLocked returns IPs and related targets for domain -// depending on typ. -// If typ is typalg, returns all algips for domain. -// If typ is typreal, returns all realips for domain resolved by tid. -// If typ is typsecondary, returns all secondaryips for domain (disregarding tid). -// ip4s and ip6s may overlap, and are segregated by the source algip -// family (and not by the family of the resolved IPs themselves). -func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (ip4s, ip6s, staleips []netip.Addr, until time.Duration, targets []string) { - partkey4 := domain + key4 - partkey6 := domain + key6 - until = time.Duration(math.MaxInt64) - - ip4s = make([]netip.Addr, 0) - ip6s = make([]netip.Addr, 0) - targets = make([]string, 0) - staleips = make([]netip.Addr, 0) - switch typ { - case typalg: - for i := range 2 { - k4 := partkey4 + strconv.Itoa(i) - if ans, ok := t.alg[k4]; ok { - if life, fresh := ans.fresh(); fresh { // not stale - ip4s = append(ip4s, ans.algip) - targets = append(targets, domainsFor(ans.baseans, tid, uid, ans.algip, xalive)...) - until = min(until, life) - } else { - staleips = append(staleips, ans.algip) - } - } else { - break - } - } - for i := range 2 { - k6 := partkey6 + strconv.Itoa(i) - if ans, ok := t.alg[k6]; ok { - if life, fresh := ans.fresh(); fresh { // not stale - ip6s = append(ip6s, ans.algip) - targets = append(targets, domainsFor(ans.baseans, tid, uid, ans.algip, xalive)...) - until = min(until, life) - } else { - staleips = append(staleips, ans.algip) - } - } else { - break - } - } - if settings.Debug { - log.V("alg: resolv: %s:%s[%s] => alg ip4 %d, ip6 %d (until: %s); stale %v", - domain, tid, uid, len(ip4s), len(ip6s), until, staleips) - } - case typreal: - for i := range 2 { // a = 0, https/svcb = 1+ - k4 := partkey4 + strconv.Itoa(i) - if ans, ok := t.alg[k4]; ok { - if life, fresh := ans.fresh(); fresh { // not stale - all4s := v4only(ans.ips.realipsFor(tid, uid, xalive)) - ip4s = append(ip4s, all4s...) - targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all4s), xalive)...) - until = min(until, life) - } else { - staleips = append(staleips, ans.ips.realipsFor(tid, uid, xall)...) - } - // all ans{} have all realips; pick the first one - break - } // continue - } - for i := range 2 { // aaaa = 0, https/svcb = 1+ - k6 := partkey6 + strconv.Itoa(i) - if ans, ok := t.alg[k6]; ok { - if life, fresh := ans.fresh(); fresh { // not stale - all6s := v6only(ans.ips.realipsFor(tid, uid, xalive)) - ip6s = append(ip6s, all6s...) - targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all6s), xalive)...) - until = min(until, life) - } else { - staleips = append(staleips, ans.ips.realipsFor(tid, uid, xall)...) - } - // all ans{} have all realips; pick the first one - break - } // continue - } - if settings.Debug { - log.V("alg: resolv: %s:%s[%s] => real(ip4 %d, ip6 %d) until: %s; stale %v", - domain, tid, uid, len(ip4s), len(ip6s), until, staleips) - } - case typsecondary: - for i := range 2 { // a = 0, https/svcb = 1+ - k4 := partkey4 + strconv.Itoa(i) - if ans, ok := t.alg[k4]; ok { - if life, fresh := ans.fresh(); fresh { // not stale - all4s := v4only(ans.ips.secipsFor(tid, uid)) - ip4s = append(ip4s, all4s...) - targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all4s), xalive)...) - until = min(until, life) - } else { - staleips = append(staleips, ans.ips.secips(xall)...) - } - // all ans{} have all secondaryips; pick the first one - break - } // continue - } - for i := range 2 { // aaaa = 0, https/svcb = 1+ - k6 := partkey6 + strconv.Itoa(i) - if ans, ok := t.alg[k6]; ok { - if life, fresh := ans.fresh(); fresh { // not stale - all6s := v6only(ans.ips.secipsFor(tid, uid)) - ip6s = append(ip6s, all6s...) - targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all6s), xalive)...) - until = min(until, life) - } else { - // TODO: stale targets? - staleips = append(staleips, ans.ips.secips(xall)...) - } - // all ans{} have all secondaryips; pick the first one - break - } // continue - } - if settings.Debug { - log.V("alg: resolv: %s:%s[%s] => secondary ip4 %d, ip6 %d (until: %s); stale %v", - domain, tid, uid, len(ip4s), len(ip6s), until, staleips) - } - } - - return -} - -func (t *dnsgateway) rdnsblLocked(algip netip.Addr, useptr bool) (bcsv string) { - // alg ips are always unmappped; see take4Locked - unmapped := algip.Unmap() - if ans, ok := t.nat[unmapped]; ok { - bcsv = ans.blocklists - } else if ans, ok := t.ptr[unmapped]; useptr && ok { - // translate from realip only if not in mod mode - bcsv = ans.blocklists - } - return -} - -// xor fold fnv to 18 bits: www.isthe.com/chongo/tech/comp/fnv -func hash22(s string) uint32 { - h := fnv.New64a() - _, _ = h.Write([]byte(s)) - v64 := h.Sum64() - return (uint32(v64>>22) ^ uint32(v64)) & 0x3FFFFF // 22 bits -} - -// xor fold fnv to 48 bits: www.isthe.com/chongo/tech/comp/fnv -func hash48(s string) uint64 { - h := fnv.New64a() - _, _ = h.Write([]byte(s)) - v64 := h.Sum64() - return (uint64(v64>>48) ^ uint64(v64)) & 0xFFFFFFFFFFFF // 48 bits -} - -func synthesizeOrQuery(preset []netip.Addr, tr Transport, msg *dns.Msg, network string, smm *x.DNSSummary, fixed bool) (*dns.Msg, error) { - // synthesize a response with the given ips - if len(preset) == 0 { - return Req(tr, network, msg, smm) - } - if msg == nil || !xdns.HasAnyQuestion(msg) { - return nil, errNoQuestion - } - algXlatTtl := xdns.ZeroTTL - qname := qname(msg) - qtyp := uint16(qtype(msg)) - is4 := xdns.IsAQType(qtyp) - is6 := !is4 && xdns.IsAAAAQType(qtyp) - isHTTPS := (!is4 && !is6) && xdns.IsHTTPSQType(qtyp) - isSVCB := (!is4 && !is6) && xdns.IsSVCBQType(qtyp) - if is4 || is6 { - // if no ips are of the same family as the question xdns.AQuadAForQuery returns error - ans, err := xdns.AQuadAForQuery(msg, preset...) - if err != nil { // errors on invalid msg, question, or mismatched ips - log.W("alg: synthesize: %s with %v; err(%v); using tr %s", - qname, preset, err, idstr(tr)) - return Req(tr, network, msg, smm) - } - withPresetSummary(smm, false /*req sent?*/, fixed) - smm.ID = Preset - smm.RCode = xdns.Rcode(ans) - smm.RData = xdns.GetInterestingRData(ans) - smm.RTtl = xdns.RTtl(ans) // usually 1 per xdns.AnsTTL - - log.V("alg: synthesize: %s q(4? %t / 6? %t), fixed? %t, rdata(%s)", - qname, is4, is6, fixed, smm.RData) - - return ans, nil // no error - } else if isHTTPS || isSVCB { - ans, err := Req(tr, network, msg, smm) - if err != nil { - return ans, err - } else if ans == nil { // empty answer is ok - return nil, errNoAnswer - } - var ok4, ok6 bool - ip4s, ip6s := splitIPFamilies(preset) - if len(ip4s) > 0 { - ok4 = xdns.SubstSVCBRecordIPs( /*out*/ ans, dns.SVCB_IPV4HINT, ip4s[0], algXlatTtl) - } - if len(ip6s) > 0 { - ok6 = xdns.SubstSVCBRecordIPs( /*out*/ ans, dns.SVCB_IPV6HINT, ip6s[0], algXlatTtl) - } - - withPresetSummary(smm, true /*req sent?*/, fixed) - smm.RCode = xdns.Rcode(ans) - smm.RData = xdns.GetInterestingRData(ans) - smm.RTtl = xdns.RTtl(ans) - - log.D("alg: synthesize: q: %s; (HTTPS? %t / fixed? %t); subst4(%t), subst6(%t); rdata(%s); tr: %s", - qname, isHTTPS, fixed, ok4, ok6, smm.RData, idstr(tr)) - - return ans, nil // no error - } else { - note := log.VV - if fixed { - note = log.W - } - note("alg: synthesize: %s skip; fixed? %t, qtype %d; using tr %s", - qname, fixed, qtyp, idstr(tr)) - return Req(tr, network, msg, smm) - } -} - -// Req sends q to transport t and returns the answer, if any; -// errors are unset if answer is not servfail or empty; -// smm, the in/out parameter, is dns summary as got from t. -func Req(t Transport, network string, q *dns.Msg, smm *x.DNSSummary) (*dns.Msg, error) { - if t == nil || core.IsNil(t) { - return nil, errNoSuchTransport - } - if !xdns.HasAnyQuestion(q) { - return nil, errNoQuestion - } - qname := qname(q) - qtyp := qtype(q) - - if smm == nil { // discard smm - discarded := new(x.DNSSummary) - smm = discarded - } - smm.ID = idstr(t) - if len(smm.QName) <= 0 { - smm.QName = qname - } - if smm.QType <= 0 { - smm.QType = qtyp - } - - r, err := t.Query(network, q, smm) - - if r == nil { - if settings.Debug { - log.V("alg: Req: %s:%d no answer; by: %s, rdata: %s, status: %d; err? %v", - qname, qtyp, smm.ID, smm.RData, smm.Status, err) - } - return nil, err // err may be nil - } - if !xdns.IsServFailOrInvalid(r) { - return r, nil - } - - if settings.Debug { - log.V("alg: Req: %s:%d servfail; by: %s, rdata: %s, status: %d, rcode %d", - qname, qtyp, smm.ID, smm.RData, smm.Status, xdns.Rcode(r)) - } - return r, err -} - -func ChooseHealthyProxy(who string, ipps []netip.AddrPort, pids []string, px ipn.ProxyProvider) (pid string) { - var errs []error - pid = NetNoProxy - if len(pids) > 0 { - pid = pids[0] - } - foundProxy := false - cipp := netip.AddrPort{} - for _, ipp := range ipps { - if !ipp.IsValid() { - continue - } - if p, err := px.ProxyTo(ipp, protect.UidSelf, pids); err == nil { - pid = proxyID(p) - foundProxy = pid != NetNoProxy - cipp = ipp - break - } else { - errs = append(errs, err) - } - } - logeif(!foundProxy)("%s: proxy for %s [among %v]; choosing %s among %v; errs? %v", - who, cipp, ipps, pid, pids, core.JoinErr(errs...)) - return -} - -func ChooseHealthyProxyHostPort(who string, host string, port uint16, pids []string, px ipn.ProxyProvider) (pid string) { - var ipps []netip.AddrPort - - splithost, _, _ := net.SplitHostPort(host) - if len(splithost) > 0 { - host = splithost - } - - if c := dialers.Confirmed(host); c.IsValid() { - ipps = append(ipps, netip.AddrPortFrom(c, port)) - } - - for _, ip := range dialers.For(host) { - if ip.IsValid() { - ipps = append(ipps, netip.AddrPortFrom(ip, port)) - } - } - - return ChooseHealthyProxy(who+" : "+host, ipps, pids, px) -} - -func proxyID(p ipn.Proxy) string { - if p == nil { - return NetNoProxy - } - return p.ID() -} - -func splitIPFamilies(ips []netip.Addr) (ip4s, ip6s []netip.Addr) { - for _, ip := range ips { - if !ip.IsValid() { - continue - } - ip = ip.Unmap() - if ip.Is4() { - ip4s = append(ip4s, ip) - } else if ip.Is6() { - ip6s = append(ip6s, ip) - } - } - return -} - -func v4only(ips []netip.Addr) []netip.Addr { - return core.FilterLeft(ips, func(ip netip.Addr) bool { - return ip.Is4() - }) -} - -func v6only(ips []netip.Addr) []netip.Addr { - return core.FilterLeft(ips, func(ip netip.Addr) bool { - return ip.Is6() - }) -} - -func withPresetSummary(smm *x.DNSSummary, reqSent, fixed bool) { - id := Preset - if fixed { - id = Fixed - } - // override id and type from whatever was set before - smm.ID = id - smm.Type = id - if !reqSent { // other unset fields if req not sent upstream - smm.Latency = 0 - smm.Status = Complete - smm.Server = "127.5.3.9" - } - smm.Server = PrefixFor(id) + smm.Server - smm.Blocklists = "" // blocklists are not honoured - smm.BlockedTarget = "" // no targets are blocked - smm.PID = "" // no relay is used - smm.RPID = "" // no hops either -} - -func idstr(t Transport) string { - if t == nil { - return notransport - } - return t.ID() -} - -func infcsv(ts ...Transport) string { - var s []string - for _, t := range ts { - s = append(s, idstr(t)+":"+getaddrstr(t)) - } - return strings.Join(s, ",") -} - -func getaddrstr(t Transport) string { - if t == nil { - return notransport - } - return t.GetAddr() -} - -func ipok(ip netip.Addr) bool { - return !ip.IsUnspecified() && ip.IsValid() -} - -// flattens and returns a copy with dups removed, if any. -func copyUniq[T comparable](a ...[]T) (out []T) { - return core.CopyUniq(a...) -} - -func logeif(cond bool) log.LogFn { - if cond { - return log.E - } - return log.D -} - -func logwif(cond bool) log.LogFn { - if cond { - return log.W - } - return log.D -} diff --git a/intra/dnsx/cacher.go b/intra/dnsx/cacher.go deleted file mode 100644 index cde56a8e..00000000 --- a/intra/dnsx/cacher.go +++ /dev/null @@ -1,617 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - "context" - "errors" - "fmt" - "hash/fnv" - "math/rand" - "net/netip" - "strconv" - "strings" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -// ideally set by clients via dnsx.cache-opts -const ( - // time to live for a cached response - defttl = 2 * time.Hour - // max bumps before we stop bumping a response - defbumps = 10 - // max size per cache bucket - defsize = 256 - // total cache buckets; can't be more than 256 (uint8 255) - defbuckets = 128 - // min duration between scrubs - scrubgap = 5 * time.Minute - // ttl for expired response - stalettl = 15 // seconds - // ttl for response from requests that were barriered - // (ideally, longer than request timeouts) - battl = 30 * time.Second - // threshold for hangover responses - httl = 10 * time.Second - // how many entries to scrub at a time per cache bucket - maxscrubs = defsize / 4 // 25% of the cache - // separator qname, qtype cache-key - cacheKeySep = ":" -) - -var ( - errNoQuestion = errors.New("dns: no question") - errNoAnswer = errors.New("dns: no answer") - errServFail = errors.New("dns: answer servfail") - errBlocked = errors.New("dns: answer blocked") - errHangover = errors.New("dns: no connectivity") - errNilCacheResponse = errors.New("dns: nil cache response") - errCannotUseStaleCache = errors.New("dns: cannot use stale cache response") - errSkipInternalCache = errors.New("dns: skip internal cache") - errCacheResponseMismatch = errors.New("dns: cache response mismatch") -) - -type Cacher interface { - Transport - Clear() -} - -type cache struct { - mu *sync.RWMutex // protects cache, cres, and scrubtime - c map[string]*cres // query -> response - - scrubtime time.Time // last time cache was scrubbed / purged - - // constants: - - ttl time.Duration // how long to cache the valid dns response - halflife time.Duration // how much to increment ttl on each read - bumps int // max bumps before we stop bumping a response - size int // max size of the cache -} - -type cres struct { - ans *dns.Msg - s *x.DNSSummary - expiry time.Time - bumps int -} - -// todo: 0s answer ttl? native caching in alg? -type ctransport struct { - sync.RWMutex // protects store - Transport // embeds the underlying transport - - ctx context.Context - done context.CancelFunc - store []*cache // cache buckets - ipport string // a fake ip:port - ttl time.Duration // lifetime duration of a cached dns entry - halflife time.Duration // increment ttl on each read - bumps int // max bumps in lifetime of a cached response - size int // max size of a cache bucket - reqbarrier *core.Barrier[*cres, string] // coalesce requests for the same query - hangover *core.Hangover // tracks send failure threshold -} - -var _ Cacher = (*ctransport)(nil) - -func NewDefaultCachingTransport(t Transport) Transport { - return NewCachingTransport(t, defttl) -} - -func NewCachingTransport(t Transport, ttl time.Duration) Transport { - if t == nil { - return nil - } - - // is type casting is a better way to do this? - if cachedTransport(t) { - log.I("cache: (%s) no-op: %s", t.ID(), t.GetAddr()) - return t - } - if strings.HasPrefix(t.GetAddr(), algprefix) { - log.W("cache: (%s) no-op for alg: %s", t.ID(), t.GetAddr()) - return t - } - ctx, done := context.WithCancel(context.Background()) - ct := &ctransport{ - Transport: t, - ctx: ctx, - done: done, - store: make([]*cache, defbuckets), - ipport: "[fdaa:cac::ed:3]:53", - ttl: ttl, - halflife: ttl / 2, - bumps: defbumps, - size: defsize, - reqbarrier: core.NewBarrier[*cres](battl), - hangover: core.NewHangover(), - } - context.AfterFunc(ctx, ct.Clear) - log.I("cache: (%s) setup: %s; opts: %s", ct.ID(), ct.GetAddr(), ct) - return ct -} - -func (c *cres) copy() *cres { - var anscopy *dns.Msg - if c.ans != nil { - anscopy = c.ans.Copy() - } - return &cres{ - ans: anscopy, // may be nil - s: copySummary(c.s), - expiry: c.expiry, // may be zero - bumps: c.bumps, - } -} - -// String implements fmt.Stringer -func (cr *cres) String() string { - if cr == nil { - return "" - } - return fmt.Sprintf("bumps=%d; expiry=%s; s=%s", cr.bumps, timestamp(cr.expiry), cr.s) -} - -func timestamp(t time.Time) string { - return t.Format(time.Stamp) -} - -// String implements fmt.Stringer -func (t *ctransport) String() string { - if t == nil { - return "" - } - return fmt.Sprintf("ttl=%s; bumps=%d; size=%d", t.ttl, t.bumps, t.size) -} - -func hash(s string) uint8 { - h := fnv.New32a() - _, _ = h.Write([]byte(s)) - return uint8(h.Sum32() % defbuckets) -} - -func mkcachekey(q *dns.Msg) (string, uint8, bool) { - if q == nil { - return "", 0, false - } - - qname, err := xdns.NormalizeQName(xdns.QName(q)) - if len(qname) <= 0 || err != nil { - return "", 0, false - } - qtyp := strconv.Itoa(int(xdns.QType(q))) - do := "0" - if xdns.IsDNSSECRequested(q) { - do = "1" - } - - return qname + cacheKeySep + qtyp + cacheKeySep + do, hash(qname), true -} - -// scrubCache deletes expired entries from the cache. -// Must be called from a goroutine. -func (cb *cache) scrubCache() { - // must unlock from deferred since panics are recovered above - cb.mu.Lock() - defer cb.mu.Unlock() - - now := time.Now() - if now.Sub(cb.scrubtime) < scrubgap { - return - } - cb.scrubtime = now - - // scrub the cache if it's getting too big - highload := len(cb.c) >= cb.size*75/100 - - i, j, m := 0, 0, 0 - for k, v := range cb.c { - i++ - if highload && time.Since(v.expiry) > 0 { - // evict expired entries on high load, otherwise keep them - // around for use in cases where transport errors out - delete(cb.c, k) - j++ - } - if i > maxscrubs { - break - } - } - log.I("cache: del: %d; ref: %d; tot: %d / high? %t", j, m, i, highload) -} - -func (cb *cache) freshCopy(key string) (v *cres, ok bool) { - cb.mu.RLock() - defer cb.mu.RUnlock() - - if v, ok = cb.c[key]; !ok { - return - } - - recent := v.bumps <= 2 - alive := time.Since(v.expiry) <= 0 - if v.bumps < cb.bumps { - n := time.Duration(v.bumps) * cb.halflife - // if the expiry time is already n duration in the future, don't incr ttl - // or if the entry is already expired, don't incr ttl - if alive && time.Since(v.expiry.Add(-n)) < 0 { - v.expiry = v.expiry.Add(n) - } - v.bumps += 1 - } - - r50 := rand.Intn(99999) < 50000 // 50% chance of reusing from the cache - return v.copy(), (r50 || recent) && alive -} - -// put caches val against key, and returns true if the cache was updated. -// val must be a valid dns packet with successful rcode with no truncation. -func (cb *cache) put(key string, cc *cres) (ok bool) { - ok = false - if cc == nil { - return - } - - ans := cc.ans - // only cache successful responses - // TODO: implement negative caching - if ans == nil || !xdns.HasRcodeSuccess(ans) || xdns.HasTCFlag(ans) { - return - } - - // do not cache .onion addresses - if strings.Contains(key, ".onion"+cacheKeySep) { - return - } - cb.mu.Lock() - defer cb.mu.Unlock() - - if rand33pc() { // 33% of the time - core.Gx("c.scrubCache", cb.scrubCache) - } - - if len(cb.c) >= cb.size { - log.W("cache: put: cache overflow %d > %d", len(cb.c), cb.size) - } - - // 1. ansttl is 0 for synthesized "block" answers (see xdns.BlockTTL) - // 2. for most empty ans (like qtype:65), ansttl is 0 - ansttl := time.Duration(xdns.RTtl(ans)) * time.Second - - // cache must keep the answer alive for a min of cb.ttl - // because the client code may fail if the cache marks - // a recently incubated answer as stale. - if ansttl < cb.ttl { - ansttl = cb.ttl - } else { // bump up a bit longer than the ttl - ansttl = ansttl + cb.halflife - } - - exp := time.Now().Add(ansttl) - v := &cres{ // TODO: copy is not required? - ans: cc.ans.Copy(), - s: copySummary(cc.s), - expiry: exp, - bumps: 0, - } - cb.c[key] = v - - log.D("cache: put(%s): l(%t/%d); %s", key, xdns.HasAnyAnswer(ans), xdns.Len(ans), v) - - ok = true - return -} - -func asResponse(q *dns.Msg, v *cres, fresh bool) (a *dns.Msg, s *x.DNSSummary, err error) { - s = v.s // v must never be nil - a = v.ans - - if q == nil || !xdns.HasAnyQuestion(q) { - err = errNoQuestion - return - } - if a == nil { // cache ans may be "empty" but should not be nil - err = errNilCacheResponse - return - } - if !fresh && xdns.IsDNSSECAnswerAuthenticated(a) { - // for stale responses, TTL is modified which violates DNSSEC? - err = errCannotUseStaleCache - return - } - aname := qname(a) - qname := qname(q) - if aname != qname { - log.E("cache: asResponse: qname mismatch: a(%s) != q(%s)", aname, qname) - err = errCacheResponseMismatch - return - } - - a.Id = q.Id // OK to change even if dnssec? - // dns 0x20 may mangle the question section, so preserve it - // github.com/jedisct1/edgedns#correct-support-for-the-dns0x20-extension - a.Question = q.Question - a.Response = true // just to be sure - if !fresh { // if the v is not fresh, set the ttl to the minimum - xdns.WithTtl(a, stalettl) // only set for Answer records - } - return -} - -func (t *ctransport) ID() string { - // must match with how wrapping transports like DcProxy / Gateway rely on the ID - return CT + t.Transport.ID() -} - -func (t *ctransport) Type() string { - return t.Transport.Type() -} - -func (t *ctransport) hangoverCheckpoint() { - if t.Status() == SendFailed { - t.hangover.Note() - } else { - t.hangover.Break() - } -} - -func (t *ctransport) fetch(network string, q *dns.Msg, smmout *x.DNSSummary, cb *cache, key string) (*dns.Msg, error) { - sendRequest := func(q2 *dns.Msg, smm2 *x.DNSSummary) (*dns.Msg, error) { - reqsent := false - - defer func() { - // fill after summaries are filled - if !reqsent { - smm2.Cached = true - } - }() - - ccx := &cres{ans: nil, s: copySummary(smm2)} - cc, err := t.reqbarrier.DoIt(key, func() (_ *cres, qerr error) { - reqsent = true - // ans may be nil - ccx.ans, qerr = Req(t.Transport, network, q2, smm2) - ccx.s = copySummary(smm2) // copy summary to cc - t.hangoverCheckpoint() - // cb.put no-ops when ans is nil or rcode != success (0) - cb.put(key, ccx) - return ccx, qerr - }) - - if cc == nil { // may be nil for example when barrier times outs - log.E("cache: barrier: %s; nil return for %s; err? %v", t.ID(), key, err) - cc = ccx - } - - cachedres, fresh := cb.freshCopy(key) // always prefer value from cache - cachehit := cachedres != nil - // nil ans when Transport returns err (no servfail) and cache is empty - cachedans := cachedres != nil && cachedres.ans != nil - - // if there's no network connectivity (in hangover for 10s) don't - // return cached/barriered response, instead return an error - inhangover := t.hangover.Exceeds(httl) - - // expect fresh values, except on verrs - logwif(cachehit && !fresh || err != nil)("cache: barrier: (k: %s) hit? %t / hitans? %t / stale? %t / sent? %t / hangover? %t, barrier: %s (cache: %s); qerr? %v", - key, cachehit, cachedans, !fresh, reqsent, inhangover, cc, cachedres, err) - - if !cachehit || !cachedans { // cc.Val may be uncacheable (ex: rcode != 0) - cachedres = cc // cc (cres) never nil; but cc.ans may be nil - } - - if inhangover { - err = core.JoinErr(err, errHangover) - log.W("cache: barrier: hangover(k: %s); sent? %t, discard ans (has? %t)", - key, reqsent, cachedans) - fillSummary(cachedres.s, smm2) - // mimic send fail - smm2.Msg = err.Error() - smm2.RCode = dns.RcodeBadTime - smm2.Status = SendFailed - // do not return any response (stall / drop silently) - return nil, err - } - - // fres may be nil - fres, cachedsmm, ferr := asResponse(q2, cachedres, fresh) - fillSummary(cachedsmm, smm2) // cachedsmm may itself be smm2 - - return fres, core.JoinErr(err, ferr) - } - - // check if underlying transport can connect fine, if not treat cache - // as stale regardless of its freshness. this avoids scenario when there's - // no network connectivity but cache returns proper responses to queries, - // which results in confused apps that think there's network connectivity, - // that is, these confused apps go bezerk resulting in battery drain. - // has 10s elapsed since the first send failure - trok := t.hangover.Within(httl) - - if v, isfresh := cb.freshCopy(key); trok && v != nil { - var cachedsmm *x.DNSSummary - hasans := v.ans != nil - - log.D("cache: hit(k: %s / stale? %t / ans? %t): %s", key, !isfresh, hasans, v) - r, cachedsmm, err := asResponse(q, v, isfresh) // return cached response, may be stale - if err != nil { - nilOrMismatch := errors.Is(err, errNilCacheResponse) || - errors.Is(err, errCacheResponseMismatch) - - logeif(nilOrMismatch)("cache: hit(k: %s) %s, but err? %v", key, v, err) - - if nilOrMismatch { - // FIXME: this is a hack to fix an issue where the cache - // returns a response that does not match the fqdn in query - // or somehow has wrapper cache-obj but not the dns answer. - cb.mu.Lock() - delete(cb.c, key) // del the corrupted entry - cb.mu.Unlock() - } - // fallthrough to sendRequest - } else if cachedsmm != nil { - if !isfresh { // not fresh, fetch in the background - core.Gx("c.sendRequest: "+key+t.ID(), func() { - _, _ = sendRequest(q.Copy(), copySummary(smmout)) // summary may be cached - }) - } - // change summary fields to reflect cached response, except for latency - fillSummary(cachedsmm, smmout) - smmout.Latency = 0 // don't use cached latency - smmout.Cached = true - return r, nil - } // else: fallthrough to sendRequest - } else { - log.D("cache: miss(k: %s): cached? %t, hangover? %t, stale? %t", - key, v != nil, !trok, !isfresh) - } - - // send request in the foreground, and return the response - return sendRequest(q, smmout) // summary is filled by underlying transport -} - -func (t *ctransport) Query(network string, q *dns.Msg, smm *x.DNSSummary) (*dns.Msg, error) { - var response *dns.Msg - var err error - var cb *cache - - if key, h, ok := mkcachekey(q); ok { - t.Lock() - cb = t.store[h] - if cb == nil { - cb = &cache{ - c: make(map[string]*cres), - mu: &sync.RWMutex{}, - size: t.size, - ttl: t.ttl, - bumps: t.bumps, - halflife: t.halflife, - } - t.store[h] = cb - } - t.Unlock() - - response, err = t.fetch(network, q, smm, cb, key) - } else { - err = errMissingQueryName // not really a transport error - } - - return response, err -} - -func (t *ctransport) P50() int64 { - return 0 -} - -func (t *ctransport) GetAddr() string { - prefix := PrefixFor(CT) - return prefix + t.Transport.GetAddr() -} - -func (t *ctransport) IPPorts() []netip.AddrPort { - return t.Transport.IPPorts() -} - -func (t *ctransport) Status() int { - return t.Transport.Status() -} - -func (t *ctransport) Stop() error { - t.done() - // does not call stop on underlying transport as - // it does not "own" it but merely "decorates" it - return nil -} - -func (t *ctransport) Clear() { - t.Lock() - defer t.Unlock() - - defer clear(t.store) - for _, c := range t.store { - if c != nil { - c.mu.Lock() - clear(c.c) - c.mu.Unlock() - } - } -} - -func copySummary(from *x.DNSSummary) (to *x.DNSSummary) { - to = new(x.DNSSummary) - *to = *from // go.dev/play/p/rcGKAcju0FU - return -} - -// fillSummary copies non-zero values into out. -func fillSummary(s *x.DNSSummary, out *x.DNSSummary) { - if out == nil || s == out { - return - } - - // prefer out - - if len(out.Type) == 0 { - out.Type = s.Type - } - if len(out.ID) <= 0 { - out.ID = s.ID - out.Server = s.Server - out.PID = s.PID - out.RPID = s.RPID - } else if len(out.Server) <= 0 { - out.Server = s.Server - out.PID = s.PID - out.RPID = s.RPID - } - if out.Latency <= 0 { - out.Latency = s.Latency - } - if len(out.UID) <= 0 { - out.UID = s.UID - } - if len(out.QName) == 0 { - // query portions are only filled in if they are empty - out.QName = s.QName - } - if out.QType == 0 { // dns.TypeNone = 0 - out.QType = s.QType - } - if len(out.Region) == 0 { // fill in region if empty - out.Region = s.Region - } - - // prefer s - - if len(s.RData) != 0 { - out.RData = s.RData - out.AD = s.AD - out.DO = s.DO - } - - out.Cached = s.Cached - out.RCode = s.RCode - out.RTtl = s.RTtl - out.Status = s.Status - out.Blocklists = s.Blocklists - out.BlockedTarget = s.BlockedTarget - out.Msg = s.Msg - out.UpstreamBlocks = s.UpstreamBlocks -} - -func rand33pc() bool { - return rand.Intn(99999) < 33000 -} diff --git a/intra/dnsx/overrides.go b/intra/dnsx/overrides.go deleted file mode 100644 index 00a5b6af..00000000 --- a/intra/dnsx/overrides.go +++ /dev/null @@ -1,66 +0,0 @@ -package dnsx - -import ( - "net/netip" - "strings" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -func (h *resolver) isDnsIpPort(addr netip.AddrPort) bool { - for _, dnsaddr := range h.dnsaddrs { - if addr.Compare(dnsaddr) == 0 { - return true - } - } - return false -} - -func (h *resolver) isDnsPort(addr netip.AddrPort) bool { - // isn't h.fakedns.Port always expected to be 53? - for _, dnsaddr := range h.dnsaddrs { - if addr.Port() == dnsaddr.Port() { - return true - } - } - return false -} - -func (h *resolver) isDns(ipp netip.AddrPort) bool { - if !ipp.IsValid() || len(h.dnsaddrs) <= 0 { - log.E("dnsx: missing dst-addr(%v) or dns(%v)", ipp, h.dnsaddrs) - return false - } - dnsmode := settings.DNSMode.Load() - if dnsmode == settings.DNSModeIP { - if yes := h.isDnsIpPort(ipp); yes { - return true - } - } else if dnsmode == settings.DNSModePort { - if yes := h.isDnsPort(ipp); yes { - return true - } - } - return false -} - -func (r *resolver) addDnsAddrs(csvaddr string) { - addrs := strings.Split(csvaddr, ",") - dnsaddrs := make([]netip.AddrPort, 0) - if len(addrs) <= 0 { - log.E("dnsx: missing dnsaddrs(%s)", csvaddr) - return - } - for _, a := range addrs { - if ipp, err := netip.ParseAddrPort(a); ipp.IsValid() && err == nil { - dnsaddrs = append(dnsaddrs, ipp) - } else { - log.W("dnsx: not valid fake udpaddr(%s <=> %s): %v", ipp, a, err) - } - } - if len(dnsaddrs) <= 0 { - log.E("dnsx: no valid dnsaddrs(%s)", csvaddr) - } - r.dnsaddrs = dnsaddrs -} diff --git a/intra/dnsx/plus.go b/intra/dnsx/plus.go deleted file mode 100644 index 8bec445f..00000000 --- a/intra/dnsx/plus.go +++ /dev/null @@ -1,444 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - "context" - "net/netip" - "strings" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -const plusSupportsCachedTransports = false -const plusUsesPreferred = false -const plusUsesSystem = false - -const plusMaxTries = 6 - -const ttl10s = 10 * time.Second - -var fakePlusIpports = []netip.AddrPort{ - netip.MustParseAddrPort("[fdaa:9125::9125:9]:53"), -} - -type plus struct { - mu sync.RWMutex // protects all - transports map[string]Transport // id => transport - - r TransportProviderInternal - ctx context.Context - done context.CancelFunc - ipports []netip.AddrPort - - ba *core.Barrier[[]Transport, string] - - closed *core.Volatile[bool] - last *core.Volatile[Transport] -} - -var _ Transport = (*plus)(nil) -var _ TransportMult = (*plus)(nil) - -func NewPlusTransport(ctx context.Context, r TransportProviderInternal, ts ...Transport) Transport { - ctx, done := context.WithCancel(ctx) - t := &plus{ - ctx: ctx, - transports: make(map[string]Transport, len(ts)), - ba: core.NewBarrier[[]Transport](ttl10s), - r: r, - done: done, - ipports: fakePlusIpports, - closed: core.NewVolatile(false), - last: core.NewZeroVolatile[Transport](), - } - - for _, tr := range ts { - if len(idstr(tr)) > 0 { - t.transports[tr.ID()] = tr - } - } - - log.I("plus: at %s; added: %d/%d", t.getAddr(), len(t.transports), len(ts)) - context.AfterFunc(ctx, t.stopAll) - return t -} - -func (t *plus) stopAll() { - t.closed.Store(true) - - t.mu.Lock() - defer t.mu.Unlock() - - for _, tr := range t.transports { - if err := tr.Stop(); err != nil { - log.E("plus: (%s) stop: %v", t.ID(), err) - } - } - clear(t.transports) - t.last.Store(nil) -} - -// String implements fmt.Stringer -func (t *plus) all() []Transport { - t.mu.RLock() - defer t.mu.RUnlock() - - const all = 0 - return vals(t.transports, all) -} - -func (t *plus) ID() string { - // must match with how wrapping transports like DcProxy / Gateway rely on the ID - return Plus -} - -func (t *plus) Type() string { - return DOH -} - -func (t *plus) latest() Transport { - if t.closed.Load() { - return nil - } - - if l := t.last.Load(); l != nil { - return l - } - - if ts := t.all(); len(ts) > 0 { - return ts[0] - } - return nil -} - -func (t *plus) defaultdns() (Transport, error) { - return t.r.GetInternal(Default) -} - -func (t *plus) systemdns() (Transport, error) { - if !plusUsesSystem { - return nil, errNoSuchTransport - } - return t.r.GetInternal(System) // may return Goos or Default -} - -func (t *plus) preferreddns() (Transport, error) { - if !plusUsesPreferred { - return nil, errNoSuchTransport - } - return t.r.GetInternal(Preferred) // may return Default -} - -func (t *plus) ordered() ([]Transport, error) { - best, preferred, recov, errored, ended := Categorize(t.all()) - - expected := len(best) + len(preferred) + len(recov) + 1 - - ord := make([]Transport, 0, expected) - - if l := t.latest(); l != nil { - ord = append(ord, l) // latest may be nil - } - ord = append(ord, best...) - d, _ := t.defaultdns() - if d != nil { - ord = append(ord, d) - } - sys, _ := t.systemdns() - if sys != nil && idstr(d) != idstr(sys) { - ord = append(ord, sys) - } - p, _ := t.preferreddns() - if p != nil && idstr(d) != idstr(p) { - ord = append(ord, p) - } - ord = append(ord, preferred...) - ord = append(ord, recov...) - - ord = core.CopyUniq(ord) - - prev := ord - strat := settings.PlusStrat.Load() - - refiltered := false -refilter: - switch strat { - case settings.PlusFilterSafest: - ord = core.FilterLeft(ord, IsEncrypted) - case settings.PlusOrderRandom: - ord = core.ShuffleInPlace(ord) - case settings.PlusOrderFastest: - ord = core.Sort(ord, Fastest) - case settings.PlusOrderRobust: - // nothing to do - } - - if !refiltered && len(ord) < plusMaxTries { - ord = core.CopyUniq(ord, errored) - refiltered = true - goto refilter - } - - if len(ord) <= 0 { - log.W("plus: strat %d: zero transports avail [exp: %d]: sys? %s / pref: %s / errored: %v / ended: %v", - strat, expected, idstr(sys), infcsv(preferred...), infcsv(errored...), infcsv(ended...)) - return nil, errNoSuchTransport - } else if len(ord) < len(prev) { - log.VV("plus: strat %d: filtered %d < chosen %d; chosen: %s / pref: %v", - strat, len(ord), len(prev), infcsv(ord...), infcsv(preferred...)) - } - - return ord, nil -} - -func (t *plus) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg, err error) { - if t.closed.Load() { - return nil, NewEndQueryError() - } - - ord, err := t.ba.DoIt("plus.q."+network, t.ordered) - if err != nil { - return nil, err - } - return t.forward(network, q, smm, ord...) -} - -func (t *plus) forward(network string, q *dns.Msg, outSmm *x.DNSSummary, all ...Transport) (finalans *dns.Msg, errs error) { - qname := qname(q) - qtyp := qtype(q) - tries := plusMaxTries - visited := make(map[string]struct{}, len(all)) - finalsmm := copySummary(outSmm) - - defer func() { - fillSummary(finalsmm, outSmm) - if finalans != nil { // suppress errors - log.D("plus: suppressing errors for %s:%d[%s]: %v", qname, qtyp, outSmm.RData, errs) - errs = nil - } - }() - - for _, tr := range all { - cursmm := copySummary(outSmm) - - if len(visited) > tries { - break - } - - if tr == nil { // unlikely - errs = core.JoinErr(errs, errNoSuchTransport) - continue - } - - id := tr.ID() - if plusSupportsCachedTransports { - id, _ = strings.CutPrefix(id, CT) - } - if _, ok := visited[id]; ok { - continue - } - visited[id] = struct{}{} - - ans, err := tr.Query(network, q, cursmm) - - failed := xdns.IsServFailOrInvalid(ans) - noans := !failed && !xdns.HasAnyAnswer(ans) - ipblock := xdns.HasAQuadAQuestion(q) && xdns.AQuadAUnspecified(ans) - // HTTPS/SVCB blocks have 0 answer records when blocked - svcbblock := (xdns.HasHTTPQuestion(q) || xdns.HasSVCBQuestion(q)) && noans - - finalsmm = cursmm - - loged(err != nil || failed)("plus: queried %s for %s:%d; data: %s [noans? %t], code: %d, err? %v", - idstr(tr), qname, qtyp, finalsmm.RData, noans, finalsmm.RCode, err) - - if err != nil || ans == nil { - errs = core.JoinErr(errs, core.OneErr(err, errNoAnswer)) - continue - } - - finalans = ans // may be this is the final answer - - if failed { - errs = core.JoinErr(errs, errServFail) - continue - } - - if ipblock || svcbblock { - errs = core.JoinErr(errs, errBlocked) - continue - } - - // an ipv4 only service/website will always return nxdomain (no answer) - // for ipv6 queries, and so, always treating those as errors is not - // what we want but instead to note the nxdomain response down and see - // if other resolvers return an answer or not (censoring resolvers may - // also return nxdomain or blocked answers, which is why we must continue - // to try other resolvers). - if noans { // servfail, nxdomain, etc. - errs = core.JoinErr(errs, errNoAnswer) - // wind down faster if multiple transports return no answer - if len(visited) <= tries/2 { - continue - } // fallthrough and return current finalans - } - - t.last.Store(tr) - return // current finalans - } - - log.W("plus: [exp: %d / tried: %d]: all transports failed: %v", len(all), len(visited), errs) - return // finalans probably nil -} - -func (t *plus) P50() int64 { - if l := t.latest(); l != nil { - return l.P50() - } - return 0 -} - -func (t *plus) GetAddr() string { - return t.getAddr() -} - -func (t *plus) getAddr() string { - return PrefixFor(t.ID()) + t.ipports[0].String() -} - -func (t *plus) GetRelay() x.Proxy { - return nil -} - -func (t *plus) IPPorts() []netip.AddrPort { - return t.ipports -} - -func (t *plus) Status() int { - if l := t.latest(); l != nil { - return l.Status() - } - return ClientError // see also: bootstrap.go -} - -func (t *plus) Stop() error { - t.done() - return nil -} - -// Add implements TransportMult. -func (t *plus) Add(tr x.DNSTransport) bool { - if tr == nil || core.IsNil(tr) || t.closed.Load() { - return false - } - - newt, ok := tr.(Transport) - if !ok { // unlikely - log.W("plus: add %s: cannot cast %T to Transport", tr.ID(), tr) - return false - } - - cachingTransport := cachedTransport(newt) - oldTransportStopped := false - if !plusSupportsCachedTransports && cachingTransport { - log.W("plus: add %s@%s: err no cached transports", newt.ID(), newt.GetAddr()) - return false - } - - t.mu.Lock() - defer t.mu.Unlock() - - if oldt, ok := t.transports[tr.ID()]; ok { - if oldt == newt { - log.I("plus: add %s@%s: already present", newt.ID(), newt.GetAddr()) - return true - } - core.Gxe("plus.stop."+oldt.ID(), oldt.Stop) - oldTransportStopped = true - } - - t.transports[tr.ID()] = newt - - log.I("plus: add %s@%s; old stopped? %t, cacher? %t", - newt.ID(), newt.GetAddr(), oldTransportStopped, cachingTransport) - return true -} - -// Remove implements TransportMult. -func (t *plus) Remove(id string) (y bool) { - t.mu.Lock() - tr := t.transports[id] - delete(t.transports, id) - t.mu.Unlock() - - if tr != nil { - tr.Stop() - y = true - } - - log.I("plus: remove: %s? %t", id, y) - - return -} - -// Get implements TransportMult. -func (t *plus) Get(id string) (x.DNSTransport, error) { - t.mu.RLock() - defer t.mu.RUnlock() - - if tr, ok := t.transports[id]; ok { - return tr, nil - } - return nil, errNoSuchTransport -} - -func (t *plus) refresh() { - if !plusSupportsCachedTransports { - return - } - for _, t := range t.all() { - // clear caches of cached transports: - if ct := asCachedTransport(t); ct != nil { - ct.Clear() // one at a time ... - } - } -} - -// Refresh implements TransportMult. -func (t *plus) Refresh() (string, error) { - // dialers.Clear in transport.go already clears the cache - // that holds ips <> doh hostnames mapping. - core.Gx("plus.refresh", t.refresh) - return t.LiveTransports(), nil -} - -// LiveTransports implements TransportMult. -func (t *plus) LiveTransports() string { - var ids []string - for _, tr := range t.all() { - if activeTransport(tr) { - ids = append(ids, tr.ID()) - } - } - - return strings.Join(ids, ",") -} - -func loged(cond bool) log.LogFn { - if cond { - return log.E - } - return log.D -} diff --git a/intra/dnsx/queryerror.go b/intra/dnsx/queryerror.go deleted file mode 100644 index 182e2982..00000000 --- a/intra/dnsx/queryerror.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -package dnsx - -import ( - "errors" - - x "github.com/celzero/firestack/intra/backend" -) - -const ( - Start = x.Start - Complete = x.Complete - SendFailed = x.SendFailed - NoResponse = x.NoResponse - BadQuery = x.BadQuery - BadResponse = x.BadResponse - InternalError = x.InternalError - TransportError = x.TransportError - ClientError = x.ClientError - Paused = x.Paused - DEnd = x.DEnd - Unknown = 100 -) - -func Status2Str(status int) string { - switch status { - case Start: - return "Starting" - case Complete: - return "OK" - case SendFailed: - return "Failing" - case NoResponse: - return "No Response" - case BadQuery: - return "Bad Query" - case BadResponse: - return "Misbehaving" - case InternalError: - return "Buggy" - case TransportError: - return "Refusing" - case ClientError: - return "Missing" - case DEnd: - return "End" - case Paused: - return "Paused" - default: - return "Unknown" // 100 - } -} - -var errNop = errors.New("no error") - -type QueryError struct { - status int - err error -} - -func (e *QueryError) Error() string { - if e == nil || e.err == nil { - return "[nil]" - } - return e.err.Error() -} - -func (e *QueryError) Unwrap() error { - if e == nil { - return nil - } - return e.err // may be nil and that's how it should be -} - -func (e *QueryError) Status() int { - if e == nil { - return Unknown // unknown - } - return e.status -} - -func (e *QueryError) strstatus() string { - if e == nil { - return "[nil]" - } - return Status2Str(e.status) -} - -func (e *QueryError) String() string { - if e == nil { - return "[nil]" - } - return e.strstatus() + ":" + e.Error() -} - -func newQueryError(no int, err error) *QueryError { - return &QueryError{no, err} // err may be nil -} - -func NewSendFailedQueryError(err error) *QueryError { - return newQueryError(SendFailed, err) -} - -func NewNoResponseQueryError(err error) *QueryError { - return newQueryError(NoResponse, err) -} - -func NewInternalQueryError(err error) *QueryError { - return newQueryError(InternalError, err) -} - -func NewBadQueryError(err error) *QueryError { - return newQueryError(BadQuery, err) -} - -func NewBadResponseQueryError(err error) *QueryError { - return newQueryError(BadResponse, err) -} - -// with http, for 5xx errors -func NewTransportQueryError(err error) *QueryError { - return newQueryError(TransportError, err) -} - -// with http, for 4xx errors -func NewClientQueryError(err error) *QueryError { - return newQueryError(ClientError, err) -} - -func NewPausedQueryError() *QueryError { - return newQueryError(Paused, errTransportPaused) -} - -func NewEndQueryError() *QueryError { - return newQueryError(DEnd, errTransportEnd) -} diff --git a/intra/dnsx/rethinkdns.go b/intra/dnsx/rethinkdns.go deleted file mode 100644 index d501fa6d..00000000 --- a/intra/dnsx/rethinkdns.go +++ /dev/null @@ -1,683 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - b32 "encoding/base32" - b64 "encoding/base64" - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - - "github.com/miekg/dns" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - - "slices" - - "github.com/celzero/gotrie/trie" -) - -const ( - localBlock = 0 - remoteBlock = 1 - - // blocklist stamp separator for base64 encoding - colonsep = ":" - // blocklist stamp separator for base32 encoding - hyphensep = "-" -) - -// supported blocklist stamp versions: 0 or 1 -const ( - ver1 = "1" - ver0 = "0" -) - -// encoding type, base32 or base64 -const ( - EB32 = x.EB32 - EB64 = x.EB64 -) - -var ( - errRemote = errors.New("op not valid in remote block mode") - errNoStamps = errors.New("no stamp set") - errMissingCsv = errors.New("zero comma-separated flags") - errFlagsMismatch = errors.New("flagcsv does not match loaded flags") - errNotEnoughAnswers = errors.New("req at least one answer") - errTrieArgs = errors.New("missing data, unable to build blocklist") - errNoBlocklistMatch = errors.New("no blocklist applies") -) - -type RdnsResolver interface { - x.RDNSResolver - blockQ(Transport, Transport, *dns.Msg) (*dns.Msg, string, error) - blockA(Transport, Transport, *dns.Msg, *dns.Msg, string) (*dns.Msg, string, string) -} - -// ResolverSelf is for internal resolution needs. -type ResolverSelf interface { - // LocalLookup performs resolution on Default and/or Goos DNSes. - // To be only used by protect.UidSelf. - LocalLookup(q []byte) (a []byte, tid string, err error) - // LookupFor performs resolution for uid. - LookupFor(q []byte, uid string) (a []byte, tid string, err error) - // Lookup performs resolution on chosen Transport for uid. - LookupFor2(q []byte, uid string, chosen ...string) (a []byte, tid string, err error) -} - -type RDNS interface { - x.RDNS - OnDeviceBlock() bool // Mode - blockQuery(*dns.Msg) (string, error) - blockAnswer(*dns.Msg) (string, string, error) -} - -type rethinkdns struct { - // value -> group:name - flags []string - // uname -> group:name - tags map[string]string - mode int - stamp string -} - -type rethinkdnslocal struct { - *rethinkdns - ftrie *trie.FrozenTrie -} - -type listinfo struct { - pos int - name string -} - -var _ RDNS = (*rethinkdnslocal)(nil) -var _ RDNS = (*rethinkdns)(nil) - -func newRDNSRemote(filetagjson string) (*rethinkdns, error) { - flags, tags, err := load(filetagjson) - if err != nil { - return nil, err - } - r := &rethinkdns{ - flags: flags, - tags: tags, - mode: remoteBlock, - } - return r, nil -} - -func newRDNSLocal(t string, rank string, - conf string, filetagjson string) (*rethinkdnslocal, error) { - - if len(t) <= 0 || len(rank) <= 0 || len(conf) <= 0 || len(filetagjson) <= 0 { - return nil, errTrieArgs - } - - ft, err := trie.Build(t, rank, conf, filetagjson, trie.Fmmap) - if err != nil { - return nil, err - } - - flags, tags, err := load(filetagjson) - if err != nil { - return nil, err - } - - // docs.pi-hole.net/ftldns/blockingmode - r := &rethinkdns{ - // pos/index/value ->subgroup:vname - flags: flags, - // uname -> subgroup:vname - tags: tags, - mode: localBlock, - } - rlocal := &rethinkdnslocal{ - rethinkdns: r, - ftrie: ft, - } - - return rlocal, nil -} - -func (r *rethinkdns) OnDeviceBlock() bool { - return r.mode == localBlock -} - -func (r *rethinkdns) GetStamp() (string, error) { - return r.getStamp() -} - -func (r *rethinkdns) getStamp() (s string, err error) { - if !r.OnDeviceBlock() { - err = errRemote - return - } - if len(r.stamp) <= 0 { - s = "" - } else { - s = r.stamp - } - return -} - -func (r *rethinkdns) SetStamp(stamp string) error { - return r.setStamp(stamp) -} - -func (r *rethinkdns) setStamp(stamp string) error { - if !r.OnDeviceBlock() { - return errRemote - } - if len(stamp) <= 0 { - r.stamp = "" - } else { - // normalize also validates the stamp - if nm, err := r.normalizeStamp(stamp); err != nil { - return err - } else { - r.stamp = nm - } - } - return nil -} - -// Returns blockstamp given comma-separated blocklist ids -func (r *rethinkdns) FlagsToStamp(flagscsv string, enctyp int) (string, error) { - return r.flagsToStamp(flagscsv, enctyp) -} - -func (r *rethinkdns) flagsToStamp(flagscsv string, enctyp int) (string, error) { - fstr := strings.Split(flagscsv, ",") - if len(fstr) <= 0 || firstEmpty(fstr) { - return "", errMissingCsv - } - - flags := make([]uint16, len(fstr)) - for i, s := range fstr { - val, err := strconv.Atoi(s) - if err != nil { - return "", err - } - if i >= len(flags) { - return "", errFlagsMismatch - } - flags[i] = uint16(val) - } - - if stamp, err := r.flagtostamp(flags); err != nil { - return "", err - } else { - return encode(ver1, stamp, enctyp) - } -} - -// Returns comma-separated blocklist ids, given a stamp of form version:base64 -func (r *rethinkdns) StampToFlags(stamp string) (string, error) { - return r.stampToFlags(stamp) -} - -func (r *rethinkdns) stampToFlags(stamp string) (string, error) { - blocklists, err := r.stampToBlocklist(stamp) - if err != nil { - return "", err - } - - var blocklistids []string - for _, x := range blocklists { - blocklistids = append(blocklistids, fmt.Sprint(x.pos)) - } - - return strings.Join(blocklistids[:], ","), nil -} - -func (r *rethinkdns) StampToNames(stamp string) (string, error) { - return r.stampToNames(stamp) -} - -func (r *rethinkdns) stampToNames(stamp string) (string, error) { - blocklists, err := r.stampToBlocklist(stamp) - if err != nil { - return "", err - } - - var blocklistnames []string - for _, x := range blocklists { - blocklistnames = append(blocklistnames, x.name) - } - - return strings.Join(blocklistnames[:], ","), nil -} - -func (r *rethinkdns) stampToBlocklist(stamp string) ([]*listinfo, error) { - if len(stamp) <= 0 { - return nil, errNoStamps - } - - // b64 -> 1:YAYBACABEDAgAA== / b32 -> 1-madacabaaeidaiaa - colonidx := strings.Index(stamp, colonsep) - hyphenidx := strings.Index(stamp, hyphensep) - isb32 := hyphenidx >= 0 && (hyphenidx < colonidx || colonidx < 0) - versep := colonsep - enctyp := EB64 - if isb32 { - versep = hyphensep - enctyp = EB32 - } - s := strings.Split(stamp, versep) - if len(s) > 1 { - return r.decode(s[1], s[0], enctyp) - } else { - return r.decode(stamp, "0", enctyp) - } -} - -func (r *rethinkdns) keyToNames(list []string) (v []string) { - for _, l := range list { - x := r.tags[l] - if len(x) > 0 { // TODO: else err? - v = append(v, x) - } - } - return -} - -func (r *rethinkdns) blockQuery(*dns.Msg) (b string, err error) { err = errRemote; return } -func (r *rethinkdns) blockAnswer(*dns.Msg) (t, b string, err error) { err = errRemote; return } - -func (r *rethinkdnslocal) blockQuery(msg *dns.Msg) (blocklists string, err error) { - if len(msg.Question) <= 0 { - err = errMissingQueryName - return - } - - stamp, err := r.getStamp() - if err != nil { - return - } - if len(stamp) <= 0 { - err = errNoStamps - return - } - for _, quest := range msg.Question { - // err when incoming name != ascii, ignore - qname, _ := xdns.NormalizeQName(quest.Name) - qtype := msg.Question[0].Qtype - if !(xdns.IsAAAAQType(qtype) || xdns.IsAQType(qtype) || xdns.IsSVCBQType(qtype) || xdns.IsHTTPSQType(qtype)) { - err = fmt.Errorf("unsupported dns query type %v", qtype) - return - } - block, lists := r.ftrie.DNlookup(qname, stamp) - // TODO: handle empty lists as err? - if block { - blocklists = strings.Join(r.keyToNames(lists), ",") - return - } - } - err = errNoBlocklistMatch - return -} - -func (r *rethinkdnslocal) blockAnswer(msg *dns.Msg) (blockedtarget, blocklists string, err error) { - if msg == nil { - err = errNoAnswer - return - } - ans := msg.Answer - if len(ans) <= 0 { - err = errNotEnoughAnswers - return - } - stamp, err := r.getStamp() - if err != nil { - return - } - if len(stamp) <= 0 { - err = errNoStamps - return - } - - qname := xdns.QName(msg) - // handle cname, https/svcb name cloaking: news.ycombinator.com/item?id=26298339 - // adopted from: github.com/DNSCrypt/dnscrypt-proxy/blob/6e8628f79/dnscrypt-proxy/plugin_block_name.go#L178 - for _, a := range ans { - var target string - switch rr := a.(type) { - case *dns.CNAME: - target = rr.Target - case *dns.SVCB: - if rr.Priority == 0 { - target = rr.Target - } - case *dns.HTTPS: - if rr.Priority == 0 { - target = rr.Target - } - default: - // no-op - } - - if len(target) <= 0 { - continue - } - // if target is ".", then it is a self-reference to the qname - if len(target) == 1 && target[0] == '.' { - target = qname - } - - // ignore err when incoming name != ascii - target, _ = xdns.NormalizeQName(target) - block, lists := r.ftrie.DNlookup(target, stamp) - if block { // TODO: handle empty lists as err? - blockedtarget = target - blocklists = strings.Join(r.keyToNames(lists), ",") - return - } - - log.D("rdns: blockAnswer: no block for target %s, qname %s", target, qname) - } - - err = fmt.Errorf("answers not in blocklist %s", stamp) - return -} - -func load(configjson string) ([]string, map[string]string, error) { - configjson = filepath.Clean(configjson) - data, err := os.ReadFile(configjson) - if err != nil { - return nil, nil, err - } - - var obj map[string]any - err = json.Unmarshal(data, &obj) - if err != nil { - return nil, nil, err - } - - rflags := make([]string, len(obj)) - fdata := make(map[string]string) - // example: - // { - // "XYZ": { - // "value":171, - // "uname":"XYZ", - // "vname":"1Hosts", - // "group":"privacy", - // "subg":"", - // "url":"badmojr.github.io...", - // "show":1, - // "entries":511684 - // } - // ... - // } - for key := range obj { - indata, _ := obj[key].(map[string]any) - if indata == nil { // should not happen - continue - } - findex, _ := indata["value"].(float64) - index := int(findex) - name, _ := indata["vname"].(string) - subgroup, _ := indata["subg"].(string) - group, _ := indata["group"].(string) - - if len(subgroup) <= 0 { - subgroup = group - } - if len(name) <= 0 { - name = subgroup - subgroup = group - } - // 171 -> privacy:1Hosts - rflags[index] = subgroup + ":" + name - // XYZ -> privacy:1Hosts - fdata[key] = subgroup + ":" + name - } - return rflags, fdata, nil -} - -func (r *rethinkdns) decode(stamp, ver string, enctyp int) (info []*listinfo, err error) { - haspad := strings.Contains(stamp, "=") - decoder := b64.RawStdEncoding - decoder32 := b32.StdEncoding.WithPadding(b32.NoPadding) - if haspad { - decoder = b64.StdEncoding - decoder32 = b32.StdEncoding.WithPadding(b32.StdPadding) - } - if ver == ver0 { - stamp, err = url.QueryUnescape(stamp) - } else if ver == ver1 { - decoder = b64.RawURLEncoding - if haspad { - decoder = b64.URLEncoding - } - } else { - err = fmt.Errorf("version %s unsupported", ver) - } - if err != nil { - return nil, err - } - - var buf []byte - if enctyp == EB32 { - buf, err = decoder32.DecodeString(strings.ToUpper(stamp)) - } else { - buf, err = decoder.DecodeString(stamp) - } - if err != nil { - return - } - - var u16 []uint16 - if ver == ver0 { - u16 = str2uint16(string(buf)) - } else if ver == ver1 { - u16 = byte2uint16(buf) - } else { - err = fmt.Errorf("unimplemented header stamp version %v", ver) - return - } - - return r.flagstoinfo(u16) -} - -func (r *rethinkdns) flagstoinfo(flags []uint16) ([]*listinfo, error) { - // flags has to be an array of 16-bit integers. - - // first index always contains the header - header := uint16(flags[0]) - // store of each big-endian position of set bits in header - tagIndices := []int{} - values := make([]*listinfo, 0) - var mask uint16 - - // b1000,0000,0000,0000 - mask = 0x8000 - - // read first 16 header bits from msb to lsb - // and capture indices of set bits in tagIndices - for i := range 16 { - if (header << i) == 0 { - break - } - if (header & mask) == mask { - tagIndices = append(tagIndices, i) - } - mask = mask >> 1 // shift to read the next msb bit - } - // the number of set bits in header must correspond to total - // blocklist "flags" excluding the header at position 0 - if len(tagIndices) != (len(flags) - 1) { - err := fmt.Errorf("%v %v flags and header mismatch", tagIndices, flags) - return nil, err - } - - // for all blocklist flags excluding the header - // figure out the blocklist-ids - for i := 1; i < len(flags); i++ { - // 16 blocklists are represented by one flag; ie, - // one bit per blocklist; flag[0] is the header. - var flag = uint16(flags[i]) - // get the index of the current flag in the header - var index = tagIndices[i-1] - // b1000,0000,0000,0000; 1<<15 - mask = 0x8000 - // for each of the 16 bits in the flag - // capture the set bits and calculate - // its actual decimal value, the blocklist-id - for j := range 16 { - if (flag << j) == 0 { - break - } - if (flag & mask) == mask { - pos := (index * 16) + j - if pos >= len(r.flags) { - // github.com/celzero/firestack/issues/5 - // silently ignore scenarios where stamp encodes many - // more blocklsts than what's currently loaded - continue - } - // from the decimal value which is its - // blocklist-id, fetch its metadata - values = append(values, &listinfo{pos, r.flags[pos]}) - } - mask = mask >> 1 - } - } - return values, nil -} - -// convert int flags (blocklist-ids) to a packed uint16 stamp -func (r *rethinkdns) flagtostamp(fl []uint16) ([]uint16, error) { - const u1 = uint16(1) - res := []uint16{0} - - w := trie.W - uw := uint16(w) - for _, val := range fl { - hindex := val / uw - pos := val % uw - - h := &res[0] - n := uint16(0) - - // only header present in res, append 'n' to it - if len(res) == 1 { - *h |= u1 << (15 - hindex) - n |= u1 << (15 - pos) - res = append(res, n) - continue - } - - mm := int(uw - hindex) - ww := trie.MaskLo[w] - if mm < 0 || len(ww) <= 0 || mm >= len(ww) { - continue // should not happen - } - hmask := *h & ww[mm] - databit := *h >> (15 - hindex) - dataindex := countSetBits(hmask) + 1 - datafound := (databit & 0x1) == 1 - // if !datafound { - // log too verbose - // log.Debugf("!!flag not found: len(res) %d / dataindex %d / found? %t\n", len(res), dataindex, datafound) - // } - if datafound { - // upsert, as in 'n' is updated in-place - n = res[dataindex] - n |= u1 << (15 - pos) - res[dataindex] = n - } else { - *h |= u1 << (15 - hindex) - n |= u1 << (15 - pos) - // insert 'n' between res[:dataindex] and r[dataindex:] - nxt := slices.Clone(res[:dataindex]) - nxt = append(nxt, n) - if dataindex < len(res) { - nxt = append(nxt, res[dataindex:]...) - } - res = nxt - } - // log too verbose - // log.Debugf("done: %d/%x | %x | n:%x / hidx: %d / mask: %x / databit: %x / didx: %d\n", val, res, *h, n, hindex, hmask, databit, dataindex) - } - - return res, nil -} - -func encode(ver string, u16 []uint16, enctyp int) (string, error) { - if ver != ver1 { - return "", fmt.Errorf("version %s unsupported / len(input): %d", ver, len(u16)) - } - - buf := uint16tobyte(u16) - if enctyp == EB32 { - out := b32.StdEncoding.WithPadding(b32.NoPadding).EncodeToString(buf) - return ver + hyphensep + strings.ToLower(out), nil - } - // decode may recv padded or unpadded stamps, but always encode with pad - // as FrozenTrie.DNLookup expects only padded b64url for ver1 - return ver + colonsep + b64.URLEncoding.EncodeToString(buf), nil -} - -// normalizeStamp stamp to base64url padded format if its base32 -func (r *rethinkdns) normalizeStamp(s string) (string, error) { - // b64 -> 1:YAYBACABEDAgAA== / b32 -> 1-madacabaaeidaiaa - // b32 -> b64 - flagscsv, err := r.stampToFlags(s) // validate stamp - if err != nil { - return "", err - } - colonidx := strings.Index(s, colonsep) - hyphenidx := strings.Index(s, hyphensep) - isb32 := hyphenidx >= 0 && (hyphenidx < colonidx || colonidx < 0) - if !isb32 { - return s, nil - } - return r.flagsToStamp(flagscsv, EB64) // encode as b64 -} - -func str2uint16(str string) []uint16 { - runedata := []rune(str) - resp := make([]uint16, len(runedata)) - for key, value := range runedata { - resp[key] = uint16(value) - } - return resp -} - -func byte2uint16(b []byte) []uint16 { - data := make([]uint16, len(b)/2) - for i := range data { - // assuming little endian - data[i] = binary.LittleEndian.Uint16(b[i*2 : (i+1)*2]) - } - return data -} - -func uint16tobyte(u16 []uint16) []byte { - bytes := make([]byte, len(u16)*2) - for i, v := range u16 { - binary.LittleEndian.PutUint16(bytes[i*2:(i+1)*2], v) - } - return bytes -} - -// return the count of set bits in n -func countSetBits(n uint16) int { - return (trie.BitsetTable256[n&0xff] + trie.BitsetTable256[(n>>8)&0xff]) -} diff --git a/intra/dnsx/rethinkdns_test.go b/intra/dnsx/rethinkdns_test.go deleted file mode 100644 index 706772d0..00000000 --- a/intra/dnsx/rethinkdns_test.go +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - "fmt" - "testing" -) - -var ( - v0case0 = "6b%2Bg67y%2Bz7%2Fvv7%2Fvv7ztlaDvgIDkhIDnhYTogKA%3D" - v0case1 = "77%2Bg77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2B%2F77%2Bg" - v1case0 = "1:ENz_PwDwfwD___j_YKE=" // same as fcase0 - v1case1 = "1:4J8-v_8D___8_2DVAPAAQURxIIA=" - v1case2 = "1:4P___________________________-D_" - v1case2a = "1-4d7777777777777777777777777777777776b7y" - v1case3 = "1:ENz_PwDw_wP___j_YKk=" // same as fcase3 - v1case4 = "1:MNz_PwDw_wP___j_BABgqQ==" // same as fcase4 -) - -var ( - fcase0 = []uint16{ // same as v1case0 - 15, 16, 17, 18, 186, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 176, 183, 3, 4, 185, 2, 19, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 178, - } - fcase3 = []uint16{ - 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 178, 15, 16, 17, 18, 186, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 176, 183, 3, 4, 185, 2, 19, 54, 55, 56, 180, - } - fcase4 = []uint16{ - 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 178, 15, 16, 17, 18, 186, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 176, 183, 3, 4, 185, 2, 19, 54, 55, 56, 180, 173, - } -) - -func TestGeneric(tester *testing.T) { - r, f := load1() - b := &rethinkdns{ - flags: r, - tags: f, - } - - // decode v0 to blocklist-info - _, err := b.decode(v0case0, ver0, EB64) - ok("v0case0", err) - _, err = b.decode(v0case1, ver0, EB64) - ok("v0case1", err) - fmt.Println("-------------------------------------------------") - - // blockstamp to flags (csv) - f0, err := b.StampToFlags(v1case0) - ko(tester, err) - s0, err := b.FlagsToStamp(f0, EB64) - ko(tester, err) - fmt.Println("case0; ok?", s0 == v1case0, "\t\t", f0, s0) - fmt.Println("-------------------------------------------------") - - f1, err := b.StampToFlags(v1case1) - ko(tester, err) - s1, err := b.FlagsToStamp(f1, EB64) - ko(tester, err) - fmt.Println("case1; ok?", s1 == v1case1, "\t\t", f1, s1) - fmt.Println("-------------------------------------------------") - - f2, err := b.StampToFlags(v1case2) - ko(tester, err) - f2a, err := b.StampToFlags(v1case2a) - ko(tester, err) - s2, err := b.FlagsToStamp(f2, EB64) - ko(tester, err) - s21, err := b.FlagsToStamp(f2, EB32) - ko(tester, err) - s2a, err := b.FlagsToStamp(f2a, EB64) - ko(tester, err) - - fmt.Println("case2/2a; ok?", s2 == v1case2, s21 == v1case2a, s2a == v1case2, "\t\t", f2, s2) - fmt.Println("-------------------------------------------------") - - f3, err := b.StampToFlags(v1case3) - ko(tester, err) - s3, err := b.FlagsToStamp(f3, EB64) - ko(tester, err) - fmt.Println("case3; ok?", s3 == v1case3, "\t\t", f3, s3) - fmt.Println("-------------------------------------------------") - - f4, err := b.StampToFlags(v1case4) - ko(tester, err) - s4, err := b.FlagsToStamp(f4, EB64) - ko(tester, err) - fmt.Println("case4; ok?", s4 == v1case4, "\t\t", f4, s4) - fmt.Println("-------------------------------------------------") - - // flag to blockstamp test - ustamp0, err := b.flagtostamp(fcase0) - ko(tester, err) - stamp0, err := encode(ver1, ustamp0, EB64) - ko(tester, err) - fmt.Println("fcase0; ok?", stamp0 == v1case0, "\t\t", ustamp0, stamp0) - fmt.Println("-------------------------------------------------") - - ustamp3, err := b.flagtostamp(fcase3) - ko(tester, err) - stamp3, err := encode(ver1, ustamp3, EB64) - ko(tester, err) - fmt.Println("fcase3 ok?", stamp3 == v1case3, "\t\t", ustamp3, stamp3) - fmt.Println("-------------------------------------------------") - - ustamp4, err := b.flagtostamp(fcase4) - ko(tester, err) - stamp4, err := encode(ver1, ustamp4, EB64) - ko(tester, err) - fmt.Println("fcase4 ok?", stamp4 == v1case4, "\t\t", ustamp4, stamp4) - fmt.Println("-------------------------------------------------") - - err = b.SetStamp(v1case2a) // v1case2 is its base64 representation - ko(tester, err) - gstamp0, err := b.GetStamp() // always returns as base64 - ko(tester, err) - fmt.Println("gcase0 ok?", gstamp0 == v1case2, "\t\t", gstamp0) -} - -func load1() ([]string, map[string]string) { - obj := map[int]string{ - 0: "MTF", - 1: "KBI", - 2: "YAC", - 3: "HBP", - 4: "NIM", - 5: "YWG", - 6: "SMQ", - 7: "AQX", - 8: "BTG", - 9: "GUN", - 10: "KSH", - 11: "WAS", - 12: "AZY", - 13: "GWB", - 14: "YMG", - 15: "CZM", - 16: "HYS", - 17: "XIF", - 18: "TQN", - 19: "ZVO", - 20: "YOM", - 21: "THR", - 22: "RPW", - 23: "AMG", - 24: "WTJ", - 25: "ZXU", - 26: "FJG", - 27: "NYS", - 28: "OKG", - 29: "KNP", - 30: "FLI", - 31: "RYX", - 32: "CIH", - 33: "PTE", - 34: "KEA", - 35: "CMR", - 36: "DDO", - 37: "VLM", - 38: "JEH", - 39: "XLX", - 40: "OQW", - 41: "FXC", - 42: "HZJ", - 43: "SWK", - 44: "VAM", - 45: "AOS", - 46: "FAL", - 47: "CZK", - 48: "FZB", - 49: "PYW", - 50: "JXA", - 51: "KOR", - 52: "DEP", - 53: "RFX", - 54: "DTT", - 56: "RAF", - 55: "VZP", - 57: "THG", - 58: "YVH", - 59: "XQV", - 60: "PIB", - 61: "EEN", - 62: "GDA", - 63: "MAD", - 64: "NAK", - 65: "BPZ", - 66: "HWO", - 67: "YUC", - 68: "IKY", - 69: "LSS", - 70: "NOE", - 71: "PLR", - 72: "FIT", - 73: "LHX", - 74: "FOF", - 75: "DYA", - 76: "JAN", - 77: "FHQ", - 78: "CMC", - 79: "RKG", - 80: "XMK", - 81: "GAX", - 82: "RFI", - 83: "AZR", - 84: "CEN", - 85: "SPR", - 86: "MZT", - 87: "NHM", - 88: "GLV", - 89: "NUY", - 90: "EDM", - 91: "ZFC", - 92: "DOP", - 93: "XGC", - 94: "OHE", - 95: "MYS", - 96: "IAJ", - 97: "EAQ", - 98: "AOC", - 99: "XAT", - 100: "OSE", - 101: "IBB", - 102: "EGX", - 103: "HZD", - 104: "FLW", - 105: "ULZ", - 106: "OFY", - 107: "MLE", - 108: "YER", - 109: "DMC", - 110: "IJO", - 111: "OWW", - 112: "EMY", - 113: "XKM", - 114: "CQT", - 115: "ANW", - 116: "DGE", - 117: "BBS", - 118: "OKW", - 119: "ONV", - 120: "CDE", - 121: "PAL", - 122: "DBP", - 123: "MHP", - 124: "EPR", - 125: "OUU", - 126: "YXS", - 127: "UQK", - 128: "GVI", - 129: "TXJ", - 130: "DPY", - 131: "DUC", - 132: "WYE", - 133: "CGF", - 134: "JRV", - 135: "EOK", - 136: "HQL", - 137: "NNH", - 138: "KRM", - 139: "QKN", - 140: "MPR", - 141: "EOO", - 142: "MDE", - 143: "WWI", - 144: "TTI", - 145: "GFJ", - 146: "WOD", - 147: "YJR", - 148: "WIB", - 149: "NUI", - 150: "XIO", - 151: "OBW", - 152: "YBO", - 153: "TTW", - 154: "NML", - 155: "MIN", - 156: "IFD", - 157: "AMI", - 158: "TZF", - 159: "VKE", - 160: "PWQ", - 161: "KUA", - 162: "FHW", - 163: "AGZ", - 164: "IVN", - 165: "FIB", - 166: "FGF", - 167: "FLL", - 168: "IVO", - 169: "ALQ", - 170: "FHM", - 171: "AA1", - 172: "AA2", - 173: "AA3", - 174: "AA4", - 175: "AA5", - 176: "AA6", - 177: "AA7", - 178: "AA8", - 179: "AA9", - 180: "AB0", - 181: "AB1", - 182: "AB2", - 183: "AB3", - 184: "AB4", - 185: "AB5", - 186: "AB6", - 187: "AB7", - 188: "AB8", - } - - rflags := make([]string, len(obj)) - fdata := make(map[string]string) - for key := range obj { - val := obj[key] - rflags[key] = val - fdata[val] = val - } - return rflags, fdata -} - -func ko(t *testing.T, err error) { - if err != nil { - t.Error(err) - } -} - -func ok(tag string, err error) { - if err != nil { - fmt.Println(tag, err) - } -} diff --git a/intra/dnsx/transport.go b/intra/dnsx/transport.go deleted file mode 100644 index 26424ac3..00000000 --- a/intra/dnsx/transport.go +++ /dev/null @@ -1,1571 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "io" - "net/netip" - "slices" - "strings" - "sync" - "sync/atomic" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/protect/ipmap" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -const ( - // DNS transport types - DOH = x.DOH - DNSCrypt = x.DNSCrypt - DNS53 = x.DNS53 - DOT = x.DOT - ODOH = x.ODOH - - // DNS decorators - CT = x.CT - - // DNS transport IDs - Goos = x.Goos - System = x.System - Local = x.Local - Default = x.Default - Preferred = x.Preferred - Preset = x.Preset - Fixed = x.Fixed - BlockFree = x.BlockFree - Bootstrap = x.Bootstrap - BlockAll = x.BlockAll - Alg = x.Alg - DcProxy = x.DcProxy - Plus = x.Plus - IpMapper = x.IpMapper - NoDNS = "" - - // DNS request origin indicators - OriginInternal = x.OriginInternal - OriginTunnel = x.OriginTunnel - - invalidQname = "invalid.query" - - // preferred network to use with t.Query - NetTypeUDP = "udp" - NetTypeTCP = "tcp" - // preferred forwarding network, if any - // ipn.Base is treated as a no-proxy - NetBaseProxy = x.Base - NetAutoProxy = x.Auto - NetNoProxy = x.Block - NetExitProxy = x.Exit - - ttl10m = 10 * time.Minute - - listenerTimeout = 3 * time.Second - - // pseudo transport ID to tag dns64 responses - AlgDNS64 = "dns64" -) - -var ( - selfprefix = protect.UidSelf + "." - systemprefix = protect.UidSystem + "." - algprefix = "alg." - cacheprefix = "cached." - plusprefix = "plus." - d64prefix = "64." - defaultprefix = "d." - presetprefix = "pre." - fixedprefix = "fix." - EchPrefix = "ech." - NoPkiPrefix = "nopki." - - NoIPPort []netip.AddrPort = nil -) - -var ( - ErrNotDefaultTransport = errors.New("dns: not a default transport") - ErrNoDcProxy = errors.New("dns: no dnscrypt-proxy") - ErrNoProxyProvider = errors.New("dns: no proxy provider") - ErrNoProxyDNS = errors.New("dns: no proxy") - ErrAddFailed = errors.New("dns: add failed") - errNoSuchTransport = errors.New("dns: missing transport") - errTransportEnd = errors.New("dns: transport ended") - errTransportPaused = errors.New("dns: transport paused") - errOnQueryTimeout = errors.New("dns: timeout fetching prefs") - errOnUpstreamAnswerTimeout = errors.New("dns: timeout fetching prefs for upstream answer") - errBlockFreeTransport = errors.New("dns: block free transport") - errNoRdns = errors.New("dns: no rdns") - errTransportNotMult = errors.New("dns: not a multi-transport") - errTransportNotMDNS = errors.New("dns: not an mdns transport") - errMissingQueryName = errors.New("dns: no query name") - errResolverClosed = errors.New("dns: closed for business") -) - -type MDNSTransport interface { - Transport - RefreshProto(protos string) -} - -// Transport represents a DNS query transport. This interface is exported by gobind, -// so it has to be very simple. -type Transport interface { - x.DNSTransport - // Given a DNS query (including ID), returns a DNS response with matching - // ID, or an error if no response was received. The error may be accompanied - // by a SERVFAIL response if appropriate. - Query(network string, q *dns.Msg, summary *x.DNSSummary) (*dns.Msg, error) - // IPPorts returns all ip:ports of this server. - IPPorts() []netip.AddrPort - // Stop closes the transport. - Stop() error -} - -// TransportMult is a hybrid: transport and a multi-transport. -type TransportMult interface { - x.DNSTransportMult - Transport -} - -type TransportMultInternal interface { - x.DNSTransportMult -} - -type TransportMultProviderInternal interface { - x.DNSTransportMultProvider - // GetMultInternal returns multi-transport, if available - GetMultInternal(id string) (TransportMult, error) -} - -type TransportProviderInternal interface { - x.DNSTransportProvider - // GetInternal returns the internal transport interface for the given ID. - GetInternal(id string) (Transport, error) - - // special purpose pre-defined transports: - - // Gateway implements a DNS ALG transport - Gateway() Gateway - // MDNS returns the mdns transport, if available; error otherwise. - MDNS() (MDNSTransport, error) -} - -type Resolver interface { - TransportProviderInternal - TransportMultProviderInternal - TransportMultInternal - ResolverSelf - RdnsResolver - NatPt - - // IsDnsAddr returns true if the ip:port is resolver's fake endpoint - IsDnsAddr(ipport netip.AddrPort) bool - // Serve reads DNS query from conn and writes DNS answer to conn - Serve(proto string, conn protect.Conn, uid string) (rx, tx int64, errs []error) - - // StopAll stops all transports. - StopAll() - - // S reveals internal state for debugging - S() string -} - -type resolver struct { - sync.RWMutex // protects transports - NatPt - ctx context.Context - done context.CancelFunc - dnsaddrs []netip.AddrPort - transports map[string]Transport - gateway Gateway - localdomains x.RadixTree - listener x.DNSListener - smms chan *x.DNSSummary - - once sync.Once - closed atomic.Bool - - // mutable fields - rmu sync.RWMutex // protects rdnsr and rdnsl - rdnsl *rethinkdnslocal - rdnsr *rethinkdns -} - -var _ Resolver = (*resolver)(nil) -var _ x.DNSResolver = (*resolver)(nil) - -func NewResolver(pctx context.Context, fakeaddrs string, dtr x.DNSTransport, l x.DNSListener, pt NatPt) *resolver { - ctx, cancel := context.WithCancel(pctx) - r := &resolver{ - ctx: ctx, - done: cancel, - NatPt: pt, - listener: l, - smms: make(chan *x.DNSSummary, 64), - transports: make(map[string]Transport), - localdomains: ipmap.UndelegatedDomainsTrie, - } - r.loadaddrs(fakeaddrs) - r.gateway = NewDNSGateway(ctx, r.dnsaddrs, r, pt) - if dtr.ID() != Default { - log.W("dns: not default; ignoring %s @ %s", dtr.ID(), dtr.GetAddr()) - } else if tr, ok := dtr.(Transport); !ok { - log.W("dns: not a transport; ignoring", dtr.ID(), dtr.GetAddr()) - } else { - ctr := NewCachingTransport(tr, ttl10m) - r.Lock() - r.transports[idstr(tr)] = tr // regular - if ctr != nil { - r.transports[idstr(ctr)] = ctr // cached - } else { - log.W("dns: no caching transport for %s", tr.ID()) - } - r.Unlock() - } - log.I("dns: new! gw? %t; default? %s", r.gateway != nil, dtr.GetAddr()) - - core.Go("r.Listener", r.sendSummaries) - context.AfterFunc(ctx, r.StopAll) - return r -} - -// sendSummaries sends summaries to the listener. -// Must be run in a goroutine. -func (r *resolver) sendSummaries() { - for smm := range r.smms { - r.listener.OnResponse(smm) - } -} - -func (r *resolver) queueSummary(smm *x.DNSSummary) { - if smm == nil { - return - } - select { - case <-r.ctx.Done(): - log.W("dns: fwd: smms closed; dropping %s", smm) - default: - select { - case <-r.ctx.Done(): - case r.smms <- smm: - default: - log.W("dns: fwd: smms full; dropping %s", smm) - } - } -} - -func (r *resolver) Gateway() Gateway { - return r.gateway -} - -func (r *resolver) MDNS() (MDNSTransport, error) { - r.RLock() - defer r.RUnlock() - if t, ok := r.transports[Local]; ok { - if mdnst, ok := t.(MDNSTransport); ok { - return mdnst, nil - } - return nil, errTransportNotMDNS - } - return nil, errNoSuchTransport -} - -func (r *resolver) Translate(tr, fix bool) { - r.gateway.translate(tr, fix) -} - -// stopIfExistsLocked stops the transport if it exists, -// then deletes it from the map. -func (r *resolver) stopIfExistsLocked(id string) { - if t, ok := r.transports[id]; ok && t != nil { - core.Go("r.gateway.stopTid", func() { - err := t.Stop() - r.gateway.onStopped(id) - log.VV("dns: stop: %s; err? %v", id, err) - }) - delete(r.transports, id) - } -} - -// Implements Resolver -func (r *resolver) Add(dt x.DNSTransport) (ok bool) { - if r.closed.Load() { - log.W("dns: add: closed for business") - return false - } - if dt == nil || core.IsNil(dt) { - log.D("dns: cannot add nil transports") - return false - } - t, ok := dt.(Transport) - if !ok { // unlikely - return false - } - tid := idstr(t) - if tid == Default || cachedTransport(t) { - log.W("dns: cannot re-add default/cached transports; ignoring: %s", t.GetAddr()) - return false - } - - // add transports that are prefixed with "Plus" to the plus() - // multi-transport, while the supervisor plus() multi-transport - // itself must be added to r.transports (below) - if tid != Plus && isPlus(tid) { - plus, err := r.plus() - if err != nil { - log.W("dns: plus: cannot add %s; %v", tid, err) - return false - } - plus.Add(t) // onDNSAdded listener is not called - return true - } - - caching := false - switch t.Type() { - case DNS53, DNSCrypt, DOH, DOT, ODOH: - r.Lock() - // stop existing transport if different - if oldt := r.transports[tid]; t != oldt { - r.stopIfExistsLocked(tid) - r.transports[tid] = t - } - // always recreate caching transport - if ct := NewCachingTransport(t, ttl10m); ct != nil { - ctid := idstr(ct) - r.stopIfExistsLocked(ctid) - r.transports[ctid] = ct - caching = true - } - r.Unlock() - - if tid == System || tid == Goos { - // always add64 after having added the system transport - core.Gx("r.Add64", func() { r.Add64(tid) }) - } - - core.Go("r.onAdd", func() { r.listener.OnDNSAdded(tid) }) - log.I("dns: add transport %s@%s; caching? %t", - t.ID(), t.GetAddr(), caching) - - return true - default: - log.E("dns: unknown transport(%s) type: %s", t.ID(), t.Type()) - } - return false -} - -func (r *resolver) GetMult(id string) (x.DNSTransportMult, error) { - return r.GetMultInternal(id) -} - -func (r *resolver) GetMultInternal(id string) (TransportMult, error) { - if r.closed.Load() { - return nil, errResolverClosed - } - r.RLock() - t, ok := r.transports[id] - r.RUnlock() - - if ok { - if tm, ok := t.(TransportMult); ok { - return tm, nil - } - return nil, errTransportNotMult - } - return nil, errNoSuchTransport -} - -func (r *resolver) dcProxy() (TransportMult, error) { - return r.GetMultInternal(DcProxy) -} - -func (r *resolver) plus() (TransportMult, error) { - return r.GetMultInternal(Plus) -} - -func (r *resolver) GetInternal(id string) (Transport, error) { - if r.closed.Load() { - return nil, errResolverClosed - } - - if t := r.determineTransport(id); t == nil || core.IsNil(t) { - return nil, errNoSuchTransport - } else { - return t, nil - } -} - -func (r *resolver) S() string { - return r.gateway.S() -} - -func (r *resolver) Get(id string) (x.DNSTransport, error) { - return r.GetInternal(id) -} - -func (r *resolver) Remove(tid string) (ok bool) { - if r.closed.Load() { - log.W("dns: remove: closed for business") - return false - } - - id := tid - // these IDs are reserved for internal use - if isReserved(id) { - log.I("dns: removing reserved transport %s", id) - } - - r.RLock() - _, hasTransport := r.transports[id] - r.RUnlock() - - if hasTransport { - if id == System || id == Goos { - core.Gx("r.Remove64", func() { r.Remove64(id) }) - } - r.Lock() - r.stopIfExistsLocked(id) - r.stopIfExistsLocked(CT + id) - r.Unlock() - - log.I("dns: removed transport %s", id) - } - - if tm, err := r.dcProxy(); err == nil { // remove from dc-proxy, if any - hasTransport = tm.Remove(tid) || hasTransport - hasTransport = tm.Remove(CT+id) || hasTransport - } - - if tm, err := r.plus(); err == nil { // remove from plus, if any - hasTransport = tm.Remove(tid) || hasTransport - hasTransport = tm.Remove(CT+id) || hasTransport - } - - if hasTransport { - core.Go("r.onRemove", func() { r.listener.OnDNSRemoved(id) }) - } - - return hasTransport -} - -func (r *resolver) IsDnsAddr(ipport netip.AddrPort) bool { - return r.isDns(ipport) -} - -// LookupFor2 implements ResolverSelf. -func (r *resolver) LookupFor2(q []byte, uid string, tids ...string) ([]byte, string, error) { - if len(q) <= 0 { - return nil, NoDNS, errNoQuestion - } - if uid == core.UNKNOWN_UID_STR { - uid = protect.UidSelf - } - // if len(tids) == 0, use transport from preferences - return r.forward(q, OriginInternal, uid, tids...) -} - -// LookupFor implements ResolverSelf. -func (r *resolver) LookupFor(q []byte, uid string) ([]byte, string, error) { - if len(q) <= 0 { - return nil, NoDNS, errNoQuestion - } - - // prechose tids preferred & fixed when uid is set to 0, -1, or 1051 (common - // android system components that send DNS requests on behalf of actual apps/uids) - // to use "fixed" transport to later uncover the actual requesting app/uid during - // tcp/udp flows (specifically, with preflow) - if (uid == core.UNKNOWN_UID_STR || uid == core.DNS_UID_STR || uid == core.ANDROID_UID_STR) && r.gateway.fixedTransport() { - return r.forward(q, OriginInternal, uid, Preferred, Fixed) - } - - return r.forward(q, OriginInternal, uid) -} - -// LocalLookup implements ResovlerSelf. -func (r *resolver) LocalLookup(q []byte) ([]byte, string, error) { - if r.closed.Load() { - return nil, NoDNS, errResolverClosed - } - - loopingBack := settings.Loopingback.Load() - defaultIsSystemDNS := r.isDefaultSystemDNS() - - // including dns64 and/or alg - ans, tid, err := r.forward(q, OriginInternal, protect.UidSelf, Default) - if !defaultIsSystemDNS || loopingBack { - return ans, tid, err - } // else: retry with Goos/System, if needed - - // msg may be nil - if msg := xdns.AsMsg(ans); err != nil || xdns.IsNXDomain(msg) || !xdns.HasRcodeSuccess(msg) { - log.I("dns: nxdomain via Default (err? %v); attempting Goos for %s", err, xdns.QName(msg)) - ans, tid, err = r.forward(q, OriginInternal, protect.UidSelf, Goos) // Goos is System; see: determineTransport - } // else: rcode success and nil err; do not fallback on Goos/System - - return ans, tid, err -} - -func (r *resolver) forward(q []byte, who, uid string, chosenids ...string) (res0 []byte, tid0 string, err0 error) { - starttime := time.Now() - ogsmm := &x.DNSSummary{ - ID: NoDNS, - UID: uid, // may be overwritten to by Cacher via fillSummary - QName: invalidQname, - Status: Start, - Msg: errNop.Error(), - } - - msg, err := unpack(q) - if err != nil { - log.W("dns: fwd: for %s; %d not a dns packet %v", uid, len(q), err) - ogsmm.Latency = time.Since(starttime).Seconds() - ogsmm.Status = BadQuery - ogsmm.Msg = err.Error() - r.queueSummary(ogsmm) - return nil, NoDNS, err - } - - // figure out transport to use - qname := qname(msg) - qtyp := qtype(msg) - ogsmm.QName = qname - ogsmm.QType = qtyp - ogsmm.Targets = qname - - if len(qname) <= 0 { // unexpected; github.com/celzero/rethink-app/issues/1210 - ogsmm.Latency = time.Since(starttime).Seconds() - ogsmm.Status = BadQuery - ogsmm.Msg = errMissingQueryName.Error() - r.queueSummary(ogsmm) - return nil, NoDNS, errMissingQueryName - } - - pref, oqcompleted := core.Grx("r.onQuery", func(_ context.Context) (*x.DNSOpts, error) { - return r.listener.OnQuery(who, uid, qname, qtyp), nil - }, listenerTimeout) - if !oqcompleted || pref == nil { - log.W("dns: fwd: for %s; no preferences (%t) for %s:%d", uid, pref == nil, qname, qtyp) - ogsmm.Latency = time.Since(starttime).Seconds() - ogsmm.Status = ClientError - ogsmm.Msg = errOnQueryTimeout.Error() - r.queueSummary(ogsmm) - return nil, NoDNS, errOnQueryTimeout - } - - prefuid := pref.UID - senduid := uid // may be prefuid when oguid is unknown - run := 0 - if ogsmm.UID == core.UNKNOWN_UID_STR { - ogsmm.UID = prefuid - senduid = prefuid - } - - smm := copySummary(ogsmm) - - // TODO? do not use defer func() and do copy: go.dev/play/p/oGUJepa3VUo - defer func() { - r.queueSummary(smm) // always call up to the listener - }() - -runagain: - run++ - *smm = *copySummary(ogsmm) - - log.V("dns: fwd: 1 for %s (%s); query %s:%d, r%d; [prefs:%v; chosen:%v]", uid, who, qname, qtyp, run, pref, chosenids) - - id, sid, pids, presetIPs := r.preferencesFrom(qname, uint16(qtyp), pref, chosenids...) - t := r.determineTransport(id) // id may be empty if pref is nil - - log.V("dns: fwd: 2 for %s; query %s:%d, r%d; [prefs:%v; chosen:%v]; id? %s, sid? %s, pid? %s, ips? %v", - uid, qname, qtyp, run, pref, chosenids, id, sid, pids, presetIPs) - - if t == nil || core.IsNil(t) { - smm.Latency = time.Since(starttime).Seconds() - smm.Status = TransportError - smm.Msg = strings.Join(append(chosenids, id, sid, errNoSuchTransport.Error()), ";") - return nil, NoDNS, errNoSuchTransport - } - var t2 Transport - if len(sid) > 0 { - t2 = r.determineTransport(sid) - } - - smm.Type = t.Type() - smm.ID = idstr(t) - - res1, blocklists, err := r.blockQ(t, t2, msg) // skips if the t, t2 are alg/block-free - if err == nil { - if pref.NOBLOCK { // only add blocklists and do not actually block - smm.Blocklists = blocklists - } else { // block the query - b, e := res1.Pack() - smm.Latency = time.Since(starttime).Seconds() - smm.Status = Complete - smm.Blocklists = blocklists - smm.RData = xdns.GetInterestingRData(res1) - if e != nil { - smm.Msg = e.Error() - } else { - smm.Msg = errNop.Error() - } - log.V("dns: fwd: 3 for %s; r%d, query blocked %s:%d by %s", uid, run, qname, qtyp, blocklists) - return b, smm.ID, e - } - } else { - log.V("dns: fwd: 4 for %s; r%d, query NOT blocked %s:%d; why? %v", uid, run, qname, qtyp, err) - } - - var res2 []byte - var nonalg, ans1 *dns.Msg // alg'd answer - - // t, t2 could be different from user-selected sid & pid - // when sid and pid fallback on Default or System DNS - // in which case, selected proxy must be overriden - netid := xdns.NetAndProxyID(NetTypeUDP, pids) - - // with t2 as the secondary transport, which could be nil - nonalg, ans1, err = r.gateway.q(t, t2, presetIPs, netid, uid, msg, smm) - - if smm.Latency <= 0 { - smm.Latency = time.Since(starttime).Seconds() - } - smm.UID = senduid // reset uid as it may have been cleared by cacher - - if nonalg == nil || err != nil { // TODO: servfail? - if isAlgErr(err) { // alg errs not set when gw.translate is off - log.W("dns: fwd: for %s; r%d, alg error %s for %s:%d", uid, run, err, qname, qtyp) - smm.Status = NoResponse - } else if smm.Status == Start { - smm.Status = InternalError - } - err = core.OneErr(err, errNoAnswer) - smm.Msg = err.Error() - // both err and res2 are set when res2 has rcode error or servfail - // summary latency, ips, response, status already set by transport t - return res2, smm.ID, err - } - - err = nil // discard alg errs if any - if ans1 == nil { // ans1 is nil when alg is disabled - ans1 = nonalg - } - - res2, err = ans1.Pack() - if err != nil { - smm.Status = BadResponse // TODO: servfail? - smm.Msg = err.Error() - return res2, smm.ID, err - } - - smm.Targets = xdns.GetTargets(ans1) - - ans2, blockedtarget, blocklistnames := r.blockA(t, t2, msg, nonalg, smm.Blocklists) - - isnewans := ans2 != nil - hasblocklists := len(blocklistnames) > 0 - hasmsg := len(smm.Msg) > 0 - - if hasblocklists { // blocklists added even if pref.NOBLOCK is set - smm.Blocklists = blocklistnames - smm.BlockedTarget = blockedtarget - } - if !hasmsg { - smm.Msg = errNop.Error() // no error - } - // do not block, only add blocklists if NOBLOCK is set - if !pref.NOBLOCK && isnewans { - // overwrite if new answer - ans1 = ans2 - // summary latency, response, status, ips also set by transports - smm.RTtl = xdns.RTtl(ans2) - smm.RCode = xdns.Rcode(ans2) - smm.RData = xdns.GetInterestingRData(ans2) - smm.Status = Complete - res2, err = ans2.Pack() - if err != nil { - smm.RTtl = 0 - smm.RCode = dns.RcodeFormatError - smm.Status = BadResponse // TODO: servfail? - smm.Msg = err.Error() - } - - log.V("dns: fwd: 5 for %s[%s]; query %s:%d, r%d, smm[data: %s, status: %d] blocked", - smm.ID, uid, qname, qtyp, run, smm.RData, smm.Status) - return res2, smm.ID, err - } - - realips := Netip2Csv(xdns.IPs(nonalg)) - ansblocked := xdns.AQuadAUnspecified(ans1) - - if settings.Debug { - log.V("dns: fwd: 6 for %s[%s]; query %s:%d, r%d, ips: %s; smm[data: %s, status: %d]; new-ans? %t, blocklists? %t, blocked? %t", - smm.ID, uid, qname, qtyp, run, realips, smm.RData, smm.Status, isnewans, hasblocklists, ansblocked) - } - - if run == 1 { - pref2, ouacompleted := core.Grx("r.onUA."+qname, func(_ context.Context) (*x.DNSOpts, error) { - return r.listener.OnUpstreamAnswer(smm, realips), nil - }, listenerTimeout) - if !ouacompleted { - log.W("dns: fwd: for %s[%s]; preferences2 missing for %s:%d; ips? %s", smm.ID, uid, qname, qtyp, realips) - smm.Status = ClientError - smm.Msg = errOnUpstreamAnswerTimeout.Error() - smm.ID = NoDNS - return nil, NoDNS, errOnUpstreamAnswerTimeout - } - - if pref2 != nil && len(pref2.TIDCSV) > 0 && pref2.TIDCSV != pref.TIDCSV { - pref = pref2 - goto runagain // re-run with new pids - } - - log.V("dns: fwd: 7 for %s[%s], r%d; preferences2 skipped for %s:%d [ips? %s]: %v", - smm.ID, uid, run, qname, qtyp, realips, pref2) - } - - // return transport ID match w/ ID used by alg.go:registerLocked (alg/nat/ptr caches) - // as ipmapper uses this ID (tid0) subsequently to undoAlg - return res2, smm.ID, nil -} - -// Serve implements Resolver. -func (r *resolver) Serve(proto string, c protect.Conn, uid string) (rx, tx int64, errs []error) { - if r.closed.Load() { - err := log.EE("dns: serve: closed for business") - errs = append(errs, err) - return - } - - // if Serve (which is called by common.go:dnsOverride) calls in with a uid - // that is not UNKNOWN_UID_STR, then we know that the query is from an app - // and we can presume per app split tunnel is working as expected. - if len(uid) > 0 && uid != core.ANDROID_UID_STR && uid != core.UNKNOWN_UID_STR && uid != core.DNS_UID_STR { - r.gateway.splitTunnel() - } - - switch proto { - case NetTypeTCP: - rx, tx, errs = r.accept(c, uid) - case NetTypeUDP: - rx, tx, errs = r.reply(c, uid) - default: - err := log.EE("dns: unknown proto: %s", proto) - errs = append(errs, err) - } - return -} - -func (r *resolver) determineTransport(id string) Transport { - if len(id) <= 0 { - return nil - } - if id == Default || id == CT+Default { - r.RLock() - d := r.transports[Default] - r.RUnlock() - return d - } - - var id0, id1 string - if id == Local || id == CT+Local { // mdns never cached - id0 = Local - } else if id == Alg { - // if no firewall is setup, alg isn't possible - if settings.BlockMode.Load() == settings.BlockModeNone { - id0 = CT + Preferred - } else { - id0 = CT + BlockFree - id1 = CT + Preferred - } - } else if id == System || id == CT+System || id == Goos || id == CT+Goos { - // fallback on Goos if System is unavailable - // but unlike "System", "Goos" does not support - // other than A / AAAA queries - // cf: undelegated.go:requiresGoosOrLocal() - if id == CT+System || id == CT+Goos { - id0 = CT + System - id1 = CT + Goos - } else { - id0 = System - id1 = Goos - } - } else if isPlus(id) { - id0 = Plus // replace a plus transport with its mult equivalent - } else { - id0 = id - } - - var t0, t1, tf Transport - r.RLock() - t0 = r.transports[id0] - if len(id1) > 0 { - t1 = r.transports[id1] - } - tf = r.transports[Default] - r.RUnlock() - - mayusedefault := canUseDefaultDNS(id0) - if t0 != nil && (t1 == nil || !mayusedefault || activeTransport(t0)) { - return t0 - } else if t1 != nil && (!mayusedefault || activeTransport(t1)) { - return t1 - } else if tf != nil && mayusedefault { - log.W("dns: fwd: %s is missing; using default", id0) - return tf // todo: assert tf != nil? - } - - return nil -} - -// dnstcp queries the transport and writes answers to w, prefixed by length. -func (r *resolver) dnstcp(q []byte, w io.WriteCloser, uid string) (written int, err error) { - ans, _, err := r.forward(q, OriginTunnel, uid) - - rlen := len(ans) - if rlen <= 0 && err != nil { - clos(w) // close on client err - return - } - - if written, err = writePrefixed(w, ans, rlen); err != nil { - clos(w) // close on write back err - } else if written != rlen { // do not close on incomplete writes - err = fmt.Errorf("dns: tcp: for %s incomplete write: n(%d) != r(%d)", uid, written, rlen) - } - return -} - -// dnsudp queries the transport and writes answers to w. -func (r *resolver) dnsudp(q []byte, w io.WriteCloser, uid string) (written int, err error) { - ans, _, err := r.forward(q, OriginTunnel, uid) - - rlen := len(ans) - if rlen <= 0 && err != nil { - clos(w) // close on client err - return - } - - if written, err = w.Write(ans); err != nil { - clos(w) // close on write back err - } else if written != rlen { - // do not close on incomplete writes - err = fmt.Errorf("dns: udp: for %s incomplete write: n(%d) != r(%d)", uid, written, rlen) - } - - return -} - -// reply DNS-over-UDP from a stub resolver. -func (r *resolver) reply(c protect.Conn, uid string) (rx, tx int64, errs []error) { - defer clos(c) - - var rxv, txv atomic.Int64 - - var wg sync.WaitGroup - start := time.Now() - cnt := 0 - for { - qptr := core.Alloc() - q := *qptr - q = q[:cap(q)] - free := func() { - *qptr = q - core.Recycle(qptr) - } - - tm := time.Now().Add(ttl2m) - _ = c.SetDeadline(tm) - - if n, err := c.Read(q); err != nil { - log.VV("dns: udp: for %s done; tot: %d, t: %s, err: %v", - uid, cnt, core.FmtTimeAsPeriod(start), err) - free() - break - } else { - core.Gx("r.reply.do", func() { - wg.Add(1) - defer wg.Done() - defer free() - m, err := r.dnsudp(q[:n], c, uid) - logeif(err != nil)("dns: udp: for %s err! tot: %d, t: %s, %v", - uid, cnt, core.FmtTimeAsPeriod(start), err) - rxv.Add(int64(m)) - txv.Add(int64(n)) - errs = append(errs, err) - }) - } - cnt++ - } - wg.Wait() - rx = rxv.Load() - tx = txv.Load() - log.VV("dns: udp: for %s done; tot: %d (rx: %d, tx: %d), t: %s", uid, cnt, rx, tx, core.FmtTimeAsPeriod(start)) - return -} - -// Accept a DNS-over-TCP socket from a stub resolver, and connect the socket -// to this DNSTransport. -func (r *resolver) accept(c io.ReadWriteCloser, uid string) (rx, tx int64, errs []error) { - defer clos(c) - - var rxv, txv atomic.Int64 - - var wg sync.WaitGroup - start := time.Now() - cnt := 0 - qlbuf := make([]byte, 2) - for { - n, err := c.Read(qlbuf) - if n == 0 { - log.D("dns: tcp: for %s query socket shutdown", uid) - break - } - if err != nil { - log.W("dns: tcp: for %s err reading from socket: %v", uid, err) - break // close on read errs - } - // TODO: inform the listener? - if n < 2 { - log.W("dns: tcp: for %s incomplete query length", uid) - break // close on incorrect lengths - } - qlen := binary.BigEndian.Uint16(qlbuf) - - qptr := core.AllocRegion(int(qlen)) - q := *qptr - q = q[:cap(q)] - free := func() { - *qptr = q - core.Recycle(qptr) - } - - n, err = c.Read(q) - if err != nil { - log.D("dns: tcp: for %s done; err: %v", uid, err) - free() - break // close on read errs - } - if n != int(qlen) { - free() - log.W("dns: tcp: for %s incomplete query: %d < %d; tot: %d, t: %s", - uid, n, qlen, cnt, core.FmtTimeAsPeriod(start)) - break // close on incomplete reads - } - core.Gx("r.accept.do", func() { - wg.Add(1) - defer wg.Done() - defer free() - m, err := r.dnstcp(q[:n], c, uid) - logeif(err != nil)("dns: tcp: for %s err! tot: %d, t: %s, %v", - uid, cnt, core.FmtTimeAsPeriod(start), err) - errs = append(errs, err) - txv.Add(int64(n)) - rxv.Add(int64(m)) - }) - cnt++ - } - wg.Wait() - rx = rxv.Load() - tx = txv.Load() - log.VV("dns: tcp: for %s done; tot: %d (rx: %d, tx: %d), t: %s", uid, cnt, rx, tx, core.FmtTimeAsPeriod(start)) - // TODO: Cancel outstanding queries. - return -} - -// StopAll implements TransportMult. -// StopAll stops all transports and closes the resolver. -func (r *resolver) StopAll() { - r.once.Do(func() { - defer core.Go("r.onStop", func() { r.listener.OnDNSStopped() }) - r.done() - - if dc, err := r.dcProxy(); err == nil { - _ = dc.Stop() - } - - if p, err := r.plus(); err == nil { - _ = p.Stop() - } - - // Stop all transports in a separate goroutine to avoid blocking - core.Go("r.stopAllTransports", func() { - r.Lock() - for _, tr := range r.transports { - _ = tr.Stop() - // r.gateway.onStopped(id) is not required - // as the entire setup is closed and going away - } - clear(r.transports) - r.Unlock() - }) - - close(r.smms) // close listener chan - }) -} - -func (r *resolver) all() []Transport { - r.RLock() - defer r.RUnlock() - out := make([]Transport, 0, len(r.transports)) - for _, t := range r.transports { - out = append(out, t) - } - return out -} - -func (r *resolver) refresh() { - for _, t := range r.all() { - // clear caches of cached transports: - if ct := asCachedTransport(t); ct != nil { - ct.Clear() // one at a time ... - } - } -} - -func (r *resolver) Refresh() (string, error) { - return r.refreshAll() -} - -func (r *resolver) refreshAll() (string, error) { - if r.closed.Load() { - return "", errResolverClosed - } - - log.I("dns: refresh transports") - - core.Gx("r.refresh", r.refresh) - core.Gx("r.refresh.clearcache", dialers.Clear) - s := tr2csv(r.all()) - if dc, err := r.dcProxy(); err == nil { - if x, err := dc.Refresh(); err == nil { - s += "," + x - } - } - if p, err := r.plus(); err == nil { - if x, err := p.Refresh(); err == nil { - s += "," + x - } - } - return trimcsv(s), nil -} - -func (r *resolver) LiveTransports() string { - if r.closed.Load() { - log.W("dns: liveTransports: closed for business") - return "" - } - s := tr2csv(r.all()) - if dc, err := r.dcProxy(); err == nil { - x := dc.LiveTransports() - s += "," + x - } - if p, err := r.plus(); err == nil { - x := p.LiveTransports() - s += "," + x - } - return trimcsv(s) -} - -func (r *resolver) preferencesFrom(qname string, qtyp uint16, s *x.DNSOpts, chosenids ...string) (id1, id2, pidcsv string, ips []netip.Addr) { - var x []string // primary - var xx []string // secondary - if s == nil { // should never happen; but it has during testing (on End()) - log.W("dns: pref: no ns opts for %s", qname) - return // no-op - } else { - x = strings.Split(s.TIDCSV, ",") - xx = strings.Split(s.TIDSECCSV, ",") - if y := strings.Split(s.IPCSV, ","); len(y) > 0 { - ips = make([]netip.Addr, 0, len(y)) - for _, a := range y { - a = strings.TrimSpace(a) - if len(a) <= 0 { - continue - } - ip, err := netip.ParseAddr(a) - if err != nil || !ip.IsValid() { - log.W("dns: pref: skip bad ip %s for %s", a, qname) - continue - } - ips = append(ips, ip) // unmap? - } - } - if len(ips) > 0 { - ip4s, ip6s := splitIPFamilies(ips) - if xdns.IsAQType(qtyp) { - ips = ip4s - } else if xdns.IsAAAAQType(qtyp) { - ips = ip6s - } else if xdns.IsHTTPSQType(qtyp) || xdns.IsSVCBQType(qtyp) { - // ips are substituted in after answers are received - // so qtype checks are not sufficient - // see: synthesizeOrQuery - } else { - ips = nil // mismatch in query type and ip family - } - } - } - - if len(x) <= 0 { // x may be nil - log.W("dns: pref: no tids for %s", qname) - // no-op - } else { - id1 = r.chooseOne(x...) - id2 = r.chooseOne(xx...) // mostly, just 0 or 1 secondary - } - - if !firstEmpty(chosenids) && len(chosenids) > 0 { - // chosen ID overrides all except: - if (isPlus(id1) || isPlus(id2)) && isAnyDefault(chosenids...) { - // Plus overrides Default - id1 = Plus - id2 = "" - log.D("dns: pref: use Plus instead of Default for %s", qname) - } else { - id1 = chosenids[0] // never empty - id2 = "" // wipe out id2 if not set; use just id1 - if len(chosenids) > 1 { - id2 = chosenids[1] // may be empty, but that's ok - } - log.D("dns: pref: use chosen tr(%s, %s) for %s", id1, id2, qname) - } - } else if isAnyIPUnspecified(ips) || isAnyBlockAll(x...) { - // BlockAll must appear in primary TIDCSV - id1 = BlockAll // just one transport, BlockAll, if set - id2 = "" - } else if reqid := r.requiresGoosOrLocal(qname); len(reqid) > 0 { - // use approp transport given a qname - log.D("dns: pref: use suggested tr(%s) for %s", reqid, qname) - id1 = reqid - id2 = "" - } else if isAnyFixed(x...) || isAnyFixed(xx...) { - if id1 != Fixed && id1 != cacheprefix+Fixed { // Fixed must always be the primary transport - id2 = id1 - id1 = Fixed - } - if len(id2) <= 0 { - id2 = Preferred - } - log.VV("dns: pref: use fixed tr(%s, %s) for %s", id1, id2, qname) - // s.NOBLOCK must be respected - // s.PIDCSV must be respected - } - if len(ips) > 0 { - log.D("dns: pref: preset ips (no block) %v for %s", ips, qname) - s.NOBLOCK = true // skip blocks if ips are set (even if unspecified ips) - id2 = "" // no secondary transport - } - if isAnyLocal(id1, id2) { // use one transport, Local, if set - id1 = Local - id2 = "" - } - - if len(s.PIDCSV) > 0 { - pidcsv = overrideProxyIfNeeded(s.PIDCSV, id1, id2) - } else { - pidcsv = NetNoProxy - } - return -} - -func (r *resolver) requiresGoosOrLocal(qname string) (id string) { - if strings.HasSuffix(qname, ".local") || xdns.IsMDNSQuery(qname) { - id = Local - } else if !settings.SystemDNSForUndelegatedDomains.Load() { - // todo: remove this once we let users "pin" domains to resolvers - // github.com/celzero/rethink-app/issues/1153 - // skip override when preventing DNS capture on port53 is turned off - } else if len(qname) > 0 && r.localdomains.HasAny(qname) { - id = Goos // system is primary; see: transport.go:determineTransports() - } - return -} - -func (r *resolver) chooseOne(ids ...string) (theone string) { - if len(ids) <= 0 { - return "" - } - if isAnyPlus(ids...) { // prefer Plus, if set - return Plus - } - if len(ids) == 1 { - return ids[0] - } - - trs := make([]Transport, 0, len(ids)) - r.RLock() - for _, id := range ids { - id = strings.TrimSpace(id) - if t := r.transports[id]; t != nil { - trs = append(trs, t) - } - } - r.RUnlock() - - best, preferred, recoverables, errored, ended := Categorize(trs) - if settings.Debug { - defer func() { - loged(len(theone) <= 0)("dns: pref: chose: %s from best(%v) prefer(%v) recov(%v) err(%v) dead(%v)", - theone, best, preferred, recoverables, errored, ended) - }() - } - - if len(best) > 0 { - return idstr(best[0]) - } else if len(preferred) > 0 { - return idstr(preferred[0]) - } else if len(recoverables) > 0 { - return idstr(core.ChooseOne(recoverables)) - } else if len(errored) > 0 { - return idstr(core.ChooseOne(errored)) - } - log.E("dns: pref: no transports for %v [all ended? %v]", ids, ended) - return "" -} - -func Categorize(ts []Transport) (best []Transport, preferred []Transport, recoverables []Transport, errored []Transport, ended []Transport) { - for _, t := range ts { - switch t.Status() { - case Complete: - best = append(best, t) - case Start, NoResponse, BadQuery: - preferred = append(preferred, t) - case BadResponse: - preferred = append(preferred, t) - case InternalError, TransportError: - recoverables = append(recoverables, t) - case DEnd, Paused, Unknown: // discard non-active transports - ended = append(ended, t) - default: // ClientError, SendFailed - errored = append(errored, t) - } - } - return -} - -func (r *resolver) loadaddrs(csvaddr string) { - r.addDnsAddrs(csvaddr) -} - -func writePrefixed(w io.Writer, b []byte, l int) (int, error) { - const pre = 2 - sz := l + pre - bptr := core.AllocRegion(sz) - buf := *bptr - buf = buf[:cap(buf)] - - defer func() { - *bptr = buf - core.Recycle(bptr) - }() - - binary.BigEndian.PutUint16(buf, uint16(l)) - // Use a combined write (pre+b) to ensure atomicity. - // Otherwise, writes from two responses could be interleaved. - copy(buf[pre:], b) - n, err := w.Write(buf[:sz]) - return max(0, n-pre), err -} - -// meta proxies like pid == Auto are not considered local. -func IsLocalProxy(pid string) bool { - return len(pid) <= 0 || - pid == NetBaseProxy || - pid == NetExitProxy || - pid == NetNoProxy -} - -// return true ipn.Auto -func IsAutoProxy(pid string) bool { - return pid == NetAutoProxy -} - -// RegisterAddrs registers IP ports with all dialers for a given hostname. -// If id is dnsx.Bootstrap, the hostname is "protected" from re-resolutions. -// hostname is a domain name, and as a special case, can be protect.UidSelf or protect.UidSystem. -func RegisterAddrs(id, hostname string, ipps []string) (ok bool) { - var ipset *ipmap.IPSet - var addrs []netip.Addr - id, _ = strings.CutPrefix(id, CT) - if isProtected(id) { - log.I("dns: protected %s! %s => %v", id, hostname, ipps) - ipset, ok = dialers.NewProtected(hostname, ipps) - } else { - log.I("dns: regular %s! %s => %v", id, hostname, ipps) - ipset, ok = dialers.New(hostname, ipps) - } - if ipset != nil { - addrs = ipset.Addrs() - } - log.I("dns: reg regular/protected done %s! %s[+%v] => %v", id, hostname, ipps, addrs) - return -} - -func Fastest(a, b Transport) int { - if a == nil || b == nil { - return 0 // unlikely - } - return int(a.P50() - b.P50()) -} - -func IsEncrypted(t Transport) bool { - return t != nil && isEncrypted(t.Type()) -} - -func isEncrypted(t string) bool { - return t == DOT || t == DOH || t == DNSCrypt || t == ODOH -} - -func isProtected(id string) bool { - return id == Bootstrap || id == System || id == Default || id == Local || isPlus(id) -} - -func isReserved(id string) bool { - switch id { - case Default, Goos, System, Local, Alg, Plus, DcProxy, BlockAll, Preferred, Bootstrap, BlockFree, Fixed, Preset: - return true - case CT + Default, CT + Goos, CT + System, CT + Local, CT + Alg, CT + Plus, CT + DcProxy, CT + BlockAll, - CT + Bootstrap, CT + Preferred, CT + BlockFree, CT + Fixed, CT + Preset: - return true - } - return false -} - -func canUseDefaultDNS(id string) bool { - switch id { - // system can never be subst by default; only by Goos - case System, CT + System: - return false - // no other transport can do what mdns does - case Local, CT + Local: - return false - case Alg, Preferred, Plus, BlockFree: - return settings.DefaultDNSAsFallback.Load() - case CT + Alg, CT + Preferred, CT + Plus, CT + BlockFree: - return settings.DefaultDNSAsFallback.Load() - } - return false -} - -func isTransportID(match string, ids ...string) bool { - return slices.Contains(ids, match) -} - -func isAnyBlockAll(ids ...string) bool { - return isTransportID(BlockAll, ids...) -} - -func isAnyFixed(ids ...string) bool { - return isTransportID(Fixed, ids...) -} - -func isAnyIPUnspecified(ips []netip.Addr) bool { - for _, ip := range ips { - if ip.IsUnspecified() { - return true - } - } - return false -} - -func isAnyLocal(ids ...string) bool { - return isTransportID(Local, ids...) -} - -func isAnyPlus(ids ...string) bool { - return slices.ContainsFunc(ids, isPlus) -} - -func isAnyDefault(ids ...string) bool { - return isTransportID(Default, ids...) -} - -func CanUseProxy(id string) bool { - switch id { - case Goos, CT + Goos, Local, CT + Local, System, CT + System: - return false - case Preset, CT + Preset: - return false - case Default, CT + Default, Bootstrap, CT + Bootstrap: - return canProxyDefault() - } - return true -} - -func overrideProxyIfNeeded(pid string, ids ...string) string { - for _, id := range ids { - switch id { - // notes: - // 1. Goos is anyway hard-coded to use NetExitProxy. - // 2. Plus simply delegates queries to underlying servers, - // which may or may not use proxies. - case Goos, Local: // exit - return NetExitProxy - case CT + Goos, CT + Local: // exit - return NetExitProxy - case System, Preset: // base - return NetBaseProxy - case CT + System, CT + Preset: // base - return NetBaseProxy - case Default, CT + Default, Bootstrap, CT + Bootstrap: // may be proxy - return proxyForDefault(pid) - } - } - return pid // as-is -} - -func skipBlock(tr ...Transport) bool { - for _, t := range tr { - if t == nil { - continue - } - switch idstr(t) { // Plus/CT+Plus to skip blocks conditionally? - case BlockFree, Alg: - return true - case CT + BlockFree, CT + Alg: - return true - case Default, CT + Default, Bootstrap, CT + Bootstrap: - return canBlockDefault() - } - } - return false -} - -func (r *resolver) isDefaultSystemDNS() (y bool) { - if dtr, _ := r.GetInternal(Default); dtr != nil { - // todo: a better way to determine whether Default is SystemDNS - // Default is usually SystemDNS if it is of type DNS53 - y = dtr.Type() == DNS53 - } - return -} - -func canBlockDefault() bool { - // TODO: check for gateway.split? - return settings.Loopingback.Load() -} - -func canProxyDefault() bool { - // TODO: check for gateway.split? - // TODO: do not allow proxying when Default is mapped to Goos/System - return settings.Loopingback.Load() -} - -func proxyForDefault(pid string) string { - if canProxyDefault() { - return pid - } - return NetBaseProxy -} - -func skipInternalCache(tids ...string) bool { - return isAnyBlockAll(tids...) -} - -func unpack(q []byte) (*dns.Msg, error) { - msg := &dns.Msg{} - err := msg.Unpack(q) - return msg, err -} - -func qname(msg *dns.Msg) string { - n := xdns.QName(msg) - n, _ = xdns.NormalizeQName(n) - return n -} - -func qtype(msg *dns.Msg) int { - return int(xdns.QType(msg)) -} - -func tr2csv(ts []Transport) string { - s := "" - for _, t := range ts { - if activeTransport(t) { - s += idstr(t) + "," - } - } - return trimcsv(s) -} - -func trimcsv(s string) string { - return strings.Trim(s, ",") -} - -func PrefixFor(id string) string { - switch id { - case CT: - return cacheprefix - case System, CT + System: - return systemprefix - case Bootstrap, CT + Bootstrap: - return selfprefix - case Alg, CT + Alg: - return algprefix - case AlgDNS64, CT + AlgDNS64: - return d64prefix - case Default, CT + Default: - return defaultprefix - case Plus, CT + Plus: - return plusprefix - case Preset: - return presetprefix - case Fixed: - return fixedprefix - } - return "" -} - -func asCachedTransport(t Transport) Cacher { - if ct, ok := t.(Cacher); ok { - return ct - } - return nil -} - -func cachedTransport(t Transport) bool { - return strings.HasSuffix(idstr(t), CT) || - strings.HasPrefix(t.GetAddr(), cacheprefix) -} - -func WillErr(t Transport) *QueryError { - switch t.Status() { - case DEnd: - return NewEndQueryError() - case Paused: - return NewPausedQueryError() - } - return nil -} - -func isPlus(id string) bool { - return strings.HasPrefix(id, Plus) || strings.HasPrefix(id, CT+Plus) -} - -func activeTransport(t Transport) bool { - st := t.Status() - return st != DEnd && st != Paused && st != Unknown -} - -func clos(c io.Closer) { - core.CloseOp(c, core.CopRW) -} - -func firstEmpty(arr []string) bool { - return len(arr) <= 0 || len(arr[0]) <= 0 -} diff --git a/intra/dnsx/wall.go b/intra/dnsx/wall.go deleted file mode 100644 index a314c186..00000000 --- a/intra/dnsx/wall.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -func (r *resolver) setRdnsLocal(rlocal *rethinkdnslocal) { - r.rmu.Lock() - defer r.rmu.Unlock() - // rlocal can be nil - r.rdnsl = rlocal -} - -func (r *resolver) setRdnsRemote(rremote *rethinkdns) { - r.rmu.Lock() - defer r.rmu.Unlock() - // rremote can be nil - r.rdnsr = rremote -} - -func (r *resolver) getRdnsLocal() *rethinkdnslocal { - r.rmu.RLock() - defer r.rmu.RUnlock() - return r.rdnsl -} - -func (r *resolver) getRdnsRemote() *rethinkdns { - r.rmu.RLock() - defer r.rmu.RUnlock() - return r.rdnsr -} - -// Implements RdnsResolver -func (r *resolver) SetRdnsLocal(t, rd, conf, filetag string) error { - if len(t) <= 0 || len(rd) <= 0 { - log.I("transport: unset rdns local") - r.setRdnsLocal(nil) - return nil - } - if r.closed.Load() { - return errResolverClosed - } - - rlocal, err := newRDNSLocal(t, rd, conf, filetag) - r.setRdnsLocal(rlocal) - return err -} - -// Implements RdnsResolver -func (r *resolver) SetRdnsRemote(filetag string) error { - if len(filetag) <= 0 { - log.I("transport: unset rdns remote") - r.setRdnsRemote(nil) - return nil - } - if r.closed.Load() { - return errResolverClosed - } - - rremote, err := newRDNSRemote(filetag) - r.setRdnsRemote(rremote) - return err -} - -// Implements RdnsResolver -func (r *resolver) GetRdnsLocal() (x.RDNS, error) { - if r.closed.Load() { - return nil, errResolverClosed - } - - rlocal := r.getRdnsLocal() - - if rlocal != nil { - // a non-ftrie version for across the jni boundary - return rlocal.rethinkdns, nil - } - return nil, errNoRdns -} - -// Implements RdnsResolver -func (r *resolver) GetRdnsRemote() (x.RDNS, error) { - if r.closed.Load() { - return nil, errResolverClosed - } - - rremote := r.getRdnsRemote() - if rremote != nil { - // a non-ftrie version for across the jni boundary - return rremote, nil - } - return nil, errNoRdns -} - -// blockQ returns a refused ans if q is blocked by local blocklists; nil, otherwise. -// If t, t2 are non-nil, it skips local blocks for alg and blockfree transports. -func (r *resolver) blockQ(t, t2 Transport, msg *dns.Msg) (ans *dns.Msg, blocklists string, err error) { - if skipBlock(t, t2) { - return nil, "", errBlockFreeTransport - } - - qname := xdns.QName(msg) - b := r.getRdnsLocal() - - if b == nil || !b.OnDeviceBlock() { - if settings.Debug { - log.V("wall: no local blockerQ; letting through %s", qname) - } - return nil, "", errNoRdns - } - // OnDeviceBlock() is true; enforce blocklists - ans, blocklists, err = applyBlocklists(b, msg) - if err != nil { - // block skipped because err is set - log.D("wall: skip local for %s blockQ for %s with err %s", qname, blocklists, err) - } - return -} - -func applyBlocklists(b RDNS, q *dns.Msg) (ans *dns.Msg, blocklists string, err error) { - blocklists, err = b.blockQuery(q) - if err != nil { - return - } - if len(blocklists) <= 0 { - err = errNoBlocklistMatch - return - } - - ans, err = xdns.RefusedResponseFromMessage(q) - return -} - -// blockA blocks the answer if it is blocked by local blocklists. -// If blocklistStamp is not empty, it resolves them to blocklist names, if valid; -// and treats as if q was blocked by remote blocklists, effectively skipping local blocks. -// t, t2 can be nil. if non-nil, they are used to skip local blocks for alg and blockfree. -// If blocklistStamp is empty, it resolves the answer to blocklist names, if blocked by local blocklists. -// If blocklistStamp is empty and the answer is not blocked by local blocklists, it returns nil. -// If blocklistStamp is empty and the answer is blocked by local blocklists, it returns a refused response. -func (r *resolver) blockA(t, t2 Transport, q, ans *dns.Msg, blocklistStamp string) (finalans *dns.Msg, blockedtarget, blocklistNames string) { - br := r.getRdnsRemote() - b := r.getRdnsLocal() - - var err error - qname := xdns.QName(q) - - if len(blocklistStamp) > 0 && br != nil { // remote block resolution, if any - blocklistNames, err = br.stampToNames(blocklistStamp) - if err == nil { - log.D("wall: for %s blocklists %s", qname, blocklistNames) - return - } else { - log.D("wall: could not resolve blocklist-stamp(%s) for %s, err: %v", blocklistStamp, qname, err) - } // continue to local block resolution - } else { - log.D("wall: no blockA for %s; blocklist-stamp? (%d) / rdnsr? (%t)", qname, len(blocklistStamp), br != nil) - } - - if skipBlock(t, t2) { - return // skip local blocks for alg and blockfree - } - - // local block resolution, if any - if b == nil { - if settings.Debug { - log.V("wall: no local blockerA; letting through %s", qname) - } - return - } - - if !b.OnDeviceBlock() { - if settings.Debug { - log.D("wall: no local blockA for %s", qname) - } - return - } - - if blockedtarget, blocklistNames, err = b.blockAnswer(ans); err != nil { - if settings.Debug { - log.D("wall: answer for %s not blocked %v", qname, err) - } - return - } - - if len(blocklistNames) <= 0 { - if settings.Debug { - log.D("wall: answer %s not blocked blocklist empty", qname) - } - return - } - - finalans, err = xdns.RefusedResponseFromMessage(q) - if err != nil { - log.W("wall: could not pack %s blocked dns answer %v", qname, err) - return - } - - return -} diff --git a/intra/dnsx/x64.go b/intra/dnsx/x64.go deleted file mode 100644 index ae530e50..00000000 --- a/intra/dnsx/x64.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package dnsx - -import ( - "net/netip" - - "github.com/miekg/dns" -) - -// ref: datatracker.ietf.org/doc/html/rfc8880 -const Rfc7050WKN = "ipv4only.arpa." -const AnyResolver = "__anyresolver" -const UnderlayResolver = "__underlay" // used by transport dnsx.System -const OverlayResolver = "__overlay" // "net.DefaultResolver" dnsx.Goos -const Local464Resolver = "__local464" // preset "forced" DNS64/NAT64 - -type NatPt interface { - DNS64 - NAT64 -} - -type DNS64 interface { - // Add64 registers DNS64 resolver f to id. - Add64(id string) bool - // Remove64 deregisters any current resolver from id. - Remove64(id string) bool - // ResetNat64Prefix sets the NAT64 prefix for transport id to ip6prefix. - ResetNat64Prefix(ip6prefix string) bool - // D64 synthesizes ans64 (AAAA) from ans6 if required, using resolver f. - // Returned ans64 is nil if no DNS64 synthesis is needed (not AAAA). - // Returned ans64 is ans6 if it already has AAAA records. - D64(network, id, uid string, ans6 *dns.Msg) *dns.Msg -} - -type NAT64 interface { - // Returns true if ip is a NAT64 address from transport id. - IsNat64(id string, ip netip.Addr) bool - // Translates ip to IPv4 using the NAT64 prefix for transport id. - // As a special case, ip is zero addr, output is always IPv4 zero addr. - X64(id string, ip netip.Addr) netip.Addr -} diff --git a/intra/doh/client_auth.go b/intra/doh/client_auth.go deleted file mode 100644 index d6c4c546..00000000 --- a/intra/doh/client_auth.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2020 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build ignore - -package doh - -import ( - "crypto" - "crypto/ecdsa" - "crypto/tls" - "crypto/x509" - "errors" - "io" - - "github.com/celzero/firestack/intra/log" -) - -// ClientAuth interface for providing TLS certificates and signatures. -type ClientAuth interface { - // GetClientCertificate returns the client certificate (if any). - // May block as the first call may cause certificates to load. - // Returns a DER encoded X.509 client certificate. - GetClientCertificate() []byte - // GetIntermediateCertificate returns the chaining certificate (if any). - // It does not block or cause certificates to load. - // Returns a DER encoded X.509 certificate. - GetIntermediateCertificate() []byte - // Request a signature on a digest. - Sign(digest []byte) []byte -} - -// clientAuthWrapper manages certificate loading and usage during TLS handshakes. -// Implements crypto.Signer. -type clientAuthWrapper struct { - signer ClientAuth -} - -// GetClientCertificate returns the client certificate chain as a tls.Certificate. -// Returns an empty Certificate on failure, permitting the handshake to -// continue without authentication. -// Implements tls.Config GetClientCertificate(). -func (ca *clientAuthWrapper) GetClientCertificate( - info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - if ca.signer == nil { - log.W("Client certificate requested but not supported") - return &tls.Certificate{}, nil - } - cert := ca.signer.GetClientCertificate() - if cert == nil { - log.W("Unable to fetch client certificate") - return &tls.Certificate{}, nil - } - chain := [][]byte{cert} - intermediate := ca.signer.GetIntermediateCertificate() - if intermediate != nil { - chain = append(chain, intermediate) - } - leaf, err := x509.ParseCertificate(cert) - if err != nil { - log.W("Unable to parse client certificate: %v", err) - return &tls.Certificate{}, nil - } - _, isECDSA := leaf.PublicKey.(*ecdsa.PublicKey) - if !isECDSA { - // RSA-PSS and RSA-SSA both need explicit signature generation support. - log.W("Only ECDSA client certificates are supported") - return &tls.Certificate{}, nil - } - return &tls.Certificate{ - Certificate: chain, - PrivateKey: ca, - Leaf: leaf, - }, nil -} - -// Public returns the public key for the client certificate. -func (ca *clientAuthWrapper) Public() crypto.PublicKey { - if ca.signer == nil { - return nil - } - cert := ca.signer.GetClientCertificate() - leaf, err := x509.ParseCertificate(cert) - if err != nil { - log.W("Unable to parse client certificate: %v", err) - return nil - } - return leaf.PublicKey -} - -// Sign a digest. -func (ca *clientAuthWrapper) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - if ca.signer == nil { - return nil, errors.New("no client certificate") - } - signature := ca.signer.Sign(digest) - if signature == nil { - return nil, errors.New("failed to create signature") - } - return signature, nil -} - -func newClientAuthWrapper(signer ClientAuth) clientAuthWrapper { - return clientAuthWrapper{ - signer: signer, - } -} diff --git a/intra/doh/client_auth_test.go b/intra/doh/client_auth_test.go deleted file mode 100644 index 0cc349e1..00000000 --- a/intra/doh/client_auth_test.go +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright 2020 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build ignore - -package doh - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/rand" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "fmt" - "log" - "net" - "testing" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - ilog "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/rnet" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/x64" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -// PEM encoded test leaf certificate with ECDSA public key. -var ecCertificate string = `-----BEGIN CERTIFICATE----- -MIIBpTCCAQ4CAiAAMA0GCSqGSIb3DQEBCwUAMD4xCzAJBgNVBAYTAlVTMQswCQYD -VQQIDAJDQTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEKMAgGA1UECgwBWDAeFw0y -MDExMDQwNTU2MTZaFw0zMDExMDIwNTU2MTZaMD4xCzAJBgNVBAYTAlVTMQswCQYD -VQQIDAJDQTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEKMAgGA1UECgwBWDBZMBMG -ByqGSM49AgEGCCqGSM49AwEHA0IABNFVWlOs0tnaLgiutLbPISCd5Fn9UJz6oDen -prTOrHz11PiO/XiqwpJY8yO72QappL/7RYV+uw9hJfU+YOE3tZQwDQYJKoZIhvcN -AQELBQADgYEAdy6CNPvIA7DrS6WrN7N4ZjHjeUtjj2w8n5abTHhvANEvIHI0DARI -AoJJWp4Pe41mzFhROzo+U/ofC2b+ukA8sYqoio4QUxlSW3HkzUAR4HZMi8Risvo3 -OxSR9Lw/mGvZrJ8xr070EwnsD+cCZLfYQ0mSKDM9uPfI3YrgCVKyUwE= ------END CERTIFICATE-----` - -// PKCS8 encoded test ECDSA private key. -var ecKey string = `-----BEGIN PRIVATE KEY----- -MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIlI6NB+skAYL36XP -JvE+x5Nlbn0wvw2hlSqIqADiZhShRANCAATRVVpTrNLZ2i4IrrS2zyEgneRZ/VCc -+qA3p6a0zqx89dT4jv14qsKSWPMju9kGqaS/+0WFfrsPYSX1PmDhN7WU ------END PRIVATE KEY-----` - -// PEM encoded test leaf certificate with RSA public key. -// Doubles as an intermediate depending on the test. -var rsaCertificate string = `-----BEGIN CERTIFICATE----- -MIICWDCCAcGgAwIBAgIUS36guwZMKNO0ADReGLi0cZq8fOowDQYJKoZIhvcNAQEL -BQAwPjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1Nb3VudGFp -biBWaWV3MQowCAYDVQQKDAFYMB4XDTIwMTEwNDA1NDgyNVoXDTMwMTEwMjA1NDgy -NVowPjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1Nb3VudGFp -biBWaWV3MQowCAYDVQQKDAFYMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDd -eznqVu1Rn0m8KR4mX/qVv6uytzZ+juqW5VD55D+w9N6JryPpFHPi4VIm8PKLXp3X -GvY9mc8r+0Ow1qJZYoc/X0Na1c79bv9xwbD3aK28FlAs1+cmyesaFhCWa0bYAvcy -mqQGYhObEWb46E5AANV82CitDE9C1aXRT4SvkLnc6wIDAQABo1MwUTAdBgNVHQ4E -FgQUnUib8BhOHqjq9+gqPQ+ePyEW9zwwHwYDVR0jBBgwFoAUnUib8BhOHqjq9+gq -PQ+ePyEW9zwwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQAx/uZG -Gmb5w/u4UkdH7wnoOUNx6GwdraqtQWnFaXb87PmuVAjBwSAnzes2mlp/Vbcd6tYs -pPuHrxOcWgw/aRV6rK3vJZIH3DGvy1pNphGgegEcG88nrUCDcQqPLxvPJ8bmbaee -Tf+l5U2OHC3Yifb4FDOv47kGmq5VeWiYdp60/A== ------END CERTIFICATE-----` - -// PKCS8 encoded test RSA private key. -var rsaKey string = `-----BEGIN PRIVATE KEY----- -MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAN17OepW7VGfSbwp -HiZf+pW/q7K3Nn6O6pblUPnkP7D03omvI+kUc+LhUibw8otendca9j2Zzyv7Q7DW -ollihz9fQ1rVzv1u/3HBsPdorbwWUCzX5ybJ6xoWEJZrRtgC9zKapAZiE5sRZvjo -TkAA1XzYKK0MT0LVpdFPhK+QudzrAgMBAAECgYEAoCdhI8Ej7qe+S993u8wfiXWG -FL9DGpUBsYe03F5eZ/lJikopL3voqKDCJQKKgJk0jb0jXjwAgQ86TX+G+hezL5jp -xOOfMmTYgMwnUuFYN1gHAd+TnYB9G1qSQr9TOw3K9Rf4q2x09GhLP75qdr+qzmIR -YGle5ZSP0LqKNkpGNUECQQD+6CxOO8+knnzIFvqkUyNDVFR5ALRNpb53TGVITNf3 -ysT32oJ75ButA0l4q/jsL+MeLLvrHkJOHN+ydLaZOUkbAkEA3m5cICisW9lsT+Rj -glXykkbj3Ougldy7rhPivAaS7clk8cl8cDcIvHna1mDlhSanUu/s4TFEXBLnSzee -XLNIcQJBAJ0n3TD6lSEkCUB/UlX/X81B77aOZZs9pXj9o6/4mGoQHHHGyQ3C7AE1 -9pUsSZKsT3UqFU124WAxUwU+CdnbxKMCQB/QrUC0UKL6oHF0+37DCGU/2ovY8Ck/ -X2Dw2zeFwTJd4iBrb28lkAxVaaXMSkgXVUuZoco8H8kDsy2hEPe1dSECQQCPw5Yg -2gdmdpUk+QetqqhSuuIDwILHU9m3CoX3rY+njaR5LOWDz3utC9Ogo+4wdIMamP/o -2SAWPAZPqDUbtqGH ------END PRIVATE KEY-----` - -// fakeClientAuth implements the ClientAuth interface for testing. -type fakeClientAuth struct { - certificate *x509.Certificate - intermediate *x509.Certificate - key crypto.PrivateKey -} - -func (ca *fakeClientAuth) GetClientCertificate() []byte { - if ca.certificate == nil { - // Interface uses nil for errors to support binding. - return nil - } - return ca.certificate.Raw -} - -func (ca *fakeClientAuth) GetIntermediateCertificate() []byte { - if ca.intermediate == nil { - return nil - } - return ca.intermediate.Raw -} - -func (ca *fakeClientAuth) Sign(digest []byte) []byte { - if ca.key == nil { - return nil - } - if k, isECDSA := ca.key.(*ecdsa.PrivateKey); isECDSA { - signature, err := ecdsa.SignASN1(rand.Reader, k, digest) - if err != nil { - return nil - } - return signature - } - // Unsupported key type - return nil -} - -func newFakeClientAuth(certificate, intermediate, key []byte) (*fakeClientAuth, error) { - ca := &fakeClientAuth{} - if certificate != nil { - certX509, err := x509.ParseCertificate(certificate) - if err != nil { - return nil, fmt.Errorf("certificate: %v", err) - } - ca.certificate = certX509 - } - if intermediate != nil { - intX509, err := x509.ParseCertificate(intermediate) - if err != nil { - return nil, fmt.Errorf("intermediate: %v", err) - } - ca.intermediate = intX509 - } - if key != nil { - key, err := x509.ParsePKCS8PrivateKey(key) - if err != nil { - return nil, fmt.Errorf("private key: %v", err) - } - ca.key = key - } - return ca, nil -} - -func newCertificateRequestInfo() *tls.CertificateRequestInfo { - return &tls.CertificateRequestInfo{ - Version: tls.VersionTLS13, - } -} - -func newToBeSigned(message []byte) ([]byte, crypto.SignerOpts) { - digest := sha256.Sum256(message) - opts := crypto.SignerOpts(crypto.SHA256) - return digest[:], opts -} - -// Simulate a TLS handshake that requires a client cert and signature. -func TestSign(t *testing.T) { - certDer, _ := pem.Decode([]byte(ecCertificate)) - keyDer, _ := pem.Decode([]byte(ecKey)) - intDer, _ := pem.Decode([]byte(rsaCertificate)) - ca, err := newFakeClientAuth(certDer.Bytes, intDer.Bytes, keyDer.Bytes) - if err != nil { - t.Fatal(err) - } - wrapper := newClientAuthWrapper(ca) - // TLS stack requests the client cert. - req := newCertificateRequestInfo() - cert, err := wrapper.GetClientCertificate(req) - if err != nil { - t.Fatal("Expected to get a client certificate") - } - if cert == nil { - // From the crypto.tls docs: - // If GetClientCertificate returns an error, the handshake will - // be aborted and that error will be returned. Otherwise - // GetClientCertificate must return a non-nil Certificate. - t.Error("GetClientCertificate must return a non-nil certificate") - } - if len(cert.Certificate) != 2 { - t.Fatal("Certificate chain is the wrong length") - } - if !bytes.Equal(cert.Certificate[0], certDer.Bytes) { - t.Error("Problem with certificate chain[0]") - } - if !bytes.Equal(cert.Certificate[1], intDer.Bytes) { - t.Error("Problem with certificate chain[1]") - } - // TLS stack requests a signature. - digest, opts := newToBeSigned([]byte("hello world")) - signature, err := wrapper.Sign(rand.Reader, digest, opts) - if err != nil { - t.Fatal(err) - } - // Verify the signature. - pub, ok := wrapper.Public().(*ecdsa.PublicKey) - if !ok { - t.Fatal("Expected public key to be ECDSA") - } - if !ecdsa.VerifyASN1(pub, digest, signature) { - t.Fatal("Problem verifying signature") - } -} - -// Simulate a client that does not use an intermediate certificate. -func TestSignNoIntermediate(t *testing.T) { - certDer, _ := pem.Decode([]byte(ecCertificate)) - keyDer, _ := pem.Decode([]byte(ecKey)) - ca, err := newFakeClientAuth(certDer.Bytes, nil, keyDer.Bytes) - if err != nil { - t.Fatal(err) - } - wrapper := newClientAuthWrapper(ca) - // TLS stack requests a client cert. - req := newCertificateRequestInfo() - cert, err := wrapper.GetClientCertificate(req) - if err != nil { - t.Error("Expected to get a client certificate") - } - if cert == nil { - t.Error("GetClientCertificate must return a non-nil certificate") - } - if len(cert.Certificate) != 1 { - t.Error("Certificate chain is the wrong length") - } - if !bytes.Equal(cert.Certificate[0], certDer.Bytes) { - t.Error("Problem with certificate chain[0]") - } - // TLS stack requests a signature - digest, opts := newToBeSigned([]byte("hello world")) - signature, err := wrapper.Sign(rand.Reader, digest, opts) - if err != nil { - t.Error(err) - } - // Verify the signature. - pub, ok := wrapper.Public().(*ecdsa.PublicKey) - if !ok { - t.Error("Expected public key to be ECDSA") - } - if !ecdsa.VerifyASN1(pub, digest, signature) { - t.Error("Problem verifying signature") - } -} - -// Simulate a client that does not have a certificate. -func TestNoAuth(t *testing.T) { - ca, err := newFakeClientAuth(nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wrapper := newClientAuthWrapper(ca) - // TLS stack requests a client cert. - req := newCertificateRequestInfo() - cert, err := wrapper.GetClientCertificate(req) - if err != nil { - t.Error("Expected to get a client certificate") - } - if cert == nil { - t.Error("GetClientCertificate must return a non-nil certificate") - } - if len(cert.Certificate) != 0 { - t.Error("Certificate chain is the wrong length") - } - // TLS stack requests a signature. This should not happen in real life - // because cert.Certificate is empty. - public := wrapper.Public() - if public != nil { - t.Error("Expected public to be nil") - } - digest, opts := newToBeSigned([]byte("hello world")) - _, err = wrapper.Sign(rand.Reader, digest, opts) - if err == nil { - t.Error("Expected Sign() to fail") - } -} - -// Simulate a client that has an RSA certificate. -func TestRSACertificate(t *testing.T) { - certDer, _ := pem.Decode([]byte(rsaCertificate)) - keyDer, _ := pem.Decode([]byte(rsaKey)) - ca, err := newFakeClientAuth(certDer.Bytes, nil, keyDer.Bytes) - if err != nil { - t.Fatal(err) - } - wrapper := newClientAuthWrapper(ca) - // TLS stack requests a client cert. We should not return one because - // we don't support RSA. - req := newCertificateRequestInfo() - cert, err := wrapper.GetClientCertificate(req) - if err != nil { - t.Error("Expected to get a client certificate") - } - if cert == nil { - t.Error("GetClientCertificate must return a non-nil certificate") - } - if len(cert.Certificate) != 0 { - t.Error("Unexpectedly loaded an RSA certificate") - } - // TLS stack requests a signature. This should not happen in real life - // because cert.Certificate is empty. - digest, opts := newToBeSigned([]byte("hello world")) - _, err = wrapper.Sign(rand.Reader, digest, opts) - if err == nil { - t.Error("Expected Sign() to fail") - } -} - -// Simulate a nil loader. -func TestNilLoader(t *testing.T) { - wrapper := newClientAuthWrapper(nil) - // TLS stack requests the client cert. - req := newCertificateRequestInfo() - cert, err := wrapper.GetClientCertificate(req) - if err != nil { - t.Fatal(err) - } - if cert == nil { - // From the crypto.tls docs: - // If GetClientCertificate returns an error, the handshake will - // be aborted and that error will be returned. Otherwise - // GetClientCertificate must return a non-nil Certificate. - t.Error("GetClientCertificate must return a non-nil certificate") - } - if len(cert.Certificate) != 0 { - t.Fatal("Expected an empty certificate chain") - } - // TLS stack requests a signature. This should not happen in real life - // because cert.Certificate is empty. - digest, opts := newToBeSigned([]byte("hello world")) - _, err = wrapper.Sign(rand.Reader, digest, opts) - if err == nil { - t.Error("Expected Sign() to fail") - } -} - -type fakeCtl struct { - protect.Controller -} - -func (*fakeCtl) Bind4(_, _ string, _ int) {} -func (*fakeCtl) Bind6(_, _ string, _ int) {} -func (*fakeCtl) Protect(_ string, _ int) {} - -type fakeObs struct { - x.ProxyListener -} - -func (*fakeObs) OnProxyAdded(string) {} -func (*fakeObs) OnProxyRemoved(string) {} -func (*fakeObs) OnProxiesStopped() {} - -type fakeBdg struct { - protect.Controller - x.DNSListener -} - -var ( - baseNsOpts = &x.DNSOpts{PID: dnsx.NetBaseProxy, IPCSV: "", TIDCSV: x.CT + "test0"} - baseTab = &rnet.Tab{CID: "testcid", Block: false} -) - -func (*fakeBdg) OnQuery(_ string, _ int) *x.DNSOpts { return baseNsOpts } -func (*fakeBdg) OnResponse(*x.DNSSummary) {} -func (*fakeBdg) OnDNSAdded(string) {} -func (*fakeBdg) OnDNSRemoved(string) {} -func (*fakeBdg) OnDNSStopped() {} - -func (*fakeBdg) Route(a, b, c, d, e string) *rnet.Tab { return baseTab } -func (*fakeBdg) OnComplete(*rnet.ServerSummary) {} - -func TestDoh(t *testing.T) { - netr := &net.Resolver{} - // create a struct that implements protect.Controller interface - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := ipn.NewProxifier(ctl, obs) - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - q := aquery("skysports.com") - q6 := aaaaquery("skysports.com") - b4, _ := q.Pack() - b6, _ := q6.Pack() - // smm := &x.DNSSummary{} - // smm6 := &x.DNSSummary{} - _ = xdns.NetAndProxyID("tcp", dnsx.NetBaseProxy) - tm := &settings.TunMode{ - DNSMode: settings.DNSModePort, - BlockMode: settings.BlockModeNone, - PtMode: settings.PtModeAuto, - } - tr, _ := NewTransport("test0", "https://8.8.8.8/dns-query", nil, pxr, ctl) - dtr, _ := NewTransport(x.Default, "https://1.1.1.1/dns-query", nil, pxr, ctl) - - natpt := x64.NewNatPt(tm) - resolv := dnsx.NewResolver("10.111.222.3", tm, dtr, bdg, natpt) - resolv.Add(tr) - r4, err := resolv.Forward(b4) - r6, err6 := resolv.Forward(b6) - time.Sleep(1 * time.Second) - _, _ = resolv.Forward(b6) - if err != nil { - // log.Output(2, smm.Str()) - t.Fatal(err) - } - if err6 != nil { - // log.Output(2, smm6.Str()) - t.Fatal(err6) - } - ans := xdns.AsMsg(r4) - ans6 := xdns.AsMsg(r6) - if xdns.Len(ans) == 0 && xdns.Len(ans6) == 0 { - t.Fatal("no ans") - } - log.Output(10, xdns.Ans(ans)) - log.Output(10, xdns.Ans(ans6)) -} - -func aquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeA) - msg.Id = 1234 - return msg -} - -func aaaaquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeAAAA) - msg.Id = 3456 - return msg -} diff --git a/intra/doh/doh.go b/intra/doh/doh.go deleted file mode 100644 index 7cbb96b3..00000000 --- a/intra/doh/doh.go +++ /dev/null @@ -1,905 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package doh - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/base64" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/http/httptrace" - "net/netip" - "net/url" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/cloudflare/odoh-go" - "github.com/miekg/dns" -) - -const dohmimetype = "application/dns-message" - -const DohPortU16 = uint16(443) - -const maxEOFTries = uint8(2) - -const purgethreshold = 1 * time.Minute - -const echRetryPeriod = 8 * time.Hour - -var errNoClient error = errors.New("no doh client") - -type odohtransport struct { - omu sync.RWMutex // protects odohConfig - odohproxyurl string // proxy url - odohproxyname string // proxy hostname - odohproxyport uint16 // proxy port - odohtargetname string // target hostname - odohtargetpath string // target path - odohConfig *odoh.ObliviousDoHConfig - odohConfigExpiry time.Time - preferWK bool // prefer .well-known over svcb/https probe -} - -// TODO: Keep a context here so that queries can be canceled. -type transport struct { - *odohtransport // stackoverflow.com/a/28505394 - ctx context.Context - done context.CancelFunc - id string - typ string // dnsx.DOH / dnsx.ODOH - url string // endpoint URL - hostname string // endpoint hostname - port uint16 - skipTLSVerify bool // skips tls verification - tlsconfig *tls.Config // preset tlsconfig for the endpoint - echconfig *core.Volatile[*tls.Config] // echconfig for the endpoint; may be nil - echrejects atomic.Uint32 // number of running ech rejections - echlastattempt *core.Volatile[time.Time] // last attempt fetching ech cfg - pxcmu sync.RWMutex // protects pxclients - pxclients map[string]*proxytransport // todo: use weak pointers for Proxy - lastpurge *core.Volatile[time.Time] // last scrubbed time for stale pxclients - preferGET bool // saw 405 Method Not Allowed - proxies ipn.ProxyProvider // proxy provider, may be nil - relay string // dial doh via relay, may be empty - status *core.Volatile[int] - est core.P2QuantileEstimator -} - -var _ dnsx.Transport = (*transport)(nil) - -// NewTransport returns a POST-only DoH transport. -// `id` identifies this transport. -// `rawurl` is the DoH template in string form. -// `addrs` is a list of IP addresses to bootstrap dialers. -// `px` is the proxy provider, may be nil (eg for id == dnsx.Default) -func NewTransport(ctx context.Context, id, rawurl string, addrs []string, px ipn.ProxyProvider) (*transport, error) { - return newTransport(ctx, dnsx.DOH, id, rawurl, "", addrs, px) -} - -// NewTransport returns a POST-only Oblivious DoH transport. -// `id` identifies this transport. -// `endpoint` is the ODoH proxy that liaisons with the target. -// `target` is the ODoH resolver. -// `addrs` is a list of IP addresses to bootstrap endpoint dialers. -// `px` is the proxy provider, never nil. -func NewOdohTransport(ctx context.Context, id, endpoint, target string, addrs []string, px ipn.ProxyProvider) (*transport, error) { - return newTransport(ctx, dnsx.ODOH, id, endpoint, target, addrs, px) -} - -func newTransport(ctx context.Context, typ, id, rawurl, otargeturl string, addrs []string, px ipn.ProxyProvider) (*transport, error) { - isodoh := typ == dnsx.ODOH - - var renewed bool - var relay string - if px != nil { - if p, _ := px.ProxyFor(id); p != nil { - relay = p.ID() - } - } - - ctx, done := context.WithCancel(ctx) - - t := &transport{ - ctx: ctx, - done: done, - id: id, - typ: typ, - proxies: px, // may be nil - relay: relay, // may be empty - status: core.NewVolatile(dnsx.Start), - pxclients: make(map[string]*proxytransport), - echconfig: core.NewZeroVolatile[*tls.Config](), - echlastattempt: core.NewZeroVolatile[time.Time](), - lastpurge: core.NewVolatile(time.Now()), - est: core.NewP50Estimator(ctx), - } - if !isodoh { - parsedurl, err := url.Parse(rawurl) - if err != nil { - return nil, err - } - // use of "http" is an indication to turn-off TLS verification - // for, odoh rawurl represents a proxy, which can operate on http - if parsedurl.Scheme == "http" { - log.I("doh: disabling tls verification for %s", rawurl) - parsedurl.Scheme = "https" - t.skipTLSVerify = true - } - if parsedurl.Scheme != "https" { - return nil, fmt.Errorf("unsupported scheme %s", parsedurl.Scheme) - } - // for odoh, rawurl represents a proxy, which is optional - if len(parsedurl.Hostname()) == 0 { - return nil, fmt.Errorf("no hostname in %s", rawurl) - } - t.url = parsedurl.String() - t.hostname = parsedurl.Hostname() - t.port = DohPortU16 - if port, _ := strconv.ParseUint(parsedurl.Port(), 10, 16); port > 0 { - t.port = uint16(port) - } - // addrs are pre-determined ip addresses for url / hostname - renewed = dnsx.RegisterAddrs(t.id, t.hostname, addrs) - } else { - t.odohtransport = &odohtransport{} - - proxy := rawurl // may be empty - configurl, err := url.Parse(odohconfigdns) - if err != nil || configurl == nil || configurl.Hostname() == "" { - return nil, core.JoinErr(errNoOdohConfigUrl, err) - } - targeturl, err := url.Parse(otargeturl) - if err != nil || targeturl == nil || targeturl.Hostname() == "" { - return nil, core.JoinErr(errNoOdohTarget, err) - } - proxyurl, _ := url.Parse(proxy) // ignore err as proxy may be empty - - // addrs are proxy addresses if proxy is not empty, otherwise target addresses - if proxyurl != nil && proxyurl.Hostname() != "" { - renewed = dnsx.RegisterAddrs(id, proxyurl.Hostname(), addrs) - if len(proxyurl.Path) <= 1 { // should not be "" or "/" - proxyurl.Path = odohproxypath // default: "/proxy" - } - t.odohproxyurl = proxyurl.String() - t.odohproxyname = proxyurl.Hostname() - t.odohproxyport = DohPortU16 - if port, _ := strconv.ParseUint(proxyurl.Port(), 10, 16); port > 0 { - t.odohproxyport = uint16(port) - } - } else { - renewed = dnsx.RegisterAddrs(id, targeturl.Hostname(), addrs) - } - - t.url = configurl.String() // odohconfigdns - t.hostname = configurl.Hostname() // 1.1.1.1 - t.port = DohPortU16 // TODO: grab port from configUrl - t.odohtargetname = targeturl.Hostname() - if len(targeturl.Path) > 1 { // should not be "" or "/" - t.odohtargetpath = targeturl.Path - } else { - t.odohtargetpath = odohtargetpath // default: "/dns-query" - } - log.I("doh: ODOH for %s -> %s", proxy, otargeturl) - } - - echcfg := t.getOrCreateEchConfigIfNeeded() - - // TODO: ClientAuth - // Supply a client certificate during TLS handshakes. - // if auth != nil { - // signer := newClientAuthWrapper(auth) - // t.tlsconfig = &tls.Config{ - // GetClientCertificate: signer.GetClientCertificate, - // ServerName: t.hostname, - // } - // } - - t.tlsconfig = &tls.Config{ - InsecureSkipVerify: t.skipTLSVerify, - MinVersion: tls.VersionTLS12, - // SNI (hostname) must always be inferred from http-request - // ServerName: t.hostname, - SessionTicketsDisabled: false, - ClientSessionCache: core.TlsSessionCache(), - } - - log.I("doh: new transport(%s): %s; relay? %t; addrs? %v; resolved? %t, ech? %t", - t.typ, t.url, len(relay) > 0, addrs, renewed, echcfg != nil) - return t, nil -} - -type proxytransport struct { - p ipn.Proxy - c *http.Client - c3 *http.Client -} - -func (t *transport) ech() []byte { - // host http.Client connects to may change on redirects - name := t.hostname - if t.typ == dnsx.ODOH && len(t.odohproxyname) > 0 { - name = t.odohproxyname - } - if len(name) <= 0 { - return nil - } else if v, err := dialers.ECH(name); err != nil { - log.W("doh: ech(%s): %v", name, err) - return nil - } else { - log.V("doh: ech(%s): sz %d", name, len(v)) - return v - } -} - -func (t *transport) echVerifyFn() func(tls.ConnectionState) error { - if t.skipTLSVerify { - return func(info tls.ConnectionState) error { - log.V("doh: skip ech verify for %s via %s", t.hostname, info.ServerName) - return nil // never reject - } - } - return nil // delegate to stdlib -} - -func h2(d protect.DialFn, c *tls.Config) *http.Transport { - return &http.Transport{ - Dial: d, - ForceAttemptHTTP2: true, - IdleConnTimeout: 3 * time.Minute, - TLSHandshakeTimeout: 7 * time.Second, - // Android's DNS-over-TLS sets it to 30s - ResponseHeaderTimeout: 20 * time.Second, - // SNI (hostname) must always be inferred from http-request - TLSClientConfig: c, - } -} - -// always called from a go-routine -func (t *transport) purgeProxyClients() { - lastpurge := t.lastpurge.Load() - if time.Since(lastpurge) <= purgethreshold { - return - } - if ok := t.lastpurge.Cas(lastpurge, time.Now()); !ok { - log.I("doh: purge proxy clients: race...") - return - } - t.pxcmu.Lock() - defer t.pxcmu.Unlock() - for id, pxtr := range t.pxclients { - if pxtr == nil { - continue - } else if pxtr.p == nil { - delete(t.pxclients, id) - continue - } else if orig, err := t.proxies.ProxyFor(id); err != nil { - delete(t.pxclients, id) - log.W("doh: purge proxy clients: %s %v", id, err) - continue - } else { - diff := pxtr.p != orig - note := log.V - if diff { - note = log.I - delete(t.pxclients, id) - continue - } - note("doh: purge proxy clients: remove? %t %s", diff, id) - } - } -} - -func (t *transport) getOrCreateEchConfigIfNeeded() *tls.Config { - echcfg := t.echconfig.Load() - if echcfg != nil { - return echcfg - } - - prev := t.echlastattempt.Load() - if time.Since(prev) < echRetryPeriod { - return nil - } - refetch := t.echlastattempt.Cas(prev, time.Now()) - if !refetch { - return nil - } - - if ech := t.ech(); len(ech) > 0 { - echcfg = &tls.Config{ - InsecureSkipVerify: t.skipTLSVerify, - MinVersion: tls.VersionTLS13, // must be 1.3 - EncryptedClientHelloConfigList: ech, - SessionTicketsDisabled: false, - ClientSessionCache: core.TlsSessionCache(), - EncryptedClientHelloRejectionVerify: t.echVerifyFn(), - } - t.echconfig.Store(echcfg) - } - - ok := echcfg == nil - logeif(!ok)("doh: %s fetch ech... ok? %t", t.ID(), ok) - return echcfg -} - -func (t *transport) httpClientsFor(p ipn.Proxy) (c3, c *http.Client) { - pid := p.ID() - t.pxcmu.RLock() - pxtr, ok := t.pxclients[pid] - same := pxtr != nil && pxtr.p.Handle() == p.Handle() - if ok && same { - c = pxtr.c - c3 = pxtr.c3 - } - t.pxcmu.RUnlock() - - pdial := p.Dialer().Dial - if c != nil { // use existing clients - if c3 == nil { - if echcfg := t.getOrCreateEchConfigIfNeeded(); echcfg != nil { - c3 = new(http.Client) - c3.Transport = h2(pdial, echcfg) - t.updateHttpClientsFor(p, c, c3) - } - } - return c3, c // c3 may be nil - } - - var client http.Client - var client3 *http.Client - client.Transport = h2(pdial, t.tlsconfig) - if echcfg := t.echconfig.Load(); echcfg != nil { - client3 = new(http.Client) - client3.Transport = h2(pdial, echcfg) - } - - // last writer wins - t.updateHttpClientsFor(p, &client, client3) - - // check if other proxies need to be purged - core.Gx("doh.purgepx", t.purgeProxyClients) - - return client3, &client -} - -// updateHttpClientsFor only updates non-nil http clients dialing via Proxy p. -func (t *transport) updateHttpClientsFor(p ipn.Proxy, c, c3 *http.Client) { - if c == nil && c3 == nil { - log.E("doh: %s cannot set/update, all clients nil", t.ID()) - return - } - - pid := p.ID() - - t.pxcmu.Lock() - defer t.pxcmu.Unlock() - - if pt := t.pxclients[pid]; pt != nil { - if c != nil { - pt.c = c - } - if c3 != nil { - pt.c3 = c3 - } - } else { - if c == nil { - log.E("doh: %s cannot set http client to nil", t.ID()) - return - } - t.pxclients[pid] = &proxytransport{ - p: p, - c: c, - c3: c3, // may be nil - } - } -} - -// Given a raw DNS query (including the query ID), this function sends the -// query. If the query is successful, it returns the response and a nil qerr. Otherwise, -// it returns a SERVFAIL response and a qerr with a status value indicating the cause. -// Independent of the query's success or failure, this function also returns the -// address of the server on a best-effort basis, or nil if the address could not -// be determined. -func (t *transport) doDoh(pid string, q *dns.Msg) (response *dns.Msg, rpid, blocklists, region string, ech bool, elapsed time.Duration, qerr *dnsx.QueryError) { - start := time.Now() - if qerr = dnsx.WillErr(t); qerr != nil { - return - } - - padQuery(q) - // zero out the query id - id := q.Id - q.Id = 0 - - req, err := t.asDohRequest(q) - if err != nil { - log.D("doh: failed to create request: %v", err) - elapsed = time.Since(start) - qerr = dnsx.NewInternalQueryError(err) - return - } - - response, rpid, blocklists, region, ech, elapsed, qerr = t.send(pid, req) - - // restore dns query id - q.Id = id - if response != nil { - response.Id = id - } else { // override response with servfail - response = xdns.Servfail(q) - } - return -} - -func (t *transport) fetch(pid string, req *http.Request) (*http.Response, string, bool, error) { - ustr := req.URL.String() - - uerr := func(e error) *url.Error { - if e == nil { - return nil - } - if e, ok := e.(*url.Error); ok { - return e - } - return &url.Error{ - Op: req.Method, - URL: ustr, - Err: e, - } - } - - r, rpid, echdialer, err := t.multifetch(req, pid) - if err != nil { - log.W("doh: fetch: %s, mayech? %t / echdialer? %t, err: %v", - ustr, t.echconfig.Load() != nil, echdialer, err) - return r, rpid, echdialer, uerr(err) - } - return r, rpid, echdialer, nil -} - -func (t *transport) multifetch(req *http.Request, pid string) (res *http.Response, rpid string, echdialer bool, err error) { - px, err := t.prepare(pid) - if err != nil || px == nil { - return nil, "", false, core.OneErr(err, dnsx.ErrNoProxyProvider) - } - - rpid = ipn.ViaID(px) - c3, c0 := t.httpClientsFor(px) // c3 may be nil - - if settings.Debug { - log.VV("doh: using proxy %s+%s@%s ech? %t / other? %t", - px.ID(), rpid, px.GetAddr(), c3 != nil, c0 != nil) - } - - clients := []*http.Client{c3, c0} - - var cont, sent bool - for _, c := range clients { - if c == nil { // c may be nil (ex: if no ech) - continue - } - cont = true - sent = false - for i := uint8(0); cont && i < maxEOFTries; i++ { - cont = false - sent = true - if res, err = c.Do(req); err == nil { - return res, rpid, c == c3, nil // res is never nil here - } - if eerr := new(tls.ECHRejectionError); errors.As(err, &eerr) { - cont = true - ech := eerr.RetryConfigList - - echcfg := t.echconfig.Load() - useech := echcfg != nil - if len(ech) <= 0 && useech { - ech = t.ech() // todo: use t.echconfig.ServerName? - log.I("doh: fetch #%d: err %v; grab new ech? %t", - i, eerr, len(ech) > 0) - } - if len(ech) > 0 && useech { - echcfg.EncryptedClientHelloConfigList = ech - c.Transport = h2(px.Dialer().Dial, echcfg) - t.echconfig.Store(echcfg) // update ech config - t.echlastattempt.Store(time.Now()) - t.updateHttpClientsFor(px, nil, c) // update c3 - } - n := t.echrejects.Add(1) - log.I("doh: fetch #%d: ech rejected; retry? %t, ech? %t; total rejects: %d", - i, len(ech) > 0, useech, n) - } else if uerr, ok := err.(*url.Error); ok { - eof := uerr.Err == io.EOF - if eof && res != nil { - log.D("doh: fetch #%d: EOF; but res exists! %t", i) - return res, rpid, c == c3, nil - } // continue if EOF - cont = eof || uerr.Err == io.ErrUnexpectedEOF - } // terminate if not EOF - log.W("doh: fetch #%d (cont? %t) px: %s[%s]; err: %v", i, cont, pid, rpid, err) - } - } - if !sent && err == nil { // should never happen - log.E("doh: fetch: no client sent request %d", len(clients)) - } - return nil, rpid, false, core.OneErr(err, errNoClient) -} - -func (t *transport) prepare(pid string) (px ipn.Proxy, err error) { - userelay := len(t.relay) > 0 - hasproxy := t.proxies != nil - useproxy := len(pid) != 0 // if pid == dnsx.NetNoProxy, then px is ipn.Block - useech := t.echconfig.Load() != nil - - if userelay || useproxy { - if userelay { // relay takes precedence - pid = t.relay - } - if hasproxy { - px, err = t.proxies.ProxyFor(pid) - } else { - err = dnsx.ErrNoProxyProvider - } - } else { - err = dnsx.ErrNoProxyProvider - log.W("doh: no proxy %s ech? %t; err: %v", pid, useech, err) - } - return -} - -func (t *transport) do(pid string, req *http.Request) (ans []byte, rpid, blocklists, region string, withech bool, elapsed time.Duration, qerr *dnsx.QueryError) { - var server net.Addr - var conn net.Conn - start := time.Now() - // either t.hostname or t.odohtargetname or t.odohproxy - hostname := req.URL.Hostname() - - // Error cleanup function. If the query fails, this function will close the - // underlying socket and disconfirm the server IP. Empirically, sockets often - // become unresponsive after a network change, causing timeouts on all requests. - defer func() { - elapsed = time.Since(start) - - // server addr would be of relay / proxy (ex: 127.0.0.1:9050) if used - usedrelay := len(t.relay) > 0 - usedproxy := !dnsx.IsLocalProxy(pid) // pid == dnsx.NetNoProxy => ipn.Block - hasserveraddr := server != nil && !usedrelay && !usedproxy - - if hostname != t.hostname { - log.I("doh: redirected %s => %s", t.hostname, hostname) - t.hostname = hostname - } - if hasserveraddr { - if qerr == nil { - // record a working IP address for this server - dialers.Confirm3(hostname, server) - return - } else { - ok := dialers.Disconfirm3(hostname, server) - log.D("doh: disconfirming %s, %s done? %t", hostname, server, ok) - } - } - if qerr != nil { - log.I("doh: close failing doh conn %s; why? %v", hostname, qerr) - core.CloseConn(conn) - } - }() - - // Add a trace to the request in order to expose the server's IP address. - // Only GotConn performs any action; the other methods just provide debug logs. - // GotConn runs before client.Do() returns, so there is no data race when - // reading the variables it has set. - trace := httptrace.ClientTrace{ - GotConn: func(info httptrace.GotConnInfo) { - log.V("doh: got-conn(%v)", info) - if info.Conn == nil { - return - } - conn = info.Conn - // info.Conn is a DuplexConn, so RemoteAddr is actually a TCPAddr. - // if the conn is proxied, then RemoteAddr is that of the proxy - server = conn.RemoteAddr() - }, - ConnectStart: func(network, addr string) { - start = time.Now() // re...start - log.VV("doh: connect-start(%s, %s)", network, addr) - }, - TLSHandshakeDone: func(state tls.ConnectionState, err error) { - log.VV("doh: %s tls%d (resumed? %t, done? %t, ech? %t); err? %v", - state.ServerName, state.Version, state.DidResume, state.HandshakeComplete, state.ECHAccepted, err) - withech = state.ECHAccepted - }, - WroteRequest: func(info httptrace.WroteRequestInfo) { - log.VV("doh: wrote-req(%v)", info) - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), &trace)) - - log.VV("doh: sending query to: %s", t.hostname) - - res, rpid, echdialer, err := t.fetch(pid, req) - - withech = withech || echdialer - - if err != nil || res == nil { - qerr = dnsx.NewSendFailedQueryError(err) - return - } - - blocklists, region = t.rdnsHeaders(&res.Header) - // todo: check if content-type is [doh|odoh] mime type - - ans, err = io.ReadAll(res.Body) - if err != nil { - qerr = dnsx.NewSendFailedQueryError(err) - return - } - core.Close(res.Body) - if settings.Debug { - log.V("doh: closed response of sz %d; used ech? %t", len(ans), withech) - } - - // update the hostname, which could have changed due to a redirect - // for ex, 1.1.1.1 or cloudflare-dns.com => one.one.one.one - hostname = res.Request.URL.Hostname() - - sc := res.StatusCode - if sc != http.StatusOK { // 4xx - if sc >= http.StatusBadRequest && sc < http.StatusInternalServerError { - qerr = dnsx.NewClientQueryError(fmt.Errorf("http-status: %d", sc)) - } else { - qerr = dnsx.NewTransportQueryError(fmt.Errorf("http-status: %d", sc)) - } - if !t.preferGET { // flip on 404 or 405; then remain on GET - t.preferGET = sc == http.StatusMethodNotAllowed || sc == http.StatusNotFound - } - return - } - - return -} - -func (t *transport) send(pid string, req *http.Request) (msg *dns.Msg, rpid, blocklists, region string, ech bool, elapsed time.Duration, qerr *dnsx.QueryError) { - var ans []byte - var err error - ans, rpid, blocklists, region, ech, elapsed, qerr = t.do(pid, req) - if qerr != nil { - return - } - msg, err = xdns.AsMsg2(ans) - if msg == nil { - qerr = dnsx.NewBadResponseQueryError(fmt.Errorf("parse err: %v", err)) - return - } - return -} - -func (t *transport) rdnsHeaders(h *http.Header) (blocklistStamp, region string) { - if h == nil { // should not be nil - return - } - blocklistStamp = h.Get(xdns.GetBlocklistStampHeaderKey()) - // X-Nile-Region:[sin] - region = h.Get(xdns.GetRethinkDNSRegionHeaderKey1()) - if len(region) <= 0 { - // Cf-Ray:[d1e2a3d4b5e6e7f8-SIN] - if ck := h.Get(xdns.GetRethinkDNSRegionHeaderKey2()); len(ck) > 0 { - _, region, _ = strings.Cut(ck, "-") - } - } - // too long: - // log.VV("doh: header %s; region %s; stamp %v", h, region, blocklistStamp) - return -} - -func (t *transport) asDohRequest(msg *dns.Msg) (req *http.Request, err error) { - var q []byte - q, err = msg.Pack() - if err != nil { - return - } - if t.preferGET { - url := t.url + "?dns=" + base64.RawURLEncoding.EncodeToString(q) - req, err = http.NewRequest(http.MethodGet, url, nil) - } else { - req, err = http.NewRequest(http.MethodPost, t.url, bytes.NewBuffer(q)) - } - if err != nil { - return - } - req.Header.Set("content-type", dohmimetype) - req.Header.Set("accept", dohmimetype) - if settings.SetUserAgent.Load() { - req.Header.Set("user-agent", settings.IntraUa) - } - return -} - -func (t *transport) ID() string { - return t.id -} - -func (t *transport) Type() string { - return t.typ -} - -func (t *transport) chooseProxy(pids ...string) string { - host, port := t.hostport() - return dnsx.ChooseHealthyProxyHostPort("doh: "+t.id, host, port, pids, t.proxies) -} - -func (t *transport) hostport() (addr string, port uint16) { - addr = t.hostname - port = t.port - if t.typ == dnsx.ODOH && len(t.odohproxyname) > 0 { - addr = t.odohproxyname - port = t.odohproxyport - } - return -} - -func (t *transport) Query(network string, q *dns.Msg, smm *x.DNSSummary) (r *dns.Msg, err error) { - var rpid, pid, blocklists, region string - var ech bool - var elapsed time.Duration - var qerr *dnsx.QueryError - - canproxy := dnsx.CanUseProxy(t.id) - if !canproxy { // bootstrap/default may not be proxied - pid = dnsx.NetBaseProxy - } else if r := t.relay; len(r) > 0 { - pid = t.chooseProxy(r) - } else { - _, pids := xdns.Net2ProxyID(network) - pid = t.chooseProxy(pids...) - } - - if t.typ == dnsx.DOH { - r, rpid, blocklists, region, ech, elapsed, qerr = t.doDoh(pid, q) - } else { - r, ech, elapsed, qerr = t.doOdoh(pid, q) - } - - smm.Server = t.getAddr() - if ech { - smm.Server = dnsx.EchPrefix + smm.Server - } - - status := dnsx.Complete - - if qerr != nil { - status = qerr.Status() - err = qerr.Unwrap() - } - t.status.Store(status) - - t.est.Add(elapsed.Seconds()) - smm.Latency = elapsed.Seconds() - smm.RData = xdns.GetInterestingRData(r) - smm.RCode = xdns.Rcode(r) - smm.RTtl = xdns.RTtl(r) - smm.Status = status - smm.Region = region - // TODO: smm.BlockedTarget - smm.Blocklists = blocklists - if t.typ == dnsx.ODOH && len(t.odohproxyname) > 0 { - smm.PID = t.odohproxyname // odoh proxy - smm.RPID = pid // other proxy, if any - } else { - smm.PID = pid // proxy, if any - smm.RPID = rpid // hopping proxy, if any - } - if err != nil { - smm.Msg = err.Error() - } - if settings.Debug { - log.V("doh: (p/px/via/can? %s/%s/%s/%t); a:%d/sz:%d/pad:%d, q: %s:%d, data: %s, code: %d, via: %s, err? %v", - network, pid, rpid, canproxy, xdns.Len(r), xdns.Size(r), xdns.EDNS0PadLen(r), smm.QName, smm.QType, smm.RData, smm.RCode, smm.PID, err) - } - return r, err -} - -func (t *transport) P50() int64 { - return t.est.Get() -} - -func (t *transport) GetAddr() string { - return t.getAddr() -} - -func (t *transport) getAddr() string { - addr := t.hostname - if t.typ == dnsx.ODOH { - addr = t.odohtargetname - } - - if t.skipTLSVerify { - addr = dnsx.NoPkiPrefix + addr - } - // doh transports could be "dnsx.Bootstrap" - prefix := dnsx.PrefixFor(t.id) - if len(prefix) > 0 { - addr = prefix + addr - } - return addr -} - -func (t *transport) GetRelay() x.Proxy { - if r := t.relay; len(r) > 0 { - px, _ := t.proxies.ProxyFor(r) - return px - } - return nil -} - -func (t *transport) IPPorts() (ipps []netip.AddrPort) { - addr := t.hostname - port := t.port - if t.typ == dnsx.ODOH && len(t.odohproxyname) > 0 { - addr = t.odohproxyname - port = t.odohproxyport - } - for _, ip := range dialers.For(addr) { - ipps = append(ipps, netip.AddrPortFrom(ip, port)) - } - return // may be nil -} - -func (t *transport) Status() int { - if px := t.GetRelay(); px != nil { - if px.Status() == ipn.TPU { // relay paused => transport paused - return dnsx.Paused - } - } - return t.status.Load() -} - -func (t *transport) Stop() error { - t.status.Store(dnsx.DEnd) - t.done() - return nil -} - -func logeif(cond bool) log.LogFn { - if cond { - return log.E - } - return log.D -} diff --git a/intra/doh/doh_test.go b/intra/doh/doh_test.go deleted file mode 100644 index a2a2d015..00000000 --- a/intra/doh/doh_test.go +++ /dev/null @@ -1,848 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build ignore - -package doh - -import ( - "bytes" - "encoding/binary" - "errors" - "io" - "io/ioutil" - "net" - "net/http" - "net/http/httptrace" - "net/url" - "reflect" - "testing" - - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/xdns" - "golang.org/x/net/dns/dnsmessage" -) - -var testURL = "https://dns.google/dns-query" -var ips = []string{ - "8.8.8.8", - "8.8.4.4", - "2001:4860:4860::8888", - "2001:4860:4860::8844", -} -var parsedURL *url.URL - -var simpleQuery dnsmessage.Message = dnsmessage.Message{ - Header: dnsmessage.Header{ - ID: 0xbeef, - Response: false, - OpCode: 0, - Authoritative: false, - Truncated: false, - RecursionDesired: true, - RecursionAvailable: false, - RCode: 0, - }, - Questions: []dnsmessage.Question{ - { - Name: dnsmessage.MustNewName("www.example.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }}, - Answers: []dnsmessage.Resource{}, - Authorities: []dnsmessage.Resource{}, - Additionals: []dnsmessage.Resource{}, -} - -func mustPack(m *dnsmessage.Message) []byte { - packed, err := m.Pack() - if err != nil { - panic(err) - } - return packed -} - -func mustUnpack(q []byte) *dnsmessage.Message { - var m dnsmessage.Message - err := m.Unpack(q) - if err != nil { - panic(err) - } - return &m -} - -var simpleQueryBytes []byte = mustPack(&simpleQuery) - -var compressedQueryBytes []byte = []byte{ - 0xbe, 0xef, // ID - 0x01, // QR, OPCODE, AA, TC, RD - 0x00, // RA, Z, RCODE - 0x00, 0x02, // QDCOUNT = 2 - 0x00, 0x00, // ANCOUNT = 0 - 0x00, 0x00, // NSCOUNT = 0 - 0x00, 0x00, // ARCOUNT = 0 - // Question 1 - 0x03, 'f', 'o', 'o', - 0x03, 'b', 'a', 'r', - 0x00, - 0x00, 0x01, // QTYPE: A query - 0x00, 0x01, // QCLASS: IN - // Question 2 - 0xc0, 12, // Pointer to beginning of "foo.bar." - 0x00, 0x01, // QTYPE: A query - 0x00, 0x01, // QCLASS: IN -} - -var uncompressedQueryBytes []byte = []byte{ - 0xbe, 0xef, // ID - 0x01, // QR, OPCODE, AA, TC, RD - 0x00, // RA, Z, RCODE - 0x00, 0x02, // QDCOUNT = 2 - 0x00, 0x00, // ANCOUNT = 0 - 0x00, 0x00, // NSCOUNT = 0 - 0x00, 0x00, // ARCOUNT = 0 - // Question 1 - 0x03, 'f', 'o', 'o', - 0x03, 'b', 'a', 'r', - 0x00, - 0x00, 0x01, // QTYPE: A query - 0x00, 0x01, // QCLASS: IN - // Question 2 - 0x03, 'f', 'o', 'o', - 0x03, 'b', 'a', 'r', - 0x00, - 0x00, 0x01, // QTYPE: A query - 0x00, 0x01, // QCLASS: IN -} - -func init() { - parsedURL, _ = url.Parse(testURL) -} - -// Check that the constructor works. -func TestNewTransport(t *testing.T) { - _, err := NewTransport("test0", testURL, ips, nil) - if err != nil { - t.Fatal(err) - } -} - -// Check that the constructor rejects unsupported URLs. -func TestBadUrl(t *testing.T) { - _, err := NewTransport("test0", "ftp://www.example.com", nil, nil) - if err == nil { - t.Error("Expected error") - } - _, err = NewTransport("test1", "https://www.example", nil, nil) - if err == nil { - t.Error("Expected error") - } -} - -// Check for failure when the query is too short to be valid. -func TestShortQuery(t *testing.T) { - var qerr *dnsx.QueryError - doh, _ := NewTransport("test0", testURL, ips, nil) - _, err := doh.Query("", []byte{}, nil) - if err == nil { - t.Error("Empty query should fail") - } else if !errors.As(err, &qerr) { - t.Errorf("Wrong error type: %v", err) - } else if qerr.Status() != dnsx.BadQuery { - t.Errorf("Wrong error status: %d", qerr.Status()) - } - - _, err = doh.Query([]byte{1}) - if err == nil { - t.Error("One byte query should fail") - } else if !errors.As(err, &qerr) { - t.Errorf("Wrong error type: %v", err) - } else if qerr.Status() != dnsx.BadQuery { - t.Errorf("Wrong error status: %d", qerr.Status()) - } -} - -// Send a DoH query to an actual DoH server -func TestQueryIntegration(t *testing.T) { - queryData := []byte{ - 111, 222, // [0-1] query ID - 1, 0, // [2-3] flags, RD=1 - 0, 1, // [4-5] QDCOUNT (number of queries) = 1 - 0, 0, // [6-7] ANCOUNT (number of answers) = 0 - 0, 0, // [8-9] NSCOUNT (number of authoritative answers) = 0 - 0, 0, // [10-11] ARCOUNT (number of additional records) = 0 - // Start of first query - 7, 'y', 'o', 'u', 't', 'u', 'b', 'e', - 3, 'c', 'o', 'm', - 0, // null terminator of FQDN (DNS root) - 0, 1, // QTYPE = A - 0, 1, // QCLASS = IN (Internet) - } - - testQuery := func(queryData []byte) { - - doh, err := NewTransport("test", testURL, ips, nil) - if err != nil { - t.Fatal(err) - } - resp, err2 := doh.Query(dnsx.NetTypeUDP, queryData, nil) - if err2 != nil { - t.Fatal(err2) - } - if resp[0] != queryData[0] || resp[1] != queryData[1] { - t.Error("Query ID mismatch") - } - if len(resp) <= len(queryData) { - t.Error("Response is short") - } - } - - testQuery(queryData) - - paddedQueryBytes, err := AddEdnsPadding(simpleQueryBytes) - if err != nil { - t.Fatal(err) - } - - testQuery(paddedQueryBytes) -} - -type testRoundTripper struct { - http.RoundTripper - req chan *http.Request - resp chan *http.Response - err error -} - -func makeTestRoundTripper() *testRoundTripper { - return &testRoundTripper{ - req: make(chan *http.Request), - resp: make(chan *http.Response), - } -} - -func (r *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if r.err != nil { - return nil, r.err - } - r.req <- req - return <-r.resp, nil -} - -// Check that a DNS query is converted correctly into an HTTP query. -func TestRequest(t *testing.T) { - doh, _ := NewTransport(testURL, ips, nil, nil, nil) - transport := doh.(*transport) - rt := makeTestRoundTripper() - transport.client.Transport = rt - go doh.Query("", simpleQueryBytes, nil) - req := <-rt.req - if req.URL.String() != testURL { - t.Errorf("URL mismatch: %s != %s", req.URL.String(), testURL) - } - reqBody, err := ioutil.ReadAll(req.Body) - if err != nil { - t.Error(err) - } - if len(reqBody)%PaddingBlockSize != 0 { - t.Errorf("reqBody has unexpected length: %d", len(reqBody)) - } - // Parse reqBody into a Message. - newQuery := mustUnpack(reqBody) - // Ensure the converted request has an ID of zero. - if newQuery.Header.ID != 0 { - t.Errorf("Unexpected request header id: %v", newQuery.Header.ID) - } - // Check that all fields except for Header.ID and Additionals - // are the same as the original. Additionals may differ if - // padding was added. - if !queriesMostlyEqual(simpleQuery, *newQuery) { - t.Errorf("Unexpected query body:\n\t%v\nExpected:\n\t%v", newQuery, simpleQuery) - } - contentType := req.Header.Get("Content-Type") - if contentType != "application/dns-message" { - t.Errorf("Wrong content type: %s", contentType) - } - accept := req.Header.Get("Accept") - if accept != "application/dns-message" { - t.Errorf("Wrong Accept header: %s", accept) - } -} - -// Check that all fields of m1 match those of m2, except for Header.ID -// and Additionals. -func queriesMostlyEqual(m1 dnsmessage.Message, m2 dnsmessage.Message) bool { - // Make fields we don't care about match, so that equality check is easy. - m1.Header.ID = m2.Header.ID - m1.Additionals = m2.Additionals - return reflect.DeepEqual(m1, m2) -} - -// Check that a DOH response is returned correctly. -func TestResponse(t *testing.T) { - doh, _ := NewTransport("test0", testURL, ips, nil) - transport := doh.(*transport) - rt := makeTestRoundTripper() - transport.client.Transport = rt - - // Fake server. - go func() { - <-rt.req - r, w := io.Pipe() - rt.resp <- &http.Response{ - StatusCode: 200, - Body: r, - Request: &http.Request{URL: parsedURL}, - } - // The DOH response should have a zero query ID. - var modifiedQuery dnsmessage.Message = simpleQuery - modifiedQuery.Header.ID = 0 - w.Write(mustPack(&modifiedQuery)) - w.Close() - }() - - resp, err := doh.Query("", simpleQueryBytes, nil) - if err != nil { - t.Error(err) - } - - // Parse the response as a DNS message. - respParsed := mustUnpack(resp) - - // Query() should reconstitute the query ID in the response. - if respParsed.Header.ID != simpleQuery.Header.ID || - !queriesMostlyEqual(*respParsed, simpleQuery) { - t.Errorf("Unexpected response %v", resp) - } -} - -// Simulate an empty response. (This is not a compliant server -// behavior.) -func TestEmptyResponse(t *testing.T) { - doh, _ := NewTransport("test0", testURL, ips, nil) - transport := doh.(*transport) - rt := makeTestRoundTripper() - transport.client.Transport = rt - - // Fake server. - go func() { - <-rt.req - // Make an empty body. - r, w := io.Pipe() - w.Close() - rt.resp <- &http.Response{ - StatusCode: 200, - Body: r, - Request: &http.Request{URL: parsedURL}, - } - }() - - _, err := doh.Query("", simpleQueryBytes, nil) - var qerr *dnsx.QueryError - if err == nil { - t.Error("Empty body should cause an error") - } else if !errors.As(err, &qerr) { - t.Errorf("Wrong error type: %v", err) - } else if qerr.Status() != dnsx.BadResponse { - t.Errorf("Wrong error status: %d", qerr.Status()) - } -} - -// Simulate a non-200 HTTP response code. -func TestHTTPError(t *testing.T) { - doh, _ := NewTransport("test0", testURL, ips, nil) - transport := doh.(*transport) - rt := makeTestRoundTripper() - transport.client.Transport = rt - - go func() { - <-rt.req - r, w := io.Pipe() - rt.resp <- &http.Response{ - StatusCode: 500, - Body: r, - Request: &http.Request{URL: parsedURL}, - } - w.Write([]byte{0, 0, 8, 9, 10}) - w.Close() - }() - - _, err := doh.Query("", simpleQueryBytes, nil) - var qerr *dnsx.QueryError - if err == nil { - t.Error("Empty body should cause an error") - } else if !errors.As(err, &qerr) { - t.Errorf("Wrong error type: %v", err) - } else if qerr.Status() != dnsx.TransportError { - t.Errorf("Wrong error status: %d", qerr.Status()) - } -} - -// Simulate an HTTP query error. -func TestSendFailed(t *testing.T) { - doh, _ := NewTransport("test0", testURL, ips, nil) - transport := doh.(*transport) - rt := makeTestRoundTripper() - transport.client.Transport = rt - - rt.err = errors.New("test") - _, err := doh.Query("", simpleQueryBytes, nil) - var qerr *dnsx.QueryError - if err == nil { - t.Error("Send failure should be reported") - } else if !errors.As(err, &qerr) { - t.Errorf("Wrong error type: %v", err) - } else if qerr.Status() != dnsx.SendFailed { - t.Errorf("Wrong error status: %d", qerr.Status()) - } else if !errors.Is(qerr, rt.err) { - t.Errorf("Underlying error is not retained") - } -} - -type fakeListener struct { - dnsx.DNSListener - summary *dnsx.Summary -} - -func (l *fakeListener) OnQuery(domain string, qtype int, sug string) string { - return "" -} - -func (l *fakeListener) OnResponse(summ *dnsx.Summary) { - l.summary = summ -} - -type fakeConn struct { - net.TCPConn - remoteAddr *net.TCPAddr -} - -func (c *fakeConn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -// Check that the DNSListener is called with a correct summary. -func TestListener(t *testing.T) { - listener := &fakeListener{} - doh, _ := NewTransport("test0", testURL, ips, nil) - transport := doh.(*transport) - rt := makeTestRoundTripper() - transport.client.Transport = rt - - go func() { - req := <-rt.req - trace := httptrace.ContextClientTrace(req.Context()) - trace.GotConn(httptrace.GotConnInfo{ - Conn: &fakeConn{ - remoteAddr: &net.TCPAddr{ - IP: net.ParseIP("192.0.2.2"), - Port: 443, - }}}) - - r, w := io.Pipe() - rt.resp <- &http.Response{ - StatusCode: 200, - Body: r, - Request: &http.Request{URL: parsedURL}, - } - w.Write([]byte{0, 0, 8, 9, 10}) - w.Close() - }() - - doh.Query("", simpleQueryBytes, listener.summary) - s := listener.summary - if s.Server != "192.0.2.2" { - t.Errorf("Wrong server IP string: %s", s.Server) - } - if s.Status != dnsx.Complete { - t.Errorf("Wrong status: %d", s.Status) - } -} - -type socket struct { - r io.ReadCloser - w io.WriteCloser -} - -func (c *socket) Read(b []byte) (int, error) { - return c.r.Read(b) -} - -func (c *socket) Write(b []byte) (int, error) { - return c.w.Write(b) -} - -func (c *socket) Close() error { - e1 := c.r.Close() - e2 := c.w.Close() - if e1 != nil { - return e1 - } - return e2 -} - -func makePair() (io.ReadWriteCloser, io.ReadWriteCloser) { - r1, w1 := io.Pipe() - r2, w2 := io.Pipe() - return &socket{r1, w2}, &socket{r2, w1} -} - -type fakeTransport struct { - dnsx.Transport - query chan []byte - response chan []byte - err error -} - -func (t *fakeTransport) Query(q []byte) ([]byte, error) { - t.query <- q - if t.err != nil { - return nil, t.err - } - return <-t.response, nil -} - -func (t *fakeTransport) GetURL() string { - return "fake" -} - -func (t *fakeTransport) Close() { - t.err = errors.New("closed") - close(t.query) - close(t.response) -} - -func newFakeTransport() *fakeTransport { - return &fakeTransport{ - query: make(chan []byte), - response: make(chan []byte), - } -} - -// Test a successful query over TCP -func TestAccept(t *testing.T) { - doh := newFakeTransport() - client, server := makePair() - - // Start the forwarder running. - go Accept(doh, server) - - lbuf := make([]byte, 2) - // Send Query - queryData := simpleQueryBytes - binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) - n, err := client.Write(lbuf) - if err != nil { - t.Fatal(err) - } - if n != 2 { - t.Error("Length write problem") - } - n, err = client.Write(queryData) - if err != nil { - t.Fatal(err) - } - if n != len(queryData) { - t.Error("Query write problem") - } - - // Read query - queryRead := <-doh.query - if !bytes.Equal(queryRead, queryData) { - t.Error("Query mismatch") - } - - // Send fake response - responseData := []byte{1, 2, 8, 9, 10} - doh.response <- responseData - - // Get Response - n, err = client.Read(lbuf) - if err != nil { - t.Fatal(err) - } - if n != 2 { - t.Error("Length read problem") - } - rlen := binary.BigEndian.Uint16(lbuf) - resp := make([]byte, int(rlen)) - n, err = client.Read(resp) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(responseData, resp) { - t.Error("Response mismatch") - } - - client.Close() -} - -// Sends a TCP query that results in failure. When a query fails, -// Accept should close the TCP socket. -func TestAcceptFail(t *testing.T) { - doh := newFakeTransport() - client, server := makePair() - - // Start the forwarder running. - go Accept(doh, server) - - lbuf := make([]byte, 2) - // Send Query - queryData := simpleQueryBytes - binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) - client.Write(lbuf) - client.Write(queryData) - - // Indicate that the query failed - doh.err = errors.New("fake error") - - // Read query - queryRead := <-doh.query - if !bytes.Equal(queryRead, queryData) { - t.Error("Query mismatch") - } - - // Accept should have closed the socket. - n, _ := client.Read(lbuf) - if n != 0 { - t.Error("Expected to read 0 bytes") - } -} - -// Sends a TCP query, and closes the socket before the response is sent. -// This tests for crashes when a response cannot be delivered. -func TestAcceptClose(t *testing.T) { - doh := newFakeTransport() - client, server := makePair() - - // Start the forwarder running. - go Accept(doh, server) - - lbuf := make([]byte, 2) - // Send Query - queryData := simpleQueryBytes - binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) - client.Write(lbuf) - client.Write(queryData) - - // Read query - queryRead := <-doh.query - if !bytes.Equal(queryRead, queryData) { - t.Error("Query mismatch") - } - - // Close the TCP connection - client.Close() - - // Send fake response too late. - responseData := []byte{1, 2, 8, 9, 10} - doh.response <- responseData -} - -// Test failure due to a response that is larger than the -// maximum message size for DNS over TCP (65535). -func TestAcceptOversize(t *testing.T) { - doh := newFakeTransport() - client, server := makePair() - - // Start the forwarder running. - go Accept(doh, server) - - lbuf := make([]byte, 2) - // Send Query - queryData := simpleQueryBytes - binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) - client.Write(lbuf) - client.Write(queryData) - - // Read query - <-doh.query - - // Send oversize response - doh.response <- make([]byte, 65536) - - // Accept should have closed the socket because the response - // cannot be written. - n, _ := client.Read(lbuf) - if n != 0 { - t.Error("Expected to read 0 bytes") - } -} - -func TestComputePaddingSize(t *testing.T) { - if computePaddingSize(100-kOptPaddingHeaderLen, 100) != 0 { - t.Errorf("Expected no padding") - } - if computePaddingSize(200-kOptPaddingHeaderLen, 100) != 0 { - t.Errorf("Expected no padding") - } - if computePaddingSize(190-kOptPaddingHeaderLen, 100) != 10 { - t.Errorf("Expected to pad up to next block") - } -} - -func TestAddEdnsPaddingIdempotent(t *testing.T) { - padded, err := AddEdnsPadding(simpleQueryBytes) - if err != nil { - t.Fatal(err) - } - paddedAgain, err := AddEdnsPadding(padded) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(padded, paddedAgain) { - t.Errorf("Padding should be idempotent\n%v\n%v", padded, paddedAgain) - } -} - -// Sanity check that packing |compressedQueryBytes| constructs the same query -// byte-for-byte. -func TestDnsMessageCompressedQuerySanityCheck(t *testing.T) { - m := mustUnpack(compressedQueryBytes) - packedBytes := mustPack(m) - if len(packedBytes) != len(compressedQueryBytes) { - t.Errorf("Packed query has different size than original:\n %v\n %v", packedBytes, compressedQueryBytes) - } -} - -// Sanity check that packing |uncompressedQueryBytes| constructs a smaller -// query byte-for-byte, since label compression is enabled by default. -func TestDnsMessageUncompressedQuerySanityCheck(t *testing.T) { - m := mustUnpack(uncompressedQueryBytes) - packedBytes := mustPack(m) - if len(packedBytes) >= len(uncompressedQueryBytes) { - t.Errorf("Compressed query is not smaller than uncompressed query") - } -} - -// Check that we correctly pad an uncompressed query to the nearest block. -func TestAddEdnsPaddingUncompressedQuery(t *testing.T) { - if len(uncompressedQueryBytes)%PaddingBlockSize == 0 { - t.Errorf("uncompressedQueryBytes does not require padding, so this test is invalid") - } - padded, err := AddEdnsPadding(uncompressedQueryBytes) - if err != nil { - panic(err) - } - if len(padded)%PaddingBlockSize != 0 { - t.Errorf("AddEdnsPadding failed to correctly pad uncompressed query") - } -} - -// Check that we correctly pad a compressed query to the nearest block. -func TestAddEdnsPaddingCompressedQuery(t *testing.T) { - if len(compressedQueryBytes)%PaddingBlockSize == 0 { - t.Errorf("compressedQueryBytes does not require padding, so this test is invalid") - } - padded, err := AddEdnsPadding(compressedQueryBytes) - if err != nil { - panic(err) - } - if len(padded)%PaddingBlockSize != 0 { - t.Errorf("AddEdnsPadding failed to correctly pad compressed query") - } -} - -// Try to pad a query that already contains an OPT record, but no padding option. -func TestAddEdnsPaddingCompressedOptQuery(t *testing.T) { - optQuery := simpleQuery - optQuery.Additionals = make([]dnsmessage.Resource, len(simpleQuery.Additionals)) - copy(optQuery.Additionals, simpleQuery.Additionals) - - optQuery.Additionals = append(optQuery.Additionals, - dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName("."), - Class: dnsmessage.ClassINET, - TTL: 0, - }, - Body: &dnsmessage.OPTResource{ - Options: []dnsmessage.Option{}, - }, - }, - ) - paddedOnWire, err := AddEdnsPadding(mustPack(&optQuery)) - if err != nil { - t.Errorf("Failed to pad query with OPT but no padding: %v", err) - } - if len(paddedOnWire)%PaddingBlockSize != 0 { - t.Errorf("AddEdnsPadding failed to correctly pad query with OPT but no padding") - } -} - -// Try to pad a query that already contains an OPT record with padding. The -// query should be unmodified by AddEdnsPadding. -func TestAddEdnsPaddingCompressedPaddedQuery(t *testing.T) { - paddedQuery := simpleQuery - paddedQuery.Additionals = make([]dnsmessage.Resource, len(simpleQuery.Additionals)) - copy(paddedQuery.Additionals, simpleQuery.Additionals) - - paddedQuery.Additionals = append(paddedQuery.Additionals, - dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName("."), - Class: dnsmessage.ClassINET, - TTL: 0, - }, - Body: &dnsmessage.OPTResource{ - Options: []dnsmessage.Option{ - { - Code: OptResourcePaddingCode, - Data: make([]byte, 5), - }, - }, - }, - }, - ) - originalOnWire := mustPack(&paddedQuery) - - paddedOnWire, err := AddEdnsPadding(mustPack(&paddedQuery)) - if err != nil { - t.Errorf("Failed to pad padded query: %v", err) - } - - if !bytes.Equal(originalOnWire, paddedOnWire) { - t.Errorf("AddEdnsPadding tampered with a query that was already padded") - } -} - -func TestServfail(t *testing.T) { - sf := xdns.Servfail(simpleQueryBytes) - servfail := mustUnpack(sf) - expectedHeader := dnsmessage.Header{ - ID: 0xbeef, - Response: true, - OpCode: 0, - Authoritative: false, - Truncated: false, - RecursionDesired: true, - RecursionAvailable: true, - RCode: 2, - } - if servfail.Header != expectedHeader { - t.Errorf("Wrong header: %v != %v", servfail.Header, expectedHeader) - } - if servfail.Questions[0] != simpleQuery.Questions[0] { - t.Errorf("Wrong question: %v", servfail.Questions[0]) - } -} diff --git a/intra/doh/odoh.go b/intra/doh/odoh.go deleted file mode 100644 index 37eb4526..00000000 --- a/intra/doh/odoh.go +++ /dev/null @@ -1,321 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -package doh - -import ( - "bytes" - "errors" - "io" - "net/http" - "net/url" - "time" - - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/cloudflare/odoh-go" - "github.com/miekg/dns" -) - -// adopted from: github.com/folbricht/routedns/pull/118 -// and: github.com/cloudflare/odoh-client/blob/4762219808/commands/request.go - -// constants from: github.com/cloudflare/odoh-client-go/blob/8d45d054d3/commands/common.go#L4 -const odohmimetype = "application/oblivious-dns-message" -const odohconfigdns = "https://1.1.1.1/dns-query" -const odohtargetscheme = "https" -const odohconfigwkpath = "/.well-known/odohconfigs" -const odohtargetpath = "/dns-query" -const odohproxypath = "/proxy" // dns-query in latest spec -const odohttlsec = 3600 // 1hr - -var ( - errMissingOdohQuery = errors.New("no odoh request") - errMissingOdohCfgQuery = errors.New("no odoh config request") - errMissingOdohCfgResponse = errors.New("no odoh config response") - errZeroOdohCfgs = errors.New("no odoh configs found") - errNoOdohConfigUrl = errors.New("no odoh config url") - errNoOdohTarget = errors.New("no odoh target") -) - -// targets: github.com/DNSCrypt/dnscrypt-resolvers/blob/master/v3/odoh-servers.md -// endpoints: github.com/DNSCrypt/dnscrypt-resolvers/blob/master/v3/odoh-relays.md -func (d *transport) doOdoh(pid string, q *dns.Msg) (res *dns.Msg, ech bool, elapsed time.Duration, qerr *dnsx.QueryError) { - var ans []byte - viaproxy := len(d.odohproxyurl) > 0 - - odohmsg, odohctx, err := d.buildTargetQuery(q) - if err != nil { - log.W("odoh: build target query err: %v", err) - qerr = dnsx.NewInternalQueryError(err) - return - } - - oq := odohmsg.Marshal() - req, err := d.asOdohRequest(oq) - if err != nil { - qerr = dnsx.NewInternalQueryError(err) - return - } - - ans, _, _, _, ech, elapsed, qerr = d.do(pid, req) - if settings.Debug { - log.V("odoh: send; proxy? %t, ech? %t, elapsed: %s; err? %v", - viaproxy, ech, elapsed, qerr) - } - if qerr != nil { - // datatracker.ietf.org/doc/rfc9230 section 4.3 and section 7 - // 401 authorization error on hpke failure - // 400 bad request on padding or other failures - // these are "transport errors", in which case we should retry - // but for now, invalidate cached odoh config, if any - if qerr.Status() == dnsx.ClientError { - d.omu.Lock() - d.odohConfig = nil - d.odohConfigExpiry = time.Now() - d.omu.Unlock() - } - res = xdns.Servfail(q) // servfail on the original query - return - } - - oans, err := odoh.UnmarshalDNSMessage(ans) - if err != nil { - qerr = dnsx.NewBadResponseQueryError(err) - return - } - - ans, err = odohctx.OpenAnswer(oans) - if err != nil { - qerr = dnsx.NewInternalQueryError(err) - return - } - - log.V("odoh: success; res: %d", len(ans)) - res = new(dns.Msg) // unpack into a new msg - if err = res.Unpack(ans); err != nil { - qerr = dnsx.NewBadResponseQueryError(err) - return - } - return -} - -func (d *transport) asOdohRequest(q []byte) (req *http.Request, err error) { - viaproxy := len(d.odohproxyurl) > 0 - // ref: github.com/cloudflare/odoh-client-go/blob/8d45d054d3/commands/request.go#L53 - if viaproxy { - req, err = http.NewRequest(http.MethodPost, d.odohproxyurl, bytes.NewBuffer(q)) - if err != nil { - return - } - if req == nil || req.URL == nil { - err = errMissingOdohQuery - return - } - query := req.URL.Query() - if query == nil { - query = make(url.Values) - } - query.Add("targethost", d.odohtargetname) - query.Add("targetpath", d.odohtargetpath) - req.URL.RawQuery = query.Encode() - } else { - req, err = http.NewRequest(http.MethodPost, d.odohTargetUrl(), bytes.NewBuffer(q)) - if err != nil { - return - } - } - req.Header.Set("user-agent", "") - req.Header.Set("content-type", odohmimetype) - req.Header.Add("accept", odohmimetype) - return -} - -func (d *transport) buildTargetQuery(msg *dns.Msg) (m odoh.ObliviousDNSMessage, ctx odoh.QueryContext, err error) { - ocfg, err := d.fetchTargetConfig() - if err != nil { - return - } - if ocfg == nil { - err = errZeroOdohCfgs - return - } - q, err := msg.Pack() - if err != nil { - return - } - - key := ocfg.Contents - pad := xdns.ComputePaddingSize(msg) - oq := odoh.CreateObliviousDNSQuery(q, uint16(pad)) - if settings.Debug { - log.V("odoh: build-target: odoh qlen: %d / pad: %d", len(oq.DnsMessage), len(oq.Padding)) - } - return key.EncryptQuery(oq) -} - -// Get the current (cached) target config or refresh it if expired. -func (d *transport) fetchTargetConfig() (cfg *odoh.ObliviousDoHConfig, err error) { - d.omu.RLock() - ok1 := d.odohConfig != nil - ok2 := time.Now().Before(d.odohConfigExpiry) - d.omu.RUnlock() - - if ok1 && ok2 { // return cached config - log.V("odoh: fetch-target: using cached config for %s", d.odohtargetname) - return d.odohConfig, nil - } - - var exp time.Time - cfg, exp, err = d.refresh() - d.omu.Lock() - d.odohConfig, d.odohConfigExpiry = cfg, exp // may be nil, 0 on error - d.omu.Unlock() - - log.V("odoh: fetch-target: using refereshed config for %s; expiring: %s", d.odohtargetname, exp) - return -} - -func (d *transport) refresh() (cfg *odoh.ObliviousDoHConfig, exp time.Time, err error) { - first := d.refreshTargetKeyDNS - second := d.refreshTargetKeyWellKnown - if d.preferWK { - first = d.refreshTargetKeyWellKnown - second = d.refreshTargetKeyDNS - } - - if cfg, exp, err = first(); err != nil { - d.preferWK = !d.preferWK - if cfg, exp, err = second(); err != nil { - return - } - } - log.V("odoh: fetch-target: %s; expiring: %s", d.odohtargetname, exp) - return -} - -func (d *transport) refreshTargetKeyWellKnown() (ocfg *odoh.ObliviousDoHConfig, exp time.Time, err error) { - var req *http.Request - var resp *http.Response - - req, err = http.NewRequest(http.MethodGet, d.odohConfigUrl(), nil) - if err != nil { - return - } - if req == nil { - err = errMissingOdohCfgQuery - return - } - // may use insecure TLS if user opts in; ref: d.tlsconfig - resp, _, _, err = d.fetch(dnsx.NetBaseProxy, req) - if err != nil { - return - } - if resp == nil || resp.Body == nil { - err = errMissingOdohCfgResponse - return - } - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return - } - - ocfgs, err := odoh.UnmarshalObliviousDoHConfigs(bodyBytes) - if err != nil { - log.W("odoh: refresh-target-wk: unmarshal config err: %v", err) - return - } else if len(ocfgs.Configs) <= 0 { - log.W("odoh: refresh-target-wk: no configs found") - err = errZeroOdohCfgs - return - } - ocfg = &ocfgs.Configs[0] - exp = time.Now().Add(odohttlsec * time.Second) - log.V("odoh: refresh-target-wk: %s; %v; expiring: %s", d.odohtargetname, ocfg, exp) - return -} - -func (d *transport) refreshTargetKeyDNS() (ocfg *odoh.ObliviousDoHConfig, exp time.Time, err error) { - cmsg := new(dns.Msg) - cmsg.SetQuestion(dns.Fqdn(d.odohtargetname), dns.TypeHTTPS) - - var cres *dns.Msg - // fetch odoh-config from the default dns - if cres, err = dialers.Query(cmsg); err != nil { - var req *http.Request - // fetch odoh-config from odohconfigdns - if req, err = d.asDohRequest(cmsg); err == nil { - cres, _, _, _, _, _, err = d.send(dnsx.NetBaseProxy, req) - } - } - - if err != nil { - log.E("odoh: refresh-target: query err %v", err) - return - } - - if cres == nil || !xdns.HasAnyAnswer(cres) { - log.W("odoh: refresh-target: no config ans") - err = errMissingOdohCfgResponse - return - } - - for _, rec := range cres.Answer { - https, ok := rec.(*dns.HTTPS) - if !ok { - log.V("odoh: refresh-target: config not a https record; next") - continue - } - ttlsec := time.Duration(rec.Header().Ttl) * time.Second - for _, kv := range https.Value { - if kv.Key() != 32769 { // up until draft-06, the key was 0x8001 - log.D("odoh: refresh-target: unexpected https record key; next") - continue - } - var ocfgs odoh.ObliviousDoHConfigs - if svcblocal, ok := kv.(*dns.SVCBLocal); ok { - ocfgs, err = odoh.UnmarshalObliviousDoHConfigs(svcblocal.Data) - if err != nil { - log.W("odoh: refresh-target: unmarshal config err: %v", err) - return - } else if len(ocfgs.Configs) <= 0 { - log.W("odoh: refresh-target: no configs found") - err = errZeroOdohCfgs - return - } - ocfg = &ocfgs.Configs[0] - exp = time.Now().Add(ttlsec) - log.V("odoh: refresh-target: %s; %v; expiring: %s", d.odohtargetname, ocfg, exp) - return - } else { - log.D("odoh: refresh-target: not a svcblocal value; next") - } - } - } - - log.W("odoh: refresh-target: no config in https/svcb %d", len(cres.Answer)) - log.V("odoh: refresh-target: dns ans %v", cres.Answer) - err = errMissingOdohCfgResponse - return -} - -func (d *transport) odohTargetUrl() string { - u := new(url.URL) - u.Scheme = odohtargetscheme - u.Path = d.odohtargetpath - u.Host = d.odohtargetname - return u.String() -} - -func (d *transport) odohConfigUrl() string { - u := new(url.URL) - u.Scheme = odohtargetscheme - u.Path = odohconfigwkpath - u.Host = d.odohtargetname - return u.String() -} diff --git a/intra/doh/padding.go b/intra/doh/padding.go deleted file mode 100644 index 8fd29359..00000000 --- a/intra/doh/padding.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package doh - -import ( - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -// padQuery adds EDNS padding (RFC7830) to msg. -func padQuery(msg *dns.Msg) { - defer core.Recover(core.DontExit, msg) - xdns.AddEDNS0PaddingIfNoneFound(msg) -} diff --git a/intra/icmp.go b/intra/icmp.go deleted file mode 100644 index 2c52e5c5..00000000 --- a/intra/icmp.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "context" - "net" - "net/netip" - "time" - - "golang.org/x/sys/unix" - - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/netstack" -) - -type icmpHandler struct { - *baseHandler -} - -var _ netstack.GICMPHandler = (*icmpHandler)(nil) - -func NewICMPHandler(pctx context.Context, resolver dnsx.Resolver, prox ipn.ProxyProvider, listener Listener) netstack.GICMPHandler { - h := &icmpHandler{ - baseHandler: newBaseHandler(pctx, "icmp", resolver, prox, listener), - } - - core.Gx("icmp.ps", h.processSummaries) - - log.I("icmp: new handler created") - return h -} - -// Ping implements netstack.GICMPHandler. Takes ownership of msg. -// Nb: to send icmp pings, root access is required; and so, -// send "unprivileged" icmp pings via udp reqs; which do -// work on Vanilla Android, because ping_group_range is -// set to 0 2147483647 -// ref: cs.android.com/android/platform/superproject/+/master:system/core/rootdir/init.rc;drc=eef0f563fd2d16343aa1ac01eebe98126f26e352;l=297 -// ref: androidxref.com/9.0.0_r3/xref/libcore/luni/src/test/java/libcore/java/net/InetAddressTest.java#265 -// see: sturmflut.github.io/linux/ubuntu/2015/01/17/unprivileged-icmp-sockets-on-linux/ -// ex: github.com/prometheus-community/pro-bing/blob/0bacb2d5e/ping.go#L703 -func (h *icmpHandler) Ping(msg []byte, source, target netip.AddrPort) (echoed bool) { - var px ipn.Proxy = nil - var err error - var tx, rx int - var rtt time.Duration - - // flow is alg/nat-aware, do not change target or any addrs - res, undidAlg, realips, doms := h.onFlow(source, target) - - h.maybeReplaceDest(res, &target) - - preferred, _, _ := filterFamilyForDialing(realips) - dst := oneRealIPPort(preferred, target, !undidAlg) - // on Android, uid is always "unknown" for icmp - cid, uid, _, pids := h.judge(res) - smm := icmpSummary(cid, uid) - - defer func() { - smm.PID = pidstr(px) - smm.RPID = ipn.ViaID(px) - smm.Tx = int64(tx) - smm.Rx = int64(rx) - smm.Rtt = rtt.Milliseconds() - smm.Target = dst.Addr().String() - h.queueSummary(smm.done(err)) // err may be nil - }() - - if h.status.Load() == HDLEND { - err = log.EE("t.icmp: handler ended (%s => %s)", source, target) - return false // not handled - } - - if isAnyBlockPid(pids) { - smm.PID = ipn.Block - if undidAlg && len(realips) <= 0 && len(doms) > 0 { - err = errNoIPsForDomain - } else { - err = errIcmpFirewalled - } - log.I("t.icmp: egress: firewalled %s => %s", source, target) - // sleep for a while to avoid busy conns? will also block netstack - // see: netstack/dispatcher.go:newReadvDispatcher - // time.Sleep(blocktime) - return false // denied - } - - if px, err = h.prox.ProxyTo(dst, uid, pids); err != nil || px == nil { - err = log.EE("t.icmp: egress: no proxy(%s); err %v", pids, err) - return false // denied - } - - rttstart := time.Now() - proto, anyaddr := anyaddrFor(dst) - - uc, err := px.Dialer().Probe(proto, anyaddr) - defer core.Close(uc) - ucnil := uc == nil || core.IsNil(uc) - - // nilaway: tx.socks5 returns nil conn even if err == nil - if err != nil || ucnil { - err = core.OneErr(err, unix.ENETUNREACH) - err = log.EE("t.icmp: egress: dial(%s); hasConn? %s(%t); err %v", - dst, pids, !ucnil, err) - return false // unhandled - } - - h.conntracker.Track(cid, uid, pidstr(px), uc) - defer h.conntracker.Untrack(cid) - - awaited := core.Await(func() { - h.listener.PostFlow(smm.postMark()) - }, onFlowTimeout) - - tx = len(msg) - // todo: construct ICMP header? github.com/prometheus-community/pro-bing/blob/0bacb2d5e7/ping.go#L717 - reply, from, err := core.Echo(uc, msg, net.UDPAddrFromAddrPort(dst), target.Addr().Is4()) - rx = len(reply) - rtt = time.Since(rttstart) - // todo: ignore non-ICMP replies in b: github.com/prometheus-community/pro-bing/blob/0bacb2d5e7/ping.go#L630 - logev(err)("t.icmp: ingress: read(%v <= %v / %v) ping done (send: %d, recv: %d, rtt: %s); postflow? %t; err? %v", - source, from, dst, tx, rx, core.FmtPeriod(rtt), awaited, err) - - // TODO: on timeout errs, return false? - return true // echoed -} diff --git a/intra/ipn/auto.go b/intra/ipn/auto.go deleted file mode 100644 index 562c9f48..00000000 --- a/intra/ipn/auto.go +++ /dev/null @@ -1,520 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "net" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" -) - -const ( - ttl30s = 30 * time.Second - shortdelay = 100 * time.Millisecond - delayForUnhealthyProxies = 2 * time.Second -) - -// exit is a proxy that always dials out to the internet. -type auto struct { - NoDNS - ProtoAgnostic - SkipRefresh - CantPause - GW - pxr ProxyProvider - addr string - - via *core.WeakRef[Proxy] // via dialer - viaID *core.Volatile[string] // via ID - - exp *core.Sieve[string, int] - ba *core.Barrier[bool, string] - status *core.Volatile[int] -} - -// NewAutoProxy returns a new exit proxy. -func NewAutoProxy(ctx context.Context, pxr Proxies) *auto { - var err error - - h := &auto{ - pxr: pxr, - viaID: core.NewZeroVolatile[string](), - addr: "127.5.51.52:5321", - exp: core.NewSieve[string, int](ctx, ttl30s), - ba: core.NewBarrier[bool](ttl30s), - status: core.NewVolatile(TUP), - } - h.via, err = core.NewWeakRef(h.viafor, viaok) - if err != nil { - panic(err) // unlikely - } - return h -} - -func (h *auto) viafor() *Proxy { - return viafor(idstr(h), h.viaID.Load(), h.pxr) -} - -func (h *auto) swapVia(new Proxy) Proxy { - return swapVia(idstr(h), new, h.viaID, h.via) -} - -// Handle implements Proxy. -func (h *auto) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *auto) DialerHandle() (mix uintptr) { - remoteOnly := settings.AutoAlwaysRemote() - if !remoteOnly { - if exit, _ := h.pxr.ProxyFor(Exit); exit != nil { - mix ^= exit.DialerHandle() - } - if exit64, _ := h.pxr.ProxyFor(Rpn64); exit64 != nil { - mix ^= exit64.DialerHandle() - } - } - if win, _ := h.pxr.mainRpnProxyOf(RpnWin); win != nil { - mix ^= win.DialerHandle() - } - - return mix -} - -// Dial implements Proxy. -func (h *auto) Dial(network, addr string) (protect.Conn, error) { - return h.dial(network, "", addr) -} - -// DialBind implements Proxy. -func (h *auto) DialBind(network, local, remote string) (protect.Conn, error) { - return h.dial(network, local, remote) -} - -func (h *auto) dial(network, laddr, raddr string) (protect.Conn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - - exit, exerr := h.pxr.ProxyFor(Exit) - exit64, ex64err := h.pxr.ProxyFor(Rpn64) - win, winerr := h.pxr.mainRpnProxyOf(RpnWin) - - pxrerrs := core.JoinErr(exerr, winerr, ex64err) - - if usevia(h.viaID) { - if v, vok := h.via.Get(); !vok { - if removeViaOnErrors { - h.Hop(nil, false /*dryrun*/) // stale; unset - } - log.W("proxy: auto: via(%s) failing...", idhandle(v)) - } - } - - remoteOnly := settings.AutoAlwaysRemote() - parallelDial := settings.AutoDialsParallel.Load() - - var c protect.Conn - var err error - - // non-parallel dial states - who := -1 - previdx, recent := h.exp.Get(raddr) - delpin := false - - // parallel dial states - tothealthy := -1 - totdials := -1 - - if !parallelDial { - rpns := []Proxy{exit, exit64, win} - healthy := core.Map( - core.FilterLeft( - rpns, - func(p Proxy) bool { - if p == nil || core.IsNil(p) { - return false // nil proxies out - } - if remoteOnly && local(idstr(p)) { - return false // local proxies out - } - if err := healthy(p); err != nil { - log.D("auto: dial; %s %s not ok; %v: %s", p.ID(), network, err, raddr) - return false // not healthy out - } - return true // ok - }), - func(p Proxy) protect.RDialer { - return p.Dialer() - }) - - tothealthy = len(healthy) - if len(healthy) > 0 { - // dial healthy proxies - c, err = dialAny(healthy, network, laddr, raddr) - totdials = len(healthy) - } else { - // no healthy proxies; fail open - d := core.Map(rpns, func(p Proxy) protect.RDialer { - if p == nil || core.IsNil(p) { - return nil // nil proxies out - } - if remoteOnly && local(idstr(p)) { - return nil // local proxies out - } - return p.Dialer() - }) - totdials = len(d) - if len(d) > 0 { - // dialAny delegates to dialers.DialAny which pins IPs - // to proxies (against their IDs) for 30s. - c, err = dialAny(core.WithoutNils(d), network, laddr, raddr) - } else { - c, err = nil, core.OneErr(pxrerrs, errNoProxyHealthy) - } - } - } else { - c, who, err = core.Race( - network+".dial-auto."+raddr, - tlsHandshakeTimeout, - func(ctx context.Context) (protect.Conn, error) { - const myidx = 0 - if exit == nil { // exit must always be present - return nil, exerr - } - if !remoteOnly { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: // dial ahead - } - } else { - return nil, errNotRemote - } - if recent { - if previdx != myidx { - return nil, errNotPinned - } - // ip pinned to this proxy - return h.dialAlways(exit, network, laddr, raddr) - } - return h.dialIfReachable(exit, network, laddr, raddr) - }, func(ctx context.Context) (protect.Conn, error) { - const myidx = 1 - if exit64 == nil { - return nil, ex64err - } - if remoteOnly { - return nil, errNotRemote - } - if recent { - if previdx != myidx { - return nil, errNotPinned - } - // ip pinned to this proxy - return h.dialAlways(exit64, network, laddr, raddr) - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(shortdelay * myidx): // 300ms - } - return h.dialIfHealthy(exit64, network, laddr, raddr) - }, func(ctx context.Context) (protect.Conn, error) { - const myidx = 3 - if win == nil { - return nil, winerr - } - if recent { - if previdx != myidx { - return nil, errNotPinned - } - // ip pinned to this proxy - return h.dialAlways(win, network, laddr, raddr) - } - - // wait only if exit was used - if !remoteOnly { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(shortdelay * myidx): // 500ms - } - } - return h.dialIfHealthy(win, network, laddr, raddr) - }, - ) - - if err != nil || c == nil || core.IsNil(c) { - h.exp.Del(raddr) - c = nil - delpin = true // remove pin - } else { - h.exp.Put(raddr, who) - } - } - - defer localDialStatus(h.status, err) - - kaenabled := maybeKeepAlive(c) - logei(err)("proxy: auto: w(%d) pin(%t+%t/%d), dial(%s) %s, ka? %t / parallel? %t / remote? %t; tot(healthy %d / dials %d); errs? %v+%v", - who, recent, !delpin, previdx, network, raddr, kaenabled, parallelDial, remoteOnly, tothealthy, totdials, err, pxrerrs) - - return c, err -} - -// Announce implements Proxy. -func (h *auto) Announce(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - - exit, exerr := h.pxr.ProxyFor(Exit) - win, winerr := h.pxr.mainRpnProxyOf(RpnWin) - - previdx, recent := h.exp.Get(local) - - // TODO: announceIfHealthy - c, who, err := core.Race( - network+".announce-auto."+local, - tlsHandshakeTimeout, - func(ctx context.Context) (protect.PacketConn, error) { - const myidx = 0 - if exit == nil { - return nil, exerr - } - if recent { - if previdx != myidx { - return nil, errNotPinned - } - // ip pinned to this proxy - return h.announceIfHealthy(exit, network, local) - } - return h.announceIfHealthy(exit, network, local) - }, func(ctx context.Context) (protect.PacketConn, error) { - const myidx = 1 - if win == nil { - return nil, winerr - } - if recent { - if previdx != myidx { - return nil, errNotPinned - } - // ip pinned to this proxy - return h.announceIfHealthy(win, network, local) - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(shortdelay * myidx): // 100ms - } - return h.announceIfHealthy(win, network, local) - }, - ) - defer localDialStatus(h.status, err) - - log.I("proxy: auto: w(%d) listen(%s) to %s; err? %v", who, network, local, err) - return c, err -} - -// Accept implements Proxy. -func (h *auto) Accept(network, local string) (l protect.Listener, err error) { - if err := candial(h.status); err != nil { - return nil, err - } - if settings.AutoAlwaysRemote() { - log.E("proxy: auto: accept(%s) on %s remote-dial unimplemented", network, local) - return nil, errNotRemote - } - exit, err := h.pxr.ProxyFor(Exit) - if err == nil { - l, err = exit.Dialer().Accept(network, local) - } - defer localDialStatus(h.status, err) - - log.I("proxy: auto: accept(%s) on %s; err? %v", network, local, err) - return l, err -} - -// Probe implements Proxy. -func (h *auto) Probe(network, local string) (pc protect.PacketConn, err error) { - if err := candial(h.status); err != nil { - return nil, err - } - if settings.AutoAlwaysRemote() { - log.E("proxy: auto: probe(%s) on %s remote-dial unimplemented", network, local) - return nil, errNotRemote - } - // todo: rpnwg, rpnamz, rpnwin - exit, err := h.pxr.ProxyFor(Exit) - if err == nil { - pc, err = exit.Dialer().Probe(network, local) - } - defer localDialStatus(h.status, err) - - log.I("proxy: auto: probe(%s) on %s; err? %v", network, local, err) - return pc, err -} - -// Dialer implements Proxy. -func (h *auto) Dialer() protect.RDialer { - return h -} - -// ID implements x.Proxy. -func (h *auto) ID() string { - return Auto -} - -// Type implements x.Proxy. -func (h *auto) Type() string { - return RPN -} - -// Router implements x.Proxy. -func (h *auto) Router() x.Router { - return h -} - -// Reaches implements x.Router. -func (h *auto) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (h *auto) Hop(p Proxy, dryrun bool) error { - if p == nil { - if !dryrun { - old := h.swapVia(nil) - log.I("proxy: auto: hop(%s) removed", idhandle(old)) - } - return nil - } - if p.Status() == END { - return errProxyStopped - } - - var win Proxy - var waerr, winerr error - old := h.swapVia(p) - if win, winerr = h.pxr.mainRpnProxyOf(RpnWin); win != nil { - winerr = win.Hop(p, dryrun) - } - - errs := core.JoinErr(waerr, winerr) // may be nil - - logei(errs)("proxy: auto: hop(%s) => %s; errs? %v", - idhandle(old), idhandle(p), errs) - - return errs -} - -func (h *auto) Via() (x.Proxy, error) { - if v := h.via.Load(); v != nil { - return v, nil - } - return nil, errNoHop -} - -// GetAddr implements x.Proxy. -func (h *auto) GetAddr() string { - return h.addr -} - -// Status implements x.Proxy. -func (h *auto) Status() int { - return h.status.Load() -} - -// Stop implements x.Proxy. -func (h *auto) Stop() error { - h.status.Store(END) - h.exp.Clear() - log.I("proxy: auto: stopped") - return nil -} - -// dialIfReachable currently aliases dialIfHealthy. -func (h *auto) dialIfReachable(p Proxy, network, local, remote string) (net.Conn, error) { - // remote is oftimes a hostname; in which case hasroute would error out (as it - // works with ip addresses only). The alternative is to get the ipmap from dialers pkg - // but that would be redundant to what the individual proxy implementations already do. - // if !hasroute(p, remote) { - // return nil, fmt.Errorf("auto; %s: %v", p.ID(), errNoRouteToHost) - // } - // some IPs never respond to ping; ex: 34.245.245.138:443, 63.32.2.144:80 - // even if they respond over tcp/udp on the same ip:port. - // ipp, _ := netip.ParseAddrPort(remote) - // if reachable, err := h.ba.DoIt(p.ID()+remote, remote), icmpReachesWork(p, ipp)); err != nil { - // return nil, fmt.Errorf("auto; %s ping %s: %v", p.ID(), remote, err) - // } else if !reachable { - // return nil, fmt.Errorf("auto; %s: %v: %s", p.ID(), errNoRouteToHost, remote) - // } - return h.dialIfHealthy(p, network, local, remote) -} - -func (*auto) dialAlways(p Proxy, network, local, remote string) (net.Conn, error) { - err := healthy(p) - if err != nil { - log.E("auto dial; %s %s not ok; to %s; err: %v", idstr(p), network, remote, err) - } - if len(local) > 0 { - return p.Dialer().DialBind(network, local, remote) - } - return p.Dialer().Dial(network, remote) -} - -func (a *auto) dialIfHealthy(p Proxy, network, local, remote string) (net.Conn, error) { - if err := healthy(p); err != nil { - log.E("auto dial; %s %s not ok; %v: %s", p.ID(), network, err, remote) - time.Sleep(delayForUnhealthyProxies) - } - if len(local) > 0 { - return p.Dialer().DialBind(network, local, remote) - } - return p.Dialer().Dial(network, remote) -} - -func (*auto) announceIfHealthy(p Proxy, network, local string) (net.PacketConn, error) { - if err := healthy(p); err != nil { - log.E("auto announce; %s %s not ok; %v: %s", p.ID(), network, err, local) - time.Sleep(delayForUnhealthyProxies) - } - return p.Dialer().Announce(network, local) -} - -func maybeKeepAlive(c net.Conn) (keepingalive bool) { - keepingalive, _ = maybeKeepAlive2(c) - return -} - -func maybeKeepAlive2(c net.Conn) (keepingalive, ok bool) { - if c == nil || core.IsNil(c) { - return - } - - if opts := settings.GetDialerOpts(); opts.LowerKeepAlive { - // adjust socket's keepalive config - lowered := core.SetKeepAliveConfigSockOpt(c) - keepingalive = lowered - ok = lowered - return - } - // disable socket keepalive - disabled := core.DisableKeepAlive(c) - keepingalive = !disabled - ok = disabled - return -} diff --git a/intra/ipn/base.go b/intra/ipn/base.go deleted file mode 100644 index 6d9e54bf..00000000 --- a/intra/ipn/base.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" -) - -const fakeBaseAddr = "127.8.4.5:3690" - -// base is no-op proxy that dials into the underlying network, -// which typically is wifi or mobile but may also be a tun device. -type base struct { - NoDNS - ProtoAgnostic - SkipRefresh - CantPause - GW - id string - addr string - outbound *protect.RDial // outbound dialer - via *core.WeakRef[Proxy] // via dialer - viaID *core.Volatile[string] // via proxy ID - px ProxyProvider - status *core.Volatile[int] - done context.CancelFunc -} - -// Base returns a base proxy. -func NewBaseProxy(ctx context.Context, c protect.Controller, px ProxyProvider) *base { - return newBasicProxy(Base, fakeBaseAddr, ctx, c, px) -} - -func newBasicProxy(id, addr string, ctx context.Context, c protect.Controller, px ProxyProvider) *base { - ctx, done := context.WithCancel(ctx) - h := &base{ - id: id, - addr: addr, - px: px, - outbound: protect.MakeNsRDial(Base, ctx, c), - viaID: core.NewZeroVolatile[string](), - status: core.NewVolatile(TUP), - done: done, - } - var err error - h.via, err = core.NewWeakRef(h.viafor, viaok) - if err != nil { - panic(err) // unlikely - } - return h -} - -func NewBasicProxy(id string, ctx context.Context, c protect.Controller, px ProxyProvider) Proxy { - return newBasicProxy(id, fakeBaseAddr, ctx, c, px) -} - -func (h *base) viafor() *Proxy { - return viafor(idstr(h), h.viaID.Load(), h.px) -} - -func (h *base) swapVia(new Proxy) (old Proxy) { - return swapVia(idstr(h), new, h.viaID, h.via) -} - -// Handle implements Proxy. -func (h *base) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *base) DialerHandle() uintptr { - return core.Loc(h.outbound) -} - -// Dial implements Proxy. -func (h *base) Dial(network, addr string) (c protect.Conn, err error) { - return h.dial(network, "", addr) -} - -// DialBind implements Proxy. -func (h *base) DialBind(network, local, remote string) (c protect.Conn, err error) { - return h.dial(network, local, remote) -} - -func (h *base) dial(network, local, remote string) (c protect.Conn, err error) { - if err := candial(h.status); err != nil { - return nil, err - } - - who := idstr(h) - if usevia(h.viaID) { - if v, vok := h.via.Get(); vok { // dial via another proxy - who = idstr(v) - c, err = v.DialBind(network, local, remote) - } else { - err = errNoHop - if removeViaOnErrors { - h.Hop(nil, false /*dryrun*/) // stale; unset - } - log.W("proxy: base: via(%s) failing...", idhandle(v)) - } - } else { - if settings.Loopingback.Load() { // loopback (rinr) mode - // TODO: test if binding to local address works in rinr mode - c, err = dialers.DialBind(h.outbound, network, local, remote) - } else { - c, err = localDialStrat(h.outbound, network, local, remote) - } - } - defer localDialStatus(h.status, err) - - kaenabled := maybeKeepAlive(c) - log.I("proxy: base: dial(%s) to %s=>%s (via %s), ka? %t; err? %v", - network, local, remote, who, kaenabled, err) - return -} - -// Announce implements Proxy. -func (h *base) Announce(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - c, err := dialers.ListenPacket(h.outbound, network, local) - defer localDialStatus(h.status, err) - log.I("proxy: base: announce(%s) on %s; err? %v", network, local, err) - return c, err -} - -// Accept implements Proxy. -func (h *base) Accept(network, local string) (protect.Listener, error) { - if err := candial(h.status); err != nil { - return nil, err - } - return dialers.Listen(h.outbound, network, local) -} - -// Probe implements Proxy. -func (h *base) Probe(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - c, err := dialers.Probe(h.outbound, network, local) - defer localDialStatus(h.status, err) - log.I("proxy: base: probe(%s) on %s; err? %v", network, local, err) - return c, err -} - -func (h *base) Dialer() protect.RDialer { - return h -} - -func (h *base) ID() string { - return Base -} - -func (h *base) Type() string { - return NOOP -} - -func (h *base) Router() x.Router { - return h -} - -// Reaches implements x.Router. -func (h *base) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (h *base) Hop(p Proxy, dryrun bool) error { - if p == nil { - if !dryrun { - old := h.swapVia(nil) - log.I("proxy: base: hop(%s) removed", idhandle(old)) - } - return nil - } - if p.Status() == END { - return errProxyStopped - } - if idstr(p) != GlobalH1 { - return errHopGlobalProxy - } - - if !dryrun { - old := h.swapVia(nil) - log.I("proxy: base: hop %s => %s", idhandle(old), idhandle(p)) - } - return nil -} - -// Via implements x.Router. -func (h *base) Via() (x.Proxy, error) { - if v := h.via.Load(); v != nil { - return v, nil - } - return nil, errNoHop -} - -func (h *base) GetAddr() string { - return h.addr -} - -func (h *base) Status() int { - return h.status.Load() -} - -func (h *base) Stop() error { - h.status.Store(END) - h.done() - log.I("proxy: base: stopped") - return nil -} diff --git a/intra/ipn/exit.go b/intra/ipn/exit.go deleted file mode 100644 index 749f3d9e..00000000 --- a/intra/ipn/exit.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - crand "crypto/rand" - "encoding/hex" - "math/rand/v2" - "net" - "strconv" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" -) - -const ( - fakeExitAddr = "127.0.0.127" - fakeExitPort = "1337" -) - -// exit is a proxy that always dials out to the internet. -type exit struct { - NoDNS - ProtoAgnostic - SkipRefresh - CantPause - GWNoVia - id string - addr string - outbound *protect.RDial // outbound dialer - status *core.Volatile[int] - done context.CancelFunc -} - -// NewExitProxy returns a new exit proxy. -func NewExitProxy(ctx context.Context, c protect.Controller) *exit { - return newExitProxy(Exit, net.JoinHostPort(fakeExitAddr, fakeExitPort), ctx, c) -} - -func NewExitProxyWithID(id, addr string, ctx context.Context, c protect.Controller) *exit { - if len(id) <= 0 { - id = hex8() - } - if len(addr) <= 0 { - addr = net.JoinHostPort(fakeExitAddr, no65535()) - } - return newExitProxy(id, addr, ctx, c) -} - -func newExitProxy(id, addr string, ctx context.Context, c protect.Controller) *exit { - ctx, done := context.WithCancel(ctx) - h := &exit{ - id: id, - addr: addr, - outbound: protect.MakeNsRDial(Exit, ctx, c), - status: core.NewVolatile(TUP), - done: done, - } - return h -} - -// Handle implements Proxy. -func (h *exit) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *exit) DialerHandle() uintptr { - return core.Loc(h.outbound) -} - -// Dial implements Proxy. -func (h *exit) Dial(network, addr string) (protect.Conn, error) { - return h.dial(network, "", addr) -} - -// DialBind implements Proxy. -func (h *exit) DialBind(network, local, remote string) (protect.Conn, error) { - return h.dial(network, local, remote) -} - -func (h *exit) dial(network, local, remote string) (protect.Conn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - // exit always splits - c, err := localDialStrat(h.outbound, network, local, remote) - defer localDialStatus(h.status, err) - - kaenabled := maybeKeepAlive(c) - log.I("proxy: %s: dial(%s) %s => %s, ka? %t; err? %v", - h.id, network, local, remote, kaenabled, err) - return c, err -} - -// Announce implements Proxy. -func (h *exit) Announce(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - c, err := dialers.ListenPacket(h.outbound, network, local) - defer localDialStatus(h.status, err) - log.I("proxy: %s: announce(%s) on %s; err? %v", h.id, network, local, err) - return c, err -} - -// Accept implements Proxy. -func (h *exit) Accept(network, local string) (protect.Listener, error) { - if err := candial(h.status); err != nil { - return nil, err - } - return dialers.Listen(h.outbound, network, local) -} - -// Probe implements Proxy. -func (h *exit) Probe(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - c, err := dialers.Probe(h.outbound, network, local) - defer localDialStatus(h.status, err) - log.I("proxy: %s: probe(%s) on %s; err? %v", h.id, network, local, err) - return c, err -} - -// Dialer implements Proxy. -func (h *exit) Dialer() protect.RDialer { - return h -} - -// ID implements x.Proxy. -func (h *exit) ID() string { - return h.id -} - -// Type implements x.Proxy. -func (h *exit) Type() string { - return INTERNET -} - -// Router implements x.Proxy. -func (h *exit) Router() x.Router { - return h -} - -// Reaches implements x.Router. -func (h *exit) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// GetAddr implements x.Proxy. -func (h *exit) GetAddr() string { - return h.addr -} - -// Status implements x.Proxy. -func (h *exit) Status() int { - return h.status.Load() -} - -// Stop implements x.Proxy. -func (h *exit) Stop() error { - h.status.Store(END) - h.done() - log.I("proxy: %s: stopped", h.id) - return nil -} - -func localDialStatus(status *core.Volatile[int], err error) bool { - cur := status.Load() - if cur == END || cur == TPU { - return false - } - if err != nil { - return status.Cas(cur, TKO) - } - return status.Cas(cur, TOK) -} - -func idhandle(p Proxy) string { - if p == nil || core.IsNil(p) { - return "" - } - return idstr(p) + "@" + strconv.Itoa(int(p.Handle())) -} - -func idstr(p x.Proxy) string { - if p == nil || core.IsNil(p) { - return "" - } - return p.ID() -} - -func typstr(p x.Proxy) string { - if p == nil || core.IsNil(p) { - return "" - } - return p.Type() -} - -// create a random hex character string of length 8 -func hex8() string { - b := make([]byte, 4) - if _, err := crand.Read(b); err != nil { - return "deadbeef" - } - return hex.EncodeToString(b) -} - -func no65535() string { - no := max(rand.IntN(65535), 1024) - return strconv.Itoa(no) -} diff --git a/intra/ipn/exit64.go b/intra/ipn/exit64.go deleted file mode 100644 index 06fc3e2d..00000000 --- a/intra/ipn/exit64.go +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "net" - "net/netip" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/ipn/rpn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" -) - -var ( - anyaddr4 = netip.IPv4Unspecified() - anyaddr6 = netip.IPv6Unspecified() -) - -// exit64 is a proxy that always dials out to the internet -// over well-known preset public NAT64 prefixes. -type exit64 struct { - NoDNS - ProtoAgnostic - SkipRefresh - GWNoVia - - rpn.RpnForever - rpn.RpnStateless - rpn.RpnCountryless - - outbound *protect.RDial // outbound dialer - addr string - since time.Time - status *core.Volatile[int] - done context.CancelFunc -} - -var _ RpnAcc = (*exit64)(nil) - -// NewExit64Proxy returns a new exit64 proxy. -func NewExit64Proxy(ctx context.Context, c protect.Controller) *exit64 { - ctx, done := context.WithCancel(ctx) - h := &exit64{ - addr: "127.64.64.127:6464", - // "Exit" as "id" to have all its sockets "protected" - outbound: protect.MakeNsRDial(Exit, ctx, c), - status: core.NewVolatile(TUP), - since: time.Now(), - done: done, - } - return h -} - -// Handle implements Proxy. -func (h *exit64) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *exit64) DialerHandle() uintptr { - return core.Loc(h.outbound) -} - -// Dial implements Proxy. -func (h *exit64) Dial(network, addr string) (protect.Conn, error) { - return h.dial(network, "", addr) -} - -// DialBind implements Proxy. -func (h *exit64) DialBind(network, local, remote string) (protect.Conn, error) { - return h.dial(network, local, remote) -} - -func (h *exit64) dial(network, local, remote string) (protect.Conn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - - addr64 := addr4to6(remote) - local64 := anyaddr4to6(local) - if len(addr64) <= 0 || (len(local) > 0 && len(local64) <= 0) { - return nil, errNoAuto464XLAT - } - - // exit64 always splits - c, err := localDialStrat(h.outbound, network, local64, addr64) - defer localDialStatus(h.status, err) - - kaenabled := maybeKeepAlive(c) - log.I("proxy: exit64: dial(%s) %s via %s to %s, ka? %t; err? %v", - network, local64, remote, addr64, kaenabled, err) - - return c, err -} - -// Announce implements Proxy. -func (h *exit64) Announce(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - var local64 string - if ipp, _ := netip.ParseAddrPort(local); ipp.IsValid() { - if ipp.Addr().Is4() { - local64 = netip.AddrPortFrom(netip.IPv6Unspecified(), ipp.Port()).String() - } else { - local64 = local - } - } - if len(local64) <= 0 { - return nil, errNoAuto464XLAT - } - - c, err := dialers.ListenPacket(h.outbound, network, local64) - defer localDialStatus(h.status, err) - - log.I("proxy: exit64: announce(%s) via %s on %s; err? %v", network, local64, local, err) - return c, err -} - -// Accept implements Proxy. -func (h *exit64) Accept(network, local string) (protect.Listener, error) { - if err := candial(h.status); err != nil { - return nil, err - } - var local64 string - if ipp, _ := netip.ParseAddrPort(local); ipp.IsValid() { - if ipp.Addr().Is4() { - local64 = netip.AddrPortFrom(netip.IPv6Unspecified(), ipp.Port()).String() - } else { - local64 = local - } - } - if len(local64) <= 0 { - return nil, errNoAuto464XLAT - } - - l, err := dialers.Listen(h.outbound, network, local) - defer localDialStatus(h.status, err) - - log.I("proxy: exit64: accept(%s) via %s on %s; err? %v", network, local64, local, err) - return l, err -} - -// Probe implements Proxy. -func (h *exit64) Probe(network, local string) (protect.PacketConn, error) { - if err := candial(h.status); err != nil { - return nil, err - } - var local64 string - if ipp, _ := netip.ParseAddrPort(local); ipp.IsValid() { - if ipp.Addr().Is4() { - local64 = netip.AddrPortFrom(netip.IPv6Unspecified(), ipp.Port()).String() - } else { - local64 = local - } - } - if len(local64) <= 0 { - return nil, errNoAuto464XLAT - } - - c, err := dialers.Probe(h.outbound, network, local) - defer localDialStatus(h.status, err) - - log.I("proxy: exit64: probe(%s) via %s on %s; err? %v", network, local64, local, err) - return c, err -} - -// Dialer implements Proxy. -func (h *exit64) Dialer() protect.RDialer { - return h -} - -// ID implements Proxy. -func (h *exit64) ID() string { - return Rpn64 -} - -// Type implements Proxy. -func (h *exit64) Type() string { - return INTERNET -} - -// Router implements Proxy. -func (h *exit64) Router() x.Router { - return h -} - -// Reaches implements x.Router. -func (h *exit64) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// GetAddr implements Proxy. -func (h *exit64) GetAddr() string { - return h.addr -} - -// Status implements Proxy. -func (h *exit64) Status() int { - return h.status.Load() -} - -// Since implements x.Proxy. -func (h *exit64) Pause() bool { - st := h.status.Load() - if st == END { - log.W("proxy: exit64: pause called when stopped") - return false - } - - ok := h.status.Cas(st, TPU) - log.I("proxy: exit64: paused? %t", ok) - return ok -} - -// Resume implements x.Proxy. -func (h *exit64) Resume() bool { - st := h.status.Load() - if st != TPU { - log.W("proxy: exit64: resume called when not paused; status %d", st) - return false - } - - ok := h.status.Cas(st, TUP) - go h.Refresh() // no-op since SkipRefresh - log.I("proxy: exit64: resumed? %t", ok) - return ok -} - -// Stop implements Proxy. -func (h *exit64) Stop() error { - h.status.Store(END) - h.done() - log.I("proxy: exit64: stopped") - return nil -} - -// Who implements x.RpnAcc. -func (h *exit64) Who() string { - return Rpn64 -} - -// Provider implements RpnAcc. -func (*exit64) ProviderID() string { return Rpn64 } - -// go.dev/play/p/GtLCDAXeeLJ -func addr4to6(addr string) string { - // check if addr is an IPv4 address - ipport, err := netip.ParseAddrPort(addr) - if err != nil { // hostname? - resolved := dialers.For(addr) - ok := len(resolved) > 0 - - for _, ip := range resolved { - if !ip.IsValid() || ip.Is6() { - continue - } - ipport = netip.AddrPortFrom(ip, ipport.Port()) - break // break on first valid IPv4 ipport - } - - logeif(!ok)("proxy: auto: exit64: addr64: is host? %s; chosen? %v, resolved? %v; err: %v", - addr, ipport, resolved, err) - - if !ipport.IsValid() { - return "" - } - } - - ip4 := ipport.Addr() - if !ip4.Is4() { - log.VV("proxy: auto: exit64: addr64: chosen addr not v4(%s)", addr) - return "" - } - // embed IPv4 in IPv6 - ippre := core.ChooseOne(rpn.Net6to4) - ip6 := ip4to6(ippre, ip4) - if !ip6.IsValid() { - log.W("proxy: auto: exit64: addr64: failed to embed(%s) in v6(%s)", ip4, ippre) - return "" - } - return netip.AddrPortFrom(ip6, ipport.Port()).String() -} - -func ip4to6(prefix96 netip.Prefix, ip4 netip.Addr) netip.Addr { - if !prefix96.IsValid() || !ip4.IsValid() { - return netip.Addr{} - } - startingAddress := prefix96.Masked().Addr() - addrLen := startingAddress.BitLen() / 8 // == 128 / 8 == 16 - prefixLen := prefix96.Bits() / 8 // == 96 / 8 == 12 - hostLen := (addrLen - prefixLen) // == 16 - 12 == 4 - s6 := startingAddress.As16() - s4 := ip4.As4() - n := copy(s6[prefixLen:], s4[:hostLen]) - if n != hostLen { - log.W("proxy: auto: exit64: ip4to6(%v, %v) failed; pre:%d host:%d for net:%v ip4:%v", - s6, s4, prefixLen, hostLen, prefix96, ip4) - return netip.Addr{} - } - return netip.AddrFrom16(s6) -} - -func anyaddr4to6(addr string) string { - if _, port, err := net.SplitHostPort(addr); err == nil { - return net.JoinHostPort(anyaddr6.String(), port) - } - return addr -} - -func logeif(e bool) log.LogFn { - if e { - return log.E - } - if settings.Debug { - return log.D - } - return log.N -} - -func logei(err error) log.LogFn { - if err != nil { - return log.E - } - return log.I -} diff --git a/intra/ipn/ground.go b/intra/ipn/ground.go deleted file mode 100644 index 91c3861a..00000000 --- a/intra/ipn/ground.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/protect" -) - -// ground is a proxy that does nothing. -type ground struct { - NoDNS - ProtoAgnostic - SkipRefresh - GWNoVia - CantPause - NoClient - addr string -} - -var _ Proxy = (*ground)(nil) - -// NewGroundProxy returns a new ground proxy. -func NewGroundProxy() *ground { - h := &ground{ - GWNoVia: ProxyNoGateway, - addr: "[::]:0", - } - return h -} - -// Handle implements Proxy. -func (h *ground) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *ground) DialerHandle() uintptr { - return h.Handle() -} - -// Dial implements Proxy. -func (h *ground) Dial(network, addr string) (protect.Conn, error) { - return nil, errNoProxyResponse -} - -// DialBind implements Proxy. -func (h *ground) DialBind(network, local, remote string) (protect.Conn, error) { - return nil, errNoProxyResponse -} - -// Announce implements Proxy. -func (h *ground) Announce(network, local string) (protect.PacketConn, error) { - return nil, errNoProxyResponse -} - -// Accept implements Proxy. -func (h *ground) Accept(network, local string) (protect.Listener, error) { - return nil, errNoProxyResponse -} - -// Probe implements Proxy. -func (h *ground) Probe(network, local string) (protect.PacketConn, error) { - return nil, errNoProxyResponse -} - -func (h *ground) Dialer() protect.RDialer { - return h // no-op dialer -} - -func (h *ground) ID() string { - return Block -} - -func (h *ground) Type() string { - return NOOP -} - -func (h *ground) Router() x.Router { - return h -} - -func (h *ground) GetAddr() string { - return h.addr -} - -func (h *ground) Status() int { - return TKO -} - -func (h *ground) Stop() error { - return nil -} diff --git a/intra/ipn/h1/auth.go b/intra/ipn/h1/auth.go deleted file mode 100644 index 7283459a..00000000 --- a/intra/ipn/h1/auth.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2016 Michal Witkowski. All Rights Reserved. - -package h1 - -import "encoding/base64" - -// code adopted from: github.com/mwitkow/go-http-dialer/blob/378f744fb2/auth.go#L1 - -const ( - hdrProxyAuthResp = "Proxy-Authorization" - hdrProxyAuthReq = "Proxy-Authenticate" -) - -// ProxyAuthorization allows for plugging in arbitrary implementations of the "Proxy-Authorization" handler. -type ProxyAuthorization interface { - // Type represents what kind of Authorization, e.g. "Bearer", "Token", "Digest". - Type() string - - // Initial allows you to specify an a-priori "Proxy-Authenticate" response header, attached to first request, - // so you don't need to wait for an additional challenge. If empty string is returned, "Proxy-Authenticate" - // header is added. - InitialResponse() string - - // ChallengeResponse returns the content of the "Proxy-Authenticate" response header, that has been chose as - // response to "Proxy-Authorization" request header challenge. - ChallengeResponse(challenge string) string -} - -type basicAuth struct { - username string - password string -} - -// AuthBasic returns a ProxyAuthorization that implements "Basic" protocol while ignoring realm challenges. -func AuthBasic(username string, password string) *basicAuth { - return &basicAuth{username: username, password: password} -} - -func (b *basicAuth) Type() string { - return "Basic" -} - -func (b *basicAuth) InitialResponse() string { - return b.authString() -} - -func (b *basicAuth) ChallengeResponse(challenge string) string { - // challenge can be realm="proxy.com" - // TODO(mwitkow): Implement realm lookup in AuthBasicWithRealm. - return b.authString() -} - -func (b *basicAuth) authString() string { - resp := b.username + ":" + b.password - return base64.StdEncoding.EncodeToString([]byte(resp)) -} diff --git a/intra/ipn/h1/dialer.go b/intra/ipn/h1/dialer.go deleted file mode 100644 index 497a50d8..00000000 --- a/intra/ipn/h1/dialer.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2016 Michal Witkowski. All Rights Reserved. - -package h1 - -import ( - "bufio" - "crypto/tls" - "fmt" - "net" - "net/http" - "net/url" - "strings" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" -) - -// from: github.com/mwitkow/go-http-dialer/blob/378f744fb2/dialer.go - -type Opt func(*HttpTunnel) - -func New(proxyUrl *url.URL, opts ...Opt) *HttpTunnel { - t := &HttpTunnel{} - t.parseProxyUrl(proxyUrl) - for _, opt := range opts { - opt(t) - } - _, ok := dialers.New(t.hostname, nil) - log.I("http: new dialer for %s; resolved? %t", t.hostname, ok) - return t -} - -// WithTls sets the tls.Config to be used (e.g. CA certs) when connecting to an HTTP proxy over TLS. -func WithTls(tlsConfig *tls.Config) Opt { - return func(t *HttpTunnel) { - t.tlsConfig = tlsConfig - } -} - -// WithDialer allows the customization of the underlying net.Dialer used for establishing TCP connections to the proxy. -func WithDialer(dialer protect.RDialer) Opt { - return func(t *HttpTunnel) { - t.d = dialer - } -} - -// WithProxyAuth allows you to add ProxyAuthorization to calls. -func WithProxyAuth(auth ProxyAuthorization) Opt { - return func(t *HttpTunnel) { - t.auth = auth - } -} - -// HttpTunnel represents a configured HTTP Connect Tunnel dialer. -type HttpTunnel struct { - d protect.RDialer - isTls bool - hostname string // host - proxyAddr string // host or host:port - tlsConfig *tls.Config - auth ProxyAuthorization -} - -func (t *HttpTunnel) parseProxyUrl(proxyUrl *url.URL) { - t.hostname = proxyUrl.Hostname() - t.proxyAddr = proxyUrl.Host - if strings.ToLower(proxyUrl.Scheme) == "https" { - if !strings.Contains(t.proxyAddr, ":") { - t.proxyAddr = t.proxyAddr + ":443" - } - t.isTls = true - } else { - if !strings.Contains(t.proxyAddr, ":") { - t.proxyAddr = t.proxyAddr + ":8080" - } - t.isTls = false - } -} - -func (t *HttpTunnel) dialProxy() (net.Conn, error) { - if !t.isTls { - return dialers.ProxyDial(t.d, "tcp", t.proxyAddr) - } - return dialers.DialWithTls(t.d, t.tlsConfig, "tcp", t.proxyAddr) -} - -// Dial implements proxy.Dialer. -// Returns a conn to address that HTTP CONNECT reached. -func (t *HttpTunnel) Dial(network string, address string) (net.Conn, error) { - if !strings.Contains(network, "tcp") { // tcp4, tcp6, tcp - return nil, fmt.Errorf("http1: tunnel: network type '%v' unsupported (only 'tcp')", network) - } - conn, err := t.dialProxy() - if err != nil { - return nil, fmt.Errorf("http1: tunnel: failed dialing to proxy: %v", err) - } - req := &http.Request{ - Method: "CONNECT", - URL: &url.URL{Opaque: address}, - Host: address, // This is weird - Header: make(http.Header), - } - if t.auth != nil { - if creds := t.auth.InitialResponse(); len(creds) > 0 { - req.Header.Set(hdrProxyAuthResp, t.auth.Type()+" "+creds) - } - } - resp, err := t.doRoundtrip(conn, req) - if err != nil { - clos(conn) - return nil, err - } - // Retry request with auth, if available. - if resp.StatusCode == http.StatusProxyAuthRequired && t.auth != nil { - responseHdr, err := t.performAuthChallengeResponse(resp) - if err != nil { - clos(conn) - return nil, err - } - req.Header.Set(hdrProxyAuthResp, t.auth.Type()+" "+responseHdr) - resp, err = t.doRoundtrip(conn, req) - if err != nil { - clos(conn) - return nil, err - } - } - - if resp.StatusCode != http.StatusOK { - clos(conn) - return nil, fmt.Errorf("http1: tunnel: failed proxying %d: %s", resp.StatusCode, resp.Status) - } - return conn, nil -} - -func clos(c net.Conn) { - core.CloseConn(c) -} - -func (t *HttpTunnel) doRoundtrip(conn net.Conn, req *http.Request) (*http.Response, error) { - if err := req.Write(conn); err != nil { - return nil, fmt.Errorf("http1: tunnel: failed writing request: %v", err) - } - // Doesn't matter, discard this bufio. - br := bufio.NewReader(conn) - return http.ReadResponse(br, req) -} - -func (t *HttpTunnel) performAuthChallengeResponse(resp *http.Response) (string, error) { - respAuthHdr := resp.Header.Get(hdrProxyAuthReq) - if !strings.Contains(respAuthHdr, t.auth.Type()+" ") { - return "", fmt.Errorf("http1: tunnel: expected '%v' Proxy authentication, got: '%v'", t.auth.Type(), respAuthHdr) - } - splits := strings.SplitN(respAuthHdr, " ", 2) - if len(splits) <= 1 { - return "", fmt.Errorf("http1: tunnel: malformed Proxy-Authenticate header: '%v'", respAuthHdr) - } - challenge := splits[1] - return t.auth.ChallengeResponse(challenge), nil -} diff --git a/intra/ipn/http1.go b/intra/ipn/http1.go deleted file mode 100644 index 91d1ab3c..00000000 --- a/intra/ipn/http1.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "crypto/tls" - "net/url" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - tx "github.com/celzero/firestack/intra/ipn/h1" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/net/proxy" -) - -type http1 struct { - NoFwd // no forwarding/listening - NoDNS // no dns - SkipRefresh // no refresh - GW // dual stack gateway - - id string - outbound proxy.Dialer - via *core.WeakRef[Proxy] - viaID *core.Volatile[string] - px ProxyProvider - opts *settings.ProxyOptions - lastdial time.Time - status *core.Volatile[int] -} - -func NewHTTPProxy(id string, ctx context.Context, c protect.Controller, px ProxyProvider, po *settings.ProxyOptions) (*http1, error) { - var err error - if po == nil { - log.W("proxy: err setting up http1 w(%v): %v", po, err) - return nil, errMissingProxyOpt - } - - u, err := url.Parse(po.Url()) - if err != nil { - log.W("proxy: http1: err proxy opts(%v): %v", po, err) - return nil, errProxyScheme - } - - d := protect.MakeNsRDial(id, ctx, c) - - opts := make([]tx.Opt, 0) - optdialer := tx.WithDialer(d) - opts = append(opts, optdialer) - if po.Scheme == "https" && len(po.Host) > 0 { - opttls := tx.WithTls(&tls.Config{ - ServerName: po.Host, - MinVersion: tls.VersionTLS12, - }) - opts = append(opts, opttls) - } - if po.HasAuth() { - optauth := tx.WithProxyAuth(tx.AuthBasic(po.Auth.User, po.Auth.Password)) - opts = append(opts, optauth) - } - - hp := tx.New(u, opts...) - - h := &http1{ - outbound: hp, // does not support udp - px: px, - viaID: core.NewZeroVolatile[string](), - status: core.NewVolatile(TUP), - id: id, - opts: po, - } - h.via, err = core.NewWeakRef(h.viafor, viaok) - - logeif(err != nil)("proxy: http1: created %s with opts(%s); err? %v", - h.ID(), po, err) - - return h, nil -} - -func (h *http1) viafor() *Proxy { - return viafor(h.id, h.viaID.Load(), h.px) -} - -func (h *http1) swapVia(new Proxy) Proxy { - return swapVia(h.id, new, h.viaID, h.via) -} - -// Handle implements Proxy. -func (h *http1) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *http1) DialerHandle() uintptr { - return core.Loc(h.outbound) -} - -// Dial implements Proxy. -func (h *http1) Dial(network, addr string) (c protect.Conn, err error) { - if err := candial(h.status); err != nil { - return nil, err - } - - h.lastdial = time.Now() - - who := idstr(h) - if usevia(h.viaID) { - if v, vok := h.via.Get(); vok { // dial via another proxy - who = idstr(v) - c, err = v.Dial(network, addr) - } else { - err = errNoHop - if removeViaOnErrors { - h.Hop(nil, false /*dryrun*/) // stale; unset - } - log.W("http1: via(%s) failing...", idhandle(v)) - } - } else { - // actually, dialers.ProxyDial not needed, because - // tx.HttpTunnel.Dial() supports dialing into hostnames - c, err = dialers.ProxyDial(h.outbound, network, addr) - } - defer localDialStatus(h.status, err) - - log.I("proxy: http1: dial(%s) from %s => %s (via %s); err? %v", network, h.GetAddr(), addr, who, err) - return -} - -// DialBind implements Proxy. -func (h *http1) DialBind(network, local, remote string) (c protect.Conn, err error) { - log.D("http1: dialbind(%s) from %s to %s not supported", network, local, remote) - // TODO: error instead? - return h.Dial(network, remote) -} - -func (h *http1) Dialer() protect.RDialer { - return h -} - -func (h *http1) ID() string { - return h.id -} - -func (h *http1) Type() string { - return HTTP1 -} - -func (h *http1) Router() x.Router { - return h -} - -// Reaches implements x.Router. -func (h *http1) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (h *http1) Hop(p Proxy, dryrun bool) error { - if h.id == GlobalH1 { - return errNop // global proxy exits as-is - } - - if p == nil { - if !dryrun { - old := h.swapVia(nil) - log.I("proxy: http1: hop(%s) removed", idhandle(old)) - } - return nil - } - if p.Status() == END { - return errProxyStopped - } - - if !dryrun { - old := h.swapVia(p) - log.I("http1: hop %s => %s", idhandle(old), idhandle(p)) - } - return nil -} - -// Via implements x.Router. -func (h *http1) Via() (x.Proxy, error) { - if v := h.via.Load(); v != nil { - return v, nil - } - return nil, errNoHop -} - -// GetAddr implements Proxy. -func (h *http1) GetAddr() string { - return h.opts.IPPort -} - -// Status implements Proxy. -func (h *http1) Status() int { - s := h.status.Load() - if s != END && idling(h.lastdial) { - return TZZ - } - return s -} - -// Pause implements x.Proxy. -func (h *http1) Pause() bool { - st := h.status.Load() - if st == END { - log.W("proxy: http1: pause called when stopped") - return false - } - - ok := h.status.Cas(st, TPU) - log.I("proxy: http1: paused? %t", ok) - return ok -} - -// Resume implements x.Proxy. -func (h *http1) Resume() bool { - st := h.status.Load() - if st != TPU { - log.W("proxy: http1: resume called when not paused; status %d", st) - return false - } - - ok := h.status.Cas(st, TUP) - go h.Refresh() // no-op since SkipRefresh - log.I("proxy: http1: resumed? %t", ok) - return ok -} - -// Stop implements Proxy. -func (h *http1) Stop() error { - h.status.Store(END) - log.I("proxy: http1: stopped %s", h.id) - return nil -} - -// OnProtoChange implements Proxy. -func (h *http1) OnProtoChange(_ LinkProps) (string, bool) { - if err := candial(h.status); err != nil { - return "", false - } - return h.opts.FullUrl(), true -} diff --git a/intra/ipn/multihost/map.go b/intra/ipn/multihost/map.go deleted file mode 100644 index 183173e6..00000000 --- a/intra/ipn/multihost/map.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package multihost - -import ( - "fmt" - "net/netip" - "net/url" - "strings" - "sync" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" -) - -type MHMap struct { - sync.RWMutex - k string // uniq identifier - uniq map[*MH]struct{} - byIpp map[netip.AddrPort]*MH // ip:port => MH - byHostport map[string]*MH // host:port => MH -} - -func (m *MHMap) All() (all []*MH) { - if m == nil { - return - } - - m.RLock() - defer m.RUnlock() - for h := range m.uniq { - all = append(all, h) - } - return -} - -func (m *MHMap) Get(hostOrIpport string) (h *MH, _ error) { - if m == nil { - return nil, errMhNotFound - } - m.RLock() - defer m.RUnlock() - - host, port, err := normalize(hostOrIpport) // port may be 0 - if err != nil || len(host) <= 0 { - log.D("multihost: %s map: get for %s => %s:%d / err: %v", m.k, hostOrIpport, host, port, err) - return nil, core.JoinErr(err, url.InvalidHostError(hostOrIpport)) - } - - ipp, err := netip.ParseAddrPort(hostOrIpport) - if err == nil { // is ip:port - h = m.byIpp[ipp] - } else { // may be host:port - h = m.byHostport[hostOrIpport] - } - - ok := h != nil - logeif(!ok)("multihost: %s map: get: for %s [%s]; ok? %t, by ip? %t; parse-err: %v", - m.k, hostOrIpport, ipp, ok, err == nil, err) - - if h == nil { - return nil, core.JoinErr(err, errMhNotFound) - } - return h, nil -} - -func (m *MHMap) Put(h *MH) (ok bool) { - if h == nil { - log.W("multihost: %s map: put: nil? %t", m.k, h == nil) - return - } - - m.Lock() - defer m.Unlock() - return m.putLocked(h) -} - -func (m *MHMap) putLocked(h *MH) (ok bool) { - if h == nil { - return false - } - - if _, dup := m.uniq[h]; dup { - log.W("multihost: %s map: put: dup; call refresh instead?", m.k) - return h.Len() > 0 - } - - ipps := h.Addrs() - names := h.Names() - ok = len(ipps) > 0 || len(names) > 0 - - if ok { // overwrites all existing - m.uniq[h] = struct{}{} - for _, ipp := range ipps { - m.byIpp[ipp] = h - } - for _, name := range names { - m.byHostport[name] = h - } - } - - logeif(!ok)("multihost: %s map: %s put: ipps %d, names %d; ok? %t", - m.k, h.o, len(ipps), len(names), ok) - - return -} - -func (m *MHMap) Del(h *MH) (ok bool) { - if h == nil { - log.W("multihost: %s map: del: nil? %t", m.k, h == nil) - return - } - - m.Lock() - defer m.Unlock() - return m.delLocked(h) -} - -func (m *MHMap) delLocked(h *MH) (ok bool) { - ipps := h.Addrs() - names := h.Names() - ok = len(ipps) > 0 || len(names) > 0 - - if ok { - delete(m.uniq, h) - for _, ip := range ipps { - if x := m.byIpp[ip]; x == h { - delete(m.byIpp, ip) - } - } - for _, name := range names { - if x := m.byHostport[name]; x == h { - delete(m.byHostport, name) - } - } - } - - logeif(!ok)("multihost: %s map: %s del: ipps %d, names %d", - m.k, h.o, len(ipps), len(names)) - - return -} - -func (m *MHMap) Len() (n int64) { - if m == nil { - return - } - - m.RLock() - defer m.RUnlock() - for h := range m.uniq { - n += int64(h.Len()) - } - return -} - -func (m *MHMap) Refresh() (n int64) { - if m == nil { - return - } - - m.Lock() - defer m.Unlock() - for h := range m.uniq { - m.delLocked(h) - n += int64(h.Refresh()) - m.putLocked(h) - } - return -} - -func (m *MHMap) MaybeRefresh() (n int64) { - if m == nil { - return - } - - m.Lock() - defer m.Unlock() - for h := range m.uniq { - if _, stale := h.stale(); stale { - m.delLocked(h) - n += int64(h.Refresh()) - m.putLocked(h) - } - } - return -} - -func (m *MHMap) String() string { - if m == nil { - return "" - } - - m.RLock() - defer m.RUnlock() - if len(m.uniq) <= 0 { - return m.k + ": " - } - - var sb strings.Builder - sb.WriteString(m.k + ": ") - i := 0 - for h := range m.uniq { - sb.WriteString(fmt.Sprintf("#%d ", i)) - sb.WriteString(h.String()) - sb.WriteString(" / ") - i++ - } - return sb.String() -} - -func Flatten(m []*MH) (addrs []netip.AddrPort) { - if len(m) > 0 { - for _, h := range m { - addrs = append(addrs, h.Addrs()...) - } - } - return -} - -func NewMap(id string) *MHMap { - return &MHMap{ - k: id, - uniq: make(map[*MH]struct{}), - byIpp: make(map[netip.AddrPort]*MH), - byHostport: make(map[string]*MH), - } -} diff --git a/intra/ipn/multihost/multihost.go b/intra/ipn/multihost/multihost.go deleted file mode 100644 index b1070910..00000000 --- a/intra/ipn/multihost/multihost.go +++ /dev/null @@ -1,483 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package multihost - -import ( - "errors" - "net" - "net/netip" - "slices" - "strconv" - "strings" - "sync" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" -) - -const refreshInterval time.Duration = 2 * time.Minute - -var ( - errNoIps = errors.New("multihost: no ips") - errMhNotFound = errors.New("multihost: not found") - errInvalidPort = errors.New("multihost: invalid port") -) - -var zeroaddr = netip.AddrPort{} - -type MHAddOp int - -const ( - // Reset replaces the existing IPs with the new IPs. - Reset MHAddOp = iota - // Append appends the new IPs to the existing IPs. - Append -) - -func (op MHAddOp) String() string { - switch op { - case Reset: - return "reset" - case Append: - return "append" - default: - return "unknown" - } -} - -// MH is a list of hostnames and/or ip addresses for one endpoint. -type MH struct { - sync.RWMutex // protects names and addrs - - o string // owner tag - names []string // host:port - addrs []netip.AddrPort // ip:port - preresolved []netip.AddrPort // ip:port; pre-resolved - mtime time.Time // modified time -} - -// New returns a new multihost with the given id. -func New(id string) *MH { - return &MH{ - o: id, - names: make([]string, 0), - addrs: make([]netip.AddrPort, 0), - preresolved: make([]netip.AddrPort, 0), - mtime: time.Now(), - } -} - -func (h *MH) String() string { - if h == nil { - return "" - } - return h.o + "[" + strings.Join(h.Names(), ",") + - " | " + strings.Join(h.straddrs(), ",") + "]" + - " @ " + core.FmtTimeAsPeriod(h.Mtime()) -} - -func (h *MH) Mtime() time.Time { - if h == nil { - return time.Time{} - } - h.RLock() - defer h.RUnlock() - return h.mtime -} - -func (h *MH) straddrs() []string { - a := make([]string, 0) - for _, ip := range h.Addrs() { - if ip.Addr().IsUnspecified() || !ip.IsValid() { - continue - } - a = append(a, ip.String()) - } - return a -} - -// Names returns a copy of the list of hostnames or host:ports. -func (h *MH) Names() []string { - if h == nil { - return nil - } - h.RLock() - defer h.RUnlock() - // Return a copy to prevent external modification - return slices.Clone(h.names) -} - -// Returns ip:port, where ports may be 0. -func (h *MH) Addrs() []netip.AddrPort { - if h == nil { - return nil - } - h.RLock() - defer h.RUnlock() - - return slices.Concat(h.addrs, h.preresolved) -} - -func (h *MH) splitFamily() (out4, out6, og []netip.AddrPort) { - out4 = make([]netip.AddrPort, 0) - out6 = make([]netip.AddrPort, 0) - og = h.Addrs() - - for _, ip := range og { - if ip.Addr().IsUnspecified() || !ip.IsValid() { - continue - } - if ip.Addr().Is4() { - out4 = append(out4, ip) - } else if ip.Addr().Is6() { - out6 = append(out6, ip) - } - // Note: IsLoopback, IsLinkLocalUnicast, etc. addresses are still included - // Consider if these should be filtered based on use case - } - return -} - -// PreferredAddrs returns the list of IPs per the dialer's preference. -func (h *MH) PreferredAddrs() []netip.AddrPort { - if h == nil { - return nil - } - - out4, out6, og := h.splitFamily() - - out := make([]netip.AddrPort, 0, len(og)) - if dialers.Use4() { - out = append(out, out4...) - } - if dialers.Use6() { // ipv4 addrs followed by ipv6 - out = append(out, out6...) - } - if len(out) <= 0 { // fail open - return slices.Clone(og) // Return copy to prevent modification - } - return out -} - -// prefers v4; see: github.com/WireGuard/wireguard-android/blob/4ba87947a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java#L97 -func (h *MH) PreferredAddr() netip.AddrPort { - if h == nil { - log.W("multihost: PreferredAddr: nil multihost") - return zeroaddr - } - - addrs := h.Addrs() - if len(addrs) == 0 { - log.W("multihost: %s: no addresses available", h.o) - return zeroaddr - } - - out6 := zeroaddr - fallback4 := zeroaddr - fallback6 := zeroaddr - has4Or46 := dialers.Use4() - has6Or46 := dialers.Use6() - hasOnly6 := has6Or46 && !has4Or46 - - for _, ip := range addrs { - if ip.Addr().IsUnspecified() || !ip.IsValid() { - continue - } - if ip.Addr().Is4() && has4Or46 { - return ip // the first v4 addr - } else if ip.Addr().Is4() && !fallback4.IsValid() { - fallback4 = ip // note the first valid v4 addr - } - if ip.Addr().Is6() { - if hasOnly6 { - return ip // the first v6 addr - } - if has6Or46 && !out6.IsValid() { - out6 = ip // note the first valid v6 addr - } else if !fallback6.IsValid() { - fallback6 = ip // note the first valid v6 addr - } - } - } - - if out6.IsValid() { - return out6 - } - - log.W("multihost: %s: no preferred; v4(use? %t, fallback? %s), v6(use? %t, fallback? %s)", - h.o, has4Or46, fallback4, has6Or46, fallback6) - if fallback4.IsValid() { - return fallback4 - } - return fallback6 // may be zero addr or unspecified -} - -func (h *MH) Len() int { - if h == nil { - return 0 - } - - h.RLock() - defer h.RUnlock() - // names may exist without addrs and vice versa - return max(len(h.addrs)+len(h.preresolved), len(h.names)) -} - -// Refresh resets the list of IPs, hostnames, and re-resolves the hostname. -// It returns the total number of IPs, or -1 on error. -func (h *MH) Refresh() int { - if h == nil { - log.W("multihost: refresh: nil") - return -1 - } - if names := h.Names(); len(names) > 0 { - // reset all ips; resolve from names - return h.Set(names) - } // nothing to refresh - return h.Len() -} - -// SoftRefresh appends to the list of IPs, hostnames by re-resolving the hostname. -// It returns the total number of IPs, or -1 on error. -func (h *MH) SoftRefresh() int { - if h == nil { - log.W("multihost: soft refresh: nil") - return -1 - } - - if names, stale := h.stale(); len(names) > 0 && stale { - // resolve ip from domain names (auto removes dups); then append - return h.Add(names) - } - return h.Len() -} - -func (h *MH) stale() ([]string, bool) { - if h == nil { - return nil, false - } - h.RLock() - thres := h.mtime.Add(refreshInterval) - names := slices.Clone(h.names) // Return copy - h.RUnlock() - return names, time.Since(thres) > 0 -} - -// Add appends to the existing list of IPs, hostnames, and hostname's IPs if resolved. -func (h *MH) Add(domainsOrIps []string) int { - return h.add(domainsOrIps, Append) -} - -// Set replaces the existing list of IPs, hostnames, and hostname's IPs if resolved. -func (h *MH) Set(domainsOrIps []string) int { - return h.add(domainsOrIps, Reset) -} - -// Add appends the list of de-duplicated IPs, hostnames, and hostname's IPs as resolved. -// It returns the total number of IPs. -func (h *MH) add(domainsOrIps []string, op MHAddOp) int { - if h == nil { - log.E("multihost: add: nil multihost") - return -1 - } - - id := h.o - if len(domainsOrIps) <= 0 { - log.D("multihost: %s add: no domains or ips; existing n? %d", id, h.Len()) - return 0 - } - - names, pre, addrs, err := resolv(id, domainsOrIps) - if err != nil { // errs are okay - log.W("multihost: %s add: resolution errs: %v", id, err) - } - - h.Lock() - defer h.Unlock() - - if op == Reset { // reset whatever is non-empty - if len(names) > 0 { - h.names = names - } - if len(addrs) > 0 { - h.addrs = addrs - } - if len(pre) > 0 { - h.preresolved = pre - } - } else if op == Append { - h.names = append(h.names, names...) - h.addrs = append(h.addrs, addrs...) - h.preresolved = append(h.preresolved, pre...) - } else { - log.E("multihost: %s add: %v => %v [+ %v]; unknown op %d", id, names, addrs, pre, op) - return -1 - } - - h.mtime = time.Now() - // remove dups from h.addrs and h.names - h.uniqAddrsLocked() - h.uniqPreLocked() - h.uniqNamesLocked() - log.D("multihost: %s add: op %s; names: %v (new: %v) => resolved: %v (new: %v) + pre: %v (new: %v)", - h.o, op, h.names, names, h.addrs, addrs, h.preresolved, pre) - return len(h.addrs) + len(h.preresolved) -} - -// resolv resolves the given domains or ips and returns the names, pre-resolved IPs, and resolved IPs. -// It returns an error if any of the domains or ips could not be resolved. -// It also returns the names as is, even if they could not be resolved. -// The returned pre-resolved IPs are those that were already resolved before calling this function. -// The returned resolved IPs are those that were resolved during this call. -// Caller may ignore err if you are okay with some domains or ips not being resolved. -func resolv(id string, domainsOrIps []string) (names []string, pre []netip.AddrPort, addrs []netip.AddrPort, err error) { - names = make([]string, 0, len(domainsOrIps)) - pre = make([]netip.AddrPort, 0) // pre-resolved - addrs = make([]netip.AddrPort, 0) - var errs []error - - for _, ep := range domainsOrIps { - // ep is host or ip or host:port or ip:port - dip, port, parseErr := normalize(ep) // port may be 0 - if parseErr != nil { - log.W("multihost: %s failed to parse endpoint %s: %v", id, ep, parseErr) - errs = append(errs, parseErr) - continue - } - if len(dip) <= 0 { - log.D("multihost: %s add, skipping empty host: %s:%d", id, dip, port) - continue - } - if ip, parseIPErr := netip.ParseAddr(dip); parseIPErr != nil { // may be hostname - names = append(names, ep) // add hostname regardless of resolution success - log.D("multihost: %s resolving: %q", id, ep) - if resolvedips, resolveErr := dialers.Resolve(dip); resolveErr == nil && len(resolvedips) > 0 { - reps := addrport(port, resolvedips...) - addrs = append(addrs, reps...) - log.V("multihost: %s resolved: %q => %s", id, dip, reps) - } else { - // err may be nil even on zero answers - resolveErr = core.OneErr(resolveErr, errNoIps) - log.W("multihost: %s no ips for %q; err? %v", id, dip, resolveErr) - errs = append(errs, resolveErr) - } - } else { // may be ip - // Validate IP before adding - if !ip.IsValid() { - ipErr := core.OneErr(errInvalidPort, nil) // Use core.OneErr to create error - log.W("multihost: %s invalid IP: %s", id, dip) - errs = append(errs, ipErr) - continue - } - pre = append(pre, addrport(port, ip)...) - } - } - - err = core.JoinErr(errs...) - - return -} - -// dip can be host or ip or host:port or ip:port -func normalize(dip string) (string, uint16, error) { - dip = strings.TrimSpace(dip) - if len(dip) <= 0 { - return "", 0, errNoIps - } - if hostOrIP, portstr, err := net.SplitHostPort(dip); err == nil { - port, err := strconv.ParseUint(portstr, 10, 16) - if err != nil { - log.D("multihost: normalize(%s), invalid port; err: %v", dip, err) - return "", 0, core.OneErr(errInvalidPort, err) - } - if port > 65535 { - return "", 0, errInvalidPort - } - return hostOrIP, uint16(port), nil - } - return dip, 0, nil -} - -// 0 port is valid -func addrport(port uint16, ips ...netip.Addr) []netip.AddrPort { - if len(ips) == 0 { - return nil - } - a := make([]netip.AddrPort, 0, len(ips)) - for _, ip := range ips { - if !ip.IsValid() { - log.D("multihost: addrport: skipping invalid IP: %s", ip) - continue - } - a = append(a, netip.AddrPortFrom(ip, port)) - } - return a -} - -func (h *MH) EqualAddrs(other *MH) bool { - const eq = true - const noteq = false - if h == nil && other == nil { - return eq - } - if h == nil || other == nil { - return noteq - } - - us := core.CopyUniq(h.Addrs()) - them := core.CopyUniq(other.Addrs()) - - if len(us) != len(them) { - return noteq - } - - for _, u := range us { - found := false - for _, t := range them { - if u.Compare(t) == 0 { - found = true - break - } - } - if !found { - log.D("multihost: %s != %s; missing %s", h.o, other.o, u) - return noteq - } - } - log.V("multihost: %s == %s", h.o, other.o) - return eq -} - -func (h *MH) uniqNamesLocked() { - if h == nil { - return - } - h.names = core.CopyUniq(h.names) -} - -func (h *MH) uniqAddrsLocked() { - if h == nil { - return - } - h.addrs = core.CopyUniq(h.addrs) -} - -func (h *MH) uniqPreLocked() { - if h == nil { - return - } - h.preresolved = core.CopyUniq(h.preresolved) -} - -func logeif(cond bool) log.LogFn { - if cond { - return log.E - } - return log.D -} diff --git a/intra/ipn/multihost/multihost_test.go b/intra/ipn/multihost/multihost_test.go deleted file mode 100644 index 18879d1e..00000000 --- a/intra/ipn/multihost/multihost_test.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package multihost - -import ( - "context" - "errors" - "net" - "net/netip" - "testing" - - "github.com/celzero/firestack/intra/dialers" - ilog "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -type fakeResolver struct { - *net.Resolver -} - -func (r fakeResolver) Lookup(q []byte, _ string, _ ...string) ([]byte, error) { - // return nil, errors.New("lookup: not implemented") - msg := xdns.AsMsg(q) - if msg == nil { - return nil, errors.New("fakeresolver: nil dns msg") - } - if !xdns.HasAQuadAQuestion(msg) { - return nil, errors.New("fakeresolver: A/AAAA only") - } - qname := xdns.QName(msg) - network := "ip4" - if xdns.HasAAAAQuestion(msg) { - network = "ip6" - } - addrs, err := r.Resolver.LookupNetIP(context.TODO(), network, qname) - if err != nil { - return nil, err - } - // make a dns answer for addrs - ans := xdns.EmptyResponseFromMessage(msg) - if ans == nil { - return nil, errors.New("fakeresolver: nil pkt") - } - rrs := make([]dns.RR, 0) - for _, a := range addrs { - if network == "ip4" { - rr := xdns.MakeARecord(qname, a.String(), 30) - rrs = append(rrs, rr) - } else { - rr := xdns.MakeAAAARecord(qname, a.String(), 30) - rrs = append(rrs, rr) - } - } - ans.Answer = rrs - - return ans.Pack() -} - -func (r fakeResolver) LocalLookup(q []byte) ([]byte, error) { - return r.Lookup(q, protect.UidSelf) -} - -func (r fakeResolver) LookupFor(q []byte, _ string) ([]byte, error) { - return r.Lookup(q, protect.UidSelf) -} - -func (r fakeResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { - return r.Resolver.LookupNetIP(ctx, network, host) -} - -func (r fakeResolver) LookupNetIPFor(ctx context.Context, network, host, uid string) ([]netip.Addr, error) { - return r.Resolver.LookupNetIP(ctx, network, host) -} - -func (r fakeResolver) LookupNetIPOn(ctx context.Context, network, host string, tid ...string) ([]netip.Addr, error) { - return r.Resolver.LookupNetIP(ctx, network, host) -} - -func TestMultihostMap(t *testing.T) { - ilog.SetLevel(0) - settings.Debug = true - - dialers.Mapper(&fakeResolver{}) - - h0 := New("h0") - h0domains := []string{ - "one.one.one.one:53", - "dns.google:443", - } - h0.Add(h0domains) - - h1 := New("h1") - h1ips := []string{ - "1.1.1.1:53", - "2.2.2.2:23", - "3.3.3.3:33", - } - h1.Add(h1ips) - - h2 := New("h2") - h2ips := []string{ - "[2000:b:0::5:0]:53", - "[2000:d:e::a:d]:23", - "[2000:b:e::e:f]:33", - } - h2.Add(h2ips) - - m := NewMap("testmap") - - if !m.Put(h0) { - t.Fatal("expected to put h0") - } - - _, xperr0 := m.Get("one.one.one.one:443") // empty; wrong port - _, unerr0 := m.Get("dns.google:443") - if xperr0 == nil { - t.Errorf("expected error, got nil") - t.Fail() - } - if unerr0 != nil { - t.Errorf("expected no error, got %v", unerr0) - t.Fail() - } - - if !m.Put(h1) { - t.Fatal("expected to put h1") - } - _, xperr1 := m.Get("1.1.1.1") // empty - _, unerr1 := m.Get("1.1.1.1:53") - if xperr1 == nil { - t.Errorf("expected error, got nil") - t.Fail() - } - if unerr1 != nil { - t.Errorf("expected no error, got %v", unerr1) - t.Fail() - } - _, xperr2 := m.Get("[2000:d:e::a:d]:23") // empty - if !m.Put(h2) { - t.Fatal("expected to put h2") - } - _, unerr2 := m.Get("[2000:d:e::a:d]:23") // empty - if xperr2 == nil { - t.Errorf("expected error, got nil") - t.Fail() - } - if unerr2 != nil { - t.Errorf("expected no error, got %v", unerr2) - t.Fail() - } - // ilog.D(m.String()) // only prints ips -} diff --git a/intra/ipn/nop.go b/intra/ipn/nop.go deleted file mode 100644 index 3fe6347f..00000000 --- a/intra/ipn/nop.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "errors" - "net/netip" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/protect" -) - -var ( - errProbeNotSupported = errors.New("proxy: probe not supported") - errAnnounceNotSupported = errors.New("proxy: announce not supported") -) - -const nodns = "" // no DNS - -// todo: impl a baseproxy for common fns ID, GetAddr, Status, Via, Hop etc -type GWNoVia struct { - NoVia - GW -} - -// GW is a no-op/stub gateway that is either dualstack or not and has dummy stats. -type GW struct { - nov4, nov6 bool // is dualstack - stats x.RouterStats // zero stats -} - -var _ x.Router = (*GWNoVia)(nil) - -// IP4 implements x.Router. -func (w *GW) IP4() bool { return !w.nov4 } - -// IP6 implements x.Router. -func (w *GW) IP6() bool { return !w.nov6 } - -// MTU implements x.Router. -func (w *GW) MTU() (int, error) { return NOMTU, errNoMtu } - -// Stat implements x.Router. -func (w *GW) Stat() *x.RouterStats { - if !w.nov4 || !w.nov6 { - w.stats.LastOK = now() // always OK - } - return &w.stats -} - -// Contains implements x.Router. -func (w *GW) Contains(ippOrCidr string) bool { - prefix, err := core.IP2Cidr2(ippOrCidr) - if err != nil { - return false - } - return w.ok(prefix.Addr()) -} - -func (w *GW) ok(ip netip.Addr) bool { return w.ok4(ip) || w.ok6(ip) } -func (w *GW) ok4(ip netip.Addr) bool { return w.IP4() && ip.IsValid() && ip.Is4() } -func (w *GW) ok6(ip netip.Addr) bool { return w.IP6() && ip.IsValid() && ip.Is6() } - -// Reaches implements Router. -func (w *GW) Reaches(hostportOrIPPortCsvStr string) bool { - hostportOrIPPortCsv := hostportOrIPPortCsvStr - - if len(hostportOrIPPortCsv) <= 0 { - return true - } - ips := dialers.For(hostportOrIPPortCsv) - for _, ip := range ips { - if w.ok(ip) { - return true - } - } - return false -} - -// ProxyNoGateway is a Router that routes nothing. -var ProxyNoGateway = GWNoVia{GW: GW{nov4: true, nov6: true}} - -// ProtoAgnostic is a proxy that does not care about protocol changes. -type ProtoAgnostic struct{} - -// OnProtoChange implements Proxy. -func (ProtoAgnostic) OnProtoChange(_ LinkProps) (string, bool) { return "", false } - -// SkipRefresh is a proxy that does not need to be refreshed or pinged on network changes. -type SkipRefresh struct{} - -// Refresh implements Proxy. -func (SkipRefresh) Refresh() error { return nil } - -func (SkipRefresh) onNotOK() (didRefresh bool, allOK bool) { return false, true } - -// Ping implements Proxy. -func (SkipRefresh) Ping() bool { return false } - -type CantPause struct{} - -// Pause implements Proxy. -func (CantPause) Pause() bool { return false } - -// Resume implements Proxy. -func (CantPause) Resume() bool { return false } - -// NoFwd is a proxy that does not support listening or forwarding. -type NoFwd struct{} - -// Announce implements Proxy. -func (NoFwd) Announce(network, local string) (protect.PacketConn, error) { - return nil, errAnnounceNotSupported -} - -// Accept implements Proxy. -func (NoFwd) Accept(network, local string) (protect.Listener, error) { - return nil, errAnnounceNotSupported -} - -// Probe implements Proxy. -func (NoFwd) Probe(string, string) (protect.PacketConn, error) { - return nil, errProbeNotSupported -} - -type NoDNS struct{} - -func (NoDNS) DNS() string { - return nodns -} - -type NoVia struct{} - -func (NoVia) Via() (x.Proxy, error) { return nil, errNop } -func (NoVia) Hop(Proxy, bool) error { return errNop } - -var errNop = errors.New("proxy: nop") - -type NoClient struct{} - -func (NoClient) Client() x.Client { return nil } - -type NoProxy struct { - NoDNS - ProtoAgnostic - SkipRefresh - NoFwd - CantPause - GWNoVia -} - -func (NoProxy) Handle() uintptr { return core.Nobody } -func (NoProxy) DialerHandle() uintptr { return core.Nobody } -func (NoProxy) ID() string { return "" } -func (NoProxy) Type() string { return "" } -func (NoProxy) Router() x.Router { return nil } -func (NoProxy) Reaches(string) bool { return false } -func (NoProxy) Dial(string, string) (protect.Conn, error) { return nil, errNop } -func (NoProxy) DialBind(string, string, string) (protect.Conn, error) { return nil, errNop } -func (NoProxy) Dialer() protect.RDialer { return nil } -func (NoProxy) Status() int { return 0 } -func (NoProxy) GetAddr() string { return "" } -func (NoProxy) Stop() error { return nil } -func (NoProxy) Client() x.Client { return nil } diff --git a/intra/ipn/piph2.go b/intra/ipn/piph2.go deleted file mode 100644 index 984d90ad..00000000 --- a/intra/ipn/piph2.go +++ /dev/null @@ -1,597 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "crypto/hmac" - "crypto/sha256" - "crypto/tls" - "encoding/hex" - "io" - "net" - "net/http" - "net/http/httptrace" - "net/textproto" - "net/url" - "strconv" - "strings" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/net/http2" -) - -type piph2 struct { - NoFwd // no forwarding/listening - NoDNS // no dns - ProtoAgnostic // since dial, dialts are proto aware - SkipRefresh // no refresh - GW // dual stack gateway - - url string // h2 proxy url - hostname string // h2 proxy hostname - port int // h2 proxy port - token string // hex, client token - toksig string // hex, authorizer signed client token - rsasig string // hex, authorizer unblinded signature - client http.Client // h2 client, see trType - outbound *protect.RDial // h2 dialer - px ProxyProvider - via *core.WeakRef[Proxy] // hop dialer - viaID *core.Volatile[string] // hop proxy ID - opts *settings.ProxyOptions - - done context.CancelFunc - - // mutable fields - lastdial *core.Volatile[time.Time] // last dial time - status *core.Volatile[int] // proxy status: TOK, TKO, END -} - -// github.com/posener/h2conn/blob/13e7df33ed1/conn.go -type pipconn struct { - id string // some identifier - rch <-chan io.ReadCloser // reader provider - wch chan<- int64 // first write len(data) - ok bool // r is ok to read from - r io.ReadCloser // reader, nil until ok is true - w io.WriteCloser // writer - laddr net.Addr // local address, may be nil - raddr net.Addr // remote address -} - -var _ core.TCPConn = (*pipconn)(nil) - -func (c *pipconn) Read(b []byte) (int, error) { - log.V("piph2: read(%v/%s) waiting?(%t)", len(b), c.id, !c.ok) - if !c.ok { - c.r = <-c.rch // nil on error - c.ok = true - } - if core.IsNil(c.r) { - log.E("piph2: read(%v/%s) not ok", len(b), c.id) - return 0, io.EOF - } - return c.r.Read(b) -} - -func (c *pipconn) Write(b []byte) (int, error) { - c.wch <- int64(len(b)) - log.V("piph2: write(%v/%s) read-waiting?(%t)", len(b), c.id, !c.ok) - if c.w == nil { - log.E("piph2: write(%v/%s) not ok", len(b), c.id) - return 0, io.EOF - } - return c.w.Write(b) -} - -func (c *pipconn) Close() (err error) { - log.D("piph2: close(%s); waiting?(%t)", c.id, c.ok) - c.CloseRead() - c.CloseWrite() - return nil -} - -func (c *pipconn) CloseRead() error { core.Close(c.r); return nil } -func (c *pipconn) CloseWrite() error { core.Close(c.w); return nil } -func (c *pipconn) LocalAddr() net.Addr { return c.laddr } -func (c *pipconn) RemoteAddr() net.Addr { return c.raddr } -func (c *pipconn) SetDeadline(t time.Time) error { return nil } -func (c *pipconn) SetReadDeadline(t time.Time) error { return nil } -func (c *pipconn) SetWriteDeadline(t time.Time) error { return nil } - -func (t *piph2) dialtls(network, addr string, cfg *tls.Config) (net.Conn, error) { - rawConn, err := t.dial(network, addr) - if err != nil || rawConn == nil || core.IsNil(rawConn) { - return nil, core.JoinErr(err, errNoProxyConn) - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - if cfg == nil { - cfg = &tls.Config{ - ServerName: hostname, - MinVersion: tls.VersionTLS12, - SessionTicketsDisabled: false, - ClientSessionCache: core.TlsSessionCache(), - } - } else if cfg.ServerName == "" { - if cfg = cfg.Clone(); cfg != nil { - cfg.ServerName = hostname - if cfg.ClientSessionCache == nil && !cfg.SessionTicketsDisabled { - cfg.ClientSessionCache = core.TlsSessionCache() - } - } - } - - conn := tls.Client(rawConn, cfg) - if err := conn.HandshakeContext(context.Background()); err != nil { - log.D("piph2: dialtls(%s) handshake error: %v", addr, err) - core.CloseConn(rawConn) - return nil, err - } - return conn, nil -} - -// dial dials proxy addr using the proxydialer via dialers.SplitDial, -// which is aware of proto changes. -func (t *piph2) dial(network, addr string) (c net.Conn, err error) { - who := idstr(t) - if usevia(t.viaID) { - if v, vok := t.via.Get(); vok { // dial via another proxy - who = idstr(v) - c, err = v.Dial(network, addr) - } else { - err = errNoHop - if removeViaOnErrors { - t.Hop(nil, false /*dryrun*/) // stale; unset - } - log.W("piph2: via(%s) failing...", idhandle(v)) - } - } else { - if settings.Loopingback.Load() { // no split in loopback (rinr) mode - c, err = dialers.Dial(t.outbound, network, addr) - } else { - c, err = dialers.SplitDial(t.outbound, network, addr) - } - } - defer localDialStatus(t.status, err) - logei(err)("piph2: dial(%s) %s (via %s); err? %v", network, addr, who, err) - return -} - -func NewPipProxy(ctx context.Context, ctl protect.Controller, px ProxyProvider, po *settings.ProxyOptions) (*piph2, error) { - if po == nil { - return nil, errMissingProxyOpt - } - - parsedurl, err := url.Parse(po.Url()) - if err != nil { - return nil, err - } - // may be "piph2" - if parsedurl.Scheme != "https" { - parsedurl.Scheme = "https" - } - portStr := parsedurl.Port() - var port int - if len(portStr) > 0 { - port, err = strconv.Atoi(portStr) - if err != nil { - return nil, err - } - } else { - port = 443 - } - - splitpath := strings.Split(parsedurl.Path, "/") - if len(splitpath) < 3 { - return nil, errNoSig - } - trType := splitpath[1] - if trType != "h2" && trType != "h3" { - return nil, errProxyConfig - } - rsasig := splitpath[2] - // todo: check if the len(rsasig) is 64/128 hex chars? - if len(rsasig) == 0 { - return nil, errNoSig - } - ctx, done := context.WithCancel(ctx) - t := &piph2{ - url: parsedurl.String(), - hostname: parsedurl.Hostname(), - port: port, - outbound: protect.MakeNsRDial(RpnH2, ctx, ctl), - px: px, - viaID: core.NewZeroVolatile[string](), - token: po.Auth.User, - toksig: po.Auth.Password, - rsasig: rsasig, - status: core.NewVolatile(TUP), - done: done, - lastdial: core.NewVolatile(time.Time{}), - opts: po, - } - t.via, err = core.NewWeakRef(t.viafor, viaok) - if err != nil { - return nil, err - } - - _, ok := dialers.New(t.hostname, po.Addrs) // po.Addrs may be nil or empty - if !ok { - log.W("piph2: zero bootstrap ips %s", t.hostname) - } - - if trType == "h3" { - // github.com/quic-go/quic-go v0.36.1 - // t.client.Transport = &http3.RoundTripper{} - log.W("piph2: h3 not supported yet") - t.client.Transport = &http2.Transport{ - DialTLS: t.dialtls, - } - } else if trType == "h2" { - // h2 is duplex: github.com/golang/go/issues/19653#issuecomment-341539160 - t.client.Transport = &http2.Transport{ - DialTLS: t.dialtls, - } - } else { - t.client.Transport = &http.Transport{ - Dial: t.dial, - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: tlsHandshakeTimeout, - ResponseHeaderTimeout: responseHeaderTimeout, - } - } - - return t, nil -} - -func (t *piph2) viafor() *Proxy { - return viafor(idstr(t), t.viaID.Load(), t.px) -} - -func (t *piph2) swapVia(new Proxy) Proxy { - return swapVia(idstr(t), new, t.viaID, t.via) -} - -// ID implements Proxy. -func (t *piph2) ID() string { - return RpnH2 -} - -// Type implements Proxy. -func (t *piph2) Type() string { - return PIPH2 -} - -// GetAddr implements Proxy. -func (t *piph2) GetAddr() string { - return t.hostname + ":" + strconv.Itoa(t.port) -} - -// Router implements Proxy. -func (t *piph2) Router() x.Router { - return t -} - -// Reaches implements x.Router. -func (t *piph2) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(t, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (t *piph2) Hop(p Proxy, dryrun bool) error { - if p == nil { - if !dryrun { - old := t.swapVia(nil) - log.I("piph2: hop(%s) removed", idhandle(old)) - } - return nil - } - if p.Status() == END { - return errProxyStopped - } - - if !dryrun { - old := t.swapVia(p) - log.I("piph2: hop %s => %s", idhandle(old), idhandle(p)) - } - return nil -} - -// Via implements x.Router. -func (t *piph2) Via() (x.Proxy, error) { - if v := t.via.Load(); v != nil { - return v, nil - } - return nil, errNoHop -} - -// Start implements Proxy. -func (t *piph2) Stop() error { - t.status.Store(END) - t.done() - return nil -} - -// Status implements Proxy. -func (t *piph2) Status() int { - st := t.status.Load() - if st != END && idling(t.lastdial.Load()) { - return TZZ - } - return st -} - -// Since implements x.Proxy. -func (h *piph2) Pause() bool { - st := h.status.Load() - if st == END { - log.W("proxy: piph2: pause called when stopped") - return false - } - - ok := h.status.Cas(st, TPU) - log.I("proxy: piph2: paused? %t", ok) - return ok -} - -// Resume implements x.Proxy. -func (h *piph2) Resume() bool { - st := h.status.Load() - if st != TPU { - log.W("proxy: piph2: resume called when not paused; status %d", st) - return false - } - - ok := h.status.Cas(st, TUP) - go h.Refresh() // no-op since SkipRefresh - log.I("proxy: piph2: resumed? %t", ok) - return ok -} - -// Scenario 4: privacypass.github.io/protocol -func (t *piph2) claim(msg string) []string { - if len(t.token) == 0 || len(t.toksig) == 0 { - return nil - } - // hmac msg keyed by token's sig - msgmac := hmac256(hex2byte(msg), hex2byte(t.toksig)) - return []string{t.token, byte2hex(msgmac)} -} - -// Handle implements Proxy. -func (t *piph2) Handle() uintptr { - return core.Loc(t) -} - -// DialerHandle implements Proxy. -func (t *piph2) DialerHandle() uintptr { - return core.Loc(t.outbound) -} - -// Dial implements Proxy. -func (t *piph2) Dial(network, addr string) (protect.Conn, error) { - return t.forward(network, addr) -} - -// DialBind implements Proxy. -func (t *piph2) DialBind(network, local, remote string) (protect.Conn, error) { - log.D("piph2: dialbind(%s) from %s to %s not supported", network, local, remote) - // TODO: error instead? - return t.forward(network, remote) -} - -func (t *piph2) forward(network, addr string) (protect.Conn, error) { - if err := candial(t.status); err != nil { - return nil, errProxyStopped - } - if network != "tcp" { - return nil, errUnexpectedProxy - } - - u, err := url.Parse(t.url) - if err != nil { - return nil, err - } - domain, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - if !strings.HasSuffix(u.Path, "/") { - u.Path += "/" - } - u.Path += domain + "/" + port + "/" + t.rsasig - - // ref: github.com/ginuerzh/gost/blob/1c62376e0880e/http2.go#L221 - // and: github.com/golang/go/issues/17227#issuecomment-249424243 - readable, writable := io.Pipe() - // multipart? stackoverflow.com/questions/39761910 - // mpw := multipart.NewWriter(writable) - // todo: buffered chan may slow down the client - incomingCh := make(chan io.ReadCloser, 1) - wlenCh := make(chan int64, 1) - oconn := &pipconn{ - id: u.Path, - rch: incomingCh, - wch: wlenCh, - w: writable, // never nil - } - - // github.com/golang/go/issues/26574 - req, err := http.NewRequest(http.MethodPut, u.String(), io.NopCloser(readable)) - - if err != nil { - log.E("piph2: req err: %v", err) - t.status.Store(TKO) - closePipe(readable, writable) - return nil, err - } - - msg := fixedMsgHex // 16 bytes; fixed - if uniqClaimPerUrl { - msg = hexurl(u.Path) // 32 bytes; per url - } else { - u.Path = u.Path + "/" + msg - } - - trace := httptrace.ClientTrace{ - GetConn: func(hostPort string) { - log.V("piph2: %s GetConn(%s)", u.Path, hostPort) - }, - GotConn: func(info httptrace.GotConnInfo) { - if info.Conn == nil { - return - } - oconn.laddr = info.Conn.LocalAddr() - oconn.raddr = info.Conn.RemoteAddr() - log.D("piph2: GotConn([%v -> %v] (via %v))", oconn.laddr, addr, oconn.raddr) - }, - PutIdleConn: func(err error) { - log.V("piph2: %s PutIdleConn(%v)", u.Path, err) - }, - GotFirstResponseByte: func() { - log.V("piph2: %s GotFirstResponseByte()", u.Path) - }, - Got100Continue: func() { - log.V("piph2: %s Got100Continue()", u.Path) - }, - Got1xxResponse: func(code int, header textproto.MIMEHeader) error { - log.V("piph2: %s Got1xxResponse(%d, %v)", u.Path, code, header) - return nil - }, - DNSStart: func(info httptrace.DNSStartInfo) { - log.V("piph2: %s DNSStart(%v)", u.Path, info) - }, - DNSDone: func(info httptrace.DNSDoneInfo) { - log.V("piph2: %s DNSDone(%v)", u.Path, info) - }, - ConnectStart: func(network, addr string) { - log.V("piph2: %s ConnectStart(%s, %s)", u.Path, network, addr) - }, - ConnectDone: func(network, addr string, err error) { - log.V("piph2: %s ConnectDone(%s, %s, %v)", u.Path, network, addr, err) - }, - TLSHandshakeStart: func() { - log.V("piph2: %s TLSHandshakeStart()", u.Path) - }, - TLSHandshakeDone: func(state tls.ConnectionState, err error) { - log.V("piph2: %s TLSHandshakeDone(%v, %v)", u.Path, state, err) - }, - WroteHeaders: func() { - log.V("piph2: %s WroteHeaders()", u.Path) - }, - WroteRequest: func(info httptrace.WroteRequestInfo) { - log.V("piph2: %s WroteRequest(%v)", u.Path, info) - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), &trace)) - - log.D("piph2: req %s", u) - // infinite length? doesn't work with cloudflare - // req.ContentLength = -1 - req.Close = false // allow keep-alive - // github.com/stripe/stripe-go/pull/711 - req.GetBody = func() (io.ReadCloser, error) { - log.V("piph2: %s GetBody()", u.Path) - return io.NopCloser(readable), nil - } - req.Header.Set("User-Agent", "") - // sse? community.cloudflare.com/t/184219 - // pack binary data into utf-8? - // stackoverflow.com/a/31661586 - // go.dev/play/p/NPsulbF2y9X - // req.Header.Set("Content-Type", "text/event-stream") - msgmac := t.claim(msg) - req.Header.Set("Content-Type", "application/octet-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - if msgmac != nil { - req.Header.Set("x-nile-pip-claim", msgmac[0]) - req.Header.Set("x-nile-pip-mac", msgmac[1]) - // msg is implicitly hex(sha256(url.Path)) - // req.Header.Set("x-nile-pip-msg", msg) - } - - t.lastdial.Store(time.Now()) - core.Go("piph2.Dial", func() { - // fixme: currently, this hangs forever when upstream is cloudflare - // setting the content-length to the first len(first-write-bytes) works - // with cloudflare, but then golang's h2 client isn't happy about sending - // more data than what's defined in content-length: - // github.com/golang/go/issues/32728 - req.ContentLength = <-wlenCh - res, err := t.client.Do(req) - if err != nil || res == nil { - log.E("piph2: path(%s) send err: %v", u.Path, err) - t.status.Store(TKO) - incomingCh <- nil - closePipe(readable, writable) - } else if res.StatusCode != http.StatusOK { - log.E("piph2: path(%s) recv bad: %v", u.Path, res.Status) - core.Close(res.Body) - t.status.Store(TKO) - incomingCh <- nil - closePipe(readable, writable) - } else { - log.D("piph2: duplex %s", u) - // github.com/posener/h2conn/blob/13e7df33ed1/client.go - res.Request = req - t.status.Store(TOK) - incomingCh <- res.Body - } - }) - - t.status.Store(TOK) - return oconn, nil -} - -// Dialer implements Proxy. -func (t *piph2) Dialer() protect.RDialer { - return t -} - -func closePipe(ps ...io.Closer) { - for _, c := range ps { - core.CloseOp(c, core.CopAny) - } -} - -func hmac256(m, k []byte) []byte { - mac := hmac.New(sha256.New, k) - mac.Write(m) - return mac.Sum(nil) -} - -func hexurl(p string) string { - digest := sha256.Sum256([]byte(p)) - return hex.EncodeToString(digest[:]) -} - -func hex2byte(s string) []byte { - b, err := hex.DecodeString(s) - if err != nil { - log.E("piph2: hex2byte: err %v", err) - } - return b -} - -func byte2hex(b []byte) string { - return hex.EncodeToString(b) -} diff --git a/intra/ipn/pipws.go b/intra/ipn/pipws.go deleted file mode 100644 index 86fd0711..00000000 --- a/intra/ipn/pipws.go +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "crypto/tls" - "errors" - "net" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "github.com/coder/websocket" -) - -const ( - writeTimeout = 10 * time.Second - uniqClaimPerUrl = false // generate a new claim per url - fixedMsgHex = "aecdcde241e3196f2252738c11467baf" // some fixed hex; 16 bytes -) - -type pipws struct { - NoFwd // no forwarding/listening - NoDNS // no dns - ProtoAgnostic // since dial is proto aware - SkipRefresh // no refresh - GW // dual stack gateway - - url string // ws proxy url - hostname string // ws proxy hostname - port int // ws proxy port - token string // hex, raw client token - toksig string // hex, authorizer (rdns) signed client token - rsasighash string // hex, authorizer sha256(unblinded signature) - echcfg *tls.Config // ech config - client http.Client // ws client - client3 *http.Client // ws client for ech - outbound *protect.RDial // ws dialer - px ProxyProvider - via *core.WeakRef[Proxy] // hop dialer - viaID *core.Volatile[string] - lastdial time.Time // last dial time - - done context.CancelFunc // cancel func - - status *core.Volatile[int] // proxy status: TOK, TKO, END - opts *settings.ProxyOptions -} - -var _ core.TCPConn = (*pipwsconn)(nil) - -// pipwsconn minimally adapts net.Conn to the core.TCPConn interface -type pipwsconn struct { - net.Conn -} - -func (c *pipwsconn) CloseRead() error { return c.Close() } -func (c *pipwsconn) CloseWrite() error { return c.Close() } - -// connects to the ws proxy at addr over tcp; used by t.client -// dial is aware of proto changes via dialers.SplitDial -func (t *pipws) dial(network, addr string) (c net.Conn, err error) { - who := idstr(t) - if usevia(t.viaID) { - if v, vok := t.via.Get(); vok { // dial via another proxy - who = idstr(v) - c, err = v.Dial(network, addr) - } else { - err = errNoHop - if removeViaOnErrors { - t.Hop(nil, false /*dryrun*/) // stale; unset - } - log.W("pipws: via(%s) failing...", idhandle(v)) - } - } else { - if settings.Loopingback.Load() { // no split in loopback (rinr) mode - c, err = dialers.Dial(t.outbound, network, addr) - } else { - c, err = dialers.SplitDial(t.outbound, network, addr) - } - } - defer localDialStatus(t.status, err) - logei(err)("pipws: dial(%s) to %s (via: %s); err? %v", network, addr, who, err) - return -} - -func (t *pipws) wsconn(rurl, msg string) (c net.Conn, res *http.Response, err error) { - var ws *websocket.Conn - ctx := context.TODO() - msgmac := t.claim(msg) // msg is hex(sha256(url.Path)) or fixedMsgHex - hdrs := http.Header{} - hdrs.Set("User-Agent", "") - if msgmac != nil { - hdrs.Set("x-nile-pip-msg", msg) - hdrs.Set("x-nile-pip-claim", msgmac[0]) // client token (po.User) - hdrs.Set("x-nile-pip-mac", msgmac[1]) // hmac derived from token-sig (po.Password) - // msg is implicitly hex(sha256(url.Path)) - // hdrs.Set("x-nile-pip-msg", msg) - } - - log.D("pipws: connecting to %s", rurl) - - if c3 := t.client3; c3 != nil { // ech with tls v3 - ws, res, err = websocket.Dial(ctx, rurl, &websocket.DialOptions{ - HTTPClient: c3, - HTTPHeader: hdrs, - }) - - if eerr := new(tls.ECHRejectionError); errors.As(err, &eerr) { - closeWs(ws, "ech rejected") - manual := false - ech := eerr.RetryConfigList - if len(ech) <= 0 { - ech = t.ech() - manual = true - } - log.I("pipws: ech rejected; new? %d / manual? %t, err: %v", - len(ech), manual, eerr) - if len(ech) > 0 { // retry with new ech - t.echcfg.EncryptedClientHelloConfigList = ech - // TODO: is this necessary given echcfg is already set? - t.client3.Transport = t.h2(t.echcfg) - // retry with new ech - ws, res, err = websocket.Dial(ctx, rurl, &websocket.DialOptions{ - HTTPClient: t.client3, - HTTPHeader: hdrs, - }) - } - } - } - // err nil when there's no ech; err non-nil when ech fails - if err != nil || ws == nil || res == nil { // fallback or use tls v2 - closeWs(ws, "fallback") - - log.D("pipws: fallback to tls v2; err? %v", rurl, err) // err maybe nil - ws, res, err = websocket.Dial(ctx, rurl, &websocket.DialOptions{ - // todo: igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ - // compression does not work with Workers - // CompressionMode: websocket.CompressionNoContextTakeover, - HTTPClient: &t.client, - HTTPHeader: hdrs, - }) - } - if err != nil || ws == nil || res == nil { - closeWs(ws, "dial err") - err = core.OneErr(err, errNoProxyConn) - log.E("pipws: dialing %s (ws? %t, hres? %t); err: %v\n", - rurl, ws == nil, res == nil, err) - return - } - - conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) - c = &pipwsconn{conn} - return -} - -// NewPipWsProxy creates a new pipws proxy with the given id, controller, and proxy options. -// The proxy options must contain a valid URL, and the URL must have a path with the format "/ws/". -// The proxy options must also contain a valid auth user (raw client token) and -// password (expiry + signed raw client token). -func NewPipWsProxy(ctx context.Context, ctl protect.Controller, px ProxyProvider, po *settings.ProxyOptions) (*pipws, error) { - if po == nil { - return nil, errMissingProxyOpt - } - - parsedurl, err := url.Parse(po.Url()) - if err != nil { - return nil, err - } - // may be "pipws" - if parsedurl.Scheme != "wss" { - parsedurl.Scheme = "wss" - } - portStr := parsedurl.Port() - var port int - if len(portStr) <= 0 { - portStr = "443" - } - port, err = strconv.Atoi(portStr) - if err != nil { - return nil, err - } - - splitpath := strings.Split(parsedurl.Path, "/") - // todo: check if the len(rsasig) is 64/128 hex chars? - if len(splitpath) < 3 { - return nil, errNoSig - } - if (splitpath[1] != "ws" && splitpath[1] != "wss") || len(splitpath[3]) <= 0 { - return nil, errProxyConfig - } - - ctx, done := context.WithCancel(ctx) - t := &pipws{ - url: parsedurl.String(), - hostname: parsedurl.Hostname(), - port: port, - outbound: protect.MakeNsRDial(RpnWs, ctx, ctl), - px: px, - viaID: core.NewZeroVolatile[string](), - token: po.Auth.User, - toksig: po.Auth.Password, - rsasighash: splitpath[2], - status: core.NewVolatile(TUP), - done: done, - opts: po, - } - t.via, err = core.NewWeakRef(t.viafor, viaok) - if err != nil { - return nil, err - } - - _, ok := dialers.New(t.hostname, po.Addrs) // po.Addrs may be nil or empty - if !ok { - log.W("pipws: zero bootstrap ips %s", t.hostname) - } - - tlscfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - SessionTicketsDisabled: false, - ClientSessionCache: core.TlsSessionCache(), - } - ech := t.ech() - if len(ech) > 0 { - t.client3 = new(http.Client) - t.echcfg = &tls.Config{ - MinVersion: tls.VersionTLS13, // must be 1.3 - EncryptedClientHelloConfigList: ech, - SessionTicketsDisabled: false, - ClientSessionCache: core.TlsSessionCache(), - } - t.client3.Transport = t.h2(t.echcfg) - } - t.client.Transport = t.h2(tlscfg) - - log.I("pipws: host: %s:%s, sig: %s, ech? %t", t.hostname, portStr, t.rsasighash[:6], t.client3 != nil) - return t, nil -} - -func (t *pipws) h2(cfg *tls.Config) *http.Transport { - return &http.Transport{ - Dial: t.dial, - TLSHandshakeTimeout: writeTimeout, - ResponseHeaderTimeout: writeTimeout, - TLSClientConfig: cfg, - } -} - -func (t *pipws) viafor() *Proxy { - return viafor(idstr(t), t.viaID.Load(), t.px) -} - -func (t *pipws) swapVia(new Proxy) Proxy { - return swapVia(idstr(t), new, t.viaID, t.via) -} - -// ID implements x.Proxy. -func (t *pipws) ID() string { - return RpnWs -} - -// Type implements x.Proxy. -func (t *pipws) Type() string { - return PIPWS -} - -// GetAddr implements x.Proxy. -func (t *pipws) GetAddr() string { - return t.hostname + ":" + strconv.Itoa(t.port) -} - -// Router implements x.Proxy. -func (t *pipws) Router() x.Router { - return t -} - -// Reaches implements x.Router. -func (t *pipws) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(t, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (h *pipws) Hop(p Proxy, dryrun bool) error { - if p == nil { - if !dryrun { - old := h.swapVia(nil) - log.I("pipws: hop(%s) removed", idhandle(old)) - } - return nil - } - if p.Status() == END { - return errProxyStopped - } - - if !dryrun { - old := h.swapVia(p) - log.I("pipws: hop %s => %s", idhandle(old), idhandle(p)) - } - return nil -} - -// Via implements x.Router. -func (h *pipws) Via() (x.Proxy, error) { - if v := h.via.Load(); v != nil { - return v, nil - } - return nil, errNoHop -} - -// Stop implements x.Proxy. -func (t *pipws) Stop() error { - t.status.Store(END) - t.done() - log.I("pipws: stopped") - return nil -} - -// Status implements Proxy. -func (t *pipws) Status() int { - s := t.status.Load() - if s != END && idling(t.lastdial) { - return TZZ - } - return s -} - -// Pause implements x.Proxy. -func (h *pipws) Pause() bool { - st := h.status.Load() - if st == END { - log.W("proxy: pipws: pause called when stopped") - return false - } - - ok := h.status.Cas(st, TPU) - log.I("proxy: pipws: paused? %t", ok) - return ok -} - -// Resume implements x.Proxy. -func (h *pipws) Resume() bool { - st := h.status.Load() - if st != TPU { - log.W("proxy: pipws: resume called when not paused; status %d", st) - return false - } - - ok := h.status.Cas(st, TUP) - go h.Refresh() - - log.I("proxy: pipws: resumed? %t", ok) - return ok -} - -// Scenario 4: privacypass.github.io/protocol -func (t *pipws) claim(msg string) []string { - if len(t.token) == 0 || len(t.toksig) == 0 { - return nil - } - // hmac(msg aka url.path) keyed to hmac-signed(token) - msgmac := hmac256(hex2byte(msg), hex2byte(t.toksig)) - return []string{t.token, byte2hex(msgmac)} -} - -// Handle implements Proxy. -func (t *pipws) Handle() uintptr { - return core.Loc(t) -} - -// DialerHandle implements Proxy. -func (t *pipws) DialerHandle() uintptr { - return core.Loc(t.outbound) -} - -// Dial connects to addr via wsconn over this ws proxy -func (t *pipws) Dial(network, addr string) (protect.Conn, error) { - return t.forward(network, addr) -} - -// DialBind implements Proxy. -func (t *pipws) DialBind(network, local, remote string) (protect.Conn, error) { - log.D("pipws: dialbind(%s) from %s to %s not supported", network, local, remote) - // TODO: error instead? - return t.forward(network, remote) -} - -func (t *pipws) forward(network, addr string) (protect.Conn, error) { - if t.status.Load() == END { - return nil, errProxyStopped - } - // tcp, tcp4, tcp6 - if !strings.Contains(network, "tcp") { - return nil, errUnexpectedProxy - } - - u, err := url.Parse(t.url) - if err != nil { - return nil, err - } - domain, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - if !strings.HasSuffix(u.Path, "/") { - u.Path += "/" - } - u.Path += domain + "/" + port + "/" + t.rsasighash - - msg := fixedMsgHex // 16 bytes; fixed - if uniqClaimPerUrl { - msg = hexurl(u.Path) // 32 bytes; per url - } else { - u.Path = u.Path + "/" + msg - } - - rurl := u.String() - c, res, err := t.wsconn(rurl, msg) - t.lastdial = time.Now() - if err != nil || res == nil { // nilaway - err = core.OneErr(err, errNoProxyConn) - core.CloseConn(c) - log.E("pipws: req %s err: %v", rurl, err) - t.status.Store(TKO) - return nil, err - } - if res.StatusCode != 101 { - core.CloseConn(c) - log.E("pipws: %s res not ws %d", rurl, res.StatusCode) - t.status.Store(TKO) - return nil, err - } - - log.D("pipws: duplex %s", rurl) - - t.status.Store(TOK) - return c, nil -} - -// Dialer implements Proxy. -func (t *pipws) Dialer() protect.RDialer { - return t -} - -func (t *pipws) ech() []byte { - name := t.hostname - if len(name) <= 0 { - return nil - } else if v, err := dialers.ECH(name); err != nil { - log.W("pipws: ech(%s): %v", name, err) - return nil - } else { - log.V("pipws: ech(%s): sz %d", name, len(v)) - return v - } -} - -func closeWs(ws *websocket.Conn, reason string) { - if ws != nil { - _ = ws.Close(websocket.StatusNormalClosure, reason) - } -} diff --git a/intra/ipn/proxies.go b/intra/ipn/proxies.go deleted file mode 100644 index 40b9f98c..00000000 --- a/intra/ipn/proxies.go +++ /dev/null @@ -1,1493 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "errors" - "fmt" - "math/rand" - "net" - "net/netip" - "slices" - "strconv" - "strings" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/ipn/rpn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" -) - -const ( - Block = x.Block - Base = x.Base - Exit = x.Exit - Auto = x.Auto - Ingress = x.Ingress // dummy - OrbotS5 = x.OrbotS5 - OrbotH1 = x.OrbotH1 - GlobalH1 = x.GlobalH1 - RpnWin = x.RpnWin - RpnWs = x.RpnWs - Rpn64 = x.Rpn64 - RpnH2 = x.RpnH2 - - SOCKS5 = x.SOCKS5 - HTTP1 = x.HTTP1 - WG = x.WG - WGFAST = x.WGFAST - PIPH2 = x.PIPH2 - PIPWS = x.PIPWS - NOOP = x.NOOP - INTERNET = x.INTERNET - RPN = x.RPN - - TPU = x.TPU - TNT = x.TNT - TZZ = x.TZZ - TUP = x.TUP - TOK = x.TOK - TKO = x.TKO - END = x.END - - NOMTU = 0 - MAXMTU = 65535 - AUTOMTU = "auto" - AUTOMTU2 = "(auto)" -) - -type pxstatus int - -func (s pxstatus) String() string { - switch s { - case TKO: - return "notok" - case TOK: - return "ok" - case TUP: - return "up" - case TZZ: - return "idle" - case TNT: - return "unresponsive" - case TPU: - return "paused" - case END: - return "ended" - default: - return "unknown" - } -} - -var ( - errProxyScheme = errors.New("proxy: unsupported scheme") - errUnexpectedProxy = errors.New("proxy: unexpected type") - errAddProxy = errors.New("proxy: add failed") - errAddProxyAsRpn = errors.New("proxy: cannot add rpn proxy") - errProxyNotFound = errors.New("proxy: not found") - errGetProxyTimeout = errors.New("proxy: get timeout") - errProxyAllDown = errors.New("proxy: all down") - errNoProxyHealthy = errors.New("proxy: none healthy") - errMissingProxyOpt = errors.New("proxy: opts nil") - errNoProxyConn = errors.New("proxy: not a tcp/udp conn") - errNotUDPConn = errors.New("proxy: not a udp conn") - errProxyStopped = errors.New("proxy: stopped") - errProxyPaused = errors.New("proxy: paused") - errProxyRoute = errors.New("proxy: no route to host") - errProxyConfig = errors.New("proxy: invalid config") - errNoProxyResponse = errors.New("proxy: blocked or no response") - errNoSig = errors.New("proxy: auth missing sig") - errNoMtu = errors.New("proxy: missing mtu") - errNoOpts = errors.New("proxy: no opts") - errNoAuto464XLAT = errors.New("auto: no 464xlat") - errNotPinned = errors.New("auto: another proxy pinned") - errInvalidAddr = errors.New("proxy: invaild ip:port") - errMissingProxyID = errors.New("proxy: missing proxy id") - errHopDefaultRoutes = errors.New("proxy: hop must route all ip4/ip6") - errHopHopping = errors.New("proxy: hop must not be hopping") - errNoHop = errors.New("proxy: no hop") - errHopSelf = errors.New("proxy: hop looping back onto hop") - errHopWireGuard = errors.New("proxy: hop must be wireguard") - errHopMtuInsufficient = errors.New("proxy: hop mtu insufficient") - errHopProxyRoutes = errors.New("proxy: no routes to hop") - errHop4Gateway = errors.New("proxy: hop cannot route ip4") - errHop6Gateway = errors.New("proxy: hop cannot route ip6") - errHopGlobalProxy = errors.New("proxy: hop must be global proxy") - errHopNotConnected = errors.New("proxy: set but not connected over hop") - errNilWinCfg = errors.New("proxy: win cfg nil") - errNilWinDevice = errors.New("proxy: missing win device id") - errNotRpnProxy = errors.New("proxy: rpn not found") - errNotRpnID = errors.New("proxy: not rpn id") - errNotRpnAcc = errors.New("proxy: not rpn account") - errNotRemote = errors.New("proxy: not a remote proxy") -) - -const ( - udptimeoutsec = 5 * 60 // 5m - tcptimeoutsec = (2 * 60 * 60) + (40 * 60) // 2h40m - tlsHandshakeTimeout = 30 * time.Second // some proxies take a long time to handshake - responseHeaderTimeout = 60 * time.Second - tzzTimeout = 2 * time.Minute // time between new connections before proxies transition to idle - lastOKThreshold = 10 * time.Minute // time between last OK and now before pinging & un-pinning - ageThreshold = 10 * time.Second // time for proxy to start up - pintimeout = 10 * time.Minute // time to keep a pin - alwaysPin = true // always pin to a proxy no matter the errors - maxFailingPinTrackTTl = 30 * time.Second // max period to track a failing to-be-pinned proxy - maxStallPeriodSec = 10 // max duration to stall a failing proxy - maxWaitPeriodSec = 3 // max duration to wait for a missing proxy to be added - getproxytimeout = 5 * time.Second -) - -// type checks -var _ Proxy = (*base)(nil) -var _ Proxy = (*exit)(nil) -var _ Proxy = (*exit64)(nil) -var _ Proxy = (*auto)(nil) -var _ Proxy = (*socks5)(nil) -var _ Proxy = (*http1)(nil) -var _ Proxy = (*wgproxy)(nil) -var _ Proxy = (*ground)(nil) -var _ Proxy = (*pipws)(nil) -var _ Proxy = (*piph2)(nil) - -type Proxy interface { - x.Proxy - protect.RDialer - - // DialerHandle uniquely identifies the concrete type backing this proxy's dialer. - // Useful as a phantom reference to this dialer. - // github.com/hashicorp/terraform/blob/325d18262/internal/configs/configschema/decoder_spec.go#L32 - DialerHandle() uintptr - // Handle uniquely identifies the concrete type backing this proxy. - Handle() uintptr - // Dialer returns the dialer for this proxy, which is an - // adapter for protect.RDialer interface, but with the caveat that - // not all Proxy instances implement DialTCP and DialUDP, though are - // guaranteed to implement Dial. - Dialer() protect.RDialer - // onNotOK is called by clients when the proxy is not responsive. - onNotOK() (refreshed, allok bool) - // onProtoChange returns true if the proxy must be re-added with cfg on proto changes. - OnProtoChange(lp LinkProps) (cfg string, readd bool) - // Gateway sets proxy p as the gateway for this router. - Hop(p Proxy, dryrun bool) error -} - -type Rpn interface { - x.Rpn - rpnProxyProvider - // addRpnProxy adds an RPN proxy to this multi-transport. - addRpnProxy(acc RpnAcc, cc string) (Proxy, error) - // removeRpnProxy removes an RPN proxy from this multi-transport. - removeRpnProxy(acc RpnAcc, cc string) bool -} - -type rpnProxyProvider interface { - // mainRpnProxyFor returns the main (default) RPN proxy from this multi-transport. - mainRpnProxyOf(provider string) (RpnProxy, error) - // rpnProxyFor returns a country-specific RPN proxy from this multi-transport. - rpnProxyFor(provider, cc string) (Proxy, error) - // AutoActive returns true if any of the RPN proxies are in-use by ipn.Auto. - AutoActive() bool -} - -type ProxyProvider interface { - rpnProxyProvider - // ProxyFor returns a transport from this multi-transport. - ProxyFor(id string) (Proxy, error) - // ProxyTo returns the proxy to use for ipp from given pids. - ProxyTo(ipp netip.AddrPort, uid string, pids []string) (Proxy, error) -} - -type Proxies interface { - x.Proxies - ProxyProvider - Rpn - // RefreshProto broadcasts proto change to all active proxies. - // l3 if left empty, will use last recorded value; same for mtu <= 0. - RefreshProto(l3 string, mtu int, force bool) - // LiveProxies returns a csv of active proxies. - LiveProxies() string - // Reverser sets the reverse proxy for all proxies. - Reverser(r netstack.GConnHandler) error -} - -type proxifier struct { - sync.RWMutex - NoVia - - ctx context.Context - p map[string]Proxy - - rpnmu sync.RWMutex // protects rp - rp map[string]RpnProxy // main rpn proxies - - hmu sync.RWMutex // protects hp - hp map[string][]string // hopproxy => [proxyid] - - ctl protect.Controller // dial control provider - obs x.ProxyListener // proxy observer - - lp LinkProps // link properties; protected by mu - - staller *core.ExpMap[string, string] // uid+dst(domainOrIP) -> stallSecs - - ipPins *core.Sieve[netip.AddrPort, string] // ipp -> proxyid - uidPins *core.Sieve2K[string, netip.AddrPort, string] // uid -> [dst -> proxyid] - - // immutable proxies - exit *exit // exit proxy, never changes - exit64 *exit64 // rpn64 proxy, never changes - base *base // base proxy, never changes - grounded *ground // grounded proxy, never changes - auto *auto // auto proxy, never changes - - extc *rpn.BaseClient // external wg registration, never changes - - lastWinErr *core.Volatile[error] // win registration error -} - -type LinkProps struct { - l3 string // ip4, ip6, ip46 - mtu int - rev netstack.GConnHandler // downstream; may be nil -} - -func (lp LinkProps) String() string { - return fmt.Sprintf("l3:%s/mtu:%d/rev:%X", lp.l3, lp.mtu, lp.rev) -} - -var _ Proxies = (*proxifier)(nil) -var _ x.Rpn = (*proxifier)(nil) -var _ x.Router = (*proxifier)(nil) -var _ protect.RDialer = (Proxy)(nil) -var _ Proxy = (*NoProxy)(nil) -var _ x.Router = (*NoProxy)(nil) - -// NewProxifier returns a new Proxifier instance. -func NewProxifier(pctx context.Context, l3 string, mtu int, c protect.Controller, o x.ProxyListener) *proxifier { - if c == nil || o == nil { - return nil - } - - pxr := &proxifier{ - ctx: pctx, - p: make(map[string]Proxy), - ctl: c, - obs: o, - - lp: LinkProps{l3: l3, mtu: mtu}, - - hp: make(map[string][]string), - - rp: make(map[string]RpnProxy), - lastWinErr: core.NewZeroVolatile[error](), - } - - pxr.exit = NewExitProxy(pctx, c) - pxr.exit64 = NewExit64Proxy(pctx, c) - pxr.base = NewBaseProxy(pctx, c, pxr) - pxr.grounded = NewGroundProxy() - pxr.auto = NewAutoProxy(pctx, pxr) - pxr.staller = core.NewExpiringMap[string, string](pctx) - pxr.ipPins = core.NewSieve[netip.AddrPort, string](pctx, pintimeout) - pxr.uidPins = core.NewSieve2K[string, netip.AddrPort, string](pctx, pintimeout) - - pxr.extc = rpn.NewExtClient(pxr.base) - - pxr.add(pxr.exit) // fixed - pxr.add(pxr.base) // fixed - pxr.add(pxr.grounded) // fixed - pxr.add(pxr.auto) // fixed - - if _, err := pxr.addRpnProxy2(pxr.exit64, pxr.exit64); err != nil { // fixed - // TODO: lastExit64Err? - log.W("proxy: rpn64: add: %v", err) - } - - log.I("proxy: new") - - context.AfterFunc(pctx, pxr.stopProxies) - - return pxr -} - -// add adds a proxy to the proxifier and invokes OnProxyAdded. -// It returns true if the proxy was added successfully. -// It stops old proxy if a new one with the same ID is added. -func (px *proxifier) add(p Proxy) (ok bool) { - var old Proxy - id := idstr(p) - - px.Lock() - defer px.Unlock() - - defer func() { - if ok { - core.Go("pxr.add: "+id, func() { - px.obs.OnProxyAdded(p.ID()) - }) - // new proxy, invoke Stop on old proxy - if old != nil && !Same(old, p) { - // holding px.lock, so exec stop in a goroutine - core.Go("pxr.add.stop: "+id, func() { - if oldVia, _ := old.Router().Via(); oldVia != nil { - px.Hop(oldVia.ID(), p.ID()) - } - _ = old.Stop() - // onRmv is not sent here, as one has just been added - }) - } - } - }() - - old = px.p[id] - if immutable(id) { - switch id { - case Exit: - if x, typeok := p.(*exit); typeok { - px.exit = x - px.p[id] = p - ok = true - } - case Base: - if x, typeok := p.(*base); typeok { - px.base = x - px.p[id] = p - ok = true - } - case Block: - if x, typeok := p.(*ground); typeok { - px.grounded = x - px.p[id] = p - ok = true - } - case Rpn64: - if x, typeok := p.(*exit64); typeok { - px.exit64 = x - px.p[id] = p - // do not call addRpnProxy from here - // it will result in endless recursive - // calls leading back here - ok = true - } - case Auto: - if x, typeok := p.(*auto); typeok { - px.auto = x - px.p[id] = p - ok = true - } - } - } else { - px.p[id] = p - ok = true - } - - logeif(!ok)("proxy: add: proxy %s (%s => %s); added? %t", id, idhandle(old), idhandle(p), ok) - return ok -} - -// RemoveProxy implements x.Proxies. -func (px *proxifier) RemoveProxy(id string) bool { - defer core.Recover(core.Exit11, "pxr.RemoveProxy."+id) - - return px.removeProxy(id, false /*force remove?*/) -} - -func (px *proxifier) removeProxy(id string, force bool) bool { - if isInternal(id) && !force { - log.D("proxy: remove: %s; not allowed", id) - return false - } - - px.Lock() - defer px.Unlock() - - perma := immutable(id) - if p, ok := px.p[id]; ok { - if !perma { - delete(px.p, id) - } - core.Go("pxr.removeProxy: "+id, func() { - px.unmapHopFrom(p, false /*dryrun*/) - - _ = p.Stop() - if !perma { - px.obs.OnProxyRemoved(id) - log.I("proxy: removed %s", id) - } else { - px.obs.OnProxyStopped(id) - log.I("proxy: stopped (not removed) %s", id) - } - }) - return true - } - return false -} - -// ProxyTo implements Proxies. -// May return both a Proxy and an error, in which case, the error -// denotes that while the Proxy is not healthy, it is still registered. -func (px *proxifier) ProxyTo(ipp netip.AddrPort, uid string, pids []string) (theone Proxy, err error) { - waitedForMissingProxy := false - - ippstr := ipp.String() - e := func(err error) error { - return fmt.Errorf("%v for %s to %s among %v", err, uid, ippstr, pids) - } - if len(pids) <= 0 || firstEmpty(pids) { - return nil, e(errMissingProxyID) - } - if !ipp.IsValid() { - return nil, e(errMissingAddress) - } - - stalledSec := uint32(0) - - if len(pids) == 1 { // there's no other pid to choose from - retryPin: - p, err := px.pinID(uid, ipp, pids[0]) // repin - if err != nil || p == nil { - err = core.OneErr(err, errProxyNotFound) - if !waitedForMissingProxy { - // wait for the missing proxy to be added before returning error - waitedForMissingProxy = true - stalledSec = px.stall(uid + ippstr) - if stalledSec < maxWaitPeriodSec { - time.Sleep(time.Duration(maxWaitPeriodSec-stalledSec) * time.Second) - stalledSec = maxWaitPeriodSec - } - goto retryPin - } - } - logev(err)("proxy: pin: %s+%s; pin pid0: %s (stalled? %ds / waited? %t); err? %v", - uid, ippstr, pids[0], stalledSec, waitedForMissingProxy, err) - if p != nil { - if !hasroute(p, ippstr) { - px.delpin(uid, ipp) - return nil, e(core.JoinErr(err, errProxyRoute)) - } // there is only one pid to route to - - // alwaysPin is set to true, so wipe out err; return p, even if err is not nil - // alwaysPin helps client code verify for itself just why this proxy won't work... - if alwaysPin { - return p, nil - } - } - return nil, e(err) - } - - var lopinned string - - pinnedpid, pinok := px.getpin(uid, ipp) - chosen := has(pids, pinnedpid) - lo := local(pinnedpid) - - log.VV("proxy: pin: %s+%s; pinned: %s (ok? %t); chosen? %t / local? %t; from pids: %v", - uid, ippstr, pinnedpid, pinok, chosen, lo, pids) - - if !pinok { // discard pinnedpid if pin has expired - pinnedpid = "" - } - - if pinok && chosen && lo { - // always favour remote proxy pins over local, if any - lopinned = pinnedpid - } else if pinok && chosen { - p, err := px.pinID(uid, ipp, pinnedpid) // repin - if p != nil && err == nil { - if hasroute(p, ippstr) { - return p, nil - } - px.delpin(uid, ipp) // del pin if no route - } // else: pinnedpid not ok (ex: END/TPU) or no route - log.W("proxy: pin: %s+%s; chosen and pinned: %s (but err? %v); hasproxy? %t (or no route)", - uid, ippstr, pinnedpid, err, p != nil) - } else if pinok && !chosen { - px.delpin(uid, ipp) - } - - var notok []Proxy - notokproxies := make([]string, 0) - endproxies := make([]string, 0) - pausedproxies := make([]string, 0) - norouteproxies := make([]string, 0) - missproxies := make([]string, 0) - loproxies := make([]string, 0) - if len(lopinned) > 0 { // lopinned may be empty - loproxies = append(loproxies, lopinned) - } - - defer func() { - logev(err)("proxy: pin: %s+%s; chosen? %s; stalled? %ds; local: %v; miss: %v; notok: %v; noroute: %v; paused %v; ended %v", - uid, ipp, idstr(theone), stalledSec, loproxies, missproxies, notokproxies, norouteproxies, pausedproxies, endproxies) - }() - -retrySearch: - for _, pid := range pids { - if pinok && pid == pinnedpid { // already tried above - continue - } - if local(pid) { // skip local; prefer remote - loproxies = append(loproxies, pid) - continue // process later - } - - p, err := px.proxyFor(pid) - if err != nil || p == nil { // proxy 404 - // TODO: errors.Is(err, errProxyNotFound)? - missproxies = append(missproxies, pid) - continue - } - - st := p.Status() - if st == TPU { - pausedproxies = append(pausedproxies, pid) - continue - } else if st == END { - endproxies = append(endproxies, pid) - continue - } - - if noop(typstr(p)) { - loproxies = append(loproxies, pid) - continue - } - - if hasroute(p, ippstr) { - err := px.pin(uid, ipp, p) // repin & ping if needed - if err == nil { - log.VV("proxy: pin: %s+%s; pinned: %s; from pids: %v", - uid, ippstr, pid, pids) - return p, nil - } // else: proxy not ok - notokproxies = append(notokproxies, pid) - notok = append(notok, p) - } else { // else: proxy cannot route; split-tunnel - norouteproxies = append(norouteproxies, pid) - } - } - - // can route but not healthy; choose any one on random - if len(notok) > 0 { - // stall to allow a non-healthy proxy to recover - stalledSec = px.stall(uid + ippstr) - return core.ChooseOne(notok), nil - } - - // lopinned is always the first element, if any. - for _, pid := range loproxies { - // ignore err, as it unlikely for local proxies - // that are always available, and are presumed to - // be gateways (route all ips) - if p, _ := px.pinID(uid, ipp, pid); p != nil { // repin - return p, nil - } - missproxies = append(missproxies, pid) - } - - if len(missproxies) > 0 && !waitedForMissingProxy { - // wait for the missing proxy to be added before returning error - waitedForMissingProxy = true - stalledSec = px.stall(uid + ippstr) - if stalledSec < maxWaitPeriodSec { - time.Sleep(time.Duration(maxWaitPeriodSec-stalledSec) * time.Second) - stalledSec = maxWaitPeriodSec - } - log.W("proxy: pin: %s+%s; missing: %v; notok: %v; noroute: %v; paused: %v; ended: %v; waited: %ds", - uid, ippstr, missproxies, notokproxies, norouteproxies, pausedproxies, endproxies, stalledSec) - pids = missproxies - missproxies = make([]string, 0) - goto retrySearch - } - - if len(notokproxies) > 0 { - return nil, e(errNoProxyHealthy) - } else if len(missproxies) > 0 { - return nil, e(errProxyNotFound) - } else if len(norouteproxies) > 0 { - return nil, e(errProxyRoute) - } else if len(endproxies) > 0 { - return nil, e(errProxyStopped) - } else if len(pausedproxies) > 0 { - return nil, e(errProxyPaused) - } - - return nil, e(errProxyAllDown) -} - -func (px *proxifier) stall(k string) (secs uint32) { - if n := px.staller.Get(k); n <= 3 { - secs = (rand.Uint32() % 3) + 1 // up to 3s - } else { - secs = n - } - px.staller.Set(k, maxFailingPinTrackTTl) // track uid=>target for 30s - if secs = min(maxStallPeriodSec, secs); secs > 0 { // max up to 10s - w := time.Duration(secs) * time.Second - time.Sleep(w) - } - return -} - -func (px *proxifier) pinID(uid string, ipp netip.AddrPort, id string) (Proxy, error) { - p, err := px.proxyFor(id) - if err != nil || p == nil { - err = core.OneErr(err, errProxyNotFound) - return p, fmt.Errorf("proxy: pin: id %s; err: %v", id, err) - } - err = px.pin(uid, ipp, p) - return p, err -} - -func (px *proxifier) pin(uid string, ipp netip.AddrPort, p Proxy) error { - pid := idstr(p) - - err := healthy(p) // called to ensure p is ready-to-go - if err == nil { - px.uidPins.Put(uid, ipp, pid) - px.ipPins.Put(ipp, pid) - } - logev(err)("proxy: pin: ok? %t; %s from %s; err? %v", - err == nil, ipp, pid, err) - - if err != nil { - return fmt.Errorf("proxy: pin: %s; err: %v", pid, err) - } - return nil -} - -func (px *proxifier) delpin(uid string, ipp netip.AddrPort) { - px.uidPins.Del(uid, ipp) - px.ipPins.Del(ipp) -} - -func (px *proxifier) getpin(uid string, ipp netip.AddrPort) (string, bool) { - if id, ok := px.uidPins.Get(uid, ipp); ok { - return id, ok - } - return px.ipPins.Get(ipp) -} - -func (px *proxifier) clearpins() (int, int) { - totips := px.ipPins.Clear() - totuids := px.uidPins.Clear() - - return totips, totuids -} - -// ProxyFor returns the proxy for the given id or an error. -// As a special case, if it takes longer than getproxytimeout, it returns an error. -// ProxyFor implements Proxies. -func (px *proxifier) ProxyFor(id string) (Proxy, error) { - p, err := px.proxyFor(id) - if !errors.Is(err, errProxyNotFound) || !isWellknown(id) { - // return proxy not found for non-wellknown proxy ids immediately without waiting - // because the constructor's of dns transports call into ProxyFor with their own IDs - // (ex: dnsx.Default / dnsx.Preferred) to auto-setup the transporting over proxy - // (ex: when WireGuard DNS53 transports are setup). Waiting for "maxWaitPeriodSec" - // then delays construction of the transport & in case of dnsx.Default specifically, - // it results in prolonged intra.NewTunnel creation, which is sensitive to delays, - // as it is expected to be called from the main service thread of the Android client. - return p, err - } - - log.W("proxy: for: %s; not found; waiting for %ds...", id, maxWaitPeriodSec) - time.Sleep(time.Duration(maxWaitPeriodSec) * time.Second) - return px.proxyFor(id) -} - -func (px *proxifier) proxyFor(id string) (Proxy, error) { - defer core.Recover(core.Exit11, "pxr.proxyFor."+id) - - if len(id) <= 0 { - return nil, errMissingProxyID - } - - if immutable(id) { // fast path for immutable proxies - if id == Exit { - return px.exit, nil - } else if id == Base { - return px.base, nil - } else if id == Block { - return px.grounded, nil - } else if id == Auto { - return px.auto, nil - } else if id == Rpn64 { - return px.exit64, nil - } // Ingress do not have a fast path - } - - if isRPN(id) { - rpn, _ := core.Grx("pxr.mainRpnProxyFor: "+id, func(_ context.Context) (RpnProxy, error) { - // id here must be non-countrycode "rpn provider" - // ex: x.RpnWin; not "rpn+cc": x.RpnWin+US, x.RpnWin+MX - if p, err := px.mainRpnProxyOf(id); err == nil { - return p, nil - } - return nil, errNotRpnID - }, getproxytimeout/2) - if rpn != nil && core.IsNotNil(rpn) { - _ = healthy(rpn) - return rpn, nil - } // else: search for id in px.p, which includes rpn+cc proxies - } - - // go.dev/play/p/xCug1W3OcMH - p, completed := core.Grx("pxr.ProxyFor: "+id, func(_ context.Context) (Proxy, error) { - px.RLock() - defer px.RUnlock() - - return px.p[id], nil - }, getproxytimeout) - - if !completed { - log.W("proxy: for: %s; timeout!", id) - // possibly a deadlock, so return an error - return nil, errGetProxyTimeout - } - if p == nil || core.IsNil(p) { - log.W("proxy: for: %s; not found", id) - return nil, errProxyNotFound - } - if isWG(idstr(p)) { - _ = healthy(p) // ping or refresh - } - return p, nil -} - -func (px *proxifier) AutoActive() bool { - return settings.AutoActive() -} - -func (px *proxifier) mainRpnProxyOf(provider string) (RpnProxy, error) { - if !isRPN(provider) { - return nil, errNotRpnID - } - px.rpnmu.RLock() - rp := px.rp[provider] - px.rpnmu.RUnlock() - if rp == nil { - return nil, errNotRpnProxy - } - return rp, nil -} - -func (px *proxifier) rpnProxyFor(provider, cc string) (Proxy, error) { - id := provider + cc - p, err := px.ProxyFor(id) - if p == nil { - return nil, core.OneErr(err, errProxyNotFound) - } - return p, err -} - -// GetProxy implements x.Proxies. -func (px *proxifier) GetProxy(id string) (x.Proxy, error) { - return px.ProxyFor(id) -} - -// TestHop implements Proxies. -func (px *proxifier) TestHop(via, origin string) string { - defer core.Recover(core.Exit11, "pxr.TestHop."+via+">>"+origin) - if err := px.hop(via, origin, true /*dryrun*/); err != nil { - return err.Error() - } - return "" // all ok -} - -// Hop implements x.Proxies. -func (px *proxifier) Hop(via, origin string) error { - return px.hop(via, origin, false /*dryrun*/) -} - -func (px *proxifier) hop(via, origin string, dryrun bool) error { - defer core.Recover(core.Exit11, "pxr.Hop."+via+">>"+origin) - - if len(origin) <= 0 { - return errMissingProxyID - } - origPx, err := px.ProxyFor(origin) - if err != nil || origPx == nil { - return core.OneErr(err, errProxyNotFound) - } - - oldViaPx, _ := origPx.Router().Via() // may be nil - - if len(via) <= 0 { // remove hop if needed - err = origPx.Hop(nil, dryrun) - _ = px.unmapHop(oldViaPx, origPx, err != nil || dryrun) - return err - } - - if via == origin { - return errHopSelf - } - - viaPx, err := px.ProxyFor(via) - if err != nil || viaPx == nil { - return core.OneErr(err, errProxyNotFound) - } - if viaPx.Status() == END || origPx.Status() == END { - return errProxyStopped - } - - if idstr(oldViaPx) == idstr(viaPx) { - if !dryrun { - core.Gxe("pxr.hop.refresh."+idstr(origPx), origPx.Refresh) - } - log.I("proxy: hop: %s => %s (no change)", origin, via) - return nil // no change - } - - viaRouter := viaPx.Router() - if viaViaVia, _ := viaRouter.Via(); viaViaVia != nil { - log.W("proxy: triple hop: %s => %s => %s; not allowed", origin, via, viaViaVia.ID()) - return errHopHopping - } else if !viaRouter.IP4() && !viaRouter.IP6() { - // via must either route all ip4 or all ip6; ideally both - return errHopDefaultRoutes - } - - _ = px.unmapHop(oldViaPx, origPx, dryrun) - err = origPx.Hop(viaPx, dryrun) - _ = px.mapHop(viaPx, origPx, err != nil || dryrun) - - return err -} - -func (px *proxifier) mapHop(hop x.Proxy, orig x.Proxy, dryrun bool) (mapped bool) { - hopID := idstr(hop) - origID := idstr(orig) - if len(hopID) <= 0 || len(origID) <= 0 { - return - } - - px.hmu.Lock() - defer px.hmu.Unlock() - in := px.hp[hopID] // in may be nil - out := addElem(in, origID) - if !dryrun { - px.hp[hopID] = out - } - log.I("proxy: mapHop: %s => %s; remaining origins: %v", hopID, origID, out) - return len(out) > len(in) -} - -func (px *proxifier) unmapHopFrom(orig x.Proxy, dryrun bool) (unmapped bool) { - via, _ := orig.Router().Via() // may be nil - return px.unmapHop(via, orig, dryrun) -} - -func (px *proxifier) unmapHop(hop x.Proxy, orig x.Proxy, dryrun bool) (unmapped bool) { - hopID := idstr(hop) - origID := idstr(orig) - if len(hopID) <= 0 || len(origID) <= 0 { - return - } - - px.hmu.Lock() - defer px.hmu.Unlock() - - if in, ok := px.hp[hopID]; ok { - out := removeElem(in, origID) - if !dryrun { - if len(out) <= 0 { - log.I("proxy: unmapHop: %s => %s; no more origins, removing hop", hopID, origID) - delete(px.hp, hopID) // remove hop if no origins left - } else { - log.I("proxy: unmapHop: %s => %s; remaining origins: %v", hopID, origID, out) - px.hp[hopID] = out - } - } - unmapped = len(out) < len(in) - } - return -} - -func (px *proxifier) refreshHopOriginsIfAny(hop Proxy, why string) (n int) { - hopID := idstr(hop) - if len(hopID) <= 0 { - return - } - - px.hmu.RLock() - origins := slices.Clone(px.hp[hopID]) // Create a copy to avoid race - px.hmu.RUnlock() - - if len(origins) <= 0 { - log.D("proxy: refreshHopOrigins for %s: no-op", why) - return - } - - px.RLock() - for _, origin := range origins { - if p := px.p[origin]; p != nil { - n++ - core.Gxe("pxr.hop.refresh."+idstr(p), p.Refresh) - } - } - px.RUnlock() - - log.I("proxy: refreshHopOrigins for %s: %d[%v]", why, n, origins) - return -} - -// Router implements x.Proxy. -func (px *proxifier) Router() x.Router { - return px -} - -// Rpn implements x.Proxies. -func (px *proxifier) Rpn() x.Rpn { - return px -} - -func (px *proxifier) stopProxies() { - px.rpnmu.Lock() - n := len(px.rp) - for _, rp := range px.rp { - curpRp := rp - id := idstr(curpRp) - core.Go("pxr.stopProxies.purgeRpn: "+id, func() { - _ = curpRp.PurgeAll() - }) - } - clear(px.rp) - px.rpnmu.Unlock() - - px.hmu.Lock() - clear(px.hp) - px.hmu.Unlock() - - px.Lock() - defer px.Unlock() - - l := len(px.p) - for _, p := range px.p { - curp := p - id := idstr(curp) - - core.Go("pxr.stopProxies: "+id, func() { - _ = curp.Stop() - }) - } - clear(px.p) - px.staller.Clear() - px.ipPins.Clear() - px.uidPins.Clear() - - core.Go("pxr.onStop", func() { px.obs.OnProxiesStopped() }) - log.I("proxy: stopped and removed %d+%d", n, l) -} - -// RefreshProxies implements x.Proxies. -func (px *proxifier) RefreshProxies() string { - // TODO: remove error in the return value - defer core.Recover(core.Exit11, "pxr.RefreshProxies") - - ptot, ptotu := px.clearpins() - - px.Lock() - defer px.Unlock() - - tot := len(px.p) - log.I("proxy: refresh pxs: %d / removed pins: %d %d", tot, ptot, ptotu) - - var which = make([]string, 0, len(px.p)) - for _, p := range px.p { - curp := p - id := idstr(curp) - which = append(which, id) - // some proxy.Refershes may be slow due to network requests, hence - // preferred to run in a goroutine to avoid blocking the caller. - // ex: wgproxy.Refresh -> multihost.Refersh -> dialers.Resolve - core.Gx("pxr.RefreshProxies: "+id, func() { - if err := curp.Refresh(); err != nil { - log.E("proxy: refresh (%s/%s/%s) failed: %v", id, curp.Type(), curp.GetAddr(), err) - } - }) - } - - log.I("proxy: refreshed %d / %d: %v", len(which), tot, which) - - return strings.Join(which, ",") -} - -// LiveProxies implements x.Proxies. -func (px *proxifier) LiveProxies() string { - px.RLock() - defer px.RUnlock() - - out := make([]string, 0, len(px.p)) - for id := range px.p { - out = append(out, id) - } - return strings.Join(out, ",") -} - -// RefreshProto implements x.Proxies. -func (px *proxifier) RefreshProto(l3 string, mtu int, force bool) { - defer core.Recover(core.Exit11, "pxr.RefreshProto") - // must unlock from deferred since panics are recovered above - px.Lock() - defer px.Unlock() - - if len(l3) <= 0 { - l3 = px.lp.l3 // keep existing - } - if mtu <= 0 { - mtu = px.lp.mtu // keep existing - } - - if !force && px.lp.l3 == l3 && px.lp.mtu == mtu { - log.D("proxy: refreshProto (forced? %t): (%s == %s & %d == %d) unchanged", - force, px.lp.l3, l3, px.lp.mtu, mtu) - return - } - - newlp := LinkProps{l3: l3, mtu: mtu, rev: px.lp.rev} // copy - px.lp = newlp - for _, p := range px.p { - curp := p - id := idstr(curp) - core.Gx("pxr.RefreshProto: "+id, func() { - // always run in a goroutine (or there is a deadlock) - // wgproxy.onProtoChange -> multihost.Refresh -> dialers.Resolve - // -> ipmapper.LookupIPNet -> resolver.LocalLookup -> transport.Query - // -> ipn.ProxyFor -> px.Lock() -> deadlock - if cfg, readd := curp.OnProtoChange(newlp); readd { - // px.addProxy -> px.add -> px.Lock() -> deadlock - _, err := px.forceAddProxy(id, cfg) - // TODO: preserve hop? - log.I("proxy: refreshProto (forced? %t): (%s/%s/%s) re-add; err? %v", - force, id, curp.Type(), curp.GetAddr(), err) - } - }) - } -} - -func (px *proxifier) Reverser(rhdl netstack.GConnHandler) error { - px.Lock() - defer px.Unlock() - - px.lp.rev = rhdl - return nil -} - -// IP4 implements x.Router. -func (px *proxifier) IP4() bool { - px.RLock() - defer px.RUnlock() - - for _, p := range px.p { - if local(idstr(p)) || noop(typstr(p)) { - continue - } - if r := p.Router(); r != nil && !r.IP4() { - return false - } - } - return len(px.p) > 0 -} - -// IP6 implements x.Router. -func (px *proxifier) IP6() bool { - px.RLock() - defer px.RUnlock() - - for _, p := range px.p { - if local(idstr(p)) || noop(typstr(p)) { - continue - } - if r := p.Router(); r != nil && !r.IP6() { - return false - } - } - - return len(px.p) > 0 -} - -// MTU implements x.Router. -func (px *proxifier) MTU() (out int, err error) { - px.RLock() - defer px.RUnlock() - - out = MAXMTU - only4 := false - minmtu := minmtu6 - for _, p := range px.p { - if local(idstr(p)) || noop(typstr(p)) { - continue - } - r := p.Router() // never nil - only4 = only4 || r.IP4() && !r.IP6() - if only4 && minmtu > minmtu4 { - minmtu = minmtu4 - } - if hopping(r) { // skip proxies hopping via another - continue - } // inner tunnel MTUs should not have any bearing on outer MTU - if m, err1 := r.MTU(); err1 == nil { - out = min(out, max(m, minmtu)) - } // else: NOMTU - } - if out == MAXMTU || out == NOMTU { // unchanged or unknown - err = errNoMtu - } - return out, err -} - -// Stat implements x.Router. -func (px *proxifier) Stat() *x.RouterStats { - px.RLock() - defer px.RUnlock() - - s := new(x.RouterStats) - for _, p := range px.p { - pid := idstr(p) - ptyp := typstr(p) - if local(pid) || isInternal(pid) || noop(ptyp) { - continue - } - if r := p.Router(); r != nil { - s = accStats(s, r.Stat()) - } - } - return s -} - -func accStats(a, b *x.RouterStats) (c *x.RouterStats) { - c = new(x.RouterStats) - if a == nil && b == nil { - return c - } else if a == nil { - return b - } else if b == nil { - return a - } - // c.Addr? c.Extra? c.LastErr, c.LastRxErr, c.LastTxErr - c.Tx = a.Tx + b.Tx - c.Rx = a.Rx + b.Rx - c.ErrRx = a.ErrRx + b.ErrRx - c.ErrTx = a.ErrTx + b.ErrTx - c.LastOK = max(a.LastOK, b.LastOK) - c.LastTx = max(a.LastTx, b.LastTx) - c.LastRx = max(a.LastRx, b.LastRx) - c.LastGoodRx = max(a.LastGoodRx, b.LastGoodRx) - c.LastGoodTx = max(a.LastGoodTx, b.LastGoodTx) - c.LastRefresh = max(a.LastRefresh, b.LastRefresh) - // todo: a.Since or b.Since may be zero - c.Since = min(a.Since, b.Since) - c.Status = strings.Join([]string{a.Status, b.Status}, ";") - c.StatusReason = strings.Join([]string{a.StatusReason, b.StatusReason}, ";") - return c -} - -// Contains implements x.Router. -func (px *proxifier) Contains(ipprefix string) bool { - px.RLock() - defer px.RUnlock() - - for _, p := range px.p { - // always present local proxies route either everything or - // nothing: not useful for making routing decisions - if local(idstr(p)) || noop(typstr(p)) { - continue - } - if r := p.Router(); r != nil && r.Contains(ipprefix) { - return true - } - } - return false -} - -// Reaches implements x.Router. -func (px *proxifier) Reaches(urlOrHostPortOrIPPortCsv string) bool { - px.RLock() - defer px.RUnlock() - - for _, p := range px.p { - if r := p.Router(); r != nil && r.Reaches(urlOrHostPortOrIPPortCsv) { - return true - } - } - return false -} - -func (px *proxifier) EntitlementFrom(entitlementOrStateJson []byte, id, did string) (ent x.RpnEntitlement, err error) { - switch id { - case RpnWin: - ent, err = px.extc.MakeWsEntitlement(entitlementOrStateJson, did) - default: - err = errNotRpnAcc - } - return -} - -// RegisterWin implements x.Rpn. -func (px *proxifier) RegisterWin(entitlementOrState []byte, did string, ops *x.RpnOps) (stateJson []byte, err error) { - defer func() { - px.lastWinErr.Store(err) // may be nil - }() - - if len(did) <= 0 { - return nil, errNilWinDevice - } - - if ops == nil { - ops = new(x.RpnOps) - } - - existingStateJson := entitlementOrState - restore := len(existingStateJson) > 0 - - win, err := px.registerWin(existingStateJson, did, *ops) - if err != nil || core.IsNil(win) { - log.E("proxy: ws: make failed: %v", err) - return nil, core.JoinErr(err, errNilWinCfg) - } - - state, err := win.State() - if err != nil { - // TODO: RpnAcc may be stateless, in which case err is expected & could be ignored - return nil, err - } - - // TODO: create a new proxy type for win, so Refresh() could be sent to /connect - // TODO: best location: github.com/Windscribe/browser-extension/blob/ed83749ad1/modules/ext/src/utils/getBestLocation.js - rp, err := px.addRpnProxy(win, anycc(win)) - if err != nil || rp == nil { - log.E("proxy: ws: add wg for %s failed: %v", win.Who(), err) - return nil, core.JoinErr(err, errNotRpnProxy) - } - - log.I("proxy: ws: registered: %s / %d; new? %t; ops: %+v", win.Who(), len(state), !restore, ops) - return state, nil -} - -func (px *proxifier) registerWin(entitlementOrStateJson []byte, did string, ops x.RpnOps) (RpnAcc, error) { - return px.extc.MakeWsWgFrom(entitlementOrStateJson, did, ops) -} - -// UnregisterWin implements x.Rpn. -func (px *proxifier) UnregisterWin() bool { - return px.unregisterRpn(RpnWin) -} - -func (px *proxifier) unregisterRpn(provider string) bool { - rp, _ := px.mainRpnProxyOf(provider) - if rp == nil { - return false - } - - n := rp.PurgeAll() // n == 1 for single country rpn - - px.rpnmu.Lock() - delete(px.rp, provider) - px.rpnmu.Unlock() - - log.I("proxy: %s: unregistered; forks: %d", provider, n) - return true -} - -// Win implements x.Rpn. -func (px *proxifier) Win() (x.RpnProxy, error) { - win, err := px.mainRpnProxyOf(RpnWin) - if win == nil { - return nil, core.JoinErr(err, px.lastWinErr.Load()) - } - return win, nil -} - -// Pip implements x.Rpn. -func (px *proxifier) Pip() (x.RpnProxy, error) { - // TODO: Register and Unregister for Pip - // TODO: Pip asRpnProxy (with multi-country support) - return px.mainRpnProxyOf(RpnWs) -} - -// Exit64 implements x.Rpn. -func (px *proxifier) Exit64() (x.RpnProxy, error) { - return px.mainRpnProxyOf(Rpn64) -} - -// TestWin implements x.Rpn. -func (px *proxifier) TestWin() (string, error) { - return px.testWin() -} - -func (px *proxifier) testWin() (string, error) { - v4, v6, err := rpn.WinEndpoints() - if err != nil { - log.W("proxy: ws: err testing endpoints: %v", err) - return "", err - } - - n := 0 - const maxpings = 5 - oks := make([]string, 0, len(v4)) - for _, ip := range append(v4, v6...) { - ipstr := ip.String() - // base can route back into netstack (settings.LoopingBack) - // in which case all endpoints will "seem" reachable. - // exit, however, never routes back into netstack and has - // the true, unhindered path to the underlying network. - if Reaches(px.exit, ipstr, "tcp") { - oks = append(oks, ipstr) - n++ - } - if n >= maxpings { - break // stop after maxpings - } - } - - if len(oks) <= 0 { - log.E("proxy: ws: no reachable addrs among %v", v4) - return "", core.JoinErr(errNoSuitableAddress, px.lastWinErr.Load()) - } - return strings.Join(oks, ","), nil -} - -// TestExit64 implements x.Rpn. -func (px *proxifier) TestExit64() (string, error) { - return px.testExit64() -} - -func (px *proxifier) testExit64() (ips string, errs error) { - v6, err := rpn.Exit64Endpoints() - if err != nil { - log.W("proxy: exit64: err testing endpoints %v", err) - return "", err - } - - oks := make([]string, 0, len(v6)) - for _, ip := range v6 { - ipstr := ip.String() - // base can route back into netstack (settings.LoopingBack) - // in which case all endpoints will "seem" reachable. - // exit, however, never routes back into netstack and has - // the true, unhindered path to the underlying network. - if Reaches(px.exit, ipstr, "icmp") { - oks = append(oks, ipstr) - } - } - - if len(oks) <= 0 { - log.E("proxy: exit64: no reachable addrs among %v", v6) - return "", errNoSuitableAddress - } - return strings.Join(oks, ","), nil -} - -func IsAnyLocalProxy(ids ...string) bool { - return core.IsAny(ids, local) -} - -// Base, Block, Exit, Rpn64, Ingress -func local(id string) bool { - return id == Base || id == Block || id == Exit || id == Rpn64 || id == Ingress -} - -func automatic(id string) bool { - return id == Auto -} - -func noop(typ string) bool { - return typ == NOOP -} - -// TODO: check for hops on "noop" transports; if those -// are NOT hoppping, then those are NOT remote, either -func Remote(id string) bool { - return !local(id) || !automatic(id) -} - -func hopping(r x.Router) bool { - hop, _ := r.Via() - return hop != nil -} - -func immutable(id string) bool { - return local(id) || id == Auto -} - -func isInternal(id string) bool { - return isRPN(id) || immutable(id) -} - -func isWellknown(id string) bool { - return isInternal(id) || isWG(id) || isOrbot(id) || isGlobalH1(id) || isPip(id) -} - -func isRPN(id string) bool { - return strings.Contains(id, RPN) // RPN is a suffix -} - -func isWG(id string) bool { - return strings.HasPrefix(id, WG) || strings.HasPrefix(id, WGFAST) -} - -func isOrbot(id string) bool { - return id == OrbotH1 || id == OrbotS5 -} - -func isGlobalH1(id string) bool { - return id == GlobalH1 -} - -func isPip(id string) bool { - return strings.HasPrefix(id, PIPH2) || strings.HasPrefix(id, PIPWS) -} - -func idling(t time.Time) bool { - return time.Since(t) > tzzTimeout -} - -func localDialStrat(d *protect.RDial, network, local, remote string) (protect.Conn, error) { - return dialers.SplitDialBind(d, network, local, remote) -} - -func dialAny(all []protect.RDialer, network, local, remote string) (protect.Conn, error) { - return dialers.DialAny(all, str2addr(network, local), str2addr(network, remote)) -} - -func str2addr(network, addrport string) net.Addr { - ip, port, err := net.SplitHostPort(addrport) - if err != nil { - return nil - } - portno, err := strconv.Atoi(port) - if err != nil { - return nil - } - switch network { - case "tcp", "tcp4", "tcp6": - return &net.TCPAddr{ - IP: net.ParseIP(ip), - Port: portno, - } - case "udp", "udp4", "udp6": - fallthrough - default: - return &net.UDPAddr{ - IP: net.ParseIP(ip), - Port: portno, - } - } -} - -func firstEmpty(arr []string) bool { - return len(arr) <= 0 || len(arr[0]) <= 0 -} diff --git a/intra/ipn/proxy.go b/intra/ipn/proxy.go deleted file mode 100644 index 2b86f874..00000000 --- a/intra/ipn/proxy.go +++ /dev/null @@ -1,857 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "errors" - "fmt" - "math/rand" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "slices" - "strconv" - "strings" - "syscall" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -// must be kept in sync with rpn's Conf() impls (like in yegor.go) -const anyCountryCode = "**" // random country -const noCountryForOldMen = "" // zz - -func anycc(acc RpnAcc) string { - cc := anyCountryCode - if !acc.MultiCountry() { - cc = noCountryForOldMen - } - return cc -} - -func (pxr *proxifier) NewSocks5Proxy(id, user, pwd, ip, port string) (p *socks5, err error) { - opts := settings.NewAuthProxyOptions("socks5", user, pwd, ip, port, nil) - return NewSocks5Proxy(id, pxr.ctx, pxr.ctl, pxr, opts) -} - -func (pxr *proxifier) Underlay(id string, c x.Controller) x.Proxy { - return newBasicProxy(id, fakeBaseAddr, pxr.ctx, c, pxr) -} - -// AddProxy implements Proxifier. -func (pxr *proxifier) AddProxy(id, txt string) (x.Proxy, error) { - defer core.Recover(core.Exit11, "prx.AddProxy."+id) - - pid := id - if isRPN(pid) { // must call addRpnProxy instead - return nil, errAddProxyAsRpn - } - - return pxr.addProxy(pid, txt) -} - -// cc may be a fully qualified ID in case of removing the main proxy. -func (pxr *proxifier) removeRpnProxy(acc RpnAcc, cc string) bool { - if acc == nil || core.IsNil(acc) { - return false - } - typ := acc.ProviderID() - if !isRPN(typ) { - return false - } - if !acc.MultiCountry() && cc != noCountryForOldMen { - log.W("proxy: rpn: remove: %s not multi-country; [%s] ignored", typ, cc) - cc = noCountryForOldMen - } - - log.I("proxy: rpn: remove: %s[%s]", typ, cc) - - rpnid := cc // cc itself may be a fully qualified id if removing main proxy - if !strings.HasPrefix(cc, typ) { - rpnid = typ + cc - } - - return pxr.removeProxy(rpnid, true /*force*/) -} - -// cc may be a fully qualified ID in case when re-adding the main proxy. -func (pxr *proxifier) addRpnProxy(acc RpnAcc, cc string) (Proxy, error) { - if acc == nil || core.IsNil(acc) { - return nil, errNotRpnAcc - } - - typ := acc.ProviderID() - if !isRPN(typ) { - return nil, errNotRpnID - } - - if !acc.MultiCountry() && cc != noCountryForOldMen { - log.W("proxy: rpn: add: %s not multi-country; [%s] ignored", typ, cc) - cc = noCountryForOldMen - } - - log.I("proxy: rpn: add: %s[%s]", typ, cc) - - // cc may be "typcity;cc" (see var rpnid below) - // but we need cc to be "city;cc" (ref struct RpnServer.Key) - cc, _ = strings.CutPrefix(cc, typ) - - txt, err := acc.Conf(cc) - if err != nil { - return nil, err - } - - rpnid := typ + cc - - p, err := pxr.forceAddProxy(rpnid, txt) - if p == nil { - pxr.postAddRpnProxyError(acc) // remove from pxr.rp if exists - return nil, core.JoinErr(err, errAddProxy) - } - - return pxr.postAddRpnProxy(p, acc) -} - -// TODO: on add / update a via proxy; refresh all dependent origins -func (pxr *proxifier) addRpnProxy2(p Proxy, acc RpnAcc) (Proxy, error) { - proxyid := idstr(p) - providerid := acc.ProviderID() - if !isRPN(proxyid) || !isRPN(providerid) { - return nil, errNotRpnProxy - } - - ok := pxr.add(p) - if !ok { - pxr.postAddRpnProxyError(acc) // remove from pxr.rp if exists - return nil, errAddProxy - } - - who := "postAddRpnProxy." + proxyid - // TODO: setup hop from mainCountryCode to forked rpn proxies - core.Gx(who, func() { pxr.refreshHopOriginsIfAny(p, who) }) - - return pxr.postAddRpnProxy(p, acc) -} - -func (pxr *proxifier) postAddRpnProxy(p Proxy, acc RpnAcc) (_ Proxy, err error) { - proxyid := idstr(p) - provider := acc.ProviderID() - - pxr.rpnmu.Lock() - rp := pxr.rp[provider] - pxr.rpnmu.Unlock() - - // add rpn proxy iff rpn proxy isn't multicountry (in which case only one - // instance of it can exist and hence it is being re-added if already present) - // or, if it is multicountry, add it only if the proxy is the main proxy, - // as forked children countries only need be added as plain-old proxies - // which is done before calling this function (ie, a no-op) - if rp == nil { - rp, err = asRpnProxy(p, acc, pxr) - if rp == nil { // should not happen; unexpected! - defer pxr.removeProxy(proxyid, true /*force*/) - return nil, core.JoinErr(err, errAddProxyAsRpn) - } - // TODO: setup hop from mainCountryCode to forked rpn proxies - pxr.rpnmu.Lock() - pxr.rp[provider] = rp // removed on unregister - pxr.rpnmu.Unlock() - log.I("proxy: rpn: add: post: registered %s as rpn proxy for %s", proxyid, provider) - } else if idstr(p) == idstr(rp) { - log.I("proxy: rpn: add: post: %s already registered for %s; emplacing...", proxyid, provider) - core.Gx("emplace."+idstr(p), func() { rp.Emplace(p) }) // may fail - } - - return p, nil -} - -func (pxr *proxifier) postAddRpnProxyError(acc RpnAcc) (removed bool) { - return pxr.unregisterRpn(acc.ProviderID()) // unregisters if it exists -} - -func (pxr *proxifier) forceAddProxy(id, txt string) (p Proxy, err error) { - return pxr.addOrUpdateProxy(id, txt, true /*force*/) -} - -func (pxr *proxifier) addProxy(id, txt string) (p Proxy, err error) { - return pxr.addOrUpdateProxy(id, txt, false /*force*/) -} - -func (pxr *proxifier) addOrUpdateProxy(id, txt string, force bool) (p Proxy, err error) { - if len(id) <= 0 { - return nil, errAddProxy - } - - defer func() { - if err != nil { - core.Gx("addProxy.refreshHop"+id, func() { pxr.refreshHopOriginsIfAny(p, "addProxy."+id) }) - } - }() - - // wireguard proxies have IDs starting with "wg" - if isWG(id) { - pxr.RLock() - lp := pxr.lp - pxr.RUnlock() - if force { - p, err = NewWgProxy(id, pxr.ctl, pxr, lp, txt) - } else if p, _ = pxr.proxyFor(id); p != nil { - if wgp, ok := p.(WgProxy); ok && wgp.update(id, txt) { - newcfg, readd := wgp.OnProtoChange(lp) - if readd || len(newcfg) > 0 { - log.W("proxy: add: cannot update wg(%s); readd it!", id) - } else { - log.I("proxy: add: updated wg %s/%s/%s", id, lp, p.GetAddr()) - return - } - } // else: recreate - } - if !force && p == nil { - // txt is both wg ifconfig and peercfg - p, err = NewWgProxy(id, pxr.ctl, pxr, lp, txt) - } - } else if len(txt) <= 0 { - p = NewBasicProxy(id, pxr.ctx, pxr.ctl, pxr) - err = nil - } else { - var strurl string - var usr string - var pwd string - var u *url.URL - // go.dev/play/p/2DTBGO0Wisj - // scheme://usr:pwd@domain.tld:port/p/a/t/h?q&u=e&r=y#f,r - u, err = url.Parse(txt) - if err != nil { - return nil, err - } - - if u.User != nil { - usr = u.User.Username() // usr - pwd, _ = u.User.Password() // pwd - } - strurl = u.Host + u.RequestURI() // domain.tld:port/p/a/t/h?q&u=e&r=y#f,r - addrs := strings.Split(u.Fragment, ",") - // opts may be nil - opts := settings.NewAuthProxyOptions(u.Scheme, usr, pwd, strurl, u.Port(), addrs) - - p, err = pxr.fromOpts(id, opts) // opts may be nil - } - - if err != nil { - log.P("proxy: add: %s failed; cfg: %v", id, txt) - log.W("proxy: add: %s failed; force? %t; err: %v", id, force, err) - return nil, err - } else if p == nil { - log.P("proxy: add: %s nil; cfg: %v", id, txt) - log.W("proxy: add: %s nil; force? %t; txt: %d", id, force, len(txt)) - return nil, errAddProxy - } else if ok := pxr.add(p); !ok { - return nil, errAddProxy - } - - log.I("proxy: add: force? %t; done %s/%s/%s", force, p.ID(), p.Type(), p.GetAddr()) - return -} - -func (pxr *proxifier) fromOpts(id string, opts *settings.ProxyOptions) (Proxy, error) { - if opts == nil { - return nil, errNoOpts - } - - var p Proxy = nil - var err error = nil - switch opts.Scheme { - case "socks5": - p, err = NewSocks5Proxy(id, pxr.ctx, pxr.ctl, pxr, opts) - case "http": - fallthrough - case "https": - p, err = NewHTTPProxy(id, pxr.ctx, pxr.ctl, pxr, opts) - case "piph2": - // todo: assert id == RpnH2 - p, err = NewPipProxy(pxr.ctx, pxr.ctl, pxr, opts) - case "pipws": - // todo: assert id == RpnWs - p, err = NewPipWsProxy(pxr.ctx, pxr.ctl, pxr, opts) - case "wg": - err = fmt.Errorf("proxy: id must be prefixed with %s in %s for [%s]", WG, id, opts) - default: - err = errProxyScheme - } - return p, err -} - -func Reaches(p Proxy, urlOrHostPortOrIPPortCsv string, protos ...string) bool { - if p == nil { - return false - } - st := p.Status() - if err := candial2(st); err != nil { - log.W("proxy: %s reaches: err %v, status(%s)", idstr(p), err, pxstatus(st)) - return false - } - if len(urlOrHostPortOrIPPortCsv) <= 0 { - return true - } - - // auto:[ip,http,https]:[v4,v6] - if strings.HasPrefix(urlOrHostPortOrIPPortCsv, "auto") { - const autoSize = 5 - prefs := strings.Split(urlOrHostPortOrIPPortCsv, ":") - scheme := "https" - if len(prefs) >= 2 { - scheme = prefs[1] - } - ipfrag := "" - if len(prefs) >= 3 { - // TODO: change "ipv4", "ipv6", "tcp4", "tcp6", "udp4", "udp6" to "v4", "v6" - ipfrag = prefs[2] // "v4", "v6", "" - } - - protos = []string{} - switch scheme { - case "http", "https": - urls := []string{} - for _, h := range dialers.SampleHosts(autoSize, ipfrag) { - u := url.URL{ - Scheme: scheme, - Host: h, - Fragment: ipfrag, // if empty, connectivity over v4+v6 is attempted - } - urls = append(urls, u.String()) - } - log.I("proxy: %s reaches: auto:http for %v urls", idstr(p), urls) - urlOrHostPortOrIPPortCsv = strings.Join(urls, ",") - case "ip": - ips := make([]netip.Addr, 0, autoSize) - log.I("proxy: %s reaches: auto:ip for %v ips", idstr(p), ips) - - if ipfrag == "v4" { - protos = append(protos, "tcp4", "udp4") - } else if ipfrag == "v6" { - protos = append(protos, "tcp6", "udp6") - } else { - protos = append(protos, "tcp", "udp") - } - for _, ip := range dialers.SampleIPs(autoSize, ipfrag) { - if ipfrag == "v4" && ip.Is4() { - ips = append(ips, ip) - } else if ipfrag == "v6" && ip.Is6() { - ips = append(ips, ip) - } else if ipfrag == "" { - ips = append(ips, ip) // both v4 and v6 - } - - } - // default port for ip:port is 80 if left unspecified (see below) - urlOrHostPortOrIPPortCsv = strings.Join(core.Map(ips, func(ip netip.Addr) string { return ip.String() }), ",") - default: - log.E("proxy: %s reaches: auto:%s for %v protos; unsupported scheme", idstr(p), scheme, protos) - return false - } - } - - pid := idstr(p) - hostportOrIPPort := strings.Split(urlOrHostPortOrIPPortCsv, ",") - if urls, oth := extractHttpURLs(urlOrHostPortOrIPPortCsv); len(urls) > 0 { - log.V("proxy: %s reaches: testing for %v", idstr(p), urls) - - hostportOrIPPort = oth - tests := make([]core.WorkCtx[bool], 0) - for _, u := range urls { - tests = append(tests, httpsReachesWorkCtx(p, u)) - } - threeSecsPerTest := time.Duration(len(tests)) * 3 * time.Second - - ok, who := core.First("reach.http."+pid, threeSecsPerTest, tests...) - - logeif(!ok)("proxy: %s #%d reaches: %v verdict (https): reachable? %t", - pid, who, urlOrHostPortOrIPPortCsv, ok) - - if !ok || len(oth) <= 0 { - return ok - } - } - - // Original logic for host:port or ip:port - hastcp := has(protos, "tcp") || has(protos, "tcp4") || has(protos, "tcp6") - hasudp := has(protos, "udp") || has(protos, "udp4") || has(protos, "udp6") - hasicmp := has(protos, "icmp") || has(protos, "icmp4") || has(protos, "icmp6") - - if !hastcp && !hasudp && !hasicmp { // fail open - hastcp = true - hasudp = true - hasicmp = false - protos = []string{"tcp", "udp"} - } - // upstream := dnsx.Default - // if pdns := p.DNS(); len(pdns) > 0 { - // upstream = pdns - // } - ipps := make([]netip.AddrPort, 0) - for _, x := range hostportOrIPPort { - host, port, err := net.SplitHostPort(x) - if err != nil { - port = "80" - } else { - x = host - } - on, _ := strconv.ParseUint(port, 10, 16) - if on == 0 { - on = 80 - } - if len(x) > 0 { // x may be ip, host - ips := dialers.For(x) - for _, ip := range ips { - ipp := netip.AddrPortFrom(ip, uint16(on)) - ipps = append(ipps, ipp) - } - } - } - - n := 0 - log.V("proxy: %s reaches: testing for %s", pid, ipps) - tests := make([][]core.WorkCtx[bool], 0) - for _, ipp := range ipps { - fns := make([]core.WorkCtx[bool], 0) - ippstr := ipp.String() - if hastcp { - fns = append(fns, tcpReachesWorkCtx(p, ippstr)) - } - if hasudp { - fns = append(fns, udpReachesWorkCtx(p, ippstr)) - } - if hasicmp { - fns = append(fns, icmpReachesWorkCtx(p, ipp)) - } - tests = append(tests, fns) - n += len(fns) - } - - if n <= 0 { - log.W("proxy: %s reaches: %v / %v; no tests for %s", - pid, urlOrHostPortOrIPPortCsv, ipps, protos) - return false - } - - ok, who, errs := core.Race("reach"+"."+pid, getproxytimeout, every(pid, tests)...) - - logeif(!ok)("proxy: %s #%d reaches: %v => %v verdict (%s): reachable? %t; errs? %v", - pid, who, urlOrHostPortOrIPPortCsv, ipps, protos, ok, errs) - - return ok -} - -func httpclient(p Proxy, url *url.URL) (client *http.Client) { - v4, v6 := true, true - switch url.Fragment { - case "tcp", "udp": - case "v4", "ipv4", "upd4", "tcp4": - v6 = false // only v4 - case "v6", "ipv6", "upd6", "tcp6": - v4 = false // only v6 - default: - } - // Lightweight transport for one-time use - client = &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - Dial: func(network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - if url.Scheme == "https" { - port = "443" - } else { - port = "80" - } - } else { - addr = host - } - on, _ := strconv.ParseUint(port, 10, 16) - if on == 0 { - if url.Scheme == "https" { - on = 443 - } else { - on = 80 - } - } - ipps := make([]netip.AddrPort, 0) - ips := dialers.For(addr) - for _, ip := range ips { - if v4 && ip.Is4() || v6 && ip.Is6() { - ipp := netip.AddrPortFrom(ip, uint16(on)) - ipps = append(ipps, ipp) - } - } - - logeif(len(ipps) == 0)("proxy: %s reaches: dial(%s, %s [among %v]) for %s", - idstr(p), network, addr, ipps, url) - - if len(ipps) <= 0 { - return nil, errNoSuitableAddress - } - - n := rand.Intn(len(ipps)) - - // filter out the revelant IPs ourselves as dialers pkg does not - // respect ip-specific network type "tcp4" or "tcp6" - // see: cdial.go:commondial2() - return p.Dial(network, ipps[n].String()) - }, - // Disable connection pooling for one-time use - DisableKeepAlives: true, - MaxIdleConns: -1, - MaxIdleConnsPerHost: -1, - // Short timeouts for quick failure detection - ResponseHeaderTimeout: 3 * time.Second, - // TODO: Prefer h1 to simplify conn handling? - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 3 * time.Second, - }, - } - return -} - -func every(who string, tests [][]core.WorkCtx[bool]) []core.WorkCtx[bool] { - all := make([]core.WorkCtx[bool], 0, len(tests)) - for _, t := range tests { - t := t - all = append(all, func(ctx context.Context) (bool, error) { - okays, errs := core.All("reach.all."+who, getproxytimeout, t...) - // overall is false if any okays is false, or if all errs are not nil - overall := core.IsAll(errs, func(err error) bool { return err == nil }) && - core.IsAll(okays, func(ok bool) bool { return ok }) - return overall, core.JoinErr(errs...) - }) - } - return all -} - -func AnyAddrForUDP(ipp netip.AddrPort) (proto, anyaddr string) { - anyaddr = "0.0.0.0:0" - proto = "udp4" - if ipp.Addr().Is6() { - proto = "udp6" - anyaddr = "[::]:0" - } - return -} - -func tcpReachesWorkCtx(p Proxy, ippstr string) core.WorkCtx[bool] { - return func(_ context.Context) (bool, error) { - return tcpReaches(p, ippstr) - } -} - -func tcpReaches(p Proxy, ippstr string) (bool, error) { - start := time.Now() - c, err := p.Dial("tcp", ippstr) - defer core.CloseConn(c) - - rtt := time.Since(start) - ok := err == nil - // net.OpError => os.SyscallError => syscall.Errno - if syserr := new(os.SyscallError); errors.As(err, &syserr) { - ok = ok || syserr.Err == syscall.ECONNREFUSED - } - - log.V("proxy: %s reaches: tcp: %s ok? %t, rtt: %s; err: %v", - p.ID(), ippstr, ok, rtt, err) - if ok { // wipe out err as it makes core.Race discard "ok" - err = nil - } - return ok, err -} - -func udpReachesWorkCtx(p Proxy, ippstr string) core.WorkCtx[bool] { - return func(_ context.Context) (bool, error) { - return udpReaches(p, ippstr) - } -} - -func udpReaches(p Proxy, ippstr string) (bool, error) { - start := time.Now() - c, err := p.Dial("udp", ippstr) - defer core.CloseConn(c) - - rtt := time.Since(start) - ok := err == nil - // net.OpError => os.SyscallError => syscall.Errno - if syserr := new(os.SyscallError); errors.As(err, &syserr) { - ok = ok || syserr.Err == syscall.ECONNREFUSED - } - - log.V("proxy: %s reaches: udp: %s ok? %t, rtt: %s; err: %v", - p.ID(), ippstr, ok, rtt, err) - if ok { // wipe out err as it makes core.Race discard "ok" - err = nil - } - return ok, err -} - -func icmpReachesWorkCtx(p Proxy, ipp netip.AddrPort) core.WorkCtx[bool] { - return func(_ context.Context) (bool, error) { - return IcmpReaches(p, ipp) - } -} - -func IcmpReaches(p Proxy, ipp netip.AddrPort) (bool, error) { - if !ipp.IsValid() { - return false, errInvalidAddr - } - - proto, anyaddr := AnyAddrForUDP(ipp) - c, err := p.Probe(proto, anyaddr) - defer core.CloseConn(c) - - if c == nil || err != nil { - err = core.OneErr(err, errNotUDPConn) - return false, err - } - - ok, rtt, err := core.Ping(c, ipp) - - // net.OpError => os.SyscallError => syscall.Errno - if syserr := new(os.SyscallError); errors.As(err, &syserr) { - ok = ok || syserr.Err == syscall.ECONNREFUSED - } - - log.V("proxy: %s reaches: icmp: %s ok? %t, rtt: %v; err: %v", - p.ID(), ipp, ok, rtt, err) - if ok { // wipe out err as it makes core.Race discard "ok" - err = nil - } - return ok, err -} - -func viaCanBind(orig Proxy, hop Proxy) error { - pxCan4 := orig.Router().IP4() - hopCan4 := hop.Router().IP4() - pxCan6 := orig.Router().IP6() - hopCan6 := hop.Router().IP6() - if pxCan4 { // suffices if px's ip4 is routable over hop - if !hopCan4 { - return errHop4Gateway - } // else: do not need to check for ip6 routes - } else if pxCan6 { // ip6 ok & px does not need ip4 - if !hopCan6 { - return errHop6Gateway - } // else: can at least route ip6, which is enough for px - } else { // unlikely that px cannot do both ip4 & ip6 - return errHopProxyRoutes - } - return nil -} - -func hasroute(p Proxy, ipp string) bool { - if p == nil { - return false - } - return p.Router().Contains(ipp) -} - -func healthy(p Proxy) error { - if p == nil { - return errProxyNotFound - } - - pid := idstr(p) - typ := typstr(p) - if local(pid) || noop(typ) { // fast path for local proxies which are always ok - return nil - } - - if err := candial2(p.Status()); err != nil { - return err - } // TODO: err on TNT, TKO? - - // TODO: via, _ := p.Router().Via() - - stat := p.Router().Stat() - now := now() - age := now - stat.Since - - oldEnough := age > ageThreshold.Milliseconds() - lastOK := stat.LastOK - lastOKNeverOK := lastOK <= 0 - lastOKBeyondThres := now-lastOK > lastOKThreshold.Milliseconds() - if (oldEnough && lastOKNeverOK) || lastOKBeyondThres { - core.Gx("healthy.TNT."+pid, func() { p.onNotOK() }) // not ok for too long - return fmt.Errorf("proxy: %s not ok; age: %s / lastOKNeverOK? %t / lastOKBeyondThres? %t", - pid, core.FmtMillis(age), lastOKNeverOK, lastOKBeyondThres) - } else if now-lastOK > tzzTimeout.Milliseconds() { - core.Gx("healthy.TZZ."+pid, func() { p.Ping() }) - } else if p.Status() != TOK { - core.Gx("healthy.TOK."+pid, func() { p.Ping() }) - } - - return nil // ok -} - -func has[T comparable](pids []T, pid T) bool { - return slices.Contains(pids, pid) -} - -func Same(a, b Proxy) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - return a.Handle() == b.Handle() -} - -func ViaID(p Proxy) string { - const novia = "" - if p == nil { - return novia - } - v, _ := p.Router().Via() - if v == nil { - return novia - } - // TODO: change all equality checks on ID() to use idstr - vid := idstr(v) - pid := idstr(p) - if vid == pid { - log.W("proxy: %s via %s; loop detected", p.ID(), v.ID()) - return novia - } - return vid -} - -func candial2(st int) error { - if st == END { - return errProxyStopped - } - if st == TPU { - return errProxyPaused - } - return nil -} - -func candial(state *core.Volatile[int]) error { - return candial2(state.Load()) -} - -func candserve2(st int) error { - if st == END { - return errProxyStopped - } - return nil -} - -func canserve(state *core.Volatile[int]) error { - return candserve2(state.Load()) -} - -func usevia(viaID *core.Volatile[string]) bool { - return viaID != nil && len(viaID.Load()) > 0 -} - -func viafor(who, viaID string, px ProxyProvider) *Proxy { - if len(viaID) <= 0 { - return nil - } - via, err := px.ProxyFor(viaID) - logei(err)("proxy: %s: viafor %s; err? %v", who, idhandle(via), err) - - if err != nil || via == nil || core.IsNil(via) { - return nil - } - return &via -} - -func swapVia(who string, new Proxy, on *core.Volatile[string], ref *core.WeakRef[Proxy]) (oldRef Proxy) { - newID := idstr(new) - oldRef = ref.Load() // old may be nil - oldID := on.Tango(newID) // newID/oldID may be empty - if idstr(oldRef) != oldID { - log.W("proxy: wg: %s setVia(%s) old(%s != %s)", - who, newID, idstr(oldRef), oldID) - return nil - } - log.I("proxy: wg: %s setVia(%s); rmv old(%s)", who, newID, oldID) - return oldRef -} - -func viaok(p *Proxy) bool { - return p != nil && core.IsNotNil(*p) && (*p).Status() != END -} - -// removeElem removes the all occurrences of rmv from s. -func removeElem[T comparable](s []T, rmv T) []T { - return core.WithoutElem(s, rmv) -} - -// addElem adds add to s if it is not already present. -func addElem[T comparable](s []T, add T) []T { - return core.WithElem(s, add) -} - -// extractHttpURLs extracts valid URLs from comma-separated input -func extractHttpURLs(csv string) (urls []*url.URL, oth []string) { - for x := range strings.SplitSeq(csv, ",") { - x = strings.TrimSpace(x) - if len(x) == 0 { - continue - } - // Check if it's a URL (contains scheme) - if u, err := url.Parse(x); err == nil && strings.Contains(u.Scheme, "http") { - urls = append(urls, u) - } else { - oth = append(oth, x) - } - } - return -} - -func httpsReachesWorkCtx(p Proxy, url *url.URL) core.WorkCtx[bool] { - return func(ctx context.Context) (bool, error) { - return httpsReaches(idstr(p), httpclient(p, url), url) - } -} - -func httpsReaches(who string, c *http.Client, url *url.URL) (bool, error) { - start := time.Now() - - req, err := http.NewRequest("HEAD", url.String(), nil) - if err != nil { - return false, fmt.Errorf("proxy: reaches: err creating req: %w", err) - } - - setua := settings.SetUserAgent.Load() - if setua { - req.Header.Set("User-Agent", settings.AndroidCcUa) - } - - statuscode := -1 - resp, err := c.Do(req) - if resp != nil { - defer core.Close(resp.Body) - statuscode = resp.StatusCode - } - - ok := statuscode > 0 - - logeif(!ok)("proxy: %s reaches: %v (ua? %t); ok? %t, status: %d, rtt: %s; err: %v", - who, url, setua, ok, statuscode, core.FmtPeriod(time.Since(start)), err) - - if ok { - err = nil // wipe out err as it makes core.Race discard "ok" - } - return ok, err -} diff --git a/intra/ipn/proxy_test.go b/intra/ipn/proxy_test.go deleted file mode 100644 index 30eef79b..00000000 --- a/intra/ipn/proxy_test.go +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build ignore -// +build ignore - -// cyclic imports - -package ipn - -import ( - "context" - "errors" - "log" - "net" - "net/netip" - "testing" - "time" - - // Removed cyclic import - // x "github.com/celzero/firestack/intra/backend" - // "github.com/celzero/firestack/intra/dns53" - // "github.com/celzero/firestack/intra/rnet" - - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dns53" - "github.com/celzero/firestack/intra/dnsx" - ilog "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/rnet" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/x64" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -type fakeResolver struct{ *net.Resolver } - -func (r fakeResolver) Lookup([]byte) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupOn([]byte, ...string) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIPFor(ctx context.Context, network, host, uid string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -type fakeCtl struct { - protect.Controller -} - -func (*fakeCtl) Bind4(_, _ string, _ int) {} -func (*fakeCtl) Bind6(_, _ string, _ int) {} -func (*fakeCtl) Protect(_ string, _ int) {} - -type fakeObs struct { - x.ProxyListener -} - -func (*fakeObs) OnProxyAdded(string) {} -func (*fakeObs) OnProxyRemoved(string) {} -func (*fakeObs) OnProxiesStopped() {} - -type fakeBdg struct { - protect.Controller - x.DNSListener -} - -var ( - baseNsOpts = &x.DNSOpts{PIDCSV: dnsx.NetNoProxy, IPCSV: "", TIDCSV: x.CT + "test0"} - baseTab = &rnet.Tab{CID: "testcid", Block: false} -) - -func (*fakeBdg) OnQuery(_, _ string, _ int) *x.DNSOpts { return baseNsOpts } -func (*fakeBdg) OnResponse(*x.DNSSummary) {} -func (*fakeBdg) OnDNSAdded(string) {} -func (*fakeBdg) OnDNSRemoved(string) {} -func (*fakeBdg) OnDNSStopped() {} - -func (*fakeBdg) Route(a, b, c, d, e string) *rnet.Tab { return baseTab } -func (*fakeBdg) OnComplete(*rnet.ServerSummary) {} - -func TestDot(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := NewProxifier(ctx, ctl, obs) - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - q := aquery("skysports.com") - q6 := aaaaquery("skysports.com") - q2 := aquery("yahoo.com") - q26 := aaaaquery("yahoo.com") - - b4, _ := q.Pack() - b6, _ := q6.Pack() - b24, _ := q2.Pack() - b26, _ := q26.Pack() - // smm := &x.DNSSummary{} - // smm6 := &x.DNSSummary{} - _ = xdns.NetAndProxyID("tcp", Base) - tm := settings.NewTunMode( - settings.DNSModePort, - settings.BlockModeNone, - settings.PtModeAuto, - ) - - // tr, _ := NewTLSTransport(ctx, "test0", "max.rethinkdns.com", []string{"213.188.216.9"}, pxr, ctl) - dtr, _ := dns53.NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - tr, _ := dns53.NewTransport(ctx, "test0", "1.0.0.2", "53", pxr) - - natpt := x64.NewNatPt(tm, bdg) - resolv := dnsx.NewResolver(ctx, "10.111.222.3:53", tm, dtr, bdg, natpt) - resolv.Add(tr) - r4, _, err := resolv.Forward(b4) - r6, _, err6 := resolv.Forward(b6) - _, _, _ = resolv.Forward(b24) - _, _, _ = resolv.Forward(b26) - time.Sleep(1 * time.Second) - _, _, _ = resolv.Forward(b6) - if err != nil { - // log.Output(2, smm.Str()) - t.Fatal(err) - } - if err6 != nil { - // log.Output(2, smm6.Str()) - t.Fatal(err6) - } - ans := xdns.AsMsg(r4) - ans6 := xdns.AsMsg(r6) - if xdns.Len(ans) == 0 && xdns.Len(ans6) == 0 { - t.Fatal("no ans") - } - log.Output(10, xdns.Ans(ans)) - log.Output(10, xdns.Ans(ans6)) -} - - -func TestProxyReaches(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := NewProxifier(ctx, ctl, obs) - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - _ = xdns.NetAndProxyID("tcp", Base) - tm := settings.NewTunMode( - settings.DNSModePort, - settings.BlockModeNone, - settings.PtModeAuto, - ) - - tr, _ := dns53.NewTLSTransport(ctx, "test0", "1.1.1.1", nil, pxr) - dtr, _ := dns53.NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - - natpt := x64.NewNatPt(tm, bdg) - resolv := dnsx.NewResolver(ctx, "10.111.222.3", tm, dtr, bdg, natpt) - resolv.Add(tr) - - var projson []byte - var err error - if projson, err = pxr.RegisterWin(nil, "", nil); err != nil { - t.Fatal(err) - } - if ips, err := pxr.TestWin(); err != nil { - t.Fatal(err) - } else { - ilog.D("se: %v", ips) - } - - pxr.AddProxy(RpnWin) - - se, _ := pxr.ProxyFor(RpnWin) - if ok := Reaches(se, "google.com", "tcp"); !ok { - t.Fail() - } - t.Log("proxy reaches") -} - -func TestWindscribeReaches(t *testing.T) { - netr := &fakeResolver{} - ctx := context.TODO() - ctl := &fakeCtl{} - obs := &fakeObs{} - bdg := &fakeBdg{Controller: ctl} - pxr := NewProxifier(ctx, ctl, obs) - ilog.SetLevel(0) - settings.Debug = true - dialers.Mapper(netr) - - _ = xdns.NetAndProxyID("tcp", Base) - tm := settings.NewTunMode( - settings.DNSModePort, - settings.BlockModeNone, - settings.PtModeAuto, - ) - - tr, _ := dns53.NewTLSTransport(ctx, "test0", "1.1.1.1", nil, pxr) - dtr, _ := dns53.NewTransport(ctx, x.Default, "1.1.1.1", "53", pxr) - - natpt := x64.NewNatPt(tm, bdg) - resolv := dnsx.NewResolver(ctx, "10.111.222.3", tm, dtr, bdg, natpt) - resolv.Add(tr) - - exit, _ := pxr.ProxyFor(Exit) - if ok := Reaches(exit, "1.1.1.1:153"); !ok { - t.Fail() - } - t.Log("proxy reaches") -} - -func aquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeA) - msg.Id = 1234 - return msg -} - -func aaaaquery(d string) *dns.Msg { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(d), dns.TypeAAAA) - msg.Id = 3456 - return msg -} diff --git a/intra/ipn/pxclient.go b/intra/ipn/pxclient.go deleted file mode 100644 index 44530ba4..00000000 --- a/intra/ipn/pxclient.go +++ /dev/null @@ -1,404 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/netip" - "net/url" - "strconv" - "strings" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" -) - -const ( - defaultTraceURL = "https://sky.rethinkdns.com/cdn-cgi/trace" - defaultWarpURL = "https://redir.nile.workers.dev/p/warp" - defaultMullvadV4URL = "https://ipv4.am.i.mullvad.net/json" - defaultMullvadV6URL = "https://ipv6.am.i.mullvad.net/json" - maxIPBodySize = int64(128 * 1024) - httpTimeout = 10 * time.Second -) - -// test hooks -var ( - traceURL = defaultTraceURL - warpURL = defaultWarpURL - mullvadV4URL = defaultMullvadV4URL - mullvadV6URL = defaultMullvadV6URL - - skipTraceForTesting = false - skipWarpForTesting = false - skipMullvadForTesting = false -) - -type proxyClient struct { - p Proxy -} - -var _ x.Client = (*proxyClient)(nil) - -func newProxyClient(p Proxy) x.Client { - return &proxyClient{p: p} -} - -// IP4 implements x.Client. -func (c *proxyClient) IP4() (*x.IPMetadata, error) { - return fetchIPMetadata(c.p, "tcp4") -} - -// IP6 implements x.Client. -func (c *proxyClient) IP6() (*x.IPMetadata, error) { - return fetchIPMetadata(c.p, "tcp6") -} - -func fetchIPMetadata(p Proxy, network string) (*x.IPMetadata, error) { - meta := &x.IPMetadata{ID: idstr(p)} - mullvadURL := mullvadV4URL - if network == "tcp6" { - mullvadURL = mullvadV6URL - } - - if trace, err1 := fetchTrace(p, network); err1 == nil { - applyTrace(meta, trace) - meta.ProviderURL = traceURL - } else if warp, err2 := fetchWarp(p, network); err2 == nil { - applyWarp(meta, warp) - meta.ProviderURL = warpURL - } else if mull, err3 := fetchMullvad(p, network, mullvadURL); err3 == nil { - applyMullvad(meta, mull) - meta.ProviderURL = mullvadURL - } else { - perr := fmt.Errorf("proxy: client: %s ip lookup failed", idstr(p)) - return nil, core.JoinErr(perr, err1, err2, err3) - } - - if len(meta.IP) <= 0 { - return nil, fmt.Errorf("proxy: client: %s ip lookup failed", idstr(p)) - } - - return meta, nil -} - -// fetchTrace fetches the Cloudflare trace data via the given proxy. -// fl=765f119 -// h=sky.rethinkdns.com -// ip=dead:beef::dead:beef -// ts=1766262434 -// visit_scheme=https -// uag=.../... -// colo=GRU -// sliver=none -// http=http/2 -// loc=BR -// tls=TLSv1.3 -// sni=plaintext -// warp=off -// gateway=off -// rbi=off -// kex=X25519 -func fetchTrace(p Proxy, network string) (map[string]string, error) { - if skipTraceForTesting { - return nil, errors.New("testing: trace skipped") - } - - body, err := fetch(p, network, traceURL) - if err != nil { - return nil, err - } - - kv := make(map[string]string) - for line := range strings.SplitSeq(string(body), "\n") { - if strings.TrimSpace(line) == "" { - continue - } - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - kv[parts[0]] = parts[1] - } - - if len(kv) == 0 { - return nil, errors.New("proxy: client: empty trace response") - } - - return kv, nil -} - -// { -// "vcode":"...", -// "minvcode":"...", -// "cansell":false, -// "ip":"dead:beef::dead:beef", -// "country":"br", -// "asorg":"NETWORKS", -// "city":"Sรฃo Paulo", -// "colo":"BR", -// "region":"Sรฃo Paulo State", -// "postalcode":"01000-000", -// "addrs":[], -// "status":"ok", -// "pubkey": {jwk} -// } -type warpResp struct { - IP string `json:"ip"` - Country string `json:"country"` - City string `json:"city"` - Region string `json:"region"` - ASNOrg string `json:"asorg"` - Latitude float64 `json:"latitude"` - Longitude float64 `json:"longitude"` -} - -func fetchWarp(p Proxy, network string) (*warpResp, error) { - if skipWarpForTesting { - return nil, errors.New("testing: warp skipped") - } - - body, err := fetch(p, network, warpURL) - if err != nil { - return nil, err - } - - var resp warpResp - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - - if resp.IP == "" { - return nil, errors.New("proxy: client: empty warp response") - } - - return &resp, nil -} - -// { -// "ip":"w.x.y.z", -// "country":"Brazil", -// "city":"Sรฃo Paulo", -// "longitude":-46.6333, -// "latitude":-23.5505, -// "mullvad_exit_ip":false, -// "blacklisted":{"blacklisted":false,"results":[]}, -// "organization":"Example Org" -// } -type mullvadResp struct { - IP string `json:"ip"` - Country string `json:"country"` - City string `json:"city"` - Longitude float64 `json:"longitude"` - Latitude float64 `json:"latitude"` - Organization string `json:"organization"` -} - -func fetchMullvad(p Proxy, network, url string) (*mullvadResp, error) { - if skipMullvadForTesting { - return nil, errors.New("testing: mullvad skipped") - } - - body, err := fetch(p, network, url) - if err != nil { - return nil, err - } - - var resp mullvadResp - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - - if resp.IP == "" { - return nil, errors.New("proxy: client: empty mullvad response") - } - - return &resp, nil -} - -func applyTrace(meta *x.IPMetadata, kv map[string]string) { - if ip, ok := kv["ip"]; ok { - meta.IP = ip - } - if cc, ok := kv["loc"]; ok { - meta.CC = strings.ToUpper(cc) - } -} - -func applyWarp(meta *x.IPMetadata, resp *warpResp) { - meta.IP = resp.IP - if resp.Country != "" { - meta.CC = strings.ToUpper(resp.Country) - } - if resp.City != "" { - meta.City = resp.City - } else if resp.Region != "" { - meta.City = resp.Region - } - if resp.ASNOrg != "" { - meta.ASNOrg = resp.ASNOrg - } - if resp.Latitude != 0 { - meta.Lat = resp.Latitude - } - if resp.Longitude != 0 { - meta.Lon = resp.Longitude - } -} - -func applyMullvad(meta *x.IPMetadata, resp *mullvadResp) { - if resp.IP != "" { - meta.IP = resp.IP - } - if resp.Country != "" && meta.CC == "" { - meta.CC = resp.Country - } - if resp.City != "" { - meta.City = resp.City - } - if resp.Organization != "" { - meta.ASNOrg = resp.Organization - } - if resp.Latitude != 0 { - meta.Lat = resp.Latitude - } - if resp.Longitude != 0 { - meta.Lon = resp.Longitude - } -} - -func fetch(p Proxy, network, rawurl string) ([]byte, error) { - parsed, err := url.Parse(rawurl) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), httpTimeout) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil) - if err != nil { - return nil, err - } - - log.VV("proxy: client: %s fetching %s via %s...", idstr(p), rawurl, network) - - // TODO: pool clients - client := httpClient(p, network, parsed) - resp, err := client.Do(req) - if err != nil { - return nil, err - } - if resp == nil { - return nil, errors.New("proxy: client: ip lookup nil response") - } - defer core.Close(resp.Body) - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return nil, fmt.Errorf("proxy: client: ip lookup %s: status %s", rawurl, resp.Status) - } - - log.VV("proxy: client: %s fetched %s via %s with status %s; reading body...", idstr(p), rawurl, network, resp.Status) - - data, err := io.ReadAll(io.LimitReader(resp.Body, maxIPBodySize)) - if err != nil { - return nil, err - } - - return data, nil -} - -func httpClient(p Proxy, network string, u *url.URL) *http.Client { - return &http.Client{ - Timeout: httpTimeout, - Transport: &http.Transport{ - DialContext: func(ctx context.Context, _, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - - if port == "" { - switch { - case u.Port() != "": - port = u.Port() - case u.Scheme == "https": - port = "443" - default: - port = "80" - } - } - - on, _ := strconv.Atoi(port) - if on <= 0 { - if u.Scheme == "https" { - on = 443 - } else { - on = 80 - } - } - - ips := dialers.For(host) - filtered := make([]netip.Addr, 0, len(ips)) - for _, ip := range ips { - if network == "tcp4" && ip.Is4() { - filtered = append(filtered, ip) - } - if network == "tcp6" && ip.Is6() { - filtered = append(filtered, ip) - } - } - - if len(filtered) == 0 { - return nil, errNoSuitableAddress - } - - log.VV("proxy: client: %s resolved %s to %v on port %d for %s", idstr(p), host, filtered, on, network) - - var lastErr error - for _, ip := range filtered { - dest := netip.AddrPortFrom(ip, uint16(on)).String() - if conn, err := p.Dial(network, dest); err == nil { - log.VV("proxy: client: %s dialed %s @ %s on %s", idstr(p), host, dest, network) - return conn, nil - } else { - log.E("proxy: client: %s failed to dial %s @ %s on %s: %v", idstr(p), host, dest, network, err) - lastErr = err - } - } - - if lastErr == nil { - lastErr = errNoSuitableAddress - } - return nil, lastErr - }, - TLSHandshakeTimeout: httpTimeout / 2, - ResponseHeaderTimeout: httpTimeout - 2, - DisableKeepAlives: true, - ForceAttemptHTTP2: true, - }, - } -} - -func (h *base) Client() x.Client { return newProxyClient(h) } -func (h *exit) Client() x.Client { return newProxyClient(h) } -func (h *exit64) Client() x.Client { return newProxyClient(h) } -func (h *auto) Client() x.Client { return newProxyClient(h) } -func (h *socks5) Client() x.Client { return newProxyClient(h) } -func (h *http1) Client() x.Client { return newProxyClient(h) } -func (h *wgproxy) Client() x.Client { return newProxyClient(h) } -func (h *pipws) Client() x.Client { return newProxyClient(h) } -func (h *piph2) Client() x.Client { return newProxyClient(h) } diff --git a/intra/ipn/pxclient_test.go b/intra/ipn/pxclient_test.go deleted file mode 100644 index c8787448..00000000 --- a/intra/ipn/pxclient_test.go +++ /dev/null @@ -1,237 +0,0 @@ -package ipn - -import ( - "context" - "errors" - "net" - "net/http" - "net/http/httptest" - "net/netip" - "strings" - "testing" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/protect/ipmap" -) - -type fakeProxy struct{ id string } - -type systemMapper struct{} - -func (systemMapper) LocalLookup(_ []byte) ([]byte, error) { - return nil, errors.New("wire lookup not supported") -} -func (systemMapper) Lookup(_ []byte, _ string, _ ...string) ([]byte, error) { - return nil, errors.New("wire lookup not supported") -} -func (systemMapper) LookupFor(_ []byte, _ string) ([]byte, error) { - return nil, errors.New("wire lookup not supported") -} -func (systemMapper) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { - return net.DefaultResolver.LookupNetIP(ctx, network, host) -} -func (systemMapper) LookupNetIPFor(ctx context.Context, network, host, _ string) ([]netip.Addr, error) { - return net.DefaultResolver.LookupNetIP(ctx, network, host) -} -func (systemMapper) LookupNetIPOn(ctx context.Context, network, host string, _ ...string) ([]netip.Addr, error) { - return net.DefaultResolver.LookupNetIP(ctx, network, host) -} - -func (f *fakeProxy) Dial(network, addr string) (protect.Conn, error) { return net.Dial(network, addr) } -func (f *fakeProxy) DialBind(network, local, remote string) (protect.Conn, error) { - if local == "" { - return f.Dial(network, remote) - } - return net.Dial(network, remote) -} -func (f *fakeProxy) Announce(string, string) (protect.PacketConn, error) { - return nil, errAnnounceNotSupported -} -func (f *fakeProxy) Accept(string, string) (protect.Listener, error) { - return nil, errAnnounceNotSupported -} -func (f *fakeProxy) Probe(string, string) (protect.PacketConn, error) { - return nil, errProbeNotSupported -} -func (f *fakeProxy) Dialer() protect.RDialer { return f } -func (f *fakeProxy) DialerHandle() uintptr { return 0 } -func (f *fakeProxy) Handle() uintptr { return 0 } -func (f *fakeProxy) ID() string { return f.id } -func (f *fakeProxy) Type() string { return NOOP } -func (f *fakeProxy) Router() x.Router { return &GWNoVia{} } -func (f *fakeProxy) Client() x.Client { return newProxyClient(f) } -func (f *fakeProxy) onNotOK() (bool, bool) { return false, true } -func (f *fakeProxy) OnProtoChange(LinkProps) (string, bool) { return "", false } -func (f *fakeProxy) Hop(Proxy, bool) error { return nil } -func (f *fakeProxy) Status() int { return TOK } -func (f *fakeProxy) GetAddr() string { return "" } -func (f *fakeProxy) DNS() string { return "" } -func (f *fakeProxy) Ping() bool { return true } -func (f *fakeProxy) Pause() bool { return false } -func (f *fakeProxy) Resume() bool { return false } -func (f *fakeProxy) Stop() error { return nil } -func (f *fakeProxy) Refresh() error { return nil } - -func restoreDefaultURLs(t *testing.T) func() { - prevTrace, prevWarp := traceURL, warpURL - prevV4, prevV6 := mullvadV4URL, mullvadV6URL - traceURL, warpURL = defaultTraceURL, defaultWarpURL - mullvadV4URL, mullvadV6URL = defaultMullvadV4URL, defaultMullvadV6URL - return func() { - traceURL, warpURL = prevTrace, prevWarp - mullvadV4URL, mullvadV6URL = prevV4, prevV6 - } -} - -func newIPv4Server(t *testing.T) *httptest.Server { - t.Helper() - ln, err := net.Listen("tcp4", "127.0.0.1:0") - if err != nil { - t.Fatalf("ipv4 listen: %v", err) - } - return newServerWithListener(t, ln) -} - -func newIPv6Server(t *testing.T) (*httptest.Server, bool) { - t.Helper() - ln, err := net.Listen("tcp6", "[::1]:0") - if err != nil { - return nil, false - } - return newServerWithListener(t, ln), true -} - -func newServerWithListener(t *testing.T, ln net.Listener) *httptest.Server { - t.Helper() - handler := http.NewServeMux() - handler.HandleFunc("/cdn-cgi/trace", func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("fl=765f119\nloc=US\ncolo=DFW\nip=1.2.3.4\n")) - }) - handler.HandleFunc("/json", func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"ip":"1.2.3.4","country":"United States","city":"Dallas","longitude":-96.8,"latitude":32.8,"organization":"Example Org"}`)) - }) - - srv := httptest.NewUnstartedServer(handler) - srv.Listener = ln - srv.Start() - t.Cleanup(srv.Close) - return srv -} - -func TestProxyClientIP4(t *testing.T) { - srv := newIPv4Server(t) - - prevTrace, prevMull := traceURL, mullvadV4URL - traceURL, mullvadV4URL = srv.URL+"/cdn-cgi/trace", srv.URL+"/json" - defer func() { traceURL, mullvadV4URL = prevTrace, prevMull }() - - p := &fakeProxy{id: "test-ipv4"} - meta, err := newProxyClient(p).IP4() - if err != nil { - t.Fatalf("ip4 err: %v", err) - } - - if meta.IP != "1.2.3.4" { - t.Fatalf("ip mismatch: %v", meta.IP) - } - if meta.CC != "US" { - t.Fatalf("cc mismatch: %v", meta.CC) - } - if meta.City != "Dallas" { - t.Fatalf("city mismatch: %v", meta.City) - } - if meta.ASNOrg != "Example Org" { - t.Fatalf("asn org mismatch: %v", meta.ASNOrg) - } - if meta.ProviderURL != mullvadV4URL { - t.Fatalf("provider mismatch: %v", meta.ProviderURL) - } -} - -func TestProxyClientIP6(t *testing.T) { - srv, ok := newIPv6Server(t) - if !ok { - t.Skip("ipv6 not available") - } - - prevTrace, prevMull := traceURL, mullvadV6URL - traceURL, mullvadV6URL = srv.URL+"/cdn-cgi/trace", srv.URL+"/json" - defer func() { traceURL, mullvadV6URL = prevTrace, prevMull }() - - p := &fakeProxy{id: "test-ipv6"} - meta, err := newProxyClient(p).IP6() - if err != nil { - t.Fatalf("ip6 err: %v", err) - } - - if meta.IP != "1.2.3.4" { - t.Fatalf("ip mismatch: %v", meta.IP) - } - if meta.CC != "US" { - t.Fatalf("cc mismatch: %v", meta.CC) - } - if meta.ProviderURL != mullvadV6URL { - t.Fatalf("provider mismatch: %v", meta.ProviderURL) - } -} - -func TestProxyClientIP4Live(t *testing.T) { - defer restoreDefaultURLs(t)() - skipWarpForTesting = true - skipTraceForTesting = true - skipMullvadForTesting = false - dialers.Mapper(ipmap.NewIPMapFor(systemMapper{})) - - p := &fakeProxy{id: "live-ipv4"} - meta, err := newProxyClient(p).IP4() - if err != nil { - t.Fatalf("live ip4 err: %v", err) - } - - if meta.IP == "" { - t.Fatal("live ip4: empty ip") - } - if ip, err := netip.ParseAddr(meta.IP); err != nil || !ip.Is4() { - t.Fatalf("live ip4: not v4 ip: %v", meta.IP) - } - if meta.ProviderURL == "" { - t.Fatal("live ip4: empty provider") - } - if strings.Contains(strings.ToLower(meta.ProviderURL), "example.org") { - t.Fatalf("live ip4: unexpected provider: %v", meta.ProviderURL) - } -} - -func TestProxyClientIP6Live(t *testing.T) { - defer restoreDefaultURLs(t)() - skipWarpForTesting = false - skipTraceForTesting = true - skipMullvadForTesting = true - dialers.Mapper(ipmap.NewIPMapFor(systemMapper{})) - - p := &fakeProxy{id: "live-ipv6"} - meta, err := newProxyClient(p).IP6() - - if err != nil { - // skip on environments without ipv6 connectivity - if strings.Contains(err.Error(), "no suitable address") || - strings.Contains(strings.ToLower(err.Error()), "ipv6") { - t.Skipf("ipv6 live lookup skipped: %v", err) - } - t.Fatalf("live ip6 err: %v", err) - } - t.Log(meta) - - if meta.IP == "" { - t.Fatal("live ip6: empty ip") - } - if ip, err := netip.ParseAddr(meta.IP); err != nil || !ip.Is6() { - t.Fatalf("live ip6: not v6 ip: %v", meta.IP) - } - if meta.ProviderURL == "" { - t.Fatal("live ip6: empty provider") - } -} diff --git a/intra/ipn/rpn.go b/intra/ipn/rpn.go deleted file mode 100644 index 4649b3b9..00000000 --- a/intra/ipn/rpn.go +++ /dev/null @@ -1,554 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "errors" - "fmt" - "strings" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn/rpn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" -) - -type RpnProxy interface { - x.RpnProxy - Proxy - Emplace(Proxy) error - PurgeAll() (n uint32) -} - -type RpnAcc = rpn.RpnAcc - -// and kick-off an update if the acc is expired? -type rpnp struct { - mu sync.RWMutex // protects Proxy & kids - - // parent proxy - p Proxy - // Rpn-specific accounting - // TODO: unembed to type assert RpnAcc impl - RpnAcc - - // rpn proxy manager - pxr Rpn - - // forked child proxy IDs, may be empty (returned proxy IDs may have stopped) - kids map[string]struct{} -} - -var _ RpnProxy = (*rpnp)(nil) -var _ RpnAcc = (*rpnp)(nil) // (useless) assertion always succeeds, see above -var _ Proxy = (*rpnp)(nil) // (useless) assertion always succeeds, see above - -var ( - errRpnMissing = errors.New("proxy: rpn: missing") - errRpnBadArgs = errors.New("proxy: rpn: bad args") - errRpnBadEmplace = errors.New("proxy: rpn: emplace: bad args") - errRpnBadCC = errors.New("proxy: rpn: bad country code") - errRpnIDsMismatch = errors.New("proxy: rpn: provider x proxy mismatch") - errRpnMainProxyMissing = errors.New("proxy: rpn: cannot fork; main proxy missing") - errRpnMainProxyStopped = errors.New("proxy: rpn: cannot fork; main proxy stopped") - errRpnNotForked = errors.New("proxy: rpn: not forked") -) - -// nb: client code isn't really expecting error from asRpnProxy. -func asRpnProxy(e Proxy, acc RpnAcc, pxr Rpn) (RpnProxy, error) { - if e == nil || acc == nil || pxr == nil { - return nil, errRpnBadArgs - } - - proxyid := idstr(e) // must be of form "provider-id + country-code" - providerid := acc.ProviderID() - if !strings.HasPrefix(proxyid, providerid) { - log.W("proxy: rpn: make: %s <> %s mismatch", proxyid, providerid) - return nil, errRpnIDsMismatch - } - log.D("proxy: rpn: make: %s[%s]", providerid, proxyid) - return &rpnp{sync.RWMutex{}, e, acc, pxr, make(map[string]struct{}, 0)}, nil -} - -func (r *rpnp) ensureProxy() Proxy { - r.mu.RLock() - defer r.mu.RUnlock() - if r.p == nil { - panic(fmt.Sprintf("proxy: rpn: missing main for %s using provider %s", r.RpnAcc.Who(), r.RpnAcc.ProviderID())) - } - return r.p -} - -func (r *rpnp) currentProxy() Proxy { - r.mu.RLock() - defer r.mu.RUnlock() - return r.p -} - -func (r *rpnp) requireProxy() (Proxy, error) { - if p := r.currentProxy(); p != nil { - return p, nil - } - return nil, errRpnMissing -} - -// ID implements x.Proxy. -func (r *rpnp) ID() string { - return r.ensureProxy().ID() -} - -// Type implements x.Proxy. -func (r *rpnp) Type() string { - return r.ensureProxy().Type() -} - -// Router implements x.Proxy. -func (r *rpnp) Router() x.Router { - return r.ensureProxy().Router() -} - -// Client implements x.Proxy. -func (r *rpnp) Client() x.Client { - return r.ensureProxy().Client() -} - -// GetAddr implements x.Proxy. -func (r *rpnp) GetAddr() string { - return r.ensureProxy().GetAddr() -} - -// DNS implements x.Proxy. -func (r *rpnp) DNS() string { - return r.ensureProxy().DNS() -} - -// Status implements x.Proxy. -func (r *rpnp) Status() int { - return r.ensureProxy().Status() -} - -// Ping implements x.Proxy. -func (r *rpnp) Ping() bool { - if p := r.currentProxy(); p != nil { - return p.Ping() - } - return false -} - -// Pause implements x.Proxy. -func (r *rpnp) Pause() bool { - if p := r.currentProxy(); p != nil { - return p.Pause() - } - return false -} - -// Resume implements x.Proxy. -func (r *rpnp) Resume() bool { - if p := r.currentProxy(); p != nil { - return p.Resume() - } - return false -} - -// Stop implements x.Proxy. -func (r *rpnp) Stop() error { - p, err := r.requireProxy() - if err != nil { - return err - } - return p.Stop() -} - -// Refresh implements x.Proxy. -func (r *rpnp) Refresh() error { - p, err := r.requireProxy() - if err != nil { - return err - } - return p.Refresh() -} - -// DialerHandle implements Proxy. -func (r *rpnp) DialerHandle() uintptr { - return r.ensureProxy().DialerHandle() -} - -// Handle implements Proxy. -func (r *rpnp) Handle() uintptr { - return r.ensureProxy().Handle() -} - -// Dialer implements Proxy. -func (r *rpnp) Dialer() protect.RDialer { - return r -} - -// onNotOK implements Proxy. -func (r *rpnp) onNotOK() (bool, bool) { - if p := r.currentProxy(); p != nil { - return p.onNotOK() - } - return false, false -} - -// OnProtoChange implements Proxy. -func (r *rpnp) OnProtoChange(lp LinkProps) (string, bool) { - if p := r.currentProxy(); p != nil { - return p.OnProtoChange(lp) - } - return "", false -} - -// Hop implements Proxy. -func (r *rpnp) Hop(p Proxy, dryrun bool) error { - main, err := r.requireProxy() - if err != nil { - return err - } - return main.Hop(p, dryrun) -} - -// Dial implements Proxy. -func (r *rpnp) Dial(network, addr string) (protect.Conn, error) { - if p, err := r.requireProxy(); err == nil { - return p.Dial(network, addr) - } else { - return nil, err - } -} - -// DialBind implements Proxy. -func (r *rpnp) DialBind(network, local, remote string) (protect.Conn, error) { - if p, err := r.requireProxy(); err == nil { - return p.DialBind(network, local, remote) - } else { - return nil, err - } -} - -// Probe implements Proxy. -func (r *rpnp) Announce(network, local string) (protect.PacketConn, error) { - if p, err := r.requireProxy(); err == nil { - return p.Announce(network, local) - } else { - return nil, err - } -} - -// Accept implements Proxy. -func (r *rpnp) Accept(network, local string) (protect.Listener, error) { - if p, err := r.requireProxy(); err == nil { - return p.Accept(network, local) - } else { - return nil, err - } -} - -// Probe implements Proxy. -func (r *rpnp) Probe(network, local string) (protect.PacketConn, error) { - if p, err := r.requireProxy(); err == nil { - return p.Probe(network, local) - } else { - return nil, err - } -} - -// Emplace implements RpnProxy. -func (r *rpnp) Emplace(new Proxy) (err error) { - if new == nil { - log.W("proxy: rpn: emplace: no-op as new proxy nil") - return errRpnBadEmplace - } - - r.mu.Lock() - defer r.mu.Unlock() - - old := r.p - oldid := idstr(old) - newid := idstr(new) - - defer func() { - if err != nil { - core.Go("rpn.emplace."+oldid, func() { - n := r.PurgeAll() // purge all kids on error - log.E("proxy: rpn: emplace: %s[%s] failed; purged %d kids; emplace err: %v", oldid, newid, n, err) - }) - } - }() - - if oldid != newid { - log.W("proxy: rpn: emplace: %s <> %s mismatch", oldid, newid) - } - - r.p = new - - log.D("proxy: rpn: emplace: %s[%s]", r.RpnAcc.ProviderID(), newid) - return nil -} - -// Fork implements x.RpnProxy. -func (r *rpnp) Fork(cc string) (x.Proxy, error) { - return r.fork(cc) -} - -// cc may be a fully qualified ID (in case of re-forking the main proxy), too. -func (r *rpnp) fork(cc string) (x.Proxy, error) { - // do not hold lock while calling into pxr as it can callback via Emplace. - main, err := r.requireProxy() - if err != nil || main == nil { - return nil, core.OneErr(err, errRpnMainProxyMissing) - } - acc := r.RpnAcc - - mainpid := idstr(main) - if len(mainpid) <= 0 { - return nil, errMissingProxyID - } - if main.Status() == END { - // TODO: PurgeAll? - return nil, errRpnMainProxyStopped - } - - provider := acc.ProviderID() - if mainpid == provider+cc || // true when cc == noCountryForOldMen or anyCountryCode - mainpid == cc || // true when cc is fully-qualified ID of the main proxy - (cc == noCountryForOldMen && !acc.MultiCountry()) || - (cc == anyCountryCode && acc.MultiCountry()) { - // re-forking main proxy (which may not be multi-country acc) via Update() => forkAll() - log.I("proxy: rpn: fork: %s main cc %s; re-adding...", provider, cc) - // expect Emplace to be called - return r.pxr.addRpnProxy(acc, cc) // re-generates conf and re-adds - } - - if len(cc) < 2 { - return nil, errRpnBadCC - } - cc = strings.ToUpper(cc) - if !acc.MultiCountry() { - return nil, log.EE("proxy: rpn: fork: %s not multi-country %s", cc, provider) - } - - log.I("proxy: rpn: fork: %s[%s]", provider, cc) - - // re-adds + updates if the proxy already exists - kid, err := r.pxr.addRpnProxy(acc, cc) - - if kid != nil { - r.mu.Lock() - r.kids[cc] = struct{}{} - r.mu.Unlock() - } - - return kid, err -} - -func (r *rpnp) forkMain() error { - main, err := r.requireProxy() - if err != nil || main == nil { - return log.EE("proxy: rpn: forkMain: main missing; err? %v", err) - } - - mainpid := idstr(main) - - _, err = r.fork(mainpid) // re-adds main proxy (via Emplace) - - logei(err)("proxy: rpn: forkMain: %s; err? %v", mainpid, err) - return err -} - -func (r *rpnp) forkAll() error { - provider := r.RpnAcc.ProviderID() - kids := r.flattenKids() - log.I("proxy: rpn: forkAll: %s[%v]", provider, kids) - - errs := make([]error, 0) // may contain nil errors - - e := r.forkMain() - - errs = append(errs, e) - - for _, cc := range kids { - _, e := r.fork(cc) - loged(e)("proxy: rpn: forkAll: forked %s[%s]; err? %v", provider, cc, e) - errs = append(errs, e) - } - return core.JoinErr(errs...) -} - -func (r *rpnp) Redo() (err error) { - return r.forkAll() -} - -func (r *rpnp) PingAll() (csvpids string, err error) { - start := time.Now() - provider := r.RpnAcc.ProviderID() - kids := r.flattenKids() - main, err := r.requireProxy() - - logei(err)("proxy: rpn: pingAll: %s[%v]; got main? %t; err: %v", - provider, kids, main != nil, err) - - if err != nil { - return - } - - mainpinged := main.Ping() - if !mainpinged { - log.W("proxy: rpn: pingAll: main proxy %s failed ping", provider) - } - - kidspinged := make([]string, 0, len(kids)) - errs := make([]error, 0) - for _, cc := range kids { - p, rerr := r.pxr.rpnProxyFor(provider, cc) - if rerr != nil { - errs = append(errs, rerr) - continue - } - if !p.Ping() { - log.W("proxy: rpn: pingAll: proxy for %s[%s] failed ping", provider, cc) - } else { - kidspinged = append(kidspinged, cc) - } - } - - err = core.JoinErr(errs...) - logei(err)("proxy: rpn: pingAll: %s[%v] done in %s; main pinged? %t / kids pinged? %v; err? %v", - provider, kids, core.FmtTimeAsPeriod(start), mainpinged, kidspinged, errs) - - return strings.Join(kidspinged, ","), err -} - -func (r *rpnp) PurgeAll() (n uint32) { - for _, cc := range r.flattenKids() { - if r.purge(cc) { - n++ - } - } - - if r.purgeMain() { - n++ - } - return -} - -func (r *rpnp) purgeMain() bool { - main, err := r.requireProxy() - mainpid := idstr(main) - logei(err)("proxy: rpn: purgeMain: %s; err? %v", mainpid, err) - if err != nil { - return false - } - return r.pxr.removeRpnProxy(r.RpnAcc, mainpid) -} - -// Purge implements x.RpnProxy. -func (r *rpnp) Purge(cc string) bool { - return r.purge(cc) -} - -func (r *rpnp) purge(cc string) bool { - main, err := r.requireProxy() - if err != nil { - log.W("proxy: rpn: purge: no main proxy %s", err) - return false - } - acc := r.RpnAcc - - provider := acc.ProviderID() - mainpid := idstr(main) - cc = strings.ToUpper(cc) - - if !acc.MultiCountry() { - log.D("proxy: rpn: purge: %s not multi-country %s", cc, provider) - return false - } else if cc == mainpid || provider+cc == mainpid { - log.W("proxy: rpn: purge: %s is main; call PurgeAll instead", cc, provider) - return false - } else if len(cc) < 2 { - log.W("proxy: rpn: purge: %s bad country code; not purging", cc) - return false - } - - rmv := r.pxr.removeRpnProxy(acc, cc) - - r.mu.Lock() - delete(r.kids, cc) - r.mu.Unlock() - - log.D("proxy: rpn: purge: %s[%s]? %t", provider, cc, rmv) - return rmv -} - -// Get implements x.RpnProxy. -func (r *rpnp) Get(cc string) (x.Proxy, error) { - return r.get(cc) -} - -func (r *rpnp) get(cc string) (x.Proxy, error) { - acc := r.RpnAcc - rpnid := acc.ProviderID() - - if cc == noCountryForOldMen && !acc.MultiCountry() { - return r, nil - } - if !acc.MultiCountry() { - return nil, log.EE("proxy: rpn: get: %s not multi-country %s", cc, rpnid) - } - if len(cc) < 2 { - log.W("proxy: rpn: get: %s bad country code", cc) - return nil, errRpnBadCC - } - cc = strings.ToUpper(cc) - - r.mu.RLock() - main := r.p - _, gotCC := r.kids[cc] - r.mu.RUnlock() - - if rpnid+cc == idstr(main) { - // return r as-is; r.p is always got after r.mu.RLock() - return r, nil - } - if !gotCC { - return nil, errRpnNotForked - } - return r.pxr.rpnProxyFor(rpnid, cc) -} - -// Kids implements x.RpnProxy. -func (r *rpnp) Kids() (csvpids string) { - return r.kidsCsv() -} - -func (r *rpnp) kidsCsv() string { - return strings.Join(r.flattenKids(), ",") -} - -func (r *rpnp) flattenKids() (ccs []string) { - r.mu.RLock() - defer r.mu.RUnlock() - - ccs = make([]string, 0, len(r.kids)) - for cc := range r.kids { - ccs = append(ccs, cc) - } - return -} - -// Update implements RpnAcc. -func (r *rpnp) Update(ops *x.RpnOps) (newState []byte, err error) { - newState, err = r.RpnAcc.Update(ops) - if err == nil { - core.Gxe("rpn.fork."+r.ProviderID(), r.forkAll) - } - return -} diff --git a/intra/ipn/rpn/cfg.go b/intra/ipn/rpn/cfg.go deleted file mode 100644 index f9e89a46..00000000 --- a/intra/ipn/rpn/cfg.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT - -// from: github.com/bepass-org/warp-plus/blob/19ac233cc/warp/endpoint.go - -package rpn - -import ( - "crypto/tls" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/netip" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/protect" -) - -// developers.cloudflare.com/1.1.1.1/ip-addresses/ -const cfdns4 = "1.1.1.1" - -const gw4 = "0.0.0.0/0" // netip.ParsePrefix("0.0.0.0/0") - -// preset 6to4 NATs; from: nat64.xyz -var Net6to4 = []netip.Prefix{ - netip.MustParsePrefix("2a00:1098:2b::/96"), // kasper - netip.MustParsePrefix("2a00:1098:2c:1::/96"), // kasper - netip.MustParsePrefix("2a01:4f8:c2c:123f:64::/96"), // kasper - netip.MustParsePrefix("2a01:4f9:c010:3f02:64::/96"), // kasper - netip.MustParsePrefix("2001:67c:2960:6464::/96"), // level66 - netip.MustParsePrefix("2001:67c:2b0:db32:0:1::/96"), // trex -} - -var ( - errRpnCountryless = errors.New("rpn is not multi-country") - errRpnStateless = errors.New("rpn has no state or config") - errRpnUpdateless = errors.New("rpn cannot be updated only registered") - - errZeroRandomEp = errors.New("warp: zero random endpoint") -) - -type RpnAcc interface { - x.RpnAcc - ProviderID() string // x.RpnWg, x.RpnPro, x.RpnAmz, x.RpnWin - MultiCountry() bool - Conf(key string) (string, error) -} - -var _ RpnAcc = (*WsClient)(nil) - -type BaseClient struct { - d protect.RDialer - h2 http.Client -} - -var dob = time.Now() -var neverEver = time.Date(5253, time.March, 6, 0, 0, 0, 0, time.UTC) - -type RpnForever struct{} - -func (RpnForever) Created() int64 { return dob.UnixMilli() } -func (RpnForever) Expires() int64 { return neverEver.UnixMilli() } - -type RpnMultiCountry struct{} - -func (RpnMultiCountry) MultiCountry() bool { return true } - -type RpnCountryless struct{} - -func (c RpnCountryless) MultiCountry() bool { return false } -func (c RpnCountryless) Locations() (x.RpnServers, error) { return nil, errRpnCountryless } - -type RpnStateless struct { - RpnUpdateless -} - -func (RpnStateless) Updated() int64 { return neverEver.UnixMilli() } -func (RpnStateless) State() ([]byte, error) { return nil, errRpnStateless } -func (RpnStateless) Conf(cc string) (string, error) { return "", errRpnStateless } - -type RpnUpdateless struct{} - -func (RpnUpdateless) Ops() *x.RpnOps { return nil } -func (RpnUpdateless) Update(*x.RpnOps) ([]byte, error) { return nil, errRpnUpdateless } - -type RpnMultiCountryServers struct { - all []x.RpnServer -} - -var _ x.RpnServers = (*RpnMultiCountryServers)(nil) - -func (s *RpnMultiCountryServers) Get(i int) (*x.RpnServer, error) { - if i < 0 || i >= len(s.all) { - return nil, fmt.Errorf("rpn: %d out of range [0, %d)", i, len(s.all)) - } - return &s.all[i], nil -} - -func (s *RpnMultiCountryServers) Len() int { - return len(s.all) -} - -func (s *RpnMultiCountryServers) Json() ([]byte, error) { - if s == nil || len(s.all) <= 0 { - return nil, fmt.Errorf("rpn: no servers") - } - // go.dev/play/p/Cxy0imeHYKx - b, err := json.Marshal(s.all) - if err != nil { - return nil, fmt.Errorf("rpn: json: %w", err) - } - return b, nil -} - -func WinEndpoints() (v4 []netip.AddrPort, v6 []netip.AddrPort, err error) { - var v4ok, v6ok bool - for _, u := range []string{svchost, wsMyIp2, wsMyIp} { - // svchost is a host, but url.Parse will work - for _, ip := range dialers.ResolveForUrl(u) { - if ipok(ip) { - if ip.Is4() { - v4 = append(v4, netip.AddrPortFrom(ip, uint16(80))) - v4ok = true - } else if ip.Is6() { - v6 = append(v6, netip.AddrPortFrom(ip, uint16(80))) - v6ok = true - } - } - } - } - if !v4ok && !v6ok { - err = errZeroRandomEp - } - return -} - -func Exit64Endpoints() (v6 []netip.Addr, errs error) { - for _, cidr6 := range Net6to4 { - if ip6, err := core.RandomIPFromPrefix(cidr6); err == nil { - if ipok(ip6) { - v6 = append(v6, ip6) - } // else: discard - } else { - errs = core.JoinErr(errs, err) - } - } - if len(v6) <= 0 { - return nil, core.JoinErr(errs, errZeroRandomEp) - } - return v6, nil -} - -func NewExtClient(d protect.RDialer) *BaseClient { - w := &BaseClient{d: d} - w.h2.Transport = &http.Transport{ - Dial: d.Dial, - ForceAttemptHTTP2: true, - ResponseHeaderTimeout: 15 * time.Second, - IdleConnTimeout: 30 * time.Second, - TLSClientConfig: &tls.Config{ - ClientSessionCache: core.TlsSessionCache(), - }, - } - return w -} - -func ipok(ip netip.Addr) bool { - return ip.IsValid() && !ip.IsUnspecified() -} - -func fmtUnixMillis(ms int64) string { - return core.FmtUnixMillisAsTimestamp(ms) -} - -func fmtTime(t time.Time) string { - return core.FmtTimeAsPeriod(t) -} diff --git a/intra/ipn/rpn/regional.go b/intra/ipn/rpn/regional.go deleted file mode 100644 index 54850fbf..00000000 --- a/intra/ipn/rpn/regional.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rpn - -import ( - "encoding/base64" - "encoding/hex" - "fmt" - "net" - "strings" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -type RegionalWgConf struct { - // WsServerList.CountryCode (uppercased) - CC string `json:"CC"` - // WsServerGroup.City - City string `json:"City"` - // City (Nick) - Name string `json:"Name"` - // WsServerGroup.Health (0-100, lower is better) - Load int32 `json:"Load"` - // WsServerGroup.LinkSpeed (100, 1000, 10000 in mbps) - Link int32 `json:"Link"` - // len(WsServerGroup.Nodes) (number of nodes in this group) - Count int32 `json:"Count"` - // WsServerList.PremiumOnly == 1 - Premium bool `json:"Premium"` - - ClientAddr4 string `json:"ClientAddr4"` - ClientAddr6 string `json:"ClientAddr6"` - ClientPrivKey string `json:"ClientPrivKey"` - ClientPubKey string `json:"ClientPubKey"` - ClientDNS4 string `json:"ClientDNS4"` - ClientDNS6 string `json:"ClientDNS6"` - - PskKey string `json:"PskKey"` - - ServerPubKey string `json:"ServerPubKey"` - ServerIPPort4 string `json:"ServerIPPort4"` - ServerIPPort6 string `json:"ServerIPPort6"` - ServerDomainPort string `json:"ServerDomainPort"` - AllowedIPs []string `json:"AllowedIPs"` // csv - - UapiWgConf string `json:"uapiwgconf,omitempty"` // generated -} - -func (rwg *RegionalWgConf) String() string { - if rwg == nil { - return "" - } - return fmt.Sprintf("%s, %s: %s", rwg.City, rwg.CC, rwg.Name) -} - -func toHex(b64 string) string { - b, err := base64.StdEncoding.DecodeString(b64) - if err != nil { - return "" - } - return hex.EncodeToString(b) -} - -func (rwg *RegionalWgConf) GenUapiConfig() (didGenerate bool) { - return rwg.genUapiConfig() -} - -func (rwg *RegionalWgConf) genUapiConfig() (didGenerate bool) { - // github.com/WireGuard/wireguard-android/blob/4ba87947ae/tunnel/src/main/java/com/wireguard/config/Config.java#L179 - // github.com/WireGuard/wireguard-android/blob/4ba87947ae/tunnel/src/main/java/com/wireguard/config/Interface.java#L257 - // allowedips must be individual entries in uapi, but our custom impl can handle csv - // see: wgproxy.go:wgIfConfigOf => wgproxy.go:loadIPNets - allowedips := rwg.AllowedIPs - if len(allowedips) <= 0 { - allowedips = []string{gw4} - } - - // not added: listen_port, persistent_keepalive_interval - rwg.UapiWgConf = fmt.Sprintf(`private_key=%s -replace_peers=true -address=%s -dns=%s -mtu=(auto) -public_key=%s`, - toHex(rwg.ClientPrivKey), - rwg.ClientAddr4, - rwg.ClientDNS4, - toHex(rwg.ServerPubKey), - ) - if len(rwg.ServerIPPort4) > 0 { - rwg.UapiWgConf += "\nendpoint=" + rwg.ServerIPPort4 - } - if len(rwg.ServerIPPort6) > 0 { - rwg.UapiWgConf += "\nendpoint=" + rwg.ServerIPPort6 - } - if len(rwg.ServerDomainPort) > 0 { - rwg.UapiWgConf += "\nendpoint=" + rwg.ServerDomainPort - } - if len(rwg.PskKey) > 0 { - rwg.UapiWgConf += "\npreshared_key=" + toHex(rwg.PskKey) - } - if len(rwg.ClientAddr6) > 0 { - rwg.UapiWgConf += "\naddress=" + rwg.ClientAddr6 - } - if len(rwg.ClientDNS6) > 0 { - rwg.UapiWgConf += "\ndns=" + rwg.ClientDNS6 - } - for _, ip := range allowedips { - rwg.UapiWgConf += fmt.Sprintf("\nallowed_ip=%s", ip) - } - - return true -} - -func (rwg *RegionalWgConf) addrCsv() string { - var addrs []string - if len(rwg.ClientAddr4) > 0 { - addrs = append(addrs, rwg.ClientAddr4) - } - if len(rwg.ClientAddr6) > 0 { - addrs = append(addrs, rwg.ClientAddr6) - } - return strings.Join(addrs, ",") -} - -// MakeUapiConfig builds an on-the-fly UAPI config string that overlays -// the credentials from a permanent config (private key, address, DNS, -// preshared key, allowed IPs) onto this regional config's server endpoints. -// rwg.UapiWgConf is NOT modified; the generated string is returned directly. -// Returns ("", false) when rwg/perma is nil or PrivateKey/Address is absent. -func (rwg *RegionalWgConf) MakeUapiConfig(creds *WsWgCreds, port string) (string, bool) { - if rwg == nil || creds == nil || len(creds.PrivateKey) <= 0 { - return "", false - } - - addr4 := rwg.ClientAddr4 - if len(creds.Address) > 0 { - addr4 = creds.Address // perma address - } - dns4 := rwg.ClientDNS4 - if len(creds.DNS) > 0 { - dns4 = creds.DNS // perma dns - } - - clientpriv := creds.PrivateKey - psk := creds.PresharedKey - peerpub := rwg.ServerPubKey - - if len(clientpriv) <= 0 || len(peerpub) <= 0 { - log.E("rpn: regconf: cannot gen; empty priv (%t) or peer pub (%t) key", len(clientpriv) <= 0, len(peerpub) <= 0) - return "", false - } - - // github.com/WireGuard/wireguard-android/blob/4ba87947ae/tunnel/src/main/java/com/wireguard/config/Config.java#L179 - // github.com/WireGuard/wireguard-android/blob/4ba87947ae/tunnel/src/main/java/com/wireguard/config/Interface.java#L257 - // allowedips must be individual entries in uapi, but our custom impl can handle csv - // see: wgproxy.go:wgIfConfigOf => wgproxy.go:loadIPNets - allowedips := []string{gw4} - if len(creds.AllowedIPs) > 0 { - parts := strings.Split(creds.AllowedIPs, ",") - allowedips = make([]string, 0, len(parts)) - for _, p := range parts { - if t := strings.TrimSpace(p); len(t) > 0 { - allowedips = append(allowedips, t) - } - } - } - - // port may be empty - ipp4str := changeport(rwg.ServerIPPort4, port) - ipp6str := changeport(rwg.ServerIPPort6, port) - domstr := changeport(rwg.ServerDomainPort, port) - - if settings.Debug { - log.V("rpn: regconf: gen for %s/%s (port? %s); endpoint: %s %s %s; psk? %t; allowed: %v", - addr4, dns4, port, ipp4str, ipp6str, domstr, len(psk) > 0, allowedips) - } - - // not added: listen_port, persistent_keepalive_interval - conf := fmt.Sprintf(`private_key=%s -replace_peers=true -address=%s -dns=%s -mtu=(auto) -public_key=%s`, - toHex(clientpriv), - addr4, - dns4, - toHex(peerpub), - ) - if len(ipp4str) > 0 { - conf += "\nendpoint=" + ipp4str - } - if len(ipp6str) > 0 { - conf += "\nendpoint=" + ipp6str - } - if len(domstr) > 0 { - conf += "\nendpoint=" + domstr - } - if len(psk) > 0 { - conf += "\npreshared_key=" + toHex(psk) - } - for _, ip := range allowedips { - conf += fmt.Sprintf("\nallowed_ip=%s", ip) - } - - return conf, true -} - -func changeport(endpoint, newPort string) string { - if len(endpoint) <= 0 || len(newPort) <= 0 || newPort == "0" { - return endpoint - } - host, _, err := net.SplitHostPort(endpoint) - if err != nil { - return endpoint // malformed, return as is - } - return net.JoinHostPort(host, newPort) -} diff --git a/intra/ipn/rpn/yegor.go b/intra/ipn/rpn/yegor.go deleted file mode 100644 index b0df2110..00000000 --- a/intra/ipn/rpn/yegor.go +++ /dev/null @@ -1,2211 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rpn - -import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "math/rand/v2" - "net" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -// github.com/Windscribe/browser-extension/blob/ed83749ad/modules/ext/src/utils/constants.js#L31 -const ( - svchost = "svc.rethinkdns.com" - svchosttest = "redir.nile.workers.dev" - wsMyIp = "https://checkip.windscribe.com/" - wsMyIp2 = "https://checkip.totallyacdn.com/" - // didTokenHeader is the HTTP response/request header for a device-id token - // issued by svchost / svchosttest. Format: "a hextoken:expiryepochsec". - didTokenHeader = "x-rethink-app-did-token" -) - -const ( - // "/init" registers this client's identity with remote. Calling multiple times - // registers identity multiple times (different Preshared Keys are generated), and so, - // calling just the once is sufficient. Though, if called multiple times, only the latest - // config (address + preshared key) it generates is valid, while the rest won't work. - // - // if the dynamic wg interface was released, connecting to WG using current config (old interface), - // the handshakes will fail. This is a trigger to re-run the "/connect" call to get a new interface, - // and handshake again. There is no harm blindly running "/connect" before every WG connection attempt, - // as it will reuse the interface if its still reserved. Avoid running "/init" multiple times as after - // you hit the limit, you will get an error like this: - // { - // "errorCode": 1313, - // "errorMessage": "You have reached your limit of WireGuard keys. Do you want to delete your oldest key?", - // "errorDescription": "Maximum number of pub keys reached", - // "logStatus": null - // } - // - // To recover from this, either supply the optional "force_init=1" field, - // which will delete the oldest keypair, or delete all keypais using - // the PUT "/Users" method + "delete_credentials=1" field. - // - // If "/connect" is attempted using such a deleted key (still stored locally, but not useless) - // the API returns: - // { - // "errorCode": 1311, - // "errorMessage": "Invalid WireGuard public key was provided", - // "errorDescription": "WG pub key is unknown", - // "logStatus": null - // } - // - // This error is a trigger to run a new "/init" API call to generate a new keypair, - // as you're effectively in a clean slate (how all accounts start out). - // - // Another error 1312 - "Could not select a new WireGuard interface IP" is returned - // by "/connect" if the "/init"d client IP was released by the server and assigned - // to another user? - // - // Once local (client) has keypair, do not make the "/init" call again unless on errors - // as described above. - wswginitpath = "WgConfigs/init" - // To setup WG interface after "/init", run "/connect", which reserves an interface on remote - // A full WG config must now be built form WsServerList and "/init"d keypair. - // This interface reservation is active while connected, and up to 4m after disconnect. - // That is, the keys are released 4-5 mins after the the last handshake - // occurs from the perspective of the server. This usually happens when a user disconnects - // in the app, or the connection is severed for any reason. - // If more time passes, this interface is released into the pool and will no longer work, - // requiring a new "/connect" API call. - // If possible, hook into WireGuard's "verbose logging" - // to detect a handshake failure and not wait for a handshake timeout. - wswgconnectpath = "WgConfigs/connect" - // wswgpermanentpath generates a permanent (static) WireGuard config tied to a client-supplied - // public key. Unlike /init+/connect, this config does not expire automatically. - // POST with form fields: port, wg_pubkey. Bearer token is client-supplied. - wswgpermanentpath = "WgConfigs/permanent" - // wswglistkeyspath lists all active permanent WireGuard configs for the account (max 5). - // GET with Bearer token client-supplied. - wswglistkeyspath = "WgConfigs/list_keys" - // TTL reservation time using param "wg_ttl", not for longer than an hour, if needed. - // github.com/Windscribe/Android-App/blob/3f9c2ab98a70fa/base/src/main/java/com/windscribe/vpn/repository/WgConfigRepository.kt#L143 - wgttl = "3600" // an hour in seconds - - wssessionpath = "Session/" - // wsportpath = "PortMap/" - wslocpath = "serverlist/mob-v2/1/" // + $loc_hash - - // for ovpn (unused): - // wspxpath = "ServerCredentials/" - // wsbestloc = "BestLocation/" - - wsMinServerLinkSpeed = 1 // 1mbps - wsMaxServerHealth = 100 // min is 0 - allPerRegionWgConfs = true // when false, only maxPerRegionWgConfs*2 are chosen - maxPerRegionWgConfs = 4 - maxAnyWgConfs = 8 - wsMaxPermaWgKeys = 5 - - disablePermaCreds = true -) - -// github.com/Windscribe/Android-App/blob/746d505dc69/base/src/main/java/com/windscribe/vpn/constants/NetworkErrorCodes.kt -const ( - ekeylimit = 1313 - ekeyinvalid = 1311 - enoaddr = 1312 -) - -const ( - confKeySep = ";" - - // onlyPremiumServers removes non-premium servers from the server list, - // which may reduce its count substantially. - onlyPremiumServers = false -) - -// github.com/Windscribe/Android-App/blob/746d505dc69/base/src/main/res/raw/port_map.txt#L76 -var wswgports = []string{ /*0th & 1st pos used by wsRandomPort */ "65142", "1194", "53", "123", "443", "80"} - -var ( - errWsBadGatewayArgs = errors.New("ws: cannot make gw; missing args") - errWsNoConfig = errors.New("ws: no config") - errWsNoJsonConfig = errors.New("ws: no json config") - errWsNoSession = errors.New("ws: no session info") - errWsNoClient = errors.New("ws: no client") - errWsNoEntitlement = errors.New("ws: missing entitlement") - errWsNoToken = errors.New("ws: missing token") - errWsNoCid = errors.New("ws: missing cid") - errWsNoDid = errors.New("ws: missing device id") - errWsDidMismatch = errors.New("ws: device id mismatch") - errWsNoResponse = errors.New("ws: no response") - errWsNoLocHash = errors.New("ws: no loc hash") - errWsNoServerList = errors.New("ws: no server list") - errWsBadServerList = errors.New("ws: invalid server list") - errWsRetryUpdate = errors.New("ws: retry update") - errWsNoCcConfig = errors.New("ws: not available in that location") -) - -/* - { - "data": { - "portmap": [ - {...}, {...} - ], - } - "metadata": { ... } -*/ -type WsPortMapResponse struct { - Data struct { - PortMap []WsPortMap `json:"portmap"` - } `json:"data"` - Metadata WsMetadata `json:"metadata"` -} - -/* - { - "errorCode": 502, - "errorMessage": "1 arguments had validation errors", - "errorDescription": "Argument did not validate", - "logStatus": null, - "validationFailuresArray": { - "0": { - "wg_pubkey": { - "lengthMin": { - "validationValue": "44" - } - } - }, - "validationErrorMessageArray": [ - "wg pubkey is too short. Minimum value is 44 characters" - ] - } - } -*/ -type WsErrorResponse struct { - Code int `json:"errorCode"` - Msg string `json:"errorMessage"` - Desc string `json:"errorDescription"` - LogStatus string `json:"logStatus"` - Failures map[string]any `json:"validationFailuresArray"` - // RPN errors - Error string `json:"error"` - Details string `json:"details,omitempty"` -} - -/* - { - "serviceRequestId": "1752328061104032696", - "hostName": "staging", - "duration": "0.00278ms", - "logStatus": null, - "md5": "f90552b29b73b6899f7f00dc0c9fe5f4" - } -*/ -type WsMetadata struct { - ServiceRequestId string `json:"serviceRequestId"` - HostName string `json:"hostName"` - Duration string `json:"duration"` - LogStatus string `json:"logStatus"` - MD5 string `json:"md5"` -} - -/* - { - "portmap": [ - { - "protocol": "wg", - "heading": "WireGuard", - "use": "ip3", - "ports": [ - "443", - "80", - "53", - "123", - "1194", - "65142" - ] - }, - { - "protocol": "ikev2", - "heading": "IKEv2", - "use": "hostname", - "ports": [ - "500" - ], - "legacy_ports": [ - "500" - ] - }, - { - "protocol": "udp", - "heading": "UDP", - "use": "ip2", - "ports": [ - "443", - "80", - ], - "legacy_ports": [ - "443" - ] - }, - { - "protocol": "tcp", - "heading": "TCP", - "use": "ip2", - "ports": [ - "443", - "587", - "21", - "1194" - ], - "legacy_ports": [ - "1194" - ] - }, - { - "protocol": "stunnel", - "heading": "Stealth", - "use": "ip3", - "ports": [ - "443", - "587", - "21", - "22", - "80", - "123", - "8443" - ], - "legacy_ports": [ - "8443" - ] - }, - { - "protocol": "wstunnel", - "heading": "WStunnel", - "use": "hostname", - "ports": [ - "443" - ], - "legacy_ports": [ - "443" - ] - } - ] - } -*/ -type WsPortMap struct { - Protocol string `json:"protocol"` - Heading string `json:"heading"` - Use string `json:"use"` - Ports []string `json:"ports"` - LegacyPorts []string `json:"legacy_ports,omitempty"` -} - -/* - { - "revision": 2672, - "revision_hash": "6ad01549d6643292d8021f19a14b82310ff44c90", - "changed": 1, - "fc": 1 - } -*/ -type WsInfo struct { - Revision int `json:"revision"` - RevisionHash string `json:"revision_hash"` - Changed int `json:"changed"` - FeatureConfig int `json:"fc"` -} - -/* - { - "id": 90, - "name": "Guatemala", - "country_code": "gt", - "status": 1, - "premium_only": 1, - "short_name": "GT", - "p2p": 1, - "tz": "America/Toronto", - "tz_offset": "-4,EST", - "loc_type": "normal", - "dns_hostname": "gt.windscribe.dev", - "groups": [ - { - "id": 110, - "city": "San Marcos", - "nick": "Trafficker", - "pro": 1, - "gps": "14.97,-91.80", - "tz": "America/Guatemala", - "wg_pubkey": "NXPIQ0kD2ww9VkgFqMWvJs7ZLthd5PS259/yJPyDyz0=", - "wg_endpoint": "gua-110-wg.whiskergalaxy.dev", - "ovpn_x509": "gua-110.windscribe.dev", - "ping_ip": "66.66.66.70", - "ping_host": "http://gt-stg-001.whiskergalaxy.dev:6464/latency", - "link_speed": "100", - "nodes": [ - { - "ip": "66.66.66.70", - "ip2": "66.66.66.71", - "ip3": "66.66.66.72", - "hostname": "gt-stg-001.whiskergalaxy.dev", - "weight": 1, - "health": 0 - } - ], - "health": 0 - } - ] - }, - { - "id": 38, - "name": "Canada East", - "country_code": "CA", - "status": 1, - "premium_only": 0, - "short_name": "CA", - "p2p": 0, - "tz": "America/Toronto", - "tz_offset": "-4,EST", - "loc_type": "normal", - "dns_hostname": "ca.windscribe.dev", - "groups": [ - { ... } - ] - }, - { - "id": 52, - "name": "United States", - "country_code": "US", - "status": 1, - "premium_only": 0, - "short_name": "US", - "p2p": 0, - "tz": "America/Chicago", - "tz_offset": "-6,CET", - "loc_type": "normal", - "dns_hostname": "us.windscribe.dev", - "groups": [ - { ... } - ] - }, - { - "id": 84, - "name": "The Best Korea", - "country_code": "KP", - "status": 2, - "premium_only": 1, - "short_name": "KP", - "p2p": 1, - "tz": "Asia/Pyongyang", - "tz_offset": "9,PYT", - "loc_type": "normal", - "groups": [ - { ... } - ] - } -*/ -type WsServerList struct { - ID int `json:"id"` - Name string `json:"name"` - CountryCode string `json:"country_code"` - // Status is for country level location record, - // flips to a value thats not 1 only if all servers - // in the whole country are unavailable. - Status int `json:"status"` - PremiumOnly int `json:"premium_only"` - ShortName string `json:"short_name"` - // p2p is 0 as a signal that common torrent trackers are null routed - // on these machines and torrenting is discouraged. Nothing prevents - // users from still doing so, especially on private trackers. - // This flag has no impact on port forwarding. - P2P int `json:"p2p"` - TZ string `json:"tz"` - TZOffset string `json:"tz_offset"` - LocType string `json:"loc_type"` - DNSHostname string `json:"dns_hostname,omitempty"` - Groups []WsServerGroup `json:"groups"` -} - -/* - { - "id": 87, - "city": "Boston", - "nick": "The Wahlberg", - "pro": 1, - "gps": "42.36,-71.06", - "tz": "EST", - "wg_pubkey": "M/kVfITvFaz8i8msJi7C3jsgk45vnbnYLJIOl/zBsVk=", - "wg_endpoint": "bos-87-wg.whiskergalaxy.dev", - "ovpn_x509": "bos-87.windscribe.dev", - "ping_ip": "181.215.52.122", - "ping_host": "http://us-014.whiskergalaxy.dev:6464/latency", - "link_speed": "1000", - "nodes": [ - { ... } - ], - "health": 0 - } -*/ -type WsServerGroup struct { - ID int `json:"id"` - City string `json:"city"` - Nick string `json:"nick"` - Pro int `json:"pro"` - GPS string `json:"gps"` - TZ string `json:"tz"` - // public key for all servers in this datacenter / building - WgPubKey string `json:"wg_pubkey"` - // WireGuard hostname to connect to, which will connect to a - // random WireGuard server in this datacenter: Its A/AAAA records - // contains all "ip3" (WireGuard-only) IPs for individual Nodes - // (all hosts in this datacenter). - WgEndpoint string `json:"wg_endpoint"` - OvpnX509 string `json:"ovpn_x509"` - PingIP string `json:"ping_ip"` - // GET - // {"rtt": "5775"} = 5.7ms - PingHost string `json:"ping_host"` - // 100, 1000, 10000 (in mbps) etc; - LinkSpeed string `json:"link_speed"` - // Nodes are online servers that can be connected to in a datacneter. - // If empty, the datacenter is offline. - Nodes []WsServerNode `json:"nodes"` - // Health is a measure of load, between 0 and 100. Lower is better. - Health int `json:"health"` -} - -/* - { - "ip": "181.215.52.122", - "ip2": "181.215.52.123", - "ip3": "181.215.52.124", - "hostname": "us-014.whiskergalaxy.dev", - "weight": 1, - "health": 0 - }, - { - "ip": "181.215.52.2", - "ip2": "181.215.52.3", - "ip3": "181.215.52.4", - "hostname": "us-004.whiskergalaxy.dev", - "weight": 1, - "health": 0 - } -*/ -type WsServerNode struct { - IP string `json:"ip"` - IP2 string `json:"ip2,omitempty"` - // Init direct connections to "ip3" + port to connect to - // a specific node (host) skipping DNS. - IP3 string `json:"ip3,omitempty"` - Hostname string `json:"hostname"` - Weight int `json:"weight"` - Health int `json:"health"` -} - -/* - { - "data": [ - { ... }, { ... }, - ], - "info": { ... }, - "metadata": { ... } - } -*/ -type WsServerListResponse struct { - Data []WsServerList `json:"data"` - Info WsInfo `json:"info"` - Metadata WsMetadata `json:"metadata"` -} - -/* - { - "user_id": "l7xfy9c1", - "session_auth_hash": "id:typ:epochsec:sig1:sig2", - "username": "celz_l7xfy9c1", - "traffic_used": 0, - "traffic_max": -1, - "status": 1, - "email": null, - "email_status": 0, - "billing_plan_id": 120, - "is_premium": 1, - "rebill": 0, - "premium_expiry_date": "2025-07-15", - "reg_date": 1749999508, - "last_reset": "2025-07-12", // or null - "loc_rev": 2672, - "loc_hash": "6ad01549d66..." - } -*/ -type WsSession struct { - UserID string `json:"user_id"` - // Encrypted account session token authenticates this user with the API to get - // and set any state, incl WG configuration which are unique and bound - // to each session token. Session token is the "bearer token". - // Unencrypted session/bearer token is of shape, id:type:timestamp:sig1:sig2. - // However it could change to a new format in the future. - SessionToken string `json:"session_auth_hash"` - Username string `json:"username"` - // TrafficUsed shows byte count of data used since LastReset date? - TrafficUsed int64 `json:"traffic_used"` - TrafficMax int64 `json:"traffic_max"` - // Status is 1 under normal circumstances. Any other state means - // this user is banned (or a free account has expired). - // Bans are extremely rare, and there is no appeal process. - // Bans are issued on repeated abuse, never right away. - // This user is permanently disabled if banned. - Status int `json:"status"` - Email string `json:"email"` - EmailStatus int `json:"email_status"` - BillingPlanID int `json:"billing_plan_id"` - IsPremium int `json:"is_premium"` - Rebill int `json:"rebill"` - // This will downgrade on this date, unless its renewed. - // ex: "2025-07-22" go.dev/play/p/IoeH1Ee6cZ3 - ExpiryDate string `json:"premium_expiry_date"` - RegDate int64 `json:"reg_date"` - LastReset string `json:"last_reset"` // can be null - LocRev int `json:"loc_rev"` - // Latest revision hash of the server list. - LocHash string `json:"loc_hash"` -} - -/* - { - "data": WsSession , - "metadata": { ... } - } -*/ -type WsSessionResponse struct { - Data WsSession `json:"data"` - Metadata WsMetadata `json:"metadata"` -} - -/* - { - "username": "string", - "password": "string" - } -*/ -type WsProxyCreds struct { - Username string `json:"username"` - Password string `json:"password"` -} - -/* - { - "data": WsProxyCreds, - metadata: { ... } - } -*/ -type WsProxyCredsResponse struct { - Data WsProxyCreds `json:"data"` - Metadata WsMetadata `json:"metadata"` -} - -/* - { - "PrivateKey": "stdbase64", - "PublicKey": "tsoZzRelDNFe/xF6eQz+xxzjmgS0xKfxEmlqsZKPNgs=", - "PresharedKey": "stdbase64", - "AllowedIPs": "0.0.0.0/0", - "Address": "100.64.236.203/32", // omitempty: absent in /init responses; present in /permanent - "DNS": "10.255.255.1" // omitempty: absent in /init responses; present in /permanent - } -*/ -type WsWgCreds struct { - PrivateKey string `json:"PrivateKey,omitempty"` // base64; locally generated for /init, server-generated for /permanent - PublicKey string `json:"PublicKey"` // base64; locally generated for /init, server-generated for /permanent - PresharedKey string `json:"PresharedKey"` // base64; only the latest key is valid (generated by remote) - AllowedIPs string `json:"AllowedIPs"` // e.g. "0.0.0.0/0" or "0.0.0.0/0, ::/0" - // Address and DNS are populated only by the /permanent endpoint; empty for /init responses. - // So dynamic creds, after /init, will have to /connect before populating these - Address string `json:"Address,omitempty"` // CIDR notation, e.g. "100.64.236.203/32" - DNS string `json:"DNS,omitempty"` // IP address, e.g. "10.255.255.1" -} - -/* - { - "config": WsWgCreds, - "debug": { - "init": "generated: tsoZzRelDNFe/xF6eQz+xxzjmgS0xKfxEmlqsZKPNgs=" - }, - "success": 1 - } - -// Also used for /WgConfigs/permanent responses, where config additionally -// carries Address and DNS (populated in WsWgCreds via the omitempty fields). -*/ -type WsWgCredsData struct { - Config WsWgCreds `json:"config"` - Debug map[string]string `json:"debug,omitempty"` - Success int `json:"success"` -} - -/* - { - "data": WsWgCredsData, - "metadata": { ... } - } -*/ -type WsWgCredsResponse struct { - Data WsWgCredsData `json:"data"` - Metadata WsMetadata `json:"metadata"` -} - -/* - { - "Address": "100.65.61.145/32", - "DNS": "10.255.255.2" - } -*/ -type WsWgInterface struct { - Address string `json:"Address"` // cidr notation - DNS string `json:"DNS"` // ip address -} - -/* - { - "config": WsWgInterface, - "debug": { - "pub_key": "supplied: tsoZzRelDNFe/xF6eQz+xxzjmgS0xKfxEmlqsZKPNgs=", - "interface": "generated + attached: 100.65.61.145" - }, - "success": 1 - } -*/ -type WsWgConnectData struct { - Config WsWgInterface `json:"config"` - Debug map[string]string `json:"debug,omitempty"` // optional - Success int `json:"success"` -} - -/* - { - "data": WsWgConnectData, - "metadata": { ... } - } -*/ -type WsWgConnectResponse struct { - Data WsWgConnectData `json:"data"` - Metadata WsMetadata `json:"metadata"` -} - -// WsWgPermanentConfig is an alias for WsWgCreds; the /permanent endpoint returns -// the same envelope as /init (WsWgCredsData / WsWgCredsResponse) but additionally -// populates the Address and DNS fields added to WsWgCreds. -type WsWgPermanentConfig = WsWgCreds - -/* - { - "data": { - "pub_keys": ["WzPsW3p+t5rkbZ2zg/QciGN3vMQVKciP/csQzIZ0ohE=", ...], - "success": 1 - }, - "metadata": { ... } - } -*/ -type WsWgListKeysData struct { - PubKeys []string `json:"pub_keys"` - Success int `json:"success"` -} - -type WsWgListKeysResponse struct { - Data WsWgListKeysData `json:"data"` - Metadata WsMetadata `json:"metadata"` -} - -type WsClient struct { - RpnMultiCountry - - http *http.Client - ops *core.Volatile[x.RpnOps] // current ops; retained across subsequent Conf() calls - - configExt *core.Volatile[*WsWgConfig] - configExtUpdateTime *core.Volatile[time.Time] -} - -type WsWgConfig struct { - Entitlement *WsEntitlement `json:"entitlement"` // entitlement info - Session *WsSession `json:"session"` - Configs []*RegionalWgConf `json:"configs"` - Servers []WsServerList `json:"servers"` // all servers in the server list - Creds *WsWgCreds `json:"creds"` // base64 encoded private key - PermaCreds *WsWgPermanentConfig `json:"permacreds,omitempty"` // permanent WG config; nil if not yet fetched -} - -/* -{ -"kind": "ws#v1", -"cid": "hex", // Identifier -"sid": "id:epochsec:parentcidsig", // profile identifier, if any -"sessiontoken": enc("id:typ:epochsec:sig1:sig2"), -"expiry": "2025-07-15T00:00:00Z", // Expiry date of the entitlement -"status": "valid" // "valid" | "invalid" | "banned" | "expired" | "unknown" -"allowRestore": false, // true if this entitlement can be restored -"test": false // true if this is a test entitlement -} -*/ -type WsEntitlement struct { - Kind string `json:"kind"` // e.g. "ws#v1" - Cid string `json:"cid"` // Client ID - Did string `json:"did,omitempty"` // Device ID, if any - Pid string `json:"pid,omitempty"` // Share ID - SessionToken string `json:"sessiontoken"` // Encrypted session token - // Expiry date of the entitlement; go.dev/play/p/d2gshytEF61 - Exp time.Time `json:"expiry"` - AccStatus string `json:"status"` // "valid" | "invalid" | "banned" | "expired" | "unknown" - AllowCrossDevice bool `json:"allowRestore"` // true if this entitlement can be restored - TestDomain bool `json:"test"` // true if this is a test entitlement - // DidToken is the device-id token (x-rethink-app-did-token) issued by svchost, - // format "hextoken:expiryepochsec". - DidToken string `json:"didtoken,omitempty"` -} - -var _ x.RpnAcc = (*WsClient)(nil) -var _ x.RpnEntitlement = (*WsEntitlement)(nil) - -func (e *WsEntitlement) ProviderID() string { - return x.RpnWin -} - -func (e *WsEntitlement) DID() string { - return e.Did -} - -func (e *WsEntitlement) CID() string { - return e.Cid -} - -func (e *WsEntitlement) Token() string { - return e.SessionToken -} - -func (e *WsEntitlement) Expiry() string { - return e.Exp.Format(time.RFC3339) -} - -func (e *WsEntitlement) Status() string { - return e.AccStatus -} - -func (e *WsEntitlement) AllowRestore() bool { - return e.AllowCrossDevice -} - -func (e *WsEntitlement) Test() bool { - return e.TestDomain -} - -func (e *WsEntitlement) Json() ([]byte, error) { - if e == nil { - return nil, errWsNoEntitlement - } - var w core.ByteWriter - enc := json.NewEncoder(&w) - if err := enc.Encode(e); err != nil { - return nil, fmt.Errorf("ws: entitlement encode err: %w", err) - } - // Bytes not recycled as these are crossing into cgo - return w.Bytes(), nil -} - -func (a *WsWgConfig) Json() ([]byte, error) { - if a == nil { - return nil, errWsNoConfig - } - - var w core.ByteWriter - if err := a.writeJson(&w); err != nil { - return nil, err - } - // Bytes not recycled as these are crossing into cgo - return w.Bytes(), nil -} - -func (a *WsWgConfig) writeJson(w io.Writer) error { - if a == nil { - return errWsNoConfig - } - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - return enc.Encode(a) -} - -func (a *WsClient) config() *WsWgConfig { - if a == nil { - return nil - } - return a.configExt.Load() -} - -// Who implements x.RpnAcc. -func (a *WsClient) Who() string { - if a == nil { - return "" - } - c := a.config() - if c == nil || c.Session == nil { - return "" - } - status := strconv.Itoa(c.Session.Status) - return status + ":" + c.Session.UserID + "+" + trunc8(byte2hex(sha(c.Session.SessionToken))) + "@" + a.kid() -} - -// ProviderID implements RpnAcc. -func (*WsClient) ProviderID() string { return x.RpnWin } - -// State implements x.RpnAcc. -func (a *WsClient) State() ([]byte, error) { - if a == nil { - return nil, errWsNoClient - } - c := a.config() - if c == nil { - return nil, errWsNoConfig - } - return c.Json() -} - -// Created implements x.RpnAcc. -func (a *WsClient) Created() int64 { - if a == nil { - return -1 - } - c := a.config() - if c == nil { - return 0 - } - createdAt := time.Unix(int64(c.Session.RegDate), 0) - return createdAt.UnixMilli() -} - -func (a *WsClient) Updated() int64 { - if a == nil { - return -1 - } - c := a.config() // must have config - if c == nil { - return 0 - } - updatedAt := a.configExtUpdateTime.Load() - return updatedAt.UnixMilli() -} - -// Ops implements x.RpnAcc. -func (a *WsClient) Ops() *x.RpnOps { - ops := a.ops.Load() - return &ops -} - -// Expires implements x.RpnAcc. -func (a *WsClient) Expires() int64 { - if a == nil { - return -1 - } - c := a.config() - if c == nil { - return 0 - } - - refreshAt, err := time.Parse(time.DateOnly, c.Session.ExpiryDate) - if err != nil { - log.W("ws: expires: cannot parse %s; err: %v", c.Session.ExpiryDate, err) - return -2 - } - - return refreshAt.UnixMilli() -} - -func (a *WsClient) Locations() (x.RpnServers, error) { - if a == nil { - return nil, errWsNoClient - } - c := a.config() - if c == nil { - return nil, errWsNoConfig - } - if len(c.Configs) <= 0 { - return nil, errWsNoCcConfig - } - visited := make(map[string]bool, len(c.Configs)) - s := make([]x.RpnServer, 0, len(c.Configs)/maxPerRegionWgConfs) - for i, rc := range c.Configs { - if rc == nil { - log.W("ws: locations: config %d is nil", i) - continue - } - if len(rc.ServerPubKey) <= 0 { - log.D("ws: locations: config#%d has no wg conf", i) - continue - } - if len(rc.CC) <= 0 { - log.W("ws: locations: config#%d has no cc", i) - continue - } - if !visited[rc.Name] { - s = append(s, x.RpnServer{ - CC: rc.CC, - City: rc.City, - Name: rc.Name, - Load: rc.Load, - Link: rc.Link, - Count: rc.Count, - Premium: rc.Premium, - // cc is always suffixed; see proxy.go:proxifier.postAddRpnProxy - Key: strings.Join([]string{rc.City, rc.CC}, confKeySep), - Addrs: strings.Join([]string{rc.ServerDomainPort, rc.addrCsv()}, ","), - }) - } - visited[rc.Name] = true - } - return &RpnMultiCountryServers{s}, nil -} - -// Update implements x.RpnAcc. -func (a *WsClient) Update(ops *x.RpnOps) (newstate []byte, err error) { - if a == nil { - return nil, errWsNoClient - } - c := a.config() - if c == nil { - return nil, errWsNoConfig - } - if ops == nil { - ops = a.Ops() - } - start := time.Now() - b, refreshed, err := makeWsWgFrom(a.http, c, *ops, true /*updating*/) - if err != nil || !refreshed { - log.E("ws: update: refreshed? %t; err: %v", refreshed, err) - return nil, core.OneErr(err, errWsRetryUpdate) - } - - // if configs have changed, the current proxies using those, if any, - // will need to be updated. - if _, err := a.shallowCopyConfig(b); err != nil { - return nil, log.EE("ws: update: shallow copy err: %v", err) - } - log.I("ws: update: refreshed? %t; took %v", refreshed, core.FmtTimeAsPeriod(start)) - return a.State() -} - -func (a *WsClient) shallowCopyConfig(b *WsClient) (copied bool, err error) { - if a == nil || b == nil { - return false, nil // no-op - } - bc := b.config() - if bc == nil { - log.E("ws: shallowcopy: storing nil config...") - return false, errWsNoConfig - } - a.configExt.Store(bc) - a.configExtUpdateTime.Store(time.Now()) - a.ops.Store(*b.Ops()) - return true, nil -} - -// Conf implements RpnAcc. -func (a *WsClient) Conf(cc string) (string, error) { - cfg := a.config() - if cfg == nil { - return "", errWsNoConfig - } - usePerma := !disablePermaCreds && a.Ops().Perma() - if usePerma && cfg.PermaCreds == nil { - usePerma = false - log.E("ws: conf: permacreds requested but nil; using dynamic creds") - } - portstr := "" - if port := a.Ops().Port(); port > 0 { - portstr = fmt.Sprintf("%d", port) // port may be 0 - } - city := "" - if cccsv := strings.Split(cc, confKeySep); len(cccsv) >= 2 { - city = cccsv[0] - cc = cccsv[1] - } - visited := make(map[string]struct{}, 0) - // in sync with anyCountryCode / noCountryForOldMen vars in proxy.go - chooseAny := cc == "**" || len(cc) <= 0 - hasCity := len(city) > 0 - tot := 0 - c := 0 - out := make([]string, 0, maxPerRegionWgConfs) - ids := make([]string, 0, maxPerRegionWgConfs) - for _, rc := range cfg.Configs { - // TODO: strings.HasSuffix(rc.Cc, cc) replaced with ==? - if (chooseAny || strings.HasSuffix(rc.CC, cc)) && (!hasCity || rc.City == city) { - if chooseAny { - if _, ok := visited[rc.CC]; ok { - continue - } - visited[rc.CC] = struct{}{} - if c > 2 { - // choose only low load and high link speed servers - gbps10 := rc.Link >= 10000 - healthy50 := rc.Load <= 50 - gbps1 := rc.Link >= 1000 - healthy20 := rc.Load <= 20 - if !gbps10 && !gbps1 { - continue - } - if (gbps10 && !healthy50) || (gbps1 && !healthy20) { - continue - } - } - if c >= maxAnyWgConfs { - // generate maxAnyWgConfs across all CCs. - break - } - } else { - if !hasCity && c >= maxPerRegionWgConfs { - // not chooseAny; city not specified; - // generate maxPerRegionWgConfs for this cc. - break - } - } - - var confstr string - var confok bool - if usePerma && cfg.PermaCreds != nil { - confstr, confok = rc.MakeUapiConfig(cfg.PermaCreds, portstr) - } else { - confstr, confok = rc.MakeUapiConfig(cfg.Creds, portstr) - } - if confok { - out = append(out, confstr) - ids = append(ids, strings.Join([]string{rc.CC, rc.City, rc.Name}, "/")) - c++ - } - } - tot++ - } - if len(out) > 0 { - r := rand.IntN(len(out)) - log.I("ws: conf: cc %s(%s): %d/%d => chosen (any? %t): %d[%s] (port: %s)", cc, city, c, len(out), chooseAny, r, ids[r], portstr) - return out[r], nil - } - log.E("ws: conf: cc %s(%s) not found (tot: %d)", cc, city, tot) - return "", errWsNoCcConfig -} - -// unused on the control plane, so use a fixed but valid hostname -func fixedValidWsEndpoint(test bool) string { - if test { - return "ca.windscribe.dev" - } - return "ca.windscribe.com" -} - -func baseurl(test bool, cid, did string) *url.URL { - u := url.URL{ - Scheme: "https", - Host: svchosttest, - } - q := u.Query() - q.Set("cid", cid) - q.Set("did", did) - if test { - q.Set("rpn", "wstest") - q.Set("test", "") // value for the test param does not matter - } else { - q.Set("rpn", "ws") - } - u.RawQuery = q.Encode() - - return &u -} - -func assetsurl(test bool, cid, did string) *url.URL { - u := url.URL{ - Scheme: "https", - Host: svchosttest, - } - q := u.Query() - q.Set("cid", cid) - q.Set("did", did) - if test { - q.Set("rpn", "wsassetstest") - q.Set("test", "") // value for the test param does not matter - } else { - q.Set("rpn", "wsassets") - } - u.RawQuery = q.Encode() - - return &u -} - -func authHeader(req *http.Request, t string) { - if req == nil { - return - } - req.Header.Set("Authorization", "Bearer "+t) -} - -// didHeader sets the didTokenHeader request header if tok is non-empty. -func didHeader(req *http.Request, tok string) { - if req != nil && len(tok) > 0 { - req.Header.Set(didTokenHeader, tok) - } -} - -func logDidToken(tok string) { - if len(tok) <= 0 { - log.W("ws: didtoken: empty token") - return - } - // token format is "a hextoken:expiryepochsec"; parse the epoch to log expiry. - parts := strings.SplitN(tok, ":", 2) - if len(parts) < 2 { - log.W("ws: didtoken: unknown format") - return - } - expSec, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) - if err != nil { - log.E("ws: didtoken: cannot parse expiry epoch: %v", err) - } - expTime := time.Unix(expSec, 0) - log.I("ws: didtoken: expiry %s; expired? %t", fmtTime(expTime), expTime.Before(time.Now())) -} - -// updateDidTokenIfNeeded reads didTokenHeader from the response, logs its expiry, -// and โ€“ when it differs from ent.DidToken โ€“ overwrites it in ent. -func updateDidTokenIfNeeded(ent *WsEntitlement, res *http.Response) { - if ent == nil || res == nil { - return - } - incoming := res.Header.Get(didTokenHeader) - logDidToken(incoming) - if len(incoming) > 0 && ent.DidToken != incoming { - ent.DidToken = incoming - } -} - -func wsErr(res *http.Response, op string) error { - _, err := wsErr2(res, op) - return err -} - -func wsErr2(res *http.Response, op string) (*WsErrorResponse, error) { - if res == nil { - return nil, log.EE("ws: %s: %v", op, errWsNoResponse) - } - code := res.StatusCode - body, err := io.ReadAll(res.Body) - if err != nil { - return nil, log.EE("ws: %s: (%d) read body err: %v", op, code, err) - } - - var wsErr WsErrorResponse - err = json.Unmarshal(body, &wsErr) - if err != nil { - return nil, log.EE("ws: %s: (%d) unmarshal err: %v; body: %s", op, code, err, truncate2k(body)) - } - - if len(wsErr.Error) > 0 { - wsErr.Msg += "/" + wsErr.Error - } - if len(wsErr.Desc) <= 0 { - wsErr.Desc += "/" + wsErr.Details - } - if len(wsErr.Msg) <= 0 { - wsErr.Msg = string(truncate2k(body)) - } - - return &wsErr, log.EE("ws: %s: (%d) error %d: %s; why: %s", op, code, wsErr.Code, wsErr.Msg, wsErr.Desc) -} - -func wsRes[T any](res *http.Response, out *T, op string) (*T, error) { - if res == nil { - return nil, log.EE("ws: %s: %v", op, errWsNoResponse) - } - if out == nil { - return nil, log.EE("ws: %s: out is nil", op) - } - body, err := io.ReadAll(res.Body) - if err != nil { - return nil, log.EE("ws: %s: read res err: %v", op, err) - } - - err = json.Unmarshal(body, out) - if err != nil { - return nil, log.EE("ws: %s: unmarshal err: %v; res: %s", op, err, truncate2k(body)) - } - - if settings.Debug { - log.V("ws: wgconfs: %s: res json: %+v", op, out) - } - - return out, nil -} - -func getSession(h *http.Client, ent *WsEntitlement) (*WsSession, error) { - if ent == nil { - return nil, errWsNoSession - } - tok := ent.SessionToken - if len(tok) <= 0 { - return nil, errWsNoToken - } - cid := ent.Cid - if len(cid) <= 0 { - return nil, errWsNoCid - } - did := ent.Did - tokst := tokenState(tok) - /* - curl -x GET '.../Session' - -H 'Authorization: Bearer id:typ:epochsec:sig1:sig2' - */ - u := baseurl(ent.TestDomain, cid, did).JoinPath(wssessionpath) - req, err := http.NewRequest("GET", u.String(), nil) - if err != nil { - return nil, log.EE("ws: getsess: make req err: %v", err) - } - authHeader(req, tok) - didHeader(req, ent.DidToken) - - if settings.Debug { - log.V("ws: getsess: req: %s tok %s", u.String(), tokst) - } - - res, err := h.Do(req) - if err != nil || res == nil { - return nil, log.EE("ws: getsess: res err (nil? %t / tok? %s): %v", res == nil, tokst, err) - } - defer core.Close(res.Body) - updateDidTokenIfNeeded(ent, res) - if res.StatusCode != http.StatusOK { - return nil, wsErr(res, "getsess/"+tokst) - } - - var wsSess WsSessionResponse - _, err = wsRes(res, &wsSess, "getsess/"+tokst) - if err != nil { - return nil, err - } - - return &wsSess.Data, nil -} - -func skipWsServer(server WsServerList) (bool, string) { - if onlyPremiumServers && server.PremiumOnly != 1 { // skip non-premium servers - return true, "not premium" - } else if server.Status != 1 { // skip servers that are not okay - return true, "status not okay" - } else if len(server.Groups) <= 0 { - return true, "no groups" // skip servers without groups - } // else if: skip server.P2P == 0? - return false, "" // this server is okay to use -} - -func wsRandomPort() string { - // return a random port from the list of WireGuard ports - // return wswgports[rand.Int32N(int32(len(wswgports)))] - if rand.Uint()%2 == 0 { - return wswgports[0] - } - return wswgports[1] -} - -func wsRandomIP3(nodes []WsServerNode) string { - if len(nodes) <= 0 { - return "" - } - return nodes[rand.Int32N(int32(len(nodes)))].IP3 -} - -func hasIP3(nodes []WsServerNode) bool { - if len(nodes) <= 0 { - return false - } - for _, node := range nodes { - if len(node.IP3) > 0 { - return true - } - } - return false -} - -func convertToRegionalWgConfs(id *WsWgCreds, list []WsServerList, test bool, port string) ([]*RegionalWgConf, error) { - if id == nil || len(id.DNS) <= 0 || len(id.Address) <= 0 || len(list) <= 0 { - return nil, fmt.Errorf("regional configs err: DNS/Addr/creds? %t; servers? %d", - id != nil, len(list)) - } - - tot := make(map[string]int) - out := make([]*RegionalWgConf, 0, len(list)) - for _, server := range list { - if !test { - if skip, why := skipWsServer(server); skip { - log.VV("ws: conf: convert skip; %s: %s", server.CountryCode, why) - continue - } - } - - cc := server.CountryCode - portStr := port - if len(portStr) <= 0 { - portStr = wsRandomPort() - } - sorted := core.Sort(server.Groups, func(a, b WsServerGroup) int { - ia, _ := strconv.ParseInt(a.LinkSpeed, 10, 64) - ib, _ := strconv.ParseInt(b.LinkSpeed, 10, 64) - // max(..., wsMinServerLinkSpeed) preserves actual link speed (100/1000/10000 mbps) - // rather than clamping everything to 1 with min. - // Score: higher = healthier (lower load) AND faster link; best servers first. - la := max(int(ia), wsMinServerLinkSpeed) * (wsMaxServerHealth - a.Health) - lb := max(int(ib), wsMinServerLinkSpeed) * (wsMaxServerHealth - b.Health) - if la > lb { // descending: highest composite score (healthiest + fastest) first - return -1 - } else if la < lb { - return 1 - } - return 0 - }) - - for _, group := range sorted { - servername := group.City + " (" + group.Nick + ")" - if len(group.WgPubKey) <= 0 || len(group.WgEndpoint) <= 0 { - continue // skip servers without wg - } - noip3 := !hasIP3(group.Nodes) - if len(group.Nodes) <= 0 || noip3 { - log.W("ws: wgconfs: no nodes in %s (%s); ip3? %t", group.City, group.Nick, noip3) - continue // skip servers without nodes - } - if !allPerRegionWgConfs && tot[cc] >= maxPerRegionWgConfs*2 { - log.D("ws: wgconfs: skip! %s (%s) has %d configs already", - cc, servername, tot[cc]) - break // we have enough configs for this region - } - tot[cc] = tot[cc] + 1 - dnsaddr := id.DNS - if len(dnsaddr) <= 0 { - dnsaddr = cfdns4 - } - // Use any IPv4 permutation of AllowedIPs. The API only sends a hint. - // IPv6s are firewalled. - allowed := []string{gw4} - linkspeed, lerr := strconv.Atoi(group.LinkSpeed) - out = append(out, &RegionalWgConf{ - CC: strings.ToUpper(server.CountryCode), - City: strings.ToUpper(group.City), - Name: servername, - Load: int32(group.Health), - Link: int32(linkspeed), - Count: int32(len(group.Nodes)), - Premium: server.PremiumOnly == 1, - ClientAddr4: id.Address, - ClientPrivKey: id.PrivateKey, - ClientPubKey: id.PublicKey, - ClientDNS4: dnsaddr, - PskKey: id.PresharedKey, - ServerPubKey: group.WgPubKey, - ServerDomainPort: net.JoinHostPort(group.WgEndpoint, portStr), - ServerIPPort4: net.JoinHostPort(wsRandomIP3(group.Nodes), portStr), - AllowedIPs: allowed, - }) - if settings.Debug { - log.VV("ws: wgconfs: gen for %s (%s) [load: %d; link: %s; count: %d]; total for %s: %d; errs? %v", - group.City, group.Nick, group.Health, group.LinkSpeed, len(group.Nodes), cc, tot[cc], lerr) - } - } - } - - if len(out) <= 0 { - return nil, errWsBadServerList - } - - return out, nil -} - -func tokenState(t string) (s string) { - l := strconv.Itoa(len(t)) - if len(t) <= 0 { - s = "notok-" - } else if len(strings.Split(t, ":")) > 4 { - s = "plaintok-" + l - } else { - s = "enctok-" + l - } - return -} - -func (a *WsWgConfig) tokenState() string { - if a == nil { - return "no-cfg" - } - s1, s2 := "no-ent", "no-sess" - if ent := a.Entitlement; ent != nil { - s1 = "ent-" + tokenState(ent.SessionToken) - } - if sess := a.Session; sess != nil { - s2 = "sess-" + tokenState(sess.SessionToken) - } - return s1 + " | " + s2 -} - -func getServerList(h *http.Client, sess *WsSession, ent *WsEntitlement) (*WsServerListResponse, error) { - if sess == nil || ent == nil { - return nil, errWsNoSession - } - lochash := sess.LocHash - if len(lochash) <= 0 { - return nil, errWsNoLocHash - } - bearer := sess.SessionToken - if len(bearer) <= 0 { - return nil, errWsNoToken - } - cid := ent.Cid - if len(cid) <= 0 { - return nil, errWsNoCid - } - did := ent.Did - test := ent.TestDomain - - // curl -x GET '.../serverlist/mob-v2/1/' - u := assetsurl(test, cid, did).JoinPath(wslocpath, lochash) - locreq, err := http.NewRequest("GET", u.String(), nil) - if err != nil { - return nil, log.EE("ws: wgconfs: req err: %v", err) - } - didHeader(locreq, ent.DidToken) - - if settings.Debug { - log.V("ws: wgconfs: req: %s tok %s", u.String(), tokenState(bearer)) - } - - locres, err := h.Do(locreq) - if err != nil || locres == nil { - return nil, log.EE("ws: wgconfs: res err (nil? %t): %v", locres == nil, err) - } - - defer core.Close(locres.Body) - updateDidTokenIfNeeded(ent, locres) - if locres.StatusCode != http.StatusOK { - return nil, wsErr(locres, "wgconfs") - } - - var wsServerList WsServerListResponse - - return wsRes(locres, &wsServerList, "wgconfs") -} - -// initAndConnectCreds registers creds via /WgConfigs/init and then either: -// - dynamic creds (perma=false): reserves a WireGuard interface via /WgConfigs/connect. -// - perma creds (perma=true): registers the init'd pubkey via /WgConfigs/permanent. -// -// For perma=true, if existingCreds is non-nil and its pubkey is still present in -// /WgConfigs/list_keys, existingCreds is returned as-is (no /init or /permanent needed). -// If the key is no longer listed, a fresh /init + /permanent cycle is performed. -// The /init error handling is identical for both dynamic and perma paths. -func initAndConnectCreds(h *http.Client, existingCreds *WsWgCreds, perma bool, sess *WsSession, ent *WsEntitlement, forceInit bool) (*WsWgCreds, error) { - if sess == nil || ent == nil { - return nil, errWsNoSession - } - bearer := sess.SessionToken - if len(bearer) <= 0 { - return nil, errWsNoToken - } - cid := ent.Cid - if len(cid) <= 0 { - return nil, errWsNoCid - } - test := ent.TestDomain - tokst := "sess-" + tokenState(bearer) - - if perma && disablePermaCreds { - log.W("ws: wgconfs: perma creds disabled; skipping...") - return nil, nil // perma creds disabled; no API calls, no error - } - - force := "0" // 0 when forced registration (which deletes older keys) is not needed - - // For perma creds, check whether the existing pubkey is still registered on the server. - // If found, reuse it directly without any /init or /permanent API call. - // If not found (key was dropped from the server list), discard the stale creds so the - // key-gen + /init + /permanent path below produces a fresh registration. - if perma && existingCreds != nil && len(existingCreds.PublicKey) > 0 { - var kerr error - for range 2 { - var keys *WsWgListKeysResponse - keys, kerr = listKeys(h, ent, bearer) - if kerr != nil || keys == nil { - log.E("ws: wgconfs: perma: list keys err (nil? %t / tok? %s): %v", keys == nil, tokst, kerr) - wsBriefPauseBeforeRetry() - continue - } - for _, k := range keys.Data.PubKeys { - if k == existingCreds.PublicKey { - log.I("ws: wgconfs: perma: existing key %s active (%s); reusing", trunc8(existingCreds.PublicKey), tokst) - return existingCreds, nil - } - } - if len(keys.Data.PubKeys) >= wsMaxPermaWgKeys { - force = "1" // does not yet work - } - break - } - - // TODO: creds not generated by the server are not returned by listKeys anyway - // if kerr != nil { - // return nil, log.EE("ws: wgconfs: perma: failed key verification; err: %v", kerr) - // } - // pubkey no longer in list; discard existing creds so /init generates a fresh keypair - forceInit = false - } // fallthrough to WgConfigs/init the credential - - runkey := 0 - runinit := 0 - runconnect := 0 - keyed := 0 -keyagain: - useExistingCreds := existingCreds != nil && keyed == 0 && !forceInit - runkey += 1 - - var priv x.WgKey - if !useExistingCreds { - var err error - priv, err = x.NewWgPrivateKey() - if err != nil { - return nil, log.EE("ws: wgconfs: gen key #%d (perma? %t) err: %v", runkey, perma, err) - } - } else { - var err error - // use the existing key, which is already registered - priv, err = x.NewWgPrivateKeyOf(existingCreds.PrivateKey) - if err != nil { - return nil, log.EE("ws: wgconfs: existing key #%d (perma? %t) err: %v", runkey, perma, err) - } - } - pub := priv.Mult() - pubkeybase64 := pub.Base64() - - log.I("ws: wgconfs: gen creds: pubkey: %s, existing key #%d? %t; force? %t; perma? %t", - trunc8(pubkeybase64), runkey, useExistingCreds, forceInit, perma) - -initagain: - keyNeedsInit := !useExistingCreds || force == "1" - runinit += 1 - - details := fmt.Sprintf("pub: %s, keyed#%d? %t; usingExisting#%d? %t; forceinit? %t; perma? %t", - trunc8(pubkeybase64), runkey, keyNeedsInit, runinit, useExistingCreds, force == "1", perma) - - var creds *WsWgCreds - if keyNeedsInit { - // POST 'https://api-staging.windscribe.com/WgConfigs/init' \ - // --header 'Content-Type: application/x-www-form-urlencoded' \ - // --header 'Authorization: Bearer id:typ:epochsec:sig1:sig2' \ - // --data-urlencode 'force_init=1' - // --data-urlencode 'wg_pubkey=base64' - initdata := url.Values{} - initdata.Set("wg_pubkey", pubkeybase64) - initdata.Set("force_init", force) - u := baseurl(test, cid, ent.Did).JoinPath(wswginitpath) - initreq, err := http.NewRequest("POST", u.String(), strings.NewReader(initdata.Encode())) - if err != nil { - return nil, log.EE("ws: wgconfs: %s req err: %v", details, err) - } - initreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - authHeader(initreq, bearer) - didHeader(initreq, ent.DidToken) - - if settings.Debug { - log.V("ws: wgconfs: %s init req: %s; tok %s; force %s", details, u.String(), tokst, force) - } - - initres, err := h.Do(initreq) - - if err != nil || initres == nil { - return nil, log.EE("ws: wgconfs: %s res err (nil? %t / tok? %s): %v", details, initres == nil, tokst, err) - } - updateDidTokenIfNeeded(ent, initres) - - if initres.StatusCode != http.StatusOK { - wserr, err := wsErr2(initres, "wsinit") - core.Close(initres.Body) - if wserr != nil && wserr.Code == ekeylimit { - if force != "1" { - log.I("ws: wgconfs: redo init with force %s; err: %v", details, err) - force = "1" - goto initagain - } - } - log.E("ws: wgconfs: init %s; err: %v", details, err) - return nil, err - } - - defer core.Close(initres.Body) - - var wgCreds WsWgCredsResponse - _, err = wsRes(initres, &wgCreds, "wgconfs") - if err != nil { - return nil, err - } - - d := wgCreds.Data - creds = &d.Config - if d.Success != 1 { - return nil, log.EE("ws: wgconfs: %s success != 1; debug: %v", details, d.Debug) - } - if len(d.Config.PrivateKey) <= 0 { // private key is generated locally (by the client) - d.Config.PrivateKey = priv.Base64() - if len(d.Config.PublicKey) > 0 && d.Config.PublicKey != pubkeybase64 { // registered public key must match the local one - return nil, log.EE("ws: wgconfs: pubkey mismatch; expected %s, got %s", - pubkeybase64, d.Config.PublicKey) - } - d.Config.PublicKey = pubkeybase64 - } // TODO: else panic? - } else { - creds = existingCreds - } - - if creds == nil || len(creds.PublicKey) <= 0 || len(creds.PrivateKey) <= 0 { - return nil, log.EE("ws: wgconfs: missing pub/priv creds %s", details) - } - - log.I("ws: wgconfs: got creds;" + details) - - if perma { - return createPermaCreds(h, ent, bearer, pubkeybase64) - } - - someEndpoint := fixedValidWsEndpoint(test) - -connectagain: - runconnect += 1 - // github.com/Windscribe/Android-App/blob/746d505dc69/base/src/main/java/com/windscribe/vpn/backend/utils/WindVpnController.kt#L159 - /* - curl -x POST '.../WgConfigs/connect' \ - --data-urlencode 'hostname=<>' \ - --data-urlencode 'wg_pubkey=stdbase64==' - --data-urlencode 'wg_ttl=3600' - -H 'Content-Type: application/x-www-form-urlencoded' \ - -H 'Authorization: Bearer id:typ:epochsec:sig1:sig2' - */ - cdata := url.Values{} - // The "hostname" for WgConfigs/connect call is requested, but currently it - // does nothing as we never made use of this server side. - cdata.Set("hostname", someEndpoint) - cdata.Set("wg_pubkey", pubkeybase64) - cdata.Set("wg_ttl", wgttl) - u := baseurl(test, cid, ent.Did).JoinPath(wswgconnectpath) - creq, err := http.NewRequest("POST", u.String(), strings.NewReader(cdata.Encode())) - if err != nil { - return nil, log.EE("ws: wgconfs: %s connect#%d req err: %v", details, runconnect, err) - } - creq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - authHeader(creq, sess.SessionToken) - didHeader(creq, ent.DidToken) - - if settings.Debug { - log.V("ws: wgconfs: %s connect#%d req: %s tok %s", details, runconnect, u.String(), tokst) - } - - cres, err := h.Do(creq) - if err != nil || cres == nil { - return nil, log.EE("ws: wgconfs: %s connect#%d res err (nil? %t / tok? %s): %v", - details, runconnect, cres == nil, tokst, err) - } - updateDidTokenIfNeeded(ent, cres) - if cres.StatusCode != http.StatusOK { - wserr, err := wsErr2(cres, "wsconnect") - core.Close(cres.Body) - if wserr != nil && wserr.Code == ekeyinvalid { // the key was deleted! - if keyed == 0 { - keyed = 1 - goto keyagain // try again with a non-default key - } - } else if wserr != nil && wserr.Code == enoaddr && runconnect < 2 { - time.Sleep(3 * time.Second) // wait a bit before retrying once - goto connectagain // retry connect - } - return nil, err - } - - var wgConnect WsWgConnectResponse - _, err = wsRes(cres, &wgConnect, "wgconfs") - defer core.Close(cres.Body) - if err != nil { - return nil, log.EE("ws: wgconfs: %s connect#%d res err: %v", - details, runconnect, err) - } - - // TODO: goto connectagain if runconnect < 2? - if wgConnect.Data.Success != 1 { - return nil, log.EE("ws: wgconfs: %s connect#%d success != 1; debug: %v", - details, runconnect, wgConnect.Data.Debug) - } - // TODO: goto connectagain if runconnect < 2? - if len(wgConnect.Data.Config.Address) <= 0 || len(wgConnect.Data.Config.DNS) <= 0 { - return nil, log.EE("ws: wgconfs: %s connect#%d missing config; debug: %v", - details, runconnect, wgConnect.Data.Debug) - } - - if len(creds.Address) <= 0 { - creds.Address = wgConnect.Data.Config.Address - } - if len(creds.DNS) <= 0 { - creds.DNS = wgConnect.Data.Config.DNS - } - - log.I("ws: wgconfs: got connect data; %s; config addr: %s, dns: %s; perma? %t", - details, wgConnect.Data.Config.Address, wgConnect.Data.Config.DNS, perma) - - return creds, nil -} - -func genWgConfs(h *http.Client, existingCreds *WsWgCreds, existingPermaCreds *WsWgPermanentConfig, sess *WsSession, servers []WsServerList, ent *WsEntitlement, ops x.RpnOps) (*WsWgCreds, *WsWgPermanentConfig, []*RegionalWgConf, error) { - if sess == nil || ent == nil { - return nil, nil, nil, errWsNoSession - } - forceInit := ops.Rotate() - port := "" - if ops.Port() > 0 { - port = strconv.FormatUint(uint64(ops.Port()), 10) - } - if len(sess.LocHash) <= 0 { - return nil, nil, nil, errWsNoLocHash - } - bearer := sess.SessionToken - if len(bearer) <= 0 { - return nil, nil, nil, errWsNoToken - } - if len(ent.Cid) <= 0 { - return nil, nil, nil, errWsNoCid - } - test := ent.TestDomain - tokst := "sess-" + tokenState(bearer) - - creds, err := initAndConnectCreds(h, existingCreds, false /*dynamic*/, sess, ent, forceInit) - if err != nil || creds == nil { - return nil, nil, nil, core.OneErr(err, errWsNoConfig) - } - - // TODO: if wgConnectData.Config.Address has not changed and existingCreds is non-nil, - // then we do not have to generate regional configs again (unless location hash has changed). - regconfs, err := convertToRegionalWgConfs(creds, servers, test, port) - if err != nil || len(regconfs) <= 0 { - return nil, nil, nil, log.EE("ws: wgconfs: (test? %t / tok? %s) no regions found: %v", test, tokst, err) - } - - // attempt to generate or reuse a permanent WG config (best-effort; non-fatal) - var permaCreds *WsWgPermanentConfig - if !disablePermaCreds { - var permaErr error - permaCreds, permaErr = initAndConnectCreds(h, existingPermaCreds /*may be nil*/, true /*perma*/, sess, ent, false /*forceInit is not useful for perma*/) - if permaErr != nil || permaCreds == nil { - log.W("ws: wgconfs: permacreds err (ops: %v): %v", &ops, permaErr) - permaCreds = existingPermaCreds // keep existing on error - } - } - - log.I("ws: wgconfs: ok (test? %t / tok? %s) found %d regions", test, tokst, len(regconfs)) - - return creds, permaCreds, regconfs, nil -} - -func (a *WsClient) kid() string { - if a == nil { - return "" - } - c := a.config() - if c == nil { - return "" - } - pub := c.Creds.PublicKey - if len(pub) <= 0 { - return "" - } - return trunc8(pub) -} - -func trunc8(s string) string { - if len(s) <= 8 { - return s[:3] - } - if len(s) <= 16 { - return s[:2] + ".." + s[len(s)-2:] - } - return s[:4] + ".." + s[len(s)-4:] -} - -func newWsGw(c *WsWgConfig, h *http.Client, o x.RpnOps) (*WsClient, error) { - if h == nil || c == nil || c.Session == nil || c.Creds == nil { - return nil, errWsBadGatewayArgs - } - a := &WsClient{ - http: h, - ops: core.NewVolatile(o), - configExt: core.NewVolatile(c), - configExtUpdateTime: core.NewVolatile(time.Now()), - } - - log.I("ws: gw: for %s/%s; ops: %s; from: %s until: %s", - a.Who(), c.tokenState(), a.Ops(), fmtUnixMillis(a.Created()), fmtUnixMillis(a.Expires())) - - return a, nil -} - -// setDidIfNeeded assigns did to ent.Did. If ent.Did is already set and does not match -// the incoming did, errWsDidMismatch is returned and ent is left unchanged. -func setDidIfNeeded(ent *WsEntitlement, did string) error { - if ent == nil { - return errWsNoEntitlement - } - if existing := ent.Did; len(existing) > 0 && existing != did { - log.E("ws: did mismatch: existing %s, incoming %s", trunc8(existing), trunc8(did)) - return errWsDidMismatch - } - ent.Did = did - return nil -} - -func (w *BaseClient) MakeWsWg(entitlement []byte, did string, ops x.RpnOps) (*WsClient, error) { - if len(entitlement) <= 0 { - return nil, errWsNoEntitlement - } - if len(did) <= 0 { - return nil, errWsNoDid - } - - var ent WsEntitlement - err := json.Unmarshal(entitlement, &ent) - if err != nil { - return nil, err - } - - // TODO: if ent already has did set; then err on mismatch? - if err := setDidIfNeeded(&ent, did); err != nil { - return nil, err - } - return makeWsWg(&w.h2, &ent, ops) -} - -func makeWsWg(h *http.Client, ent *WsEntitlement, ops x.RpnOps) (*WsClient, error) { - if ent == nil || len(ent.SessionToken) <= 0 { - log.E("ws: makeWsWg: entitlement is nil") - return nil, errWsNoEntitlement - } - - sess, err := getSession(h, ent) - if err != nil { - return nil, err - } - - servers, err := getServerList(h, sess, ent) - if err != nil { - return nil, err - } - - creds, permaCreds, wgconfs, err := genWgConfs(h, nil, nil, sess, servers.Data, ent, ops) - if err != nil { - return nil, err - } - - cfg := &WsWgConfig{ - Entitlement: ent, - Session: sess, - Configs: wgconfs, - Servers: servers.Data, - Creds: creds, - PermaCreds: permaCreds, // may be nil - } - - return newWsGw(cfg, h, ops) -} - -func (w *BaseClient) MakeWsEntitlement(entitlementOrStateJson []byte, did string) (x.RpnEntitlement, error) { - if len(entitlementOrStateJson) <= 0 { - return nil, errWsNoEntitlement - } - if len(did) <= 0 { - return nil, errWsNoDid - } - - var ent WsEntitlement - err1 := json.Unmarshal(entitlementOrStateJson, &ent) - if err1 == nil { - if err := setDidIfNeeded(&ent, did); err != nil { - return nil, err - } - return &ent, nil - } - var existingConf WsWgConfig - err2 := json.Unmarshal(entitlementOrStateJson, &existingConf) - if err2 == nil && existingConf.Entitlement != nil && len(existingConf.Entitlement.SessionToken) > 0 { - ent := existingConf.Entitlement - if err := setDidIfNeeded(ent, did); err != nil { - return nil, err - } - return ent, nil - } - return nil, core.JoinErr(err1, err2) -} - -func (w *BaseClient) MakeWsWgFrom(entitlementOrWsConfigJson []byte, did string, ops x.RpnOps) (*WsClient, error) { - if len(entitlementOrWsConfigJson) <= 0 { - return nil, errWsNoJsonConfig - } - if len(did) <= 0 { - return nil, errWsNoDid - } - - var existingConf WsWgConfig - err := json.Unmarshal(entitlementOrWsConfigJson, &existingConf) - - sz := len(entitlementOrWsConfigJson) - hasEnt := existingConf.Entitlement != nil - hasTok := hasEnt && len(existingConf.Entitlement.SessionToken) > 0 - if err != nil || !hasEnt || !hasTok { - // may be this is an entitlement and not conf? - log.W("ws: make: unmarshal config (sz %d / hasEnt %t / hasTok %t) err? %v; retry as entitlement", - sz, hasEnt, hasTok, err) - return w.MakeWsWg(entitlementOrWsConfigJson, did, ops) - } - if err := setDidIfNeeded(existingConf.Entitlement, did); err != nil { - return nil, err - } - return w.makeWsWgFrom(&existingConf, ops) -} - -func (w *BaseClient) makeWsWgFrom(existingConf *WsWgConfig, ops x.RpnOps) (*WsClient, error) { - ws, _, err := makeWsWgFrom(&w.h2, existingConf, ops, false /*not updating*/) - return ws, err -} - -func makeWsWgFrom(h *http.Client, existingConf *WsWgConfig, ops x.RpnOps, updating bool) (ws *WsClient, refreshedSess bool, err error) { - existingEnt := existingConf.Entitlement - if existingEnt == nil || len(existingEnt.SessionToken) <= 0 { - err = errWsNoEntitlement - return - } - - performingUpdate := updating - // performingUpdate is set for "Update" calls only; that is, when remote api call fails to - // either init or init+connect, we can safely errors out on the "Update"; - - existingSess := existingConf.Session - existingCreds := existingConf.Creds - noExistingCreds := existingCreds == nil - noExistingSess := existingSess == nil || len(existingSess.SessionToken) <= 0 - if noExistingCreds || noExistingSess { - log.W("ws: make: no existing creds? %t; no existing sess? %t; getting new ws wg", noExistingCreds, noExistingSess) - ws, err = makeWsWg(h, existingEnt, ops) - refreshedSess = true - return - } - - tokst := existingConf.tokenState() - existingToken := existingSess.SessionToken - existingLocHash := existingSess.LocHash - if existingEnt.SessionToken != existingToken { - log.W("ws: make: entitlement does not match session; tok? %s", tokst) - } - - usingExitingSess := false - newSess, err := getSession(h, existingEnt) - if err == nil { - existingConf.Session = newSess // update session with the latest info - refreshedSess = true - } else { - usingExitingSess = true - log.W("ws: make: get session err: %v; using existing; tok? %s", err, tokst) - newSess = existingConf.Session // use existing session - } - - exp, err := time.Parse(time.DateOnly, newSess.ExpiryDate) - if err != nil { - err = log.EE("ws: make: parsing expiry %s (newSess? %t); err: %v", newSess.ExpiryDate, !usingExitingSess, err) - return - } - - active := exp.After(time.Now()) - existingServers := existingConf.Servers - // skip server refresh if ops requests it; but honour loc hash change regardless - downloadServerList := (existingLocHash != newSess.LocHash) || ops.FetchServers() - if active { - maybeNewServers := existingServers - hasnew := false - if downloadServerList { - newServersRes, err := getServerList(h, newSess, existingEnt) - - loge(err)("ws: make: lochash changed %s != %s / exlen(%d); fetch err? %v", - existingLocHash, newSess.LocHash, len(existingServers), err) - - if err == nil && newServersRes != nil && len(newServersRes.Data) > 0 { - maybeNewServers = newServersRes.Data - hasnew = true - } - - if len(maybeNewServers) <= 0 { // no new servers, no existing servers; bail - return nil, refreshedSess, core.OneErr(err, errWsNoServerList) - } - } - - // create wg confs from new or existing server list - // always reconfigure (as /WgConfigs/connect must be done once every wg_ttl, which is 60m) - maybeNewCreds, maybeNewPermaCreds, maybeNewWgConfs, uerr := genWgConfs(h, existingCreds, existingConf.PermaCreds, newSess, maybeNewServers, existingConf.Entitlement, ops) - loge(uerr)("ws: make: gen wg confs; tok? %s; downloadloc? %t / hasnewloc? %t len (%d/%d); ops: %v; err? %v", - tokst, downloadServerList, hasnew, len(existingServers), len(maybeNewServers), &ops, uerr) - - if uerr == nil { - existingConf.Servers = maybeNewServers - existingConf.Configs = maybeNewWgConfs - existingConf.Creds = maybeNewCreds - existingConf.PermaCreds = maybeNewPermaCreds // may be nil - } else if performingUpdate { - // error out early as this was meant to create an update config for later use - // but it itself is not the currently active config aka "existingConf" - return nil, refreshedSess, uerr - } - } else { - log.W("ws: make: session expired at %s (newSess? %t); tok? %s", fmtTime(exp), !usingExitingSess, tokst) - } - - ws, err = newWsGw(existingConf, h, ops) - return -} - -// listKeys calls GET WgConfigs/list_keys and returns the parsed response. -func listKeys(h *http.Client, ent *WsEntitlement, bearer string) (*WsWgListKeysResponse, error) { - if len(bearer) <= 0 { - return nil, errWsNoToken - } - tokst := tokenState(bearer) - // curl --location --request GET '.../WgConfigs/list_keys' \ - // --header 'Authorization: Bearer ' - u := baseurl(ent.TestDomain, ent.Cid, ent.Did).JoinPath(wswglistkeyspath) - req, err := http.NewRequest("GET", u.String(), nil) - if err != nil { - return nil, log.EE("ws: listkeys: req err: %v", err) - } - authHeader(req, bearer) - didHeader(req, ent.DidToken) - - if settings.Debug { - log.V("ws: listkeys: req: %s tok %s", u.String(), tokst) - } - - res, err := h.Do(req) - if err != nil || res == nil { - return nil, log.EE("ws: listkeys: do err (nil? %t / tok? %s): %v", res == nil, tokst, err) - } - defer core.Close(res.Body) - updateDidTokenIfNeeded(ent, res) - if res.StatusCode != http.StatusOK { - return nil, wsErr(res, "listkeys/"+tokst) - } - - var out WsWgListKeysResponse - _, err = wsRes(res, &out, "listkeys/"+tokst) - if err != nil { - return nil, err - } - pubkeys := core.Map(out.Data.PubKeys, func(pub string) string { return trunc8(pub) }) - log.I("ws: listkeys: ok (tok? %s); %d keys: %v", tokst, len(out.Data.PubKeys), pubkeys) - return &out, nil -} - -// createPermaCreds calls POST WgConfigs/permanent to create a permanent WG config. -// If pubkey is empty the server generates both the private and public keys. -func createPermaCreds(h *http.Client, ent *WsEntitlement, bearer, pubkey string) (*WsWgPermanentConfig, error) { - if len(bearer) <= 0 { - return nil, errWsNoToken - } - port := wsRandomPort() // some port; doesn't matter which one - tokst := tokenState(bearer) - - // curl --location --request POST '.../WgConfigs/permanent' \ - // --header 'Authorization: Bearer ' \ - // --data-urlencode 'port=443' \ - // --data-urlencode 'wg_pubkey=...' (optional) - data := url.Values{} - data.Set("port", port) - if len(pubkey) > 0 { // creds vended by the server - data.Set("wg_pubkey", pubkey) - } - - u := baseurl(ent.TestDomain, ent.Cid, ent.Did).JoinPath(wswgpermanentpath) - req, err := http.NewRequest("POST", u.String(), strings.NewReader(data.Encode())) - if err != nil { - return nil, log.EE("ws: conf: perma: req err: %v", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - authHeader(req, bearer) - didHeader(req, ent.DidToken) - - if settings.Debug { - log.V("ws: conf: perma: req: %s tok %s; port %s", u.String(), tokst, port) - } - - res, err := h.Do(req) - if err != nil || res == nil { - return nil, log.EE("ws: conf: perma: do err (nil? %t / tok? %s): %v", res == nil, tokst, err) - } - defer core.Close(res.Body) - updateDidTokenIfNeeded(ent, res) - if res.StatusCode != http.StatusOK { - return nil, wsErr(res, "confperma/"+tokst) - } - - var out WsWgCredsResponse - _, err = wsRes(res, &out, "confperma/"+tokst) - if err != nil { - return nil, err - } - if out.Data.Success != 1 { - return nil, log.EE("ws: conf: perma: success != 1; tok? %s; debug: %v", tokst, out.Data.Debug) - } - - cfg := out.Data.Config - log.I("ws: conf: perma: ok (tok? %s); pubkey: %s", tokst, trunc8(cfg.PublicKey)) - return &cfg, nil -} - -func wsBriefPauseBeforeRetry() { - time.Sleep(2200 * time.Millisecond) -} - -func loge(err error) log.LogFn { - if err == nil { - return log.I - } - return log.E -} - -func sha(p string) []byte { - return shab([]byte(p)) -} - -func shab(b []byte) []byte { - digest := sha256.Sum256(b) - return digest[:] -} - -func byte2hex(b []byte) string { - return hex.EncodeToString(b) -} - -func truncate2k(b []byte) []byte { - if len(b) <= 2048 { - return b - } - return b[:2048] -} diff --git a/intra/ipn/socks5.go b/intra/ipn/socks5.go deleted file mode 100644 index b1919f2d..00000000 --- a/intra/ipn/socks5.go +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ipn - -import ( - "context" - "errors" - "net" - "net/netip" - "strconv" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/ipn/multihost" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - tx "github.com/txthinking/socks5" - "golang.org/x/net/proxy" -) - -type socks5 struct { - NoFwd // no forwarding/listening - NoDNS // no dns - SkipRefresh // no refresh - GW // dual stack gateway - - id string // unique identifier - opts *settings.ProxyOptions // connect options - d protect.RDialer // dialer to this upstream proxy - outbound []proxy.Dialer // outbound dialers via this upstream proxy - px ProxyProvider // proxy provider - viaID *core.Volatile[string] // hop id - via *core.WeakRef[Proxy] // hop proxy - lastdial time.Time // last time this transport attempted a connection - status *core.Volatile[int] // status of this transport - done context.CancelFunc // cancel func -} - -type socks5tcpconn struct { - *tx.Client -} - -type socks5udpconn struct { - *tx.Client -} - -var _ core.TCPConn = (*socks5tcpconn)(nil) -var _ core.UDPConn = (*socks5udpconn)(nil) -var _ net.Conn = (*socks5tcpconn)(nil) // needed by golang/http transport -var _ net.Conn = (*socks5udpconn)(nil) - -func (c *socks5tcpconn) CloseRead() error { - if c.Client != nil && c.Client.TCPConn != nil { - core.CloseOp(c.Client.TCPConn, core.CopR) - return nil - } - return errNoProxyConn -} - -func (c *socks5tcpconn) CloseWrite() error { - if c.Client != nil && c.Client.TCPConn != nil { - core.CloseOp(c.Client.TCPConn, core.CopW) - return nil - } - return errNoProxyConn -} - -// WriteFrom writes b to TUN using addr as the source. -func (c *socks5udpconn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - if c.Client != nil && c.Client.UDPConn != nil { - if uconn, ok := c.Client.UDPConn.(*net.UDPConn); ok { - return uconn.WriteTo(b, addr) - } - return c.Client.UDPConn.Write(b) - } - return 0, errNoProxyConn -} - -// ReceiveTo is incoming TUN packet b to be sent to addr. -func (c *socks5udpconn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - if c.Client != nil && c.Client.UDPConn != nil { - if uconn, ok := c.Client.UDPConn.(*net.UDPConn); ok { - return uconn.ReadFrom(b) - } - return 0, nil, errNotUDPConn - } - return 0, nil, errNoProxyConn -} - -func NewSocks5Proxy(id string, ctx context.Context, ctl protect.Controller, px ProxyProvider, po *settings.ProxyOptions) (_ *socks5, err error) { - tx.Debug = settings.Debug - if po == nil { - log.W("proxy: err setting up socks5(%v): %v", po, err) - return nil, errMissingProxyOpt - } - - ctx, done := context.WithCancel(ctx) - - portnumber, _ := strconv.Atoi(po.Port) - mh := multihost.New(id) - mh.Add([]string{po.Host, po.IP}) // resolves if ip is name - - var clients []proxy.Dialer - // x.net.proxy doesn't yet support udp - // github.com/golang/net/blob/62affa334/internal/socks/socks.go#L233 - // if po.Auth.User and po.Auth.Password are empty strings, the upstream - // socks5 server may throw err when dialing with golang/net/x/proxy; - // although, txthinking/socks5 deals gracefully with empty auth strings - // fproxy, err = proxy.SOCKS5("udp", po.IPPort, po.Auth, proxy.Direct) - for _, ip := range mh.PreferredAddrs() { - ipport := netip.AddrPortFrom(ip.Addr(), uint16(portnumber)) - c, cerr := tx.NewClient(ipport.String(), po.Auth.User, po.Auth.Password, tcptimeoutsec, udptimeoutsec) - if cerr != nil { - err = errors.Join(err, cerr) - } else { - clients = append(clients, c) - } - } - - if len(clients) == 0 && err != nil { - defer done() - log.W("proxy: err creating socks5 for %v (opts: %v): %v", - mh, po, err) - return nil, err - } - - // always with a network namespace aware dialer - dialer := protect.MakeNsRDial(id, ctx, ctl) - h := &socks5{ - id: id, - d: dialer, - px: px, - outbound: clients, - viaID: core.NewZeroVolatile[string](), - opts: po, - done: done, - } - - tx.DialTCP = h.txdial // h.outbound uses this - tx.DialUDP = h.txdial // h.outbound uses this - - via, err := core.NewWeakRef(h.viafor, viaok) - if err != nil { - defer done() - log.W("proxy: socks5: %s err via: %v", h.ID(), err) - return nil, err - } - h.via = via - - log.D("proxy: socks5: created %s with clients(%d), opts(%s)", - h.id, len(clients), po) - - return h, nil -} - -func (h *socks5) viafor() *Proxy { - return viafor(h.id, h.viaID.Load(), h.px) -} - -func (h *socks5) swapVia(new Proxy) (old Proxy) { - return swapVia(h.id, new, h.viaID, h.via) -} - -func (h *socks5) txdial(n, src, dst string) (c net.Conn, err error) { - who := idstr(h) - if usevia(h.viaID) { - if v, vok := h.via.Get(); vok { - who = idstr(v) - c, err = v.DialBind(n, src, dst) - } else { - err = errNoHop - if removeViaOnErrors { - h.Hop(nil, false /*dryrun*/) // stale; unset - } - log.W("proxy: socks5: %s via(%s) failing...", h.id, idhandle(v)) - } - } else { - c, err = h.d.DialBind(n, src, dst) - } - logei(err)("proxy: socks5: %s dial(%s) from %s => %s (via %s); err? %v", h.id, n, h.GetAddr(), dst, who, err) - return -} - -// Handle implements Proxy. -func (h *socks5) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *socks5) DialerHandle() uintptr { - return core.Loc(h.d) -} - -// Dial implements Proxy. -func (h *socks5) Dial(network, addr string) (c protect.Conn, err error) { - return h.dial(network, "", addr) -} - -// DialBind implements Proxy. -func (h *socks5) DialBind(network, local, remote string) (c protect.Conn, err error) { - log.D("proxy: socks5: %s dialbind(%s) %s => %s; not supported", - h.ID(), network, local, remote) - return h.dial(network, local, remote) -} - -// todo: bind to local -func (h *socks5) dial(network, _, remote string) (c protect.Conn, err error) { - if err := candial(h.status); err != nil { - return nil, err - } - - h.lastdial = time.Now() - // todo: tx.Client can only dial in to ip:port and not host:port even for server addr - // tx.Client.Dial does not support dialing into client addr as hostnames - if c, err = dialers.ProxyDials(h.outbound, network, remote); err == nil { - // github.com/txthinking/socks5/blob/39268fae/client.go#L15 - if uc, ok := c.(*tx.Client); ok { - if uc.UDPConn != nil { // a udp conn will always have an embedded tcp conn - c = &socks5udpconn{uc} - } else if uc.TCPConn != nil { // a tcp conn will never also have a udp conn - c = &socks5tcpconn{uc} - } else { - log.W("proxy: socks5: %s conn not tcp nor udp %s => %s", - h.ID(), h.GetAddr(), remote) - core.CloseConn(c) - c = nil - err = errNoProxyConn - } - } else { - log.W("proxy: socks5: %s conn not a tx.Client(%s) %s => %s", - h.ID(), network, h.GetAddr(), remote) - core.CloseConn(c) - c = nil - err = core.OneErr(err, errNoProxyConn) - } - } else { - log.W("proxy: socks5: %s dial(%s) failed %s => %s: %v", - h.ID(), network, h.GetAddr(), remote, err) - } - defer localDialStatus(h.status, err) - return -} - -// Dialer implements Proxy. -func (h *socks5) Dialer() protect.RDialer { - return h -} - -// ID implements x.Proxy. -func (h *socks5) ID() string { - return h.id -} - -// Type implements x.Proxy. -func (h *socks5) Type() string { - return SOCKS5 -} - -// Router implements x.Proxy. -func (h *socks5) Router() x.Router { - return h -} - -// Reaches implements x.Router. -func (h *socks5) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (h *socks5) Hop(p Proxy, dryrun bool) error { - if p == nil { - if !dryrun { - old := h.swapVia(nil) - log.I("socks5: hop(%s) removed", idhandle(old)) - } - return nil - } - if p.Status() == END { - return errProxyStopped - } - - if !dryrun { - old := h.swapVia(p) - log.I("socks5: hop %s => %s", idhandle(old), idhandle(p)) - } - return nil -} - -// Via implements x.Router. -func (h *socks5) Via() (x.Proxy, error) { - if v := h.via.Load(); v != nil { - return v, nil - } - return nil, errNoHop -} - -// GetAddr implements x.Proxy. -func (h *socks5) GetAddr() string { - return h.opts.IPPort -} - -// Status implements Proxy. -func (h *socks5) Status() int { - s := h.status.Load() - if s != END && idling(h.lastdial) { - return TZZ - } - return s -} - -// Pause implements x.Proxy. -func (h *socks5) Pause() bool { - st := h.status.Load() - if st == END { - log.W("proxy: socks5: pause called when stopped") - return false - } - - ok := h.status.Cas(st, TPU) - log.I("proxy: socks5: paused? %t", ok) - return ok -} - -// Resume implements x.Proxy. -func (h *socks5) Resume() bool { - st := h.status.Load() - if st != TPU { - log.W("proxy: socks5: resume called when not paused; status %d", st) - return false - } - - ok := h.status.Cas(st, TUP) - go h.Refresh() // no-op since SkipRefresh - log.I("proxy: socks5: resumed? %t", ok) - return ok -} - -// Stop implements Proxy. -func (h *socks5) Stop() error { - h.status.Store(END) - h.done() - log.I("proxy: socks5: stopped %s", h.id) - return nil -} - -// OnProtoChange implements Proxy. -func (h *socks5) OnProtoChange(_ LinkProps) (string, bool) { - if err := candial(h.status); err != nil { - return "", false - } - return h.opts.FullUrl(), true -} diff --git a/intra/ipn/wg/amnezia.go b/intra/ipn/wg/amnezia.go deleted file mode 100644 index b3e067b6..00000000 --- a/intra/ipn/wg/amnezia.go +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package wg - -import ( - "crypto/rand" - "encoding/binary" - "fmt" - - "github.com/celzero/firestack/intra/log" - "golang.zx2c4.com/wireguard/device" -) - -// https://github.com/amnezia-vpn/amneziawg-go/pull/2/files - -const ( - // TODO: re-enable after figuring out how to account for - // changing the message header values for cookies mac1 & mac2 - // here: github.com/amnezia-vpn/amneziawg-go/blob/27e661d68e/device/send.go#L167 - disableAmenzia = true - sNoop = 0 // no-op size -) - -// Jc (Junk packet count) - number of packets with random data that are sent before the start of the session -// Jmin (Junk packet minimum size) - minimum packet size for Junk packet. That is, all randomly generated packets will have a size no smaller than Jmin. -// Jmax (Junk packet maximum size) - maximum size for Junk packets -// S1 (Init packet junk size) - the size of random data that will be added to the init packet, the size of which is initially fixed. -// S2 (Response packet junk size) - the size of random data that will be added to the response packet, the size of which is initially fixed. -// H1 (Init packet magic header) - the header of the first byte of the handshake -// H2 (Response packet magic header) - header of the first byte of the handshake response -// H4 (Transport packet magic header) - header of the packet of the data packet -// H3 (Underload packet magic header) - UnderLoad packet header." -type Amnezia struct { - id string - - Jc, Jmin, Jmax uint16 // unused: junk packet count, min, max - - S1, S2 uint16 // handshake init/resp pkt sizes - H1, H2, H3, H4 uint32 // modified msg types [4]byte -} - -func NewAmnezia(id string) *Amnezia { - return &Amnezia{ - id: id, - } -} - -func (a *Amnezia) String() string { - if a == nil { - return "" - } - if disableAmenzia { - return "" - } - if !a.Set() { - return "" - } - return fmt.Sprintf("%s: amnezia: jc(%d), jmin(%d), jmax(%d), s1(%d), s2(%d), h1(%d), h2(%d), h3(%d), h4(%d)", - a.id, a.Jc, a.Jmin, a.Jmax, a.S1, a.S2, a.H1, a.H2, a.H3, a.H4) -} - -func (a *Amnezia) Set() bool { - if a == nil || disableAmenzia { - return false - } - - return a.S1 > 0 || a.S2 > 0 || a.H1 > 0 || a.H2 > 0 || a.H3 > 0 || a.H4 > 0 -} - -func (a *Amnezia) Same(b *Amnezia) bool { - if a == nil && b == nil { - return false - } else if a == nil || b == nil { - return false - } - - return a.S1 == b.S1 && - a.S2 == b.S2 && - a.H1 == b.H1 && - a.H2 == b.H2 && - a.H3 == b.H3 && - a.H4 == b.H4 -} - -func (a *Amnezia) send(pktptr *[]byte) (ok bool) { - if a == nil || !a.Set() { - return - } - - pkt := *pktptr - - n := len(pkt) - if n < device.MinMessageSize { - return - } - - typ := binary.LittleEndian.Uint32(pkt) - - *pktptr, _ = a.instate(pkt) - - a.logIfNeeded("send", typ, n, len(*pktptr)) - - return true -} - -func (a *Amnezia) recv(pkt []byte, upto int) (out []byte, ok bool) { - if a == nil || !a.Set() { - return - } - if upto < device.MinMessageSize { - return - } - - var typ uint32 - // h := uint16(device.MessageTransportOffsetReceiver) - pkt, typ = a.strip(pkt[:upto]) - strippedSz := len(pkt) - - switch typ { - case device.MessageInitiationType, a.H1: - typ = device.MessageInitiationType - binary.LittleEndian.PutUint32(pkt, device.MessageInitiationType) - case device.MessageResponseType, a.H2: - typ = device.MessageResponseType - binary.LittleEndian.PutUint32(pkt, device.MessageResponseType) - case device.MessageCookieReplyType, a.H3: - typ = device.MessageCookieReplyType - binary.LittleEndian.PutUint32(pkt, device.MessageCookieReplyType) - case device.MessageTransportType, a.H4: // must be default? - typ = device.MessageTransportType - binary.LittleEndian.PutUint32(pkt, device.MessageTransportType) - default: - log.W("wg: %s: amnezia: recv: unexpected type %d", a.id, typ) - // TODO: error? - } - - a.logIfNeeded("recv", typ, strippedSz, upto) - - return pkt, true -} - -func (a *Amnezia) instate(pkt []byte) ([]byte, uint32) { - n := len(pkt) - - defaultType := binary.LittleEndian.Uint32(pkt) - - pad := uint16(0) - obsType := uint32(0) - maybeInstate := false - - switch defaultType { - case device.MessageInitiationType: - if n == device.MessageInitiationSize { - // github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/send.go#L130 - pad = a.S1 - obsType = a.H1 - maybeInstate = obsType > 0 || pad > 0 - } - case device.MessageResponseType: - if n == device.MessageResponseSize { - // github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/send.go#L198 - pad = a.S2 - obsType = a.H2 - maybeInstate = obsType > 0 || pad > 0 - } - case device.MessageCookieReplyType: - if n == device.MessageCookieReplySize { - pad = sNoop - obsType = a.H3 - maybeInstate = obsType > 0 - } - case device.MessageTransportType: - if n >= device.MinMessageSize { - pad = sNoop - obsType = a.H4 - maybeInstate = obsType > 0 - } - } - - log.VV("wg: %s: amnezia: instate: msg size: %d, msg typ: (d: %d, o: %d), pad? %t, s1/s2: %d/%d, do? %t", - a.id, n, defaultType, obsType, pad > 0, a.S1, a.S2, maybeInstate) - - useObsType := obsType > 0 && defaultType != obsType - if useObsType { - binary.LittleEndian.PutUint32(pkt, obsType) - } - // pad may be 0 - if random, err := blob(pad); err != nil { // unlikely - log.E("wg: %s: amnezia: instate: pad err %v", a.id, err) - } else if len(random) > 0 && len(random) == int(pad) { - pkt = append(random, pkt...) - } - - if useObsType { - return pkt, obsType - } - return pkt, defaultType -} - -func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) { - size := uint16(len(pkt)) - h := uint16(device.MessageTransportOffsetReceiver) - // assume the correct msg type is in just the first byte: - // github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56 - defaultType := binary.LittleEndian.Uint32(pkt[:h]) - - var discard uint16 = 0 - var possibleType uint32 = 0 - maybeStrip := false - - // ref: github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/device.go#L765 - if size == a.S1+device.MessageInitiationSize { - discard = a.S1 - possibleType = a.H1 - maybeStrip = discard > 0 - } else if size == a.S2+device.MessageResponseSize { - discard = a.S2 - possibleType = a.H2 - maybeStrip = discard > 0 - } // else: default - - log.VV("wg: %s: amnezia: strip: msg size: %d, msg typ (d: %d / o: %d), s1/s2: %d/%d, do? %t", - a.id, size, defaultType, possibleType, a.S1, a.S2, maybeStrip) - - if maybeStrip { - hdr := pkt[discard : discard+h] - obsType := binary.LittleEndian.Uint32(hdr) - if obsType == possibleType { - return pkt[discard:], obsType - } // else: msg type mismatch, but size matched - log.W("wg: %s: amnezia: strip: mismatched msg type %d != %d", a.id, obsType, possibleType) - } // else: nothing to discard - - return pkt, defaultType -} - -func (a *Amnezia) logIfNeeded(dir string, typ uint32, n int, newn int) { - switch typ { - case device.MessageInitiationType: - notok := n != device.MessageInitiationSize - logif(notok)("wg: %s: amnezia: %s: err initiation %d != %d (=> %d)", - a.id, dir, n, device.MessageInitiationSize, newn) - case device.MessageResponseType: - notok := n != device.MessageResponseSize - logif(notok)("wg: %s: amnezia: %s: err response %d != %d (=> %d)", - a.id, dir, n, device.MessageResponseSize, newn) - case device.MessageCookieReplyType: - notok := n != device.MessageCookieReplySize - logif(notok)("wg: %s: amnezia: %s: err cookie %d != %d (=> %d)", - a.id, dir, n, device.MessageCookieReplySize, newn) - case device.MessageTransportType: - notok := n < device.MinMessageSize - logif(notok)("wg: %s: amnezia: %s: err data %d < %d (=> %d)", - a.id, dir, n, device.MinMessageSize, newn) - default: - log.W("wg: %s: amnezia: %s: unexpected type %d; sz(pkt): %d => %d", - a.id, dir, typ, n, newn) - } -} - -// ref: github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/util.go#L1 -func blob(sz uint16) ([]byte, error) { - if sz == 0 { - return nil, nil - } - - junk := make([]byte, sz) - n, err := rand.Read(junk) - return junk[:n], err -} - -func logif(cond bool) log.LogFn { - if cond { - return log.D - } - return log.N -} diff --git a/intra/ipn/wg/controlfns.go b/intra/ipn/wg/controlfns.go deleted file mode 100644 index a83e70b5..00000000 --- a/intra/ipn/wg/controlfns.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package wg - -import ( - "fmt" - "runtime" - "syscall" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/protect" - "golang.org/x/sys/unix" -) - -// from: github.com/WireGuard/wireguard-go/blob/12269c276/conn/controlfns.go - -// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is -// the max supported by a default configuration of macOS. Some platforms will -// silently clamp the value to other maximums, such as linux clamping to -// net.core.{r,w}mem_max (see _linux.go for additional implementation that works -// around this limitation) -const socketBufferSize = 7 << 20 - -// controlFns is a list of functions that are called from the listen config -// that can apply socket options. -var controlFns = []protect.ControlFn{} - -// A net.ListenConfig (protect.MakeNsListenConfigExt) must apply the controlFns -// to the socket prior to bind. This is used to apply socket buffer sizing and -// packet information OOB configuration for sticky sockets. -// from: github.com/WireGuard/wireguard-go/blob/12269c276/conn/controlfns_linux.go -func init() { - controlFns = append(controlFns, - // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by - // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to - // fail silently - the result of failure is lower performance on very fast - // links or high latency links. - func(network, address string, c syscall.RawConn) error { - return c.Control(func(fd uintptr) { - // Set up to *mem_max - _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) - _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) - // Set beyond *mem_max if CAP_NET_ADMIN - _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) - _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) - }) - }, - - // Enable receiving of the packet information (IP_PKTINFO for IPv4, - // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. - func(network, address string, c syscall.RawConn) error { - var errc, errs error - switch network { - case "udp4": - if runtime.GOOS != "android" { - errc = c.Control(func(fd uintptr) { - errs = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) - }) - } - case "udp6": - errc = c.Control(func(fd uintptr) { - if runtime.GOOS != "android" { - errs = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) - if errs != nil { - return - } - } - errs = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) - }) - default: - errs = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) - } - loge(core.OneErr(errs, errc))("wg: control: done; IP_PKTINFO/IPV6_RECVPKTINFO") - return errs // discard errc - }, - - // Attempt to enable UDP_GRO - func(network, address string, c syscall.RawConn) error { - _ = c.Control(func(fd uintptr) { - _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) - }) - return nil - }, - ) -} diff --git a/intra/ipn/wg/gso.go b/intra/ipn/wg/gso.go deleted file mode 100644 index 73e68c80..00000000 --- a/intra/ipn/wg/gso.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package wg - -import ( - "errors" - "fmt" - "net" - "os" - "runtime" - "unsafe" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "golang.org/x/sys/unix" -) - -// from: github.com/WireGuard/wireguard-go/blob/12269c27/conn/gso_linux.go - -// TODO: GSO/GRO and mmsgs in pkg net: github.com/golang/go/issues/45886 - -const sizeOfGSOData = 2 - -// gsoControlSize returns the recommended buffer size for pooling UDP -// offloading control data. -var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) - -// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. -func getGSOSize(control []byte) (int, error) { - var ( - hdr unix.Cmsghdr - data []byte - rem = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) - if err != nil { - return 0, fmt.Errorf("error parsing socket control message: %w", err) - } - if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { - var gso uint16 - copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) - return int(gso), nil - } - } - return 0, nil -} - -// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing -// data in control untouched. -func setGSOSize(control *[]byte, gsoSize uint16) { - existingLen := len(*control) - avail := cap(*control) - existingLen - space := unix.CmsgSpace(sizeOfGSOData) - if avail < space { - return - } - *control = (*control)[:cap(*control)] - gsoControl := (*control)[existingLen:] - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) - hdr.Level = unix.SOL_UDP - hdr.Type = unix.UDP_SEGMENT - hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) - // github.com/WireGuard/wireguard-go/commit/f502ec3fad116d11109529bcf283e464f4822c18 - copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) - *control = (*control)[:existingLen+space] -} - -// from: github.com/WireGuard/wireguard-go/blob/12269c276/conn/features_linux.go -func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { - rc, err := conn.SyscallConn() - if err != nil { - log.W("wg: gso: syscall err: %v", err) - return - } - if rc == nil || core.IsNil(rc) { - log.W("wg: gso: syscall conn nil") - return - } - - var opt int - var errSyscallTx, errSyscallRx error - err = rc.Control(func(fd uintptr) { - _, errSyscallTx = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) - opt, errSyscallRx = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) - - }) - - if err != nil { - log.W("wg: gso: no support; err: %v", err) - return - } - - txOffload = errSyscallTx == nil - rxOffload = errSyscallRx == nil && opt == 1 - - log.I("wg: gso: txOffload: %t (errTx: %v), rxOffload: %t (opt: %d; errRx: %v)", - txOffload, rxOffload, errSyscallTx, opt, errSyscallRx) - return txOffload, rxOffload -} - -func supportsBatchRw() bool { - return runtime.GOOS == "linux" || runtime.GOOS == "android" -} - -// from: github.com/WireGuard/wireguard-go/blob/12269c276/conn/errors_linux.go# -func shouldDisableUDPGSOOnError(err error) bool { - if err == nil { - return false - } - var serr *os.SyscallError - if errors.As(err, &serr) { - // EIO is returned by udp_send_skb() if the device driver does not have - // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. - // See: - // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 - // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 - eio := serr != nil && serr.Err == unix.EIO - if eio { - log.W("wg: gso: EIO: %v", eio) - } - return eio - } - return false -} diff --git a/intra/ipn/wg/sosticky.go b/intra/ipn/wg/sosticky.go deleted file mode 100644 index 4c3ce5a7..00000000 --- a/intra/ipn/wg/sosticky.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package wg - -import "net/netip" - -// "sticky socks" disabled on Android: github.com/WireGuard/wireguard-go/commit/3a9e75374f - -func (e *StdNetEndpoint2) SrcIP() netip.Addr { - return netip.Addr{} -} - -func (e *StdNetEndpoint2) SrcIfidx() int32 { - return 0 -} - -func (e *StdNetEndpoint2) SrcToString() string { - return "" -} - -// getSrcFromControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func getSrcFromControl(control []byte, ep *StdNetEndpoint2) { -} - -// setSrcControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func setSrcControl(control *[]byte, ep *StdNetEndpoint2) { -} - -// stickyControlSize returns the recommended buffer size for pooling sticky -// offloading control data; for linux: stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) -const stickyControlSize = 0 - -// no netlink on Androids: github.com/WireGuard/wireguard-go/blob/12269c2761/device/sticky_linux.go#L28 -// for linux: StdNetSupportsStickySockets = true -const StdNetSupportsStickySockets = false diff --git a/intra/ipn/wg/stats.go b/intra/ipn/wg/stats.go deleted file mode 100644 index cacb36d3..00000000 --- a/intra/ipn/wg/stats.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package wg - -import ( - "errors" - "fmt" - "strconv" - "strings" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -// from: github.com/WireGuard/wireguard-android/blob/4ba87947ae/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java -// from: github.com/WireGuard/wireguard-android/blob/4ba87947ae/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java#L119 - -var ( - errNoSuchPeer = errors.New("wg: no such peer") - errAllStatsNotOK = errors.New("wg: all stats not OK") - - baTtl = 30 * time.Second - baNegTtl = 2 * time.Second - ba = core.NewBarrier2[*ifstats, uintptr](baTtl, baNegTtl) -) - -// peerstats represents the statistics for a peer. -type peerstats struct { - RxBytes int64 - TxBytes int64 - LatestHandshakeEpochMillis int64 -} - -// ifstats holds the statistics for peers. -type ifstats struct { - o string - stats map[string]peerstats - lastTouched time.Time -} - -func (s *ifstats) String() string { - if s == nil { - return "" - } - - o := s.o - d := s.lastTouched.UnixMilli() - rx := s.TotalRx() - tx := s.TotalTx() - hdshk := s.LatestRecentHandshake() - return fmt.Sprintf("ifstats{o: %s, lastTouched: %d, rx: %d, tx: %d, lastOK: %d}", - o, d, rx, tx, hdshk) -} - -// newStats creates a new Statistics instance. -func newStats(owner string) *ifstats { - return &ifstats{ - o: owner, - stats: make(map[string]peerstats), - lastTouched: time.Now(), - } -} - -// add adds a new peer's statistics to the map. -func (s *ifstats) add(key string, rx, tx, latestHandshake int64) bool { - if settings.Debug { - log.VV("wg: ReadStats: %s: add %s, %d, %d, %d", s.o, key, rx, tx, latestHandshake) - } - s.stats[key] = peerstats{RxBytes: rx, TxBytes: tx, LatestHandshakeEpochMillis: latestHandshake} - return latestHandshake > 0 -} - -// IsStale checks if the statistics are older than 15 minutes. -func (s *ifstats) IsStale() bool { - return time.Since(s.lastTouched) > 15*time.Minute -} - -// Peer retrieves the statistics for a specific peer. -func (s *ifstats) Peer(key string) (peerstats, error) { - if stats, ok := s.stats[key]; ok { - return stats, nil - } - return peerstats{}, errNoSuchPeer -} - -// Peers returns all the keys (peers) in the statistics map. -func (s *ifstats) Peers() []string { - keys := make([]string, 0, len(s.stats)) - for key := range s.stats { - keys = append(keys, key) - } - return keys -} - -// TotalRx calculates the total received bytes. -func (s *ifstats) TotalRx() int64 { - var total int64 - for _, stats := range s.stats { - total += stats.RxBytes - } - return total -} - -// TotalTx calculates the total transmitted bytes. -func (s *ifstats) TotalTx() int64 { - var total int64 - for _, stats := range s.stats { - total += stats.TxBytes - } - return total -} - -func (s *ifstats) LatestRecentHandshake() int64 { - least := int64(0) - for _, stats := range s.stats { - least = max(least, stats.LatestHandshakeEpochMillis) - } - if settings.Debug { - log.VV("wg: ReadStats: %s: LatestRecentHandshake: %s, Peers: %d", - s.o, core.FmtUnixMillisAsPeriod(least), len(s.stats)) - } - return least -} - -func ReadStats(who string, id uintptr, cfn core.Work[string]) *ifstats { - v, err := ba.DoIt(id, func() (*ifstats, error) { - cfg, err := cfn() - if err != nil || len(cfg) <= 0 { - log.W("wg: ReadStats: %s: %d: ipcget: %v", who, id, err) - return nil, err - } - return readStats(who, cfg) - }) - if err != nil { // v is nil when ba.Do timesout or no handshake yet - log.W("wg: ReadStats: %s nil for %d, err: %v", who, id, err) - } - return v -} - -// readStats parses a configuration string and returns a Statistics instance. -func readStats(who, config string) (*ifstats, error) { - stats := newStats(who) - var key string - var rx, tx, latestHandshakeMillis int64 - var anyStatOK bool - - // see: github.com/WireGuard/wireguard-go/blob/12269c27/device/uapi.go#L51 - lines := strings.Split(config, "\n") - n := len(lines) - k := 0 - for _, line := range lines { - if strings.HasPrefix(line, "public_key=") { - if key != "" { - k++ - anyStatOK = stats.add(key, rx, tx, latestHandshakeMillis) || anyStatOK - } - rx = 0 - tx = 0 - latestHandshakeMillis = 0 - key = line[11:] - } else if strings.HasPrefix(line, "rx_bytes=") { - if key == "" { - continue - } - rx, _ = strconv.ParseInt(line[9:], 10, 64) - } else if strings.HasPrefix(line, "tx_bytes=") { - if key == "" { - continue - } - tx, _ = strconv.ParseInt(line[9:], 10, 64) - } else if strings.HasPrefix(line, "last_handshake_time_sec=") { - if key == "" { - continue - } - sec, _ := strconv.ParseInt(line[24:], 10, 64) - latestHandshakeMillis += sec * 1000 - } else if strings.HasPrefix(line, "last_handshake_time_nsec=") { - if key == "" { - continue - } - nsec, _ := strconv.ParseInt(line[25:], 10, 64) - latestHandshakeMillis += nsec / 1000000 - } - } - if key != "" { - k++ - anyStatOK = stats.add(key, rx, tx, latestHandshakeMillis) || anyStatOK - } - stats.lastTouched = time.Now() - - if settings.Debug { - log.V("wg: ReadStats: %s: %d peers, %d lines, any OK? %t", who, k, n, anyStatOK) - } - if !anyStatOK { - return stats, errAllStatsNotOK // negative ttl on barrier - } - - return stats, nil -} diff --git a/intra/ipn/wg/wgconn.go b/intra/ipn/wg/wgconn.go deleted file mode 100644 index ac7efe09..00000000 --- a/intra/ipn/wg/wgconn.go +++ /dev/null @@ -1,752 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package wg - -import ( - "crypto/rand" - "encoding/binary" - "errors" - "fmt" - "io" - mrand "math/rand/v2" - "net" - "net/netip" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn/multihost" - "github.com/celzero/firestack/intra/settings" - - "github.com/celzero/firestack/intra/log" - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" -) - -// from: github.com/WireGuard/wireguard-go/blob/ebbd4a433/conn/bind_std.go - -const maxbindtries = 50 -const wgtimeout = 60 * time.Second - -// github.com/WireGuard/wireguard-go/blob/19ac233cc6/wireguard/device/send.go#L96 -var ( - // quic? - // github.com/hiddify/hiddify-sing-box/blob/17127b0535d/outbound/wireguard.go#L217 - mlist = []byte{0xDC, 0xDE, 0xD3, 0xD9, 0xD0, 0xEC, 0xEE, 0xE3} - // github.com/WireGuard/wireguard-go/blob/12269c27617/device/send.go#L456 - wgheader = []byte{ // size 18 - /*00-03*/ 0x05, 0x00, 0x00, 0x00, // fieldType - /*04-07*/ 0x01, 0x08, 0x00, 0x00, // fieldReceiver - /*08-11*/ 0x00, 0x00, 0x00, 0x00, // fieldNonce - /*11-15*/ 0x00, 0x00, 0x00, 0x00, // fieldNonce - /*16-17*/ 0x44, 0xD0, // ??? - } - - anyaddr6 = netip.IPv6Unspecified() - anyaddr4 = netip.IPv4Unspecified() -) - -const ( - minFloodPkts = 3 - maxFloodPkts = minFloodPkts * 10 - maxFloodDuration = 3 * time.Second - minFloodInterval = 1 * time.Minute // flood once every min - - minFloodPktLen = 28 // bytes; must be > len(wgheader) - maxFloodPktLen = 138 // must be >> minFloodPktLen; < device.MessageInitiationSize? -) - -var ( - errInvalidEndpoint = errors.New("wg: bind: no endpoint") - errNoLocalAddr = errors.New("wg: bind: no local address") - errNoRawConn = errors.New("wg: bind: no raw conn") - errNotUDP = errors.New("wg: bind: not a UDP conn") - errNoListen = errors.New("wg: bind: listen failed") - errEnded = errors.New("wg: bind: proxy ended") -) - -type floodkind int - -const ( - fkHandshake floodkind = iota - fkKeepalive -) - -func (k floodkind) String() string { - switch k { - case fkHandshake: - return "handshake" - case fkKeepalive: - return "keepalive" - default: - return "unknown" - } -} - -type rwobserver func(op PktDir, err error) (ended bool) -type connector func(network, to string) (net.PacketConn, error) - -type PktDir string - -const ( - Rcv PktDir = "recv" // data received - Snd PktDir = "send" // data sent - Crc PktDir = "notr" // not transport data (recv) - Csn PktDir = "nots" // not transport data (send) - Con PktDir = "conn" // e.g. dial, announce, accept - Opn PktDir = "open" // open conn to the wg endpoint - Clo PktDir = "clos" // close conn to the wg endpoint - Drp PktDir = "drop" // ignored packet -) - -// Rcv or Crc -func (op PktDir) Read() bool { - return op == Rcv || op == Crc -} - -// Snd or Csn -func (op PktDir) Write() bool { - return op == Snd || op == Csn -} - -type StdNetBind struct { - id string - connect connector - pm *core.Volatile[*multihost.MHMap] // peer ip:port or host => preferred-addrs - - amnezia *core.Volatile[*Amnezia] // may return nil *Amnezia - floodBa *core.Barrier[int, netip.AddrPort] - - mu sync.Mutex // protects following fields - ipv4 net.PacketConn // (*net.UDPConn or *gonet.UDPConn) - ipv6 net.PacketConn // (*net.UDPConn or *gonet.UDPConn) - - // keeps wireguard's recv routine running by not returning errors yet dropping packets - blackhole4 bool - blackhole6 bool - - epmu sync.RWMutex - eps map[net.Addr]StdNetEndpoint // peer-addr => std-net-endpoint - - observer rwobserver - sendAddr *core.Volatile[netip.AddrPort] // may be invalid - - closed atomic.Bool // wgconn has been closed - ended atomic.Bool // observer / connector are done -} - -// TODO: get d, ep, f, rb through an Opts bag? -func NewEndpoint(id string, d connector, pm *core.Volatile[*multihost.MHMap], f rwobserver, a *core.Volatile[*Amnezia]) *StdNetBind { - s := &StdNetBind{ - id: id, - connect: d, - pm: pm, - observer: f, - amnezia: a, - floodBa: core.NewKeyedBarrier[int, netip.AddrPort](minFloodInterval), - eps: make(map[net.Addr]StdNetEndpoint), - sendAddr: core.NewZeroVolatile[netip.AddrPort](), - } - return s -} - -type StdNetEndpoint struct { - netip.AddrPort - addr *net.UDPAddr -} - -var invalidStdNetEndpoint = StdNetEndpoint{} - -var ( - _ conn.Bind = (*StdNetBind)(nil) - _ conn.Endpoint = StdNetEndpoint{} -) - -func (e *StdNetBind) ParseEndpoint(s string) (conn.Endpoint, error) { - /* - host, portstr, err := net.SplitHostPort(s) - if err != nil { - log.E("wg: bind: %s invalid endpoint in(%s); err: %v", e.id, s, err) - return nil, err - } - port, err := strconv.Atoi(portstr) - if err != nil { - log.E("wg: bind: %s invalid port in(%s); err: %v", e.id, s, err) - return nil, err - } - */ - // d.Add([]string{host}) // resolves host if needed - d, err := e.pm.Load().Get(s) - if err != nil || d == nil /*nilaway; can't happen*/ { - log.E("wg: bind: parse: %s invalid endpoint in(%s); err: %v", e.id, s, err) - return nil, err - } - - // do what tailscale does, and share a preferred endpoint regardless of "s" - // github.com/tailscale/tailscale/blob/3a6d3f1a5b7/wgengine/magicsock/magicsock.go#L2568 - ipport := d.PreferredAddr() - if !ipport.IsValid() || ipport.Addr().IsUnspecified() { - log.E("wg: bind: parse: %s invalid endpoint; chosen(%v) => in(%s) => out(%s, %s)", e.id, ipport, s, d.Names(), d.Addrs()) - // erroring out from here prevents PostConfig (handshake for this peer endpoint will always be zero) - // github.com/WireGuard/wireguard-go/blob/12269c276173/device/uapi.go#L183 - return nil, errInvalidEndpoint - } - - log.I("wg: bind: %s new shared endpoint for %s %v", e.id, s, ipport) - - // todo: add stdnetendpoint to s.eps - return StdNetEndpoint{ipport, udpaddr(ipport)}, nil -} - -func (StdNetEndpoint) ClearSrc() {} // not supported - -func (e StdNetEndpoint) DstIP() netip.Addr { - return e.Addr() -} - -func (e StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} // not supported -} - -func (e StdNetEndpoint) DstToBytes() []byte { - b, _ := e.MarshalBinary() - return b -} - -func (e StdNetEndpoint) DstToString() string { - return e.String() -} - -func (e StdNetEndpoint) SrcToString() string { - return "" -} - -func udpaddr(ipp netip.AddrPort) *net.UDPAddr { - if ipp.IsValid() { - return net.UDPAddrFromAddrPort(ipp) - } - return nil -} - -func (s *StdNetBind) RemoteAddr() netip.AddrPort { - return s.sendAddr.Load() -} - -func (s *StdNetBind) listenNet(network string, port int) (net.PacketConn, int, error) { - if s.ended.Load() { - return nil, 0, errEnded - } - - anyaddr := anyaddr6 - if network == "udp4" { - anyaddr = anyaddr4 - } - saddr := net.JoinHostPort(anyaddr.String(), fmt.Sprintf("%d", port)) - - conn, err := s.connect(network, saddr) - if err != nil { - log.E("wg: bind: listen: %s %s: on(%v); err: %v", s.id, network, saddr, err) - return nil, 0, err - } - if conn == nil { - log.E("wg: bind: listen: %s %s: on(%v); conn nil", s.id, network, saddr) - return nil, 0, errNoListen - } - - laddr := conn.LocalAddr() - if laddr == nil { - log.E("wg: bind: listen: %s %s: on(%v); local-addr nil", s.id, network, saddr) - return nil, 0, errNoLocalAddr - } - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - if uaddr == nil { - return nil, 0, errNoLocalAddr - } - if settings.Debug { - log.VV("wg: bind: listen: %s %s: on(%v)", s.id, network, laddr) - } - // typecast is safe, because "network" is always udp[4|6]; see: Open - return conn, uaddr.Port, nil -} - -func (s *StdNetBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.ended.Load() { - return nil, 0, errEnded - } - - s.closed.Store(false) - - var err error - var tries int - - if s.ipv4 != nil || s.ipv6 != nil { - log.W("wg: bind: open: %s already open at :%d", s.id, uport) - return nil, 0, conn.ErrBindAlreadyOpen - } - - // Attempt to open ipv4 and ipv6 listeners on the same port. - // If uport is 0, we can retry on failure. -again: - port := int(uport) - var ipv4, ipv6 net.PacketConn - - ipv4, port, err = s.listenNet("udp4", port) - no4 := errors.Is(err, syscall.EAFNOSUPPORT) - log.D("wg: bind: open: %s #%d listen4(%d); no4? %t err? %v", s.id, tries, port, no4, err) - if err != nil && !no4 { - return nil, 0, err - } - - // Listen on the same port as we're using for ipv4. - ipv6, port, err = s.listenNet("udp6", port) - busy := errors.Is(err, syscall.EADDRINUSE) - no6 := errors.Is(err, syscall.EAFNOSUPPORT) - log.D("wg: bind: open: %s #%d listen6(%d); busy? %t no6? %t err? %v", s.id, tries, port, busy, no6, err) - if uport == 0 && busy && tries < maxbindtries { - clos(ipv4) - tries++ - goto again - } - if err != nil && !no6 { - clos(ipv4) - return nil, 0, err - } - - var fns []conn.ReceiveFunc - if ipv4 != nil { - s.ipv4 = ipv4 - fns = append(fns, s.makeReceiveFn(ipv4)) - } - if ipv6 != nil { - s.ipv6 = ipv6 - fns = append(fns, s.makeReceiveFn(ipv6)) - } - - log.I("wg: bind: open: %s opened port(requested %d => using %d) for v4? %t v6? %t", - s.id, uport, port, ipv4 != nil, ipv6 != nil) - if len(fns) == 0 { - return nil, 0, syscall.EAFNOSUPPORT - } - - var eerr error = nil - if s.ended.Load() { - eerr = errEnded - } - return fns, uint16(port), eerr -} - -// Pause implements wgconn -func (s *StdNetBind) Pause() bool { - s.mu.Lock() - defer s.mu.Unlock() - - s.blackhole4 = true - s.blackhole6 = true - - // by the time resume comes about, the internal wireguard send/recv routines may have been stopped - // or the keepalives blackholed for long enough that a new connection needs to be established. - log.I("wg: bind: pr: %s pausing... v4? %t v6? %t", s.id, s.ipv4 != nil, s.ipv6 != nil) - - return true -} - -// Resume implements wgconn -func (s *StdNetBind) Resume() bool { - s.mu.Lock() - defer s.mu.Unlock() - - s.blackhole4 = false - s.blackhole6 = false - - log.I("wg: bind: pr: %s resuming... v4? %t v6? %t", s.id, s.ipv4 != nil, s.ipv6 != nil) - - return true -} - -// Closed implements wgconn -func (s *StdNetBind) Closed() bool { - return s.closed.Load() -} - -func (s *StdNetBind) Close() error { - // Do NOT do a pre-lock s.closed.Load() check here. - // Open() writes s.closed under s.mu; a read outside the lock races with it. - // The CAS below (also under s.mu) is the only correct guard. - s.mu.Lock() - defer s.mu.Unlock() - - if s.closed.CompareAndSwap(false, true) { - var err1, err2 error - var addr1, addr2 net.Addr - v4, v6 := s.ipv4, s.ipv6 - if v4 != nil { - addr1 = v4.LocalAddr() - err1 = v4.Close() - s.ipv4 = nil - } - if v6 != nil { - addr2 = v6.LocalAddr() - err2 = v6.Close() - s.ipv6 = nil - } - // resume if paused, so wireguard routines calling into send/recv error out - s.blackhole4 = false - s.blackhole6 = false - - s.ended.Store(s.observer(Clo, nil)) - - log.I("wg: bind: close: %s addrs %v + %v; err4? %v err6? %v", s.id, addr1, addr2, err1, err2) - return core.JoinErr(err1, err2) - } - log.W("wg: bind: close: %s racing... ignored", s.id) - return nil -} - -func (s *StdNetBind) makeReceiveFn(uc net.PacketConn) conn.ReceiveFunc { - // github.com/WireGuard/wireguard-go/blob/469159ecf/device/device.go#L531 - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { - defer core.Recover(core.Exit11, "wgconn.recv."+s.id) - - anyProcessed := false // true when numMsgs > 0 (ex: no error) - anyTransportTyp := false - defer func() { - op := Rcv - if !anyTransportTyp && anyProcessed { - op = Crc // processed packets not transport data - } - if s.observer(op, err) { - s.ended.Store(true) - } - }() - - amnezia := s.amnezia.Load() - usingamz := amnezia.Set() - overwritten := false - - numMsgs := 0 - b := bufs[0] // usually sized device.MaxMessageSize - - extend(uc, wgtimeout) - n, addr, err := uc.ReadFrom(b) - if err == nil { - b, overwritten = amnezia.recv(b, n) - numMsgs++ - } - - for i := range numMsgs { - anyProcessed = true - if overwritten { - copy(bufs[i], b) - sizes[i] = len(b) - } else { // bufs remained unchanged - sizes[i] = n - } - anyTransportTyp = anyTransportTyp || transportType(bufs[i]) - eps[i] = s.asEndpoint(addr) - } - - if err != nil && !timedout(err) { - log.E("wg: bind: recv: %s recvfrom(%v): %d / ov? %t<=%t / trans? %t / err? %v", - s.id, addr, n, usingamz, overwritten, anyTransportTyp, err) - } else if settings.Debug { - log.V("wg: bind: recv: %s recvfrom(%v): %d / ov? %t<=%t / trans? %t / err? %v", - s.id, addr, n, usingamz, overwritten, anyTransportTyp, err) - } - return numMsgs, err - } -} - -func timedout(err error) bool { - if err == nil { - return false - } - x, ok := err.(net.Error) - return ok && x.Timeout() -} - -func (s *StdNetBind) Send(buf [][]byte, peer conn.Endpoint) (err error) { - defer core.Recover(core.Exit11, "wgconn.send."+s.id) - - anyProcessed := false - anyTransportTyp := false - defer func() { - op := Snd - if !anyTransportTyp && anyProcessed { - op = Csn // processed packet not transport data - } - if s.observer(op, err) { - s.ended.Store(true) - } - }() - - // the peer endpoint - ep, ok := peer.(StdNetEndpoint) - if !ok { - log.E("wg: bind: send: %s wrong endpoint type: %T", s.id, peer) - return conn.ErrWrongEndpointType - } - dstIpp := ep.AddrPort - - s.mu.Lock() - blackhole := s.blackhole4 - uc := s.ipv4 - noconn := uc == nil - if dstIpp.Addr().Is6() { - blackhole = s.blackhole6 - uc = s.ipv6 - noconn = uc == nil - } - s.mu.Unlock() - - var floodWg = settings.FloodWireGuard.Load() - var flooded, overwritten bool - var nn int - var errs error - for _, data := range buf { - bufok := len(data) > 0 - - if settings.Debug { - log.VV("wg: bind: send: %s addr(%v) floodwg? %t, blackhole? %t; noconn? %t; hasbuf? %t", - s.id, dstIpp, floodWg, blackhole, noconn, bufok) - } - - if blackhole || !bufok { - return nil - } - if noconn || uc == nil { - return syscall.EAFNOSUPPORT - } - - amnezia := s.amnezia.Load() - anyProcessed = true - anyTransportTyp = anyTransportTyp || transportType(data) - - datalen := len(data) // grab the length before we overwrite it - - overwritten = amnezia.send(&data) - - if !flooded && (floodWg || amnezia.Set()) { - if datalen == device.MessageInitiationSize { - s.flood(uc, ep, fkHandshake) // was probably a handshake - flooded = true - } else if datalen == device.MessageKeepaliveSize { - s.flood(uc, ep, fkKeepalive) // was probably a keepalive - flooded = true - } - } - - extend(uc, wgtimeout) - n, serr := uc.WriteTo(data, ep.addr) - - errs = core.JoinErr(errs, serr) - nn += n - } - - s.sendAddr.Store(dstIpp) - - loge(err)("wg: bind: send: %s addr(%v) parcels(%d) tx(%d) (flooded? %t (enabled? %t) / overw? %t / trans? %t); err? %v", - s.id, dstIpp, len(buf), nn, flooded, floodWg, overwritten, anyTransportTyp, errs) - - return errs -} - -// flood c with random-sized, non-sense (unencrypted) packets. -// this is okay to do because wireguard silently drops packets that won't decrypt. -// github.com/WireGuard/wireguard-go/blob/19ac233cc6/wireguard/device/send.go#L96 -// github.com/GFW-knocker/wireguard/blob/8bd9f582b4/device/send.go#L98 -func (s *StdNetBind) flood(c net.PacketConn, dst StdNetEndpoint, why floodkind) (int, error) { - return s.floodBa.DoIt(dst.AddrPort, func() (int, error) { - hdrlen := len(wgheader) - hdr := make([]byte, hdrlen) - copy(hdr, wgheader) - - hdr[0] = mlist[mrand.UintN(uint(len(mlist)))] - _, _ = rand.Read(hdr[6:14]) - - tot := max(mrand.Uint64N(maxFloodPkts+1), minFloodPkts) - // go.dev/play/p/NkLihAUTqUO - maxWaitMs := maxFloodDuration.Milliseconds() / int64(tot) - expectedsent := make([]int, tot) - - // gfw-knocker generates much smaller pkts (18 + (10 to 30)) - // github.com/GFW-knocker/wireguard/blob/8bd9f582b4/device/send.go#L141 - padlen := uint64(maxFloodPktLen - hdrlen) - pkt := make([]byte, maxFloodPktLen) - var n int - for i := range tot { - sz := max(mrand.Uint64N(padlen+1), minFloodPktLen) - _, _ = rand.Read(pkt[hdrlen:sz]) - copy(pkt[0:], hdr) - - extend(c, wgtimeout) - sent, err := c.WriteTo(pkt, dst.addr) - - expectedsent[i] = hdrlen + int(sz) - n += sent - - if err != nil { - log.E("wg: bind: flood: %s %s %s: expected sent?(%v) / tot(%d); %v", - s.id, why, dst, expectedsent[:i], n, err) - return n, err - } - - wait := time.Duration(mrand.Int64N(maxWaitMs)) * time.Millisecond - time.Sleep(wait) - } - - if settings.Debug { - log.D("wg: bind: flood %s %s %s: expected sent?(%v) / tot(%d)", - s.id, why, dst, expectedsent, n) - } - return n, nil - }) -} - -func (s *StdNetBind) BatchSize() int { - return 1 -} - -// from: github.com/WireGuard/wireguard-go/blob/1417a47c8/conn/mark_unix.go -func (s *StdNetBind) SetMark(mark uint32) (err error) { - // s.ipv4 and s.ipv6 are written by Open/Close under s.mu; read them here - // under the same lock to avoid a data race. - s.mu.Lock() - uc4, _ := s.ipv4.(core.ControlConn) // may be nil - uc6, _ := s.ipv6.(core.ControlConn) // may be nil - s.mu.Unlock() - var operr error - var raw4, raw6 syscall.RawConn - fwmarkIoctl := 36 /* unix.SO_MARK */ - if uc4 != nil { - if raw4, err = uc4.SyscallConn(); err == nil { - if raw4 == nil { - log.W("wg: bind: mark: %s setmark4: raw conn nil", s.id) - return errNoRawConn - } - if err = raw4.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }); err == nil { - err = operr - } - } // else: return err - } - if err == nil && uc6 != nil { - if raw6, err = uc6.SyscallConn(); err == nil { - if raw6 == nil { - log.W("wg: bind: mark: %s setmark6: raw conn nil", s.id) - return errNoRawConn - } - if err = raw6.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }); err == nil { - err = operr - } - } // else: return err - } - log.I("wg: bind: mark: %s err? %v", s.id, err) - return nil -} - -// asEndpoint returns an Endpoint containing ap. -// pooling disabled due to data race: -// github.com/WireGuard/wireguard-go/commit/334b605e726 -func (s *StdNetBind) asEndpoint(ap net.Addr) conn.Endpoint { - // TODO: avoid allocations - if ap == nil { - return invalidStdNetEndpoint - } - s.epmu.RLock() - ep, ok := s.eps[ap] - s.epmu.RUnlock() - - if ok { - return ep - } - - s.epmu.Lock() - defer s.epmu.Unlock() - ep, ok = s.eps[ap] - if ok { - return ep - } - - if tcp, ok := ap.(*net.TCPAddr); ok { - ipp := tcp.AddrPort() - ep = StdNetEndpoint{ipp, udpaddr(ipp)} - } else if udp, ok := ap.(*net.UDPAddr); ok { - ipp := udp.AddrPort() - ep = StdNetEndpoint{ipp, udpaddr(ipp)} // copy udp addr - } else if ipp, err := netip.ParseAddrPort(ap.String()); err == nil { - ep = StdNetEndpoint{ipp, udpaddr(ipp)} - } - s.eps[ap] = ep // ep may be zero value - return ep -} - -func loge(err error) log.LogFn { - l := log.N // no-op - if err != nil { - l = log.W - } else if settings.Debug { - l = log.V - } - return l -} - -func extend(c core.MinConn, t time.Duration) { - if c != nil && core.IsNotNil(c) { - _ = c.SetDeadline(time.Now().Add(t)) - } -} - -func clos(c io.Closer) { - core.CloseOp(c, core.CopRW) -} - -func transportType(unobs []byte) (y bool) { - return messageType(unobs, device.MessageTransportType) -} - -// messageType reports whether unobs is of type t message. -// "unobs" must be free of Amnezia-like obfuscations. -func messageType(unobs []byte, t uint32) (y bool) { - var typ uint32 - n := len(unobs) - - defer func() { - if settings.Debug && !y { - log.V("wg: bind: messageType: len(%d) msgt(%d) == t(%d)? %t", n, typ, t, y) - } - }() - - if n < device.MinMessageSize { - return - } - - typ = binary.LittleEndian.Uint32(unobs) - y = typ == t - return -} diff --git a/intra/ipn/wg/wgconn2.go b/intra/ipn/wg/wgconn2.go deleted file mode 100644 index 4c6a73da..00000000 --- a/intra/ipn/wg/wgconn2.go +++ /dev/null @@ -1,818 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -package wg - -import ( - "errors" - "fmt" - "net" - "net/netip" - "sync" - "syscall" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn/multihost" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" -) - -// commit: github.com/WireGuard/wireguard-go/commit/3bb8fec7e - -// StdNetBind2 implements Bind for all platforms. -// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable -// methods for sending and receiving multiple datagrams per-syscall. See the -// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. -type StdNetBind2 struct { - mu sync.Mutex // protects following fields - id string - connect connector - observer rwobserver - sendAddr *core.Volatile[netip.AddrPort] // may be invalid - pm *core.Volatile[*multihost.MHMap] // peer ip:port or host => preferred-addrs - amnezia *core.Volatile[*Amnezia] // unused; amnezia/warp config, if any - - ipv4 *net.UDPConn - ipv6 *net.UDPConn - - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux - ipv4TxOffload bool - ipv4RxOffload bool - ipv6TxOffload bool - ipv6RxOffload bool - - // these two fields are not guarded by mu - udpAddrPool sync.Pool - msgsPool sync.Pool - - blackhole4 bool - blackhole6 bool -} - -type StdNetEndpoint2 struct { - // AddrPort is the endpoint destination. - netip.AddrPort - // src is the current source address. - src []byte -} - -type batchReader interface { - ReadBatch(ms []ipv6.Message, flags int) (int, error) -} - -type batchWriter interface { - WriteBatch(ms []ipv6.Message, flags int) (int, error) -} - -type setGSOFunc func(control *[]byte, gsoSize uint16) - -type getGSOFunc func(control []byte) (int, error) - -type ErrUDPGSODisabled struct { - onLaddr string - RetryErr error // err, if any, on retry; may be nil -} - -const ( - // Exceeding these values results in EMSGSIZE. They account for layer3 and - // layer4 headers. IPv6 does not need to account for itself as the payload - // length field is self excluding. - maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 - maxIPv6PayloadLen = 1<<16 - 1 - 8 - - // this is a hard limit imposed by the kernel. - udpSegmentMaxDatagrams = 64 - - // github.com/WireGuard/wireguard-go/blob/12269c276/device/queueconstants_android.go#L13 - IdealBatchSize = conn.IdealBatchSize -) - -var ( - zeroaddr net.Addr = &net.UDPAddr{} - zeroaddrport = netip.AddrPort{} - - // If compilation fails here these are no longer the same underlying type. - _ ipv6.Message = ipv4.Message{} - - _ conn.Bind = (*StdNetBind2)(nil) - _ conn.Endpoint = &StdNetEndpoint2{} -) - -var errSplitOverflow = errors.New("wg: splitting coalesced packet resulted in overflow") - -func (e ErrUDPGSODisabled) Error() string { - return fmt.Sprintf("disabled udp gso on %s, NIC(s) may not support checksum offload", e.onLaddr) -} - -func (e ErrUDPGSODisabled) Unwrap() error { - return e.RetryErr -} - -func NewEndpoint2(id string, d connector, pm *core.Volatile[*multihost.MHMap], f rwobserver, a *core.Volatile[*Amnezia]) *StdNetBind2 { - return &StdNetBind2{ - id: id, - connect: d, - observer: f, - pm: pm, - amnezia: a, - - udpAddrPool: sync.Pool{ - New: func() any { - return &net.UDPAddr{ - IP: make([]byte, 16), - } - }, - }, - - sendAddr: core.NewZeroVolatile[netip.AddrPort](), - - msgsPool: sync.Pool{ - New: func() any { - // ipv6.Message and ipv4.Message are interchangeable as they are - // both aliases for x/net/internal/socket.Message. - msgs := make([]ipv6.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) - } - return &msgs - }, - }, - } -} - -func (e *StdNetBind2) ParseEndpoint(s string) (conn.Endpoint, error) { - /*host, portstr, err := net.SplitHostPort(s) - if err != nil { - log.E("wg: bind2: %s invalid endpoint in(%s); err: %v", e.id, s, err) - return nil, err - } - port, err := strconv.Atoi(portstr) - if err != nil { - log.E("wg: bind2: %s invalid port in(%s); err: %v", e.id, s, err) - return nil, err - }*/ - // d.Add([]string{host}) // resolves host if needed - d, err := e.pm.Load().Get(s) - if err != nil || d == nil /*nilaway; can't happen*/ { - log.E("wg: bind2: %s parse: invalid endpoint in(%s); err: %v", e.id, s, err) - return nil, err - } - - // do what tailscale does, and share a preferred endpoint regardless of "s" - // github.com/tailscale/tailscale/blob/3a6d3f1a5b7/wgengine/magicsock/magicsock.go#L2568 - ipport := d.PreferredAddr() - if !ipport.IsValid() || ipport.Addr().IsUnspecified() { - log.E("wg: bind2: %s parse: invalid endpoint addr %v in(%s); out(%s, %s)", e.id, ipport, s, d.Names(), d.Addrs()) - // erroring out from here prevents PostConfig (handshake for this peer endpoint will always be zero) - // github.com/WireGuard/wireguard-go/blob/12269c276173/device/uapi.go#L183 - return nil, errInvalidEndpoint - } - - log.I("wg: bind2: %s new endpoint for %s, %v", e.id, s, ipport) - return asEndpoint2(ipport), nil -} - -func (e *StdNetEndpoint2) ClearSrc() { - if len(e.src) > 0 { - // truncate src, no need to reallocate. - e.src = e.src[:0] - } -} - -func (e *StdNetEndpoint2) DstIP() netip.Addr { - return e.AddrPort.Addr() -} - -// See sosticky for implementations of SrcIP and SrcIfidx. - -func (e *StdNetEndpoint2) DstToBytes() []byte { - b, _ := e.AddrPort.MarshalBinary() - return b -} - -func (e *StdNetEndpoint2) DstToString() string { - return e.AddrPort.String() -} - -func (s *StdNetBind2) RemoteAddr() netip.AddrPort { - return s.sendAddr.Load() -} - -func (s *StdNetBind2) listenNet(network string, port int) (*net.UDPConn, int, error) { - var anyaddr netip.Addr - if network == "udp4" { - anyaddr = netip.IPv4Unspecified() - } else { - anyaddr = netip.IPv6Unspecified() - } - saddr := net.JoinHostPort(anyaddr.String(), fmt.Sprintf("%d", port)) - - conn, err := s.connect(network, saddr) - if err != nil { - log.E("wg: bind2: %s %s: listen(%v); err: %v", s.id, network, saddr, err) - return nil, 0, err - } - if conn == nil { - log.E("wg: bind2: %s %s: listen(%v); conn nil", s.id, network, saddr) - return nil, 0, errNoListen - } - - laddr := conn.LocalAddr() - if laddr == nil { - log.E("wg: bind2: %s %s: listen(%v); local-addr nil", s.id, network, saddr) - return nil, 0, errNoLocalAddr - } - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - log.E("wg: bind2: %s %s: listen(%v); resolve-addr err: %v", s.id, network, saddr, err) - return nil, 0, err - } - if uaddr == nil { - log.E("wg: bind2: %s %s: listen(%v); resolve-addr nil", s.id, network, saddr) - return nil, 0, errNoLocalAddr - } - if settings.Debug { - log.VV("wg: bind2: %s %s: listen(%v)", s.id, network, laddr) - } - // typecast is safe, because "network" is always udp[4|6]; see: Open - if udpconn, ok := conn.(*net.UDPConn); ok { - return udpconn, uaddr.Port, nil - } else { - clos(conn) - return nil, 0, errNotUDP - } -} - -func (s *StdNetBind2) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { - s.mu.Lock() - defer s.mu.Unlock() - - var err error - var tries int - - if s.ipv4 != nil || s.ipv6 != nil { - log.W("wg: bind2: %s already open", s.id) - return nil, 0, conn.ErrBindAlreadyOpen - } - - // Attempt to open ipv4 and ipv6 listeners on the same port. - // If uport is 0, we can retry on failure. -again: - port := int(uport) - var v4conn, v6conn *net.UDPConn - - // v4 - v4conn, port, err = s.listenNet("udp4", port) - no4 := errors.Is(err, syscall.EAFNOSUPPORT) - loge(err)("wg: bind2: %s #%d: listen4(%d); no4? %t err? %v", s.id, tries, port, no4, err) - if err != nil && !no4 { - return nil, 0, err - } - - // v6: Listen on the same port as we're using for ipv4. - v6conn, port, err = s.listenNet("udp6", port) - busy := errors.Is(err, syscall.EADDRINUSE) - no6 := errors.Is(err, syscall.EAFNOSUPPORT) - loge(err)("wg: bind2: %s #%d listen6(%d); busy? %t no6? %t err? %v", s.id, tries, port, busy, no6, err) - if uport == 0 && busy && tries < maxbindtries { - clos(v4conn) - tries++ - goto again - } - if err != nil && !no6 { - clos(v4conn) - return nil, 0, err - } - - canBatch := supportsBatchRw() - - var fns []conn.ReceiveFunc - if v4conn != nil { - s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) - if canBatch { - s.ipv4PC = ipv4.NewPacketConn(v4conn) - } - s.ipv4 = v4conn - fns = append(fns, s.makeReceiveIPv4()) - } - if v6conn != nil { - s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) - if canBatch { - s.ipv6PC = ipv6.NewPacketConn(v6conn) - } - s.ipv6 = v6conn - fns = append(fns, s.makeReceiveIPv6()) - } - - log.I("wg: bind2: %s supports batch read/write? %t; has4? %t; has6 %t", s.id, canBatch, s.ipv4PC != nil, s.ipv6PC != nil) - log.I("wg: bind2: %s opened port(%d) for v4? %t / v6? %t", s.id, port, v4conn != nil, v6conn != nil) - - if len(fns) == 0 { - log.W("wg: bind2: %s no listeners", s.id) - return nil, 0, syscall.EAFNOSUPPORT - } - - return fns, uint16(port), nil -} - -func (s *StdNetBind2) receiveIP( - br batchReader, - conn *net.UDPConn, - rxOffload bool, - bufs [][]byte, - sizes []int, - eps []conn.Endpoint, -) (n int, err error) { - defer func() { - s.observer(Rcv, err) - }() - - if conn == nil && br == nil { - log.E("wg: bind2: %s receiveIP: no conns hasbatch? %t; hasconn? %t", s.id, br != nil, conn != nil) - return 0, syscall.EINVAL - } - - msgs := s.getMessages() - defer s.putMessages(msgs) - if msgs == nil || len(*msgs) <= 0 { // unlikely - log.E("wg: bind2: %s no messages", s.id) - return 0, syscall.ENOMEM - } - - for i := range bufs { - if i >= len(*msgs) { // unlikely as IdealBatchSize is a hard limit - log.E("wg: bind2: %s receiveIP: limit: %d; too many messages (%d)", s.id, len(*msgs), len(bufs)) - // TODO: process bufs in next batch? - break - } - msg := &(*msgs)[i] - msg.Buffers[0] = bufs[i] - msg.OOB = msg.OOB[:cap(msg.OOB)] - } - - batch := false - var numMsgs int - if br != nil { - if rxOffload { - readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) - numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) - waddr := msgAddr(msgs) - if err != nil { - log.E("wg: bind2: %s GRO: readAt(%d) addr(%v) numMsgs(%d) err(%v)", s.id, readAt, waddr, numMsgs, err) - return 0, err - } - - numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) - if err != nil { - log.E("wg: bind2: %s GRO: splitCoalescedMessages(at: %d; from: %v) numMsgs(%d) err(%v)", s.id, readAt, waddr, numMsgs, err) - return 0, err - } - } else { - batch = true - numMsgs, err = br.ReadBatch(*msgs, 0) - if err != nil { - log.E("wg: bind2: %s ReadBatch(sz: %d; from: %s) numMsgs(%d) err(%v)", s.id, len(*msgs), msgAddr(msgs), numMsgs, err) - return 0, err - } - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - log.E("wg: bind2: %s ReadMsgUDP(sz: %d; from: %v) err(%v)", s.id, msg.N, msg.Addr, err) - return 0, err - } - numMsgs = 1 - } - // TODO: loop not needed for non-Linux as getSrcFromControl is a no-op - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - if sizes[i] == 0 && settings.Debug { - log.V("wg: bind2: %s zero-sized message from %v", s.id, msg.Addr) - continue - } - uaddr, ok := msg.Addr.(*net.UDPAddr) - if !ok { // unlikely - log.E("wg: bind2: %s invalid addr type %T %v", s.id, msg.Addr, msg.Addr) - continue - } - ep := &StdNetEndpoint2{AddrPort: uaddr.AddrPort()} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) // no-op on Android - eps[i] = ep - } - if settings.Debug { - log.VV("wg: bind2: %s received (batch? %t, offload? %t) %d messages", s.id, batch, rxOffload, numMsgs) - } - return numMsgs, nil -} - -func (s *StdNetBind2) makeReceiveIPv4() conn.ReceiveFunc { - // assign on stack to avoid closure related nil checks - rawc := s.ipv4PC - c := s.ipv4 - offload := s.ipv4RxOffload - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { - return s.receiveIP(rawc, c, offload, bufs, sizes, eps) - } -} - -func (s *StdNetBind2) makeReceiveIPv6() conn.ReceiveFunc { - // assign on stack to avoid closure related nil checks - rawc := s.ipv6PC - c := s.ipv6 - offload := s.ipv6RxOffload - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { - return s.receiveIP(rawc, c, offload, bufs, sizes, eps) - } -} - -func (s *StdNetBind2) putMessages(msgs *[]ipv6.Message) { - for i := range *msgs { - // TODO: msg.Buffers[0] = msg.Buffers[0][:0] - msg := &(*msgs)[i] - msg.OOB = (*msgs)[i].OOB[:0] - *msg = ipv6.Message{Buffers: msg.Buffers, OOB: msg.OOB} - } - s.msgsPool.Put(msgs) -} - -func (s *StdNetBind2) getMessages() *[]ipv6.Message { - m, ok := s.msgsPool.Get().(*[]ipv6.Message) - if !ok { // unlikely - log.W("wg: bind2: %s failed to get from msgpool", s.id) - x := make([]ipv6.Message, IdealBatchSize) - m = &x - } - return m -} - -func (s *StdNetBind2) putUDPAddr(ua *net.UDPAddr) { - s.udpAddrPool.Put(ua) -} - -func (s *StdNetBind2) getUDPAddr() *net.UDPAddr { - ua, ok := s.udpAddrPool.Get().(*net.UDPAddr) - if !ok { // unlikely - log.W("wg: bind2: %s failed to get from udpAddrPool", s.id) - ua = &net.UDPAddr{IP: make([]byte, 16)} - } - return ua -} - -// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and -// rename the IdealBatchSize constant to BatchSize. -func (s *StdNetBind2) BatchSize() int { - if supportsBatchRw() { - return IdealBatchSize - } - return 1 -} - -func (s *StdNetBind2) Pause() bool { - s.mu.Lock() - defer s.mu.Unlock() - - s.blackhole4 = true - s.blackhole6 = true - - log.I("wg: bind2: %s pausing... v4? %t, v6? %t", s.id, s.ipv4 != nil, s.ipv6 != nil) - return true -} - -func (s *StdNetBind2) Resume() bool { - s.mu.Lock() - defer s.mu.Unlock() - - s.blackhole4 = false - s.blackhole6 = false - - log.I("wg: bind2: %s resuming... v4? %t, v6? %t", s.id, s.ipv4 != nil, s.ipv6 != nil) - - return true -} - -func (s *StdNetBind2) Closed() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.ipv4 == nil && s.ipv6 == nil -} - -func (s *StdNetBind2) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - - var err4, err6 error - c4, c6 := s.ipv4, s.ipv6 - if c4 != nil { - err4 = c4.Close() - s.ipv4 = nil - s.ipv4PC = nil - } - if c6 != nil { - err6 = c6.Close() - s.ipv6 = nil - s.ipv6PC = nil - } - s.blackhole4 = false - s.blackhole6 = false - s.ipv4TxOffload = false - s.ipv4RxOffload = false - s.ipv6TxOffload = false - s.ipv6RxOffload = false - - log.I("wg: bind2: %s close; err4? %v err6? %v", s.id, err4, err6) - - return core.JoinErr(err4, err6) -} - -func (s *StdNetBind2) Send(bufs [][]byte, peer conn.Endpoint) (err error) { - defer func() { - target := &ErrUDPGSODisabled{} - if errors.As(err, target) { - s.observer(Snd, target.Unwrap()) - } else { - s.observer(Snd, err) - } - }() - - ep, ok := peer.(*StdNetEndpoint2) - if !ok { // unlikely - log.E("wg: bind2: %s wrong endpoint type %T", s.id, peer) - return conn.ErrWrongEndpointType - } - - s.mu.Lock() - blackhole := s.blackhole4 - offload := s.ipv4TxOffload - c := s.ipv4 - var br batchWriter = s.ipv4PC - is6 := false - if peer.DstIP().Is6() { - blackhole = s.blackhole6 - offload = s.ipv6TxOffload - c = s.ipv6 - br = s.ipv6PC - is6 = true - } - s.mu.Unlock() - - if blackhole { - return nil - } - if c == nil { - return syscall.EAFNOSUPPORT - } - - msgs := s.getMessages() // from msgspool - defer s.putMessages(msgs) - - if msgs == nil || len(*msgs) <= 0 { - log.E("wg: bind2: %s no messages", s.id) - return syscall.ENOMEM - } - - dst := addrport(peer, !is6) - if !addrok(dst.Addr()) { - log.E("wg: bind2: %s invalid destination %v", s.id, dst) - return syscall.EINVAL - } - - ua := s.getUDPAddr() // from udpAddrPool - defer s.putUDPAddr(ua) - *ua = *net.UDPAddrFromAddrPort(dst) - s.sendAddr.Store(dst) - - var retried bool -retry: - if offload { - n := coalesceMessages(ua, ep, bufs, *msgs, setGSOSize) - // send coalesced msgs; ie, len(*msgs) <= len(bufs) - if err = s.send(c, br, (*msgs)[:n], "gso"); err != nil { - log.E("wg: bind2: %s gso: send(%d/%d) to %s; err(%v)", s.id, n, len(bufs), ua, err) - } - - if shouldDisableUDPGSOOnError(err) { // err may be nil - offload = false - s.mu.Lock() - if is6 { - s.ipv6TxOffload = false - } else { - s.ipv4TxOffload = false - } - s.mu.Unlock() - retried = true - log.E("wg: bind2: %s gso: disabled on %s / v4? %t; err %v", s.id, ua, !is6, err) - goto retry - } - } else { - for i := range bufs { - msg := &(*msgs)[i] - // TODO: msg.N = len(bufs[i]) - msg.Addr = ua - msg.Buffers[0] = bufs[i] - setSrcControl(&msg.OOB, ep) // no-op on Android - } - // send all msgs - if err = s.send(c, br, (*msgs)[:len(bufs)], "batch"); err != nil { - log.E("wg: bind2: %s gso: send(%d) to %s (retry? %t); err(%v)", s.id, len(bufs), ua, retried, err) - } - } - if retried { - x := zeroaddr - if a := c.LocalAddr(); a != nil { - x = a - } - log.W("wg: bind2: %s disabled udp gso on %s; err %v", s.id, x, err) - return ErrUDPGSODisabled{onLaddr: x.String(), RetryErr: err} - } - return err -} - -func (s *StdNetBind2) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message, who string) (err error) { - var n, start int - - batch := pc != nil && core.IsNotNil(pc) - if batch { - for { - n, err = pc.WriteBatch(msgs[start:], 0) - if err != nil || n == len(msgs[start:]) { - break - } - start += n - } - } else { - for _, msg := range msgs { - addr, ok := msg.Addr.(*net.UDPAddr) - if !ok { // unlikely - log.E("wg: bind2: %s send: %s wrong addr type %s %T", s.id, who, msg.Addr, msg.Addr) - continue - } - _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, addr) - if err != nil { - log.E("wg: bind2: %s send: %s to %v; err %v", s.id, who, addr, err) - break - } - n += 1 - } - } - if err != nil { - log.E("wg: bind2: %s send: %s batch? %t; n(%d); err? %v", s.id, who, batch, n, err) - } else { - log.V("wg: bind2: %s send: %s batch? %t; n(%d); err? %v", s.id, who, batch, n, err) - } - return err -} - -// asEndpoint2 returns an Endpoint containing ap. -// pooling disabled due to data race: -// github.com/WireGuard/wireguard-go/commit/334b605e726 -func asEndpoint2(ap netip.AddrPort) *StdNetEndpoint2 { - return &StdNetEndpoint2{ - AddrPort: ap, - } -} - -// from: github.com/WireGuard/wireguard-go/blob/1417a47c8/conn/mark_unix.go -func (s *StdNetBind2) SetMark(mark uint32) (err error) { - // no-op for now - return nil -} - -func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint2, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { - var ( - base = -1 // index of msg we are currently coalescing into - gsoSize int // segmentation size of msgs[base] - dgramCnt int // number of dgrams coalesced into msgs[base] - endBatch bool // tracking flag to start a new batch on next iteration of bufs - ) - maxPayloadLen := maxIPv4PayloadLen - if ep.DstIP().Is6() { - maxPayloadLen = maxIPv6PayloadLen - } - for i, buf := range bufs { - if i > 0 { - curmsg := &msgs[base] - msgLen := len(buf) - baseLenBefore := len(curmsg.Buffers[0]) - freeBaseCap := cap(curmsg.Buffers[0]) - baseLenBefore - if !endBatch && - msgLen+baseLenBefore <= maxPayloadLen && - msgLen <= gsoSize && - msgLen <= freeBaseCap && - dgramCnt < udpSegmentMaxDatagrams { - curmsg.Buffers[0] = append(curmsg.Buffers[0], buf...) - if i == len(bufs)-1 { - setGSO(&curmsg.OOB, uint16(gsoSize)) - } - dgramCnt++ - if msgLen < gsoSize { - // A smaller than gsoSize packet on the tail is legal, but - // it must end the batch. - endBatch = true - } - continue - } - } - if dgramCnt > 1 { - setGSO(&msgs[base].OOB, uint16(gsoSize)) - } - // Reset prior to incrementing base since we are preparing to start a - // new potential batch. - endBatch = false - base++ - gsoSize = len(buf) - nextmsg := &msgs[base] - setSrcControl(&nextmsg.OOB, ep) // no-op on Android - nextmsg.Buffers[0] = buf - nextmsg.Addr = addr - dgramCnt = 1 - } - return base + 1 -} - -func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { - for i := firstMsgAt; i < len(msgs); i++ { - msg := &msgs[i] - if msg.N == 0 { - return n, err - } - var ( - gsoSize int - start int - end = msg.N - numToSplit = 1 - ) - gsoSize, err = getGSO(msg.OOB[:msg.NN]) - if err != nil { - return n, err - } - if gsoSize > 0 { - numToSplit = (msg.N + gsoSize - 1) / gsoSize - end = gsoSize - } - for j := 0; j < numToSplit; j++ { - if n > i { - return n, errSplitOverflow - } - copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) - msgs[n].N = copied - msgs[n].Addr = msg.Addr - start = end - end += gsoSize - if end > msg.N { - end = msg.N - } - n++ - } - if i != n-1 { - // It is legal for bytes to move within msg.Buffers[0] as a result - // of splitting, so we only zero the source msg len when it is not - // the destination of the last split operation above. - msg.N = 0 - } - } - return n, nil -} - -func msgAddr(msgs *[]ipv6.Message) net.Addr { - if msgs == nil || len(*msgs) <= 0 { - return zeroaddr - } - return (*msgs)[0].Addr -} - -func addrport(ep conn.Endpoint, as4 bool) netip.AddrPort { - if a, ok := ep.(*StdNetEndpoint2); ok { - if as4 { - addr4 := a.AddrPort.Addr().Unmap() - return netip.AddrPortFrom(addr4, a.Port()) - } else { - addr6 := a.AddrPort.Addr() - return netip.AddrPortFrom(addr6, a.Port()) - } - } - return zeroaddrport -} - -func addrok(a netip.Addr) bool { - return a.IsValid() || !a.IsUnspecified() -} diff --git a/intra/ipn/wgnet.go b/intra/ipn/wgnet.go deleted file mode 100644 index 82bd88b4..00000000 --- a/intra/ipn/wgnet.go +++ /dev/null @@ -1,285 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -// from: github.com/WireGuard/wireguard-go/blob/5819c6af/tun/netstack/tun.go - -package ipn - -import ( - "context" - "errors" - "net" - "net/netip" - "strconv" - "strings" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" -) - -var ( - errNoSuchHost = errors.New("no such host") - errNumericPort = errors.New("port must be numeric") - errNoSuitableAddress = errors.New("no suitable address found") - errMissingAddress = errors.New("missing address") -) - -// intra/tcp expects dst conns to confirm to core.TCPConn -var _ core.TCPConn = (*gonet.TCPConn)(nil) - -// intra/udp expects dst conns to confirm to core.UDPConn -var _ core.UDPConn = (*gonet.UDPConn)(nil) - -// -------------------------------------------------------------------- -// dns dialer -// -------------------------------------------------------------------- - -func (tnet *wgtun) LookupContextHost(ctx context.Context, host string) ([]netip.Addr, error) { - if len(host) <= 0 || (!tnet.hasV6.Load() && !tnet.hasV4.Load()) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - zlen := len(host) - if strings.IndexByte(host, ':') != -1 { - if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { - zlen = zidx - } - } - if ip, err := netip.ParseAddr(host[:zlen]); err == nil { - return []netip.Addr{ip}, nil - } - - // dialers.Resolve returns from cache (which may be stale) - if ips, err := dialers.Resolve(host, tnet.ID()); len(ips) <= 0 { - if err == nil { - err = errNoSuchHost - } - return nil, &net.DNSError{Err: err.Error(), Name: host, IsNotFound: true} - } else { - log.D("wg: %s dns: dial: lookup succeeded %q: %v", tnet.id, host, ips) - return ips, nil - } -} - -// -------------------------------------------------------------------- -// generic dialer -// -------------------------------------------------------------------- - -func (tnet *wgtun) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return tnet.dial(ctx, network, "", address) -} - -func (tnet *wgtun) dial(ctx context.Context, network, local, remote string) (net.Conn, error) { - var acceptV4, acceptV6 bool - switch network { - case "tcp", "udp", "ping": - acceptV4 = true - acceptV6 = true - case "tcp4", "udp4", "ping4": - acceptV4 = true - case "tcp6", "udp6", "ping6": - acceptV6 = true - default: - log.W("wg: %s dial: unknown network %q for %s => %s", tnet.id, network, local, remote) - return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} - } - - var host string - var port int - if network == "ping" || network == "ping4" || network == "ping6" { - host = remote - } else { - var sport string - var err error - host, sport, err = net.SplitHostPort(remote) - if err != nil { - log.W("wg: %s dial: invalid address %q: %v", tnet.id, remote, err) - return nil, &net.OpError{Op: "dial", Err: err} - } - port, err = strconv.Atoi(sport) - if err != nil || port < 0 || port > 65535 { - log.W("wg: %s dial: invalid port %q: %v", tnet.id, sport, err) - return nil, &net.OpError{Op: "dial", Err: errNumericPort} - } - } - - // allAddrs may be nil but shouldn't be when reserr is not nil - allAddrs, reserr := tnet.LookupContextHost(ctx, host) - if reserr != nil { - log.W("wg: %s dial: lookup failed %q: %v", tnet.id, host, reserr) - return nil, &net.OpError{Op: "dial", Err: reserr} - } - - var addrs []netip.AddrPort - for _, ip := range allAddrs { - if (ip.Is4() && acceptV4) || (ip.Is6() && acceptV6) { - addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) - } - } - if len(addrs) == 0 && len(allAddrs) != 0 { - log.W("wg: %s dial: no suitable address for %q / %v", tnet.id, host, allAddrs) - return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} - } - - var laddr4, laddr6 netip.AddrPort - if _, port, err := net.SplitHostPort(local); err == nil { - portno, _ := strconv.Atoi(port) - laddr4 = netip.AddrPortFrom(anyaddr4, uint16(portno)) - laddr6 = netip.AddrPortFrom(anyaddr6, uint16(portno)) - } - - var errs error - for i, raddr := range addrs { - laddr := laddr6 // laddr6 may be invalid - if raddr.Addr().Is4() { - laddr = laddr4 // laddr4 may be invalid - } - var c net.Conn - var err error - switch network { - case "tcp", "tcp4", "tcp6": - c, err = tnet.DialTCPAddrPort(ctx, laddr, raddr) - case "udp", "udp4", "udp6": - c, err = tnet.DialUDPAddrPort(laddr, raddr) - case "ping", "ping4", "ping6": - c, err = tnet.DialPing(laddr, raddr) - } - log.I("wg: %s dial: %s: #%d %s => %s", tnet.id, network, i, laddr, raddr) - if err == nil { - dialers.Confirm(host, raddr.Addr()) - return c, nil - } - dialers.Disconfirm(host, raddr.Addr()) - errs = core.JoinErr(errs, err) - } - errs = core.OneErr(errs, errMissingAddress) - log.W("wg: %s dial: %s: %s failed: %v", tnet.id, network, addrs, errs) - return nil, errs -} - -// -------------------------------------------------------------------- -// tcp and udp dialers -// -------------------------------------------------------------------- - -func fullAddrFrom(by string, ipp netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, bool) { - var protoNumber tcpip.NetworkProtocolNumber - var nsdaddr tcpip.Address - if !ipp.IsValid() { - // TODO: use unspecified address like in PingConn? - return tcpip.FullAddress{}, 0, false - } - if ipp.Addr().Is4() { - protoNumber = ipv4.ProtocolNumber - nsdaddr = tcpip.AddrFrom4(ipp.Addr().As4()) - } else { - protoNumber = ipv6.ProtocolNumber - nsdaddr = tcpip.AddrFrom16(ipp.Addr().As16()) - } - log.VV("wg: dial: %s translate ipp: %v => %v", by, ipp, nsdaddr) - return tcpip.FullAddress{ - NIC: wgnic, - Addr: nsdaddr, - Port: ipp.Port(), // may be 0 - }, protoNumber, true -} - -func (tnet *wgtun) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { - if faddr, protocol, ok := fullAddrFrom("tcp", addr); ok { - return gonet.DialContextTCP(ctx, tnet.stack, faddr, protocol) - } - log.W("wg: %s: tcp: dial: invalid addr %s", tnet.id, addr) - return nil, errInvalidAddr -} - -func (tnet *wgtun) DialTCPAddrPort(ctx context.Context, laddr, raddr netip.AddrPort) (*gonet.TCPConn, error) { - remote, protocol, _ := fullAddrFrom("tcp:remote", raddr) // prefer "proto" from remote - local, _, _ := fullAddrFrom("tcp:local", laddr) - // return gonet.DialTCP(tnet.stack, remote, protocol) - return gonet.DialTCPWithBind( - ctx, - tnet.stack, - local, // may be zero value - remote, // should not be zero value - protocol, - ) -} - -func (tnet *wgtun) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { - if fa, pn, ok := fullAddrFrom("tcp:listen", addr); ok { - return gonet.ListenTCP(tnet.stack, fa, pn) - } - log.W("wg: %s: tcp: listen: invalid addr %s", tnet.id, addr) - return nil, errInvalidAddr -} - -func (tnet *wgtun) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { - var src, dst *tcpip.FullAddress - var protocol tcpip.NetworkProtocolNumber - - if srcaddr, srcprotocol, ok := fullAddrFrom("udp:local", laddr); ok { - protocol = srcprotocol - if !srcaddr.Addr.Unspecified() { - src = &srcaddr - } // else: unbound; src must be left nil - } // else: laddr not valid - if dstaddr, dstprotocol, ok := fullAddrFrom("udp:remote", raddr); ok { - protocol = dstprotocol - if !dstaddr.Addr.Unspecified() { - dst = &dstaddr - } // else: unconnected; dst must be left nil - } // else: raddr not valid - - // iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml - if protocol == 0 { // gonet.DialUDP panics on unsupported protos - log.W("wg: %s: udp: dial: zero proto; %s => %s", tnet.id, laddr, raddr) - return nil, errInvalidAddr - } - - // if src is non-nil, addrs are acquired on wgnic; - // ep.Bind => ep.BindAndThen => ep.net.BindAndThen => ep.checkV4Mapped - // github.com/google/gvisor/blob/932d9dc6/pkg/tcpip/stack/addressable_endpoint_state.go#L644 - return gonet.DialUDP(tnet.stack, src, dst, protocol) -} - -func (tnet *wgtun) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { - return tnet.DialUDPAddrPort(laddr, netip.AddrPort{}) -} - -func (tnet *wgtun) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { - var src, dst netip.AddrPort - if laddr != nil { - src = laddr.AddrPort() - } - if raddr != nil { - dst = raddr.AddrPort() - } - - return tnet.DialUDPAddrPort(src, dst) -} - -func (tnet *wgtun) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { - return tnet.DialUDP(laddr, nil) -} - -func (tnet *wgtun) ListenPing(laddr netip.Addr) (*netstack.GICMPConn, error) { - return netstack.DialPingAddr(tnet.stack, wgnic, laddr, netip.Addr{}) -} - -func (tnet *wgtun) DialPing(local, remote netip.AddrPort) (*netstack.GICMPConn, error) { - return netstack.DialPingAddr(tnet.stack, wgnic, local.Addr(), remote.Addr()) -} diff --git a/intra/ipn/wgproxy.go b/intra/ipn/wgproxy.go deleted file mode 100644 index d7cea7a8..00000000 --- a/intra/ipn/wgproxy.go +++ /dev/null @@ -1,1930 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - -// from: github.com/WireGuard/wireguard-go/blob/5819c6af/tun/netstack/tun.go - -package ipn - -import ( - "bufio" - "context" - "encoding/base64" - "encoding/binary" - "errors" - "fmt" - "net" - "net/netip" - "os" - "strconv" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn/multihost" - "github.com/celzero/firestack/intra/ipn/wg" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" - - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - // github.com/WireGuard/wireguard-go/blob/12269c276/device/queueconstants_android.go#L14 - // epsize is the size of the channel endpoint. - epsize = 4096 - // eventssize is the size of the events channel. - eventssize = 64 - // wgnic is the id of the WireGuard network interface. - wgnic = 999 - // missing wg interface address. - noaddr = "" - // min mtu for ipv6 - minmtu6 = 1280 - // min mtu for ipv4 - minmtu4 = 576 - - pingThresholdMillis = 5 * 1000 // 5s - arbitraryWaitForViaHandshake = 5 * time.Second - markTNTAfterMillis = 20 * 1000 // TNT after 20s of no rcv after snd - - removeViaOnErrors = false - resetDeviceOnTNT = false - // reset device if it is in TUP state (resuming...) - resetDeviceOnTUP = false - - FAST = x.WGFAST - - refreshInterval = 2 * time.Minute // refresh interval between onNotOKs - minRefreshInterval = 5 * time.Second // hard refresh interval; roughly one re-send handshake timeout - - noviaid = "" -) - -type wgifopts struct { - ifaddrs, allowed []netip.Prefix - willreplacepeers bool - peers map[string]device.NoisePublicKey - dns *multihost.MH - eps *multihost.MHMap - mtu int - amnezia *wg.Amnezia -} - -type wgtun struct { - ctx context.Context - done context.CancelFunc - - id string // id - - addrs []netip.Prefix // interface addresses - stack *stack.Stack // stack fakes tun device for wg - ep *channel.Endpoint // reads and writes packets to/from stack - ingress chan *buffer.View // pipes ep writes to wg - events chan tun.Event // wg specific tun (interface) events - finalize chan struct{} // close signal for incomingPacket - once sync.Once // closer fn; exec exactly once - preferOffload bool // UDP GRO/GSO offloads - since int64 // start time in unix millis - - px ProxyProvider - - // mutable fields - - via *core.WeakRef[Proxy] - viaID *core.Volatile[string] - viaUp *core.Volatile[bool] // using via? - direct protect.RDialer - - hasV4, hasV6 atomic.Bool // interface has ipv4/ipv6 routes? - - ignoreTUNClose atomic.Bool // set when re-using existing wgtun+wgep but with a new Device - - desiredmtu atomic.Uint32 // desired mtu - netmtu atomic.Uint32 // underlay network mtu - - rev *core.Volatile[netstack.GConnHandler] // reverser for local packets - - dns *core.Volatile[*multihost.MH] // dns resolver for this interface - remote *core.Volatile[*multihost.MHMap] // peer (remote endpoint) addrs - amnezia *core.Volatile[*wg.Amnezia] // amnezia/warp config, if any - - rt x.IpTree // route table for this interface - - uapicfg *core.Volatile[string] // stores the last UAPI-formatted peer config - - refreshBa *core.Barrier[bool, string] // 2mins refresh barrier - - // TODO: move status to a state-machine for all proxies - status *core.Volatile[int] // status of this interface - statusReason core.Volatile[string] // last state transition reason - latestRefresh atomic.Int64 // last refresh time in unix millis - latestPing atomic.Int64 // last ping time in unix millis - latestErr core.Volatile[error] // last open/dial err - latestRxErr core.Volatile[error] // last rx error - latestTxErr core.Volatile[error] // last tx error - latestRead atomic.Int64 // last read time in unix millis - latestWrite atomic.Int64 // last write time in unix millis - latestGoodRead atomic.Int64 // last successful read time in unix millis - latestGoodWrite atomic.Int64 // last successful write time in unix millis - latestGoodRx atomic.Int64 // last successful rx time in unix millis - latestGoodTx atomic.Int64 // last successful tx time in unix millis - latestRx atomic.Int64 // last (successful or not) rx time in unix millis - latestTx atomic.Int64 // last (successful or not) tx time in unix millis - errRx atomic.Int64 // rx error count - errTx atomic.Int64 // tx error count -} - -type wgconn interface { - conn.Bind - RemoteAddr() netip.AddrPort - Pause() bool - Resume() bool - Closed() bool -} - -var _ WgProxy = (*wgproxy)(nil) - -type wgproxy struct { - *wgtun // implements Proxy and tun.Device - *device.Device // administers tun.Device and conn.Bind - wgep wgconn // implements conn.Bind via wgconn -} - -type WgProxy interface { - Proxy - tun.Device - update(id, txt string) (updated bool) -} - -// Handle implements Proxy. -func (h *wgproxy) Handle() uintptr { - return core.Loc(h) -} - -// DialerHandle implements Proxy. -func (h *wgproxy) DialerHandle() uintptr { - via, up := h.getViaWithStatus() - if up { - return via.Handle() - } - return core.Loc(h.direct) -} - -// Dial implements Proxy. -func (h *wgproxy) Dial(network, address string) (c protect.Conn, err error) { - // ProxyDial resolves address if needed; then dials into all resolved ips. - // return dialers.ProxyDial(h.wgtun, network, address) - return h.wgtun.Dial(network, address) -} - -// DialBind implements Proxy. -func (h *wgproxy) DialBind(network, local, remote string) (c protect.Conn, err error) { - // return dialers.ProxyDialBindh.wgtun, network, local, remote) - return h.wgtun.DialBind(network, local, remote) -} - -// Announce implements Proxy. -func (h *wgproxy) Announce(network, local string) (net.PacketConn, error) { - return h.wgtun.Announce(network, local) -} - -// Accept implements Proxy. -func (h *wgproxy) Accept(network, local string) (net.Listener, error) { - return h.wgtun.Accept(network, local) -} - -// BatchSize implements WgProxy -func (w *wgproxy) BatchSize() int { - return w.wgtun.BatchSize() -} - -// Close implements WgProxy -func (w *wgproxy) Close() error { - // w.wgtun.Close() called by device.Close() via device.tun.Close() - w.Device.Close() - return nil -} - -// Stop implements Proxy -func (w *wgproxy) Stop() error { - log.I("proxy: wg: stopping(%s); status(%s)", w.tag(), pxstatus(w.status.Load())) - return w.Close() -} - -// GetAddr implements x.Proxy -func (h *wgproxy) GetAddr() string { - dst := h.wgep.RemoteAddr() - if !dst.IsValid() { - return noaddr - } - return dst.String() -} - -// onProtoChange implements Proxy -func (w *wgproxy) OnProtoChange(lp LinkProps) (string, bool) { - oldmtu := w.netmtu.Swap(uint32(lp.mtu)) - oldrev := w.rev.Tango(lp.rev) - setRev := settings.ExperimentalWireGuard.Load() - w.setupReverserIfNeeded(setRev) - log.V("proxy: wg: %s; lp changed; setReverser? %t, l3: %s, mtu %d=>%d, rev %X => %X", - w.tag(), setRev, lp.l3, lp.mtu, oldmtu, oldrev, lp.rev) - if err := w.Refresh(); err != nil { - log.W("proxy: wg: %s; lp changed; err: %v", w.tag(), err) - // TODO: return w.cfg, true - } - return "", false // do not re-add this refreshed wg -} - -// Ping implements Proxy -// As backpressure, pings are sent once in a 5s period. -func (w *wgproxy) Ping() bool { - status := w.status.Load() - if err := candial2(status); err != nil { - log.V("proxy: wg: %s ping: err %v, status(%d)", w.tag(), err, pxstatus(status)) - return false - } - - var viaOK bool - if via := w.getViaIfDialed(); via != nil { - viaOK = via.Ping() - } - - now := now() - then := w.latestPing.Load() - neversent := then == 0 - recent := then+pingThresholdMillis < now - if (neversent || !recent) && w.latestPing.CompareAndSwap(then, now) { - // keepalive are empty packets but always padded to 16 bytes - // github.com/bepass-org/warp-plus/blob/12269c2761/wireguard/device/noise-protocol.go#L67 - // github.com/wireguard/wireguard-go/blob/12269c2761/wireguard/device/send.go#L543 - // WireGuard: Next Generation Kernel Network Tunnel, rev e2da747, section 6.5 - w.Device.SendKeepalivesToPeersWithCurrentKeypair() - log.D("proxy: wg: %s ping: via OK? %t", w.tag(), viaOK) - return true - } else { - log.VV("proxy: wg: %s ping: skipped; soon? %t / neversent? %t / concurrent %d; via OK? %t", - w.tag(), !recent, neversent, then, viaOK) - } - return false -} - -func waitForViaHandshake() { - time.Sleep(arbitraryWaitForViaHandshake) -} - -func waitForDeviceUp() { - waitForViaHandshake() -} - -// onNotOK implements Proxy. -func (w *wgproxy) onNotOK() (didRefresh, allok bool) { - s := w.status.Load() - if err := candial2(s); err != nil { // stopped or paused - log.E("proxy: wg: %s onNotOK: %s; status %s; why? %v", w.tag(), pxstatus(s), err) - return - } - - // TODO: skip on s == TUP? - if w.tooyoung() { - log.VV("proxy: wg: %s onNotOK: too young; status %s", w.tag(), pxstatus(s)) - return - } - - var didPing, viaDidRefresh, viaOK bool - - if via := w.getViaIfDialed(); via != nil { - viaDidRefresh, viaOK = via.onNotOK() - } - - var err error - if viaDidRefresh { - waitForViaHandshake() // wait for via to be OK - err = w.Refresh() - didRefresh = true - allok = err == nil - } else { - allok, err = w.refreshBa.DoIt(w.who(), func() (bool, error) { - rerr := w.Refresh() - didRefresh = true - return rerr == nil, rerr - }) - } - if !didRefresh { // attempt Ping if refresh skipped by the barrier - allok = allok && w.Ping() // ping / sendkeepalive is async - didPing = true - } - loged(err)("proxy: wg: %s; onNotOK: refresh? %t+%t; ping? %t; ok? %t+%t; status? %s; err? %v", - w.tag(), viaDidRefresh, didRefresh, didPing, viaOK, allok, pxstatus(s), err) - return -} - -func (w *wgproxy) tooyoung() bool { - return now()-w.since < ageThreshold.Milliseconds() -} - -// Refresh implements Proxy. -func (w *wgproxy) Refresh() (err error) { - status := w.status.Load() - - // todo: Refresh may be called by hop-related changes which may result in one Refresh calls too many. - if err := candial2(status); err != nil { - log.W("proxy: wg: %s refresh failed; status(%s)", w.tag(), pxstatus(status)) - return err - } - - // TODO: skip on s == TUP? - if w.tooyoung() { - log.VV("proxy: wg: %s refresh skipped; too young; status(%s)", w.tag(), pxstatus(status)) - return // TODO: err? - } - - lastRefresh := w.latestRefresh.Load() - if now()-lastRefresh < minRefreshInterval.Milliseconds() { - log.VV("proxy: wg: %s refresh skipped; done recently; status(%s)", w.tag(), pxstatus(status)) - return // TODO: err? - } - - w.latestRefresh.Store(now()) - resetDevice := (resetDeviceOnTNT && status == TNT) || - (resetDeviceOnTUP && status == TUP) - - w.latestPing.Store(0) // reset latest ping time - - n := w.dns.Load().Refresh() - nn := w.remote.Load().Refresh() - - via := w.getVia() - viaOK, didWait := false, false - if via != nil { - var viaDidRefresh bool - if viaDidRefresh, viaOK = via.onNotOK(); viaDidRefresh { - waitForViaHandshake() - didWait = true - } - } - - if err = w.resetMtu(via); err == nil { - // for now, never reset since resetDeviceOnTNT is false - resetDevice = resetDevice && w.wgtun.ignoreTUNClose.CompareAndSwap(false, true) - if resetDevice { - var newdev *device.Device - // Close the old device before creating the new one. - // w.Device.Down() already set bind.ipv4/ipv6 to nil, so Close() is a - // near no-op on the bind here. Doing it in the other order would have - // Close() re-enter Down() and close the bind that newdevice just opened. - w.Device.Close() // tun.Close() is ignored via ignoreTUNClose - w.events <- tun.EventUp - if newdev, err = newdevice(w.wgtun, w.wgep); err == nil { - w.Device = newdev // TODO: core.Volatile[device.Device] - } else { - w.wgtun.ignoreTUNClose.Store(false) // next Close() must not be silently ignored - } - } else { - // err = w.Device.Down() - // prefer sending commands over the events channel to prevent - // racing Up/Down calls via Refresh and other funcs that could - // be called concurrenctly by client code and/or internal code. - w.events <- tun.EventDown - // err = w.Device.Up() - w.events <- tun.EventUp - - waitForDeviceUp() // arbitrary wait for device to be up before sending ipcset - - // Re-apply peer config so wireguard device uses freshly resolved endpoint IPs. - // remote.Refresh() above may have updated IPs; Device.Up() alone does not - // re-call ParseEndpoint, so peers would keep sending handshakes to stale IPs. - w.redoPeers() - } - } - // not required since wgconn:NewBind() is namespace aware - // bindok := bindWgSockets(w.ID(), w.remote.AnyAddr(), w.wgdev, w.ctl) - logei(err)("proxy: wg: %s: refresh done; len(dns): %d, len(peer): %d; viaOK? %t, didWait? %t / reset? %t / status: %s => %s; err? %v", - w.tag(), n, nn, viaOK, didWait, resetDevice, pxstatus(status), pxstatus(w.Status()), err) - return -} - -func (w *wgproxy) redoPeers() { - w.wgtun.ipcset(w.Device) -} - -func (h *wgtun) ipcset(d *device.Device) { - if cfg := h.uapicfg.Load(); len(cfg) > 0 { - cpcfg := cfg // copies string - _, _ = wgIfConfigOf(h.id, &cpcfg) // removes non-uapi fields - ipcerr := d.IpcSet(cpcfg) - logei(ipcerr)("proxy: wg: %s: ipcset: re-apply; err %v", h.tag(), ipcerr) - return - } - log.E("proxy: wg: %s: ipcset: missing uapicfg", h.tag()) -} - -func (h *wgproxy) Dialer() protect.RDialer { - return h -} - -func preferOffload(id string) bool { - return strings.HasPrefix(id, FAST) -} - -func stripPrefixIfNeeded(id string) string { - return strings.TrimPrefix(id, FAST) -} - -// canUpdate checks if the existing tunnel can be updated in-place; -// that is, incoming interface config is compatible with the existing tunnel, -// regardless of whether peer config has changed (which can be updated in-place). -func (w *wgproxy) update(id, txt string) (ok bool) { - const reused = true // can update in-place; reuse existing tunnel - const anew = false // cannot update in-place; create new tunnel - status := w.status.Load() - if status == END { - log.W("proxy: wg: update(%s<>%s): END; status(%s)", id, w.who(), pxstatus(status)) - return anew - } - if status == TNT { - log.W("proxy: wg: update(%s<>%s): TNT; status(%s) - marking session as un-updatable", id, w.who(), pxstatus(status)) - return anew - } - if w.wgep.Closed() { - log.W("proxy: wg: update(%s<>%s): conn closed; status(%s)", id, w.who(), pxstatus(status)) - return anew - } - - incomingPrefersOffload := preferOffload(id) - if incomingPrefersOffload != w.preferOffload { - log.W("proxy: wg: update(%s<>%s): failed; preferOffload() %t != %t", id, w.who(), incomingPrefersOffload, w.preferOffload) - return anew - } - - // str copy: go.dev/play/p/eO814kGGNtO - cptxt := txt - opts, err := wgIfConfigOf(w.id, &cptxt) - if err != nil { - log.W("proxy: wg: update(%s<>%s): err: %v", id, w.who(), err) - return anew - } - - if opts.willreplacepeers { - log.W("proxy: wg: update(%s<>%s): cannot proceed; peers will be replaced", id, w.who()) - return anew - } - - if err := w.setRoutes(opts.ifaddrs); err != nil { - log.W("proxy: wg: update(%s<>%s): failed; setRoutes: %v", id, w.who(), err) - return anew - } - - if settings.Debug { - if !w.amnezia.Load().Same(opts.amnezia) { - log.D("proxy: wg: update(%s<>%s): failed; amnezia %v != %v", - id, w.who(), opts.amnezia, w.amnezia.Load()) - } - if opts.dns != nil && !opts.dns.EqualAddrs(w.dns.Load()) { - log.D("proxy: wg: update(%s<>%s): failed; new/mismatched dns", id, w.who()) - } // nb: client code MUST re-add wg DNS, not our responsibility - } - - maybeNewMtu := calcTunMtu(opts.mtu) // only for logging - - // reusing existing tunnel (interface config unchanged) - // but peer config may have changed! - log.I("proxy: wg: update: (%s<>%s): reuse; mtu: %d=>%d, allowed: %d=>%d; peers: %d; dns: %d=>%d; endpoint: %d=>%d", - id, w.who(), w.ep.MTU(), maybeNewMtu, w.rt.Len(), len(opts.allowed), len(opts.peers), w.dns.Load().Len(), opts.dns.Len(), - w.remote.Load().Len() /*remote.Load may return nil*/, opts.eps.Len()) - - w.allowedIPs(opts.allowed) - w.remote.Store(opts.eps) // requires refresh (wg.Conn:ParseEndpoint must be re-called) - w.remote.Load().Refresh() // resolve endpoints now so ParseEndpoint below sees valid IPs - w.dns.Store(opts.dns) // requires refresh (client must also re-add via intra.AddDNSProxy) - w.desiredmtu.Store(uint32(opts.mtu)) // requires reset; [NOMTU, MAXMTU) - w.amnezia.Store(opts.amnezia) - w.resetMtu(w.getVia()) - - ipcerr := w.Device.IpcSet(cptxt) - if ipcerr != nil { - log.W("proxy: updating wg(%s<>%s) ipcset; err %v", id, w.who(), ipcerr) - return anew - } - // w.Device is assumed to be Up - w.uapicfg.Store(txt) // persist the updated UAPI peer config - - return reused -} - -func (w *wgtun) allowedIPs(allowed []netip.Prefix) { - w.rt.Clear() - for _, ipnet := range allowed { - w.rt.Set(ipnet.String(), w.id) - } - // TODO: remove IPs on peer update -} - -func wglogger(w *wgtun) *device.Logger { - tag := WG + ":" + w.id + ":" + core.LocStr(w) - logger := &device.Logger{ - Verbosef: log.Of(tag, log.V2), - Errorf: log.Of(tag, log.E2), - } - return logger -} - -func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) { - txt := *txtptr - pcfg := strings.Builder{} - r := bufio.NewScanner(strings.NewReader(txt)) - opts.dns = multihost.New(id + "dns") - opts.eps = multihost.NewMap(id + "endpoint") - opts.peers = make(map[string]device.NoisePublicKey) - opts.amnezia = wg.NewAmnezia(id) - opts.mtu = MAXMTU // auto - - var currentPeer *multihost.MH - for r.Scan() { - line := r.Text() - if len(line) <= 0 { - // Blank line means terminate operation. - if (len(opts.ifaddrs) <= 0) || (opts.dns.Len() <= 0) || (opts.mtu <= 0) { - err = errProxyConfig - } - return - } - k, v, ok := strings.Cut(line, "=") - if !ok { - err = fmt.Errorf("proxy: wg: %s failed to parse line %q", id, line) - return - } - k = strings.ToLower(strings.TrimSpace(k)) - v = strings.ToLower(strings.TrimSpace(v)) - - // process interface & peer config; Address, DNS, ListenPort, MTU, Allowed IPs, Endpoint - // github.com/WireGuard/wireguard-android/blob/713947e432/tunnel/src/main/java/com/wireguard/config/Interface.java#L232 - // github.com/WireGuard/wireguard-android/blob/713947e432/tunnel/src/main/java/com/wireguard/config/Peer.java#L176 - switch k { - case "replace_peers": - opts.willreplacepeers = v == "true" || v == "1" - log.D("proxy: wg: %s ifconfig: skipping key %q", id, k) - pcfg.WriteString(line + "\n") - case "address": // may exist more than once - if err = loadIPNets(&opts.ifaddrs, v); err != nil { - return - } - case "dns": // may exist more than once: github.com/celzero/rethink-app/issues/1298 - n := loadMH(opts.dns, v) - aerr := loadIPNets(&opts.allowed, v) - log.D("proxy: wg: %s ifconfig: dns(%d) %s; allowed err? %v", id, n, v, aerr) - case "mtu": - maxxed := false - if len(v) <= 0 || v == AUTOMTU || v == AUTOMTU2 { - opts.mtu = MAXMTU - maxxed = true - } else if opts.mtu, err = strconv.Atoi(v); err != nil { - return - } - if opts.mtu < NOMTU { // negative - opts.mtu = MAXMTU - maxxed = true - } - log.D("proxy: wg: %s ifconfig: mtu %s => %d; maxxed? %t", - id, v, opts.mtu, maxxed) - case "allowed_ip": // may exist more than once - if err = loadIPNets(&opts.allowed, v); err != nil { - return - } - // peer config: carry over allowed_ips - log.D("proxy: wg: %s ifconfig: skipping key %q", id, k) - pcfg.WriteString(line + "\n") - case "endpoint": // may exist more than once - // TODO: endpoint could be v4 or v6 or a hostname - n := loadMH(currentPeer, v) // append endpoints - log.D("proxy: wg: %s ifconfig: endpoints(%d) %s", id, n, v) - - // peer config: carry over endpoints - log.D("proxy: wg: %s ifconfig: skipping key %q", id, k) - pcfg.WriteString(line + "\n") - case "public_key": - var exx error - var peerkey device.NoisePublicKey - if exx = peerkey.FromHex(v); exx == nil { - opts.peers[v] = peerkey - } - // peer config: carry over public keys - log.D("proxy: wg: %s ifconfig: processing key %q=%s, err? %v", id, k, pfxsfx(v), exx) - pcfg.WriteString(line + "\n") - finalizeMH(opts.eps, currentPeer) - if len(v) > 8 { - v = v[:8] - } - // a public_key line points to a transition to a new peer - // github.com/WireGuard/wireguard-go/blob/12269c2761/device/uapi.go#L295 - currentPeer = multihost.New(id + ":" + v) // next peer - case "client_id": - // only for warp: blog.cloudflare.com/warp-technical-challenges - // When we begin a WireGuard session we include our clientid field - // which is provided by our authentication server which has to be - // communicated with to begin a WARP session. - // Though the open source Cloudflare WARP boring-tun impl does not do so: - // github.com/cloudflare/boringtun/blob/64a2fc7c63/boringtun/src/noise/handshake.rs#L734 - if b, err := base64.StdEncoding.DecodeString(v); err == nil && len(b) == 3 { - // github.com/WireGuard/wireguard-go/blob/12269c2761/device/send.go#L456 - // github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56 - h1 := append([]byte{device.MessageInitiationType}, b...) - h2 := append([]byte{device.MessageResponseType}, b...) - h3 := append([]byte{device.MessageCookieReplyType}, b...) - h4 := append([]byte{device.MessageTransportType}, b...) - // overwrite the 3 reserved bytes on all packets - // github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138 - opts.amnezia.H1 = binary.LittleEndian.Uint32(h1) - opts.amnezia.H2 = binary.LittleEndian.Uint32(h2) - opts.amnezia.H3 = binary.LittleEndian.Uint32(h3) - opts.amnezia.H4 = binary.LittleEndian.Uint32(h4) - log.D("proxy: wg: %s ifconfig: clientid(%d) %v", id, len(b), b) - } else { - log.W("proxy: wg: %s ifconfig: clientid(%v) %d == 3?; err: %v", - id, v, len(b), err) - } - case "jc": - // github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/uapi.go#L286 - jc, _ := strconv.Atoi(v) - opts.amnezia.Jc = uint16(jc) - case "jmin": - jmin, _ := strconv.Atoi(v) - opts.amnezia.Jmin = uint16(jmin) - case "jmax": - jmax, _ := strconv.Atoi(v) - opts.amnezia.Jmax = uint16(jmax) - case "s1": - s1, _ := strconv.Atoi(v) - opts.amnezia.S1 = uint16(s1) - case "s2": - s2, _ := strconv.Atoi(v) - opts.amnezia.S2 = uint16(s2) - case "h1": - h1, _ := strconv.ParseUint(v, 10, 32) - opts.amnezia.H1 = uint32(h1) - case "h2": - h2, _ := strconv.ParseUint(v, 10, 32) - opts.amnezia.H2 = uint32(h2) - case "h3": - h3, _ := strconv.ParseUint(v, 10, 32) - opts.amnezia.H3 = uint32(h3) - case "h4": - h4, _ := strconv.ParseUint(v, 10, 32) - opts.amnezia.H4 = uint32(h4) - default: - log.D("proxy: wg: %s ifconfig: skipping key %q", id, k) - pcfg.WriteString(line + "\n") - } - } - finalizeMH(opts.eps, currentPeer) - *txtptr = pcfg.String() - if err == nil && len(opts.ifaddrs) <= 0 || opts.dns.Len() <= 0 || opts.mtu <= NOMTU { - err = errProxyConfig - } - loged(err)("proxy: wg: %s; addr: %d, dns: %d, mtu: %d, eps: %d; amnezia: %s", - id, len(opts.ifaddrs), opts.dns.Len(), opts.mtu, opts.eps.Len(), opts.amnezia) - return -} - -func finalizeMH(m *multihost.MHMap, currentPeer *multihost.MH) bool { - if currentPeer == nil { - return false - } - return m.Put(currentPeer) -} - -func loadMH(mh *multihost.MH, v string) int { - if mh == nil || len(v) <= 0 { - return 0 - } - vv := strings.Split(v, ",") - return mh.Add(vv) // vv may be host:port, ip:port, host, or ip -} - -func loadIPNets(out *[]netip.Prefix, v string) (err error) { - var ip netip.Addr - // may be a csv: "172.1.0.2/32, 2000:db8::2/128" - for str := range strings.SplitSeq(v, ",") { - var ipnet netip.Prefix - str = strings.TrimSpace(str) - if ip, err = netip.ParseAddr(str); err != nil { - if ipnet, err = netip.ParsePrefix(str); err != nil { - return - } - *out = append(*out, ipnet) - } else { // add prefix to address - if ipnet, err = ip.Prefix(ip.BitLen()); err != nil { - return - } - *out = append(*out, ipnet) - } - } - return -} - -// ref: github.com/WireGuard/wireguard-android/blob/713947e432/tunnel/tools/libwg-go/api-android.go#L76 -func NewWgProxy(id string, ctl protect.Controller, px ProxyProvider, lp LinkProps, cfg string) (*wgproxy, error) { - ogcfg := cfg - opts, err := wgIfConfigOf(id, &cfg) - if err != nil { - log.E("proxy: wg: %s failure getting opts from config %v", id, err) - return nil, err - } - - wgtun, err := makeWgTun(id, ogcfg, ctl, px, lp, opts) - if err != nil { - log.E("proxy: wg: %s failed to create tun %v", id, err) - return nil, err - } - - var wgep wgconn - if wgtun.preferOffload { - wgep = wg.NewEndpoint2(wgtun.who(), wgtun.serve, wgtun.remote, wgtun.listener, wgtun.amnezia) - } else { - wgep = wg.NewEndpoint(wgtun.who(), wgtun.serve, wgtun.remote, wgtun.listener, wgtun.amnezia) - } - - wgdev, err := newdevice(wgtun, wgep) - if err != nil { - return nil, err - } - - w := &wgproxy{ - wgtun, // stack - wgdev, // device - wgep, // endpoint - } - - log.D("proxy: wg: new %s; addrs(%v) mtu(%d/%d) peers(%d) / v4(%t) v6(%t)", - wgtun.tag(), opts.ifaddrs, opts.mtu, w.ep.MTU(), len(opts.peers), wgtun.IP4(), wgtun.IP6()) - - return w, nil -} - -func newdevice(wgtun *wgtun, wgep wgconn) (*device.Device, error) { - wgdev := device.NewDevice(wgtun, wgep, wglogger(wgtun)) - - wgtun.ipcset(wgdev) // apply initial config to device - - // github.com/WireGuard/wireguard-android/blob/713947e432/tunnel/tools/libwg-go/api-android.go#L99 - wgdev.DisableSomeRoamingForBrokenMobileSemantics() - - // not needed: tun.EventUp is already queued by makeWgTun() - // which will be consumed by wireguard's RoutineTUNEventReader - // started by device.NewDevice() - // err = wgdev.Up() - // TODO: wait for wgconn to open? - log.I("proxy: wg: %s new device created %s", wgtun.tag(), core.LocStr(wgdev)) - return wgdev, nil -} - -func (t *wgtun) setupReverserIfNeeded(set bool) (didSet bool) { - s := t.stack - id := t.id - rev := t.rev.Load() - - if rev != nil && set { - // inbound (aka reverse outbound) - netstack.OutboundTCP(id, s, rev.TCP()) - netstack.OutboundUDP(id, s, rev.UDP()) - log.I("proxy: wg: %s rev @ %X enabled", t.tag(), rev) - return true - } // do not use reverser - - logeif(set)("proxy: wg: %s remove rev; must set? %t", t.tag(), set) - - netstack.OutboundTCP(id, s, nil) // unset - netstack.OutboundUDP(id, s, nil) // unset - return false -} - -func (w *wgtun) swapVia(new Proxy) (old Proxy) { - return swapVia(w.who(), new, w.viaID, w.via) -} - -func (w *wgtun) viafor() *Proxy { - return viafor(w.who(), w.viaID.Load(), w.px) -} - -func (w *wgtun) getVia() (v Proxy) { - return w.via.Load() -} - -func (w *wgtun) getViaWithStatus() (v Proxy, up bool) { - up = w.viaUp.Load() - v = w.getVia() - return -} - -func (w *wgtun) getViaIfDialed() Proxy { - if v, up := w.getViaWithStatus(); up { - return v - } - return nil -} - -// who concats id of this proxy & status of its via. -func (w *wgtun) who() string { - return w.id + ":" + core.LocStr(w) -} - -func (w *wgtun) tag() string { - return w.who() + " (" + w.viaStatus() + ")" -} - -func (w *wgtun) viaStatus() (s string) { - v, up := w.getViaWithStatus() - if vid := idstr(v); len(vid) > 0 { - s += vid - if up { - s += "/up" - } else { - s += "/down" - } - } else { - if vid = w.viaID.Load(); len(vid) > 0 { - s += vid + "/mia" - } else { - s += "novia/zz" - } - } - return s -} - -func (t *wgtun) maybeSpoof(spoof bool) { - log.I("proxy: wg: %s spoofing? %t", t.tag(), spoof) - // github.com/xjasonlyu/tun2socks/blob/31468620e/core/stack.go#L80 - _ = t.stack.SetSpoofing(wgnic, spoof) - // github.com/tailscale/tailscale/blob/c4d0237e5c/wgengine/netstack/netstack.go#L345-L350 - _ = t.stack.SetPromiscuousMode(wgnic, spoof) -} - -// ref: github.com/WireGuard/wireguard-go/blob/469159ecf7/tun/netstack/tun.go#L54 -func makeWgTun(id, cfg string, ctl protect.Controller, px ProxyProvider, lp LinkProps, ifopts wgifopts) (*wgtun, error) { - ctx, done := context.WithCancel(context.Background()) - - allowIncoming := settings.ExperimentalWireGuard.Load() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: !allowIncoming, - } - - minmtu := minmtu6 // ip6 or ip6 or ip4+ip6 - if lp.l3 == settings.IP4 { - minmtu = minmtu4 // ip4 - } - - tunMtu := reconcileMtu(lp.mtu, ifopts.mtu, minmtu) - if tunMtu <= NOMTU { - done() - return nil, errNoMtu - } - - s := stack.New(opts) - ep := channel.New(epsize, uint32(tunMtu), "") - netstack.SetNetstackOpts(s) - - t := &wgtun{ - ctx: ctx, - done: done, - id: stripPrefixIfNeeded(id), - addrs: ifopts.ifaddrs, - ep: ep, - stack: s, - events: make(chan tun.Event, eventssize), - ingress: make(chan *buffer.View, epsize), - finalize: make(chan struct{}), // always unbuffered - direct: protect.MakeNsRDial(id, ctx, ctl), - px: px, - viaID: core.NewZeroVolatile[string](), - viaUp: core.NewZeroVolatile[bool](), - dns: core.NewVolatile(ifopts.dns), - rev: core.NewVolatile(lp.rev), - remote: core.NewVolatile(ifopts.eps), // may be nil - rt: x.NewIpTree(), // must be set to allowedaddrs - amnezia: core.NewVolatile(ifopts.amnezia), - status: core.NewVolatile(TUP), - preferOffload: preferOffload(id), - refreshBa: core.NewBarrier[bool](refreshInterval), - uapicfg: core.NewVolatile(cfg), - since: now(), - } - t.latestRefresh.Store(t.since) - t.desiredmtu.Store(uint32(ifopts.mtu)) - t.netmtu.Store(uint32(lp.mtu)) - t.allowedIPs(ifopts.allowed) - - if viaref, verr := core.NewWeakRef(t.viafor, viaok); verr != nil { - done() - return nil, fmt.Errorf("wg: %s create tun (via ref): %v", t.id, verr) - } else { - t.via = viaref - } - - // TODO: wgnic := s.NextNICID() - // see WriteNotify below - ep.AddNotify(t) - - if err := s.CreateNIC(wgnic, ep); err != nil { - done() - ep.Close() - return nil, fmt.Errorf("wg: %s create nic: %v", t.who(), err) - } - - settings.ExperimentalWireGuard.On(ctx, func(yn bool) { - t.maybeSpoof(yn) - t.setupReverserIfNeeded(yn) - }) - - if err := t.setRoutes(ifopts.ifaddrs); err != nil { - done() - ep.Close() - return nil, err - } - - // commence the wireguard state machine the second Device is created - t.events <- tun.EventUp - - if4, if6 := netstack.StackAddrs(s, wgnic) - log.I("proxy: wg: %s tun: created; handleLocal[%t]; dns[%s]; dst[%s]; mtu[%d]; ifaddrs[%v / %v]; amnezia[%t]", - t.tag(), !allowIncoming, ifopts.dns, ifopts.eps, tunMtu, if4, if6, ifopts.amnezia.Set()) - - return t, nil -} - -func (t *wgtun) setRoutes(ifaddrs []netip.Prefix) error { - has4, has6 := false, false - processed := make(map[netip.Prefix]bool) - // clear existing addresses - if addr4, err := t.stack.GetMainNICAddress(wgnic, ipv4.ProtocolNumber); err == nil { - log.I("proxy: wg: %s replacing permanent addr4(%d) %v", t.tag(), wgnic, addr4.Address) - t.stack.RemoveAddress(wgnic, addr4.Address) - } - if addr6, err := t.stack.GetMainNICAddress(wgnic, ipv6.ProtocolNumber); err == nil { - log.I("proxy: wg: %s replacing permanent addr6(%d) %v", t.tag(), wgnic, addr6.Address) - t.stack.RemoveAddress(wgnic, addr6.Address) - } - for _, ipnet := range ifaddrs { - ip := ipnet.Addr() - if processed[ipnet] { - log.W("proxy: wg: %s skipping duplicate ip %v for ifaddr %v", - t.tag(), ip, ipnet) - continue - } - processed[ipnet] = true - - var protoid tcpip.NetworkProtocolNumber - var nsaddr tcpip.Address - if ip.Is4() { - protoid = ipv4.ProtocolNumber - nsaddr = tcpip.AddrFrom4(ip.As4()) - has4 = true - } else if ip.Is6() { - protoid = ipv6.ProtocolNumber - nsaddr = tcpip.AddrFrom16(ip.As16()) - has6 = true - } - ap := tcpip.AddressWithPrefix{ - Address: nsaddr, - PrefixLen: ipnet.Bits(), - } - protoaddr := tcpip.ProtocolAddress{ - Protocol: protoid, - AddressWithPrefix: ap, - } - if err := t.stack.AddProtocolAddress(wgnic, protoaddr, stack.AddressProperties{}); err != nil { - return fmt.Errorf("wg: %s add (v4? %t) addr(%v): %v", t.tag(), has4, ip, err) - } - - log.I("proxy: wg: %s added (v4? %t) ifaddr(%v)", t.tag(), has4, ap) - } - - if has4 || t.hasV4.Load() { - t.hasV4.Store(true) - t.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: wgnic}) - } else { - t.hasV4.Store(false) - t.stack.RemoveRoutes(func(r tcpip.Route) bool { - return r.Destination == header.IPv4EmptySubnet && r.NIC == wgnic - }) - } - if has6 || t.hasV6.Load() { - t.hasV6.Store(true) - t.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: wgnic}) - } else { - t.hasV6.Store(false) - t.stack.RemoveRoutes(func(r tcpip.Route) bool { - return r.Destination == header.IPv6EmptySubnet && r.NIC == wgnic - }) - } - return nil -} - -// implements tun.Device - -// Name implements tun.Device. -func (tun *wgtun) Name() (string, error) { - return tun.id, nil -} - -// File implements tun.Device. -func (tun *wgtun) File() *os.File { - return nil -} - -// Events implements tun.Device. -func (tun *wgtun) Events() <-chan tun.Event { - return tun.events -} - -// Read implements tun.Device. -func (tun *wgtun) Read(buf [][]byte, sizes []int, offset int) (int, error) { - view, ok := <-tun.ingress - if !ok { - log.W("wg: %s tun: read closed", tun.tag()) - return 0, os.ErrClosed - } - - n, err := view.Read(buf[0][offset:]) - if err != nil { - log.W("wg: %s tun: read(%d): %v", - tun.tag(), n, err) - return 0, err - } - - if settings.Debug { - log.VV("wg: %s tun: read(%d)", tun.tag(), n) - } - sizes[0] = n - return 1, nil -} - -// Write implements tun.Device. -func (tun *wgtun) Write(bufs [][]byte, offset int) (int, error) { - for _, buf := range bufs { - pkt := buf[offset:] - if len(pkt) == 0 { - log.D("wg: %s tun: write: empty packet", tun.tag()) - continue - } - - sz := len(pkt) - b := buffer.MakeWithData(pkt) - pko := stack.PacketBufferOptions{Payload: b} - pkb := stack.NewPacketBuffer(pko) - defer pkb.DecRef() - protoid := pkt[0] >> 4 - switch protoid { - case 4: // IPv4 - tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) // write to ep - case 6: // IPv6 - tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) // write to ep - default: - log.W("wg: %s tun: write: unknown proto %d; discard %d", - tun.tag(), protoid, sz) - return 0, syscall.EAFNOSUPPORT - } - if settings.Debug { - log.VV("wg: %s tun: write: sz(%d); proto %d", - tun.tag(), sz, protoid) - } - } - - return len(bufs), nil -} - -// WriteNotify is called by channel notifier on readable events -// github.com/google/gvisor/blob/acf460d0d73/pkg/tcpip/link/channel/channel.go#L31 -func (tun *wgtun) WriteNotify() { - pkt := tun.ep.Read() - if pkt == nil { - return - } - - view := pkt.ToView() - pkt.DecRef() - - sz := view.Size() - - select { - case <-tun.finalize: // dave.cheney.net/2013/04/30/curious-channels - log.I("wg: %s tun: write: finalize; dropped pkt; sz(%d)", - tun.tag(), sz) - default: - select { - case <-tun.finalize: - case tun.ingress <- view: // closed chans panic on send: groups.google.com/g/golang-nuts/c/SDIBFSkDlK4 - if settings.Debug { - log.VV("wg: %s tun: write: notify sz(%d)", - tun.tag(), sz) - } - default: // ingress is full and finalize is blocked - e := tun.status.Load() - log.W("wg: %s tun: write: closed? %s; dropped pkt; sz(%d)", - tun.tag(), pxstatus(e), sz) - } - } -} - -func (tun *wgtun) Close() error { - // wgproxy inherits h.status: go.dev/play/p/HeU5EvzAjnv - if tun.status.Load() == END { - log.W("wg: %s tun: already closed?", tun.tag()) - return errProxyStopped - } - if tun.ignoreTUNClose.CompareAndSwap(true, false) { - log.I("wg: %s tun: ignore close this once", tun.tag()) - return nil // ignore - } - - var err error - tun.once.Do(func() { - prev := tun.status.Swap(END) // TODO: move this to wgproxy.Close()? - - log.D("wg: %s tun: (prev status: %s) closing...", tun.tag(), pxstatus(prev)) - - tun.done() // unblock inject and dialers - - tun.stack.RemoveNIC(wgnic) - // if tun.events != nil { - // panics; is it closed by device.Device.Close()? - // close(tun.events) } - close(tun.ingress) - - tun.viaID.Store(noviaid) // via is nil - tun.viaUp.Store(false) - - // github.com/tailscale/tailscale/blob/836f932e/wgengine/netstack/netstack.go#L223 - - // stack closes the endpoint, too via nic.go#remove? - // tun.ep.Close() - // destroy waits for the stack to close - tun.stack.Destroy() - log.I("wg: %s tun: closed", tun.tag()) - }) - return err -} - -// Implements Router. -// TODO: use wgtun as a receiver for Stats() -// Never returns nil. -func (w *wgproxy) Stat() (out *x.RouterStats) { - start := time.Now() - defer log.VV("proxy: wg: %s stats: end (duration: %s)", w.tag(), core.FmtTimeAsPeriod(start)) - - out = new(x.RouterStats) - - out.Addrs = w.ifaddrs() // may be empty - out.Rx = -1 - out.Tx = -2 - out.LastOK = -3 - out.ErrRx = w.errRx.Load() - out.ErrTx = w.errTx.Load() - out.LastErr = estr(w.latestErr.Load()) - out.LastRxErr = estr(w.latestRxErr.Load()) - out.LastTxErr = estr(w.latestTxErr.Load()) - out.LastRx = w.latestRx.Load() - out.LastTx = w.latestTx.Load() - out.LastGoodRx = w.latestGoodRx.Load() - out.LastGoodTx = w.latestGoodTx.Load() - out.LastRefresh = w.latestRefresh.Load() - out.Since = w.since - out.Status = pxstatus(w.status.Load()).String() - out.StatusReason = w.statusReason.Load() - - if w.status.Load() == END { - log.W("proxy: wg: %s stats: stopped", w.tag()) - return // zz - } - - stat := wg.ReadStats(w.id, w.Handle(), w.Device.IpcGet) - if stat != nil { // unlikely - out.Rx = stat.TotalRx() - out.Tx = stat.TotalTx() - out.LastOK = stat.LatestRecentHandshake() - } - - if settings.Debug { - out.Extra = w.remote.Load().String() + "\n" + w.dns.Load().String() - - log.VV("proxy: wg: %s stats: rx: %d, tx: %d, r: %s (rlastok: %s), w: %s (wlastok: %s), lastok: %s", - w.tag(), out.Rx, out.Tx, - core.FmtUnixMillisAsPeriod(out.LastRx), core.FmtUnixMillisAsPeriod(out.LastGoodRx), - core.FmtUnixMillisAsPeriod(out.LastTx), core.FmtUnixMillisAsPeriod(out.LastGoodTx), - core.FmtUnixMillisAsPeriod(out.LastOK)) - } - return out -} - -func (w *wgtun) ifaddrs() string { - ifs := w.addrs - if len(ifs) > 0 { - s := core.Map(ifs, func(a netip.Prefix) string { return a.String() }) - return strings.Join(s, ",") - } - return noaddr -} - -// MTU implements tun.Device. -func (tun *wgtun) MTU() (int, error) { - return calcNetMtu(int(tun.ep.MTU())), nil -} - -// BatchSize implements tun.Device. -func (tun *wgtun) BatchSize() int { - if tun.preferOffload { - return conn.IdealBatchSize - } - return 1 -} - -// Dial implements proxy.Dialer and protect.RDialer -func (h *wgtun) Dial(network, address string) (c net.Conn, err error) { - // wgproxy.Dial => dialers.ProxyDial => wgtun.Dial - if err := candial(h.status); err != nil { - return nil, err - } - - log.D("wg: %s dial: start %s %s", h.tag(), network, address) - - // DialContext resolves addr if needed; then dialing into all resolved ips. - c, err = h.DialContext(h.ctx, network, address) - defer h.listener(wg.Con, err) // status updated by h.listener - - log.I("wg: %s dial: end %s %s; err %v", h.tag(), network, address, err) - return -} - -// DialBind implements proxy.Dialer and protect.RDialer -func (h *wgtun) DialBind(network, local, remote string) (c net.Conn, err error) { - // wgproxy.DialBind => wgtun.Dial - if err := candial(h.status); err != nil { - return nil, err - } - - log.D("wg: %s dialbind: start %s %s=>%s", h.tag(), network, local, remote) - - // DialContext resolves addr if needed; then dialing into all resolved ips. - c, err = h.DialContext(h.ctx, network, remote) - defer h.listener(wg.Con, err) // status updated by h.listener when creating conns - - log.I("wg: %s dialbind: end %s %s=>%s; err %v", h.tag(), network, local, remote, err) - return -} - -// Announce implements protect.RDialer -func (h *wgtun) Announce(network, local string) (pc net.PacketConn, err error) { - // wgproxy.Dial => dialers.ProxyListenPacket => protect.AnnounceUDP => wgtun.Announce - if err := candial(h.status); err != nil { - return nil, err - } - - log.D("wg: %s announce: start %s %s", h.tag(), network, local) - - var addr netip.AddrPort - if addr, err = netip.ParseAddrPort(local); err == nil { - pc, err = h.ListenUDPAddrPort(addr) - defer h.listener(wg.Con, err) - } // else: expect local to always be ipaddr - - log.I("wg: %s announce: end %s %s; err %v", h.tag(), network, local, err) - return -} - -// Accept implements protect.RDialer -func (h *wgtun) Accept(network, local string) (ln net.Listener, err error) { - // wgproxy.Dial => dialers.ProxyListen => protect.AcceptTCP => wgtun.Accept - if err := candial(h.status); err != nil { - return nil, err - } - - log.D("wg: %s accept: start %s %s", h.tag(), network, local) - - var addr netip.AddrPort - if addr, err = netip.ParseAddrPort(local); err == nil { - ln, err = h.ListenTCPAddrPort(addr) - defer h.listener(wg.Con, err) - } // else: expect local to always be ipaddr - - log.I("wg: %s accept: end %s %s; err %v", h.tag(), network, local, err) - return -} - -// Probe implements protect.RDialer -func (h *wgtun) Probe(network, local string) (pc net.PacketConn, err error) { - // wgproxy.Dial => dialers.ProxyListen => protect.AcceptTCP => wgtun.Accept - if err := candial(h.status); err != nil { - return nil, err - } - - log.D("wg: %s probe: start %s %s", h.tag(), network, local) - - var addr netip.AddrPort - if addr, err = netip.ParseAddrPort(local); err == nil { - pc, err = h.ListenUDPAddrPort(addr) - defer h.listener(wg.Con, err) - } // else: expect local to always be ipaddr - - log.I("wg: %s probe: end %s %s; err %v", h.tag(), network, local, err) - return -} - -// ID implements x.Proxy. -func (h *wgtun) ID() string { - return h.id -} - -// Type implements x.Proxy. -func (h *wgtun) Type() string { - return WG -} - -// Router implements Proxy. -// TODO: make wgtun a Router; see Stats() -func (h *wgproxy) Router() x.Router { - return h -} - -// Reaches implements x.Router. -// TODO: make wgtun a Router; see Stats() -func (h *wgproxy) Reaches(hostportOrIPPortCsv string) bool { - return Reaches(h, hostportOrIPPortCsv) -} - -// Hop implements Proxy. -func (h *wgproxy) Hop(via Proxy, dryrun bool) (err error) { - var old Proxy - - defer func() { - if dryrun { - return - } - - log.I("wg: %s hop: old(%s) => new(%s); err? %v", - h.id, idhandle(old), idhandle(via), err) - - if Same(old, via) { - return - } - if err == nil { - core.Gxe("wg.hop.refresh."+h.id, h.Refresh) // reconnect - } - }() - - if via == nil { - if !dryrun { - old = h.swapVia(nil) - // undo MTU enforced due to any prior hops - if old != nil { - err = h.resetMtu(nil) - } - log.I("wg: %s hop: %s removed; mtu reset err? %v", - h.id, idhandle(old), err) - } - return nil - } else if Same(h, via) { - return errHopSelf - } - - if via.Status() == END { - return errProxyStopped - } - - if !isWG(idstr(via)) { // for now, only wg can hop another wg - return errHopWireGuard - } - - if err := viaCanBind(h, via); err != nil { - return err - } - - // mtu needed to tunnel this wg - if err := h.maybeResetMtu(via, dryrun); err != nil { - return err // could not set mtu - } - - // hop must be able to route all of orig's peers - if err := h.viaCanRoute(via, dryrun); err != nil { - return err // via cannot not route peers - } - - if !dryrun { - old = h.swapVia(via) - } - return nil -} - -// Via implements x.Router. -func (h *wgproxy) Via() (x.Proxy, error) { - if v, up := h.getViaWithStatus(); v == nil { - return nil, errNoHop - } else if !up { - return nil, errHopNotConnected - } else { - return v, nil - } -} - -// Stats implements Proxy. -func (h *wgtun) Status() int { - return h.status.Load() -} - -// Pause implements x.Proxy. -func (h *wgproxy) Pause() (paused bool) { - defer func() { - if paused { - h.wgep.Pause() - } - }() - - st := h.status.Load() - if st == END { - log.W("wg: %s pause called when stopped", h.tag()) - return false - } - - paused = h.status.Cas(st, TPU) - log.I("wg: %s paused? %t", h.tag(), paused) - - return -} - -// Resume implements x.Proxy. -func (h *wgproxy) Resume() (resumed bool) { - st := h.status.Load() - if st != TPU { - log.W("wg: %s resume called when not paused; status %d", h.tag(), st) - return false - } - - resumed = h.status.Cas(st, TUP) - if resumed { - h.wgep.Resume() - } - core.Gxe("wg.resume.refresh."+h.id, h.Refresh) // refresh unconditionally - - log.I("wg: %s resumed? %t", h.tag(), resumed) - - return -} - -// DNS implements x.Proxy. -func (h *wgtun) DNS() string { - return h.dnsResolvers() -} - -func (h *wgtun) dnsResolvers() string { - var s string - // prefer hostnames over IPs: - // hostnames may resolve to different IPs on different networks; - // tunnel could use hostnames to implement "refresh" - dnsm := h.dns.Load() - if dnsm != nil { - names := dnsm.Names() - for _, hostname := range names { - s += hostname + "," - } - log.D("wg: %s dns hostnames (in: %d); out: %s", h.tag(), names, s) - if len(s) > 0 { // return names, if any - return strings.TrimRight(s, ",") - } - - addrs := dnsm.Addrs() - for _, dns := range addrs { - if dns.Addr().IsUnspecified() || !dns.IsValid() { - continue - } - // may be private, link local, etc - s += dns.Addr().Unmap().String() + "," - } - - log.D("wg: %s dns ipaddrs (in: %v); out: %s", h.tag(), addrs, s) - if len(s) > 0 { // return ipaddrs, if any - return strings.TrimRight(s, ",") - } - - log.W("wg: %s dns: not found (names: %v; addrs: %s)", h.tag(), names, addrs) - } else { // unlikely as wireguard config is considered invalid if DNS not set - log.E("wg: %s dns: nil", h.tag()) - } - return "" // nodns -} - -// Implements x.Router. -func (h *wgtun) IP4() bool { return h.hasV4.Load() } -func (h *wgtun) IP6() bool { return h.hasV6.Load() } - -// Contains implements x.Router. -func (h *wgtun) Contains(ippOrCidr string) bool { - var err error - y1, y2 := false, false - canroute6 := h.IP6() - canroute4 := h.IP4() - - y1, err = h.rt.HasAny(ippOrCidr) - if y1 { - y2 = true // assume all okay - if cidr, err := core.IP2Cidr2(ippOrCidr); err == nil { - is6 := cidr.Addr().Is6() - is4 := cidr.Addr().Is4() - y2 = (is6 && canroute6) || (is4 && canroute4) - } // fallback onto y1 on errs. - } // y2 is also false. - - logev(err)("wg: %s router: (4/6? %t/%t) %s; allowed? %t / contains? %t; err? %v", - h.tag(), canroute4, canroute6, ippOrCidr, y1, y2, err) - - return y1 && y2 -} - -func (h *wgtun) serve(network, local string) (pc net.PacketConn, err error) { - if err := canserve(h.status); err != nil { - return nil, err - } - - // todo: dial into both direct & via if via cannot handle all routes? - who := h.who() - var v Proxy // may be nil - hasvia, usingvia := false, false - if hasvia = usevia(h.viaID); hasvia { - if v, usingvia = h.via.Get(); v != nil && usingvia { - if rerr := h.viaCanRoute(v, false /*dryrun*/); rerr != nil { - usingvia = false - err = rerr - } else { - // TODO: use Dial if announce fails to "port-forward" on via - pc, err = v.Announce(network, local) - } - } else { - usingvia = false - err = errNoHop - } - } else { // dial direct - pc, err = h.direct.Announce(network, local) - } - - if !usingvia { - // wgproxy.Refresh() is not needed since serve is called - // at the time of wgproxy.Device.Up() anyway. - if hasvia { - log.W("wg: %s via(%s) failing... %v", who, idhandle(v), err) - if removeViaOnErrors { - // todo: call h.Hop(nil) instead? - h.swapVia(nil) // stale; unset - } - } - } - - h.viaUp.Store(usingvia) - defer h.listener(wg.Opn, err) - - logei(err)("wg: %s serve: %s (via? %s %t / usingVia? %t); err? %v", - who, local, idstr(v), hasvia, usingvia, err) - return -} - -func (h *wgtun) listener(op wg.PktDir, err error) (ended bool) { - s := h.status.Load() - cur := s - ended = s == END - - if op != wg.Clo { - if op.Read() { - h.latestRxErr.Store(err) - } else if op.Write() { - h.latestTxErr.Store(err) - } else { - h.latestErr.Store(err) - } - } - - if s == END || s == TPU { // stopped or paused - h.statusReason.Store("TXX: paused or stopped") - log.E("wg: %s listener: %s; status %s; ignoring1", h.tag(), op, pxstatus(s)) - return - } - - if s == TUP && op != wg.Opn { // ignore all else but open - h.statusReason.Store("TUP: waiting for wgconn") - return - } - - why := "" - - defer func() { - h.statusReason.Store(why) - updated := cur == s - ended = s == END - if !updated { - updated = h.status.Cas(cur, s) - } - logeif(!updated)("wg: %s listener: %s; status %s => %s; transition? %t, statusupdated? %t, why: %s", - h.tag(), op, pxstatus(cur), pxstatus(s), cur != s, !updated, why) - }() - - if op == wg.Clo { - why = "TNT: closed; prev: " + pxstatus(s).String() - s = TNT - return - } - - now := now() - age := now - h.since - if err != nil { // failing - s = TKO - why = "TKO: " + err.Error() - if op == wg.Opn { // could not open conn to wg endpoint - s = TNT - why = "TNT: could not open conn" - } else if op.Read() && timedout(err) { - s = TZZ // writes and reads have succeeded in the recent past - why = "TZZ: read timeout" - } else if errors.Is(err, net.ErrClosed) { - // github.com/WireGuard/wireguard-go/blob/f333402bd9cb/device/receive.go#L112 - // on net.ErrClosed, wg stops recieving routine for all peers; this among - // other things mean that the wg.Device is effectively down and would not - // recieve any incoming messages (nor outgoing as those use the same socket) - s = TNT - why = "TNT: closed " + string(op) - } - - if op == wg.Rcv && !timedout(err) { // read error - h.errRx.Add(1) - h.latestRx.Store(now) - } else if op == wg.Snd { // write error - h.errTx.Add(1) - h.latestTx.Store(now) - } // else: not a transport message - - if op.Read() { - h.latestRead.Store(now) - } else if op.Write() { - h.latestWrite.Store(now) - } - } else { // ok - s = TOK - why = "TOK: ok" - if op == wg.Rcv { // read ok - h.latestGoodRx.Store(now) - h.latestRx.Store(now) - why = "TOK: read ok" - } else if op == wg.Snd { // write ok - h.latestGoodTx.Store(now) - h.latestTx.Store(now) - why = "TOK: write ok" - } // else: not a transport message - - if op.Read() { - h.latestGoodRead.Store(now) - } else if op.Write() { - h.latestGoodWrite.Store(now) - } - } - - // s may also be TOK (for successful handshakes but not for transport data) - if age > ageThreshold.Milliseconds() && (s == TOK || s == TKO) { - lastSuccessfulRead := h.latestGoodRead.Load() - lastSuccessfulWrite := h.latestGoodWrite.Load() - lastRead := h.latestRx.Load() - lastWrite := h.latestTx.Load() - - deviationMs := (max(lastSuccessfulWrite, lastSuccessfulRead) - - min(lastSuccessfulWrite, lastSuccessfulRead)) - readElapsedMs := lastRead - lastSuccessfulRead // never negative - writeElapsedMs := lastWrite - lastSuccessfulWrite // never negative - - hasNewWrites := lastWrite > age - hasNewReads := lastRead > age - - // too much time since last good write and good reads - readWriteDeviation := (hasNewReads || hasNewWrites) && deviationMs > markTNTAfterMillis - // too much time since last attempted read was good - readThres := hasNewReads && readElapsedMs > markTNTAfterMillis - // too much time since last attempted write was good - writeThres := hasNewWrites && writeElapsedMs > markTNTAfterMillis - - // if status is !ok (TKO), no reads since last write, mark as unresponsive - // if status is ok (TOK) but writes have not yet happened - // then reads (Rcv) are expected to timeout; so ignore them - if !hasNewReads && !hasNewWrites { - why = "TZZ: idling after start/refresh" - s = TZZ // possibly idling - } else if readThres || writeThres || readWriteDeviation { - why = fmt.Sprintf("TZZ: r !ok? %t, w !ok? %t, rw apart? %t; overriding: %s", - readThres, writeThres, readWriteDeviation, why) - s = TNT - } - } - - if s == TNT { - // listener is called from wgconn and must retrun without performing blocking ops - core.Go("wg.listener.refresh."+h.id, func() { - m := h.dns.Load().SoftRefresh() - if n := h.remote.Load().MaybeRefresh(); n > 0 { - log.I("wg: %s listener: %s, state: %s; refreshed %d dns / %d peers; why: %s", - h.tag(), op, pxstatus(s), m, n, why) - } - // TODO: h.redoPeers() - }) - } - return -} - -// func Handle(), GetAddr(), Dialer(), Reaches(), Stop(), -// OnProtoChange(), Ping(), Stats(), Router() impl by wgproxy. - -// now returns the current time in unix millis -func now() int64 { - return time.Now().UnixMilli() -} - -func (w *wgproxy) resetMtu(via Proxy) error { - return w.maybeResetMtu(via, false /*dryrun*/) -} - -func (w *wgtun) viaCanRoute(via Proxy, dryrun bool) error { - weCan4 := w.IP4() - hopCan4 := via.Router().IP4() - weCan6 := w.IP6() - hopCan6 := via.Router().IP6() - - check4 := weCan4 && hopCan4 - check6 := weCan6 && hopCan6 - - if !check4 && !check6 { - return errHopProxyRoutes - } - - viaRouter := via.Router() - all := w.remote.Load().All() - for _, p := range multihost.Flatten(all) { - ip := p.Addr() - if (ip.Is4() && check4) || (ip.Is6() && check6) { - if !viaRouter.Contains(ip.String()) { - return log.EE("wg: %s proxy: viaCanRoute: via %s cannot route peer %s; dryrun? %t", - w.tag(), idstr(via), p, dryrun) - } - } - } - - log.D("wg: %s proxy: viaCanRoute: via %s can route all peers %v (dryrun? %t / 4? %t / 6? %t)", - w.tag(), idstr(via), all, dryrun, check4, check6) - return nil -} - -func (w *wgproxy) maybeResetMtu(via Proxy, dryrun bool) error { - // mtu needed to tunnel this wg - mtuNeededByUs := int(w.desiredmtu.Load()) - mtuAvailFromNet := int(w.netmtu.Load()) - mtuAvailable := mtuAvailFromNet - hopping := false // tunnled via another proxy - viaid := idstr(via) - - note := log.I - if dryrun { - note = log.D - } - - if via != nil { - // mtu affordable by via (routerMtu = endpointMtu + wgHeader) - if mtuAvailFromHop, err := via.Router().MTU(); err != nil { - return err - } else { - mtuAvailable = calcTunMtu(mtuAvailFromHop) - hopping = true - note("wg: %s proxy: hopping %s; mtu(needed: %d / net: %d); hopmtu(avail: %d / tot: %d)", - w.tag(), viaid, mtuNeededByUs, mtuAvailFromNet, mtuAvailable, mtuAvailFromHop) - } - } - - has4 := w.IP4() - has6 := w.IP6() - minmtu := minmtu4 - if has6 { - minmtu = minmtu6 - } - - if mtuNeededByUs > mtuAvailable { - note("wg: %s (4? %t / 6? %t) proxy: maybe hopping %t %s; mtu(needed: %d >> avail: %d << min: %d); set to avail", - w.tag(), has4, has6, hopping, viaid, mtuNeededByUs, mtuAvailable, minmtu) - mtuNeededByUs = mtuAvailable - } // else: mtu needed is well within the hop's / network's capacity - - if mtuAvailable < minmtu { - return log.EE("wg: (4? %t / 6? %t) %s proxy: hopping? %t %s; needs1 %d; avail(%d) < min(%d); %v", - w.tag(), has4, has6, hopping, viaid, mtuNeededByUs, mtuAvailable, minmtu, errHopMtuInsufficient) - } - - finalMtu := reconcileMtu(mtuAvailable, mtuNeededByUs, minmtu) - if finalMtu <= NOMTU { - return log.EE("wg: %s (4? %t / 6? %t) proxy: hopping? %t %s; needs2 %d or avail %d <= NOMTU(%d); %v", - w.tag(), has4, has6, hopping, viaid, mtuNeededByUs, mtuAvailable, finalMtu, errHopMtuInsufficient) - } - - if !dryrun { - w.ep.SetMTU(uint32(finalMtu)) - w.wgtun.events <- tun.EventMTUUpdate - } - note("wg: %s (4? %t / 6? %t) proxy: hopping %s; mtu(needed:%d, avail: %d => final: %d); hopping? %t, dryrun? %t", - w.tag(), has4, has6, viaid, mtuNeededByUs, mtuAvailable, finalMtu, hopping, dryrun) - return nil -} - -// Returns wg header (80 bytes) minus min(underlay, overlay). -// Returns NOMTU if underlay is <= NOMTU. -// Returns minoverlay if overlay is <= NOMTU. -func reconcileMtu(underlay, overlay, minoverlay int) int { - if underlay < overlay { // underlay may be NOMTU - return max(underlay-80, NOMTU) // underlay may be way smaller than overlay - } - return calcTunMtu2(overlay, minoverlay) // overlay may be NOMTU, but that's okay -} - -// May return NOMTU if netmtu-size(wg header) is <= NOMTU. -func calcTunMtu(netmtu int) int { - return calcTunMtu2(netmtu, NOMTU) -} - -// github.com/tailscale/tailscale/blob/92d3f64e95/net/tstun/mtu.go -func calcTunMtu2(netmtu, min int) int { - // uint32(mtu) - 80 is the maximum payload size of a WireGuard packet. - return max(min-80, netmtu-80) // 80 is the overhead of the WireGuard header -} - -func calcNetMtu(tunmtu int) int { - return max(minmtu6, tunmtu+80) // 80 is the overhead of the WireGuard header -} - -func timedout(err error) bool { - x, ok := err.(net.Error) - return ok && x.Timeout() -} - -func logev(err error) log.LogFn { - if err != nil { - return log.E - } - if settings.Debug { - return log.VV - } - return log.N -} - -func loged(err error) log.LogFn { - if err != nil { - return log.E - } - if settings.Debug { - return log.D - } - return log.N -} - -func estr(err error) string { - if err != nil { - return err.Error() - } - return "" -} - -func pfxsfx(s string) string { - if len(s) <= 4 { - return s - } - if len(s) <= 8 { - return s[:4] - } - if len(s) <= 16 { - return s[:4] + ".." + s[len(s)-4:] - } - return s[:6] + ".." + s[len(s)-6:] -} diff --git a/intra/listener.go b/intra/listener.go deleted file mode 100644 index 4cd9ef33..00000000 --- a/intra/listener.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "fmt" - "net/netip" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn" -) - -// SocketSummary reports information about each TCP socket -// or a non-DNS UDP association, or ICMP echo when it is closed. -type SocketSummary struct { - // tcp, udp, or icmp. - Proto string - // Unique ID for this socket. - ID string - // Proxy ID that handled this socket. - PID string - // Relay Proxy ID that tunneled PID. - RPID string - // UID of the app that owns this socket (sans ICMP). - UID string - // Source IP. - Source string - // Remote IP, if dialed in. - Target string - // Total bytes downloaded. - Rx int64 - // Total bytes uploaded. - Tx int64 - // Duration in milliseconds. - Duration int64 - // Tracks start time; unexported. - start time.Time - // Round-trip time (millis). - Rtt int64 - // Err or other messages, if any. - Msg string -} - -type SocketListener interface { - // Preflow is called before a new connection is established; return owner "uid", which is - // later used by dnsx.Resolver to determine the DNS transport to use for that "uid". - Preflow(protocol, uid int32, src, dst string) *PreMark - // Flow is called on a new connection; return Proxy IDs to forward the connection - // to a pre-registered proxy; "Base" or "Exit" to allow the connection; "Block" to block it. - // "connid" is used to uniquely identify a connection across all proxies, and a summary of the - // connection is sent back to a pre-registered listener. - // protocol is 6 for TCP, 17 for UDP, 1 for ICMP. - // uid is -1 in case owner-uid of the connection couldn't be determined. - // src and dst are string'd representation of net.TCPAddr and net.UDPAddr. - // origdsts is a comma-separated list of original source IPs, this may be same as dst. - // origdsts may contain unspecified IPv4 or IPv6 addresses, which denote that the domain - // was blocked by a rdns blocklist (but the resolution was allowed to go through). Listener - // may choose to "Block" this connection based on that information. - // domains is a comma-separated list of domain names associated with origdsts, if any. - // probableDomains is a comma-separated list of probable domain names associated with origdsts, if any. - // blocklists is a comma-separated list of rdns blocklist names that apply, if any. - Flow(protocol, uid int32, src, dst, origdsts, domains, probableDomains, blocklists string) *Mark - // Inflow is called on a new incoming connection. Returned *Mark values have no discernable effect on these connections, - // except for the CID field, which is sent back via OnSocketClosed, and "Block" proxy which - // will drop this connection on the floor. - Inflow(protocol, uid int32, src, dst string) *Mark - // PostFlow is called after a flow is marked by Flow or Inflow. - // It denotes the final Mark that was applied to the flow. - // The only major discernable effect is PIDCSV has a single PID. - PostFlow(m *Mark) - // OnSocketClosed reports summary after a socket closes. - OnSocketClosed(*SocketSummary) -} - -type PreMark struct { - // UID of the app which owns the flow. - // Set it to "-1" if unknown. - UID string - // Is the UID us (our app / process)? - IsUidSelf bool -} - -type Mark struct { - // PIDCSV is a list of proxies to forward the flow over. - PIDCSV string - // CID uniquely identifies the flow. - CID string - // UID of the app which owns the flow. - UID string - // Preferred IP (for egress) to use for the flow (not guaranteed - // as the flow may prefer IPv4 or IPv6 and the IP may not be of that family). - IP string -} - -const ( - ProtoTypeUDP = "udp" - ProtoTypeTCP = "tcp" - ProtoTypeICMP = "icmp" -) - -var ( - optionsBlock = &Mark{PIDCSV: ipn.Block} - optionsExit = &Mark{PIDCSV: ipn.Exit} -) - -var errNone noerror - -type noerror struct{} - -var _ error = noerror{} - -func (noerror) Error() string { return "no error" } - -func icmpSummary(id, uid string) *SocketSummary { - return &SocketSummary{ - Proto: ProtoTypeICMP, - ID: id, - UID: uid, - start: time.Now(), - Msg: errNone.Error(), - } -} - -func tcpSummary(id, uid string, src, dst netip.Addr) *SocketSummary { - return &SocketSummary{ - Proto: ProtoTypeTCP, - ID: id, - UID: uid, - Source: src.String(), - Target: dst.String(), - start: time.Now(), - Msg: errNone.Error(), - } -} - -func udpSummary(id, uid string, src, dst netip.Addr) *SocketSummary { - s := tcpSummary(id, uid, src, dst) - s.Proto = ProtoTypeUDP - return s -} - -func (s *SocketSummary) postMark() *Mark { - if s == nil { - return nil - } - return &Mark{ - PIDCSV: s.PID, - CID: s.ID, - UID: s.UID, - IP: s.Target, - } -} - -// String implements fmt.Stringer. -func (s *SocketSummary) String() string { - if s != nil { - return fmt.Sprintf("socket-summary: %s: id=%s pid=%s:%s uid=%s to=%s down=%s up=%s dur=%s synack=%s msg=%s", - s.Proto, s.ID, s.PID, s.RPID, s.UID, s.Target, core.FmtBytes(uint64(s.Rx)), core.FmtBytes(uint64(s.Tx)), core.FmtMillis(s.Duration), core.FmtMillis(s.Rtt), s.Msg) - } - return "" -} - -func (s *SocketSummary) elapsed() { - if s != nil { - s.Duration = time.Since(s.start).Milliseconds() - } -} - -func (s *SocketSummary) done(errs ...error) *SocketSummary { - if s == nil { - return nil - } - - defer func() { - if len(s.Msg) <= 0 { - s.Msg = errNone.Error() - } - }() - - s.elapsed() - - if len(errs) <= 0 { - return s - } - - err := core.UniqErr(errs...) // errs may be nil - if err != nil { - if s.Msg == errNone.Error() { - s.Msg = err.Error() - } else { - s.Msg = s.Msg + "; " + err.Error() - } - } - return s -} diff --git a/intra/log/fconsole.go b/intra/log/fconsole.go deleted file mode 100644 index 22d0428f..00000000 --- a/intra/log/fconsole.go +++ /dev/null @@ -1,93 +0,0 @@ -package log - -import ( - "bytes" - "io" - "os" - "syscall" - "unsafe" -) - -var newline = []byte{'\n'} - -type fconsole struct { - // takes ownership of w - w *os.File -} - -var _ FilebasedConsole = (*fconsole)(nil) - -func newfconsole(w *os.File) *fconsole { - return &fconsole{w: w} -} - -func (p *fconsole) Close() error { - if w := p.w; w != nil { - _ = w.Close() - p.w = nil - } - return nil -} - -// File returns the underlying os.File. Caller -// must dup() if it intends to use it beyond -// the lifetime of the console. -func (p *fconsole) File() *os.File { - if p == nil { - return nil - } - return p.w -} - -func (p *fconsole) Log(lvl LogLevel, msg Logmsg) { - if p == nil || p.w == nil { - return - } - p.write(lvl, msg) -} - -func setNonblock(f *os.File) error { - if f == nil { - return nil - } - return syscall.SetNonblock(int(f.Fd()), true) -} - -func (f *fconsole) write(lvl LogLevel, m Logmsg) error { - if len(m) == 0 { - return nil - } - - w := f.w - if w == nil { - return io.ErrClosedPipe - } - if len(m) <= 0 { - return nil - } - l := []byte(lvl.s()) - p := unsafe.StringData(m) - b := unsafe.Slice(p, len(m)) - // levels like STACKTRACE may not prefix the expected tag - // ("F " in the STACKTRACE case), but file-based logger - // always expects it for every line - if !bytes.HasPrefix(b, l) { - w.Write(l) - } - n, err := w.Write(b) - // go.dev/play/p/NbJimcpoS0o - if !bytes.HasSuffix(b, newline) { - w.Write(newline) - } - if err != nil { - if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR { - // non-blocking write would block; drop the log - return nil - } - return err - } - if n < len(m) { - return io.ErrShortWrite - } - return nil -} diff --git a/intra/log/log.go b/intra/log/log.go deleted file mode 100644 index 4b89ab74..00000000 --- a/intra/log/log.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// MIT License -// -// Copyright (c) 2018 eycorsican -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -package log - -import ( - "context" - "errors" - "fmt" - "io" - "os" -) - -// based on: github.com/eycorsican/go-tun2socks/blob/301549c43/common/log/log.go#L5 -var Glogger Logger = defaultLogger() - -// caller -> intra/log.go*2 (this file) -> intra/logger.go -> golang/log.go -const ( - nextframe = 1 - // see: pkg.go.dev/log#Output - callerat = 2 -) - -type Logmsg = string - -// Console is an external logger. -type Console interface { - // Log logs a multi-line msg. - Log(level LogLevel, msg Logmsg) -} - -type FilebasedConsole interface { - Console - io.Closer - File() *os.File -} - -type conMsg struct { - m Logmsg - t LogLevel -} - -type LogFn func(string, ...any) -type LogFn2 func(int, string, ...any) - -func init() { - Glogger.SetLevel(INFO) - Glogger.SetConsoleLevel(STACKTRACE) -} - -func SetLevel(level LogLevel) { - Glogger.SetLevel(level) -} - -func SetConsoleLevel(level LogLevel) { - Glogger.SetConsoleLevel(level) -} - -func ConsoleReady(ctx context.Context) { - Glogger.ConsoleReady(ctx) -} - -// SetConsole sets external console to redirect log output to. -func SetConsole(consoleCtx context.Context, c Console) { - Glogger.SetConsole(c) - - context.AfterFunc(consoleCtx, func() { - Glogger.SetConsole(nil) - }) -} - -// NewFilebased sets a pipe-backed console and returns reader and writer. -// Caller is expected to hand off the read fd and read until EOF. -// Caller owns the reader and write and must close both. -func NewFilebased() (reader *os.File, writer FilebasedConsole, err error) { - var r, w *os.File - r, w, err = os.Pipe() // r is owned by us - if err != nil { - return - } - if err = setNonblock(w); err != nil { - _ = r.Close() - _ = w.Close() - return - } - - p := newfconsole(w) // p takes ownership of w - return r, p, nil // caller must dup r -} - -func Of(tag string, l LogFn2) LogFn { - if l != nil { - return func(msg string, args ...any) { - // caller -> LogFn (parent fn) -> intra/log.go*2(this file) -> intra/logger.go -> golang/log.go - l(callerat, tag+" "+msg, args...) - } - } - return N -} - -// N is a no-op logger. -func N(string, ...any) {} - -// N2 is a no-op logger. -func N2(int, string, ...any) {} - -// V logs a verbose message. -func V(msg string, args ...any) { - V2(callerat, msg, args...) -} - -// VV logs a very verbose message. -func VV(msg string, args ...any) { - VV2(callerat, msg, args...) -} - -// D logs a debug message. -func D(msg string, args ...any) { - D2(callerat, msg, args...) -} - -// I logs an info message. -func I(msg string, args ...any) { - I2(callerat, msg, args...) -} - -// W logs a warning message. -func W(msg string, args ...any) { - W2(callerat, msg, args...) -} - -func WE(msg string, args ...any) (err error) { - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - W2(callerat, msg) - return errors.New(msg) -} - -// E logs an error message. -func E(msg string, args ...any) { - E2(callerat, msg, args...) -} - -func EE(msg string, args ...any) (err error) { - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - E2(callerat, msg) - return errors.New(msg) -} - -// P logs a private message. -func P(msg string, args ...any) { - Glogger.Piif(callerat, msg, args...) -} - -// Wtf logs a fatal message. -func Wtf(msg string, args ...any) { - Glogger.Fatalf(callerat, msg, args...) -} - -// C logs the stack trace of the current goroutine to Console. -func C(msg string, scratch []byte) { - E2(callerat, "----START----") - Glogger.Stack( /*console-only*/ 0, msg, scratch) - E2(callerat, "----STOPP----") -} - -// R logs msg to as error to log if c is false, or to console otherwise. -func R(c bool, msg string, args ...any) { - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - Glogger.Trace(c, msg) -} - -// U logs a user message (notifies the user). -func U(msg string) { - Glogger.Usr(msg) -} - -// T logs the stack trace of the current goroutine. -func T(msg string, args ...any) { - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - E2(callerat, "----START----") - Glogger.Stack(callerat, msg, make([]byte, 4096)) - E2(callerat, "----STOPP----") -} - -// TALL logs the stack trace of all active goroutines. -func TALL(msg string, atleast64k []byte) { - E2(callerat, "----START----") - Glogger.Stack(callerat, msg, atleast64k /*may be nil*/) - E2(callerat, "----STOPP----") -} - -func VV2(at int, msg string, args ...any) { - Glogger.VeryVerbosef(at+nextframe, msg, args...) -} - -func V2(at int, msg string, args ...any) { - Glogger.Verbosef(at+nextframe, msg, args...) -} - -func D2(at int, msg string, args ...any) { - Glogger.Debugf(at+nextframe, msg, args...) -} - -func I2(at int, msg string, args ...any) { - Glogger.Infof(at+nextframe, msg, args...) -} - -func W2(at int, msg string, args ...any) { - Glogger.Warnf(at+nextframe, msg, args...) -} - -func E2(at int, msg string, args ...any) { - Glogger.Errorf(at+nextframe, msg, args...) -} - -func LevelOf(level int32) LogLevel { - dlvl := NONE - switch l := LogLevel(level); l { - case VVERBOSE: - dlvl = VVERBOSE - case VERBOSE: - dlvl = VERBOSE - case DEBUG: - dlvl = DEBUG - case INFO: - dlvl = INFO - case WARN: - dlvl = WARN - case ERROR: - dlvl = ERROR - case STACKTRACE: - dlvl = STACKTRACE - case NONE: - dlvl = NONE - default: - } - return dlvl -} diff --git a/intra/log/logger.go b/intra/log/logger.go deleted file mode 100644 index 64ddbadc..00000000 --- a/intra/log/logger.go +++ /dev/null @@ -1,820 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// MIT License -// -// Copyright (c) 2018 eycorsican -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -package log - -import ( - "context" - "fmt" - "hash/fnv" - golog "log" - "os" - "reflect" - "runtime" - "strings" - "sync" - "sync/atomic" - "unsafe" -) - -type Logger interface { - SetLevel(level LogLevel) - SetConsoleLevel(level LogLevel) - SetConsole(c Console) - ConsoleReady(ctx context.Context) - Usr(msg string) - Printf(msg string, args ...any) - VeryVerbosef(at int, msg string, args ...any) - Verbosef(at int, msg string, args ...any) - Debugf(at int, msg string, args ...any) - Piif(at int, msg string, args ...any) - Infof(at int, msg string, args ...any) - Warnf(at int, msg string, args ...any) - Errorf(at int, msg string, args ...any) - Fatalf(at int, msg string, args ...any) - Trace(c bool, t string) - Stack(at int, msg string, scratch []byte) -} - -// based on github.com/eycorsican/go-tun2socks/blob/301549c43/common/log/simple/logger.go -type simpleLogger struct { - tag string - - level LogLevel // golog (internal log) level - - c atom[Console] - clevel LogLevel // may be different from level - cmsgC chan *conMsg // never closed - cskips atomic.Uint32 // number of dropped console msgs - - stmu sync.Mutex // guards stcount - stcount map[uint64]uint32 // stack trace counter for identical traces - - o *golog.Logger - e *golog.Logger - q *ring[string] // todo: use []byte instead of string for gc? - - clock - skips -} - -type atom[T any] atomic.Value - -func (a *atom[T]) get() (zz T) { - if a == nil { - return - } - aa := (*atomic.Value)(a) - if t, ok := aa.Load().(T); ok { - return t - } - return zz -} - -func (a *atom[T]) set(t T) (ok bool) { - if a == nil { - return - } - if isNil(t) { - zz := &atom[T]{} - *a = *zz - return - } - old := a.get() - if !typeEq(old, t) { - r := &atom[T]{} - *a = *r - } - aa := (*atomic.Value)(a) - return aa.CompareAndSwap(old, t) -} - -const pcbuckets = 512 - -// a clock-like spam rate limiter -// maps level+pc to its age in ticks -type clock struct { - l2 [NONE + 1][pcbuckets]uatom[uint8] // level+pc clock - l1 [NONE + 1]uatom[uint8] // level clock -} - -// number of per-level dropped (spammy) logs -type skips [NONE + 1]atomic.Uint32 - -type uatom[T uint8 | uint16] atomic.Uint32 - -func (a *uatom[T]) v() T { - aa := (*atomic.Uint32)(a) - return T(aa.Load()) -} - -func (a *uatom[T]) inc() T { - aa := (*atomic.Uint32)(a) - return T(aa.Add(1)) -} - -func (a *uatom[T]) cas(old, new T) bool { - aa := (*atomic.Uint32)(a) - return aa.CompareAndSwap(uint32(old), uint32(new)) -} - -var _ Logger = (*simpleLogger)(nil) - -type LogLevel uint32 - -const ( - VVERBOSE LogLevel = iota // VVERBOSE is the most verbose log level. - VERBOSE // VERBOSE is the verbose log level. - DEBUG // DEBUG is the debug log level. - INFO // INFO is the informational log level. - WARN // WARN is the warning log level. - ERROR // ERROR is the error log level. - STACKTRACE // STACKTRACE is the stack trace log level. - USR // USR is interactive log (e.g. as user prompt). - NONE // NONE no-ops the logger. -) - -func (l LogLevel) s() string { - switch l { - case VVERBOSE: - return "Y " - case VERBOSE: - return "V " - case DEBUG: - return "D " - case INFO: - return "I " - case WARN: - return "W " - case ERROR: - return "E " - case STACKTRACE: - return "F " - case USR: - return "U " - case NONE: - return " " - default: - return "? " - } -} - -const defaultLevel = INFO -const defaultClevel = STACKTRACE - -var _ Logger = (*simpleLogger)(nil) - -// runtime crashes "E Go ..." are sent to logd / /dev/log from here: -// github.com/golang/go/blob/3fd729b2a1/src/runtime/write_err_android.go#L13 -// github.com/golang/mobile/blob/fa72addaaa/internal/mobileinit/mobileinit_android.go#L52 -// const logcatLineSize = 1024 - -const spamConsole = false // send spammy logs to console -const logPiif = false // enable sensitive logs - -// qSize is the number of recent log msgs to keep in the ring buffer. -const qSize = 512 - -// minQSize is the minimum most number of recent log msgs to actually log. -// by default, all msgs in the qSize'd ring buffer are logged. -const minQSize = 16 - -// consoleChSize is the size of the console channel. -const consoleChSize = 1024 - -// minBytesForFullStacktrace is the size needed for a full stacktrace. -const minBytesForFullStacktrace = 16 << 10 // 16KB - -// similarTraceThreshold is the no. of similar stacktraces to report before suppressing. -const similarTraceThreshold = 8 - -// similarUsrMsgThreshold is the no. of similar user msgs to report before suppressing. -const similarUsrMsgThreshold = 3 - -// charsPerLine is max no. of characters per log line. -// less than 1024: github.com/golang/mobile/blob/2553ed8ce2/internal/mobileinit/mobileinit_android.go#L52 -const charsPerLine = 800 - -// prependTrace if true, prepends trace information to log msg; appends, otherwise. -const prependTrace = false - -// spamMsgThreshold is the min. no. of spammy msgs to report. -var spammsgThreshold = [NONE + 1]uint32{ - VVERBOSE: 256 >> 1, // 128 - VERBOSE: 256 >> 2, // 64 - DEBUG: 256 >> 3, // 32 - INFO: 256 >> 4, // 16 - WARN: 256 >> 5, // 8 - ERROR: 256 >> 6, // 4 - STACKTRACE: 256 >> 7, // 2 - USR: 256 >> 8, // 1 - NONE: 256 >> 9, // 0 -} - -const fileunknown = "?f?" -const callerunknown = "?c?" - -const defaultFlags = 0 // no flags - -func defaultLogger() *simpleLogger { - l := &simpleLogger{ - level: defaultLevel, - clevel: defaultClevel, - cmsgC: make(chan *conMsg, consoleChSize), - stcount: make(map[uint64]uint32), - // gomobile pipes stderr & stdout to logcat - // github.com/golang/mobile/blob/fa72addaaa/internal/mobileinit/mobileinit_android.go#L74-L92 - e: golog.New(os.Stderr, "", defaultFlags), - o: golog.New(os.Stdout, "", defaultFlags), - q: newRing[string](context.TODO(), qSize), - } - return l -} - -// NewLogger creates a new Glogger with the given tag. -func NewLogger(tag string) *simpleLogger { - l := defaultLogger() - if len(tag) <= 0 { // if tag is empty, leave it as is - return l - } - if !strings.HasSuffix(tag, "/") { - tag += "/ " // does not end with a /, add a / + space - } else if !strings.HasSuffix(tag, " ") { - tag += " " // does not end with a space, add space - } - l.tag = tag - return l -} - -// SetLevel sets the log level. -func (l *simpleLogger) SetLevel(n LogLevel) { - l.level = n -} - -// SetLevel sets the log level. -func (l *simpleLogger) SetConsoleLevel(n LogLevel) { - l.clearStCounts() - l.clevel = n -} - -// SetConsole sets the external log console. -func (l *simpleLogger) SetConsole(c Console) { - l.clearStCounts() - - l.c.set(c) // c may point to nil impl -} - -func (l *simpleLogger) ConsoleReady(ctx context.Context) { - go l.consoleDispatcher(ctx) - // TODO: close l.msgC when ctx is done - // TODO: wireup ctx to l.q -} - -func (l *simpleLogger) clearStCounts() { - l.stmu.Lock() - defer l.stmu.Unlock() - clear(l.stcount) -} - -// xor fold fnv to 48 bits: www.isthe.com/chongo/tech/comp/fnv -func fhash(b []byte) uint64 { - h := fnv.New64a() - _, _ = h.Write(b) - return h.Sum64() -} - -func (l *simpleLogger) incrStCount(id string) (c uint32) { - l.stmu.Lock() - defer l.stmu.Unlock() - - if len(id) > 500 { - id = id[:500] - } - loc := fhash([]byte(id)) - c = l.stcount[loc] - l.stcount[loc]++ - return c -} - -// consoleDispatcher sends msgs from l.msgC to external log console. -// It may drop logs on high load (50% for conNorm, 80% for conErr). -// Must be called once from a goroutine. -func (l *simpleLogger) consoleDispatcher(ctx context.Context) { - for m := range l.cmsgC { - select { - case <-ctx.Done(): - return - default: - } - if m == nil || len(m.m) <= 0 { // no msg - continue - } - load := (len(l.cmsgC) / cap(l.cmsgC) * 100) // load percentage - if c := l.c.get(); c != nil && !isNil(c) { // look for l.c on every msg - switch m.t { - case NONE: - // drop - case VVERBOSE, VERBOSE, DEBUG, INFO: - if load < 50 { - c.Log(m.t, m.m) - continue - } // drop - case WARN, ERROR: - if load < 5 { - if d := l.cskips.Swap(0); d > 0 { - c.Log(WARN, Logmsg(l.msgstr(WARN, "backpressure... dropped %d msgs", d))) - } - } - if load < 80 { - c.Log(m.t, m.m) - continue - } // drop - case STACKTRACE: - c.Log(m.t, m.m) - continue - case USR: - c.Log(m.t, m.m) - continue - } - } // dropped - l.cskips.Add(1) - } -} - -// consoleQueue sends msg m to l.msgC, dropping if full. -func (l *simpleLogger) consoleQueue(m *conMsg) { - select { - case l.cmsgC <- m: - default: // drop - } -} - -func (l *simpleLogger) Usr(msg string) { - if l.level <= USR { - if count := l.incrStCount(msg); count > similarUsrMsgThreshold { - return - } - l.consoleQueue(&conMsg{Logmsg(msg), USR}) - } -} - -// Printf exists to satisfy rnet/http's Logger interface -func (l *simpleLogger) Printf(msg string, args ...any) { - l.Debugf(callerat, msg, args...) -} - -func (l *simpleLogger) VeryVerbosef(at int, msg string, args ...any) { - l.writelog(VVERBOSE, at+nextframe, msg, args...) -} - -func (l *simpleLogger) Verbosef(at int, msg string, args ...any) { - l.writelog(VERBOSE, at+nextframe, msg, args...) -} - -func (l *simpleLogger) Debugf(at int, msg string, args ...any) { - l.writelog(DEBUG, at+nextframe, msg, args...) -} - -func (l *simpleLogger) Piif(at int, msg string, args ...any) { - if logPiif { - l.writelog(INFO, at+nextframe, msg, args...) - } -} - -func (l *simpleLogger) Infof(at int, msg string, args ...any) { - l.writelog(INFO, at+nextframe, msg, args...) -} - -func (l *simpleLogger) Warnf(at int, msg string, args ...any) { - l.writelog(WARN, at+nextframe, msg, args...) -} - -func (l *simpleLogger) Errorf(at int, msg string, args ...any) { - l.writelog(ERROR, at+nextframe, msg, args...) -} - -func (l *simpleLogger) Fatalf(at int, msg string, args ...any) { - // todo: log to console? - l.err(at+nextframe, l.msgstr(STACKTRACE, msg, args...)) - os.Exit(1) -} - -// emitStack sends stacktrace to console or log. -// Empty msgs are ignored. Log level (ex: "F ") is -// prepend to each log line when sent to console. -func (l *simpleLogger) emitStack(at int, msgs ...string) { - sendtoconsole := at <= callerat - c := l.c.get() - hasc := c != nil && !isNil(c) - - for _, msg := range msgs { - if len(msg) <= 0 { - continue - } - if !sendtoconsole { - l.err(at+nextframe, msg) - } else if hasc { - // c.Stack() on the same go routine, since - // the caller (ex: core.Recover) may exit - // immediately once simpleLogger.Stack() returns - c.Log(STACKTRACE, Logmsg(msg)) - } else { - // msg, which is unsafely type-coerced from []byte, - // is pooled; but the caller owns []byte and so it - // cannot be used asynchronously (ex: over channels). - // l.toConsole(&conMsg{msg, STACKTRACE}) - l.cskips.Add(1) - break // terminate the loop - } - } -} - -func (l *simpleLogger) Trace(c bool, t string) { - if len(t) <= 0 { - return - } - at := callerat // emits to console - if !c { - at += nextframe // emits to stdout - } - l.emitStack(at, t) -} - -func (l *simpleLogger) Stack(at int, msg string, scratch []byte) { - at += nextframe - if len(l.tag) > 0 { - msg = l.tag + msg - } - - if l.level > STACKTRACE { - l.emitStack(at, msg, "stacktrace disabled") - return - } else if len(scratch) <= 0 { - l.emitStack(at, msg, "stacktrace no scratch") - return - } - - count := l.incrStCount(msg) - msg = msg + fmt.Sprintf(" (#%d)", count) - if count > similarTraceThreshold { - l.emitStack(at, msg, "stacktrace suppressed") - return - } - - // full stacktrace iff large enough scratch - full := len(scratch) > minBytesForFullStacktrace // 16KB - n := runtime.Stack(scratch, full) - - if n == len(scratch) { - msg += "[trunc]" - } - - prev := l.queued(full) - - // byt2str accepted proposal: github.com/golang/go/issues/19367 - // previous discussion: github.com/golang/go/issues/25484 - trace := unsafe.String(&scratch[0], n) - l.emitStack(at, msg, trace, prev) -} - -func (l *simpleLogger) queued(all bool) (appendix string) { - maxlines := qSize - if !all { - maxlines = minQSize - } - i := 0 - // todo: interned strings github.com/golang/go/issues/62483 - lines := make([]string, maxlines) - for recent := range l.q.Iter() { - lines[i] = recent - i++ - if i >= len(lines) { - break - } - } - if i > 0 { - appendix = strings.Join(lines[:i], "\n") - } - return -} - -func (l *simpleLogger) msgstr(lvl LogLevel, f string, args ...any) (msg string) { - level := lvl.s() - - if len(f) <= 0 { - return level + l.tag + "" - } - if len(args) <= 0 { - return level + l.tag + f - } - msg = fmt.Sprintf(f, args...) - if len(msg) <= charsPerLine { // excl tag+level - return level + l.tag + msg - } - - var s strings.Builder - for i := 0; i < len(msg); i += charsPerLine { - if i > 0 { - s.WriteByte('\n') - } - s.WriteString(level) - if len(l.tag) > 0 { - s.WriteString(l.tag) - } - end := min(i+charsPerLine, len(msg)) - s.WriteString(msg[i:end]) - } - return s.String() -} - -// out logs to stdout and pushes msg into ring buffer. -// ref: github.com/golang/mobile/blob/c713f31d/internal/mobileinit/mobileinit_android.go#L51 -func (l *simpleLogger) out(msg string) { - _ = l.o.Output(0 /*not used*/, msg) // may error - l.q.Push(msg) -} - -// err logs to stderr and pushes msg into ring buffer. -func (l *simpleLogger) err(at int, msg string) { - _, file := caller(at + nextframe) - msg = file + msg - _ = l.e.Output(0 /*unused*/, msg) // may error - l.q.Push(msg) -} - -func caller(at int) (pc uintptr, who string) { - return caller2(at+nextframe, ":", ": ") -} - -func caller2(at int, sep1, sep2 string) (pc uintptr, who string) { - pc, file, line, _ := runtime.Caller(at) - if len(file) <= 0 { - file = fileunknown - } else { - file = shortfile(file) + sep1 + fmt.Sprint(line) + sep2 - } - return pc, file -} - -// go.dev/play/p/h9Woqcp0Xz0 -func callers(at, until int, sep1, sep2 string) (pcs []uintptr, files []string, skipped int) { - if until <= 0 { - return []uintptr{0}, []string{fileunknown}, 0 - } else if until == 1 { - pc, who := caller2(at+nextframe, sep1, "") - return []uintptr{pc}, []string{who}, 0 - } - - rpc := make([]uintptr, until) - n := runtime.Callers(at+nextframe, rpc) - if n < 1 { - return []uintptr{0}, []string{fileunknown}, until - } - - pcs = make([]uintptr, 0, until) - files = make([]string, 0, until) - frames := runtime.CallersFrames(rpc) - for i := range until { - frame, more := frames.Next() - pc := frame.PC // may be 0 - file := frame.File - line := frame.Line - fn := frame.Function - if len(file) <= 0 { // more is false when file is empty - file = fileunknown - } else { - file = shortfile(file) + sep1 + fmt.Sprint(line) - } - if len(fn) <= 0 { - fn = callerunknown - } else { - // ex: fn = "github.com/celzero/firestack/intra/dnsx.ChooseHealthyProxyHostPort" - fn = shortfile(fn) - } - - file += sep2 + fn - pcs = append(pcs, pc) - files = append(files, file) - if !more { - break - } - skipped = until - i - } - return -} - -func tracecaller(s string) bool { - if len(s) <= 0 || (strings.HasSuffix(s, callerunknown) && strings.HasPrefix(s, fileunknown)) { - return false - } - // ex: asm_arm64.s:1223>async.go:49>async.go:121>proxy.go:789 - if strings.Contains(s, "asm_") && strings.Contains(s, ".s") { - return false // asm files are not useful - } - return true -} - -func shortfile(file string) string { - if i := strings.LastIndexByte(file, '/'); i >= 0 { - file = file[i+1:] - } - return file -} - -func (l *simpleLogger) writelog(lvl LogLevel, at int, msg string, args ...any) { - ll := l.level <= lvl - cc := l.clevel <= lvl - - pc, file1 := caller(at + nextframe) - trace := "" - - isspam := l.spammy(lvl, pc) - if isspam { - l.skips[lvl].Add(1) - } - - if n := l.skips[lvl].Load(); n > spammsgThreshold[lvl] { - swapped := l.skips[lvl].CompareAndSwap(n, 0) - if swapped && (cc || ll) { - spammsg := l.msgstr(lvl, file1+"spammy... %d msgs; dropped? %t", n, !spamConsole) - if ll { - l.out(spammsg) - } - // print spammsg only if spamming is not allowed - if cc && !spamConsole { - l.consoleQueue(&conMsg{Logmsg(spammsg), lvl}) - } - } - } - - if ll || cc { - _, x, _ := callers(at+nextframe, 9, ":", "@") - switch lvl { - case USR, STACKTRACE, NONE: // no-op - case VVERBOSE: - if len(x) >= 10 && tracecaller(x[9]) { - trace += x[9] + ">" - } - fallthrough - case VERBOSE: - if len(x) >= 9 && tracecaller(x[8]) { - trace += x[8] + ">" - } - fallthrough - case DEBUG, ERROR, WARN, INFO: - if len(x) >= 8 && tracecaller(x[7]) { - trace += x[7] + ">" - } - if len(x) >= 7 && tracecaller(x[6]) { - trace += x[6] + ">" - } - if len(x) >= 6 && tracecaller(x[5]) { - trace += x[5] + ">" - } - if len(x) >= 5 && tracecaller(x[4]) { - trace += x[4] + ">" - } - // err - if len(x) >= 4 && tracecaller(x[3]) { - trace += x[3] + ">" - } - // warn - if len(x) >= 3 && tracecaller(x[2]) { - trace += x[2] + ">" - } - // info - if len(x) >= 2 && tracecaller(x[1]) { - trace += x[1] + ">" - } - fallthrough - default: - if tracecaller(x[0]) { // x[0] == file1 without fn info - trace += x[0] - } - } - if prependTrace { - if len(trace) > 0 { - trace += "\t" // end-of-trace marker - } - msg = l.msgstr(lvl, trace+msg, args...) - } else { - msg += "\t" // end-of-msg marker - msg = l.msgstr(lvl, msg+trace, args...) - } - if ll { - // go's internal logger grabs mutex before every write - l.out(msg) - } - if cc && (!isspam || spamConsole) { - l.consoleQueue(&conMsg{Logmsg(msg), lvl}) - } - } -} - -// go.dev/play/p/6CkoACJ1bYz -func (l *simpleLogger) spammy(lvl LogLevel, pc uintptr) (y bool) { - resyncAttempts := 0 - -top: - t := (l.clock.l1[lvl]).inc() // tick the level clock - - if pc == 0 { - return false - } - // won't work: l2 := l.clock.l2[lvl] - // go.dev/play/p/QgqEdE7KIAZ - bkt := pc % pcbuckets - - defer func() { - if !y { // age bkt when not spammy - (l.clock.l2[lvl][bkt]).inc() - } - }() - - v := (l.clock.l2[lvl][bkt]).v() - - // reset if pc clock (l2) out ticks level clock (l1); - // ie, t has probably overflowed to next generation - // and in that generation, bkt has not yet born. - // and so reset bkt to 0 or any value < t - if v > t { - resyncd := (l.clock.l2[lvl][bkt]).cas(t, 0) // set to t/2? - if resyncd { - return false // not spammy - } // else: someone else won the race - resyncAttempts++ - if resyncAttempts <= 3 { - goto top - } // else: so many calls that atomic updates won't go through - // assume spammy as that's most likely to be the case - return false - } - - tt := uint16(t) // tolerable ticks - if t < 256>>4 { // for upto 16 ticks - tt = tt * 70 / 100 // allow upto 70% of ticks - } else if t < 256>>3 { // for upto 32 ticks - tt = tt * 60 / 100 // allow upto 60% of ticks - } else if t < 256>>2 { // for upto 64 ticks - tt = tt * 50 / 100 // allow upto 50% of ticks - } else if t < 256>>1 { // for upto 128 ticks - tt = tt * 40 / 100 // allow upto 40% of ticks - } else { // for upto 256 ticks - tt = tt * 30 / 100 // allow upto 30% of ticks - } - return uint16(v) > tt -} - -// Cannot import pkg core here. -// from: intra/core/typ.go:isNil -func isNil(x any) bool { - // from: stackoverflow.com/a/76595928 - if x == nil { - return true - } - v := reflect.ValueOf(x) - k := v.Kind() - switch k { - case reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice: - return v.IsNil() - } - return false -} - -// from: intra/core/typ.go:typeEq -func typeEq(a, b any) bool { - if isNil(a) { - return false - } else if isNil(b) { - return false - } - return reflect.TypeOf(a) == reflect.TypeOf(b) -} diff --git a/intra/log/ring.go b/intra/log/ring.go deleted file mode 100644 index fd9549d2..00000000 --- a/intra/log/ring.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package log - -import ( - "context" - "sync" -) - -// A thread-safe ring buffer implementation -type ring[T any] struct { - sync.RWMutex - ctx context.Context - b []T // buffer - inC chan T // input channel - head int - tail int -} - -// NewRing creates a new ring buffer with the given capacity -func newRing[T any](ctx context.Context, capacity int) *ring[T] { - r := &ring[T]{ - ctx: ctx, - b: make([]T, capacity), - inC: make(chan T, capacity/2), - } - go r.process() - context.AfterFunc(ctx, func() { close(r.inC) }) - return r -} - -// Push adds an element to the ring buffer -func (r *ring[T]) Push(v T) (ok bool) { - select { - case <-r.ctx.Done(): - default: - select { - case <-r.ctx.Done(): - case r.inC <- v: - return true - default: // over cap, drop - } - } - return -} - -// process reads from the input channel and adds elements to the ring buffer. -// Must be run in a goroutine. -func (r *ring[T]) process() { - for v := range r.inC { - r.Lock() - r.b[r.head] = v - r.head = (r.head + 1) % len(r.b) - if r.head == r.tail { - r.tail = (r.tail + 1) % len(r.b) - } - r.Unlock() - } -} - -// Pop removes and returns the oldest element from the ring buffer -func (r *ring[T]) Pop() (v T) { - r.Lock() - defer r.Unlock() - - if r.head == r.tail { - return - } - v = r.b[r.tail] - r.tail = (r.tail + 1) % len(r.b) - return v -} - -// Len returns the number of elements in the ring buffer -func (r *ring[T]) Len() int { - r.RLock() - defer r.RUnlock() - - if r.head >= r.tail { - return r.head - r.tail - } - return len(r.b) - r.tail + r.head -} - -// Cap returns the capacity of the ring buffer -func (r *ring[T]) Cap() int { - r.RLock() - defer r.RUnlock() - - return len(r.b) -} - -// Peek returns the oldest element from the ring buffer without removing it -func (r *ring[T]) Peek() (v T) { - r.RLock() - defer r.RUnlock() - - if r.head == r.tail { - return - } - return r.b[r.tail] -} - -// Reset resets the ring buffer -func (r *ring[T]) Reset() { - r.Lock() - defer r.Unlock() - - r.head = 0 - r.tail = 0 -} - -// Iter returns a channel that yields all elements in the ring buffer -func (r *ring[T]) Iter() <-chan T { - ch := make(chan T, r.Cap()) - - go func() { - defer close(ch) - r.RLock() - defer r.RUnlock() - - for i := r.tail; i != r.head; i = (i + 1) % len(r.b) { - ch <- r.b[i] - } - }() - return ch -} diff --git a/intra/netstack/dispatchers.go b/intra/netstack/dispatchers.go deleted file mode 100644 index ff7b6e04..00000000 --- a/intra/netstack/dispatchers.go +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package netstack - -import ( - "math/rand" - "sync" - "sync/atomic" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/rawfile" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// Adopted from: github.com/google/gvisor/blob/f2b01a6e4/pkg/tcpip/link/fdbased/packet_dispatchers.go - -const wrapttl = 5 * time.Second // wrapttl is the time to wait for wrapup() to complete. - -// is the dispatcher thread safe? -const threadSafe = false - -// bufcfg defines the shape of the vectorised view used to read packets from the NIC. -var bufcfg = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} - -type iovecBuffer struct { - // serializes access to all members - sync.Mutex - - // buffer is the actual buffer that holds the packet contents. Some contents - // are reused across calls to pullBuffer if number of requested bytes is - // smaller than the number of bytes allocated in the buffer. - views []*buffer.View - - // iovecs are initialized with base pointers/len of the corresponding - // entries in the views defined above, except when GSO is enabled - // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header - // which is stripped before the views are passed up the stack for further - // processing. - iovecs []unix.Iovec - - // sizes is an array of buffer sizes for the underlying views. sizes is - // immutable. - sizes []int - - // unused: skipsVnetHdr is true if virtioNetHdr is to skipped. - // skipsVnetHdr bool - - // pulledIndex is the index of the last []byte buffer pulled from the - // underlying buffer storage during a call to pullBuffers. It is -1 - // if no buffer is pulled. - pulledIndex int -} - -func newIovecBuffer(sizes []int) *iovecBuffer { - b := &iovecBuffer{ - views: make([]*buffer.View, len(sizes)), - sizes: sizes, - // Setting pulledIndex to the length of sizes will allocate all - // the buffers. - pulledIndex: len(sizes), - } - niov := len(sizes) - b.iovecs = make([]unix.Iovec, niov) - return b -} - -func (b *iovecBuffer) nextIovecs() []unix.Iovec { - if threadSafe { - b.Lock() - defer b.Unlock() - } - - vnetHdrOff := 0 - - for i := range b.views { - if b.views[i] != nil { - break - } - v := buffer.NewViewSize(b.sizes[i]) - b.views[i] = v - b.iovecs[i+vnetHdrOff] = unix.Iovec{Base: v.BasePtr()} - b.iovecs[i+vnetHdrOff].SetLen(v.Size()) - } - return b.iovecs -} - -func (b *iovecBuffer) release() { - if threadSafe { - b.Lock() - defer b.Unlock() - } - - for _, v := range b.views { - if v != nil { - v.Release() - v = nil - } - } -} - -// pullBuffer extracts the enough underlying storage from b.buffer to hold n -// bytes. It removes this storage from b.buffer, returns a new buffer -// that holds the storage, and updates pulledIndex to indicate which part -// of b.buffer's storage must be reallocated during the next call to -// nextIovecs. -func (b *iovecBuffer) pullBuffer(n int) (pulled buffer.Buffer, ok bool) { - var views []*buffer.View - c := 0 - - needsUnlock := false - if threadSafe { - b.Lock() - needsUnlock = true - } - // Remove the used views from the buffer. - for i, v := range b.views { - if v == nil { - continue - } - c += v.Size() - if c >= n { - b.views[i].CapLength(v.Size() - (c - n)) - views = append(views, b.views[:i+1]...) - break - } - } - for i := range views { - b.views[i] = nil - } - if needsUnlock { - b.Unlock() - } - - for i, v := range views { - if err := pulled.Append(v); err != nil { - log.W("ns: dispatch: iov: err append view# %d: %v", i, err) - continue - } - } - ok = pulled.Size() >= int64(n) - pulled.Truncate(int64(n)) - return -} - -// readVDispatcher uses readv() system call to read inbound packets and -// dispatches them. -type readVDispatcher struct { - e *endpoint // e is the endpoint this dispatcher is attached to. - buf *iovecBuffer // buf is the iovec buffer that contains packets. - closed atomic.Bool // closed is set to true when fd is closed. - once sync.Once // Ensures stop() is called only once. - mgr *supervisor -} - -var _ linkDispatcher = (*readVDispatcher)(nil) - -// newReadVDispatcher creates a new linkDispatcher that vector reads packets from -// fd and dispatches them to endpoint e. It assumes ownership of fd but not of e. -func newReadVDispatcher(f *fds, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{ - e: e, - buf: newIovecBuffer(bufcfg), - mgr: newSupervisor(e, f.tun()), - } - d.mgr.start() - - log.I("ns: dispatch: newReadVDispatcher: tun(%s)", f) - return d, nil -} - -// swap atomically swaps existing fd for this new one. -// On error, it closes fd. -func (d *readVDispatcher) prepare(f *fds) { - if !d.closed.Load() { - d.mgr.note(f.tun()) // used for diagnostics only - } -} - -// stop stops the dispatcher once. Safe to call multiple times. -func (d *readVDispatcher) stop() { - d.once.Do(func() { - d.closed.Store(true) - d.mgr.stop() - log.I("ns: dispatch: closed!") - }) -} - -const abort = false // abort indicates that the dispatcher should stop. -const cont = true // cont indicates that the dispatcher should continue delivering packets despite an error. - -// dispatch reads packets from the current file descriptor in d.fds and dispatches it to netstack. -func (d *readVDispatcher) dispatch(fd *fds) (bool, tcpip.Error) { - return d.io(fd) -} - -// wrapup reads packets from fds and dispatches it to netstack -// and closes fds on error or on timeout. If settings.Loopingback is true, -// it closes fds immediately. Must be called in a goroutine. -func (d *readVDispatcher) wrapup(fds *fds, noMoreThan30s time.Duration) { - if !fds.ok() { // fds may be nil - return - } - defer fds.stop() - - // Loopback is set to true when VPN is in lockdown mode (block connections - // without vpn). It is observed that by closing the previous tun after delay - // results in "connection was reset" errors in netstack's TCP handler, which - // only go away if the tunnel is recreated (via stop/start or pause/unpause). - // This behaviour is seen when the device either connects to internet after - // quite a while or the device switches to a new network (on Android 14+). - if !threadSafe || settings.Loopingback.Load() { - log.I("ns: tun(%d): wrapup: immediate (loopback)", fds.tun()) - return - } - - awaited := core.Await(func() { - log.I("ns: tun(%d): drain: start w timeout in %s", fds.tun(), core.FmtPeriod(noMoreThan30s)) - for { - cont, err := d.io(fds) - if fd := fds.tun(); !cont { - log.W("ns: tun(%d): drain: exit; err? %v", fd, err) - return - } else if err != nil { - log.W("ns: tun(%d): drain: continue on err: %v", fd, err) - } // else: continue draining - } - }, min(30*time.Second, noMoreThan30s)) - - logei(!awaited)("ns: tun(%d): drain: timeout!", fds.tun()) -} - -// io reads packets from fds and dispatches it to netstack. -// not thread safe (see: threadSafe). -func (d *readVDispatcher) io(fds *fds) (bool, tcpip.Error) { - done := d.closed.Load() - if done { - log.W("ns: tun(%d): dispatch: done! %t", fds.tun(), done) - d.buf.release() // not thread safe - return abort, new(tcpip.ErrAborted) - } - - if fds == nil || !fds.ok() { // nil check for nilaway - log.W("ns: tun(%d): dispatch: fd closed or invalid!", fds.tun()) - return abort, new(tcpip.ErrNoSuchFile) - } - - iov := d.buf.nextIovecs() // not thread safe - if len(iov) == 0 { - log.E("ns: tun(%d): dispatch: iov == 0", fds.tun()) - return abort, new(tcpip.ErrBadBuffer) - } - - if settings.Debug { - log.VV("ns: tun(%d): dispatch: start; iov: %d", fds.tun(), len(iov)) - } - - start := time.Now() - - // github.com/google/gvisor/blob/d59375d82/pkg/tcpip/link/fdbased/packet_dispatchers.go#L186 - n, errno := rawfile.BlockingReadvUntilStopped(fds.eve(), fds.tun(), iov) - - fds.lastRead.Store(time.Now().UnixMilli()) // update last read time - - if settings.Debug { - log.VV("ns: tun(%d): dispatch: after %s, got(iov: %d / bytes: %d / tot: %d), err(%v)", - fds.tun(), core.FmtPeriod(time.Since(start)), len(iov), n, fds.read.Load(), errno) - } - - if n <= 0 || errno != 0 { - if errno == 0 { - return abort, new(tcpip.ErrNoSuchFile) - } - return abort, tcpip.TranslateErrno(errno) - } - - fds.read.Add(int64(n)) // update read bytes - - b, ok := d.buf.pullBuffer(n) // not thread safe - if !ok { - log.E("ns: tun(%d): dispatch: pullBuffer err; n: %d", fds.tun(), n) - return abort, new(tcpip.ErrBadBuffer) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: b, - }) - defer pkt.DecRef() - - var iseth = d.e.hdrSize > 0 // hdrSize always zero; unused - if iseth { - if !d.e.parseHeader(pkt) { - return abort, new(tcpip.ErrNotPermitted) - } - pkt.NetworkProtocolNumber = header.Ethernet(pkt.LinkHeader().Slice()).Type() - } - - start = time.Now() - if settings.Debug { - log.VV("ns: tun(%d): dispatch: pkt sz: %d", fds.tun(), pkt.Size()) - } - - d.mgr.queuePacket(pkt, iseth) - d.mgr.wakeReady() - - if settings.Debug { - log.VV("ns: tun(%d): dispatch: done after %s; pkt sz: %d", - fds.tun(), core.FmtPeriod(time.Since(start)), pkt.Size()) - } - - return cont, nil -} - -func rand10pc() bool { - return rand.Intn(999999) < 99999 -} - -func rand1pc() bool { - return rand.Intn(999999) < 9999 -} diff --git a/intra/netstack/fdbased.go b/intra/netstack/fdbased.go deleted file mode 100644 index ae9b01f8..00000000 --- a/intra/netstack/fdbased.go +++ /dev/null @@ -1,628 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package netstack provides the implemention of data-link layer endpoints -// backed by boundary-preserving file descriptors (e.g., TUN devices, -// seqpacket/datagram sockets). -// -// Adopted from: github.com/google/gvisor/blob/f33d034/pkg/tcpip/link/fdbased/endpoint.go -// since fdbased isn't built when building for android (it is only built for linux). -package netstack - -import ( - "errors" - "fmt" - "runtime/debug" - "sync/atomic" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/rawfile" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.InjectableLinkEndpoint = (*endpoint)(nil) -var _ stack.LinkEndpoint = (*endpoint)(nil) - -// placeholder FD for whenever existing FD wrapped in struct fds is closed. -const invalidfd int = -1 - -// wrapttl is the time to wait for the dispatcher to wrap up (close a previous FD). -const waitttl = wrapttl - -var errNeedsNewEndpoint = errors.New("ns: needs new endpoint") - -// linkDispatcher reads packets from the link FD and dispatches them to the -// NetworkDispatcher. -type linkDispatcher interface { - stop() - prepare(fd *fds) - dispatch(fd *fds) (bool, tcpip.Error) - wrapup(prev *fds, ttl time.Duration) -} - -type endpoint struct { - sync.RWMutex - // fds is the set of file descriptors each identifying one inbound/outbound - // channel. The endpoint will dispatch from all inbound channels as well as - // hash outbound packets to specific channels based on the packet hash. - fds *core.Volatile[*fds] - - // mtu (maximum transmission unit) is the maximum size of a packet. - mtu atomic.Uint32 - - // hdrSize specifies the link-layer header size. If set to 0, no header - // is added/removed; otherwise an ethernet header is used. - hdrSize int - - // addr is the address of the endpoint. - addr tcpip.LinkAddress - - // caps holds the endpoint capabilities. - caps stack.LinkEndpointCapabilities - - // dispatches packets from the link FD (tun device) - // to the network stack. Protected by the endpoint's mutex. - inboundDispatcher linkDispatcher - // the nic this endpoint is attached to. Protected by the endpoint's mutex. - dispatcher stack.NetworkDispatcher - - // wg keeps track of running goroutines. - wg core.RollingWaitGroup - - // maxSyscallHeaderBytes has the same meaning as - // Options.MaxSyscallHeaderBytes. - maxSyscallHeaderBytes uintptr - - // writevMaxIovs is the maximum number of iovecs that may be passed to - // rawfile.NonBlockingWriteIovec, as possibly limited by - // maxSyscallHeaderBytes. (No analogous limit is defined for - // rawfile.NonBlockingSendMMsg, since in that case the maximum number of - // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch - // encounters a packet whose iovec count is limited by - // maxSyscallHeaderBytes, it falls back to writing the packet using writev - // via WritePacket.) - writevMaxIovs int -} - -// Options specify the details about the fd-based endpoint to be created. -type Options struct { - // FDs is a set of FDs used to read/write packets. - FDs []int - - // MTU is the mtu to use for this endpoint. - MTU uint32 - - // EthernetHeader if true, indicates that the endpoint should read/write - // ethernet frames instead of IP packets. - EthernetHeader bool - - // Address is the link address for this endpoint. Only used if - // EthernetHeader is true. - Address tcpip.LinkAddress - - // SaveRestore if true, indicates that this NIC capability set should - // include CapabilitySaveRestore - SaveRestore bool - - // DisconnectOk if true, indicates that this NIC capability set should - // include CapabilityDisconnectOk. - DisconnectOk bool - - // TXChecksumOffload if true, indicates that this endpoints capability - // set should include CapabilityTXChecksumOffload. - TXChecksumOffload bool - - // RXChecksumOffload if true, indicates that this endpoints capability - // set should include CapabilityRXChecksumOffload. - RXChecksumOffload bool - - // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes - // of struct iovec, msghdr, and mmsghdr that may be passed by each host - // system call. - MaxSyscallHeaderBytes int -} - -// New creates a new fd-based endpoint. -// -// Makes fd non-blocking, but does not take ownership of fd, which must remain -// open for the lifetime of the returned endpoint (until after the endpoint has -// stopped being using and Wait returns). -func newFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) { - caps := stack.LinkEndpointCapabilities(0) - if opts.RXChecksumOffload { - caps |= stack.CapabilityRXChecksumOffload - } - - if opts.TXChecksumOffload { - caps |= stack.CapabilityTXChecksumOffload - } - - hdrSize := 0 - if opts.EthernetHeader { - hdrSize = header.EthernetMinimumSize - caps |= stack.CapabilityResolutionRequired - } - - if opts.SaveRestore { - caps |= stack.CapabilitySaveRestore - } - - if len(opts.FDs) == 0 { - return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") - } - - if opts.MaxSyscallHeaderBytes < 0 { - return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative") - } - - e := &endpoint{ - mtu: atomic.Uint32{}, - fds: core.NewVolatile(invalidFds), - caps: caps, - addr: opts.Address, - hdrSize: hdrSize, - // MaxSyscallHeaderBytes remains unused - maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes), - writevMaxIovs: rawfile.MaxIovs, - } - if e.maxSyscallHeaderBytes != 0 { - if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs { - e.writevMaxIovs = max - } - } - - // Create per channel dispatchers; usually only one. - if len(opts.FDs) != 1 { - return nil, fmt.Errorf("len(opts.FDs) = %d, expected 1", len(opts.FDs)) - } - - e.SetMTU(opts.MTU) - - if err := e.swap(opts.FDs[0], true); err != nil { - return nil, err - } - - return e, nil -} - -func createInboundDispatcher(e *endpoint, f *fds) (linkDispatcher, error) { - // By default use the readv() dispatcher as it works with all kinds of - // FDs (tap/tun/unix domain sockets and af_packet). - d, err := newReadVDispatcher(f, e) - if err != nil { - return nil, fmt.Errorf("newReadVDispatcher(%s, %+v) = %v", f, e, err) - } - return d, nil -} - -func (e *endpoint) Stat() (zz EpStat) { - fds := e.fds.Load() - if fds == nil { - return - } - - t := time.Now() - if death := fds.death.Load(); death > 0 { - t = time.UnixMilli(death) - } - - age := t.Sub(time.UnixMilli(fds.since.Load())) - - return EpStat{ - Fd: fds.tunFd, // f.tun() returns invalidfd if f.tunFd is closed - Alive: !fds.closed.Load(), - Age: core.FmtPeriod(age), - Read: core.FmtBytes(uint64(fds.read.Load())), - Written: core.FmtBytes(uint64(fds.written.Load())), - LastRead: core.FmtUnixMillisAsPeriod(fds.lastRead.Load()), - LastWrite: core.FmtUnixMillisAsPeriod(fds.lastWrite.Load()), - } -} - -func (e *endpoint) Dispose() (err error) { - e.Lock() - defer e.Unlock() - - prevfd := e.fds.Swap(invalidFds) // prevfd may be invalidfd - - if e.inboundDispatcher == nil { - prevfd.stop() // prevfd may be invalidfd - log.W("ns: tun(%d): Dispose, no inbound dispatcher", prevfd.tun()) - // nothing to do - return nil - } - - // e.inboundDispatcher.prepare() will not close prevfd - // dispatchLoop() will auto-exit on invalidfd - e.inboundDispatcher.wrapup(prevfd, wrapttl) - e.inboundDispatcher.prepare(invalidFds) - - return nil -} - -// Implements FdSwapper. -func (e *endpoint) Swap(fd, mtu int) (err error) { - e.SetMTU(uint32(mtu)) - return e.swap(fd, false) -} - -func (e *endpoint) swap(fd int, force bool) (err error) { - e.Lock() - defer e.Unlock() - - prevfd := e.fds.Load() - if !force && !prevfd.ok() { - return errNeedsNewEndpoint - } // if prevfd is nil, then we're creating a new endpoint - - f, err := newTun(fd) // fd may be invalid (ex: -1) - if err != nil || f == nil { - clos(fd) - return log.EE("ns: tun(%d): swap: err: %v / %v; using invalidfd", fd, err) - } - - e.fds.Store(f) // commence WritePackets() on fd - - log.D("ns: tun(%s): swap: fd %s => %d; err? %v", prevfd, prevfd, fd, err) - - if e.inboundDispatcher == nil { // prevfd must be 0 value if inbound is nil - prevfd.stop() // prevfd may be invalid - e.inboundDispatcher, err = createInboundDispatcher(e, f) - } else { - // closes prevfd, which may be invalidfd - e.inboundDispatcher.wrapup(prevfd, wrapttl) - e.inboundDispatcher.prepare(f) - } - - hasDispatcher := e.dispatcher != nil - if err == nil && hasDispatcher { // attached? - log.I("ns: tun(%s): (%s => %d) swap: restart looper %t for new fd", - prevfd, prevfd, fd, hasDispatcher) - go dispatchLoop(e.inboundDispatcher, f, &e.wg) - } else { // wait for Attach to be called eventually - log.E("ns: tun(%s): (%s => %d) swap: no dispatcher? %t for new fd; err %v", - prevfd, prevfd, fd, !hasDispatcher, err) - } - return -} - -// Attach launches the goroutine that reads packets from the file descriptor and -// dispatches them via the provided dispatcher. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - log.D("ns: attaching nic... %t", dispatcher != nil) - - e.Lock() - defer e.Unlock() - - rx := e.inboundDispatcher - - fds := e.fds.Load() - fd := fds.tun() - - attach := dispatcher != nil // nil means the NIC is being removed. - pipe := rx != nil // nil means there's no read dispatcher. - exists := e.dispatcher != nil // nil means the NIC is already detached. - - // Attach is called when the NIC is being created and then enabled. - // stack.CreateNIC -> nic.newNIC -> ep.Attach - if dispatcher == nil && e.dispatcher != nil { - log.I("ns: tun(%d): attach: detach dispatcher (and inbound? %t)", fd, pipe) - allLoopersExited := true - if rx != nil { - core.Gx("ns.stop", rx.stop) // avoid mutex - fds.stop() - - allLoopersExited = e.wait(waitttl) // on all inboundDispatcher w/ mutex locked? - } - e.dispatcher = nil - e.inboundDispatcher = nil // rx - e.fds.Store(invalidFds) - logei(!allLoopersExited)("ns: tun(%d): attach: done detaching dispatcher; all loopers done? %t", - fd, allLoopersExited) - return - } - - if dispatcher != nil && e.dispatcher == nil { - log.I("ns: tun(%d): attach: new dispatcher & looper", fd) - e.dispatcher = dispatcher - if e.inboundDispatcher == nil && fds.ok() { // unlikely - var err error - e.inboundDispatcher, err = createInboundDispatcher(e, fds) - logeif(err)("ns: tun(%d): attach: just-in-time createInboundDispatcher; err? %v", fd, err) - if e.inboundDispatcher != nil { - e.inboundDispatcher.prepare(fds) - } - rx = e.inboundDispatcher - } - go dispatchLoop(rx, fds, &e.wg) - return - } - - if dispatcher != nil { - log.W("ns: tun(%d): attach: discard? %t; but switch to new anyway", fd, exists) - e.dispatcher = dispatcher - return - } - - log.W("ns: tun(%d): attach: discard? %t; hadDispatcher? %t hadInbound? %t", fd, exists, attach, pipe) -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - d, _ := e.getDispatchers() - return d != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *endpoint) MTU() uint32 { - return e.mtu.Load() -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.caps -} - -// MaxHeaderLength returns the maximum size of the link-layer header. -func (e *endpoint) MaxHeaderLength() uint16 { - return uint16(e.hdrSize) -} - -// LinkAddress returns the link address of this endpoint. -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - e.RLock() - defer e.RUnlock() - return e.addr -} - -// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop -// reading from its FD. -func (e *endpoint) Wait() { - (&e.wg).Wait() -} - -func (e *endpoint) wait(d time.Duration) bool { - // wait on e.Wait() until ttl expires - return core.Await(func() { e.Wait() }, d) -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *endpoint) AddHeader(pkt *stack.PacketBuffer) { - if e.hdrSize > 0 && pkt != nil { - // Add ethernet header if needed. - eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) - eth.Encode(&header.EthernetFields{ - SrcAddr: pkt.EgressRoute.LocalLinkAddress, - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - }) - } -} - -func (e *endpoint) parseHeader(pkt *stack.PacketBuffer) bool { - if pkt == nil { - return false - } - _, ok := pkt.LinkHeader().Consume(e.hdrSize) - return ok -} - -// ParseHeader implements stack.LinkEndpoint.ParseHeader. -func (e *endpoint) ParseHeader(pkt *stack.PacketBuffer) bool { - if pkt == nil { - return false - } - if e.hdrSize > 0 { - return e.parseHeader(pkt) - } - return true -} - -// fd returns the file descriptor associated with the endpoint. -func (e *endpoint) fd() int { - return e.fds.Load().tun() -} - -// writePackets writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -// Way more simplified than og impl, ref: github.com/google/gvisor/issues/7125 -func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { - if pkts.Len() == 0 { - return 0, nil - } - - // Preallocate to avoid repeated reallocation as we append to batch. - // batchSz is 47 because when SWGSO is in use then a single 65KB TCP - // segment can get split into 46 segments of 1420 bytes and a single 216 - // byte segment. - const batchSz = 47 - fds := e.fds.Load() - fd := fds.tun() // if closed, returns invalidfd - - if fd == invalidfd { - log.E("ns: tun(-1): WritePackets (to tun): fd invalid (pkts: %d)", pkts.Len()) - return 0, &tcpip.ErrNoSuchFile{} - } - - batch := make([]unix.Iovec, 0, batchSz) - packets, written := 0, 0 - total := pkts.Len() - - defer func() { - fds.written.Add(int64(written)) // update written bytes - fds.lastWrite.Store(time.Now().UnixMilli()) - }() - - for _, pkt := range pkts.AsSlice() { - views := pkt.AsSlices() - numIovecs := len(views) - if len(batch)+numIovecs > rawfile.MaxIovs { - // writes in to fd, up to len(batch) not cap(batch) - if errno := rawfile.NonBlockingWriteIovec(fd, batch); errno != 0 { - log.W("ns: tun(%d): WritePackets (to tun): err(%v), sent(%d)/total(%d)", fd, errno, written, total) - return written, tcpip.TranslateErrno(errno) - } - // mark processed packets as written - written += packets - // truncate batch - batch = batch[:0] - // reset processed packets count - packets = 0 - } - for _, v := range views { - batch = rawfile.AppendIovecFromBytes(batch, v, rawfile.MaxIovs) - } - packets += 1 - } - if len(batch) > 0 { - if errno := rawfile.NonBlockingWriteIovec(fd, batch); errno != 0 { - log.W("ns: tun(%d): WritePackets (to tun): err(%v), sent(%d)/total(%d)", fd, errno, packets, total) - return written, tcpip.TranslateErrno(errno) - } - written += packets - } - - if settings.Debug { - log.VV("ns: tun(%d): WritePackets (to tun): written(%d)/total(%d)", fd, written, total) - } - return written, nil -} - -/* -func (e *endpoint) notifyRestart() { - // deferred fns here should not end up calling the caller of notifyRestart to avoid - // infinite recursion (callerFn -> someotherFn -> panic -> notifyRestart -> callerFn) - // defer e.Attach(nil) - log.U("Network stopped; restart the app") -} -*/ - -// dispatchLoop reads packets from the file descriptor in a loop and dispatches -// them to the network stack (linkDispatcher). Must be run as a goroutine. -func dispatchLoop(inbound linkDispatcher, f *fds, wg *core.RollingWaitGroup) tcpip.Error { - debug.SetPanicOnFault(true) - // defer core.RecoverFn("ns.e.dispatch", e.notifyRestart) - defer core.Recover(core.Exit11, "ns.e.dispatch") - - wg.Add(1) - defer wg.Done() - - if inbound == nil || core.IsNil(inbound) { - defer f.stop() - log.W("ns: tun(%d): dispatchLoop: inbound nil", f.tun()) - return &tcpip.ErrUnknownDevice{} - } - - start := time.Now() - log.I("ns: tun(%d): dispatchLoop: start", f.tun()) - for { - cont, err := inbound.dispatch(f) - if err != nil { - logei(cont)("ns: tun(%d): dispatchLoop: dur: %s; continue? %t; err: %v", - f.tun(), core.FmtTimeAsPeriod(start), cont, err) - } - if !cont { - defer f.stop() - return err - } // else: continue dispatching - } -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (e *endpoint) ARPHardwareType() header.ARPHardwareType { - if e.hdrSize > 0 { - return header.ARPHardwareEther - } - return header.ARPHardwareNone -} - -func (e *endpoint) SetLinkAddress(addr tcpip.LinkAddress) { - e.Lock() - defer e.Unlock() - e.addr = addr -} - -func (e *endpoint) SetMTU(mtu uint32) { - e.mtu.Store(mtu) -} - -func (e *endpoint) getDispatchers() (stack.NetworkDispatcher, *fds) { - e.RLock() - defer e.RUnlock() - return e.dispatcher, e.fds.Load() -} - -// InjectInbound ingresses a netstack-inbound packet. -func (e *endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - d, fds := e.getDispatchers() - fd := fds.tun() - - log.VV("ns: tun(%d): inject-inbound (from tun) %s; %d", fd, fds, protocol) - if d != nil && pkt != nil { - d.DeliverNetworkPacket(protocol, pkt) - } else { - log.W("ns: tun(%d): inject-inbound (from tun) %d pkt?(%t) dropped: endpoint not attached", fd, protocol, pkt != nil) - } -} - -// Unused: InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. -// InjectOutbound egresses a tun-inbound packet. -func (e *endpoint) InjectOutbound(dest tcpip.Address, packet *buffer.View) tcpip.Error { - f := e.fds.Load() - fd := f.tun() - - if !f.ok() { - log.E("ns: tun(%d): inject-outbound (to tun) to dst(%v): endpoint not attached", fd, dest) - return &tcpip.ErrUnknownDevice{} - } - - b := packet.AsSlice() - sz := int64(len(b)) - defer f.written.Add(sz) // update written bytes - defer f.lastWrite.Store(time.Now().UnixMilli()) - - if settings.Debug { - log.VV("ns: tun(%d): inject-outbound (to tun) to dst(%v) sz(%d)", fd, dest, sz) - } - - errno := rawfile.NonBlockingWrite(fd, b) - return tcpip.TranslateErrno(errno) -} - -// Close implements stack.LinkEndpoint. -func (e *endpoint) Close() { - log.W("ns: tun(%d): Close!", e.fd()) - e.Attach(nil) -} - -// SetOnCloseAction implements stack.LinkEndpoint. -func (*endpoint) SetOnCloseAction(func()) {} diff --git a/intra/netstack/fds.go b/intra/netstack/fds.go deleted file mode 100644 index f81007ab..00000000 --- a/intra/netstack/fds.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package netstack - -import ( - "fmt" - "strconv" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "golang.org/x/sys/unix" -) - -var invalidFds = &fds{stopFd: stopFd{efd: invalidfd}, tunFd: invalidfd} - -// stopFd is an eventfd used to signal the stop of a dispatcher. -type stopFd struct { - efd int -} - -func newStopFd() (stopFd, error) { - efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK) - if err != nil { - return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err) - } - return stopFd{efd: efd}, nil -} - -// stop writes to the eventfd and notifies the dispatcher to stop. It does not -// block. -func (s *stopFd) stop() error { - if s.efd == invalidfd || s.efd <= 2 { - return nil - } - increment := []byte{1, 0, 0, 0, 0, 0, 0, 0} - if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil { - // There are two possible errors documented in eventfd(2) for writing: - // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL. - // 2. stop is only supposed to be called once, it can't reach the limit, - // thus no EAGAIN. - return fmt.Errorf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment)) - } - return nil -} - -type fds struct { - stopFd stopFd - tunFd int - - read atomic.Int64 // number of bytes read - written atomic.Int64 // number of bytes written - since atomic.Int64 // when fd was created - death atomic.Int64 // age in millis - lastRead atomic.Int64 // last read time in millis - lastWrite atomic.Int64 // last write time in millis - - closed atomic.Bool - once sync.Once // ensures that stop() is called only once -} - -// Takes ownership of fd, which must be a valid TUN file descriptor. -// Never returns nil, but may return an invalid fds (with tunFd = -1). -func newTun(fd int) (*fds, error) { - if fd == invalidfd { - return invalidFds, nil - } - err := unix.SetNonblock(fd, true) - if err != nil { - clos(fd) - return invalidFds, err - } - stopFd, err := newStopFd() - if err != nil { - clos(fd) - return invalidFds, err - } - f := &fds{ - stopFd: stopFd, - tunFd: fd, - } - f.since.Store(time.Now().UnixMilli()) - return f, nil -} - -func (f *fds) ok() bool { - return f != nil && f.tun() != invalidfd && !f.closed.Load() -} - -func (f *fds) eve() int { - if f != nil && f.stopFd.efd > 2 { - return f.stopFd.efd - } - return invalidfd -} - -func (f *fds) tun() int { - if f != nil && f.tunFd > 2 { - return f.tunFd - } - return invalidfd -} - -func (f *fds) stop() { - if f.ok() { - f.once.Do(func() { - defer f.closed.Store(true) - - now := time.Now().UnixMilli() - f.death.Store(now) - age := now - f.since.Load() - - err1 := f.stopFd.stop() - err2 := syscall.Close(f.tunFd) - logeif(err1)("ns: dispatch: fds: stop: eve(%d) tun(%d) age(%s); errs? %v %v", - f.stopFd.efd, f.tunFd, core.FmtMillis(age), err1, err2) - }) - } else { - log.W("ns: dispatch: fds: stop: no-op") - } -} - -func (f *fds) String() string { - return strconv.Itoa(f.tunFd) -} - -func clos(fd int) { - if fd > 0 || fd != invalidfd { - _ = syscall.Close(fd) - } -} diff --git a/intra/netstack/forwarders.go b/intra/netstack/forwarders.go deleted file mode 100644 index ab7017a8..00000000 --- a/intra/netstack/forwarders.go +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2024 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package netstack - -import ( - "encoding/binary" - "fmt" - "math/rand" - "sync" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// adopted from: github.com/google/gvisor/blob/a244eff8ad/pkg/tcpip/link/fdbased/processors.go - -const maxForwarders = 8 - -// disable log.wtf driven exit (it did not work the last time it was tested) -const testwtf = false - -type fiveTuple struct { - srcAddr, dstAddr []byte - srcPort, dstPort uint16 - proto tcpip.NetworkProtocolNumber -} - -func (t fiveTuple) String() string { - return fmt.Sprintf("%d | :%d => :%d", t.proto, t.srcPort, t.dstPort) -} - -// tcpipConnectionID returns a tcpip connection id tuple based on the data found -// in the packet. It returns true if the packet is not associated with an active -// connection (e.g ARP, NDP, etc). The method assumes link headers have already -// been processed if they were present. -func tcpipConnectionID(pkt *stack.PacketBuffer) (fiveTuple, bool) { - tup := fiveTuple{} - h, ok := pkt.Data().PullUp(1) - if !ok { - // Skip this packet. - return tup, true - } - - const tcpSrcDstPortLen = 4 - switch header.IPVersion(h) { - case header.IPv4Version: - hdrLen := header.IPv4(h).HeaderLength() - h, ok = pkt.Data().PullUp(int(hdrLen) + tcpSrcDstPortLen) - if !ok { - return tup, true - } - ipHdr := header.IPv4(h[:hdrLen]) - tcpHdr := header.TCP(h[hdrLen:][:tcpSrcDstPortLen]) - - tup.srcAddr = ipHdr.SourceAddressSlice() - tup.dstAddr = ipHdr.DestinationAddressSlice() - // All fragment packets need to be processed by the same goroutine, so - // only record the TCP ports if this is not a fragment packet. - if ipHdr.IsValid(pkt.Data().Size()) && !ipHdr.More() && ipHdr.FragmentOffset() == 0 { - tup.srcPort = tcpHdr.SourcePort() - tup.dstPort = tcpHdr.DestinationPort() - } - tup.proto = header.IPv4ProtocolNumber - case header.IPv6Version: - h, ok = pkt.Data().PullUp(header.IPv6FixedHeaderSize + tcpSrcDstPortLen) - if !ok { - return tup, true - } - ipHdr := header.IPv6(h) - - var tcpHdr header.TCP - if tcpip.TransportProtocolNumber(ipHdr.NextHeader()) == header.TCPProtocolNumber { - tcpHdr = header.TCP(h[header.IPv6FixedHeaderSize:][:tcpSrcDstPortLen]) - } else { - // Slow path for IPv6 extension headers :(. - dataBuf := pkt.Data().ToBuffer() - dataBuf.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataBuf) - defer it.Release() - for { - hdr, done, err := it.Next() - if done || err != nil { - break - } - if hdr != nil { - hdr.Release() - } // todo: else, break? - } - h, ok = pkt.Data().PullUp(int(it.HeaderOffset()) + tcpSrcDstPortLen) - if !ok { - return tup, true - } - tcpHdr = header.TCP(h[it.HeaderOffset():][:tcpSrcDstPortLen]) - } - tup.srcAddr = ipHdr.SourceAddressSlice() - tup.dstAddr = ipHdr.DestinationAddressSlice() - tup.srcPort = tcpHdr.SourcePort() - tup.dstPort = tcpHdr.DestinationPort() - tup.proto = header.IPv6ProtocolNumber - default: - return tup, true - } - return tup, false -} - -type processor struct { - mu sync.Mutex - // +checklocks:mu - pkts stack.PacketBufferList - - e stack.InjectableLinkEndpoint - icmp *icmpResponder - sleeper sleep.Sleeper - packetWaker sleep.Waker - closeWaker sleep.Waker - - testcrash bool -} - -// start starts the processor goroutine; thread-safe. -func (p *processor) start(wg *sync.WaitGroup) { - defer wg.Done() - defer p.sleeper.Done() - for { - switch w := p.sleeper.Fetch(true); { - case w == &p.packetWaker: - p.deliverPackets() - case w == &p.closeWaker: - // must unlock via deferred since panics are recovered above - p.mu.Lock() - defer p.mu.Unlock() - p.pkts.Reset() - return - } - } -} - -// deliverPackets delivers packets to the endpoint; thread-safe. -func (p *processor) deliverPackets() { - testpanic := !p.testcrash && settings.PanicAtRandom.Load() && rand10pc() - if testpanic { - defer core.Recover(core.Exit11, "ns.forwarder.deliverPackets") - } - - p.mu.Lock() - defer p.mu.Unlock() - for p.pkts.Len() > 0 { - pkt := p.pkts.PopFront() - p.mu.Unlock() - if pkt != nil { - if !p.icmp.respond(pkt) { - p.e.InjectInbound(pkt.NetworkProtocolNumber, pkt) - } - pkt.DecRef() - } - p.mu.Lock() - } - - if testpanic { - panic("ns: tun: forwarder: deliverPackets rand10pc") - } else if testwtf && !p.testcrash && settings.PanicAtRandom.Load() && rand1pc() { - p.testcrash = true - core.RuntimeWtf("ns: tun: forwarder: test fatal\n") - var mu sync.Mutex - mu.Unlock() // ka-boom - } -} - -// supervisor handles starting, closing, and queuing packets on processor -// goroutines. -type supervisor struct { - processors []processor - icmp *icmpResponder - seed uint32 - wg sync.WaitGroup - sid *core.Volatile[int] // tun fd for diagnostics - ready []bool -} - -// newSupervisor creates a new supervisor for the processors of endpoint e. -func newSupervisor(e stack.InjectableLinkEndpoint, sid int) *supervisor { - icmp := newICMPResponder(e) - - m := &supervisor{ - seed: rand.Uint32(), - sid: core.NewVolatile(sid), - ready: make([]bool, maxForwarders), - processors: make([]processor, maxForwarders), - icmp: &icmp, - wg: sync.WaitGroup{}, - } - - m.wg.Add(maxForwarders) - for i := range m.processors { - p := &m.processors[i] - p.sleeper.AddWaker(&p.packetWaker) - p.sleeper.AddWaker(&p.closeWaker) - p.e = e - p.icmp = &icmp - } - - return m -} - -// tunid returns a unique identifier (usually current tun fd); used for diagnostics only. -func (m *supervisor) tunid() int { - return m.sid.Load() -} - -// note notes the new tun fd (used for diagnostics only). -func (m *supervisor) note(sid int) { - m.sid.Store(sid) -} - -// start starts the processor goroutines if the processor manager is configured -// with more than one processor. -func (m *supervisor) start() { - if settings.Debug { - log.D("ns: tun(%d): forwarder: starting %d procs %d", m.tunid(), len(m.processors), m.seed) - } - if m.canDeliverInline() { - return - } - for i := range m.processors { - p := &m.processors[i] - core.Gx1("ns.forwarder.start", p.start, &m.wg) - } -} - -// id returns a hash value based on the given five tuple. -// Will return 0 if the hash could not be computed. -func (m *supervisor) id(t *fiveTuple) uint32 { - if t == nil { // never nil, but nilaway complains. - return 0 - } - var payload [4]byte - binary.LittleEndian.PutUint16(payload[0:], t.srcPort) - binary.LittleEndian.PutUint16(payload[2:], t.dstPort) - - h := jenkins.Sum32(m.seed) - if _, err := h.Write(payload[:]); err != nil { - return 0 - } - if len(t.srcAddr) > 0 { - if _, err := h.Write(t.srcAddr); err != nil { - return 0 - } - } // else: should never happen - if len(t.dstAddr) > 0 { - if _, err := h.Write(t.dstAddr); err != nil { - return 0 - } - } // else: should never happen - return h.Sum32() -} - -// queuePacket queues a packet to be delivered to the appropriate processor. -func (m *supervisor) queuePacket(pkt *stack.PacketBuffer, hasEthHeader bool) { - sz := uint32(len(m.processors)) - sid := m.tunid() - var pIdx uint32 - tup, nonConnectionPkt := tcpipConnectionID(pkt) - if !hasEthHeader { - if nonConnectionPkt { - log.D("ns: tun(%d): forwarder: drop non-connection pkt (sz: %d)", sid, pkt.Size()) - // If there's no eth header this should be a standard tcpip packet. If - // it isn't the packet is invalid so drop it. - return - } - pkt.NetworkProtocolNumber = tup.proto - } - if m.canDeliverInline() || nonConnectionPkt || settings.SingleThreaded.Load() { - // If the packet is not associated with an active connection, use the - // first processor. - pIdx = 0 - } else { - pIdx = m.id(&tup) % sz - } - // despite uint32, pIdx goes negative? github.com/celzero/firestack/issues/59 - // go.dev/ref/spec#Integer_overflow? - if pIdx > sz { - log.W("ns: tun(%d): forwarder: invalid processor index %d, %s", sid, pIdx, tup) - pIdx = 0 - } - p := &m.processors[pIdx] - - if settings.Debug { - log.VV("ns: tun(%d): forwarder: q on proc %d, %s", sid, pIdx, tup) - } - - p.mu.Lock() - defer p.mu.Unlock() - p.pkts.PushBack(pkt.IncRef()) // enqueue. - m.ready[pIdx] = true // ready to deliver enqueued packets. -} - -// stop stops all processor goroutines. -func (m *supervisor) stop() { - m.icmp.stop() - sid := m.tunid() - start := time.Now() - log.D("ns: tun(%d): forwarder: stopping %d procs", sid, len(m.processors)) - if !m.canDeliverInline() { - for i := range m.processors { - p := &m.processors[i] - p.closeWaker.Assert() - } - m.wg.Wait() - } // else: no goroutines to stop or wait for. - log.D("ns: tun(%d): forwarder: stopped %d procs in %s", sid, len(m.processors), core.FmtTimeAsPeriod(start)) -} - -// wakeReady wakes up all processors that have a packet queued. If there is only -// one processor, the method delivers the packet inline without waking a -// goroutine. -func (m *supervisor) wakeReady() { - for i, ready := range m.ready { - if !ready { - continue - } - p := &m.processors[i] - if m.canDeliverInline() || settings.SingleThreaded.Load() { - p.deliverPackets() - } else { - p.packetWaker.Assert() - } - m.ready[i] = false - } -} - -// canDeliverInline returns true if the supervisor is configured to deliver -// packets inline. That is, when only one processor is active, deliver -// packets inline. sleeper/waker are no-ops. -func (m *supervisor) canDeliverInline() bool { - return len(m.processors) <= 1 -} diff --git a/intra/netstack/hdl.go b/intra/netstack/hdl.go deleted file mode 100644 index 52b99848..00000000 --- a/intra/netstack/hdl.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "net" - "net/netip" - "strings" - - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type gconns interface { - *GUDPConn | *GTCPConn | *GICMPConn -} - -type GBaseConnHandler interface { - // OpenConns returns the number of active connections. - OpenConns() string - // CloseConns closes conns by ids, or all if ids is empty. - CloseConns([]string) []string - // end closes the handler and all its connections. - End() -} - -type GSpecConnHandler[T gconns] interface { - GBaseConnHandler - // Proxy copies data between conn and dst (egress). - // must not block forever as it may block netstack - // see: netstack/dispatcher.go:newReadvDispatcher - Proxy(in T, src, dst netip.AddrPort) bool - // ReverseProxy copies data between conn and dst (ingress). - ReverseProxy(out T, in net.Conn, src, dst netip.AddrPort) bool - // Error notes the error in connecting src to dst; retrying if necessary. - Error(in T, src, dst netip.AddrPort, err error) -} - -type GMuxConnHandler[T gconns] interface { - // ProxyMux proxies data between conn and multiple destinations - // (endpoint-independent mapping). - ProxyMux(in T, src, dst netip.AddrPort, dmx DemuxerFn) bool -} - -type GEchoConnHandler interface { - // Ping informs if ICMP Echo from src to dst is replied to - Ping(msg []byte, src, dst netip.AddrPort) bool -} - -type GConnHandler interface { - Src() []netip.Prefix - TCP() GTCPConnHandler // TCP returns the TCP handler. - UDP() GUDPConnHandler // UDP returns the UDP handler. - ICMP() GICMPHandler // ICMP returns the ICMP handler. - CloseConns(csv string) string // CloseConns closes the connections with the given IDs, or all if empty. -} - -type gconnhandler struct { - src []netip.Prefix - tcp GTCPConnHandler - udp GUDPConnHandler - icmp GICMPHandler -} - -var _ GConnHandler = (*gconnhandler)(nil) - -func NewGConnHandler(addrs []netip.Prefix, tcp GTCPConnHandler, udp GUDPConnHandler, icmp GICMPHandler) GConnHandler { - return &gconnhandler{ - src: addrs, - tcp: tcp, - udp: udp, - icmp: icmp, - } -} - -func (g *gconnhandler) Src() []netip.Prefix { - // TODO? slices.Clone(g.src) - return g.src -} - -func (g *gconnhandler) TCP() GTCPConnHandler { - return g.tcp -} - -func (g *gconnhandler) UDP() GUDPConnHandler { - return g.udp -} - -func (g *gconnhandler) ICMP() GICMPHandler { - return g.icmp -} - -func (g *gconnhandler) CloseConns(csv string) string { - var cids []string = nil // nil closes all conns - if len(csv) > 0 { - // split returns [""] (slice of length 1) if csv is empty - // and so, avoid splitting on empty csv, and let cids be nil - cids = strings.Split(csv, ",") - } - - var t []string - var u []string - var i []string - if tcp := g.tcp; tcp != nil { - t = tcp.CloseConns(cids) - } - if udp := g.udp; udp != nil { - u = udp.CloseConns(cids) - } - if icmp := g.icmp; icmp != nil { - i = icmp.CloseConns(cids) - } - s := make([]string, 0, len(t)+len(u)+len(i)) - s = append(s, t...) - s = append(s, u...) - s = append(s, i...) - return strings.Join(s, ",") -} - -// src/dst addrs are flipped -// fdbased.Attach -> ... -> nic.DeliverNetworkPacket -> ... -> nic.DeliverTransportPacket: -// github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/stack/nic.go#L831-L837 - -func localAddrPort(id stack.TransportEndpointID) netip.AddrPort { - // todo: unmap? - return localUDPAddr(id).AddrPort() -} - -func remoteAddrPort(id stack.TransportEndpointID) netip.AddrPort { - // todo: unmap? - return remoteUDPAddr(id).AddrPort() -} - -func remoteUDPAddr(id stack.TransportEndpointID) *net.UDPAddr { - return &net.UDPAddr{ - IP: nsaddr2ip(id.RemoteAddress), - Port: int(id.RemotePort), - } -} - -func localUDPAddr(id stack.TransportEndpointID) *net.UDPAddr { - return &net.UDPAddr{ - IP: nsaddr2ip(id.LocalAddress), - Port: int(id.LocalPort), - } -} - -func nsaddr2ip(addr tcpip.Address) net.IP { - b := addr.AsSlice() - return net.IP(b) -} - -func addrport2nsaddr(ipp netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - var proto tcpip.NetworkProtocolNumber - var addr tcpip.Address - if ipp.Addr().Is4() { - proto = ipv4.ProtocolNumber - addr = tcpip.AddrFrom4(ipp.Addr().As4()) - } else { - proto = ipv6.ProtocolNumber - addr = tcpip.AddrFrom16(ipp.Addr().As16()) - } - return tcpip.FullAddress{ - NIC: settings.NICID, - Addr: addr, - Port: ipp.Port(), - }, proto -} diff --git a/intra/netstack/icmp.go b/intra/netstack/icmp.go deleted file mode 100644 index e844c41c..00000000 --- a/intra/netstack/icmp.go +++ /dev/null @@ -1,545 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checksum" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -type GICMPHandler interface { - GBaseConnHandler - GEchoConnHandler -} - -type icmpForwarder struct { - o string - s *stack.Stack - h GICMPHandler -} - -// github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv4/icmp.go -// github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv6/icmp.go -func OutboundICMP(id string, s *stack.Stack, hdl GICMPHandler) { - // remove default handlers - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, nil) - s.SetTransportProtocolHandler(icmp.ProtocolNumber6, nil) - - if hdl == nil { - log.E("icmp: %s: no handler", id) - return - } - - forwarder := newIcmpForwarder(id, s, hdl) - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, forwarder.reply4) - s.SetTransportProtocolHandler(icmp.ProtocolNumber6, forwarder.reply6) - // TODO: the handler must only be set for the "main" netstack - setICMPEchoHandler(forwarder) -} - -func newIcmpForwarder(owner string, s *stack.Stack, h GICMPHandler) *icmpForwarder { - return &icmpForwarder{owner, s, h} -} - -// sendICMP: github.com/google/gvisor/blob/8035cf9ed/pkg/tcpip/transport/tcp/testing/context/context.go#L404 -// parseICMP: github.com/google/gvisor/blob/8035cf9ed/pkg/tcpip/header/parse/parse.go#L194 -// makeICMP: github.com/google/gvisor/blob/8035cf9ed/pkg/tcpip/tests/integration/iptables_test.go#L2100 -func (f *icmpForwarder) reply4(id stack.TransportEndpointID, pkt *stack.PacketBuffer) (handled bool) { - var err tcpip.Error - - if settings.Debug { - log.VV("icmp: v4: %s: packet? %v", f.o, pkt) - } - - if pkt == nil || pkt.Size() <= 0 { - log.E("icmp: v4: %s: nil packet (%t) or size 0", f.o, pkt == nil) - return // not handled - } - - src := remoteAddrPort(id) - dst := localAddrPort(id) - - l4hdr := pkt.TransportHeader().Slice() - l3hdr := pkt.NetworkHeader().Slice() - if len(l4hdr) < header.ICMPv4MinimumSize || len(l3hdr) < header.IPv4MinimumSize { - log.E("icmp: v4: %s: invalid packet size %d; l4hdr: %d / l3hdr: %d", f.o, pkt.Size(), len(l4hdr), len(l3hdr)) - return // not handled - } - - // ref: github.com/google/gvisor/blob/acf460d0d735/pkg/tcpip/stack/conntrack.go#L933 - hdr := header.ICMPv4(l4hdr) - if hdr.Type() != header.ICMPv4Echo { - // netstack handles other msgs except echo / ping - log.D("icmp: v4: %s: type %v passthrough", f.o, hdr.Type()) - return // not handled - } - ipHdr := header.IPv4(l3hdr) - replyData := stack.PayloadSince(pkt.TransportHeader()) - localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast - - // see: github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv4/icmp.go#L371 - // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP - // source address MUST be one of its own IP addresses (but not a broadcast - // or multicast address). - localAddr := ipHdr.DestinationAddress() - if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) { - localAddr = tcpip.Address{} - } - - l3 := pkt.Network() // same as ipHdr; l3.Dst == id.LocalAddr and l3.Src == id.RemoteAddr - route, err := f.s.FindRoute(pkt.NICID, localAddr, l3.SourceAddress(), pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil || replyData.Size() <= 0 { - log.W("icmp: v4: %s: no route on %v to %s <= %s; sz: (l3hdr: %d / l4hdr: %d / payload: %d)", - f.o, pkt.NICID, l3.DestinationAddress(), l3.SourceAddress(), len(l3hdr), len(l4hdr), replyData.Size()) - return // not handled - } - - // github.com/google/gvisor/blob/9b4a7aa00/pkg/tcpip/network/ipv6/icmp.go#L1180 - data, derr := l4l7(pkt, route.MTU()) - if derr != nil { - log.E("icmp: v4: %s: err getting payload: %v", f.o, derr) - return // not handled - } - - if settings.Debug { - log.D("icmp: v4: %s: type %v/%v sz [%v]; src(%v) => dst(%v)", - f.o, hdr.Type(), hdr.Code(), len(data), src, dst) - } - - // always forward in a goroutine to avoid blocking netstack - // see: netstack/dispatcher.go:newReadvDispatcher - pkt.IncRef() - - core.Go("icmp4.pinger."+f.o, func() { - defer replyData.Release() - defer route.Release() - defer pkt.DecRef() - - if !f.h.Ping(data, src, dst) { // unreachable - err = f.icmpErr4(pkt, header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable) - } else { // reachable - originRef := replyData.AsSlice() - replyBuf := buffer.NewViewSize(len(originRef)) - replyRef := replyBuf.AsSlice() - - copy(replyRef[4:], originRef[4:]) - replyICMPHdr := header.ICMPv4(replyRef) - replyICMPHdr.SetType(header.ICMPv4EchoReply) - replyICMPHdr.SetCode(0) // EchoReply must have Code=0. - replyICMPHdr.SetChecksum(^checksum.Checksum(replyRef, 0)) - - var replyPkt *stack.PacketBuffer = stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()), - Payload: buffer.MakeWithView(replyBuf), - }) - replyPkt.NICID = pkt.NICID - defer replyPkt.DecRef() - - if settings.Debug { - log.D("icmp: v4: %s: ok type %v/%v sz[%d] from %v <= %v", - f.o, replyICMPHdr.Type(), replyICMPHdr.Code(), len(replyICMPHdr), src, dst) - } - // github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv4/icmp.go#L794 - err = route.WritePacket(stack.NetworkHeaderParams{ - Protocol: header.ICMPv4ProtocolNumber, - TTL: route.DefaultTTL(), - }, replyPkt) - } - loge(err)("icmp: v4: %s: wrote reply to tun; err? %v", f.o, err) - }) - - return true // handled -} - -func (f *icmpForwarder) reply6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) (handled bool) { - if settings.Debug { - log.VV("icmp: v6: %s: packet? %v", f.o, pkt) - } - - if pkt == nil || pkt.Size() <= 0 { - log.E("icmp: v6: %s: nil packet (%t) or sz <= 0", f.o, pkt == nil) - return // not handled - } - - l4hdr := pkt.TransportHeader().Slice() - if len(l4hdr) < header.ICMPv6MinimumSize { - log.E("icmp: v6: %s: invalid packet size %d; l4hdr: %d", f.o, pkt.Size(), len(l4hdr)) - return // not handled - } - - hdr := header.ICMPv6(l4hdr) - if hdr.Type() != header.ICMPv6EchoRequest { - log.D("icmp: v6: %s: type %v/%v passthrough", f.o, hdr.Type(), hdr.Code()) - return // netstack to handle other msgs except echo / ping - } - - l3 := pkt.Network() // l3.Dst == id.LocalAddr and l3.Src == id.RemoteAddr - route, err := f.s.FindRoute(pkt.NICID, l3.DestinationAddress(), l3.SourceAddress(), pkt.NetworkProtocolNumber, false) - if err != nil { - log.W("icmp: v6: %s: no route on %v to %s <= %s", f.o, pkt.NICID, l3.DestinationAddress(), l3.SourceAddress()) - return // not handled - } - - src := remoteAddrPort(id) - dst := localAddrPort(id) - // github.com/google/gvisor/blob/9b4a7aa00/pkg/tcpip/network/ipv6/icmp.go#L1180 - data, derr := l4l7(pkt, route.MTU()) - if derr != nil || len(data) <= 0 { - log.E("icmp: v6: %s: payload (sz: %d) err: %v", f.o, len(data), derr) - return // not handled - } - - if settings.Debug { - log.D("icmp: v6: %s: type %v/%v sz[l4: %d / payload: %d] from src(%v) => dst(%v)", - f.o, hdr.Type(), hdr.Code(), len(l4hdr), len(data), src, dst) - } - - // always forward in a goroutine to avoid blocking netstack - // see: netstack/dispatcher.go:newReadvDispatcher - pkt.IncRef() - - core.Go("icmp6.pinger."+f.o, func() { - defer route.Release() - defer pkt.DecRef() - - var err tcpip.Error - if !f.h.Ping(data, src, dst) { // unreachable - err = f.icmpErr6(id, pkt, header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable) - } else { // reachable - replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, - Payload: pkt.Data().ToBuffer(), - }) - defer replyPkt.DecRef() - replyHdr := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) - replyPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(replyHdr, hdr) - replyHdr.SetType(header.ICMPv6EchoReply) - replyData := replyPkt.Data() - replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: replyHdr, - Src: route.LocalAddress(), // or id.LocalAddress - Dst: route.RemoteAddress(), // or id.RemoteAddress - PayloadCsum: replyData.Checksum(), - PayloadLen: replyData.Size(), - })) - - if settings.Debug { - log.D("icmp: v6: %s: ok type %v/%v sz[%d] from %v <= %v", - f.o, replyHdr.Type(), replyHdr.Code(), len(replyHdr), src, dst) - } - // github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv6/icmp.go#L694 - replyclass, _ := l3.TOS() - err = route.WritePacket(stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: route.DefaultTTL(), - TOS: replyclass, - }, replyPkt) - } - loge(err)("icmp: v6: %s: wrote reply to tun; err? %v", f.o, err) - }) - - return true -} - -// from: github.com/google/gvisor/blob/19ab27f98/pkg/tcpip/network/ipv4/icmp.go#L609 -func (f *icmpForwarder) icmpErr4(pkt *stack.PacketBuffer, icmpType header.ICMPv4Type, icmpCode header.ICMPv4Code) tcpip.Error { - origIPHdr := header.IPv4(pkt.NetworkHeader().Slice()) - origIPHdrSrc := origIPHdr.SourceAddress() - origIPHdrDst := origIPHdr.DestinationAddress() - - // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in - // response to a non-initial fragment, but it currently can not happen. - if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any { - log.W("icmp: v4: %s: skip broadcast/multicast dst(%s) <= src(%s)", f.o, origIPHdrDst, origIPHdrSrc) - return &tcpip.ErrAddressFamilyNotSupported{} - } - - transportHeader := pkt.TransportHeader().Slice() - - // Don't respond to icmp error packets. - if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { - // We need to decide to explicitly name the packets we can respond to or - // the ones we can not respond to. The decision is somewhat arbitrary and - // if problems arise this could be reversed. It was judged less of a breach - // of protocol to not respond to unknown non-error packets than to respond - // to unknown error packets so we take the first approach. - if len(transportHeader) < header.ICMPv4MinimumSize { - log.D("icmp: v4: %s: l4 header too small: %d", f.o, len(transportHeader)) - return &tcpip.ErrMalformedHeader{} - } - x := header.ICMPv4(transportHeader) - switch x.Type() { - case - header.ICMPv4EchoReply, - header.ICMPv4Echo, - header.ICMPv4Timestamp, - header.ICMPv4TimestampReply, - header.ICMPv4InfoRequest, - header.ICMPv4InfoReply: - default: - // Assume any type we don't know about may be an error type. - log.W("icmp: v4: %s: skip ICMP error packet %d", f.o, x.Type()) - return &tcpip.ErrNotSupported{} - } - } - - var pointer byte = 0 // only needed for param problem packets - switch icmpCode { - case header.ICMPv4NetProhibited: - case header.ICMPv4HostProhibited: - case header.ICMPv4AdminProhibited: - case header.ICMPv4PortUnreachable: - case header.ICMPv4ProtoUnreachable: - case header.ICMPv4NetUnreachable: // or: header.ICMPv4TTLExceeded, header.ICMPv4CodeUnused - case header.ICMPv4HostUnreachable: // or: header.ICMPv4ReassemblyTimeout - case header.ICMPv4FragmentationNeeded: - default: - log.W("icmp: v4: %s: unsupported code %d", f.o, icmpCode) - return &tcpip.ErrNotSupported{} - } - - // origIPDst == id.LocalAddr and origIPSrc == id.RemoteAddr - route, err := f.s.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, pkt.NetworkProtocolNumber, false) - if err != nil { - log.W("icmp: v4: %s: no route on %v to %s <= %s", f.o, pkt.NICID, origIPHdrDst, origIPHdrSrc) - return &tcpip.ErrNoNet{} - } - defer route.Release() - - // Now work out how much of the triggering packet we should return. - // As per RFC 1812 Section 4.3.2.3 - // - // ICMP datagram SHOULD contain as much of the original - // datagram as possible without the length of the ICMP - // datagram exceeding 576 bytes. - // - // NOTE: The above RFC referenced is different from the original - // recommendation in RFC 1122 and RFC 792 where it mentioned that at - // least 8 bytes of the payload must be included. Today linux and other - // systems implement the RFC 1812 definition and not the original - // requirement. We treat 8 bytes as the minimum but will try send more. - mtu := int(route.MTU()) - const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize - if mtu > maxIPData { - mtu = maxIPData - } - available := mtu - header.ICMPv4MinimumSize - needed := len(origIPHdr) + header.ICMPv4MinimumErrorPayloadSize - payloadLen := len(origIPHdr) + len(transportHeader) + pkt.Data().Size() - - if available < needed { - log.W("icmp: v4: %s: no space for orig IP header has: %d < want: %d; total %d", - f.o, available, needed, payloadLen) - return &tcpip.ErrNoBufferSpace{} - } - - if payloadLen > available { - payloadLen = available - } - - // The buffers used by pkt may be used elsewhere in the system. - // For example, an AF_RAW or AF_PACKET socket may use what the transport - // protocol considers an unreachable destination. Thus we deep copy pkt to - // prevent multiple ownership and SR errors. The new copy is a vectorized - // view with the entire incoming IP packet reassembled and truncated as - // required. This is now the payload of the new ICMP packet and no longer - // considered a packet in its own right. - - payload, perr := l3l4(pkt, int64(payloadLen)) - if perr != nil { - log.E("icmp: v4: %s: err getting payload: %v", f.o, perr) - return &tcpip.ErrNoBufferSpace{} - } - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, - Payload: payload, - }) - icmpPkt.IncRef() - defer icmpPkt.DecRef() - - icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - - icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - icmpHdr.SetCode(icmpCode) - icmpHdr.SetType(icmpType) - icmpHdr.SetPointer(pointer) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().Checksum())) - - werr := route.WriteHeaderIncludedPacket(icmpPkt) - - loge(werr)("icmp: v4: %s: sent %d bytes to tun; err? %v", f.o, icmpPkt.Size(), werr) - - return werr -} - -// from: github.com/google/gvisor/blob/19ab27f98/pkg/tcpip/network/ipv6/icmp.go#L1055 -func (f *icmpForwarder) icmpErr6(id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType header.ICMPv6Type, icmpCode header.ICMPv6Code) tcpip.Error { - origIPHdr := header.IPv6(pkt.NetworkHeader().Slice()) - origIPHdrSrc := origIPHdr.SourceAddress() - origIPHdrDst := origIPHdr.DestinationAddress() - - // Only send ICMP error if the address is not a multicast v6 - // address and the source is not the unspecified address. - // - // There are exceptions to this rule. - // See: point e.3) RFC 4443 section-2.4 - // - // (e) An ICMPv6 error message MUST NOT be originated as a result of - // receiving the following: - // - // (e.1) An ICMPv6 error message. - // - // (e.2) An ICMPv6 redirect message [IPv6-DISC]. - // - // (e.3) A packet destined to an IPv6 multicast address. (There are - // two exceptions to this rule: (1) the Packet Too Big Message - // (Section 3.2) to allow Path MTU discovery to work for IPv6 - // multicast, and (2) the Parameter Problem Message, Code 2 - // (Section 3.4) reporting an unrecognized IPv6 option (see - // Section 4.2 of [IPv6]) that has the Option Type highest- - // order two bits set to 10). - // - allowResponseToMulticast := false // TODO: reason.respondsToMulticast() - isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) - if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { - log.W("icmp: v6: %s: skip multicast dst(%s) <= src(%s)", f.o, origIPHdrDst, origIPHdrSrc) - return &tcpip.ErrAddressFamilyNotSupported{} - } - - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { - if typ := header.ICMPv6(pkt.TransportHeader().Slice()).Type(); typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { - log.W("icmp: v6: %s: skip ICMP error packet %d", f.o, typ) - return nil - } - } - - var pointer uint32 = 0 // TODO: must be set for param problem packets - switch icmpCode { - // TODO: handle ICMPv6ParamProblem; determine reason.code, reason.pointer - case header.ICMPv6Prohibited: // ICMPv6DstUnreachable - case header.ICMPv6PortUnreachable: // ICMPv6DstUnreachable - case header.ICMPv6NetworkUnreachable: // ICMPv6DstUnreachable - // or: ICMPv6HopLimitExceeded/ICMPv6UnusedCode -> ICMPv6TimeLimitExceeded - // or: ICMPv6ReassemblyTimeout -> ICMPv6PacketTooBig - case header.ICMPv6AddressUnreachable: // ICMPv6DstUnreachable - default: - log.W("icmp: v6: %s: unsupported code %d", f.o, icmpCode) - return &tcpip.ErrNotSupported{} - } - - // origIPDst == id.LocalAddr and origIPSrc == id.RemoteAddr - route, err := f.s.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, pkt.NetworkProtocolNumber, false) - if err != nil { - log.W("icmp: v6: %s: no route on %v to %s <= %s", f.o, pkt.NICID, origIPHdrDst, origIPHdrSrc) - return &tcpip.ErrNoNet{} - } - defer route.Release() - - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - - // As per RFC 4443 section 2.4 - // - // (c) Every ICMPv6 error message (type < 128) MUST include - // as much of the IPv6 offending (invoking) packet (the - // packet that caused the error) as possible without making - // the error message packet exceed the minimum IPv6 MTU - // [IPv6]. - mtu := int(route.MTU()) - const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize - if mtu > maxIPv6Data { - mtu = maxIPv6Data - } - available := mtu - header.ICMPv6ErrorHeaderSize - needed := header.IPv6MinimumSize - payloadLen := network.Size() + transport.Size() + pkt.Data().Size() - if available < needed { - log.W("icmp: v6: %s: no space for orig IP header; has: %d < want: %d; total %d", - f.o, available, needed, payloadLen) - return &tcpip.ErrNoBufferSpace{} - } - if payloadLen > available { - payloadLen = available - } - - payload, perr := l3l4(pkt, int64(payloadLen)) - if perr != nil { - log.E("icmp: v6: %s: err getting payload: %v", f.o, perr) - return &tcpip.ErrNoBufferSpace{} - } - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize, - Payload: payload, - }) - icmpPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - icmpPkt.IncRef() - defer icmpPkt.DecRef() - - icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetType(icmpType) - icmpHdr.SetCode(icmpCode) - icmpHdr.SetTypeSpecific(pointer) - - pktData := icmpPkt.Data() - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: id.LocalAddress, - Dst: id.RemoteAddress, - PayloadCsum: pktData.Checksum(), - PayloadLen: pktData.Size(), - })) - - werr := route.WriteHeaderIncludedPacket(icmpPkt) - - loge(werr)("icmp: v6: %s: sent %d bytes to tun; err? %v", f.o, icmpPkt.Size(), werr) - - return werr -} - -func loge(err tcpip.Error) (f log.LogFn) { - f = log.E - if err == nil { - f = log.V - } - return -} - -func l4l7(pkt *stack.PacketBuffer, sz uint32) ([]byte, error) { - r := make([]byte, 0, sz) - din := buffer.MakeWithData(r) - l4 := pkt.TransportHeader().View() - err := din.Append(l4) - if err != nil { - log.E("icmp: l4l7: err appending transport header: %v", err) - return nil, err - } - l7 := pkt.Data().ToBuffer() - din.Merge(&l7) // l4 + l7 - return din.Flatten(), nil -} - -func l3l4(pkt *stack.PacketBuffer, sz int64) (b buffer.Buffer, err error) { - l3 := pkt.NetworkHeader().View() - l4 := pkt.TransportHeader().View() - combined := buffer.MakeWithView(l3) - if err = combined.Append(l4); err == nil { - payload := pkt.Data().ToBuffer() - combined.Merge(&payload) - combined.Truncate(sz) - b = combined - } - return -} diff --git a/intra/netstack/icmpconn.go b/intra/netstack/icmpconn.go deleted file mode 100644 index 53661a1d..00000000 --- a/intra/netstack/icmpconn.go +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// from: github.com/WireGuard/wireguard-go/blob/5819c6af/tun/netstack/tun.go - -package netstack - -import ( - "bytes" - "errors" - "fmt" - "net" - "net/netip" - "os" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - errStub = errors.New("not implemented") - errIPProtoMismatch = fmt.Errorf("ping write: mismatched protocols") - errWrongAddr = errors.New("ping write: wrong address type") - errNotASyscallConn = errors.New("ping: not a syscall conn") -) - -type GICMPConn struct { - nic tcpip.NICID - src PingAddr - dst PingAddr - is6 bool - wq waiter.Queue - ep tcpip.Endpoint - deadline *time.Timer -} - -var _ core.ICMPConn = (*GICMPConn)(nil) - -type PingAddr struct{ addr netip.Addr } - -func (ipp PingAddr) String() string { - return ipp.addr.String() -} - -func (ipp PingAddr) Network() string { - if ipp.addr.Is4() { - return "ping4" - } else if ipp.addr.Is6() { - return "ping6" - } - return "ping" -} - -func (ipp PingAddr) Addr() netip.Addr { - return ipp.addr -} - -func PingAddrFromAddr(addr netip.Addr) *PingAddr { - return &PingAddr{addr} -} - -func DialPingAddr(s *stack.Stack, nic tcpip.NICID, laddr, raddr netip.Addr) (*GICMPConn, error) { - if !laddr.IsValid() && !raddr.IsValid() { - return nil, errors.New("ping dial: invalid address") - } - v6 := laddr.Is6() || raddr.Is6() - bind := laddr.IsValid() - if !bind { - if v6 { - laddr = netip.IPv6Unspecified() - } else { - laddr = netip.IPv4Unspecified() - } - } - - tn := icmp.ProtocolNumber4 - pn := ipv4.ProtocolNumber - if v6 { - tn = icmp.ProtocolNumber6 - pn = ipv6.ProtocolNumber - } - - var wq waiter.Queue - ep, tcpipErr := s.NewEndpoint(tn, pn, &wq) - if tcpipErr != nil || ep == nil { - return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) - } - pc := &GICMPConn{ - nic: nic, - src: PingAddr{laddr}, - is6: v6, - ep: ep, - deadline: time.NewTimer(time.Hour << 10), - } - pc.deadline.Stop() - - if bind { - fa, _ := fullAddrFrom(nic, netip.AddrPortFrom(laddr, 0)) - if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { - return nil, fmt.Errorf("ping bind: %s", tcpipErr) - } - } - - if raddr.IsValid() { - pc.dst = PingAddr{raddr} - fa, _ := fullAddrFrom(nic, netip.AddrPortFrom(raddr, 0)) - if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { - return nil, fmt.Errorf("ping connect: %s", tcpipErr) - } - } // unconnected - - return pc, nil -} - -// SyscallConn implements core.ICMPConn. -func (pc *GICMPConn) SyscallConn() (syscall.RawConn, error) { - return nil, errNotASyscallConn -} - -func (pc *GICMPConn) LocalAddr() net.Addr { - return pc.src -} - -func (pc *GICMPConn) RemoteAddr() net.Addr { - return pc.dst -} - -func (pc *GICMPConn) Close() error { - pc.deadline.Reset(0) - if ep := pc.ep; ep != nil { - go ep.Close() // Close holds ep.mu - } - return nil -} - -func (pc *GICMPConn) SetWriteDeadline(t time.Time) error { - return errStub -} - -func (pc *GICMPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - var ip netip.Addr - switch v := addr.(type) { - case *PingAddr: - ip = v.addr - case *net.IPAddr: - ip, _ = netip.AddrFromSlice(v.IP) - case *net.UDPAddr: - ip, _ = netip.AddrFromSlice(v.IP) - default: - return 0, errWrongAddr - } - if !((ip.Is4() && !pc.is6) || (ip.Is6() && pc.is6)) { - return 0, errIPProtoMismatch - } - - buf := bytes.NewReader(p) - remote, _ := fullAddrFrom(pc.nic, netip.AddrPortFrom(ip, 0)) - // won't block, no deadlines - n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ - To: &remote, - }) - - // may overflow on 32-bit systems - return int(n64), e(tcpipErr) // may be nil -} - -func (pc *GICMPConn) Write(p []byte) (n int, err error) { - return pc.WriteTo(p, &pc.dst) -} - -func (pc *GICMPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) - pc.wq.EventRegister(&e) - defer pc.wq.EventUnregister(&e) - - select { - case <-pc.deadline.C: - return 0, nil, os.ErrDeadlineExceeded - case <-notifyCh: - } - - w := tcpip.SliceWriter(p) - - res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ - NeedRemoteAddr: true, - }) - if tcpipErr != nil { - return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) - } - - remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) - return res.Count, &PingAddr{remoteAddr}, nil -} - -func (pc *GICMPConn) Read(p []byte) (n int, err error) { - n, _, err = pc.ReadFrom(p) - return -} - -func (pc *GICMPConn) SetDeadline(t time.Time) error { - // pc.SetWriteDeadline is unimplemented - - return pc.SetReadDeadline(t) -} - -func (pc *GICMPConn) SetReadDeadline(t time.Time) error { - pc.deadline.Reset(time.Until(t)) - return nil -} - -func fullAddrFrom(nic tcpip.NICID, ipp netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - var protoNumber tcpip.NetworkProtocolNumber - var nsdaddr tcpip.Address - if !ipp.IsValid() { - // TODO: use unspecified address like in PingConn? - return tcpip.FullAddress{}, 0 - } - if ipp.Addr().Is4() { - protoNumber = ipv4.ProtocolNumber - nsdaddr = tcpip.AddrFrom4(ipp.Addr().As4()) - } else { - protoNumber = ipv6.ProtocolNumber - nsdaddr = tcpip.AddrFrom16(ipp.Addr().As16()) - } - log.V("wg: dial: translate ipp: %v -> %v", ipp, nsdaddr) - return tcpip.FullAddress{ - NIC: nic, - Addr: nsdaddr, - Port: ipp.Port(), // may be 0 - }, protoNumber -} diff --git a/intra/netstack/icmpecho.go b/intra/netstack/icmpecho.go deleted file mode 100644 index 0d0bdf8e..00000000 --- a/intra/netstack/icmpecho.go +++ /dev/null @@ -1,281 +0,0 @@ -package netstack - -import ( - "fmt" - "net/netip" - "sync/atomic" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/core/wire" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const minICMPPacketSize = header.ICMPv4MinimumSize + header.IPv4MinimumSize - -// const typicalICMPEchoPayloadSize = 64 // or 56 -// const expectedICMPPacketSize = header.IPv6MinimumSize + header.ICMPv6MinimumSize + typicalICMPEchoPayloadSize - -// TODO: get rid of the global in favor of passing the handler via the responder. -// hdlEcho stores the ICMP handler used by the dispatcher-level ICMP -// interception path. -var hdlEcho = core.NewZeroVolatile[*icmpForwarder]() - -func setICMPEchoHandler(h *icmpForwarder) { - hdlEcho.Store(h) -} - -// icmpResponder handles ICMP packets directly from the TUN dispatcher, -// bypassing gVisor netstack when possible. -// -// The responder uses the existing ICMP handler (which ultimately uses the core -// ICMP implementation) to forward the ping and then injects the response back -// into the TUN device. -type icmpResponder struct { - ep stack.LinkEndpoint - open atomic.Bool -} - -func (r *icmpResponder) stop() { - r.open.Store(false) -} - -func newICMPResponder(ep stack.LinkEndpoint) (r icmpResponder) { - if ep == nil || core.IsNil(ep) { - return - } - r.ep = ep - r.open.Store(true) - return -} - -// returns true if the responder is enabled and debug mode is on. -func (r *icmpResponder) ok() bool { - return settings.Debug && r != nil && r.open.Load() -} - -func (r *icmpResponder) respond(pkt *stack.PacketBuffer) (handled bool) { - if !r.ok() { - return - } - h := hdlEcho.Load() - if h == nil { - log.E("icmp: responder: no handler") - return - } - return r.handle(h, pkt.NICID, pkt) -} - -// handle returns true if the packet is ICMP and is handled (or dropped) by the -// bypass path. -func (r *icmpResponder) handle(h *icmpForwarder, nic tcpip.NICID, pkt *stack.PacketBuffer) (handled bool) { - if !r.ok() { - return - } - - inSize := pkt.Size() - if inSize <= minICMPPacketSize { - // too verbose: log.VV("icmp: responder: packet too small: %d", inSize) - // Too small to be a valid ICMP echo request. - return - } - - useIcmpForwarder := settings.ExperimentalWireGuard.Load() - - c := pkt.Clone() - defer c.DecRef() - - v := c.ToView() - b := v.ToSlice() - n := len(b) - - notok := n <= 0 || h == nil - if settings.Debug || notok { - logwv(notok)("icmp: responder: read to writer (sz: %d / %d / %d); h? %t / fwd? %t", - n, v.Size(), inSize, h != nil, useIcmpForwarder) - } - if notok { - return - } - - parsed := wire.Pool.Get() - parsed.Decode(b) - - // Only echo requests are handled; other ICMP packets are dropped to avoid - // feeding them back into netstack. - if parsed.IPProto != wire.ICMPv4 && parsed.IPProto != wire.ICMPv6 { - if settings.Debug { - log.VV("icmp: responder: unsupported proto: %d / echo: %t @ %d; h: %s; content: %s", - parsed.IPProto, parsed.IsEchoRequest(), parsed.EchoIDSeq(), parsed.ICMPHeaderString(), trunc(parsed.Buffer(), 8)) - } - wire.Pool.Put(parsed) - return - } - - src := parsed.Src - dst := parsed.Dst - has := parsed.HasTransportData() - - logwv(!has)("icmp: responder: request ipv%d; %s => %s; h: %s; ok? %t", - parsed.IPVersion, src, dst, parsed.ICMPHeaderString(), has) - - if !has { - wire.Pool.Put(parsed) - return - } - - if !parsed.IsEchoRequest() { - if settings.Debug { - log.VV("icmp: responder: not echo request ipv%d; %s => %s; h: %s; %x", - parsed.IPVersion, src, dst, parsed.ICMPHeaderString(), parsed.Buffer()) - } - wire.Pool.Put(parsed) - return - } - - if useIcmpForwarder { - wire.Pool.Put(parsed) - return icmpForward(h, pkt, src, dst) - } else { - // Process asynchronously to avoid blocking the dispatcher loop. - core.Gx("icmp.responder", func() { - r.process(h, nic, parsed, src, dst) - }) - } - - return true -} - -func icmpForward(h *icmpForwarder, pkt *stack.PacketBuffer, src, dst netip.AddrPort) bool { - // local is dst / remote is src; see: netstack/icmp/icmp.go:func (h *icmpForwarder) reply4 - // and netstack/icmp/icmp.go:func (h *icmpForwarder) reply6 - local := dst.Addr().AsSlice() - remote := src.Addr().AsSlice() - - notok := len(local) == 0 || len(remote) == 0 - logwv(notok)("icmp: responder: forward: (sz: %d) empty addr? %s => %s", pkt.Size(), src, dst) - if notok { - return false - } - - var id stack.TransportEndpointID - id.LocalAddress = tcpip.AddrFromSlice(local) - id.RemoteAddress = tcpip.AddrFromSlice(remote) - // ICMP does not use ports, so they remain zero. - - return icmpForward2(h, pkt, id) -} - -func icmpForward2(h *icmpForwarder, pkt *stack.PacketBuffer, id stack.TransportEndpointID) bool { - pkt = pkt.Clone() - defer pkt.DecRef() - - switch pkt.NetworkProtocolNumber { - case header.IPv4ProtocolNumber: - v, got := core.Await1(func() bool { defer pkt.DecRef(); return h.reply4(id, pkt.IncRef()) }, 5*time.Second) - return got && v - case header.IPv6ProtocolNumber: - v, got := core.Await1(func() bool { defer pkt.DecRef(); return h.reply6(id, pkt.IncRef()) }, 5*time.Second) - return got && v - } - - log.W("icmp: responder: unsupported proto: %d; %s => %s", - pkt.NetworkProtocolNumber, id.RemoteAddress, id.LocalAddress) - return false -} - -// process handles the ICMP echo request and injects the reply back into the TUN. -// The parsed packet is released back to the pool after processing. -func (r *icmpResponder) process(h *icmpForwarder, nic tcpip.NICID, pkt *wire.Parsed, src, dst netip.AddrPort) { - defer wire.Pool.Put(pkt) - - icmpMsg := pkt.Transport() - payload, truncated := pkt.Payload() - notok := truncated || len(icmpMsg) <= 0 - - if notok || settings.Debug { - logwv(notok)("icmp: responder: truncated? %t or missing? %t ICMPv%d; %s => %s; id: %d; h: %s; sz: %d", - truncated, len(icmpMsg) <= 0, pkt.IPVersion, src, dst, pkt.EchoIDSeq(), pkt.ICMPHeaderString(), len(payload)) - } - - if notok { - return - } - - pinged := h.h.Ping(icmpMsg, src, dst) - - resp, proto, l4proto, tag, err := r.echoReply(pkt, payload, pinged) - notok = err != nil || len(resp) == 0 - - if notok || settings.Debug { - logwv(notok)("icmp: responder: reply %s <= %s (sz: %d / id: %d); ping? %t; res: %s; err? %v", - src, dst, len(resp), pkt.EchoIDSeq(), pinged, tag, err) - } - - if err != nil || len(resp) == 0 { - return - } - - r.inject(nic, proto, l4proto, resp) -} - -func (r *icmpResponder) echoReply(pkt *wire.Parsed, d []byte, ok bool) ([]byte, tcpip.NetworkProtocolNumber, tcpip.TransportProtocolNumber, string, error) { - // github.com/tailscale/tailscale/blob/7de1b0b33082cc/wgengine/netstack/netstack.go#L1201-L1212 - switch pkt.IPVersion { - case 4: - icmpHdr := pkt.ICMP4Header() - (&icmpHdr).ToResponse() - if !ok { - (&icmpHdr).Type = wire.ICMP4Unreachable - (&icmpHdr).Code = wire.ICMP4HostUnreachable - } - tag := icmpHdr.Stringer() - return wire.Generate(&icmpHdr, d), header.IPv4ProtocolNumber, header.ICMPv4ProtocolNumber, tag, nil - case 6: - icmpHdr := pkt.ICMP6Header() - (&icmpHdr).ToResponse() - if !ok { - (&icmpHdr).Type = wire.ICMP6Unreachable - (&icmpHdr).Code = wire.ICMP6NoRoute - } - tag := icmpHdr.Stringer() - // github.com/tailscale/tailscale/blob/7de1b0b33082cc/wgengine/userspace.go#L577 - return wire.Generate(&icmpHdr, d), header.IPv6ProtocolNumber, header.ICMPv6ProtocolNumber, tag, nil - default: - return nil, 0, 0, "", fmt.Errorf("unsupported ip version: %d", pkt.IPVersion) - } -} - -func (r *icmpResponder) inject(nic tcpip.NICID, proto tcpip.NetworkProtocolNumber, l4proto tcpip.TransportProtocolNumber, packet []byte) { - ep := r.ep - if ep == nil || !r.ok() { - return - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), - }) - pkt.NICID = nic - pkt.NetworkProtocolNumber = proto - pkt.TransportProtocolNumber = l4proto - - var list stack.PacketBufferList - list.PushBack(pkt) - defer list.DecRef() - - sz := pkt.Size() - n, err := ep.WritePackets(list) - logeif(e(err))("icmp: responder: inject %d to tun (n: %d; sz: %d); err? %v", proto, n, sz, err) -} - -func trunc(b []byte, n int) string { - if len(b) <= n { - return fmt.Sprintf("%x", b) - } - return fmt.Sprintf("%x...%x", b[:n], b[len(b)-n:]) -} diff --git a/intra/netstack/netstack.go b/intra/netstack/netstack.go deleted file mode 100644 index fcddb52c..00000000 --- a/intra/netstack/netstack.go +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "errors" - "fmt" - "strconv" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -// enable forwarding of packets on the interface -const nicfwd = false - -// ref: github.com/brewlin/net-protocol/blob/ec64e5f899/internal/endpoint/endpoint.go#L20 -func Up(s *stack.Stack, ep SeamlessEndpoint, h GConnHandler) (tcpip.NICID, error) { - nic := tcpip.NICID(settings.NICID) - who := strconv.Itoa(ep.Stat().Fd) - - // fetch existing routes before adding removing nic, which wipes out routes - existingroutes := s.GetRouteTable() - newnic := false - // also closes its netstack protos (ip4, ip6), closes link-endpoint (ep), if any - if ferr := s.RemoveNIC(nic); ferr != nil { - _, newnic = ferr.(*tcpip.ErrUnknownNICID) - log.I("netstack: %s: new nic? %t; remove nic? err(%v)", who, newnic, ferr) - } else { - log.I("netstack: %s: removed nic(%d)", who, nic) - } - - // TODO? Pause and resume? - // if newnic { - // s.Pause() - // defer s.Resume() - // } - - SetNetstackOpts(s) - - if newnic { - // github.com/google/gvisor/blob/a7dcce93/pkg/tcpip/sample/tun_tcp_connect/main.go - OutboundTCP(who, s, h.TCP()) - OutboundUDP(who, s, h.UDP()) - OutboundICMP(who, s, h.ICMP()) - } - - if settings.Debug { - tcp := s.TransportProtocolInstance(tcp.ProtocolNumber) != nil // 6 - udp := s.TransportProtocolInstance(udp.ProtocolNumber) != nil // 17 - icmp4 := s.TransportProtocolInstance(icmp.ProtocolNumber4) != nil // 1 - icmp6 := s.TransportProtocolInstance(icmp.ProtocolNumber6) != nil // 58 - log.D("netstack: %s: transport instances: icmp4/6? %t/%t, tcp/udp %t/%t", - who, icmp4, icmp6, tcp, udp) - } - - // creates and enables a fake nic for netstack s - // netstack protos (ip4, ip6) enabled and ep is attached to nic - if nerr := s.CreateNIC(nic, ep); nerr != nil { - return nic, e(nerr) - } - // add addrs to this nic just attached to netstack s - if err := addIfAddrs(s, nic, h); err != nil { - return nic, err - } - - // ref: github.com/xjasonlyu/tun2socks/blob/31468620e/core/stack.go#L80 - // allow spoofing packets tuples - if nerr := s.SetSpoofing(nic, true); nerr != nil { - return nic, e(nerr) - } - // ref: github.com/xjasonlyu/tun2socks/blob/31468620e/core/stack.go#L94 - // allow all packets sent to our fake nic through to netstack - if nerr := s.SetPromiscuousMode(nic, true); nerr != nil { - return nic, e(nerr) - } - - if4, _ := s.GetMainNICAddress(nic, ipv4.ProtocolNumber) - if6, _ := s.GetMainNICAddress(nic, ipv6.ProtocolNumber) - - s.SetNICForwarding(nic, ipv4.ProtocolNumber, nicfwd) - s.SetNICForwarding(nic, ipv6.ProtocolNumber, nicfwd) - // s.SetNICMulticastForwarding(nic, ipv4.ProtocolNumber, nicfwd) - // s.SetNICMulticastForwarding(nic, ipv6.ProtocolNumber, nicfwd) - // use existing routes if the nic is being recycled - useExistingRoutes := !newnic && len(existingroutes) > 0 - if useExistingRoutes { - s.SetRouteTable(existingroutes) - } - - log.I("netstack: %s: up(%d)! new? %t; addrs: %v %v; routes? %s / existing? %t: %s", - who, nic, newnic, if4, if6, s.GetRouteTable(), useExistingRoutes, existingroutes) - - return nic, nil -} - -func e(err tcpip.Error) error { - if err != nil { - return errors.New(err.String()) - } - return nil -} - -func addIfAddrs(s *stack.Stack, nic tcpip.NICID, hdl GConnHandler) error { - // TODO: make ifaddrs configurable like fakedns is - // The NIC is set in Spoofing mode. When the UDP Endpoint uses a non-local - // address to "Connect", netstack generates a temporary addressState to - // build a route, which can be primary but is always ephemeral. When this - // UDP Endpoint uses a multicast address to "Connect", netstack selects - // any available primary addressState to build a route. However, when the - // NIC is in the just-initialized or idle state, no primary addressState - // is readily available, and "Connect" fails. And so, permanent addresses, - // e.g. 10.111.222.1/24 and fd66:f83a:c650::1/120, are assigned to the NIC, - // which are only used to build routes for multicast response (and should - // for any other connection that is "ingressing" into netstack). - // - // In fact, for multicast, the sender normally does not expect a response. - // So, the ep.net.Connect is unnecessary. - - ifaddr4, ifaddr6 := HandlerAddrs(hdl) - if !ifaddr4.IsValid() && !ifaddr6.IsValid() { - return fmt.Errorf("netstack: %d no ifaddrs from handler", nic) - } - - asMainAddr := stack.AddressProperties{PEB: stack.CanBePrimaryEndpoint} - - // go.dev/play/p/Clg4geOwXMf - if ifaddr4.IsValid() { - nsaddr4 := tcpip.AddrFrom4(ifaddr4.Addr().As4()) - ap4 := tcpip.AddressWithPrefix{ - Address: nsaddr4, // 10.111.222.1 - PrefixLen: ifaddr4.Bits(), // 24 - } - protoaddr4 := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: ap4, - } - // at: github.com/google/gvisor/blob/1f4299ee3f/pkg/tcpip/stack/addressable_endpoint_state.go#L177 - if err := s.AddProtocolAddress(nic, protoaddr4, asMainAddr); err != nil { - return fmt.Errorf("netstack: %d add addr(%v): %v", nic, ifaddr6, err) - } - } - - if ifaddr6.IsValid() { - nsaddr6 := tcpip.AddrFrom16(ifaddr6.Addr().As16()) - ap6 := tcpip.AddressWithPrefix{ - Address: nsaddr6, // fd66:f83a:c650::1 - PrefixLen: ifaddr6.Bits(), // 120 - } - protoaddr6 := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: ap6, - } - if err := s.AddProtocolAddress(nic, protoaddr6, asMainAddr); err != nil { - return fmt.Errorf("netstack: %d add addr(%v): %v", nic, ifaddr4, err) - } - } - - log.I("netstack: %d ifaddrs 4(%v) 6(%v)", nic, ifaddr4, ifaddr6) - return nil -} - -func Route(s *stack.Stack, l3 string) { - // TODO? s.Pause() - // defer s.Resume() - - which := l3 - switch l3 { - case settings.IP46: - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: settings.NICID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: settings.NICID, - }, - }) - case settings.IP6: - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: settings.NICID, - }, - }) - case settings.IP4: - fallthrough - default: - which = settings.IP4 - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: settings.NICID, - }, - }) - } - // s.AddTCPProbe(func(state *stack.TCPEndpointState) {}) - log.I("netstack: route(ask:%s; set: %s); done", l3, which) -} - -// also: github.com/google/gvisor/blob/adbdac747/runsc/boot/loader.go#L1132 -// github.com/FlowerWrong/tun2socks/blob/1045a49618/cmd/netstack/main.go -// github.com/zen-of-proxy/go-tun2io/blob/c08b329b8/tun2io/util.go -// github.com/WireGuard/wireguard-go/blob/42c9af4/tun/netstack/tun.go -// github.com/telepresenceio/telepresence/pull/2709 -func NewNetstack() (s *stack.Stack) { - o := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - ipv6.NewProtocol, - // arp.NewProtocol, unused - }, - TransportProtocols: []stack.TransportProtocolFactory{ - icmp.NewProtocol4, - icmp.NewProtocol6, - tcp.NewProtocol, - udp.NewProtocol, - }, - // HandleLocal if the packets must be forwarded to another nic within this stack, or - // to let this stack forward packets to the OS' network stack. - // also: github.com/Jigsaw-Code/outline-go-tun2socks/blob/5416729062/tunnel/tunnel.go#L45 - // HandleLocal: true, - } - - s = stack.New(o) - log.I("netstack: new stack4 and stack6") - return -} diff --git a/intra/netstack/rev.go b/intra/netstack/rev.go deleted file mode 100644 index a3c9a492..00000000 --- a/intra/netstack/rev.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "context" - "net" - "net/netip" - "strconv" - "sync/atomic" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type revbase[T gconns] struct { - o string // owner - ended atomic.Bool -} - -type revip struct { - if4 netip.Prefix - if6 netip.Prefix - stackip4 netip.Addr - stackip6 netip.Addr -} - -type revtcp struct { - *revbase[*GTCPConn] - revip - revstack *stack.Stack - reverser GTCPConnHandler -} - -type revudp struct { - *revbase[*GUDPConn] - revip - revstack *stack.Stack - reverser GUDPConnHandler -} - -type revicmp struct { - *revbase[*GICMPConn] - revstack *stack.Stack - revep stack.LinkEndpoint - reverser GICMPHandler -} - -var _ GTCPConnHandler = (*revtcp)(nil) -var _ GUDPConnHandler = (*revudp)(nil) -var _ GICMPHandler = (*revicmp)(nil) - -func NewReverseGConnHandler(pctx context.Context, to *stack.Stack, of tcpip.NICID, ep SeamlessEndpoint, via GConnHandler) *gconnhandler { - id := strconv.Itoa(ep.Stat().Fd) - ip4, ip6 := StackAddrs(to, of) - if4, if6 := HandlerAddrs(via) - ifaddrs := revip{ - if4: if4, - if6: if6, - stackip4: ip4, - stackip6: ip6, - } - h := &gconnhandler{ - src: via.Src(), - tcp: newReverseTCP(id, to, of, ifaddrs, via.TCP()), - udp: newReverseUDP(id, to, of, ifaddrs, via.UDP()), - icmp: newReverseICMP(id, to, ep, via.ICMP()), - } - log.I("rev: %s: newReverseGConnHandler %d @ %d on %v", id, of, core.Loc(to), ifaddrs) - context.AfterFunc(pctx, h.end) - return h -} - -func newReverseTCP(id string, s *stack.Stack, nic tcpip.NICID, ifaddrs revip, h GTCPConnHandler) *revtcp { - log.I("rev: %s: nic %d newReverseTCP %v", id, nic, ifaddrs) - return &revtcp{ - revbase: &revbase[*GTCPConn]{o: id}, - revip: ifaddrs, - revstack: s, - reverser: h, - } -} - -func newReverseUDP(id string, s *stack.Stack, nic tcpip.NICID, ifaddrs revip, h GUDPConnHandler) *revudp { - log.I("rev: %s: nic %d newReverseUDP %v", id, nic, ifaddrs) - return &revudp{ - revbase: &revbase[*GUDPConn]{o: id}, - revip: ifaddrs, - revstack: s, - reverser: h, - } -} - -func newReverseICMP(id string, s *stack.Stack, ep stack.LinkEndpoint, h GICMPHandler) *revicmp { - return &revicmp{ - revbase: &revbase[*GICMPConn]{o: id}, - revstack: s, - revep: ep, - reverser: h, - } -} - -// GConnHandler - -func (g *gconnhandler) end() { - if t := g.tcp; t != nil { - t.End() - } - if u := g.udp; u != nil { - u.End() - } - if i := g.icmp; i != nil { - i.End() - } - log.I("rev: gconnhandler end") -} - -// Base - -func (b *revbase[T]) ReverseProxy(out T, in net.Conn, src, dst netip.AddrPort) bool { - // TODO: stub - log.E("rev: %s: revbase: %T ReverseProxy not implemented %v <= %v", b.o, out, src, dst) - return false -} - -func (b *revbase[T]) Error(in T, src, dst netip.AddrPort, err error) { - log.E("rev: %s: revbase: %T Error %v <= %v: %v", b.o, in, src, dst, err) -} - -func (*revbase[T]) OpenConns() string { - // TODO: stub - return "" -} - -func (*revbase[T]) CloseConns([]string) []string { - // TODO: stub - return nil -} - -func (r *revbase[T]) End() { - r.ended.Store(true) -} - -// TCP - -func (t *revtcp) Proxy(in *GTCPConn, src, dst netip.AddrPort) bool { - end := t.ended.Load() - log.D("rev: %s: revtcp: Proxy %v <= %v; end? %t", t.o, src, dst, end) - if end { - return false - } - // dst is local (just the port number assuming listening sockets) - // to t.revstack to dial into; src is remote to t.revstack - // ex: src 1.1.1.1:5555 / dst 10.0.1.1:1111 - err := InboundTCP(t.o, t.revstack, in, t.revipp(dst), src, t.reverser) - logeif(err)("rev: %s: revtcp: Proxy %v <= %v; err? %v", t.o, src, dst, err) - return err == nil -} - -// ip local to revstack -func (r *revtcp) revipp(ipp netip.AddrPort) netip.AddrPort { - if ipp.Addr().Is6() { - return netip.AddrPortFrom(r.stackip6, ipp.Port()) - } - return netip.AddrPortFrom(r.stackip4, ipp.Port()) -} - -// UDP - -func (u *revudp) Proxy(in *GUDPConn, src, dst netip.AddrPort) bool { - end := u.ended.Load() - log.D("rev: %s: revudp: Proxy %v <= %v; end? %t", u.o, src, dst, end) - if end { - return false - } - // see: revtcp.Proxy - err := InboundUDP(u.o, u.revstack, in, u.revipp(dst), src, u.reverser) - logeif(err)("rev: %s: revudp: Proxy %v <= %v; err? %v", u.o, src, dst, err) - return err == nil -} - -func (u *revudp) ProxyMux(in *GUDPConn, src, dst netip.AddrPort, mux DemuxerFn) bool { - end := u.ended.Load() - log.D("rev: %s: revudp: ProxyMux %v <= %v; end? %t", u.o, src, dst, end) - if end { - return false - } - // TODO: impl mux/demux - err := InboundUDP(u.o, u.revstack, in, u.revipp(dst), src, u.reverser) - logeif(err)("rev: %s: revudp: ProxyMux %v <= %v; err? %v", u.o, src, dst, err) - return err == nil -} - -// ip local to revstack -func (r *revudp) revipp(ipp netip.AddrPort) netip.AddrPort { - if ipp.Addr().Is6() { - return netip.AddrPortFrom(r.stackip6, ipp.Port()) - } - return netip.AddrPortFrom(r.stackip4, ipp.Port()) -} - -// ICMP - -func (i *revicmp) Ping(msg []byte, src, dst netip.AddrPort) bool { - // TODO: stub - log.E("rev: %s: revicmp: Ping not implemented %v <= %v; err? %v", i.o, src, dst) - return false -} - -func logeif(err error) log.LogFn { - if err != nil { - return log.E - } - return log.V -} - -func logei(cond bool) log.LogFn { - if cond { - return log.E - } - return log.I -} - -func logwv(cond bool) log.LogFn { - if cond { - return log.W - } - return log.V -} - -func StackAddrs(s *stack.Stack, nic tcpip.NICID) (netip.Addr, netip.Addr) { - zeromainaddr := tcpip.AddressWithPrefix{} - ip4 := netip.IPv4Unspecified() - ip6 := netip.IPv6Unspecified() - mainaddr4, err4 := s.GetMainNICAddress(nic, header.IPv4ProtocolNumber) - mainaddr6, err6 := s.GetMainNICAddress(nic, header.IPv6ProtocolNumber) - if err4 != nil || err6 != nil { - log.E("rev: StackAddrs %v; err: %v", nic, err4) - } - // comparable? github.com/google/gvisor/blob/1e97c039b/pkg/tcpip/adapters/gonet/gonet.go#L509 - if !mainaddr4.Address.Equal(zeromainaddr.Address) { - ip4 = netip.AddrFrom4(mainaddr4.Address.As4()) - } - if !mainaddr6.Address.Equal(zeromainaddr.Address) { - ip6 = netip.AddrFrom16(mainaddr6.Address.As16()) - } - log.V("netstack: StackAddrs %v %v", ip4, ip6) - return ip4, ip6 -} - -func HandlerAddrs(hdl GConnHandler) (ifaddr4 netip.Prefix, ifaddr6 netip.Prefix) { - if hdl == nil { - return - } - // TODO: add multiple addrs of same family - for _, x := range hdl.Src() { - if x.Addr().Is4() && !ifaddr4.IsValid() { - ifaddr4 = x - } else if x.Addr().Is6() && !ifaddr6.IsValid() { - ifaddr6 = x - } - if ifaddr4.IsValid() && ifaddr6.IsValid() { - break - } - } - return ifaddr4, ifaddr6 -} diff --git a/intra/netstack/seamless.go b/intra/netstack/seamless.go deleted file mode 100644 index 54ac7a8b..00000000 --- a/intra/netstack/seamless.go +++ /dev/null @@ -1,361 +0,0 @@ -package netstack - -import ( - "errors" - "fmt" - "io" - "strconv" - "strings" - "syscall" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// Clearing endpoint on Dispose will result in all stack.Linkpoint APIs -// returning zero values (which may trip netstack?) -const clearEndpointOnDispose = false - -type FdSwapper interface { - // Swap closes existing FDs; uses new fd. - Swap(fd, mtu int) error - // Dispose closes all existing FDs. - Dispose() error - // Stat returns EpStat (fd, age, read, written, lastRead, lastWrite). - Stat() EpStat -} - -type EpStat struct { - // Fd is the file descriptor of the endpoint. - Fd int - // Alive indicates whether the endpoint is alive. - Alive bool - // Age is the age of the endpoint. - Age string - // Read is the number of bytes read from the endpoint. - Read string - // Written is the number of bytes written to the endpoint. - Written string - // LastRead is the last time the endpoint was read from. - LastRead string - // LastWrite is the last time the endpoint was written to. - LastWrite string -} - -func (s EpStat) String() string { - if s.Fd == 0 { - return "" - } - return fmt.Sprintf("Fd: %d,Alive: %t,Age: %s,R: %s,W: %s,LastRead: %s,LastWrite: %s", - s.Fd, - s.Alive, - s.Age, - s.Read, - s.Written, - s.LastRead, - s.LastWrite) -} - -type SeamlessEndpoint interface { - stack.LinkEndpoint - FdSwapper -} - -type magiclink struct { - e *core.Volatile[SeamlessEndpoint] - d *core.Volatile[stack.NetworkDispatcher] - s io.WriteCloser -} - -var _ stack.LinkEndpoint = (*magiclink)(nil) -var _ stack.NetworkDispatcher = (*magiclink)(nil) -var _ stack.GSOEndpoint = (*magiclink)(nil) -var _ SeamlessEndpoint = (*magiclink)(nil) - -var errMissingSink = errors.New("magic: pcap sink is nil") - -// ref: github.com/google/gvisor/blob/91f58d2cc/pkg/tcpip/sample/tun_tcp_echo/main.go#L102 -func NewEndpoint(dev, mtu int, sink io.WriteCloser) (ep SeamlessEndpoint, err error) { - defer func() { - if err != nil { - _ = syscall.Close(dev) - } - log.I("netstack: new endpoint(fd:%d / mtu:%d); err? %v", dev, mtu, err) - }() - - if sink == nil { - return nil, errMissingSink - } - - umtu := uint32(mtu) - opt := Options{ - FDs: []int{dev}, - MTU: umtu, - } - - if ep, err = newFdbasedInjectableEndpoint(&opt); err != nil { - return nil, err - } - // ref: github.com/google/gvisor/blob/aeabb785278/pkg/tcpip/link/sniffer/sniffer.go#L111-L131 - v := core.NewVolatile(ep) - d := core.NewZeroVolatile[stack.NetworkDispatcher]() - - return &magiclink{v, d /*nil*/, sink}, nil -} - -func Pcap2Stdout(y bool) (ok bool) { - if y { - ok = logPackets.CompareAndSwap(0, 1) - } else { - ok = logPackets.CompareAndSwap(1, 0) - } - log.I("netstack: pcap stdout(%t): done?(%t)", y, ok) - return -} - -func Pcap2File(y bool) (ok bool) { - if y { - ok = writePCAP.CompareAndSwap(0, 1) - } else { - ok = writePCAP.CompareAndSwap(1, 0) - } - log.I("netstack: pcap file(%t): done?(%t)", y, ok) - return -} - -// PCAP logging modes: -// - stdout: packets are logged to stdout -// - file: packets are logged to a file -// - none: no packets are logged -func PcapModes() string { - var modes []string - if logPackets.Load() == 1 { - modes = append(modes, "stdout") - } - if writePCAP.Load() == 1 { - modes = append(modes, "file") - } - if len(modes) == 0 { - return "none" - } - return strings.Join(modes, ",") -} - -// Swap implements SeamlessEndpoint. -func (l *magiclink) Swap(fd, mtu int) (err error) { - e := l.e.Load() - hasSwappedFd := false - needsNewEndpoint := e == nil - if e != nil { - err = e.Swap(fd, mtu) - hasSwappedFd = err == nil - needsNewEndpoint = errors.Is(err, errNeedsNewEndpoint) - } - - if hasSwappedFd || !needsNewEndpoint { - logei(!hasSwappedFd)("netstack: magic(%d); swap: ok? %t; err? %v", - fd, hasSwappedFd, err) - return err - } - - umtu := uint32(mtu) - opt := Options{ - FDs: []int{fd}, - MTU: umtu, - } - - ep, err := newFdbasedInjectableEndpoint(&opt) - if err != nil || ep == nil { - log.E("netstack: magic(%d); swap: ep missing? %t; err %v", fd, ep == nil, err) - return core.OneErr(err, errMissingEp) - } - - // attach eventually runs a dispatchLoop which kickstarts the endpoint's - // delivery of packets to netstack's dispatcher. - d := l.d.Load() - if d == nil { - ep.Attach(nil) // attach the new endpoint to the dispatcher - } else { - ep.Attach(l) // attach the new endpoint to the existing dispatcher - } - - // swap endpoints after the dispatchLoop has had the chance to start - // to avoid cases where clients end up calling ep.Wait() before dispatchLoop - // could begin (as it is responsible for keeping ep alive) - if old := l.e.Tango(ep); old != nil { - core.Go("magic."+strconv.Itoa(fd), old.Close) - } - - logei(d == nil)("netstack: magic(%d) mtu: %d; swap: new ep... dispatch? %t", - fd, umtu, d != nil) - - return nil -} - -// Dispose implements SeamlessEndpoint. -func (l *magiclink) Dispose() error { - if e := l.e.Load(); e != nil { - if clearEndpointOnDispose { - // will result in stack.LinkEndpoint impls return zero values - // unsure if this will trip netstack in to thinking if this - // endpoint is kaput, when in reality, this endpoint can swap - // in a healthy endpoint at a later point in time, which then - // we'd expect netstack to use as expected. - l.e.Store(nil) - } - return e.Dispose() - } - log.W("netstack: magic: dispose; no endpoint") - return nil -} - -// Stat implements SeamlessEndpoint. -func (l *magiclink) Stat() EpStat { - if e := l.e.Load(); e != nil { - return e.Stat() - } - return EpStat{} -} - -func (l *magiclink) MTU() uint32 { - if e := l.e.Load(); e != nil { - return e.MTU() - } - return 0 -} - -func (l *magiclink) SetMTU(mtu uint32) { - if e := l.e.Load(); e != nil { - e.SetMTU(mtu) - } -} - -func (l *magiclink) MaxHeaderLength() uint16 { - if e := l.e.Load(); e != nil { - return e.MaxHeaderLength() - } - return 0 -} - -func (l *magiclink) LinkAddress() tcpip.LinkAddress { - if e := l.e.Load(); e != nil { - return e.LinkAddress() - } - return "" -} - -func (l *magiclink) SetLinkAddress(addr tcpip.LinkAddress) { - if e := l.e.Load(); e != nil { - e.SetLinkAddress(addr) - } -} - -func (l *magiclink) Capabilities() stack.LinkEndpointCapabilities { - if e := l.e.Load(); e != nil { - return e.Capabilities() - } - return 0 -} - -func (l *magiclink) Attach(dispatcher stack.NetworkDispatcher) { - l.d.Store(dispatcher) // update the dispatcher - if e := l.e.Load(); e != nil { - if dispatcher == nil { - e.Attach(nil) // detach - } else { - e.Attach(l) - } - } - log.I("netstack: magic: attach dispatcher? %t", dispatcher == nil) -} - -func (l *magiclink) IsAttached() bool { - if e := l.e.Load(); e != nil { - return e.IsAttached() - } - return false -} - -func (l *magiclink) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - l.DumpPacket(DirectionRecv, protocol, pkt) - if d := l.d.Load(); d != nil { - d.DeliverNetworkPacket(protocol, pkt) - return - } - log.E("netstack: magic: deliver network packet (sz: %d); no dispatcher", pkt.Size()) -} - -// unused -func (l *magiclink) DeliverLinkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - l.DumpPacket(DirectionRecv, protocol, pkt) - if d := l.d.Load(); d != nil { - d.DeliverLinkPacket(protocol, pkt) - return - } - log.E("netstack: magic: deliver link packet (sz: %d); no dispatcher", pkt.Size()) -} - -func (l *magiclink) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { - if l.doPCAP() { - for _, pkt := range pkts.AsSlice() { - if pkt != nil { // nilaway - l.DumpPacket(DirectionSend, pkt.NetworkProtocolNumber, pkt) - } - } - } - if e := l.e.Load(); e != nil { - return e.WritePackets(pkts) - } - log.E("netstack: magic: write packets; no endpoint") - return 0, &tcpip.ErrNotPermitted{} -} - -func (l *magiclink) Wait() { - if e := l.e.Load(); e != nil { - if e.IsAttached() { - e.Wait() // may panic in case of WaitGroup reuse issues - } else { - log.W("netstack: magic: wait; dispatcher not attached; has dispatcher? %t", l.d.Load() != nil) - } - } -} - -func (l *magiclink) ARPHardwareType() header.ARPHardwareType { - if e := l.e.Load(); e != nil { - return e.ARPHardwareType() - } - return 0 -} - -func (l *magiclink) AddHeader(pkt *stack.PacketBuffer) { - if e := l.e.Load(); e != nil { - e.AddHeader(pkt) - } -} - -func (l *magiclink) ParseHeader(pkt *stack.PacketBuffer) bool { - if e := l.e.Load(); e != nil { - return e.ParseHeader(pkt) - } - return false -} - -func (l *magiclink) Close() { - if e := l.e.Load(); e != nil { - e.Close() - } -} - -func (l *magiclink) SetOnCloseAction(f func()) { - if e := l.e.Load(); e != nil { - e.SetOnCloseAction(f) - } -} - -func (l *magiclink) GSOMaxSize() uint32 { return 0 } - -// SupportedGSO returns the supported segmentation offloading. -func (l *magiclink) SupportedGSO() stack.SupportedGSO { return stack.GSONotSupported } diff --git a/intra/netstack/snooper.go b/intra/netstack/snooper.go deleted file mode 100644 index af1205b7..00000000 --- a/intra/netstack/snooper.go +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package netstack - -import ( - "encoding/binary" - "fmt" - "io" - "sync/atomic" - "time" - - "github.com/celzero/firestack/intra/core" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/header/parse" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// from: github.com/google/gvisor/blob/596e8d22/pkg/tcpip/link/sniffer/sniffer.go - -var logPackets atomic.Uint32 -var writePCAP atomic.Uint32 - -const logPrefix = "" - -// SnapLen is the maximum bytes of a packet to be saved. Packets with a length -// less than or equal to snapLen will be saved in their entirety. Longer -// packets will be truncated to snapLen. -// TODO: MTU instead of SnapLen? Must match pcapsink.begin() -const SnapLen uint32 = 2048 // in bytes; some sufficient value - -// A Direction indicates whether the packing is being sent or received. -type Direction int - -const ( - // DirectionSend indicates a sent packet. - DirectionSend = iota - // DirectionRecv indicates a received packet. - DirectionRecv -) - -func (dr Direction) String() string { - switch dr { - case DirectionSend: - return "send" - case DirectionRecv: - return "recv" - default: - return "unknown" - } -} - -func zoneOffset() (int32, error) { - date := time.Date(0, 0, 0, 0, 0, 0, 0, time.Local) - _, offset := date.Zone() - return int32(offset), nil -} - -func WritePCAPHeader(w io.Writer) error { - offset, err := zoneOffset() - if err != nil { - return err - } - return binary.Write(w, binary.LittleEndian, core.PcapHeader{ - // From https://wiki.wireshark.org/Development/LibpcapFileFormat - MagicNumber: 0xa1b2c3d4, - - VersionMajor: 2, - VersionMinor: 4, - Thiszone: offset, - Sigfigs: 0, - Snaplen: SnapLen, - Network: 101, // LINKTYPE_RAW - }) -} - -func (l *magiclink) doPCAP() bool { - if logPackets.Load() == 1 { - return true - } - if writePCAP.Load() == 1 { - return l.s != nil - } - return false -} - -// DumpPacket logs a packet, depending on configuration, to stderr and/or a -// pcap file. ts is an optional timestamp for the packet. -func (l *magiclink) DumpPacket(dir Direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if pkt == nil { // nilaway - return - } - if logPackets.Load() == 1 { - LogPacket(logPrefix, dir, protocol, pkt) - } - if writePCAP.Load() == 1 && l.s != nil { - packet := core.PcapPacket{ - Packet: pkt, - MaxCaptureLen: int(SnapLen), - } - packet.Timestamp = time.Now() - b, err := packet.MarshalBinary() - if err != nil { - log.Warningf("snoop: pkt err %v", err) - } - if _, err := l.s.Write(b); err != nil { - log.Warningf("snoop: write err %v", err) - } - } -} - -// LogPacket logs a packet to stdout. -func LogPacket(prefix string, dir Direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - // Figure out the network layer info. - var transProto uint8 - var src tcpip.Address - var dst tcpip.Address - var size uint16 - var id uint32 - var fragmentOffset uint16 - var moreFragments bool - - clone := core.TrimmedClone(pkt) - defer clone.DecRef() - switch protocol { - case header.IPv4ProtocolNumber: - if ok := parse.IPv4(clone); !ok { - return - } - - ipv4 := header.IPv4(clone.NetworkHeader().Slice()) - fragmentOffset = ipv4.FragmentOffset() - moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments - src = ipv4.SourceAddress() - dst = ipv4.DestinationAddress() - transProto = ipv4.Protocol() - size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - id = uint32(ipv4.ID()) - - case header.IPv6ProtocolNumber: - proto, fragID, fragOffset, fragMore, ok := parse.IPv6(clone) - if !ok { - return - } - - ipv6 := header.IPv6(clone.NetworkHeader().Slice()) - src = ipv6.SourceAddress() - dst = ipv6.DestinationAddress() - transProto = uint8(proto) - size = ipv6.PayloadLength() - id = fragID - moreFragments = fragMore - fragmentOffset = fragOffset - - case header.ARPProtocolNumber: - if !parse.ARP(clone) { - return - } - - arp := header.ARP(clone.NetworkHeader().Slice()) - log.Infof( - "%s%s arp %s (%s) -> %s (%s) valid:%t", - prefix, - dir, - tcpip.AddrFromSlice(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), - tcpip.AddrFromSlice(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), - arp.IsValid(), - ) - return - default: - log.Infof("%s%s unknown network protocol: %d", prefix, dir, protocol) - return - } - - // Figure out the transport layer info. - transName := "unknown" - srcPort := uint16(0) - dstPort := uint16(0) - details := "" - switch tcpip.TransportProtocolNumber(transProto) { - case header.ICMPv4ProtocolNumber: - transName = "icmp" - hdr, ok := clone.Data().PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) - icmpType := "unknown" - if fragmentOffset == 0 { - switch icmp.Type() { - case header.ICMPv4EchoReply: - icmpType = "echo reply" - case header.ICMPv4DstUnreachable: - icmpType = "destination unreachable" - case header.ICMPv4SrcQuench: - icmpType = "source quench" - case header.ICMPv4Redirect: - icmpType = "redirect" - case header.ICMPv4Echo: - icmpType = "echo" - case header.ICMPv4TimeExceeded: - icmpType = "time exceeded" - case header.ICMPv4ParamProblem: - icmpType = "param problem" - case header.ICMPv4Timestamp: - icmpType = "timestamp" - case header.ICMPv4TimestampReply: - icmpType = "timestamp reply" - case header.ICMPv4InfoRequest: - icmpType = "info request" - case header.ICMPv4InfoReply: - icmpType = "info reply" - } - } - log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, dir, transName, src, dst, icmpType, size, id, icmp.Code()) - return - - case header.ICMPv6ProtocolNumber: - transName = "icmp" - hdr, ok := clone.Data().PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) - icmpType := "unknown" - switch icmp.Type() { - case header.ICMPv6DstUnreachable: - icmpType = "destination unreachable" - case header.ICMPv6PacketTooBig: - icmpType = "packet too big" - case header.ICMPv6TimeExceeded: - icmpType = "time exceeded" - case header.ICMPv6ParamProblem: - icmpType = "param problem" - case header.ICMPv6EchoRequest: - icmpType = "echo request" - case header.ICMPv6EchoReply: - icmpType = "echo reply" - case header.ICMPv6RouterSolicit: - icmpType = "router solicit" - case header.ICMPv6RouterAdvert: - icmpType = "router advert" - case header.ICMPv6NeighborSolicit: - icmpType = "neighbor solicit" - case header.ICMPv6NeighborAdvert: - icmpType = "neighbor advert" - case header.ICMPv6RedirectMsg: - icmpType = "redirect message" - } - log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, dir, transName, src, dst, icmpType, size, id, icmp.Code()) - return - - case header.UDPProtocolNumber: - transName = "udp" - if ok := parse.UDP(clone); !ok { - break - } - - udp := header.UDP(clone.TransportHeader().Slice()) - if fragmentOffset == 0 { - srcPort = udp.SourcePort() - dstPort = udp.DestinationPort() - details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) - size -= header.UDPMinimumSize - } - - case header.TCPProtocolNumber: - transName = "tcp" - if ok := parse.TCP(clone); !ok { - break - } - - tcp := header.TCP(clone.TransportHeader().Slice()) - if fragmentOffset == 0 { - offset := int(tcp.DataOffset()) - if offset < header.TCPMinimumSize { - details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) - break - } - if size := clone.Data().Size() + len(tcp); offset > size && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) - break - } - - srcPort = tcp.SourcePort() - dstPort = tcp.DestinationPort() - size -= uint16(offset) - - // Initialize the TCP flags. - flags := tcp.Flags() - details = fmt.Sprintf("flags:%s seqnum:%d ack:%d win:%d xsum:0x%x", flags, tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) - if flags&header.TCPFlagSyn != 0 { - details += fmt.Sprintf(" options:%+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) - } else { - details += fmt.Sprintf(" options:%+v", tcp.ParsedOptions()) - } - } - - default: - log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, dir, src, dst, transProto) - return - } - - if pkt.GSOOptions.Type != stack.GSONone { - details += fmt.Sprintf(" gso:%#v", pkt.GSOOptions) - } - - log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:0x%04x %s", prefix, dir, transName, src, srcPort, dst, dstPort, size, id, details) -} diff --git a/intra/netstack/stackopts.go b/intra/netstack/stackopts.go deleted file mode 100644 index db2c0556..00000000 --- a/intra/netstack/stackopts.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -func SetNetstackOpts(s *stack.Stack) { - // TODO: other stack otps? - // github.com/xjasonlyu/tun2socks/blob/31468620e/core/option/option.go#L69 - - // TODO: setup protocol opts? - // github.com/google/gvisor/blob/ef9e8d91/test/benchmarks/tcp/tcp_proxy.go#L233 - sack := tcpip.TCPSACKEnabled(true) - _ = s.SetTransportProtocolOption(tcp.ProtocolNumber, &sack) - - // from: github.com/tailscale/tailscale/commit/83808029d8c - // See https://github.com/tailscale/tailscale/issues/9707 - // RACKs lead to spurious retransmissions and a reduced congestion window. - tcpRecoveryOpt := tcpip.TCPRecovery(0) - _ = s.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt) - - // from: github.com/telepresenceio/telepresence/blob/ab7dda7d55/pkg/vif/stack.go#L232 - // Enable Receive Buffer Auto-Tuning, see: github.com/google/gvisor/issues/1666 - bufauto := tcpip.TCPModerateReceiveBufferOption(true) - _ = s.SetTransportProtocolOption(tcp.ProtocolNumber, &bufauto) - - // probably a bad idea? github.com/tailscale/tailscale/blob/9d9a70d81d/wgengine/netstack/netstack.go#L330 - // coder.com/blog/delivering-5x-faster-throughput-in-coder-2-12-0 - // ccopt := tcpip.CongestionControlOption("cubic") - // _ = s.SetTransportProtocolOption(tcp.ProtocolNumber, &ccopt) - - ttl := tcpip.DefaultTTLOption(128) - s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &ttl) - s.SetNetworkProtocolOption(ipv6.ProtocolNumber, &ttl) - - // github.com/tailscale/tailscale/blob/c4d0237e5c/wgengine/netstack/netstack_tcpbuf_default.go - tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: tcp.MinBufferSize, - Default: tcp.DefaultSendBufferSize, - Max: 8 << 20, // 8MiB - } - tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{ - Min: tcp.MinBufferSize, - Default: tcp.DefaultReceiveBufferSize, - Max: 6 << 20, // 6MiB - } - // github.com/tailscale/tailscale/blob/c4d0237e5c/wgengine/netstack/netstack.go#L329 - _ = s.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt) - _ = s.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt) -} diff --git a/intra/netstack/stat.go b/intra/netstack/stat.go deleted file mode 100644 index 9698cb8b..00000000 --- a/intra/netstack/stat.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "errors" - "strings" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var errNoStack = errors.New("netstat: no stack") - -func Stat(s *stack.Stack) (out *x.NetStat, err error) { - if s == nil { - return nil, errNoStack - } - - out = new(x.NetStat) - - stat := s.Stats() - allinfo := s.NICInfo() - tcp := stat.TCP - udp := stat.UDP - icmp := stat.ICMP - ip := stat.IP - nic := stat.NICs - - // nicinfo - if len(allinfo) > 0 { - if info, ok := allinfo[settings.NICID]; ok { - out.NICIn.Name = info.Name - out.NICIn.Mtu = int32(info.MTU) - out.NICIn.HwAddr = info.LinkAddress.String() - addrs := make([]string, 0, len(info.ProtocolAddresses)) - for _, addr := range info.ProtocolAddresses { - addrs = append(addrs, addr.AddressWithPrefix.String()) - } - out.NICIn.Addrs = strings.Join(addrs, ",") - out.NICIn.Arp = int32(info.ARPHardwareType) - out.NICIn.Up = info.Flags.Up - out.NICIn.Running = info.Flags.Running - out.NICIn.Lo = info.Flags.Loopback - out.NICIn.Promisc = info.Flags.Promiscuous - out.NICIn.Forwarding4 = info.Forwarding[ipv4.ProtocolNumber] - out.NICIn.Forwarding6 = info.Forwarding[ipv6.ProtocolNumber] - } else { - out.NICIn.Name = "missing" - } - } else { - out.NICIn.Name = "unknown" - } - - // nic - out.NICSt.Rx = core.FmtBytes(uint64(nic.Rx.Bytes.Value())) - out.NICSt.RxPkts = int64(nic.Rx.Packets.Value()) - out.NICSt.Tx = core.FmtBytes(uint64(nic.Tx.Bytes.Value())) - out.NICSt.TxPkts = int64(nic.Tx.Packets.Value()) - out.NICSt.Drops = int64(nic.TxPacketsDroppedNoBufferSpace.Value()) - out.NICSt.Invalid = int64(nic.MalformedL4RcvdPackets.Value()) - out.NICSt.L3Unknown = int64(summation(nic.UnknownL3ProtocolRcvdPacketCounts)) - out.NICSt.L4Unknown = int64(summation(nic.UnknownL4ProtocolRcvdPacketCounts)) - out.NICSt.L4Drops = int64(stat.DroppedPackets.Value()) - // ip - out.IPSt.InvalidDst = int64(ip.InvalidDestinationAddressesReceived.Value()) - out.IPSt.InvalidSrc = int64(ip.InvalidSourceAddressesReceived.Value()) - out.IPSt.InvalidFrag = int64(ip.MalformedFragmentsReceived.Value()) - out.IPSt.InvalidPkt = int64(ip.MalformedPacketsReceived.Value()) - out.IPSt.Errs = int64(ip.OutgoingPacketErrors.Value()) - out.IPSt.Rcv = int64(ip.PacketsReceived.Value()) - out.IPSt.Snd = int64(ip.PacketsSent.Value()) - out.IPSt.ErrSnd = out.IPSt.Snd - int64(ip.PacketsDelivered.Value()) - out.IPSt.ErrRcv = out.IPSt.Rcv - int64(ip.ValidPacketsReceived.Value()) - // ip forwarding - router := ip.Forwarding - out.FWDSt.Errs = int64(router.Errors.Value()) - out.FWDSt.Timeouts = int64(router.ExhaustedTTL.Value()) - out.FWDSt.Unrch = int64(router.HostUnreachable.Value()) - out.FWDSt.PTB = int64(router.PacketTooBig.Value()) - out.FWDSt.Drops = int64(router.InitializingSource.Value() + - router.ExtensionHeaderProblem.Value() + - router.LinkLocalDestination.Value() + - router.LinkLocalSource.Value() + - router.NoMulticastPendingQueueBufferSpace.Value() + - router.OutgoingDeviceNoBufferSpace.Value()) - out.FWDSt.NoRoute = int64(router.Unrouteable.Value()) - out.FWDSt.NoHop = int64(router.UnknownOutputEndpoint.Value()) - // icmp - out.ICMPSt.Snd4 = int64(icmp.V4.PacketsSent.EchoRequest.Value()) - out.ICMPSt.Snd6 = int64(icmp.V6.PacketsSent.EchoRequest.Value()) - out.ICMPSt.Rcv4 = int64(icmp.V4.PacketsReceived.EchoReply.Value()) - out.ICMPSt.Rcv6 = int64(icmp.V6.PacketsReceived.EchoReply.Value()) - out.ICMPSt.UnrchRcv4 = int64(icmp.V4.PacketsReceived.DstUnreachable.Value()) - out.ICMPSt.UnrchRcv6 = int64(icmp.V6.PacketsReceived.DstUnreachable.Value()) - out.ICMPSt.UnrchSnd4 = int64(icmp.V4.PacketsSent.DstUnreachable.Value()) - out.ICMPSt.UnrchSnd6 = int64(icmp.V6.PacketsSent.DstUnreachable.Value()) - out.ICMPSt.Drops4 = int64(icmp.V4.PacketsSent.Dropped.Value()) - out.ICMPSt.Drops6 = int64(icmp.V6.PacketsSent.Dropped.Value()) - out.ICMPSt.Invalid4 = int64(icmp.V4.PacketsReceived.Invalid.Value()) - out.ICMPSt.Invalid6 = int64(icmp.V6.PacketsReceived.Invalid.Value()) - out.ICMPSt.TimeoutSnd4 = int64(icmp.V4.PacketsSent.TimeExceeded.Value()) - out.ICMPSt.TimeoutSnd6 = int64(icmp.V6.PacketsSent.TimeExceeded.Value()) - out.ICMPSt.TimeoutRcv4 = int64(icmp.V4.PacketsReceived.TimeExceeded.Value()) - out.ICMPSt.TimeoutRcv6 = int64(icmp.V6.PacketsReceived.TimeExceeded.Value()) - // udp - out.UDPSt.ErrChecksum = int64(udp.ChecksumErrors.Value()) - out.UDPSt.ErrRcv = int64(udp.MalformedPacketsReceived.Value()) - out.UDPSt.ErrSnd = int64(udp.PacketsReceived.Value()) - out.UDPSt.Rcv = int64(udp.PacketsReceived.Value()) - out.UDPSt.Snd = int64(udp.PacketsSent.Value()) - out.UDPSt.PortFail = int64(udp.UnknownPortErrors.Value()) - out.UDPSt.Drops = int64(udp.ReceiveBufferErrors.Value()) - // tcp - out.TCPSt.Active = int64(tcp.ActiveConnectionOpenings.Value()) - out.TCPSt.Passive = int64(tcp.PassiveConnectionOpenings.Value()) - out.TCPSt.ErrChecksum = int64(tcp.ChecksumErrors.Value()) - out.TCPSt.Est = int64(tcp.CurrentEstablished.Value()) - out.TCPSt.Con = int64(tcp.CurrentConnected.Value()) - out.TCPSt.EstClo = int64(tcp.EstablishedClosed.Value()) - out.TCPSt.EstRst = int64(tcp.EstablishedResets.Value()) - out.TCPSt.EstTo = int64(tcp.EstablishedTimedout.Value()) - out.TCPSt.ConFail = int64(tcp.FailedConnectionAttempts.Value()) - out.TCPSt.PortFail = int64(tcp.FailedPortReservations.Value()) - out.TCPSt.ErrRcv = int64(tcp.InvalidSegmentsReceived.Value()) - out.TCPSt.AckDrop = int64(tcp.ListenOverflowAckDrop.Value()) - out.TCPSt.Snd = int64(tcp.SegmentsSent.Value()) - out.TCPSt.Rcv = int64(tcp.ValidSegmentsReceived.Value()) - out.TCPSt.ErrSnd = int64(tcp.SegmentSendErrors.Value()) - out.TCPSt.SynDrop = int64(tcp.ListenOverflowSynDrop.Value()) - out.TCPSt.Retrans = int64(tcp.Retransmits.Value()) - out.TCPSt.Timeouts = int64(tcp.Timeouts.Value()) - out.TCPSt.Drops = int64(tcp.ForwardMaxInFlightDrop.Value()) - - return out, nil -} - -func summation(m *tcpip.IntegralStatCounterMap) (sum uint64) { - if m == nil { - return - } - for _, proto := range m.Keys() { - if v, ok := m.Get(proto); ok { - sum += v.Value() - } - } - return sum -} diff --git a/intra/netstack/tcp.go b/intra/netstack/tcp.go deleted file mode 100644 index c1d9c090..00000000 --- a/intra/netstack/tcp.go +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "context" - "io" - "net" - "net/netip" - "sync" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// ref: github.com/tailscale/tailscale/blob/cfb5bd0559/wgengine/netstack/netstack.go#L236-L237 -const rcvwnd = 0 - -const maxInFlight = 512 // arbitrary - -// retry connect when early connect (done when no happy eyeballs) fails? -const retryLateConnect = false - -var ( - // defaults: github.com/google/gvisor/blob/fa49677e141db/pkg/tcpip/transport/tcp/protocol.go#L73 - // idle: 2h; count: 9; interval: 75s - defaultKeepAliveIdle = tcpip.KeepaliveIdleOption(10 * time.Minute) - defaultKeepAliveInterval = tcpip.KeepaliveIntervalOption(5 * time.Second) - defaultKeepAliveCount = 4 // unacknowledged probes - // github.com/tailscale/tailscale/blob/65fe0ba7b5/cmd/derper/derper.go#L75-L78 - // blog.cloudflare.com/when-tcp-sockets-refuse-to-die => Idle + (Interval * Count) - usrTimeout = tcpip.TCPUserTimeoutOption(10*time.Minute + (4 * 5 * time.Second)) -) - -type GTCPConnHandler interface { - GSpecConnHandler[*GTCPConn] -} - -var _ core.TCPConn = (*GTCPConn)(nil) - -type GTCPConn struct { - o string // owner - stack *stack.Stack - c *core.Volatile[*gonet.TCPConn] // conn exposes TCP semantics atop endpoint - src netip.AddrPort // local addr (remote addr in netstack) - dst netip.AddrPort // remote addr (local addr in netstack) - req *tcp.ForwarderRequest // egress request as a TCP state machine - once sync.Once -} - -// s is the netstack to use for dialing (reads/writes). -// in is the incoming connection to netstack, s. -// to (src) is remote. -// from (dst) is local (to netstack, s). -// h is the handler that handles connection in into netstack, s, by -// dialing to from (dst) from to (src). -func InboundTCP(who string, s *stack.Stack, in net.Conn, to, from netip.AddrPort, h GTCPConnHandler) error { - newgc := makeGTCPConn(who, s, nil /*not a forwarder req*/, to, from) - - // early syn/ack is okay if happy eyeballs isn't strictly required - if !settings.HappyEyeballs.Load() { - open, err := newgc.tryConnect() - - if settings.Debug { - logeif(err)("ns: tcp: %s: inbound: tryConnect err src(%v) => dst(%v); open? %t / retry? %t, err(%v)", - newgc.o, to, from, open, retryLateConnect, err) - } - // TODO: call in a go routine if settings.SingleThreaded is set - if !retryLateConnect && (err != nil || !open) { - if err == nil { - err = errMissingEp - } - h.Error(newgc, to, from, err) // error - return err - } - } - h.ReverseProxy(newgc, in, to, from) - return nil -} - -// OutboundTCP sets up a TCP forwarder h to handle TCP packets. -// If h is nil, s uses the (built-in) default TCP forwarding logic. -func OutboundTCP(who string, s *stack.Stack, h GTCPConnHandler) { - if fwd := tcpForwarder(who, s, h); fwd != nil { - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - } else { // unset - log.I("ns: tcp: %s: forwarder: nil handler; unsetting forwarder...", who) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, nil) - } -} - -// nic.deliverNetworkPacket -> no existing matching endpoints -> tcpForwarder.HandlePacket -// ref: github.com/google/gvisor/blob/e89e736f1/pkg/tcpip/adapters/gonet/gonet_test.go#L189 -func tcpForwarder(who string, s *stack.Stack, h GTCPConnHandler) *tcp.Forwarder { - if h == nil { - return nil - } - - return tcp.NewForwarder(s, rcvwnd, maxInFlight, func(req *tcp.ForwarderRequest) { - if req == nil { - log.E("ns: tcp: %s: forwarder: nil request", who) - return - } - id := req.ID() - // src 10.111.222.1:38312 / [fd66:f83a:c650::1]:15753 - src := remoteAddrPort(id) - // dst 213.188.195.179:80 - dst := localAddrPort(id) - - // read/writes are routed using 5-tuple to the same conn (endpoint) - // demuxer.handlePacket -> find matching endpoint -> queue-packet -> send/recv conn (ep) - // ref: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/stack/transport_demuxer.go#L180 - gtcp := makeGTCPConn(who, s, req, src, dst) - - // setup endpoint right away, so that netstack's internal state is consistent - // in case there are multiple forwarders dispatching from the TUN device. - if !settings.HappyEyeballs.Load() { // syn-ack before delivering to handler? - opened, err := gtcp.tryConnect() - - if settings.Debug { - logeif(err)("ns: tcp: %s: forwarder: tryConnect err src(%v) => dst(%v); open? %t, err(%v)", - who, src, dst, opened, err) - } - // TODO: call in a go routine if settings.SingleThreaded is set - if !retryLateConnect && (err != nil || !opened) { - h.Error(gtcp, src, dst, core.OneErr(err, errMissingEp)) // error - } else { // gtcp may be connected - h.Proxy(gtcp, src, dst) - } - } else { - // call the handler in-line, blocking the netstack "processor", - // however; handler must r/w to/from src/dst async after connect. - h.Proxy(gtcp, src, dst) - } - }) -} - -func makeGTCPConn(who string, s *stack.Stack, req *tcp.ForwarderRequest, src, dst netip.AddrPort) *GTCPConn { - // set sock-opts? github.com/xjasonlyu/tun2socks/blob/31468620e/core/tcp.go#L82 - return >CPConn{ - o: who, - stack: s, - c: core.NewZeroVolatile[*gonet.TCPConn](), - src: src, - dst: dst, - req: req, // may be nil - } -} - -func (g *GTCPConn) ok() bool { - return g.conn() != nil -} - -func (g *GTCPConn) conn() *gonet.TCPConn { - return g.c.Load() -} - -func (g *GTCPConn) Establish() (open bool, err error) { - rst, err := g.synack(true) - - log.VV("ns: tcp: %s: forwarder: connect src(%v) => dst(%v); fin? %t", - g.o, g.LocalAddr(), g.RemoteAddr(), rst) - return !rst, err -} - -func (g *GTCPConn) tryConnect() (open bool, err error) { - rst, err := g.synack(false) - - log.VV("ns: tcp: %s: forwarder: proxy src(%v) => dst(%v); fin? %t", g.o, g.LocalAddr(), g.RemoteAddr(), rst) - return !rst, err // open or closed -} - -// complete must be called at least once, otherwise the conn counts towards -// maxInFlight and may cause silent tcp conn drops. -func (g *GTCPConn) complete(rst bool) { - g.once.Do(func() { - req := g.req - log.D("ns: tcp: %s: forwarder: complete src(%v) => dst(%v); req? %t, rst? %t", - g.o, g.LocalAddr(), g.RemoteAddr(), req != nil, rst) - if req != nil { - req.Complete(rst) - } - }) -} - -func (g *GTCPConn) synack(complete bool) (rst bool, err error) { - if g.ok() { // already setup - return false, nil // open, err free - } - - defer func() { - // complete when either g is opened or complete is set - if complete || !rst { - g.complete(rst) - } - }() - - if g.req != nil { // egressing (process netstack's req from tun) - wq := new(waiter.Queue) - // the passive-handshake (SYN) may not successful for a non-existent route (say, ipv6) - if ep, err := g.req.CreateEndpoint(wq); err != nil || ep == nil { - log.E("ns: tcp: %s: connect: (outbound) synack(complete? %t / ep? %t) src(%v) => dst(%v); err(%v)", - g.o, complete, ep != nil, g.LocalAddr(), g.RemoteAddr(), err) - // prevent potential half-open TCP connection leak. - // hopefully doesn't break happy-eyeballs datatracker.ietf.org/doc/html/rfc8305#section-5 - // TCP RST here is indistinguishable to an app from being firewalled. - return true, e(err) // close, err - } else { - keepalive(ep) - conn := gonet.NewTCPConn(wq, ep) - g.c.Store(conn) - } - } else { // ingressing (process a conn into tun) - if settings.Debug { - log.V("ns: tcp: %s: dial: (inbound) creating endpoint for %v => %v", g.o, g.LocalAddr(), g.RemoteAddr()) - } - src, proto := addrport2nsaddr(g.dst) // remote addr is local addr in netstack - dst, _ := addrport2nsaddr(g.src) // local addr is remote addr in netstack - bg := context.Background() - if conn, err := gonet.DialTCPWithBind(bg, g.stack, src, dst, proto); err != nil { - log.E("ns: tcp: %s: dial: (inbound) synack(complete? %t) src(%v) => dst(%v); err(%v)", - g.o, complete, g.LocalAddr(), g.RemoteAddr(), err) - return true, err // close, err - } else { - g.c.Store(conn) - } - } - - return false, nil // open, err free -} - -func keepalive(ep tcpip.Endpoint) { - if settings.GetDialerOpts().LowerKeepAlive { - // github.com/tailscale/tailscale/issues/4522 (low keepalive) - // github.com/tailscale/tailscale/pull/6147 (high keepalive) - // github.com/tailscale/tailscale/issues/6148 (other changes) - sockopt(ep, &defaultKeepAliveIdle, &defaultKeepAliveInterval, &usrTimeout) - ep.SetSockOptInt(tcpip.KeepaliveCountOption, defaultKeepAliveCount) - // github.com/tailscale/tailscale/commit/1aa75b1c9ea2 - ep.SocketOptions().SetKeepAlive(true) // applies netstack defaults - } -} - -func sockopt(ep tcpip.Endpoint, opts ...tcpip.SettableSocketOption) { - for _, opt := range opts { - if opt != nil { - _ = ep.SetSockOpt(opt) - } - } -} - -// gonet conn local and remote addresses may be nil -// ref: github.com/tailscale/tailscale/blob/8c5c87be2/wgengine/netstack/netstack.go#L768-L775 -// and: github.com/google/gvisor/blob/ffabadf0/pkg/tcpip/transport/tcp/endpoint.go#L2759 -func (g *GTCPConn) LocalAddr() net.Addr { - if c := g.conn(); c != nil { - // client local addr is remote to the gonet adapter - if addr := c.RemoteAddr(); addr != nil { - return addr - } - } - return net.TCPAddrFromAddrPort(g.src) -} - -func (g *GTCPConn) RemoteAddr() net.Addr { - if c := g.conn(); c != nil { - // client remote addr is local to the gonet adapter - if addr := c.LocalAddr(); addr != nil { - return addr - } - } - return net.TCPAddrFromAddrPort(g.dst) -} - -func (g *GTCPConn) Write(data []byte) (int, error) { - if c := g.conn(); c != nil { - return c.Write(data) - } - return 0, netError(g, "tcp", g.o+":write", io.ErrClosedPipe) -} - -func (g *GTCPConn) Read(data []byte) (int, error) { - if c := g.conn(); c != nil { - return c.Read(data) - } - return 0, netError(g, "tcp", g.o+":read", io.ErrNoProgress) -} - -func (g *GTCPConn) CloseWrite() error { - if c := g.conn(); c != nil { - return c.CloseWrite() - } - return netError(g, "tcp", g.o+":close", net.ErrClosed) -} - -func (g *GTCPConn) CloseRead() error { - if c := g.conn(); c != nil { - return c.CloseRead() - } - return netError(g, "tcp", g.o+":close", net.ErrClosed) -} - -func (g *GTCPConn) SetDeadline(t time.Time) error { - if c := g.conn(); c != nil { - return c.SetDeadline(t) - } else { - return nil // no-op to confirm with netstack's gonet impl - } -} - -func (g *GTCPConn) SetReadDeadline(t time.Time) error { - if c := g.conn(); c != nil { - return c.SetReadDeadline(t) - } - return nil // no-op to confirm with netstack's gonet impl -} - -func (g *GTCPConn) SetWriteDeadline(t time.Time) error { - if c := g.conn(); c != nil { - return c.SetWriteDeadline(t) - } - return nil // no-op to confirm with netstack's gonet impl -} - -// Abort aborts the connection by sending a RST segment. -func (g *GTCPConn) Abort() { - g.complete(true) // complete if needed - go core.Close(g.conn()) -} - -func (g *GTCPConn) Close() error { - g.Abort() - return nil // g.conn.Close always returns nil; see gonet.TCPConn.Close -} - -// from: netstack gonet -func netError(c net.Conn, proto, op string, err error) *net.OpError { - return &net.OpError{ - Op: op, - Net: proto, - Source: c.LocalAddr(), - Addr: c.RemoteAddr(), - Err: err, - } -} diff --git a/intra/netstack/udp.go b/intra/netstack/udp.go deleted file mode 100644 index f74fcbfc..00000000 --- a/intra/netstack/udp.go +++ /dev/null @@ -1,324 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package netstack - -import ( - "errors" - "io" - "net" - "net/netip" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - errMissingEp = errors.New("not connected to any endpoint") - errFilteredOut = errors.New("no eif; filtered out") -) - -type DemuxerFn func(in net.Conn, to netip.AddrPort) error - -type GUDPConnHandler interface { - GSpecConnHandler[*GUDPConn] - GMuxConnHandler[*GUDPConn] -} - -var _ core.UDPConn = (*GUDPConn)(nil) - -type GUDPConn struct { - o string // owner - stack *stack.Stack - - // conn exposes UDP semantics atop endpoint - c *core.Volatile[*gonet.UDPConn] - // local addr (remote addr in netstack) - // ex: 10.111.222.1:20716; same as endpoint.GetRemoteAddress - src netip.AddrPort - // remote addr (local addr in netstack) - // ex: 10.111.222.3:53; same as endpoint.GetLocalAddress - dst netip.AddrPort - - req *udp.ForwarderRequest // egress request as UDP - - eim bool // endpoint is muxed - eif bool // endpoint is transparent -} - -// ref: github.com/google/gvisor/blob/e89e736f1/pkg/tcpip/adapters/gonet/gonet_test.go#L373 -func makeGUDPConn(who string, s *stack.Stack, r *udp.ForwarderRequest, src, dst netip.AddrPort) *GUDPConn { - return &GUDPConn{ - o: who, - stack: s, - c: core.NewZeroVolatile[*gonet.UDPConn](), - src: src, - dst: dst, - req: r, - eim: settings.EndpointIndependentMapping.Load(), - eif: settings.EndpointIndependentFiltering.Load(), - } -} - -// OutboundUDP sets up a UDP forwarder h for outbound UDP packets. -// If h is nil, s uses the (built-in) default UDP forwarding logic. -func OutboundUDP(who string, s *stack.Stack, h GUDPConnHandler) { - if fwd := udpForwarder(who, s, h); fwd != nil { - s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - } else { // unset - log.I("ns: udp: %s: forwarder: nil handler; unsetting forwarder...", who) - s.SetTransportProtocolHandler(udp.ProtocolNumber, nil) - } -} - -func InboundUDP(who string, s *stack.Stack, in net.Conn, to, from netip.AddrPort, h GUDPConnHandler) error { - newgc := makeGUDPConn(who, s, nil /*not a forwarder req*/, to, from) - if !settings.HappyEyeballs.Load() { // ref comment in netstack/tcp.go - err := newgc.Establish() - - if settings.Debug { - logeif(err)("ns: udp: %s: inbound: dial: %v; src(%v) dst(%v)", - who, err, to, from) - } - - // TODO: call in a go routine if settings.SingleThreaded is set - if !retryLateConnect && err != nil { - h.Error(newgc, to, from, err) - return err - } - } - h.ReverseProxy(newgc, in, to, from) - return nil -} - -// Perhaps udp conns shouldn't be closed as eagerly as its tcp counterpart -// Netstack's udp conn is apparently a 'connected udp' socket and it goes through a -// lot of motions, from what I can tell, to support both unconnected and connected -// udp sockets. This is untested and unconfirmed speculation from us, but unless -// intra/udp.go refrains from closing this udp conn, we'll never find out I guess. -// ref: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/stack/transport_demuxer.go#L590 -// and: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/stack/transport_demuxer.go#L75 -// and: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/transport/udp/endpoint.go#L903 -// via: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/adapters/gonet/gonet.go#L315 -// fin: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/transport/udp/endpoint.go#L220 -// but: github.com/google/gvisor/blob/be6ffa7/pkg/tcpip/transport/udp/endpoint.go#L180 -func udpForwarder(who string, s *stack.Stack, h GUDPConnHandler) *udp.Forwarder { - if h == nil { - return nil - } - - return udp.NewForwarder(s, func(req *udp.ForwarderRequest) (handled bool) { - if req == nil { - log.E("ns: udp: %s: forwarder: nil request", who) - return - } - - // owner app tun ns h - // repr socket packet endpoint socket - // type udp fd gudpconn core.minconn - // - // (src, dst) :1111, :53 :1111, :53 :53, :1111 :9999, :53 - // - // write :1111 => :53 :1111, :53 :53 => :1111 :9999 => :53 - // \ / - // \ / - // (pipe) \ / - // / \ - // / \ - // / \ - // read :1111 <= :53 :1111, :53 :53 <= :1111 :9999 <= :53 - id := req.ID() - // src 10.111.222.1:20716; same as endpoint.GetRemoteAddress - src := remoteAddrPort(id) - // dst 10.111.222.3:53; same as endpoint.GetLocalAddress - // but it may not always be the true dst (for now it is), - // especially if the resulting udp-conn is setup to handle - // multiple dst in the unconnected udp case. - dst := localAddrPort(id) - - gc := makeGUDPConn(who, s, req, src, dst) - - demux := func(ingress net.Conn, newdst netip.AddrPort) error { - if newdst.Compare(dst) == 0 { - log.D("ns: udp: %s: demuxer: no-op; src(%v) same as dst(%v)", - who, src, newdst) - return nil - } - if !gc.eif { - return errFilteredOut - } - return InboundUDP(who, s, ingress, src, newdst, h) - } - - // setup to recv right away, so that netstack's internal state is consistent - // in case there are multiple forwarders dispatching from the TUN device. - if !settings.HappyEyeballs.Load() { - err := gc.Establish() - - if settings.Debug { - logeif(err)("ns: udp: %s: forwarder: connect: %v; src(%v) dst(%v)", - who, err, src, dst) - } - // TODO: call in a go routine if settings.SingleThreaded is set - if !retryLateConnect && err != nil { - h.Error(gc, src, dst, err) - return false // not handled - } - handle(h, gc, src, dst, demux) // gc may be connected - return true // handled - } else { - // handler must connect sync; blocking netstack's processor - // but perform other ops like r/w to/from src/dst async. - return handle(h, gc, src, dst, demux) - } - }) -} - -func handle(h GUDPConnHandler, gc *GUDPConn, src, dst netip.AddrPort, demux DemuxerFn) (ok bool) { - if gc.eim { - ok = h.ProxyMux(gc, src, dst, demux) - } else { - ok = h.Proxy(gc, src, dst) - } - return -} - -func (g *GUDPConn) ok() bool { - return g.conn() != nil -} - -func (g *GUDPConn) conn() *gonet.UDPConn { - return g.c.Load() -} - -func (g *GUDPConn) StatefulTeardown() (fin bool) { - _ = g.Establish() // establish circuit then teardown - _ = g.Close() // then shutdown - return true // always fin -} - -func (g *GUDPConn) Establish() error { - if g.ok() { // already setup - return nil - } - - if g.req == nil { // ingressing (a network conn inbound to tun) - src, proto := addrport2nsaddr(g.dst) // remote addr is local addr in netstack - dst, _ := addrport2nsaddr(g.src) // local addr is remote addr in netstack - // ingress socket w/ gonet.DialUDP - if conn, err := gonet.DialUDP(g.stack, &src, &dst, proto); err != nil { - log.E("ns: udp: %s: dial: (inbound) endpoint for %v => %v; err(%v)", - g.o, g.src, g.dst, err) - return err - } else { - g.c.Store(conn) - } - } else { // egressing (netstack's conn from tun outbound to network) - if settings.Debug { - log.V("ns: udp: %s: connect: creating endpoint for %v => %v", g.o, g.src, g.dst) - } - - wq := new(waiter.Queue) - if ep, err := g.req.CreateEndpoint(wq); err != nil || ep == nil { - // ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host) - // 'bad local addrs' on missing NIC, 'invalid state' if could not be bound/connected - log.E("ns: udp: %s: connect: (outbound) endpoint(ok? %t) for %v => %v; err(%v)", - g.o, ep != nil, g.src, g.dst, err) - return e(err) - } else { - g.c.Store(gonet.NewUDPConn(wq, ep)) - } - } - return nil -} - -func (g *GUDPConn) LocalAddr() (addr net.Addr) { - if c := g.conn(); c != nil { - addr = c.RemoteAddr() - } - if addr == nil { // remoteaddr may be nil, even if g.ok() - addr = net.UDPAddrFromAddrPort(g.src) - } - return -} - -func (g *GUDPConn) RemoteAddr() (addr net.Addr) { - if c := g.conn(); c != nil { - addr = c.LocalAddr() - } - if addr == nil { // localaddr may be nil, even if g.ok() - addr = net.UDPAddrFromAddrPort(g.dst) - } - return -} - -func (g *GUDPConn) Write(data []byte) (int, error) { - if c := g.conn(); c != nil { - // nb: write-deadlines set by intra.udp - // addr: 10.111.222.3:17711; g.LocalAddr(g.udp.remote): 10.111.222.3:17711; g.RemoteAddr(g.udp.local): 10.111.222.1:53 - // ep(state 3 / info &{2048 17 {53 10.111.222.3 17711 10.111.222.1} 1 10.111.222.3 1} / stats &{{{1}} {{0}} {{{0}} {{0}} {{0}} {{0}}} {{{0}} {{0}} {{0}}} {{{0}} {{0}}} {{{0}} {{0}} {{0}}}}) - // 3: status:datagram-connected / {2048=>proto, 17=>transport, {53=>local-port localip 17711=>remote-port remoteip}=>endpoint-id, 1=>bind-nic-id, ip=>bind-addr, 1=>registered-nic-id} - // g.ep may be nil: log.V("ns: writeFrom: from(%v) / ep(state %v / info %v / stats %v)", addr, g.ep.State(), g.ep.Info(), g.ep.Stats()) - return c.Write(data) - } - return 0, netError(g, "udp", g.o+":write", io.ErrClosedPipe) -} - -func (g *GUDPConn) Read(data []byte) (int, error) { - if c := g.conn(); c != nil { - return c.Read(data) - } - return 0, netError(g, "udp", g.o+":read", io.ErrNoProgress) -} - -func (g *GUDPConn) WriteTo(data []byte, addr net.Addr) (int, error) { - if c := g.conn(); c != nil { - return c.WriteTo(data, addr) - } - return 0, netError(g, "udp", g.o+":writeTo", net.ErrWriteToConnected) -} - -func (g *GUDPConn) ReadFrom(data []byte) (int, net.Addr, error) { - if c := g.conn(); c != nil { - return c.ReadFrom(data) - } - return 0, nil, netError(g, "udp", g.o+":readFrom", io.ErrNoProgress) -} - -func (g *GUDPConn) SetDeadline(t time.Time) error { - if c := g.conn(); c != nil { - return c.SetDeadline(t) - } // else: no-op as with netstack's gonet impl - return nil -} - -func (g *GUDPConn) SetReadDeadline(t time.Time) error { - if c := g.conn(); c != nil { - return c.SetReadDeadline(t) - } // else: no-op as with netstack's gonet impl - return nil -} - -func (g *GUDPConn) SetWriteDeadline(t time.Time) error { - if c := g.conn(); c != nil { - return c.SetWriteDeadline(t) - } // else: no-op as with netstack's gonet impl - return nil -} - -// Close closes the connection. -func (g *GUDPConn) Close() error { - go core.Close(g.conn()) - return nil -} diff --git a/intra/netstack/waitgroup_test.go b/intra/netstack/waitgroup_test.go deleted file mode 100644 index 9f7f5aab..00000000 --- a/intra/netstack/waitgroup_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package netstack - -import ( - "os" - "sync" - "testing" - "time" -) - -// TestWaitGroupRaceCondition tests that the WaitGroup reuse issue is fixed. -// This test reproduces the scenario where an endpoint is swapped while -// another goroutine is waiting on the old endpoint. -func TestWaitGroupRaceCondition(t *testing.T) { - // Create a temp file to simulate a TUN device - tmpFile, err := os.CreateTemp("", "test_tun") - if err != nil { - t.Skip("Cannot create temp file for test") - } - defer os.Remove(tmpFile.Name()) - defer tmpFile.Close() - - fd := int(tmpFile.Fd()) - - // Create a magiclink endpoint - endpoint, err := NewEndpoint(fd, 1500, &testSink{}) - if err != nil { - t.Fatalf("Failed to create endpoint: %v", err) - } - defer endpoint.Dispose() - - magicLink, ok := endpoint.(*magiclink) - if !ok { - t.Fatalf("Expected magiclink, got %T", endpoint) - } - - // Start multiple goroutines that will call Wait() on the endpoint - // while we swap endpoints in the background - var wg sync.WaitGroup - errors := make(chan error, 10) - - for i := 0; i < 5; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - defer func() { - if r := recover(); r != nil { - errors <- r.(error) - } - }() - - // Call Wait() multiple times to increase chance of race condition - for j := 0; j < 10; j++ { - magicLink.Wait() - time.Sleep(time.Millisecond) - } - }(i) - } - - // Swap endpoints multiple times while Wait() is being called - go func() { - for i := 0; i < 5; i++ { - // Create another temp file for swapping - tmpFile2, err := os.CreateTemp("", "test_tun2") - if err != nil { - continue - } - fd2 := int(tmpFile2.Fd()) - - // Swap to new fd - magicLink.Swap(fd2, 1500) - time.Sleep(time.Millisecond * 5) - - tmpFile2.Close() - os.Remove(tmpFile2.Name()) - } - }() - - // Wait for all goroutines to complete - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - // Check if any errors occurred - select { - case err := <-errors: - t.Fatalf("WaitGroup reuse panic occurred: %v", err) - default: - // Success - no panic occurred - } - case <-time.After(time.Second * 10): - t.Fatal("Test timed out") - } -} - -// TestStackTraceScenario tests the specific scenario from the original stack trace -// where magiclink.Wait() is called during endpoint swapping. -func TestStackTraceScenario(t *testing.T) { - // Create a temp file to simulate a TUN device - tmpFile, err := os.CreateTemp("", "test_tun") - if err != nil { - t.Skip("Cannot create temp file for test") - } - defer os.Remove(tmpFile.Name()) - defer tmpFile.Close() - - fd := int(tmpFile.Fd()) - - // Create a magiclink endpoint - endpoint, err := NewEndpoint(fd, 1500, &testSink{}) - if err != nil { - t.Fatalf("Failed to create endpoint: %v", err) - } - defer endpoint.Dispose() - - magicLink, ok := endpoint.(*magiclink) - if !ok { - t.Fatalf("Expected magiclink, got %T", endpoint) - } - - // Simulate the exact scenario from the stack trace: - // seamless.go:312>fdbased.go:413 - magiclink.Wait() calls endpoint.Wait() - panicked := false - done := make(chan struct{}) - - // Start a goroutine that continuously calls Wait() like the tunnel waiter - go func() { - defer func() { - if r := recover(); r != nil { - panicked = true - } - close(done) - }() - - for i := 0; i < 100; i++ { - magicLink.Wait() - time.Sleep(time.Millisecond) - } - }() - - // Concurrently perform rapid endpoint swaps - for i := 0; i < 10; i++ { - tmpFile2, err := os.CreateTemp("", "test_tun2") - if err != nil { - continue - } - fd2 := int(tmpFile2.Fd()) - - // Rapid swap - this should not cause WaitGroup reuse panic - magicLink.Swap(fd2, 1500) - - tmpFile2.Close() - os.Remove(tmpFile2.Name()) - time.Sleep(time.Millisecond * 2) - } - - // Wait for the wait goroutine to complete - select { - case <-done: - if panicked { - t.Fatal("WaitGroup reuse panic occurred in stack trace scenario") - } - case <-time.After(time.Second * 15): - t.Fatal("Test timed out") - } -} - -// testSink is a simple implementation of io.WriteCloser for testing -type testSink struct{} - -func (ts *testSink) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (ts *testSink) Close() error { - return nil -} diff --git a/intra/netstat/procfs.go b/intra/netstat/procfs.go deleted file mode 100644 index 734a43f4..00000000 --- a/intra/netstat/procfs.go +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// Code relicensed from opensnitch with permissions from evilsocket. -package netstat - -import ( - "bufio" - "encoding/binary" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "regexp" - "strconv" - "strings" - "sync" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" -) - -const ( - crlftabspace = "\r\n\t " - cachettl = 30000 // millis -) - -var ( - parser = regexp.MustCompile(`(?i)` + - `\d+:\s+` + // sl - // source - `([a-f0-9]{8,32}):([a-f0-9]{4})\s+` + - // destination - `([a-f0-9]{8,32}):([a-f0-9]{4})\s+` + - `[a-f0-9]{2}\s+` + // st - // transfer queue, receive queue - `[a-f0-9]{8}:[a-f0-9]{8}\s+` + - // tr tm->when - `[a-f0-9]{2}:[a-f0-9]{8}\s+` + - // retrnsmt - `[a-f0-9]{8}\s+` + - // uid - `(\d+)\s+` + - // timeout - `\d+\s+` + - // inode - `(\d+)\s+` + - // the rest... - `.+`) - - cache = NewProcNetCache() - - zeroip4 = netip.IPv4Unspecified() - zeroip6 = netip.IPv6Unspecified() - zeroPort = 0 -) - -// ProcNetEntry represents a single line as fetched from /proc/net/* -type ProcNetEntry struct { - Protocol string - SrcIP netip.Addr - SrcPort int - DstIP netip.Addr - DstPort int - UserID int - INode int - ctime time.Time -} - -type ProcNetCache struct { - pool *sync.Map // string, *ProcNetEntry{} - lastcleanup time.Time -} - -func NewProcNetCache() ProcNetCache { - return ProcNetCache{ - pool: new(sync.Map), - lastcleanup: time.Now(), - } -} - -// NewProcNetEntry creates an Entry -func NewProcNetEntry(protocol string, srcIP netip.Addr, srcPort int, dstIP netip.Addr, dstPort int, userID int, iNode int) ProcNetEntry { - return ProcNetEntry{ - Protocol: protocol, - SrcIP: srcIP, - SrcPort: srcPort, - DstIP: dstIP, - DstPort: dstPort, - UserID: userID, - INode: iNode, - ctime: time.Now(), - } -} - -func (p *ProcNetEntry) String() string { - return p.Protocol + p.SrcIP.String() + strconv.Itoa(p.SrcPort) + p.DstIP.String() + strconv.Itoa(p.DstPort) -} - -func (p *ProcNetEntry) Same(q *ProcNetEntry) bool { - if p == nil || q == nil { - return false - } - - if p.Protocol != q.Protocol { - return false - } - - // unmap: github.com/golang/go/issues/53607 - src1 := p.SrcIP.Unmap() - src2 := q.SrcIP.Unmap() - dst1 := p.DstIP.Unmap() - dst2 := q.DstIP.Unmap() - - if src1.Is6() && !src2.Is6() { - return false - } - if dst1.Is6() && !dst2.Is6() { - return false - } - - zeroip := zeroip4 - if src1.Is6() { - zeroip = zeroip6 - } - - // github.com/M66B/NetGuard/blob/1fe3a04ae/app/src/main/jni/netguard/ip.c#L393 - skipSrcIP := false - skipDstIP := false - skipDstPort := false - if zeroip.Compare(src1) == 0 || zeroip.Compare(src2) == 0 { - skipSrcIP = true - } - if zeroip.Compare(dst1) == 0 || zeroip.Compare(dst2) == 0 { - skipDstIP = true - } - if zeroPort == p.DstPort || zeroPort == q.DstPort { - skipDstPort = true - } - - return (skipSrcIP || src1.Compare(src2) == 0) && - p.SrcPort == q.SrcPort && - (skipDstIP || dst1.Compare(dst2) == 0) && - (skipDstPort || p.DstPort == q.DstPort) -} - -func trim(s string) string { - return strings.Trim(s, crlftabspace) -} - -func decToInt(n string) int { - d, err := strconv.ParseInt(n, 10, 64) - if err != nil { - log.E("Error while parsing %s to int: %s", n, err) - } - return int(d) -} - -func hexToInt(h string) int { - d, err := strconv.ParseInt(h, 16, 64) - if err != nil { - log.E("Error while parsing %s to int: %s", h, err) - } - return int(d) -} - -func hexToInt2(h string) (uint, uint) { - if len(h) > 16 { - d, err := strconv.ParseUint(h[:16], 16, 64) - if err != nil { - log.E("Error while parsing %s to int: %s", h[:16], err) - } - d2, err := strconv.ParseUint(h[16:], 16, 64) - if err != nil { - log.E("Error while parsing %s to int: %s", h[16:], err) - } - return uint(d), uint(d2) - } - d, err := strconv.ParseUint(h, 16, 64) - if err != nil { - log.E("Error while parsing %s to int: %s", h[:16], err) - } - return uint(d), 0 - -} - -func hexToIP(h string) netip.Addr { - hi, lo := hexToInt2(h) - var ip net.IP - if lo != 0 { - lomsb := uint32(lo >> 32) - himsb := uint32(hi >> 32) - - // see: netip.Unmap - // stackoverflow.com/questions/22751035 - // hi: 0000 0000 0000 0000 0000 0000 0000 0000 - // lo: 0000 0000 0000 0000 wwww xxxx yyyy zzzz - if hi == 0 && lomsb == 0 { - ip = make(net.IP, 4) // v4in6 - binary.LittleEndian.PutUint32(ip, uint32(lo)) - } else { - ip = make(net.IP, 16) - // ip addresses are stored in network byte order - binary.LittleEndian.PutUint32(ip, himsb) - binary.LittleEndian.PutUint32(ip[4:], uint32(hi)) - // if v4in6: github.com/golang/go/blob/2bed2797/src/net/ip.go#L195-L196 - // mark: 0000 0000 0000 0000 1111 1111 1111 1111 - // mark := uint32(0xffff) - // binary.BigEndian.PutUint32(ip[8:], mark) - binary.LittleEndian.PutUint32(ip[8:], lomsb) - binary.LittleEndian.PutUint32(ip[12:], uint32(lo)) - } - } else { - ip = make(net.IP, 4) - binary.LittleEndian.PutUint32(ip, uint32(hi)) - } - return toUnmappedAddr(ip) -} - -func toUnmappedAddr(ip net.IP) netip.Addr { - ipp, _ := netip.AddrFromSlice(ip[:]) - return ipp.Unmap() -} - -// ParseProcNet scans /proc/net/* returns a list of entries, one entry per line scanned -func ParseProcNet(protocol string) ([]ProcNetEntry, error) { - filename := filepath.Clean(fmt.Sprintf("/proc/net/%s", protocol)) - fd, err := os.Open(filename) - if err != nil { - return nil, err - } - defer core.CloseFile(fd) - - entries := make([]ProcNetEntry, 0) - scanner := bufio.NewScanner(fd) - for lineno := 0; scanner.Scan(); lineno++ { - // skip column names - if lineno == 0 { - continue - } - - line := trim(scanner.Text()) - m := parser.FindStringSubmatch(line) - if m == nil { - log.W("Could not parse netstat line from %s: %s", filename, line) - continue - } - - entries = append(entries, NewProcNetEntry( - protocol, - hexToIP(m[1]), - hexToInt(m[2]), - hexToIP(m[3]), - hexToInt(m[4]), - decToInt(m[5]), - decToInt(m[6]), - )) - } - - go cleanupPool() - - return entries, nil -} - -// cleanupPool removes entries from the pool that are older than cachettl. -// Must be called from a goroutine. -func cleanupPool() { - defer core.Recover(core.Exit11, "procfs.cleanupPool") - - if time.Since(cache.lastcleanup).Milliseconds() <= cachettl { - return - } - cache.lastcleanup = time.Now() - - cache.pool.Range(func(k, v any) bool { - if e, ok := v.(*ProcNetEntry); ok { - if invalidProcNetEntry(e) { - deleteProcNetEntryFromPool(e) - } - } - return true - }) -} - -func invalidProcNetEntry(p *ProcNetEntry) bool { - if p == nil { - return true - } - - e := getProcNetEntryFromPool(p) - if e == nil { - return true - } - - return time.Since(e.ctime).Milliseconds() > cachettl -} - -func deleteProcNetEntryFromPool(p *ProcNetEntry) { - if p == nil { - return - } - - cache.pool.Delete(p.String()) -} - -func addProcNetEntryToPool(p *ProcNetEntry) { - if p == nil { - return - } - - cache.pool.Store(p.String(), p) -} - -func getProcNetEntryFromPool(p *ProcNetEntry) *ProcNetEntry { - if p == nil { - return nil - } - - if v, ok := cache.pool.Load(p.String()); !ok { - return nil - } else if e, ok := v.(*ProcNetEntry); !ok { - return nil - } else { - return e - } -} - -// findProcNetEntryForProtocol parses /proc/net/* and return the line matching the argument five-tuple -// (protocol, source, sport, destination, dport) as NewProcNetEntry. -func findProcNetEntryForProtocol(protocol string, src, dst netip.AddrPort) *ProcNetEntry { - - n := NewProcNetEntry(protocol, src.Addr().Unmap(), int(src.Port()), dst.Addr().Unmap(), int(dst.Port()), 0, 0) - e := &n // groups.google.com/g/golang-nuts/c/reaIlFdibWU?pli=1 - - if f := getProcNetEntryFromPool(e); e.Same(f) { - if !invalidProcNetEntry(f) { - return f - } - deleteProcNetEntryFromPool(f) - } - - entries, err := ParseProcNet(protocol) - if err != nil { - log.W("Error while searching for %s netstat entry: %s", protocol, err) - return nil - } - - for _, ent := range entries { - ep := &ent // stackoverflow.com/a/68247837 - cached := getProcNetEntryFromPool(ep) - if invalidProcNetEntry(cached) { - addProcNetEntryToPool(ep) - } - // return on first match since e.Same is pretty lax and deliberately - // not exact at matching the various procnet entries - if e.Same(ep) { - return ep - } - } - - return nil -} - -// FindProcNetEntry searches for netstat entries in v4 and v6 tables. -func FindProcNetEntry(protocol string, src, dst netip.AddrPort) *ProcNetEntry { - if entry := findProcNetEntryForProtocol(protocol, src, dst); entry != nil { - return entry - } - - ipv6Suffix := "6" - if !strings.HasSuffix(protocol, ipv6Suffix) { - otherProtocol := protocol + ipv6Suffix - return findProcNetEntryForProtocol(otherProtocol, src, dst) - } - - return nil -} diff --git a/intra/protect/icmplistener.go b/intra/protect/icmplistener.go deleted file mode 100644 index 0aa7af59..00000000 --- a/intra/protect/icmplistener.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package protect - -import ( - "context" - "net" - "os" - "strconv" - "syscall" - - "github.com/celzero/firestack/intra/core" -) - -const ( - protocolICMP = 1 - protocolIPv6ICMP = 58 -) - -type icmplistener struct { - Control ControlFn -} - -var _ core.ICMPConn = (*icmpConn)(nil) - -type icmpConn struct { - net.PacketConn -} - -func (s *icmpConn) SyscallConn() (syscall.RawConn, error) { - return sysconn(s.PacketConn) -} - -// listenICMP listens for incoming ICMP packets addressed to -// address for non-privileged datagram-oriented ICMP endpoints. -// network must be "udp4" or "udp6". The endpoint allows to read, -// write a few limited ICMP messages such as echo request & reply. -// -// Examples: -// -// listenICMP("udp4", "192.168.0.1") -// listenICMP("udp4", "0.0.0.0") -// listenICMP("udp6", "fe80::1%en0") -// listenICMP("udp6", "::") -// -// from: cs.opensource.google/go/x/net/+/refs/tags/v0.28.0:icmp/listen_posix.go -func (ln *icmplistener) listenICMP(_ context.Context, network, address string) (core.ICMPConn, error) { - var family, proto int - switch network { - case "udp4": - family, proto = syscall.AF_INET, protocolICMP - case "udp6": - family, proto = syscall.AF_INET6, protocolIPv6ICMP - default: - return nil, errNoICMPL3 - } - - // todo: controller bind4, bind6 - var cerr error - var c net.PacketConn - s, err := syscall.Socket(family, syscall.SOCK_DGRAM, proto) - if err != nil { - return nil, os.NewSyscallError("socket", err) - } - sa, err := sockaddr(family, address) - if err != nil { - syscall.Close(s) - return nil, err - } - if err := syscall.Bind(s, sa); err != nil { - syscall.Close(s) - return nil, os.NewSyscallError("bind", err) - } - // why? github.com/golang/go/issues/15021#issuecomment-308562480 - f := os.NewFile(uintptr(s), "datagram-oriented icmp") - c, cerr = net.FilePacketConn(f) // expecting a *net.UDPConn - f.Close() - if cerr != nil { - clos(c) - return nil, cerr - } - if cfn := ln.Control; cfn != nil { - var sc syscall.RawConn - if sc, err = sysconn(c); err == nil { - err = cfn(network, address, sc) - } - if err != nil { - clos(c) - return nil, err - } - } - return &icmpConn{c}, nil -} - -func sysconn(c net.PacketConn) (syscall.RawConn, error) { - switch v := c.(type) { - case *net.UDPConn: - return v.SyscallConn() - case *net.IPConn: - return v.SyscallConn() - case *net.UnixConn: - return v.SyscallConn() - default: - return nil, errNoSysConn - } -} - -// from: cs.opensource.google/go/x/net/+/refs/tags/v0.28.0:icmp/helper_posix.go -// todo: do not resolve address -func sockaddr(family int, address string) (syscall.Sockaddr, error) { - switch family { - case syscall.AF_INET: - a, err := net.ResolveIPAddr("ip4", address) - if err != nil { - return nil, err - } - if a == nil { // nilaway - return nil, net.InvalidAddrError("bad ipv4 address") - } - if len(a.IP) == 0 { - a.IP = net.IPv4zero - } - if a.IP = a.IP.To4(); a.IP == nil { - return nil, net.InvalidAddrError("non-ipv4 address") - } - sa := &syscall.SockaddrInet4{} - copy(sa.Addr[:], a.IP) - return sa, nil - case syscall.AF_INET6: - a, err := net.ResolveIPAddr("ip6", address) - if err != nil { - return nil, err - } - if a == nil { // nilaway - return nil, net.InvalidAddrError("bad ipv6 address") - } - if len(a.IP) == 0 { - a.IP = net.IPv6unspecified - } - if a.IP.Equal(net.IPv4zero) { - a.IP = net.IPv6unspecified - } - if a.IP = a.IP.To16(); a.IP == nil || a.IP.To4() != nil { - return nil, net.InvalidAddrError("non-ipv6 address") - } - sa := &syscall.SockaddrInet6{ZoneId: zoneToUint32(a.Zone)} - copy(sa.Addr[:], a.IP) - return sa, nil - default: - return nil, net.InvalidAddrError("unexpected family") - } -} - -// from: cs.opensource.google/go/x/net/+/refs/tags/v0.28.0:icmp/helper_posix.go -func zoneToUint32(zone string) uint32 { - if zone == "" { - return 0 - } - if ifi, err := net.InterfaceByName(zone); err == nil { - return uint32(ifi.Index) - } - n, err := strconv.Atoi(zone) - if err != nil { - return 0 - } - return uint32(n) -} diff --git a/intra/protect/ipmap/ipmap.go b/intra/protect/ipmap/ipmap.go deleted file mode 100644 index d3879373..00000000 --- a/intra/protect/ipmap/ipmap.go +++ /dev/null @@ -1,902 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipmap - -import ( - "context" - "net" - "net/netip" - "strings" - "sync" - "sync/atomic" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/xdns" -) - -const maxFailLimit = 8 - -var zeroaddr = netip.Addr{} - -// Type of IPSet. -type IPSetType int - -const ( - Protected IPSetType = iota - Regular - IPAddr - AutoType -) - -func (h IPSetType) String() string { - switch h { - case Protected: - return "Protected" - case Regular: - return "Regular" - case IPAddr: - return "ipaddr" - case AutoType: - return "Auto" - default: - return "Unknown" - } -} - -var UndelegatedDomainsTrie = newUndelegatedDomainTrie() - -func newUndelegatedDomainTrie() x.RadixTree { - t := x.NewRadixTree() - for _, domain := range core.UndelegatedDomains { - t.Add(domain) - } - return t -} - -// IPMapper is an interface for resolving hostnames to IP addresses. -// For internal used by firestack. -type IPMapper interface { - // Shorthand for Lookup(q, protect.UidSelf) - LocalLookup(q []byte) ([]byte, error) - // Lookup resolves q over one of the tids. If tids is empty, either - // dnsx.Default, and if that fails, dnsx.System or dnsx.Goos tids. - Lookup(q []byte, uid string, tids ...string) ([]byte, error) - // LookupFor resolves q over client-code preferred tid conveyed via - // DNSOpts returned from DNSListener.OnQuery. As a special case, UID - // may be protect.UidSelf ("rethink") or core.UNKNOWN_UID_STR ("-1") - // but otherwise it is usually a Linux user-id assigned to a process - // which presumably is requesting this lookup. - LookupFor(q []byte, uid string) ([]byte, error) - // LookupNetIP is like Lookup but with empty tids. - LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) - // LookupNetIPFor is like LookupFor - LookupNetIPFor(ctx context.Context, network, host, uid string) ([]netip.Addr, error) - // LookupNetIPOn is like Lookup but with tids set to some preset IDs - // on behalf of protect.UidSelf (rethink). - LookupNetIPOn(ctx context.Context, network, host string, tids ...string) ([]netip.Addr, error) -} - -// IPMap maps hostnames to IPSets. -type IPMap interface { - IPMapper - // Resolves hostOrIP and adds the resulting IPs to its IPSet. - // hostOrIP may be host:port, or ip:port, or host, or ip. - Add(hostOrIP string) *IPSet - // Get creates an IPSet for this hostname populated with the IPs - // discovered by resolving it. Subsequent calls to Get return the - // same IPSet. Never returns nil. - // hostOrIP may be host:port, or ip:port, or host, or ip. - Get(hostOrIP string) *IPSet - // GetAny creates an IPSet for this hostname, which may be empty. - // Subsequent calls to GetAny return the same IPSet. Never returns nil. - // hostOrIP may be host:port, or ip:port, or host, or ip. - GetAny(hostOrIP string) *IPSet - // GetMany returns a list of sampled IPs from the ipmap cache. - // ipver is one of "v4", "v6", or "" (for both). - GetMany(n uint8, ipver string) []netip.Addr - // MakeIPSet creates an IPSet for this hostname bootstrapped with given IPs - // or IP:Ports. Subsequent calls to MakeIPSet return a new, overridden IPSet. - // hostOrIP may be host:port, or ip:port, or host, or ip. - MakeIPSet(hostOrIP string, ipps []string, typ IPSetType) *IPSet - // Reverse lookup; returns hostnames for the given IP address. - ReverseGet(ip netip.Addr) []string - // ReverseGetMany returns a list of sampled hostnames from the ipmap cache. - // ipver is one of "v4", "v6", or "" (for both). - ReverseGetMany(n uint8, ipver string) []string - // With sets the default resolver to use for hostname resolution. - With(r IPMapper) - // Clear removes all IPSets from the map. - Clear() -} - -type ipmap struct { - sync.RWMutex // protects m, p, ip - - // hostOrIP => ips - m map[string]*IPSet // regular type - p map[string]*IPSet // protected ips; immutable, never cleared - ip map[string]*IPSet // ipaddrs - - // ip => host - rptr x.IpTree // regular => hostname - pptr x.IpTree // protected => hostname - - r *core.Volatile[IPMapper] // resolver -} - -// IPSet represents an unordered collection of IP addresses for a single host. -// One IP can be marked as confirmed to be working correctly. -type IPSet struct { - mu sync.RWMutex // Protects ips. - ips []netip.Addr // All known IPs for the server. - typ IPSetType // Regular, Protected, or AutoType - - r IPMapper // For hostname resolution, never nil - seed []string // Bootstrap ips or ip:ports; may be nil; is immutable. - - confirmed *core.Volatile[netip.Addr] // netip.Addr confirmed to be working. - fails atomic.Uint32 // Number of times the confirmed IP has failed. - - any4 atomic.Bool // Whether this set has IPv4 addresses. - any6 atomic.Bool // Whether this set has IPv6 addresses. -} - -func NewIPMap() *ipmap { - return NewIPMapFor(nil) -} - -// NewIPMapFor returns a fresh IPMap with r as its nameserver. -func NewIPMapFor(r IPMapper) *ipmap { - return &ipmap{ - m: make(map[string]*IPSet), - p: make(map[string]*IPSet), - ip: make(map[string]*IPSet), - - rptr: x.NewIpTree(), - pptr: x.NewIpTree(), - - r: core.NewVolatile(r), // r may be nil - } -} - -func (m *ipmap) With(r IPMapper) { - log.I("ipmap: new resolver; ok? %t", r != nil) - m.r.Store(r) // may be nil -} - -func (m *ipmap) Clear() { - m.Lock() - defer m.Unlock() - - sz := len(m.m) + len(m.p) - purge := make(chan *IPSet, sz) - defer close(purge) - - core.Go("ipmap.goclear", func() { - n := 0 - for s := range purge { - s.clear() - n++ - } - log.D("ipmap: clear: done %d/%d sets", n, sz) - }) - - n := 0 - for _, s := range m.m { // regular / auto - purge <- s // preserves seed addrs - n++ - } - for _, s := range m.p { // protected - purge <- s // only clears confirmed ip - n++ - } - - // Clear the maps after sending all items to the purge channel - clear(m.m) - // Don't clear m.p as it contains protected entries - // Don't clear m.ip as it contains IP address entries - - go m.rptr.Clear() - // ipaddr type is not "cleared" - log.I("ipmap: clear: requested %d/%d sets", n, sz) -} - -// Implements IPMapper. -func (m *ipmap) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { - r := m.r.Load() // actual ipmapper implementation - if r == nil { - return nil, &net.DNSError{Err: "no resolver", Name: host, Server: "localhost"} - } - return r.LookupNetIP(ctx, network, host) -} - -// Implements IPMapper. -func (m *ipmap) LocalLookup(q []byte) ([]byte, error) { - return m.Lookup(q, protect.UidSelf) -} - -// Implements IPMapper. -func (m *ipmap) Lookup(q []byte, uid string, tids ...string) ([]byte, error) { - r := m.r.Load() // actual ipmapper implementation - if r == nil { - return nil, &net.DNSError{Err: "no resolver", Name: "Lookup", Server: "localhost"} - } - return r.Lookup(q, uid, tids...) -} - -// Implements IPMapper. -func (m *ipmap) LookupFor(q []byte, uid string) ([]byte, error) { - r := m.r.Load() // actual ipmapper implementation - if r == nil { - return nil, &net.DNSError{Err: "no resolver", Name: "LookupFor", Server: "localhost"} - } - return r.LookupFor(q, uid) -} - -// Implements IPMapper. -func (m *ipmap) LookupNetIPFor(ctx context.Context, network, host, uid string) ([]netip.Addr, error) { - r := m.r.Load() // actual ipmapper implementation - if r == nil { - return nil, &net.DNSError{Err: "no resolver", Name: host, Server: "localhost"} - } - return r.LookupNetIPFor(ctx, network, host, uid) -} - -// Implements IPMapper. -func (m *ipmap) LookupNetIPOn(ctx context.Context, network, host string, tid ...string) ([]netip.Addr, error) { - r := m.r.Load() // actual ipmapper implementation - if r == nil { - return nil, &net.DNSError{Err: "no resolver", Name: host, Server: "localhost"} - } - return r.LookupNetIPOn(ctx, network, host, tid...) -} - -func (m *ipmap) Add(hostOrIP string) *IPSet { - s := m.get(hostOrIP, AutoType) - if _, ok := s.add(hostOrIP); ok { - log.I("ipmap: Add: resolving %s", hostOrIP) - m.revmap(hostOrIP, s, nil) - } else { - log.W("ipmap: Add: zero ips for %s", hostOrIP) - } - return s -} - -func (m *ipmap) ReverseGetMany(n uint8, ipver string) []string { - hosts := make([]string, 0, n) - - m.RLock() - defer m.RUnlock() - - hasDesiredIPFamily := func(ips *IPSet, ipver string) bool { - if ipver == "v4" { - return ips.has4() - } - if ipver == "v6" { - return ips.has6() - } - return true - } - - possiblyPublicHost := func(host string) bool { - if xdns.IsMDNSQuery(host) { - return false - } - if UndelegatedDomainsTrie.HasAny(host) { - return false - } - if _, err := netip.ParseAddr(host); err == nil { - return false // not a host, but an IP address - } - return strings.Contains(host, ".") - } - // TODO: use hosts with public prefixes - for host, ips := range m.m { - if len(hosts) >= int(n) { - break - } - if possiblyPublicHost(host) && hasDesiredIPFamily(ips, ipver) { - // append if not an IP address - hosts = append(hosts, host) - } - } - for host, ips := range m.p { - if len(hosts) >= int(n) { - break - } - if possiblyPublicHost(host) && hasDesiredIPFamily(ips, ipver) { - // append if not an IP address - hosts = append(hosts, host) - } - } - - log.I("ipmap: ReverseGetMany: sampled %d hosts", len(hosts)) - return hosts -} - -func (m *ipmap) ReverseGet(ip netip.Addr) []string { - q := ip.String() - - s, _ := m.rptr.Get(q) - hosts := s - if len(hosts) > 0 { - return strings.Split(hosts, x.Vsep) - } - - s, _ = m.pptr.Get(q) - hosts = s - if len(hosts) > 0 { - return strings.Split(hosts, x.Vsep) - } - return nil -} - -func (m *ipmap) Get(hostOrIP string) *IPSet { - s := m.get(hostOrIP, AutoType) - if s.Empty() { - log.I("ipmap: Get: resolving %s", hostOrIP) - if _, ok := s.add(hostOrIP); !ok { - log.W("ipmap: Get: zero ips for %s", hostOrIP) - } else { - m.revmap(hostOrIP, s, nil) - } - } - log.D("ipmap: Get: %s => %s", hostOrIP, s.ips) - return s -} - -func (m *ipmap) GetAny(hostOrIP string) *IPSet { - return m.get(hostOrIP, AutoType) // may be empty -} - -func (m *ipmap) get(hostOrIP string, typ IPSetType) (s *IPSet) { - if host, _, err := net.SplitHostPort(hostOrIP); err == nil { - hostOrIP = host - } - - m.RLock() - sp := m.p[hostOrIP] - sr := m.m[hostOrIP] - si := m.ip[hostOrIP] - m.RUnlock() - - if sp != nil || typ == Protected { - s = sp // may be nil or empty - typ = Protected // discard Regular or AutoType - } else if si != nil || typ == IPAddr { - // TODO: assert hostOrIP is a valid IP or IP:Port - s = si - typ = IPAddr - } else { // Regular or AutoType - s = sr // may be nil or empty - typ = Regular // discard AutoType - } - - if s == nil { - s = m.makeIPSet(hostOrIP, nil, typ) // typ is never AutoType - } - - return s -} - -func (m *ipmap) GetMany(n uint8, ipver string) []netip.Addr { - m.RLock() - defer m.RUnlock() - - ips := make([]netip.Addr, 0, n) - - desiredfamily := func(ip netip.Addr) bool { - if ipver == "v4" { - return ip.Is4() - } - if ipver == "v6" { - return ip.Is6() - } - return ip.IsValid() // both - } - oneip := func(s *IPSet) (zz netip.Addr) { - confirmed := s.confirmed.Load() - if desiredfamily(confirmed) && confirmed.IsGlobalUnicast() { - return confirmed - } - for _, ip := range s.ips { - if desiredfamily(ip) && ip.IsGlobalUnicast() { - return ip - } - } - return - } - for _, s := range m.m { - if len(ips) >= int(n) { - break - } - if ip := oneip(s); ip.IsValid() { - ips = append(ips, ip) - } - } - for _, s := range m.p { - if len(ips) >= int(n) { - break - } - if ip := oneip(s); ip.IsValid() { - ips = append(ips, ip) - } - } - for _, s := range m.ip { - if len(ips) >= int(n) { - break - } - if ip := oneip(s); ip.IsValid() { - ips = append(ips, ip) - } - } - - log.I("ipmap: GetMany: sampled %d ips", len(ips)) - return ips -} - -func (m *ipmap) MakeIPSet(hostOrIP string, ipps []string, typ IPSetType) *IPSet { - if host, _, err := net.SplitHostPort(hostOrIP); err == nil { - hostOrIP = host - } - if len(ipps) <= 0 && typ == Protected { - ip, err := netip.ParseAddr(hostOrIP) - if err != nil || !ip.IsValid() { - // TODO: error? - log.T("ipmap: renew: %s; empty seed for Protected!", hostOrIP) - ipps = nil // fallback to AutoType - } else { - ipps = []string{hostOrIP} - log.I("ipmap: renew: %s Protected type seeded from hostOrIP: %v", hostOrIP, ipps) - } - } else { - // TODO: hostOrIP must be IP (or IP:Port) if typ == IPAddr - log.D("ipmap: renew: %s / seed: %v / typ: %s", hostOrIP, ipps, typ) - } - return m.makeIPSet(hostOrIP, ipps, typ) -} - -func (m *ipmap) makeIPSet(hostname string, ipps []string, ogtyp IPSetType) *IPSet { - var ip netip.Addr - var err error - typ := ogtyp - if len(ipps) == 0 { - typ = AutoType - ipps = []string{} - } - - mm := m.m // Regular or AutoType - if protect.NeverResolve(hostname) || typ == Protected { - mm = m.p - typ = Protected // discard Regular or AutoType - } else if ip, err = netip.ParseAddr(hostname); err == nil && !ip.IsUnspecified() && ip.IsValid() { - mm = m.ip - ogtyp = IPAddr // reset (avoid err log below) - typ = IPAddr // may be set to AutoType above - } else { - typ = Regular // discard AutoType & IPAddr type - } - - logeif(typ != ogtyp)("ipmap: makeIPSet: %s, seed: %v, typ: %s, ogtyp: %s", hostname, ipps, typ, ogtyp) - - s := &IPSet{ - confirmed: core.NewZeroVolatile[netip.Addr](), - typ: typ, - r: m, // m stays constant, but m.r may change - seed: core.CopyUniq(ipps), - fails: atomic.Uint32{}, - } - if typ == IPAddr { - log.D("ipmap: makeIPSet: %s for %s, confirmed addr %s", hostname, typ, ip) - s.confirmed.Store(ip) - // s.ips is empty for typ == IPAddr - } else { - s.confirmed.Store(zeroaddr) - } - - totalseeds := s.bootstrap() // seed addrs only - - // if typ is Protected, then seeds must never be empty - if typ == Protected && totalseeds <= 0 { - log.W("ipmap: makeIPSet: zero seeds; %s for type %s discarded", hostname, typ) - } else { - m.Lock() - prev := mm[hostname] // prev may be nil - mm[hostname] = s // overwrites existing - m.Unlock() - m.revmap(hostname, s, prev) - } - - return s -} - -func (m *ipmap) revmap(hostOrIP string, new *IPSet, old *IPSet) { - if host, _, err := net.SplitHostPort(hostOrIP); err == nil { - hostOrIP = host - } - if maybeip, _ := netip.ParseAddr(hostOrIP); maybeip.IsValid() { - return // no-op - } - host := hostOrIP - - add := new.Addrs() // new may be nil or addrs() may be empty - remove := old.Addrs() // old may be nil or addrs() may be empty - - addtree, rmvtree := m.rptr, m.rptr - if new != nil { - if new.typ == Protected { - addtree = m.pptr - } - } - if old != nil { - if old.typ == Protected { - rmvtree = m.pptr - } - } - - var errs []error - r, a := 0, 0 - for _, ip := range remove { - if ip.IsValid() { - q := ip.String() - if rmvtree.Esc(q, host) { - r++ - } - } - } - for _, ip := range add { - if ip.IsValid() { - q := ip.String() - if err := addtree.Add(q, host); err == nil { - a++ - } else { - errs = append(errs, err) - } - } - } - log.D("ipmap: rev: for %s: added %d/%d; removed %d/%d; errs: %v", - host, a, len(add), r, len(remove), core.UniqErr(errs...)) -} - -// Reports whether ip is in the set. Must be called under RLock. -func (s *IPSet) hasLocked(ip netip.Addr) bool { - for _, oldIP := range s.ips { - if oldIP.Compare(ip) == 0 { - return true - } - } - return false -} - -// Adds an IP to the set if it is not present. Must be called under Lock. -func (s *IPSet) addLocked(ips ...netip.Addr) { - if s.typ == IPAddr { // nothing to do as this ipset only has one ipaddr - if len(ips) > 0 { - log.W("ipmap: addLocked: ipaddr type; ignoring %d ips", len(ips)) - } - return - } - - if len(ips) <= 0 { - if s.typ == Regular { - log.I("ipmap: addLocked: remove all") - s.ips = nil // remove all ips - s.any4.Store(false) - s.any6.Store(false) - } else { - log.E("ipmap: addLocked: Protected/IPAddr type; ignoring empty add; seed: %v", s.seed) - } - return - } - new4, new6 := false, false - for i, ip := range ips { - // always unmapped; github.com/golang/go/issues/53607 - ip = ip.Unmap() - nouns := !ip.IsUnspecified() - valip := ip.IsValid() - newip := !s.hasLocked(ip) - if nouns && valip && newip { - s.ips = append(s.ips, ip) - new4 = new4 || ip.Is4() - new6 = new6 || ip.Is6() - } else { - log.D("ipmap: add #%d: fail %s; uns? %t, val? %t, !new? %t", i, ip, !nouns, valip, newip) - } - } - - s.any4.Store(new4 || s.any4.Load()) - s.any6.Store(new6 || s.any6.Load()) -} - -// Returns bootstrap ips or ip:ports. -func (s *IPSet) Seed() []string { - return s.seed -} - -// add one or more IP addresses to the set. -// The hostname can be a domain name or an IP address. -func (s *IPSet) add(hostOrIP string) ([]netip.Addr, bool) { - if s.typ == IPAddr { // nothing to do as this ipset only has one ipaddr - return nil, false - } - - if host, _, err := net.SplitHostPort(hostOrIP); err == nil { - hostOrIP = host - } - r := s.r - if r == nil { // unlikely; s.r is never nil - log.W("ipmap: Add: no resolver for %s", hostOrIP) - return nil, false - } - - ctx := context.Background() - - var resolved []netip.Addr - var err error - if s.typ == Protected { - // dnsx.System is "never resolved" and hence can be used to resolve - // "protected" IPSets like the one used by bootstrap's DoH (x.Default) - // see: protect.NeverResolve and dnsx.RegisterAddrs - resolved, err = r.LookupNetIPOn(ctx, "ip", hostOrIP, x.System) - } else if s.typ == Regular || s.typ == AutoType { - resolved, err = r.LookupNetIP(ctx, "ip", hostOrIP) - } - - if err != nil { - log.W("ipmap: Add: err resolving %s: %v", hostOrIP, err) - return nil, false - } else { - log.D("ipmap: Add: resolved? %s => %s", hostOrIP, resolved) - } - - if len(resolved) > 0 { - s.mu.Lock() - s.addLocked(resolved...) // resolved may be nil - s.mu.Unlock() - } - - ok := !s.Empty() - if ok { - s.fails.Store(0) // reset fails, since we have a new ips - } - - return resolved, ok // resolved may be nil, even if ok == true -} - -// Adds seed IP addresses to the set, if any. -func (s *IPSet) bootstrap() (n int) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, ipstr := range s.seed { - ipstr = strings.TrimSpace(ipstr) - if len(ipstr) <= 0 { - continue - } - if ip, err := netip.ParseAddr(ipstr); err == nil { - s.addLocked(ip) - n += 1 - } else { - if ipport, err2 := netip.ParseAddrPort(ipstr); err2 == nil { - s.addLocked(ipport.Addr()) - n += 1 - } else { - log.W("ipmap: seed: invalid ipstr %s: err1 %v / err2 %v", ipstr, err, err2) - } - } - } - return n -} - -func (s *IPSet) has4() bool { - if s == nil { - return false - } - if s.typ == IPAddr { // ipaddr always has one ip - return s.confirmed.Load().Is4() - } - return s.any4.Load() -} - -func (s *IPSet) has6() bool { - if s == nil { - return false - } - if s.typ == IPAddr { // ipaddr always has one ip - return s.confirmed.Load().Is6() - } - return s.any6.Load() -} - -// Empty reports whether the set is empty. -func (s *IPSet) Empty() bool { - if s == nil { - return true - } - // typ == IPAddr is never empty! - return s.Size() == 0 -} - -func (s *IPSet) Size() uint32 { - if s.typ == IPAddr { // IPAddr type always has one ip (confirmed) - return 1 - } - - s.mu.RLock() - defer s.mu.RUnlock() - return uint32(len(s.ips)) -} - -// Addrs returns a copy of the IP set as a slice in random order. -// The slice is owned by the caller, but the elements are owned by the set. -func (s *IPSet) Addrs() []netip.Addr { - if s == nil { - return []netip.Addr{} - } - - if s.typ == IPAddr { // fast path for ipaddrs - return []netip.Addr{s.confirmed.Load()} - } - - s.mu.RLock() - sz := len(s.ips) - if sz <= 0 { - s.mu.RUnlock() - return []netip.Addr{} - } - c := make([]netip.Addr, 0, sz) - c = append(c, s.ips...) - s.mu.RUnlock() - - return core.ShuffleInPlace(c) -} - -func (s *IPSet) Protected() bool { - return s.typ == Protected -} - -func (s *IPSet) OneIPOnly() bool { - return s.typ == IPAddr -} - -// Confirmed returns the confirmed IP address, or zeroaddr if there is no such address. -func (s *IPSet) Confirmed() netip.Addr { - return s.confirmed.Load() -} - -// Confirm marks ip as the confirmed address. -// No-op if current confirmed IP is the same as ip. -// No-op if s is of type Protected and ip isn't in seed addrs. -// No-op if s is of type IPAddr. -func (s *IPSet) Confirm(ip netip.Addr) { - if s.typ == IPAddr { // ipaddr fast path, no-op - return - } - - // do not reset fails, as confirmed ipaddrs may be repeatedly - // disconfirmed by upstream clients (for example; dialers may - // confirm an ip if it successfully dials, but upstream clients - // like dot/doh/dns53 may disconfirm them on HTTP/DNS errors). - // We'd want to keep incrementing failures, so an eventual - // reset can happen once a generous maxFailLimit is exhausted. - // s.fails.Store(0) - if ip.Compare(s.confirmed.Load()) == 0 { - return // no-op - } - - // since mutex are acquired, perform ops asynchronously - core.Gx("ipset.cfm."+ip.String(), func() { - s.mu.RLock() - newIP := !s.hasLocked(ip) - s.mu.RUnlock() - - // Protected IPSets should stay consistent with its seed addrs - // and must not add or confirm unseeded IPs. This happens in cases - // where an IP from a previous Protected IPSet may be confirmed at - // a time after the IPSet has been updated to a new one. For example, - // if UidSelf / UidSystem has been changed to new System DNS IPs, - // a goroutine using the previous UidSelf / UidSystem IPSet may - // end up confirming IP address in the new one. - if s.typ == Protected && newIP { - s.confirmed.Store(zeroaddr) // reset instead - return - } - - s.confirmed.Store(ip) - - // Add this IP to the set if it hasn't been seen before. - if s.typ != Protected && newIP { - s.mu.Lock() - s.addLocked(ip) // Add is O(N) - s.mu.Unlock() - } - }) -} - -// Reset clears existing IPs for Regular and Protected types, -// while it is a no-op for type IPAddr. -func (s *IPSet) Reset() *IPSet { - s.clear() - return s -} - -func (s *IPSet) clear() { - if s.typ == IPAddr { // no-op for ipaddr - return - } - - if s.typ != Protected { - s.mu.Lock() - s.addLocked() // removes all ips - s.mu.Unlock() - } - s.confirmed.Store(zeroaddr) - s.fails.Store(0) -} - -// Disconfirm sets the confirmed address to zeroaddr if the current confirmed address -// is the provided ip. -func (s *IPSet) Disconfirm(ip netip.Addr) (done bool) { - if s.typ == IPAddr { // no-op for ipaddr - return false - } - - c := s.confirmed.Load() - if ip.Compare(c) == 0 { - s.confirmed.Store(zeroaddr) - done = true - } - - // Size and clear require a lock, do so asynchronously - core.Go("ipset.dis."+ip.String(), func() { - // if s is not empty, act on disconfirm - if sz := s.Size(); sz > 0 { - tot := s.fails.Load() - // either the confirmed was disconfirmed above - // or s never had a confirmed ip, but still - // Disconfirm() was called, indicating a failure - if done || c.Compare(zeroaddr) == 0 { - tot = s.fails.Add(1) - } - - if tot > max(2*sz, maxFailLimit) { - // empty out the set, may be refilled by Get() - if s.fails.CompareAndSwap(tot, 0) { - s.clear() - } - } - } - }) - return -} - -func logeif(cond bool) log.LogFn { - if cond { - return log.E - } - return log.D -} diff --git a/intra/protect/ipmap/ipmap_test.go b/intra/protect/ipmap/ipmap_test.go deleted file mode 100644 index 27a53d33..00000000 --- a/intra/protect/ipmap/ipmap_test.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipmap - -import ( - "context" - "errors" - "net" - "net/netip" - "sync/atomic" - "testing" -) - -func TestGetTwice(t *testing.T) { - m := NewIPMap() - a := m.Get("example") - b := m.Get("example") - if a == b { - t.Error("Matched Get returned different objects") - } -} - -func TestGetInvalid(t *testing.T) { - m := NewIPMap() - s := m.Get("example") - if !s.Empty() { - t.Error("Invalid name should result in an empty set") - } - if len(s.Addrs()) != 0 { - t.Error("Empty set should be empty") - } -} - -func TestGetDomain(t *testing.T) { - m := NewIPMap() - s := m.Get("www.google.com") - if s.Empty() { - t.Error("Google lookup failed") - } - ips := s.Addrs() - if len(ips) == 0 { - t.Fatal("IP set is empty") - } else if !ips[0].IsValid() { - t.Error("nil IP in set") - } -} - -func TestGetIP(t *testing.T) { - m := NewIPMap() - s := m.Get("192.0.2.1") - if s.Empty() { - t.Error("IP parsing failed") - } - ips := s.Addrs() - if len(ips) != 1 { - t.Errorf("Wrong IP set size %d", len(ips)) - } else if ips[0].Unmap().String() != "192.0.2.1" { - t.Error("Wrong IP") - } -} - -func TestAddDomain(t *testing.T) { - m := NewIPMap() - s := m.Get("example") - s.add("www.google.com") - if s.Empty() { - t.Error("Google lookup failed") - } - ips := s.Addrs() - if len(ips) == 0 { - t.Fatal("IP set is empty") - } else if !ips[0].IsValid() { - t.Error("nil IP in set") - } -} -func TestAddIP(t *testing.T) { - m := NewIPMap() - s := m.Get("example") - s.add("192.0.2.1") - ips := s.Addrs() - if len(ips) != 1 { - t.Errorf("Wrong IP set size %d", len(ips)) - } else if ips[0].Unmap().String() != "192.0.2.1" { - t.Error("Wrong IP") - } -} - -func TestConfirmed(t *testing.T) { - m := NewIPMap() - fqdn := "www.google.com" - s := m.Get(fqdn) - if s.Confirmed().IsValid() { - t.Error("Confirmed should start out nil") - } - - ips := s.Addrs() - if len(ips) == 0 { - t.Fatalf("Empty IPSet for %s", fqdn) - return - } - s.Confirm(ips[0]) - if ips[0].Compare(s.Confirmed()) != 0 { - t.Error("Confirmation failed") - } - - s.Disconfirm(ips[0]) - if s.Confirmed().IsValid() { - t.Error("Confirmed should now be nil") - } -} - -func TestConfirmNew(t *testing.T) { - m := NewIPMap() - s := m.Get("example") - s.add("192.0.2.1") - // Confirm a new address. - s.Confirm(netip.MustParseAddr("192.0.2.2")) - if !s.Confirmed().IsValid() || s.Confirmed().String() != "192.0.2.2" { - t.Error("Confirmation failed") - } - ips := s.Addrs() - if len(ips) != 2 { - t.Errorf("New address not added to the set; %v", ips) - } -} - -func TestDisconfirmMismatch(t *testing.T) { - m := NewIPMap() - fqdn := "www.google.com" - s := m.Get(fqdn) - ips := s.Addrs() - if len(ips) == 0 { - t.Fatalf("Empty IPSet for %s", fqdn) - return - } - s.Confirm(ips[0]) - - // Make a copy - otherIP := netip.MustParseAddr(ips[0].String()) - // Alter it - otherIP = otherIP.Next() - // Disconfirm. This should have no effect because otherIP - // is not the confirmed IP. - s.Disconfirm(otherIP) - - if ips[0].Compare(s.Confirmed()) != 0 { - t.Error("Mismatched disconfirmation") - } -} - -type fakeResolver struct { - *net.Resolver -} - -func (r fakeResolver) LocalLookup([]byte) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) Lookup([]byte, string, ...string) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupFor([]byte, string) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIP(_ context.Context, _, _ string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIPFor(_ context.Context, _, _, _ string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -func (r fakeResolver) LookupNetIPOn(_ context.Context, _, _ string, _ ...string) ([]netip.Addr, error) { - return nil, errors.New("not implemented") -} - -func TestResolver(t *testing.T) { - var dialCount int32 - r := &net.Resolver{ - PreferGo: true, - Dial: func(context context.Context, network, address string) (net.Conn, error) { - atomic.AddInt32(&dialCount, 1) - return nil, errors.New("Fake dialer") - }, - } - resolver := &fakeResolver{r} - m := NewIPMapFor(resolver) - s := m.Get("www.google.com") - if !s.Empty() { - t.Error("Google lookup should have failed due to fake dialer") - } - if atomic.LoadInt32(&dialCount) == 0 { - t.Error("Fake dialer didn't run") - } -} diff --git a/intra/protect/protect.go b/intra/protect/protect.go deleted file mode 100644 index 130a1e51..00000000 --- a/intra/protect/protect.go +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protect - -import ( - "context" - "net" - "net/netip" - "strings" - "syscall" - "time" - - b "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" -) - -// See: ipmap.LookupNetIP; UidSelf -> dnsx.Default; UidSystem -> dnsx.System -const ( - UidSelf = b.UidSelf - UidSystem = b.UidSystem - Localhost = b.Localhost - - // hostless is a special placeholder prefix for dns53 transport that - // has multiple IP:port addresses and is "protected" (never resolved nor - // cleaned up on refresh, but only "updated"). - HostlessPrefix = "hostless." - - // When advertised routes are actually null routed (no reply) - // this timeout will help short circuit dial attempts on it. - defaultConnectTimeout = 10 * time.Second - - // if true, only protects the socket from routing loops & binds to active network. - onlyProtectWildcardAddrs = false -) - -// never resolve system/default/"hostless" resolver; expected to have seeded ips -func NeverResolve(hostname string) bool { - return hostname == UidSelf || hostname == UidSystem || strings.HasPrefix(hostname, HostlessPrefix) -} - -type Controller = b.Controller -type Protector = b.Protector - -type ControlFn func(network, addr string, c syscall.RawConn) (err error) - -// returns true if addr is a global unicast address; and yn on error. -func maybeGlobalUnicast(addr string, yn bool) bool { - if ipport, err := netip.ParseAddrPort(addr); err == nil { - return ipport.Addr().IsGlobalUnicast() - } else if ip, err := netip.ParseAddr(addr); err == nil { - return ip.IsGlobalUnicast() - } // ignore addr; it may be a wildcard or just hostname - return yn -} - -// Binds a socket to a particular network interface. -func ifbind(who string, ctl Controller) func(string, string, syscall.RawConn) error { - return func(network, addr string, c syscall.RawConn) (err error) { - // addr may be a wildcard aka ":", in which case dst is a zero address. - log.VV("control: netbinder: %s: %s(%s); err? %v", who, network, addr, err) - return c.Control(func(fd uintptr) { - sock := int(fd) - if onlyProtectWildcardAddrs && !maybeGlobalUnicast(addr, true) { - ctl.Protect(who, sock) - return - } - switch network { - case "tcp6", "udp6": - ctl.Bind6(who, addr, sock) - case "tcp4", "udp4": - ctl.Bind4(who, addr, sock) - case "tcp", "udp": // unexpected dual-stack socket - fallthrough // Control usually qualifies protocol family for the fd - default: - ctl.Protect(who, sock) - } - }) - } -} - -// unused: Binds a socket to a local ip. -func ipbind(p Protector) func(string, string, syscall.RawConn) error { - return func(network, addr string, c syscall.RawConn) (err error) { - src := p.UIP(addr) - ipaddr, _ := netip.AddrFromSlice(src) - origaddr, perr := netip.ParseAddrPort(addr) - log.VV("control: ipbinder: %s(%s/%w), bindto(%s); err? %v", - network, addr, origaddr, ipaddr, perr) - - if onlyProtectWildcardAddrs && !maybeGlobalUnicast(addr, true) { - // todo: protect fd? - return nil - } - - bind6 := func(fd uintptr) error { - sc := &syscall.SockaddrInet6{Addr: ipaddr.As16()} - return syscall.Bind(int(fd), sc) - } - bind4 := func(fd uintptr) error { - sc := &syscall.SockaddrInet4{Addr: ipaddr.As4()} - return syscall.Bind(int(fd), sc) - } - - return c.Control(func(fd uintptr) { - switch network { - case "tcp6", "udp6": - // TODO: zone := origaddr.Addr().Zone() - err = bind6(fd) - case "tcp4", "udp4": - err = bind4(fd) - case "tcp", "udp": // unexpected dual-stack socket? - fallthrough // see: networkBinder - default: - // no-op - // protect fd? - } - if err != nil { - log.E("protect: fail to bind ip(%s) to socket %v", ipaddr, err) - } - }) - } -} - -// unused: Creates a dialer that binds to a particular ip. -func MakeDialer(p Protector) *net.Dialer { - x := netdialer() - if p != nil && core.IsNotNil(p) { - x.Control = ipbind(p) - } - x.Timeout = defaultConnectTimeout // overriden by deadlines - return x -} - -// unused: Creates a listener that binds to a particular ip. -func MakeListenConfig(p Protector) *net.ListenConfig { - x := netlistener() - if p != nil && core.IsNotNil(p) { - x.Control = ipbind(p) - } - return x -} - -// Creates a net.Dialer that can bind to any active interface. -func MakeNsDialer(who string, c Controller) *net.Dialer { - x := netdialer() - if c != nil && core.IsNotNil(c) { - x.Control = ifbind(who, c) - } - x.Timeout = defaultConnectTimeout // overriden by deadlines - return x -} - -// Creates a RDial that can bind to any active interface. -func MakeNsRDial(who string, ctx context.Context, c Controller) *RDial { - return &RDial{ - owner: who, - ctx: ctx, - dialer: MakeNsDialer(who, c), - listen: MakeNsListener(who, c), - listenICMP: MakeNsICMPListener(who, c), - } -} - -// Creates a RDial that can bind to any active interface, with additional control fns. -func MakeNsRDialExt(who string, ctx context.Context, ctl Controller, ext ...ControlFn) *RDial { - dialer := MakeNsDialer(who, ctl) - dialer.Control = func(network, address string, c syscall.RawConn) error { - for _, fn := range ext { - if err := fn(network, address, c); err != nil { - return err - } - } - if ctl != nil && core.IsNotNil(ctl) { - if err := ifbind(who, ctl)(network, address, c); err != nil { - return err - } - } - return nil - } - listener := MakeNsListenConfigExt(who, ctl, ext) - icmplistener := MakeNsICMPListenerExt(who, ctl, ext) - return &RDial{ - owner: who, - ctx: ctx, - dialer: dialer, - listen: listener, - listenICMP: icmplistener, - } -} - -// Creates a listener that can bind to any active interface. -func MakeNsListener(who string, c Controller) *net.ListenConfig { - x := netlistener() - if c != nil && core.IsNotNil(c) { - x.Control = ifbind(who, c) - } - return x -} - -// Creates a listener that can bind to any active interface, with additional control fns. -func MakeNsListenConfigExt(who string, ctl Controller, ext []ControlFn) *net.ListenConfig { - x := netlistener() - x.Control = func(network, address string, c syscall.RawConn) error { - for _, fn := range ext { // must do prior to ctl.bind - if err := fn(network, address, c); err != nil { - return err - } - } - if ctl != nil && core.IsNotNil(ctl) { - if err := ifbind(who, ctl)(network, address, c); err != nil { - return err - } - } - return nil - } - return x -} - -func MakeNsICMPListener(who string, c Controller) *icmplistener { - x := icmpListener() - if c != nil && core.IsNotNil(c) { - x.Control = ifbind(who, c) - } - return x -} - -func MakeNsICMPListenerExt(who string, ctl Controller, ext []ControlFn) *icmplistener { - x := icmpListener() - x.Control = func(network, address string, c syscall.RawConn) error { - for _, fn := range ext { // must do prior to ctl.bind - if err := fn(network, address, c); err != nil { - return err - } - } - if ctl != nil && core.IsNotNil(ctl) { - if err := ifbind(who, ctl)(network, address, c); err != nil { - return err - } - } - return nil - } - return x -} - -// Creates a plain old dialer -func netdialer() *net.Dialer { - x := &net.Dialer{} - // todo: x.KeepAliveConfig = kacfg - return x -} - -// Creates a plain old listener -func netlistener() *net.ListenConfig { - x := &net.ListenConfig{} - // todo: x.KeepAliveConfig = kacfg - return x -} - -// Creates a icmp listener over UDP -func icmpListener() *icmplistener { - return &icmplistener{} -} diff --git a/intra/protect/protect_test.go b/intra/protect/protect_test.go deleted file mode 100644 index 1784a3c6..00000000 --- a/intra/protect/protect_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package protect - -import ( - "context" - "errors" - "net" - "sync" - "syscall" - "testing" -) - -// The fake protector just records the file descriptors it was given. -type fakeProtector struct { - mu sync.Mutex - fds []int32 -} - -// Implements Protector. -func (p *fakeProtector) UIP(n string) []byte { - return net.IPv4(127, 0, 0, 1) -} - -func (p *fakeProtector) Protect(fd int32) bool { - p.mu.Lock() - p.fds = append(p.fds, fd) - p.mu.Unlock() - return true -} - -func (p *fakeProtector) GetResolvers() string { - return "8.8.8.8,2001:4860:4860::8888" -} - -// This interface serves as a supertype of net.TCPConn and net.UDPConn, so -// that they can share the verifyMatch() function. -type hasSyscallConn interface { - SyscallConn() (syscall.RawConn, error) -} - -func verifyMatch(t *testing.T, conn hasSyscallConn, p *fakeProtector) { - rawconn, err := conn.SyscallConn() - if err != nil || rawconn == nil { - t.Fatal(errors.Join(err, &net.OpError{Op: "SyscallConn"})) - } - _ = rawconn.Control(func(fd uintptr) { - if len(p.fds) == 0 { - t.Fatalf("No file descriptors") - } - if int32(fd) != p.fds[0] { - t.Fatalf("File descriptor mismatch: %d != %d", fd, p.fds[0]) - } - }) -} - -func TestDialTCP(t *testing.T) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - go l.Accept() - - p := &fakeProtector{} - d := MakeDialer(p) - if d.Control == nil { - t.Errorf("Control function is nil") - } - - conn, err := d.Dial("tcp", l.Addr().String()) - if err != nil || conn == nil { - t.Fatal(errors.Join(err, &net.OpError{Op: "Dial"})) - return - } - verifyMatch(t, conn.(*net.TCPConn), p) - clos(conn) - clos(l) -} - -func TestListenUDP(t *testing.T) { - udpaddr, err := net.ResolveUDPAddr("udp", "localhost:0") - if err != nil { - t.Fatal(err) - } - - p := &fakeProtector{} - c := MakeListenConfig(p) - - conn, err := c.ListenPacket(context.Background(), udpaddr.Network(), udpaddr.String()) - if err != nil || conn == nil { - t.Fatal(errors.Join(err, &net.OpError{Op: "ListenPacket"})) - return - } - verifyMatch(t, conn.(*net.UDPConn), p) - clos(conn) -} - -func TestLookupIPAddr(t *testing.T) { - p := &fakeProtector{} - d := MakeDialer(p) - d.Resolver.LookupIPAddr(context.Background(), "foo.test.") - // Verify that Protect was called. - if len(p.fds) == 0 { - t.Fatal("Protect was not called") - } -} - -func TestNilDialer(t *testing.T) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - go l.Accept() - - d := MakeDialer(nil) - conn, err := d.Dial("tcp", l.Addr().String()) - if err != nil || conn == nil { - t.Fatal(errors.Join(err, &net.OpError{Op: "Dial"})) - return - } - - clos(conn) - clos(l) -} - -func TestNilListener(t *testing.T) { - udpaddr, err := net.ResolveUDPAddr("udp", "localhost:0") - if err != nil { - t.Fatal(err) - } - - c := MakeListenConfig(nil) - conn, err := c.ListenPacket(context.Background(), udpaddr.Network(), udpaddr.String()) - if err != nil || conn == nil { - t.Fatal(errors.Join(err, &net.OpError{Op: "Dial"})) - } - - clos(conn) -} diff --git a/intra/protect/xdial.go b/intra/protect/xdial.go deleted file mode 100644 index ca3a1c11..00000000 --- a/intra/protect/xdial.go +++ /dev/null @@ -1,343 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package protect - -import ( - "context" - "errors" - "io" - "net" - "net/netip" - "strconv" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" -) - -var ( - anyaddr4 = netip.IPv4Unspecified() - anyaddr6 = netip.IPv6Unspecified() - alwaysDualStack = true -) - -// type alias: go.dev/blog/alias-names / archive.vn/IZjgc - -// Adapter to keep gomobile happy as it can't export net.Conn -type Conn = net.Conn - -type PacketConn = net.PacketConn - -type MinConn = core.MinConn - -type Listener = net.Listener - -type DialFn func(network, addr string) (net.Conn, error) - -type RDialer interface { - ID() string - // Dial creates a connection to the given address, - // the resulting net.Conn must be a *net.TCPConn if - // network is "tcp" or "tcp4" or "tcp6" and must be - // a *net.UDPConn if network is "udp" or "udp4" or "udp6". - Dial(network, addr string) (Conn, error) - // DialBind is like Dial but creates a connection to - // the remote address bounded from the local port (not ip). - // If local is invalid ip:port (ip must be present but not used), - // it delegates to Dial(network, remote). - DialBind(network, local, remote string) (Conn, error) - // Announce announces the local address. network must be - // packet-oriented ("udp" or "udp4" or "udp6"). - Announce(network, local string) (PacketConn, error) - // Accept creates a listener on the local address. network - // must be stream-oriented ("tcp" or "tcp4" or "tcp6"). - Accept(network, local string) (Listener, error) - // Probe listens on the local address for ICMP packets sent - // over UDP. Network must be "udp" or "udp4" or "udp6". - Probe(network, local string) (PacketConn, error) -} - -// RDial adapts dialers and listeners to RDialer. -// It always discards bind address. -type RDial struct { - owner string // owner tag - ctx context.Context - // local dialer - dialer *net.Dialer // may be nil; used by exit, base, grounded - listen *net.ListenConfig // may be nil; used by exit, base, grounded - listenICMP *icmplistener // may be nil; used by exit, base, grounded -} - -var _ RDialer = (*RDial)(nil) - -var ( - errNoDialer = errors.New("not a dialer") - errNoRAddr = errors.New("missing remote addr") - errNoTCP = errors.New("not a tcp dialer") - errNoUDP = errors.New("not a udp dialer") - errNoUDPMux = errors.New("not a udp announcer") - errNoTCPMux = errors.New("not a tcp announcer") - errNoICMPL3 = errors.New("not an ip:icmp listener") - errNoSysConn = errors.New("no syscall.Conn") - errAnnounce = errors.New("cannot announce network") - errAccept = errors.New("cannot accept network") -) - -func (d *RDial) context() context.Context { - if d.ctx != nil { - return d.ctx - } - return context.Background() -} - -// ID implements RDialer. -func (d *RDial) ID() string { - if d.owner != "" { - return d.owner - } - return "xdial" // ownerless -} - -// Dial implements RDialer. -func (d *RDial) Dial(network, addr string) (net.Conn, error) { - return d.dialer.DialContext(d.context(), network, addr) -} - -func (d *RDial) cloneDialer() *net.Dialer { - var rd *net.Dialer = new(net.Dialer) - // shallow copy: go.dev/play/p/tuadSFN3glj - *rd = *d.dialer - return rd -} - -// DialBind implements RDialer. -func (d *RDial) DialBind(network, local, remote string) (net.Conn, error) { - var onlyport netip.AddrPort - rd := d.cloneDialer() - - if _, port, err := net.SplitHostPort(local); err == nil { - // uport may be 0, which is "valid" - uport, _ := strconv.Atoi(port) // should not error - - anyaddr := anyaddr6 - if !alwaysDualStack { - anyaddr = anyaddr4 - } - switch network { - case "tcp4": - anyaddr = anyaddr4 - case "tcp6": - anyaddr = anyaddr6 - } - if !alwaysDualStack { - // ipp invalid when local is without ip; ex: ":port" - if ipp, _ := netip.ParseAddrPort(local); ipp.Addr().Is4() { - anyaddr = anyaddr4 - } - } - // ip addr binding is left upto dialer's Control - // which is "namespace" aware (on Android) - onlyport = netip.AddrPortFrom(anyaddr, uint16(uport)) - } else { // okay for local to be invalid; called by retrier.DialTCP - log.VV("xdial: DialBind: (o: %s); %s %s=>%s; why: laddr nil", - d.owner, network, local, remote) - } - - switch network { - case "tcp", "tcp4", "tcp6": - if alwaysDualStack { - network = "tcp" - } - if onlyport.IsValid() { // valid even when port is 0 - rd.LocalAddr = net.TCPAddrFromAddrPort(onlyport) - log.V("xdial: DialBind: (o: %s); %s %s=>%s", - d.owner, network, rd.LocalAddr, remote) - } - case "udp", "udp4", "udp6": - if alwaysDualStack { - network = "udp" - } - if onlyport.IsValid() { // valid even when port is 0 - rd.LocalAddr = net.UDPAddrFromAddrPort(onlyport) - log.V("xdial: DialBind: (o: %s); %s %s=>%s", - d.owner, network, rd.LocalAddr, remote) - } - default: - log.W("xdial: DialBind: (o: %s); %s %s=>%s; err: unsupported network", - d.owner, network, local, remote) - } - - // equivalent to d.dial() if LocalAddr is not set - return rd.Dial(network, remote) -} - -// Accept implements RDialer interface. -func (d *RDial) Accept(network, local string) (net.Listener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, errAccept - } - return d.listen.Listen(d.context(), network, local) -} - -// Announce implements RDialer. -func (d *RDial) Announce(network, local string) (net.PacketConn, error) { - if network != "udp" && network != "udp4" && network != "udp6" { - log.T("xdial: Announce: invalid network %s", network) - return nil, errAnnounce - } - // todo: check if local is a local address or empty (any) - // diailing (proxy.Dial/net.Dial/etc) on wildcard addresses (ex: ":8080" or "" or "localhost:1025") - // is not equivalent to listening/announcing. see: github.com/golang/go/issues/22827 - if pc, err := d.listen.ListenPacket(d.context(), network, local); err == nil { - switch x := pc.(type) { - case *net.UDPConn: - return x, nil - default: - log.T("xdial: Announce (o: %s): addr(%s) failed; %T is not net.UDPConn; other errs: %v", - d.owner, local, x, err) - clos(pc) - return nil, errNoUDPMux - } - } else { - return nil, err - } -} - -// Probe implements RDialer. -func (d *RDial) Probe(network, local string) (PacketConn, error) { - if network == "udp" { - ip, _ := netip.ParseAddrPort(local) - ipok := ip.IsValid() - if ipok && ip.Addr().Is4() { - network = "udp4" - } else if ipok && ip.Addr().Is6() { - network = "udp6" - } - } - if network != "udp4" && network != "udp6" { - return nil, errAnnounce - } - // todo: check if local is a local address or empty (any) - // drop port if present - if ip, _, err := net.SplitHostPort(local); err == nil { - local = ip - } - - return d.listenICMP.listenICMP(d.context(), network, local) -} - -func Dial(d RDialer, laddr, raddr net.Addr) (Conn, error) { - if d == nil { - return nil, errNoDialer - } - if raddr == nil { - return nil, errNoRAddr - } - if laddr == nil { - return d.Dial(raddr.Network(), raddr.String()) - } - return d.DialBind(raddr.Network(), laddr.String(), raddr.String()) -} - -// DialTCP creates a net.TCPConn to raddr. -// Helper method for d.Dial("tcp", laddr.String(), raddr.String()) -func (d *RDial) DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - return DialTCP(d, network, laddr, raddr) -} - -// DialTCP creates a net.TCPConn to raddr bound to laddr using dialer d; laddr may be nil. -func DialTCP(d RDialer, network string, laddr, raddr net.Addr) (*net.TCPConn, error) { - if c, err := d.DialBind(network, addr2str(laddr), addr2str(raddr)); err != nil { - return nil, err - } else if tc, ok := c.(*net.TCPConn); ok { - return tc, nil - } else { - log.T("xdial: DialTCP: (%s) to %s => %s, %T is not %T (ok? %t); other errs: %v", - d.ID(), laddr, raddr, c, tc, ok, err) - // some proxies like wgproxy, socks5 do not vend *net.TCPConn - // also errors if retrier (core.DuplexConn) is looped back here - clos(c) - return nil, errNoTCP - } -} - -// DialUDP creates a net.UDPConn to raddr. -// Helper method for d.Dial("udp", laddr.String(), raddr.String()) -func (d *RDial) DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - return DialUDP(d, network, laddr, raddr) -} - -// DialUDP creates a net.UDPConn to raddr bound to laddr using dialer d; laddr may be left nil. -func DialUDP(d RDialer, network string, laddr, raddr net.Addr) (*net.UDPConn, error) { - if c, err := d.DialBind(network, addr2str(laddr), addr2str(raddr)); err != nil { - return nil, err - } else if uc, ok := c.(*net.UDPConn); ok { - return uc, nil - } else { - log.T("xdial: DialUDP: (%s) to %s => %s, %T is not %T (ok? %t); other errs: %v", - d.ID(), laddr, raddr, c, uc, ok, err) - // some proxies like wgproxy, socks5 do not vend *net.UDPConn - clos(c) - return nil, errNoUDP - } -} - -// AnnounceUDP announces the local address. network must be "udp" or "udp4" or "udp6". -// Helper method for d.Announce("udp", local) -func (d *RDial) AnnounceUDP(network, local string) (*net.UDPConn, error) { - return AnnounceUDP(d, network, local) -} - -// AnnounceUDP announces the local address. network must be "udp" or "udp4" or "udp6". -func AnnounceUDP(d RDialer, network, local string) (*net.UDPConn, error) { - if c, err := d.Announce(network, local); err != nil { - return nil, err - } else if uc, ok := c.(*net.UDPConn); ok { - return uc, nil - } else { - log.T("xdial: AnnounceUDP: (%s) from %s, %T is not %T (ok? %t); other errs: %v", - d.ID(), local, c, uc, ok, err) - clos(c) - return nil, errNoUDPMux - } -} - -// AcceptTCP creates a listener on the local address. network must be "tcp" or "tcp4" or "tcp6". -// Helper method for d.Accept("tcp", local) -func (d *RDial) AcceptTCP(network string, local string) (*net.TCPListener, error) { - return AcceptTCP(d, network, local) -} - -// AcceptTCP creates a listener on localaddr. network must be "tcp" or "tcp4" or "tcp6". -func AcceptTCP(d RDialer, network string, localaddr string) (*net.TCPListener, error) { - if ln, err := d.Accept(network, localaddr); err != nil { - return nil, err - } else if tl, ok := ln.(*net.TCPListener); ok { - return tl, nil - } else { - log.T("xdial: AcceptTCP: (%s) from %s, %T is not %T (ok? %t); other errs: %v", - d.ID(), localaddr, ln, tl, ok, err) - clos(ln) - return nil, errNoTCPMux - } -} - -// ProbeICMP listens on the local address for ICMP packets sent over UDP. -// network must be "udp" or "udp4" or "udp6". Helper method for d.Probe("udp", local) -func (d *RDial) ProbeICMP(network, local string) (net.PacketConn, error) { - return d.Probe(network, local) -} - -func clos(c io.Closer) { - core.Close(c) -} - -func addr2str(a net.Addr) string { - if a == nil || core.IsNil(a) { - return "" - } - return a.String() -} diff --git a/intra/rnet/http.go b/intra/rnet/http.go deleted file mode 100644 index 9040030e..00000000 --- a/intra/rnet/http.go +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rnet - -import ( - "context" - "io" - "net" - "net/http" - "net/url" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - tx "github.com/elazarl/goproxy" -) - -type dialFn func(network, addr string) (net.Conn, error) - -type httpx struct { - id string - host string - dialer *net.Dialer - svc *http.Server - hdl *httpxhandle - listener ServerListener - usetls bool - - // mutable fields below - sync.Mutex // protects tx.ProxyHttpServer - *tx.ProxyHttpServer // changed by Hop() - - status *core.Volatile[int] // status of the server -} - -type httpxhandle struct { - *AuthHandle - px *core.Volatile[ipn.Proxy] -} - -func newHttpServer(id, x string, ctl protect.Controller, listener ServerListener) (*httpx, error) { - var host string - var usr string - var pwd string - - // ex: "http://u:p@host:8080"; "http://u:p@:8080"; "http://:8080"; "http://host" - u, err := url.Parse(x) - if err != nil { - return nil, err - } - host = u.Host // host - if u.User != nil { // usr, pwd - usr = u.User.Username() // may be empty - pwd, _ = u.User.Password() // may be empty - } - dialer := protect.MakeNsDialer(id, ctl) - hdl := &httpxhandle{ - AuthHandle: &AuthHandle{usr: usr, pwd: pwd}, - px: core.NewZeroVolatile[ipn.Proxy](), - } - hproxy := tx.NewProxyHttpServer() - hproxy.Logger = log.Glogger - hproxy.Tr = &http.Transport{ - DialContext: dialer.DialContext, // overriden by Hop() - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 20 * time.Second, - } - // todo: dial to connect endpoint as defined by the underlying network or the OS - hproxy.ConnectDial = nil - hproxy.ConnectDialWithReq = nil - - svc := &http.Server{Addr: host, Handler: hproxy, ReadHeaderTimeout: 10 * time.Second} - usetls := u.Scheme == "https" - hasauth := len(usr) > 0 || len(pwd) > 0 - if hasauth { - // todo: listener with summary and route - hproxy.OnRequest(hdl.notok()).HandleConnectFunc(hdl.denyConnect) - hproxy.OnRequest(hdl.notok()).DoFunc(hdl.denyRequest) - } - - log.I("svchttp: new %s listening at %s; tls? %t / auth? %t", id, host, usetls, hasauth) - hx := &httpx{ - ProxyHttpServer: hproxy, - id: id, - usetls: usetls, - host: host, - dialer: dialer, - hdl: hdl, - svc: svc, - listener: listener, - status: core.NewVolatile(SOK), - } - hproxy.OnRequest().HandleConnectFunc(hx.routeConnect) - hproxy.OnRequest().DoFunc(hx.route) - hproxy.OnResponse().DoFunc(hx.summarize) - - return hx, nil -} - -type AuthHandle struct { - usr string - pwd string -} - -func (au *AuthHandle) notok() tx.ReqConditionFunc { - return func(req *http.Request, ctx *tx.ProxyCtx) bool { - if len(au.usr) == 0 && len(au.pwd) == 0 { - return false // no auth; do not handle - } - u, p, ok := req.BasicAuth() - return !ok || u != au.usr || p != au.pwd // handle if match - } -} - -func (au *AuthHandle) denyConnect(host string, ctx *tx.ProxyCtx) (*tx.ConnectAction, string) { - act := &tx.ConnectAction{Action: tx.ConnectProxyAuthHijack, TLSConfig: tx.TLSConfigFromCA(&tx.GoproxyCa)} - return act, host // "host" is unused when action is ConnectProxyAuthHijack -} - -func (au *AuthHandle) denyRequest(req *http.Request, ctx *tx.ProxyCtx) (*http.Request, *http.Response) { - return req, tx.NewResponse(req, tx.ContentTypeText, http.StatusUnauthorized, "Unauthorized") -} - -// using ctx.UserData to store summary -// github.com/elazarl/goproxy/blob/2592e75ae0/examples/goproxy-httpdump/httpdump.go#L254 -// and counting bytes via a wrapped read-writer -// github.com/elazarl/goproxy/blob/2592e75ae0/examples/goproxy-stats/main.go#L61 -func (h *httpx) route(req *http.Request, ctx *tx.ProxyCtx) (*http.Request, *http.Response) { - src := req.RemoteAddr - sid := h.id - pid := h.pid() - tab := h.listener.SvcRoute(sid, pid, "tcp", src, req.Host) - log.D("svchttp: route: tab(%v) id(%s) p(%s) src(%s) dst(%s)", tab, h.id, pid, src, req.Host) - if tab.Block { - return req, tx.NewResponse(req, tx.ContentTypeText, http.StatusForbidden, "Forbidden") - } - ctx.UserData = serverSummary(h.Type(), sid, pid, tab.CID) - return req, nil -} - -func (h *httpx) summarize(res *http.Response, ctx *tx.ProxyCtx) *http.Response { - req := res.Request - if ctx.UserData == nil { - if req != nil { - log.W("svchttp: summarize for %s<-%s missing; n: %d", req.Host, req.RemoteAddr, req.ContentLength) - } else { - log.W("svchttp: summarize missing") - } - } - ssu, ok := ctx.UserData.(*ServerSummary) - if !ok { - log.W("svchttp: summarize: invalid userdata %v", ctx.UserData) - return res - } - ssu.Rx = res.ContentLength - if req != nil { - ssu.Tx = req.ContentLength - } - ssu.done(errNop) - go h.listener.OnSvcComplete(ssu.ServerSummary) - return res -} - -func (h *httpx) routeConnect(host string, ctx *tx.ProxyCtx) (*tx.ConnectAction, string) { - src := h.svc.Addr - dst := ctx.Req.Host - sid := h.id - pid := h.pid() - tab := h.listener.SvcRoute(sid, pid, "tcp", src, host) - log.D("svchttp: routeConnect: tab(%v) id(%s) p(%s) src(%s) dst(%s)", tab, h.id, pid, src, dst) - if tab.Block { - return tx.RejectConnect, host - } - ctx.UserData = serverSummary(h.Type(), sid, pid, tab.CID) - hijackact := &tx.ConnectAction{Action: tx.ConnectHijack, Hijack: h.hijackConnect} - return hijackact, host -} - -// from: https://github.com/elazarl/goproxy/blob/2592e75ae0/https.go#L126-L154 -func (h *httpx) hijackConnect(req *http.Request, client net.Conn, ctx *tx.ProxyCtx) { - ssu, _ := ctx.UserData.(*ServerSummary) - host := req.Host - addr, port, err := net.SplitHostPort(req.Host) - if err != nil { - log.W("svchttp: hijackConnect: host(%s) not valid addr/port err %v", host, err) - } else if len(port) <= 0 { - host = net.JoinHostPort(addr, "80") - } - target, err := h.Tr.DialContext(context.Background(), "tcp", host) - if err != nil { - http502(client, err, ssu) - return - } - log.D("Accepting CONNECT to %s; cid: %s", host, ssu.CID) - n, err := client.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n")) - if err != nil { - log.W("svchttp: hijackConnect: failed client write (%d); err %v", n, err) - http502(client, err, ssu) - return - } - - go func() { - wg := &sync.WaitGroup{} - wg.Add(2) - - dst, ok1 := target.(*net.TCPConn) - src, ok2 := client.(*net.TCPConn) - if ok1 && ok2 { - go pipetcp(dst, src, ssu, wg) - go pipetcp(src, dst, ssu, wg) - wg.Wait() - } else { - go pipeconn(target, client, ssu, wg) - go pipeconn(client, target, ssu, wg) - wg.Wait() - clos(client, target) - } - h.listener.OnSvcComplete(ssu.ServerSummary) - }() -} - -func clos(cs ...io.Closer) { - core.Close(cs...) -} - -func http502(w io.WriteCloser, err1 error, ssu *ServerSummary) { - _, err2 := io.WriteString(w, "HTTP/1.1 502 Bad Gateway\r\n\r\n") - err3 := w.Close() - if ssu != nil { - ssu.done(err1, err2, err3) - } - log.D("svchttp: http502: done http-connect; errs? %v", core.JoinErr(err1, err2, err3)) -} - -func pipeconn(dst net.Conn, src net.Conn, ssu *ServerSummary, wg *sync.WaitGroup) { - var err error - defer wg.Done() - defer ssu.done(err) // done handles nil ssu - - _, err = core.Pipe(dst, src) - log.D("svchttp: pipeconn: done; err src(%s) -> dst(%s); err? %v", src.RemoteAddr(), dst.RemoteAddr(), err) -} - -func pipetcp(dst, src *net.TCPConn, ssu *ServerSummary, wg *sync.WaitGroup) { - _, err1 := core.Pipe(dst, src) - log.D("svchttp: pipetcp: done; src (%s) -> dst(%s); err? %v", src.RemoteAddr(), dst.RemoteAddr(), err1) - err2 := dst.CloseWrite() - err3 := src.CloseRead() - if ssu != nil { - ssu.done(err1, err2, err3) - } - wg.Done() -} - -func (h *httpx) Hop(p x.Proxy) error { - if h.status.Load() == END { - log.D("svchttp: hop: %s not running", h.ID()) - return errServerEnd - } - - dialer := h.dialer.Dial - if p == nil || core.IsNil(p) { - h.hdl.px.Store(nil) // clear - // h.ProxyHttpServer.Tr.DialContext = h.dialer.DialContext - } else if pp, ok := p.(ipn.Proxy); ok { - h.hdl.px.Store(pp) - dialer = pp.Dialer().Dial - } else { - log.E("svchttp: hop: %s; failed: %T not ipn.Proxy", h.ID(), p) - return errNotProxy - } - - log.D("svchttp: hop: %s over proxy? %t via %s", h.ID(), p != nil, h.GetAddr()) - - h.swap(dialer) - return nil -} - -func (h *httpx) swap(f dialFn) { - h.Lock() - defer h.Unlock() - // todo: reads are not synchronized! - h.ProxyHttpServer.Tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { - return f(network, addr) - } -} - -func (h *httpx) Start() error { - if h.status.Load() == END { - return errSvcRunning - } - h.status.Store(SOK) - go func() { - if h.usetls { - h.status.Store(END) - log.E("svchttp: %s cannot start; tls is unimplemented", h.ID()) - return - } - err := h.svc.ListenAndServe() - log.I("svchttp: %s exited; err? %v", h.ID(), err) - h.status.Store(END) - }() - log.I("svchttp: %s started %s", h.ID(), h.GetAddr()) - return nil -} - -func (h *httpx) Stop() error { - err := h.svc.Close() - // err := h.svc.Shutdown(context.Background()) - h.status.Store(END) - log.I("svchttp: %s stopped; err? %v", h.ID(), err) - return err -} - -func (h *httpx) Refresh() error { - err1 := h.Stop() - time.Sleep(3 * time.Second) // arbitrary wait - err2 := h.Start() - - log.I("svchttp: %s refreshed; errs? %v; %v", h.ID(), err1, err2) - - if err2 != nil { - return err2 - } - return err1 -} - -func (h *httpx) pid() (x string) { - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - x = px.ID() - } - return -} - -func (h *httpx) ID() string { - return h.id -} - -func (h *httpx) GetAddr() string { - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - return px.GetAddr() - } - return h.host -} - -func (h *httpx) Status() int { - return h.status.Load() -} - -func (h *httpx) Type() string { - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - return PXHTTP // proxied - } - return SVCHTTP // direct -} diff --git a/intra/rnet/listener.go b/intra/rnet/listener.go deleted file mode 100644 index 3815aafe..00000000 --- a/intra/rnet/listener.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rnet - -import ( - "errors" - "fmt" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" -) - -var errNop = errors.New("no error") - -type ServerSummary struct { - *x.ServerSummary - start time.Time // Tracks start time; unexported. -} - -func (s *ServerSummary) done(errs ...error) { - if s == nil { - return - } - - s.Duration = time.Since(s.start).Milliseconds() - - err := core.JoinErr(errs...) // errs may be nil - if err != nil { - if s.Msg == errNop.Error() { - s.Msg = err.Error() - } else { - s.Msg = s.Msg + "; " + err.Error() - } - } - if len(s.Msg) <= 0 { - s.Msg = errNop.Error() - } -} - -func (s *ServerSummary) String() string { - if s == nil { - return "" - } - return fmt.Sprintf("type: %s, sid: %s, pid: %s, cid: %s, upload: %d, download: %d, duration: %d, msg: %s", - s.Type, s.SID, s.PID, s.CID, s.Tx, s.Rx, s.Duration, s.Msg) -} - -func serverSummary(typ, sid, pid, cid string) *ServerSummary { - return &ServerSummary{ - ServerSummary: &x.ServerSummary{ - Type: typ, - SID: sid, - PID: pid, - CID: cid, - Msg: errNop.Error(), - }, - start: time.Now(), - } -} diff --git a/intra/rnet/servers.go b/intra/rnet/servers.go deleted file mode 100644 index 1e3f9a04..00000000 --- a/intra/rnet/servers.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rnet - -import ( - "context" - "errors" - "fmt" - "sync" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" -) - -const ( - // type of services - SVCSOCKS5 = x.SVCSOCKS5 - SVCHTTP = x.SVCHTTP - PXSOCKS5 = x.PXSOCKS5 - PXHTTP = x.PXHTTP - - // status of proxies - SUP = x.SUP - SOK = x.SOK - SKO = x.SKO - END = x.SOP -) - -var ( - errNoServer = errors.New("svc: no such server") - errSvcRunning = errors.New("svc: service is running") - errNotUdp = errors.New("svc: not udp conn") - errNotTcp = errors.New("svc: not tcp conn") - errNoAddr = errors.New("svc: no address") - errServerEnd = errors.New("svc: server stopped") - errProxyEnd = errors.New("svc: proxy stopped") - errProxyPaused = errors.New("svc: proxy paused") - errNotProxy = errors.New("svc: not a proxy") - errBlocked = errors.New("svc: blocked") - - udptimeoutsec = 5 * 60 // 5m - tcptimeoutsec = (2 * 60 * 60) + (40 * 60) // 2h40m -) - -// todo: github.com/txthinking/brook/blob/master/pac.go - -type Server x.Server - -type Services x.Services - -type ServerListener x.ServerListener - -var _ Services = (*services)(nil) -var _ Server = (*httpx)(nil) -var _ Server = (*socks5)(nil) - -type services struct { - sync.RWMutex - servers map[string]Server - proxies ipn.Proxies - listener ServerListener - ctl protect.Controller -} - -func NewServices(pctx context.Context, proxies ipn.Proxies, ctl protect.Controller, listener ServerListener) *services { - if listener == nil || ctl == nil { - return nil - } - svc := &services{ - servers: make(map[string]Server), - ctl: ctl, - proxies: proxies, - listener: listener, - } - context.AfterFunc(pctx, svc.stopServers) - return svc -} - -func (s *services) AddServer(id, url string) (svc x.Server, err error) { - s.RemoveServer(id) - - switch id { - case SVCSOCKS5, PXSOCKS5: - svc, err = newSocks5Server(id, url, s.ctl, s.listener) - case SVCHTTP, PXHTTP: - svc, err = newHttpServer(id, url, s.ctl, s.listener) - default: - return nil, errors.ErrUnsupported - } - - if err != nil { - return nil, err - } - - s.Lock() - s.servers[id] = svc - s.Unlock() - - // if the server has a namesake proxy, bridge them - err = s.Bridge(id, id) - - log.I("svc: add: %s > %s; err? %v", id, url, err) - - return svc, err -} - -func (s *services) Bridge(serverid, proxyid string) (err error) { - svc, err := s.GetServer(serverid) - - if err != nil { - log.W("svc: bridge: no server %s; err? %v", serverid, err) - return - } - // remove existing bridge, if any - if len(proxyid) <= 0 { - err = svc.Hop(nil) - log.I("svc: bridge: remove all hops for %s; err? %v", serverid, err) - return - } - - px, err := s.proxies.ProxyFor(proxyid) - if err != nil { - log.W("svc: bridge: no proxy %s for %s; err? %v", proxyid, serverid, err) - return - } - - svcstr := fmt.Sprintf("%s/%s [%d] at %s", serverid, svc.Type(), svc.Status(), svc.GetAddr()) - pxstr := fmt.Sprintf("%s/%s [%d] at %s", proxyid, px.Type(), px.Status(), px.GetAddr()) - - err = svc.Hop(px) - - log.I("svc: bridge: %s with %s; hop err? %v", svcstr, pxstr, err) - - return -} - -func (s *services) RemoveServer(id string) bool { - if svc, err := s.GetServer(id); err == nil { - _ = svc.Stop() - delete(s.servers, id) - return true - } - return false -} - -func (s *services) GetServer(id string) (x.Server, error) { - s.RLock() - defer s.RUnlock() - - if svc, ok := s.servers[id]; ok { - return svc, nil - } - return nil, errNoServer -} - -func (s *services) stopServers() { - s.Lock() - defer s.Unlock() - - n := len(s.servers) - for _, svc := range s.servers { - _ = svc.Stop() - } - log.I("svc: stopped servers: %d", n) -} - -func (s *services) RefreshServers() string { - s.Lock() - defer s.Unlock() - - var csv string - for _, svc := range s.servers { - sid := svc.ID() - if err := svc.Refresh(); err != nil { - log.W("svc: refresh %s; err: %v", sid, err) - continue - } - if csv == "" { - csv = sid - } else { - csv += "," + sid - } - } - return csv -} - -func (s *services) RemoveAll() { - s.stopServers() - - s.Lock() - clear(s.servers) - s.Unlock() -} diff --git a/intra/rnet/socks5.go b/intra/rnet/socks5.go deleted file mode 100644 index 882dc79b..00000000 --- a/intra/rnet/socks5.go +++ /dev/null @@ -1,589 +0,0 @@ -// Copyright (c) 2023 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rnet - -import ( - "context" - "fmt" - "io" - "net" - "net/url" - "sync" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - tx "github.com/txthinking/socks5" -) - -var _ tx.Handler = (*socks5)(nil) - -type socks5 struct { - *tx.Server - sync.Mutex // protects tx.Dial - - id string - url string - outbound *protect.RDial - hdl *socks5handler - listener ServerListener - - smu sync.RWMutex // protects summaries - summaries map[*tx.UDPExchange]*ServerSummary - - done context.CancelFunc - - // mutable fields below - - status *core.Volatile[int] // SOK, SKO, END -} - -type socks5handler struct { - *tx.DefaultHandle - px *core.Volatile[ipn.Proxy] -} - -// newSocks5Server creates a new socks5 server with the given id, url, controller, and listener. -// It should not be used if ipn/socks5 is also active. -func newSocks5Server(id, x string, ctl protect.Controller, listener ServerListener) (*socks5, error) { - var host string - var usr string - var pwd string - - u, err := url.Parse(x) - if err != nil { - return nil, err - } - host = u.Host // host - if u.User != nil { // usr, pwd - usr = u.User.Username() // may be empty - pwd, _ = u.User.Password() // may be empty - } - - ctx, done := context.WithCancel(context.Background()) - dialer := protect.MakeNsRDial(id, ctx, ctl) - // tx.DialTCP and tx.DialUDP may already been set by ipn.sock5 - tx.DialTCP = func(n string, _, d string) (net.Conn, error) { - return dialer.Dial(n, d) - } - // todo: support connecting from src - tx.DialUDP = func(n string, _, d string) (net.Conn, error) { - return dialer.Dial(n, d) - } - - // unused in our case; usage: github.com/txthinking/brook/issues/988 - remoteip := "" - hdl := &socks5handler{ - DefaultHandle: &tx.DefaultHandle{}, // not used; see dial, TCPHandle, and UDPHandle - px: core.NewZeroVolatile[ipn.Proxy](), - } - server, _ := tx.NewClassicServer(host, remoteip, usr, pwd, tcptimeoutsec, udptimeoutsec) - - hasauth := len(usr) > 0 || len(pwd) > 0 - log.I("svcsocks5: new %s listening at %s; auth?", id, host, hasauth) - return &socks5{ - Server: server, - id: id, - url: host, - outbound: dialer, - hdl: hdl, - listener: listener, - summaries: make(map[*tx.UDPExchange]*ServerSummary), - status: core.NewVolatile(SOK), - done: done, - }, nil -} - -func (h *socks5) Hop(p x.Proxy) error { - if h.status.Load() == END { - log.D("svcsocks5: hop: %s not running", h.ID()) - return errServerEnd - } - - var dialer protect.RDialer = h.outbound - if p == nil || core.IsNil(p) { - h.hdl.px.Store(nil) // clear - // dialer = h.rdial - } else if pp, ok := p.(ipn.Proxy); ok { - h.hdl.px.Store(pp) - dialer = pp.Dialer() - } else { - log.E("svcsocks5: hop: %s; failed: %T not ipn.Proxy", h.ID(), p) - return errNotProxy - } - log.D("svcsocks5: hop: %s over proxy? %t via %s", h.ID(), p != nil, h.GetAddr()) - - h.swap(dialer) - return nil -} - -func (h *socks5) swap(rd protect.RDialer) { - h.Lock() - defer h.Unlock() - // todo: var tx.DialTCP/tx.DialUDP (reads) not synchronized - // tx.DialTCP and tx.DialUDP may already been set by ipn.sock5 - tx.DialTCP = func(n string, _, d string) (net.Conn, error) { - return rd.Dial(n, d) - } - // todo: support connecting from src - tx.DialUDP = func(n string, _, d string) (net.Conn, error) { - return rd.Dial(n, d) - } -} - -func (h *socks5) Start() error { - if h.status.Load() != END { - return errSvcRunning - } - h.status.Store(SOK) - go func() { - err := h.Server.ListenAndServe(h) - log.I("svcsocks5: %s exited; err? %v", h.ID(), err) - h.status.Store(END) - }() - log.I("svcsocks5: %s started %s", h.ID(), h.GetAddr()) - return nil -} - -func (h *socks5) Stop() error { - err := h.Server.Shutdown() - h.status.Store(END) - h.done() - log.I("svcsocks5: %s stopped; err? %v", h.ID(), err) - return err -} - -func (h *socks5) Refresh() error { - err1 := h.Stop() - time.Sleep(3 * time.Second) // arbitrary wait - err2 := h.Start() - - log.I("svcsocks5: %s refreshed; errs? %v; %v", h.ID(), err1, err2) - - if err2 != nil { - return err2 - } - return err1 -} - -func (h *socks5) ID() string { - return h.id -} - -func (h *socks5) GetAddr() string { - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - return px.GetAddr() - } - return h.url -} - -func (h *socks5) Status() int { - return h.status.Load() -} - -func (h *socks5) Type() string { - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - return PXSOCKS5 // proxied - } - return SVCSOCKS5 // direct -} - -// Implements tx.Handler -func (h *socks5) TCPHandle(server *tx.Server, ingress *net.TCPConn, req *tx.Request) error { - if err := h.candial(); err == nil { - return h.tcphandle(server, ingress, req) - } else { - return err - } -} - -// Implement tx.Handler -func (h *socks5) UDPHandle(server *tx.Server, ingress *net.UDPAddr, pkt *tx.Datagram) error { - if err := h.candial(); err == nil { - return h.udphandle(server, ingress, pkt) - } else { - return err - } -} - -func (h *socks5) dial(network, src, dst string) (cid string, conn net.Conn, err error) { - if err = h.candial(); err != nil { - return - } - tab := h.route(network, src, dst) - if tab.Block { - err = errBlocked - return - } - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - conn, err = px.Dialer().Dial(network, dst) - } else { - conn, err = h.outbound.Dial(network, dst) - } - return tab.CID, conn, err -} - -func (h *socks5) pid() (x string) { - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - x = px.ID() - } - return -} - -func (h *socks5) route(network, src, dst string) *x.Tab { - return h.listener.SvcRoute(h.id, h.pid(), network, src, dst) -} - -func (h *socks5) candial() error { - if h.Status() != END { - return errProxyEnd // no - } - if px := h.hdl.px.Load(); px != nil && core.IsNotNil(px) { - st := px.Status() - if st == ipn.END { - return errProxyEnd // no - } else if st == ipn.TPU { - return errProxyPaused // no - } // fallthrough - } - return nil // yes -} - -func (h *socks5) setDeadline(c net.Conn, secs int) error { - if secs == 0 { // no op - return nil - } - ttl := time.Duration(secs) * time.Second - return c.SetDeadline(time.Now().Add(ttl)) -} - -type pipefin struct { - ex int64 // bytes exchanged - err error // error, if any -} - -func (h *socks5) pipe(r, w net.Conn, finch chan<- pipefin) { - bptr := core.Alloc16() - bf := *bptr - bf = bf[:cap(bf)] - defer func() { - *bptr = bf - core.Recycle(bptr) - }() - ex := int64(0) - laddr := r.LocalAddr() - raddr := w.RemoteAddr() - for { - if err := h.setDeadline(r, tcptimeoutsec); err != nil { - finch <- pipefin{ex, err} - break - } - n, err := r.Read(bf[:]) - ex += int64(n) - if err != nil { - log.E("svcsocks5: tcp: %s; read %s; err: %v", h.ID(), laddr, err) - finch <- pipefin{ex, err} - break - } - if _, err := w.Write(bf[0:n]); err != nil { - log.E("svcsocks5: tcp: %s; write %s; err: %v", h.ID(), raddr, err) - finch <- pipefin{ex, err} - break - } - log.V("svcsocks5: tcp: %s; %s -> %s; %d bytes", h.ID(), laddr, raddr, n) - } -} - -// Adopted from tx.DefaultHandle with the only changes are -// 1. ipn.Proxy as the dialer -// 2. buffers are allocated from core.Alloc() -func (h *socks5) tcphandle(s *tx.Server, ingress *net.TCPConn, r *tx.Request) (err error) { - if r.Cmd == tx.CmdConnect { - var cid string - var egress *net.TCPConn - cid, egress, err = h.Connect(r, ingress) - ssu := serverSummary(h.Type(), h.ID(), h.pid(), cid) - defer func() { - ssu.done(err) - go h.listener.OnSvcComplete(ssu.ServerSummary) - }() - - log.D("svcsocks5: proxy-tcp: %s; socks5-connect %s", cid, r.Address()) - - if err != nil { - h.status.Store(SKO) - log.E("svcsocks5: proxy-tcp: %s; connect %s; err: %v", cid, r.Address(), err) - return err - } - // c is closed by the caller - defer clos(egress) - - finrxch := make(chan pipefin, 1) - fintxch := make(chan pipefin, 1) - go h.pipe(egress, ingress, finrxch) // read from egress, write to ingress - go h.pipe(ingress, egress, fintxch) // read from ingress, write to egress - finrx := <-finrxch - fintx := <-fintxch - - err = core.JoinErr(finrx.err, fintx.err) - - ssu.Rx = finrx.ex - ssu.Tx = fintx.ex - - return err - } - if r.Cmd == tx.CmdUDP { - log.D("svcsocks5: proxy-tcp via udp: %s; socks5-tcp-udp %s", h.ID(), r.Address()) - caddr, err := r.UDP(ingress, s.ServerAddr) - if err != nil { - h.status.Store(SKO) - return err - } - - ch := make(chan byte) - defer close(ch) - - s.AssociatedUDP.Set(caddr.String(), ch, -1) - defer s.AssociatedUDP.Delete(caddr.String()) - - n, err := core.Pipe(io.Discard, ingress) - - log.D("svcsocks: tcp: %s tcp that udp %s associated closed %d; err? %v", h.ID(), caddr, n, err) - return nil - } - return tx.ErrUnsupportCmd -} - -func (h *socks5) udphandle(s *tx.Server, addr *net.UDPAddr, pkt *tx.Datagram) (err error) { - cid := h.ID() // init connection id to server id, for logging purposes - src := addr.String() - var ch chan byte - - if s.LimitUDP { // always false, for now - any, ok := s.AssociatedUDP.Get(src) - if !ok { - return fmt.Errorf("udp addr %s not associated with tcp", src) - } - ch, ok = any.(chan byte) - if !ok { - return fmt.Errorf("udp addr %s not associated with tcp; ch missing", src) - } - } - - send := func(egress *tx.UDPExchange, data []byte) error { - ueladdr := egress.RemoteConn.LocalAddr() - ueraddr := egress.RemoteConn.RemoteAddr() - uecaddr := egress.ClientAddr - - ssu := h.getSummary(egress) - - select { - case _, ok := <-ch: - return fmt.Errorf("udp addr %s not associated with tcp; ch ok? %t", src, ok) - default: - // writing to egress conn - n, werr := egress.RemoteConn.Write(data) - if ssu != nil { - ssu.Tx += int64(n) - } - log.D("svcsocks5: udp: %s; data sent; (err: %v / summary? %t)? client: %s server: %s remote: %s sz: %d", cid, werr, ssu != nil, uecaddr, ueladdr, ueraddr, n) - if werr != nil { - return werr - } - } - return nil - } - - dst := pkt.Address() - tuple4 := src + dst - var egress *tx.UDPExchange - iue, ok := s.UDPExchanges.Get(tuple4) - if ok { - if egress, ok = iue.(*tx.UDPExchange); ok { - return send(egress, pkt.Data) - } - } - - ssu := serverSummary(h.Type(), h.ID(), h.pid(), cid) - defer func() { - ssu.done(err) - go h.listener.OnSvcComplete(ssu.ServerSummary) - }() - - log.D("svcsocks5: udp: %s; dst %s", cid, dst) - cid, uc, err := h.dial("udp", src, dst) - if err != nil { - return err - } - - rc, ok := uc.(*net.UDPConn) - if !ok { - return errNotUdp - } - - egress = &tx.UDPExchange{ - ClientAddr: addr, // same as src - RemoteConn: rc, - } - - h.setSummary(egress, ssu) - - log.D("svcsocks5: udp: %s; remote conn for client: %s server: %s remote: %s", cid, addr, egress.RemoteConn.LocalAddr(), pkt.Address()) - if err := send(egress, pkt.Data); err != nil { - log.E("svcsocks5: udp: %s; send pkt %d to remote: %s; err %v", cid, len(pkt.Data), egress.RemoteConn.RemoteAddr(), err) - - h.delSummary(egress) - - clos(egress.RemoteConn) // TODO: clos(egress) instead? - return err - } - s.UDPExchanges.Set(src+dst, egress, -1) - - go func(ue *tx.UDPExchange, dst string) { - bptr := core.Alloc() - b := *bptr - b = b[:cap(b)] - defer func() { - h.delSummary(ue) - - clos(ue.RemoteConn) - s.UDPExchanges.Delete(src + dst) - - *bptr = b - core.Recycle(bptr) - }() - - ueladdr := ue.RemoteConn.LocalAddr() - ueraddr := ue.RemoteConn.RemoteAddr() - uecaddr := ue.ClientAddr - for { - select { - case _, ok = <-ch: - log.D("svcsocks5: udp: %s; tcp to udp addr %s associated closed; ch ok? %t", cid, uecaddr, ok) - return - default: - if err := h.setDeadline(ue.RemoteConn, s.UDPTimeout); err != nil { - return - } - // reading from egress - n, err := ue.RemoteConn.Read(b[:]) - if err != nil { - log.E("svcsocks5: udp: %s; read err: %v", cid, err) - return - } - ssu.Rx += int64(n) - log.D("svcsocks5: udp: %s; got data; client: %s server: %s remote: %s data: %d", cid, uecaddr, ueladdr, ueraddr, n) - a, addr, port, err := tx.ParseAddress(dst) - if err != nil { - log.E("svcsocks5: udp: %s; parse-addr err? %v", cid, err) - return - } - d1 := tx.NewDatagram(a, addr, port, b[:n]) - // writing to ingress - if _, err := s.UDPConn.WriteToUDP(d1.Bytes(), ue.ClientAddr); err != nil { - log.E("svcsocks5: udp: %s; write err: %v", cid, err) - return - } - log.V("svcsocks5: udp: %s; data sent; client: %s server: %s remote: %s data: %#v %#v %#v %#v %#v %#d datagram address: %s", cid, uecaddr, ueladdr, ueraddr, d1.Rsv, d1.Frag, d1.Atyp, d1.DstAddr, d1.DstPort, len(d1.Data), d1.Address()) - } - } - }(egress, dst) - return nil -} - -func (h *socks5) Connect(r *tx.Request, w *net.TCPConn) (cid string, rc *net.TCPConn, err error) { - log.D("svcsocks5: tcp: %s; dial", h.ID(), r.Address()) - raddr := w.RemoteAddr() - if raddr == nil { - log.W("svcsocks5: tcp: %s; err no remote addr", h.ID()) - h.status.Store(SKO) - err = errNoAddr - return - } - - var tc net.Conn // egress - cid, tc, err = h.dial("tcp", raddr.String(), r.Address()) - if err != nil { - h.status.Store(SKO) - - log.W("svcsocks5: tcp: %s; dial remote %s; err: %v", cid, r.Address(), err) - var p *tx.Reply - if r.Atyp == tx.ATYPIPv4 || r.Atyp == tx.ATYPDomain { - p = tx.NewReply(tx.RepHostUnreachable, tx.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) - } else { - p = tx.NewReply(tx.RepHostUnreachable, tx.ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) - } - if _, err = p.WriteTo(w); err != nil { - log.E("svcsocks5: tcp: %s; write-to remote %s; err: %v", cid, r.Address(), err) - return - } - return - } - - var ok bool - rc, ok = tc.(*net.TCPConn) - if !ok { - h.status.Store(SKO) - err = errNotTcp - return - } - laddr := rc.LocalAddr() - if laddr == nil { - log.W("svcsocks5: tcp: %s; err no local addr", cid, laddr) - h.status.Store(SKO) - err = errNoAddr - return - } - - a, addr, port, perr := tx.ParseAddress(laddr.String()) - if perr != nil { - log.W("svcsocks5: tcp: %s; parse-addr err? %v", cid, err) - var p *tx.Reply - if r.Atyp == tx.ATYPIPv4 || r.Atyp == tx.ATYPDomain { - p = tx.NewReply(tx.RepHostUnreachable, tx.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) - } else { - p = tx.NewReply(tx.RepHostUnreachable, tx.ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) - } - if _, err = p.WriteTo(w); err != nil { - log.E("svcsocks5: tcp: %s; write-to remote %s; err: %v", cid, r.Address(), err) - return - } - err = perr - return - } - - p := tx.NewReply(tx.RepSuccess, a, addr, port) - if _, err = p.WriteTo(w); err != nil { - log.E("svcsocks5: tcp: %s; write-to remote %s; err: %v", cid, r.Address(), err) - return - } - return -} - -func (h *socks5) getSummary(c *tx.UDPExchange) *ServerSummary { - h.smu.RLock() - defer h.smu.RUnlock() - - return h.summaries[c] -} - -func (h *socks5) setSummary(c *tx.UDPExchange, s *ServerSummary) { - h.smu.Lock() - defer h.smu.Unlock() - - h.summaries[c] = s -} - -func (h *socks5) delSummary(c *tx.UDPExchange) { - h.smu.Lock() - defer h.smu.Unlock() - - delete(h.summaries, c) -} diff --git a/intra/rwconn.go b/intra/rwconn.go deleted file mode 100644 index 812471de..00000000 --- a/intra/rwconn.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "io" - "net" - "syscall" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/settings" -) - -// rwext wraps MinConn and extends deadline to minimum(min, settings.DialerOpts) -// on every read and write. -type rwext struct { - net.Conn // underlying conn - minidle uint32 // min idle timeout in secs -} - -// TODO? var _ core.DuplexCloser = (*rwext)(nil) -var _ core.RetrierConn = (*rwext)(nil) -var _ core.ControlConn = (*rwext)(nil) - -func (rw rwext) SetTimeout() (secs int, didSet bool) { - r, w := rw.deadlines() - secs = max(int(r), int(w)) - if r > 0 { - // always returns false for udp conns - didSet = core.SetTimeoutSockOpt(rw.Unwrap(), secs*1000) - } - if !didSet { - if dx, ok := rw.Unwrap().(*demuxconn); ok { - // udp demuxconn: set on underlying conn - extendr(dx, time.Second*time.Duration(r)) - extendw(dx, time.Second*time.Duration(w)) - didSet = true - } - } - return -} - -func (rw rwext) Unwrap() net.Conn { - return rw.Conn -} - -func (rw rwext) Read(b []byte) (n int, err error) { - rw.extendr() - return rw.Conn.Read(b) -} - -func (rw rwext) Write(b []byte) (n int, err error) { - rw.extendw() - return rw.Conn.Write(b) -} - -// ReadFrom implements core.RetrierConn. -func (rw rwext) ReadFrom(r io.Reader) (n int64, err error) { - switch c := rw.Unwrap().(type) { - case io.ReaderFrom: - // disable read and write deadlines for rw.Conn as - // io.ReaderFrom does not support io.Reader+io.Writer - // semantics which rwext relies on to extend deadlines. - rw.extendForever() - return c.ReadFrom(r) - default: - } - // nb: stream rw (which extends deadlines) not rw.Conn - return core.Stream(rw, r) -} - -// WriteTo implements core.RetrierConn. -func (rw rwext) WriteTo(w io.Writer) (n int64, err error) { - switch c := rw.Unwrap().(type) { - case io.WriterTo: - // disable read and write deadlines for rw.Conn as - // io.WriterTo does not support io.Reader+io.Writer - // semantics which rwext relies on to extend deadlines. - rw.extendForever() - return c.WriteTo(w) - default: - } - // nb: stream rw (which extends deadlines) not rw.Conn - return core.Stream(w, rw) -} - -// SyscallConn implements core.ControlConn. -func (rw rwext) SyscallConn() (syscall.RawConn, error) { - if sc, ok := rw.Unwrap().(syscall.Conn); ok { - return sc.SyscallConn() - } - return nil, syscall.EINVAL -} - -func (rw rwext) deadlines() (r, w uint32) { - dopt := settings.GetDialerOpts() - // -ve ints go higher than 2^31 w/ uint: go.dev/play/p/Rrqk_V8a7W0 - return max(rw.minidle, uint32(dopt.ReadTimeoutSec)), - max(rw.minidle, uint32(dopt.WriteTimeoutSec)) -} - -func (rw rwext) extendForever() { - extendc(rw, 0, 0) -} - -func (rw rwext) extendw() { - _, w := rw.deadlines() - tw := time.Second * time.Duration(w) - - extendw(rw.Conn, tw) -} - -func (rw rwext) extendr() { - r, _ := rw.deadlines() - tr := time.Second * time.Duration(r) - - extendr(rw.Conn, tr) -} diff --git a/intra/settings/config.go b/intra/settings/config.go deleted file mode 100644 index 1015a100..00000000 --- a/intra/settings/config.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package settings - -import ( - "sync/atomic" - - "github.com/celzero/firestack/intra/core" -) - -// NICID is the default network interface card ID for the network stack. -const NICID = 0x01 - -// Debug is a global flag to enable debug behaviour. -var Debug bool = false - -// Loopingback is a global flag to adjust netstack behaviour -// wrt preventing split dialing, closing tunfd without delay etc. -var Loopingback = atomic.Bool{} - -// SingleThreaded is a global flag to run Netstack's packet forwarder -// in a single-threaded mode. -var SingleThreaded = atomic.Bool{} - -// PortForward is a global flag to enable bound to the same port -// for the outgoing conn as the incoming sockisfied conn. -var PortForward = atomic.Bool{} - -// HappyEyeballs is a global flag to enable Happy Eyeballs algorithm -// for dual-stack (IPv4+IPv6) connections. -var HappyEyeballs = atomic.Bool{} - -// ExperimentalWireGuard is a global flag to enable experimental -// settings for WireGuard. -var ExperimentalWireGuard = core.NewForeverFlow(false) - -// FloodWireGuard is a global flag to enable flooding WireGuard -// tunnel with randomly sized non-null packets. -var FloodWireGuard = atomic.Bool{} - -// EndpointIndependentMapping is a global flag to enable endpoint-independent -// mapping for UDP as per RFC 4787. -var EndpointIndependentMapping = atomic.Bool{} - -// EndpointIndependentFiltering is a global flag to enable endpoint-independent -// filtering for UDP as per RFC 4787. -var EndpointIndependentFiltering = atomic.Bool{} - -// SystemDNSForUndelegatedDomains is a global flag to always use System DNS -// for undelegated domains. -var SystemDNSForUndelegatedDomains = atomic.Bool{} - -// DefaultDNSAsFallback is a global flag to allow using the Default transport -// as a fallback when the Preferred transport is missing or paused or ended. -var DefaultDNSAsFallback = atomic.Bool{} - -// SetUserAgent is a global flag to set User-Agent for DoH requests -// to "Intra" and for HTTP "Reaches" checks to the Android default. -var SetUserAgent = atomic.Bool{} - -// Android's default user-agent as set for connectivity checks -// PROBE_HTTPS https://www.google.com/generate_204 time=183ms ret=204 request={Connection=[close], User-Agent=[Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.32 Safari/537.36]} -// headers={null=[HTTP/1.1 204 No Content], Alt-Svc=[h3=":443"; ma=2592000,h3-29=":443"; ma=2592000], Connection=[close], Content-Length=[0], Cross-Origin-Resource-Policy=[cross-origin], -// Date=[Fri, 27 Jun 2025 10:56:24 GMT], X-Android-Received-Millis=[1751021784573], X-Android-Response-Source=[NETWORK 204], X-Android-Selected-Protocol=[http/1.1], X-Android-Sent-Millis=[1751021784495]} -// also: android-developers.googleblog.com/2024/12/user-agent-reduction-on-android-webview.html -const AndroidCcUa = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.32 Safari/537.36" -const IntraUa = "Intra" - -// PanicAtRandom is a global flag to panic the network engine -// every once in a while (for testing). -var PanicAtRandom = atomic.Bool{} - -// OwnTunFd is a global flag to indicate that the TUN fd is fully owned by netstack. -// that is, he TUN FD won't be dup'd and will be closed after use. -var OwnTunFd = atomic.Bool{} diff --git a/intra/settings/dialeropts.go b/intra/settings/dialeropts.go deleted file mode 100644 index 58f6c933..00000000 --- a/intra/settings/dialeropts.go +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package settings - -import ( - "strconv" - "strings" -) - -// DialerOpts define dialer options. -type DialerOpts struct { - // Strat is the dialing strategy. - Strat int32 - // Retry is the retry strategy. - Retry int32 - // LowerKeepAlive is the flag to enable low TCP keep-alive. - // Currently, 600s for idle, 5s for interval, and 4 probes. - LowerKeepAlive bool - // Read timeout for outgoing tcp & udp connections. - ReadTimeoutSec int32 - // Write timeout for outgoing tcp & udp connections. - WriteTimeoutSec int32 -} - -func (d DialerOpts) String() string { - s := func() string { - switch d.Strat { - case SplitAuto: - return "SplitAuto" - case SplitTCP: - return "SplitTCP" - case SplitTCPOrTLS: - return "SplitTCPOrTLS" - case SplitDesync: - return "SplitDesync" - case SplitNever: - return "SplitNever" - default: - return "Unknown" - } - }() - r := func() string { - switch d.Retry { - case RetryNever: - return "RetryNever" - case RetryWithSplit: - return "RetryWithSplit" - case RetryAfterSplit: - return "RetryAfterSplit" - default: - return "Unknown" - } - }() - ka := func() string { - if d.LowerKeepAlive { - return "LowerKeepAlive" - } - return "DefaultKeepAlive" - }() - tmo := func() string { - return strconv.Itoa(int(d.ReadTimeoutSec)) + - "s," + strconv.Itoa(int(d.WriteTimeoutSec)) + - "s" - }() - - return strings.Join([]string{s, r, ka, tmo}, ",") -} - -// Dial strategies -const ( - // SplitAuto is the default dial strategy; chosen by the engine. - SplitAuto int32 = iota - // SplitTCPOrTLS splits first TCP segment or fragments the TLS SNI header. - SplitTCPOrTLS - // SplitTCP splits the first TCP segment. - SplitTCP - // SplitDesync splits the first TCP segment after desynchronizing the connection - // by sending a different, but fixed, first TCP segement to the censor. - SplitDesync - // SplitNever doesn't muck; connects as-is. - SplitNever -) - -// Retry strategies -const ( - // RetryAfterSplit retries connection as-is after split fails. - RetryAfterSplit int32 = iota - // RetryWithSplit ("auto" mode) connects as-is, but retries with split. - RetryWithSplit - // RetryNever never retries. - RetryNever -) - -var dialerOpts = &DialerOpts{} - -// SetDialerOpts sets the dialer options to use. -func SetDialerOpts(strat, retry, timeoutsec int32, keepalive bool) bool { - s := dialerOpts - ok := true - switch strat { - case SplitTCP, SplitTCPOrTLS, SplitDesync, SplitAuto, SplitNever: - s.Strat = strat - default: - s.Strat = SplitAuto - ok = false - } - switch retry { - case RetryNever, RetryWithSplit, RetryAfterSplit: - s.Retry = retry - default: - s.Retry = RetryAfterSplit - ok = false - } - s.LowerKeepAlive = keepalive - if timeoutsec < 0 { - timeoutsec = 0 - } - s.ReadTimeoutSec = timeoutsec - s.WriteTimeoutSec = timeoutsec - return ok -} - -// GetDialerOpts returns current dialer options. -func GetDialerOpts() DialerOpts { - return *dialerOpts -} diff --git a/intra/settings/dnsopts.go b/intra/settings/dnsopts.go deleted file mode 100644 index 2de8dcf3..00000000 --- a/intra/settings/dnsopts.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package settings - -import ( - "errors" - "net" - "net/netip" - "strconv" - "strings" - "sync/atomic" - - "github.com/celzero/firestack/intra/log" -) - -var errDnsOptArg = errors.New("dnsopt: invalid arg") - -// DNSOptions define https or socks5 proxy options -type DNSOptions struct { - idOrHostportOrIpport string // host:port or ip:port; may be empty - hostips string // ips only, comma separated; may be empty - port uint16 // port only; 53 if not set -} - -func (d *DNSOptions) String() string { - if d == nil { - return "" - } - return d.AddrPort() -} - -// AddrPort returns the ip:port or host:port; may return empty string. -func (d *DNSOptions) AddrPort() string { - if len(d.idOrHostportOrIpport) > 0 { - return d.idOrHostportOrIpport - } - return "" -} - -func (d *DNSOptions) Port() uint16 { - return d.port -} - -func (d *DNSOptions) ResolvedAddrs() string { - return d.hostips // TODO: may be ip:port -} - -// NewDNSOptions returns a new DNSOpitons object. -func NewDNSOptions(ipport string) (*DNSOptions, error) { - var ipp netip.AddrPort - var err error - ip, port, err := net.SplitHostPort(ipport) - if err != nil { - return nil, err - } - if ipp, err = addrport(ip, port); err == nil { - return &DNSOptions{ - idOrHostportOrIpport: ipp.String(), - port: ipp.Port(), - }, nil - } - log.D("dnsopt(%s:%s); err(%v)", ip, port, err) - return nil, err -} - -func NewDNSOptionsFromNetIp(ipp netip.AddrPort) (*DNSOptions, error) { - if !ipp.IsValid() { - return nil, errors.New("dnsopt: empty ipport") - } - return &DNSOptions{ - idOrHostportOrIpport: ipp.String(), - port: ipp.Port(), - }, nil -} - -func NewDNSOptionsFromHostname(idOrHostOrHostPort, ipOrIPPortCsv string) (*DNSOptions, error) { - if len(idOrHostOrHostPort) <= 0 { - return nil, errDnsOptArg - } - - idOrDomain, port, _ := net.SplitHostPort(idOrHostOrHostPort) - - if len(idOrDomain) <= 0 { - idOrDomain = idOrHostOrHostPort - } - - portFromHostPort := len(port) > 0 - portu16 := uint16(53) - if portFromHostPort { - if u64, _ := strconv.ParseUint(port, 10, 16); u64 > 0 { - portu16 = uint16(u64) - } else { - port = "53" - portFromHostPort = false // as if len(port) == 0 - } - } - - ips := make([]string, 0) - ports := make([]uint16, 0) - for ipp := range strings.SplitSeq(ipOrIPPortCsv, ",") { - if addr, err := netip.ParseAddrPort(ipp); err == nil { - ips = append(ips, addr.Addr().String()) - if port := addr.Port(); port > 0 { - ports = append(ports, port) - } - } else if addr, err := netip.ParseAddr(ipp); err == nil { - ips = append(ips, addr.String()) - } else { - log.W("dnsopt: invalid ip/ipport for %s; ipp(%s); err(%v)", idOrHostOrHostPort, ipp, err) - } - } - - portFromIPPort := len(ports) > 0 - if portFromHostPort { - // skip other checks - } else if portFromIPPort { - // TODO: support multiple ports? - port = strconv.Itoa(int(ports[0])) - portu16 = ports[0] - } else { - // default port - port = "53" - portu16 = 53 - } - - log.I("dnsopt: for %s; len(ips) = %d; port = %s; portFromHostPort? %t; portFromIPPort? %t", - idOrHostOrHostPort, len(ips), port, portFromHostPort, portFromIPPort) - return &DNSOptions{ - idOrHostportOrIpport: net.JoinHostPort(idOrDomain, port), - hostips: strings.Join(ips, ","), // may be empty - port: portu16, - }, nil -} - -// Parse ip and port; where ip can be either ip:port or ip -func addrport(ip string, port string) (ipp netip.AddrPort, err error) { - var ipaddr netip.Addr - var p int - if ipaddr, err = netip.ParseAddr(ip); err == nil { - if p, err = strconv.Atoi(port); err == nil { - ipp = netip.AddrPortFrom(ipaddr.Unmap(), uint16(p)) - return ipp, nil - } - } else if ipp, err = netip.ParseAddrPort(ip); err == nil { - return ipp, nil - } - return ipp, err -} - -const ( - // Use among encrypted dns transports wrapped by dnsx.Plus - PlusFilterSafest = iota - // Use dns transports randomly wrapped by dnsx.Plus - PlusOrderRandom - // Prefer faster (p50 latency) dns transports wrapped by dnsx.Plus - PlusOrderFastest - // Prefer working dns transports wrapped by dnsx.Plus - PlusOrderRobust -) - -var PlusStrat = atomic.Int32{} - -// SetPlusStrategy returns the order strategy for Plus DNS transports. -func SetPlusStrategy(new int) bool { - if new < PlusFilterSafest || new > PlusOrderRobust { - log.W("dnsopt: invalid plus order strategy %d", new) - return false - } - old := PlusStrat.Swap(int32(new)) - log.I("dnsopt: set plus order strategy to %d <= %d", new, old) - return true -} diff --git a/intra/settings/dsopts.go b/intra/settings/dsopts.go deleted file mode 100644 index 1400078c..00000000 --- a/intra/settings/dsopts.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package settings - -// msb to lsb: ipv6, ipv4, lwip(1) or netstack(0) -const ( - Ns4 = 0b010 // 2 - Ns46 = 0b110 // 6 - Ns6 = 0b100 // 4 -) - -// IP4, IP46, IP6 are string'd repr of Ns4, Ns46, Ns6 -const ( - IP4 = "4" - IP46 = "46" - IP6 = "6" -) - -// L3 returns the string'd repr of engine. -func L3(engine int) string { - switch engine { - case Ns46: - return IP46 - case Ns6: - return IP6 - default: - return IP4 - } -} diff --git a/intra/settings/proxyopts.go b/intra/settings/proxyopts.go deleted file mode 100644 index a18d8f62..00000000 --- a/intra/settings/proxyopts.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package settings - -import ( - "net" - "net/url" - "strings" - "sync/atomic" - - "github.com/celzero/firestack/intra/log" - "golang.org/x/net/proxy" -) - -// ProxyOptions define https or socks5 proxy options -type ProxyOptions struct { - Auth *proxy.Auth - IP string // just the ip - Host string // just the hostname (no port) - Port string // just the port number - IPPort string // may be a url or ip:port - Scheme string // http, https, socks5, pip - Addrs []string // list of ips if ipport is a url; may be nil -} - -// NewAuthProxyOptions returns a new ProxyOptions object with authentication object. -func NewAuthProxyOptions(scheme, username, password, ip, port string, addrs []string) *ProxyOptions { - var ippstr string - var ipstr string - var host string - ip = strings.TrimSuffix(ip, "/") - ipp, err := addrport(ip, port) - if err != nil { - log.I("proxyopt: scheme %s; ipport(%s:%s) is url?(%v)", scheme, ip, port, err) - if len(ip) > 0 { - // port is discarded, and expected to be in ip/url - ippstr = ip - host, port, _ = net.SplitHostPort(ip) - } else if len(port) > 0 { - // incoming ip,port is a wildcard address - ippstr = ":" + port - } else { - return nil - } - } else { - ippstr = ipp.String() - ipstr = ipp.Addr().String() - } - if len(username) <= 0 || len(password) <= 0 { - log.I("proxyopt: no user(%s) and/or pwd(%d)", username, len(password)) - } - if len(scheme) <= 0 { - scheme = "http" - } - // todo: query unescape username and password? - auth := proxy.Auth{ - User: username, - Password: password, - } - return &ProxyOptions{ - Auth: &auth, - Host: host, // may be empty or hostname (without port) - IP: ipstr, // may be empty or ipaddr - Port: port, // port number - IPPort: ippstr, // may be ip4:port, [ip::6]:port, host:port, or :port - Scheme: scheme, - Addrs: addrs, // may be empty - } -} - -// NewProxyOptions returns a new ProxyOptions object. -func NewProxyOptions(ip string, port string) *ProxyOptions { - return NewAuthProxyOptions("" /*scheme*/, "" /*user*/, "" /*password*/, ip, port /*addrs*/, nil) -} - -func (p *ProxyOptions) String() string { - if p == nil { - return "" - } - return p.Auth.User + "," + p.Auth.Password + "," + p.IPPort -} - -// HasAuth returns true if p has auth params. -func (p *ProxyOptions) HasAuth() bool { - return len(p.Auth.User) > 0 && len(p.Auth.Password) > 0 -} - -// FullUrl returns the full url with auth. -func (p *ProxyOptions) FullUrl() string { - if p.HasAuth() { - // superuser.com/a/532530 - usr := url.QueryEscape(p.Auth.User) - pwd := url.QueryEscape(p.Auth.Password) - return p.Scheme + "://" + usr + ":" + pwd + "@" + p.IPPort - } else if len(p.Auth.User) > 0 { - usr := url.QueryEscape(p.Auth.User) - return p.Scheme + "://" + usr + "@" + p.IPPort - } - return p.Url() -} - -// Url returns the url without auth. -func (p *ProxyOptions) Url() string { - return p.Scheme + "://" + p.IPPort -} - -// AutoMode is a global variable to instruct if backend.Auto proxy -// is in local, remote, or hybrid mode. In local mode, backend.Auto -// uses local proxies (ex: ipn.Exit) only. In remote mode, -// backend.Auto uses remote proxies (ex: RPN). -var AutoMode atomic.Int32 - -type AutoModeType int32 - -const ( - // local mode: backend.Auto uses local proxies (ex: ipn.Exit) only. - AutoModeLocal int32 = iota - // remote mode: backend.Auto uses remote proxies (ex: RPN) only. - AutoModeRemote - // hybrid mode: backend.Auto uses local and remote proxies. - AutoModeHybrid -) - -func (m AutoModeType) String() string { - switch int32(m) { - case AutoModeLocal: - return "local" - case AutoModeRemote: - return "remote" - case AutoModeHybrid: - return "hybrid" - default: - return "unknown" - } -} - -// SetAutoMode sets the global AutoMode variable to m. -// Indicates if backend.Auto proxy is in local, remote, or hybrid mode. -func SetAutoMode(m int32) (prev int32) { - m = max(m, AutoModeLocal) - m = min(m, AutoModeHybrid) - return AutoMode.Swap(m) -} - -func AutoModeStr() string { - return AutoModeType(AutoMode.Load()).String() -} - -// backend.Auto must use remote proxies and never use local (ex: ipn.Exit) ones. -func AutoAlwaysRemote() bool { - return AutoMode.Load() == AutoModeRemote -} - -// backend.Auto is effecively not active. -func AutoActive() bool { - return AutoMode.Load() != AutoModeLocal -} - -// AutoDialsParallel is a global variable to instruct ipn.Auto proxy -// to use parallel dialing for all proxies. -var AutoDialsParallel atomic.Bool - -// SetAutoDialsParallel puts backend.Auto in parallel-dial mode if y is true. -// That is, backend.Auto will dial all (available) RPN proxies in parallel. -func SetAutoDialsParallel(y bool) (prev bool) { - return AutoDialsParallel.Swap(y) -} diff --git a/intra/settings/tunopts.go b/intra/settings/tunopts.go deleted file mode 100644 index aeaa559f..00000000 --- a/intra/settings/tunopts.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2025 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package settings - -import ( - "strconv" - "strings" - "sync/atomic" -) - -// TODO: These modes could be covered by bit-flags instead. - -const ( - // DNSModeNone does not redirect DNS queries sent to the tunnel. - DNSModeNone int32 = 0 - // DNSModeIP redirects DNS requests sent to the IP endpoint set by VPN. - DNSModeIP int32 = 1 - // DNSModePort redirects all DNS requests on port 53. - DNSModePort int32 = 2 -) - -const ( - // BlockModeNone filters no packet. - BlockModeNone int32 = 0 - // BlockModeFilter filters packets on connection establishment. - BlockModeFilter int32 = 1 - // BlockModeSink blackholes all packets. - BlockModeSink int32 = 2 - // BlockModeFilterProc determines owner-uid of a tcp/udp connection - // from procfs before filtering - BlockModeFilterProc int32 = 3 -) - -const ( - // PtModeAuto does not enforce (but may still use) 6to4 protocol translation. - PtModeAuto int32 = 0 - // PtModeForce64 enforces 6to4 protocol translation. - PtModeForce64 int32 = 1 - // Android implements 464Xlat out-of-the-box, so this zero userspace impl - PtModeNo46 int32 = 2 -) - -// Converts a given DNS/Block/Pt mode to its string representation. -// typ is one of "dns", "block", "pt"; mode is the value to convert. -func Mode2String(typ string, mode int32) string { - str := func() string { - switch strings.ToLower(typ) { - case "dns": - switch mode { - case DNSModeNone: - return "none" - case DNSModeIP: - return "ip" - case DNSModePort: - return "port" - } - case "block": - switch mode { - case BlockModeNone: - return "none" - case BlockModeFilter: - return "filter" - case BlockModeSink: - return "sink" - case BlockModeFilterProc: - return "filterproc" - } - case "pt": - switch mode { - case PtModeAuto: - return "auto" - case PtModeForce64: - return "force64" - case PtModeNo46: - return "no46" - } - } - return "unknown" - }() - return strings.Join([]string{typ, str, strconv.Itoa(int(mode))}, " ") -} - -// DNSMode specifies the kind of DNS traffic to be trapped and routed to DoH servers -var DNSMode atomic.Int32 - -// BlockMode instructs change in firewall behaviour. -var BlockMode atomic.Int32 - -// PtMode determines 6to4 translation heuristics. -var PtMode atomic.Int32 - -// SetMode re-assigns d to DNSMode, b to BlockMode, pt to NatPtMode. -func SetTunMode(d, b, pt int32) { - DNSMode.Store(d) - BlockMode.Store(b) - PtMode.Store(pt) -} - -// DefaultTunMode returns a new default TunMode with -// IP-only DNS capture and replay (not all DNS traffic but -// only the DNS traffic sent to [tcp/udp]handler.fakedns -// is captured and replayed to the remote DoH server) -// and with firewall disabled. -func DefaultTunMode() { - SetTunMode(DNSModeIP, BlockModeNone, PtModeNo46) -} - -// DupTunFd instructs whether the TUN fd should be duplicated -// (netstack to own a clone of the TUN fd, and will not -// assume ownership of the TUN fd shared with it). -func DupTunFd(yn bool) (prev bool) { - return !OwnTunFd.Swap(!yn) -} - -func init() { - DefaultTunMode() -} diff --git a/intra/tcp.go b/intra/tcp.go deleted file mode 100644 index a6f1acb2..00000000 --- a/intra/tcp.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Derived from go-tun2socks's "direct" handler under the Apache 2.0 license. - -package intra - -import ( - "context" - "errors" - "net" - "net/netip" - "sync" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" -) - -type tcpHandler struct { - *baseHandler - nat *tcpNat -} - -type tcpNat struct { - sync.Mutex - m map[string]map[netip.AddrPort]netip.AddrPort // proxyID => src => ext -} - -func newTCPNat() *tcpNat { - return &tcpNat{m: make(map[string]map[netip.AddrPort]netip.AddrPort)} -} - -func (t *tcpNat) assoc(pid string, src, ext netip.AddrPort) { - if t == nil || len(pid) == 0 || !sameFamily(src.Addr(), ext.Addr()) { - return - } - t.Lock() - defer t.Unlock() - - m := t.m[pid] - if m == nil { - m = make(map[netip.AddrPort]netip.AddrPort) - t.m[pid] = m - } - m[src] = ext -} - -func (t *tcpNat) lookup(pid string, src netip.AddrPort) (zz netip.AddrPort, ok bool) { - if t == nil || len(pid) == 0 || !src.IsValid() { - return - } - t.Lock() - defer t.Unlock() - - if m := t.m[pid]; m != nil { - if ext, ok := m[src]; ok { - return ext, true - } - } - return -} - -type ioinfo struct { - bytes int64 - err error -} - -const ( - retryTimeout = 15 * time.Second - - // onFlowTimeout takes in to account "testWithBackoff" on Kolin side (which is around 9s) - onFlowTimeout = 10 * time.Second - onPreFlowTimeout = 5 * time.Second - onInFlowTimeout = 5 * time.Second -) - -var ( - errTcpFirewalled = errors.New("tcp: firewalled") - errTcpSetupConn = errors.New("tcp: could not create conn") -) - -var _ netstack.GTCPConnHandler = (*tcpHandler)(nil) - -// NewTCPHandler returns a TCP forwarder with Intra-style behavior. -// Connections to `fakedns` are redirected to DOH. -// All other traffic is forwarded using `dialer`. -// `listener` is provided with a summary of each socket when it is closed. -func NewTCPHandler(pctx context.Context, resolver dnsx.Resolver, prox ipn.ProxyProvider, listener SocketListener) netstack.GTCPConnHandler { - if listener == nil || core.IsNil(listener) { - log.W("tcp: using noop listener") - listener = nooplistener - } - - h := &tcpHandler{ - baseHandler: newBaseHandler(pctx, dnsx.NetTypeTCP, resolver, prox, listener), - nat: newTCPNat(), - } - - core.Gx("tcp.ps", h.processSummaries) - - log.I("tcp: new handler created") - return h -} - -// Error implements netstack.GTCPConnHandler. -// It must be called from a goroutine. -func (h *tcpHandler) Error(gconn *netstack.GTCPConn, src, dst netip.AddrPort, err error) { - err = log.EE("tcp: error: %s => %s; err %v", src, dst, err) - if !src.IsValid() || !dst.IsValid() { - return - } - res, undidAlg, realips, domains := h.onFlow(src, dst) - - h.maybeReplaceDest(res, &dst) - - cid, uid, fid, pids := h.judge(res) - smm := tcpSummary(cid, uid, src.Addr(), dst.Addr()) - - if isAnyBlockPid(pids) { - smm.PID = ipn.Block - - if undidAlg && len(realips) <= 0 && len(domains) > 0 { - err = core.JoinErr(err, errNoIPsForDomain) - } else { - err = core.JoinErr(errTcpFirewalled, err) - } - core.Go("tcp.stall."+fid, func() { - defer clos(gconn) - defer h.queueSummary(smm.done(err)) - secs := h.stall(fid) - log.I("tcp: error: %s firewalled from %s => %s (dom: %s / real: %s) for %s; stall? %ds; err %v", - cid, src, dst, domains, realips, uid, secs, err) - }) - return - } - - h.queueSummary(smm.done(err)) -} - -func (h *tcpHandler) ReverseProxy(gconn *netstack.GTCPConn, in net.Conn, to, from netip.AddrPort) (open bool) { - fm := h.onInflow(to, from) - cid, uid, _, pids := h.judge(fm) - smm := tcpSummary(cid, uid, to.Addr(), from.Addr()) - - if settings.Debug { - log.VV("tcp: %s [%s]: reverse: %s => %s; pids: %v", cid, uid, from, to, pids) - } - - if isAnyBlockPid(pids) { - log.I("tcp: %s [%s]: reverse: block %s => %s", cid, uid, from, to) - clos(gconn, in) - h.queueSummary(smm.done(errUdpInFirewalled)) - return true - } // else: pid is ipn.Ingress - - // handshake; since we assume a duplex-stream from here on - if open, err := gconn.Establish(); !open { - err = log.EE("tcp: %s [%s]: reverse: gconn.Est, err %v; %s => %s for %s", - cid, uid, err, to, from, uid) - h.queueSummary(smm.done(err)) - return false - } - - core.Go("tcp.reverse:"+cid, func() { - h.forward(gconn, rwext{in, tcptimeout}, smm) - }) - return true -} - -func (h *tcpHandler) handshakeIfNeededOrClose(gconn *netstack.GTCPConn, smm *SocketSummary) (bool, error) { - const allow bool = true // allowed - const deny bool = !allow // blocked - - if _, err := gconn.Establish(); err != nil { - err = log.EE("tcp: %s handshake err %v; %s => %s for %s", - smm.ID, err, smm.Source, smm.Target, smm.UID) - // clos(gconn) - // h.queueSummary(smm.done(err)) - return deny, err // == !open - } - return allow, nil -} - -func (h *tcpHandler) natAssoc(pid string, src netip.AddrPort, addr net.Addr) { - if ext := netAddrPort(addr); ext.IsValid() { - h.nat.assoc(pid, src, ext) - } -} - -func (h *tcpHandler) natLookup(pid string, src, target netip.AddrPort) (zz netip.AddrPort) { - if ext, ok := h.nat.lookup(pid, src); ok && sameFamily(ext.Addr(), target.Addr()) { - return ext - } - return -} - -func netAddrPort(addr net.Addr) (zz netip.AddrPort) { - if addr == nil { - return - } - if v, ok := addr.(*net.TCPAddr); ok { - return v.AddrPort() - } else if ap, err := netip.ParseAddrPort(addr.String()); err == nil { - return ap - } - return -} - -func sameFamily(a, b netip.Addr) bool { - if !a.IsValid() || !b.IsValid() { - return false - } - return a.Is4() == b.Is4() -} - -// Proxy implements netstack.GTCPConnHandler -// It must be called from a goroutine. -func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort) (open bool) { - const allow bool = true // allowed - const deny bool = !allow // blocked - var smm *SocketSummary - var err error - - defer core.Recover(core.Exit11, "tcp.Proxy") - - if !src.IsValid() || !target.IsValid() { - log.E("tcp: nil addr %s => %s; close err? %v", src, target, err) - clos(gconn) // gconn may be nil - return deny - } - - // flow/dns-override are nat-aware, as in, they can deal with - // nat-ed ips just fine, and so, use target as-is instead of ipx4 - res, undidAlg, realips, domains := h.onFlow(src, target) - - h.maybeReplaceDest(res, &target) - - // TODO: use res.IP only if set - filtered, excluded, fallingback := filterFamilyForDialingWithFailSafe(realips) - actualTargets := makeIPPorts(filtered, target, !undidAlg, 0) - cid, uid, fid, pids := h.judge(res, domains, target.String()) - - if len(actualTargets) <= 0 { // unlikely - actualTargets = []netip.AddrPort{target} - } - - // actualTargets[0] may be same as target - smm = tcpSummary(cid, uid, src.Addr(), actualTargets[0].Addr()) - - if h.status.Load() == HDLEND { - err = log.EE("tcp: proxy: %s end %s => %s [%v]", cid, src, target, actualTargets) - clos(gconn) - h.queueSummary(smm.done(err)) - return deny - } - - if isAnyBlockPid(pids) { - smm.PID = ipn.Block - if undidAlg && len(realips) <= 0 && len(domains) > 0 { - err = errNoIPsForDomain - } else { - err = errTcpFirewalled - } - core.Go("tcp.stall."+fid, func() { - defer clos(gconn) - defer h.queueSummary(smm.done(err)) - secs := h.stall(fid) - log.I("tcp: %s firewalled from %s => %s (dom: %s / real: %s) for %s; stall? %ds", - cid, src, target, domains, realips, uid, secs) - }) - return deny - } - - is6 := target.Addr().Is6() || src.Addr().Is6() - happyeyeballs := settings.HappyEyeballs.Load() - delayForHappyEyeballs := happyeyeballs && is6 - - if isAnyBasePid(pids) && h.isDNS(target) { // see udp.go:Connect - synack, synackerr := h.handshakeIfNeededOrClose(gconn, smm) - if !synack { - // if IPv6, stall a bit more so apps doing HappyEyeballs will try IPv4 - if delayForHappyEyeballs { - time.Sleep(400 * time.Millisecond) - } - clos(gconn) - h.queueSummary(smm.done(synackerr)) - return deny - } - if h.dnsOverride(gconn, uid, smm) { - // SocketSummary not sent here; x.DNSSummary supercedes it. - // conn closed by the overriding dns resolver code - return allow - } // else not a dns request - } // if ipn.Exit then let it connect as-is (aka exit) - - if settings.Debug { - log.VV("tcp: %s proxying %s => %s [%v] (excluded: %v) for %s; pids: %s", - cid, src, target, actualTargets, excluded, uid, pids) - } - - cont := true - // pick all realips to connect to - for i, dstipp := range actualTargets { - // dstipp may be v4 or v6 regardless of target addr - targetstr := dstipp.Addr().String() - - var px ipn.Proxy = nil - px, err = h.prox.ProxyTo(dstipp, uid, pids) - - // last chosen (but not dialed in) proxy (which error) - smm.Target = targetstr // addr may be invalid - smm.PID = pidstr(px) // px may be nil - smm.RPID = ipn.ViaID(px) - - if err != nil || px == nil { - err = log.WE("tcp: dial: #%d: %s proxy(%s) to dst(%s) for %s; err %v", - i, cid, pidstr(px), dstipp, uid, err) - continue - } - - if cont, err = h.handle(px, gconn, src, dstipp, delayForHappyEyeballs, smm); err == nil { - return allow // smm instead queued by handle() => forward() - } else { - end := time.Since(smm.start) - err = log.WE("tcp: dial: #%d: %s failed; addr(%s) / fallback? %t / cont? %t / he? %t; for uid %s (%s); w err(%v)", - i, cid, dstipp, fallingback, cont, happyeyeballs, uid, core.FmtPeriod(end), err) - if !cont || end > retryTimeout { - break // return err - } // else: continue; try the next realip - } - } - - // if IPv6, stall a bit more so apps doing HappyEyeballs will try IPv4 - if delayForHappyEyeballs { - time.Sleep(400 * time.Millisecond) - } - h.queueSummary(smm.done(err)) - clos(gconn) // denied - return deny -} - -// handle connects to the target via the proxy, and pipes data between the src, target; thread-safe. -func (h *tcpHandler) handle(px ipn.Proxy, gconn *netstack.GTCPConn, src, target netip.AddrPort, errOnNoRoute bool, smm *SocketSummary) (cont bool, err error) { - cont = true - stop := !cont - targetstr := target.String() - - if errOnNoRoute { - if canroute := px.Router().Contains(targetstr); !canroute { - // make sure to not delay in HappyEyeballs scenario? - return cont, log.WE("proxy(%s) has no route to %s (<= %s)", pidstr(px), targetstr, src) - } - } - - var bindAddr netip.AddrPort - pid := pidstr(px) - eim := settings.EndpointIndependentMapping.Load() - portfwd := settings.PortForward.Load() - canportfwd := portfwd && ipn.Remote(pid) - - if eim { // bindAddr may be invalid - bindAddr = h.natLookup(pid, src, target) - } - if !bindAddr.IsValid() && canportfwd { // port forwarding overriden by eim - bindAddr = makeAnyAddrPort(src) - } - - var pc protect.Conn - var dst net.Conn - - start := time.Now() - - if settings.Debug { - log.VV("tcp: %s dial %s: attempt(eim? %t / fwd? %t / canfwd? %t): %s [%s [%s]] => %s for %s", - smm.ID, pid, eim, portfwd, canportfwd, src, gconn.LocalAddr(), bindAddr, targetstr, smm.UID) - } - - dialbindOK := false - // github.com/google/gvisor/blob/5ba35f516b5c2/test/benchmarks/tcp/tcp_proxy.go#L359 - // ref: stackoverflow.com/questions/63656117 - // ref: stackoverflow.com/questions/40328025 - if bindAddr.IsValid() { - pc, err = px.Dialer().DialBind("tcp", bindAddr.String(), targetstr) - dialbindOK = err == nil - logwif(!dialbindOK)("tcp: %s dialbind ok? %t (%s [%s] => %s via %s); err? %v", - smm.ID, dialbindOK, src, bindAddr, targetstr, pid, err) - } - if !dialbindOK { - pc, err = px.Dialer().Dial("tcp", targetstr) - } - if err == nil { - smm.Rtt = time.Since(start).Milliseconds() - switch uc := pc.(type) { - case *net.TCPConn: // usual - dst = uc - case *gonet.TCPConn: // from wgproxy - dst = uc - case core.DuplexConn: // using retrier (local proxies like: exit & base) - dst = uc - case core.TCPConn: // from confirming proxy dialers - dst = uc - case net.Conn: // from non-confirming proxy dialers - // TODO: log warn? - dst = uc - default: - err = errTcpSetupConn - } - } - - // pc.RemoteAddr may be that of the proxy, not the actual dst - // ex: pc.RemoteAddr is 127.0.0.1 for Orbot - smm.Target = target.Addr().String() - smm.PID = pidstr(px) - smm.RPID = ipn.ViaID(px) - - if err != nil { - clos(pc) - log.W("tcp: err dialing %s proxy(%s) %v [%v] => %v (bind? %t) for %s: %v", - smm.ID, smm.PID, src, bindAddr, smm.Target, dialbindOK, smm.UID, err) - return cont, err - } - - if _, synackerr := h.handshakeIfNeededOrClose(gconn, smm); synackerr != nil { - clos(pc) - return stop, synackerr - } - - core.Go("tcp.forward."+smm.ID, func() { - h.listener.PostFlow(smm.postMark()) - h.forward(gconn, rwext{dst, tcptimeout}, smm) // src always *gonet.TCPConn - // TODO assoc if forward was successful - if eim { - h.natAssoc(smm.PID, src, dst.LocalAddr()) - } - }) - return cont, nil // handled; takes ownership of src -} diff --git a/intra/tun2socks.go b/intra/tun2socks.go deleted file mode 100644 index 03440776..00000000 --- a/intra/tun2socks.go +++ /dev/null @@ -1,395 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package intra - -import ( - "context" - "os" - "path/filepath" - "runtime" - "runtime/debug" - "sync/atomic" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/rnet" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/sys/unix" - - "github.com/celzero/firestack/intra/log" -) - -// pkg.go.dev/runtime#hdr-Environment_Variables -type traceout string - -type Console x.Console -type Controller x.Controller -type ProxyListener x.ProxyListener -type DNSListener x.DNSListener -type ServerListener rnet.ServerListener - -const ( - one traceout = "single" // offending go routine - usr traceout = "all" // all user go routines - sys traceout = "system" // all user + system go routines - abrt traceout = "crash" // GOOS-specific crash after tracing -) - -func (t traceout) s() string { return string(t) } - -const minMemLimit = 512 * 1024 * 1024 // 512MiB -const maxMemLimit = 4 * 1024 * 1024 * 1024 // 4GiB - -func init() { - // increase garbage collection frequency: archive.is/WQBf7 - debug.SetGCPercent(50) - debug.SetMemoryLimit(maxMemLimit) - debug.SetPanicOnFault(true) -} - -// SetupConsole wires up firestack's logger to bdg. -func SetupConsole(console Console) { - ctx := context.Background() - - logch := make(chan bool, 1) - crashch := make(chan bool, 1) - go func() { - logfd := false - if r, c, err := log.NewFilebased(); err == nil { - closeall := func() { - core.Close(c) - core.Close(r) - } - if logfd = console.LogFD(int(r.Fd())); logfd { - log.SetConsole(ctx, c) - context.AfterFunc(ctx, closeall) - } else { - closeall() - } - } - if !logfd { - log.SetConsole(ctx, &clogAdapter{console}) - } - log.D("tun: <<< console >>>; log out ok; fd? %t", logfd) - logch <- logfd - }() - - go func() { - crashfd := pipeCrashOutput(console) - crashch <- crashfd - log.D("tun: <<< console >>>; crash out ok; fd? %t", crashfd) - }() - - logfd := <-logch - crashfd := <-crashch - crashpiped.Store(crashfd) - - log.ConsoleReady(ctx) - - log.I("tun: <<< console >>>; logger: ok; fds (log? %t / crash? %t)", logfd, crashfd) -} - -// Connect creates firestack-administered tunnel. -// `fd` is the TUN device. The tunnel acquires an additional reference to it, which is -// released by Disconnect(), so the caller must close `fd` and Disconnect() to close the TUN device. -// `linkmtu` is the MTU of the underlying link (actual network). If <= 0, it is assumed to be same as `tunmtu`. -// `tunmtu` is the MTU of the TUN device. This can be "faked", ie set to values larger than linkmtu. Typically, its value is same as `linkmtu`. -// `ifaddrs` is a comma-separated list of interface addresses with prefix lengths, "ip/prefixlen". -// `fakedns` is a comman-separated list of the nameservers that the system believes it is using, in "host:port" style. -// `bdg` is a kotlin object that implements the Bridge interface. -// `dtr` is the DefaultDNS (see: intra.NewDefaultDNS); can be nil. Changeable via intra.AddDefaultTransport. -// Throws an exception if the TUN file descriptor cannot be opened, or if the tunnel fails to -// connect. -func Connect(fd, linkmtu, tunmtu int, ifaddrs, fakedns string, dtr DefaultDNS, bdg Bridge) (t Tunnel, err error) { - if linkmtu <= 0 { - NewTunnel(fd, tunmtu, ifaddrs, fakedns, dtr, bdg) - } - return NewTunnel2(fd, linkmtu, tunmtu, ifaddrs, fakedns, dtr, bdg) -} - -// Connect2 is like Connect, but assumes defaults for linkmtu, ifaddrs, and fakedns -// as -1, ["10.111.222.1/24", "fd66:f83a:c650::0/120"], and ["10.111.222.3", "fd66:f83a:c650::3"] -// respectively. -func Connect2(fd, tunmtu int, dtr DefaultDNS, bdg Bridge) (t Tunnel, err error) { - // usually, 10.111.222.0/24 / [fd66:f83a:c650::1]/120 - // github.com/celzero/rethink-app/blob/59aa0daae/app/src/main/java/com/celzero/bravedns/service/BraveVPNService.kt#L2813 - return Connect(fd, -1, tunmtu, "10.111.222.1/24,fd66:f83a:c650::1/120", "10.111.222.3,fd66:f83a:c650::3", nil, bdg) -} - -// Connect3 is like Connect2, but does not require passing a Default DNS resolver. -// The tunnel will instead attempt to use the system DNS resolver (best effort). -func Connect3(fd, tunmtu int, bdg Bridge) (t Tunnel, err error) { - return Connect2(fd, tunmtu, nil, bdg) -} - -// ControlledRouter creates a [backend.Router] over a [backend.Internet] proxy (like [backend.Exit]), -// but one that uses custom Controller c. id and addrport are used only for -// diagnostics and logging, and could be left empty. Typical usage is to use -// Router.Reaches() to check if a host:port is reachable over this Controller c. -func ControlledRouter(c Controller, id, addrport string) x.Router { - return ipn.NewExitProxyWithID(id, addrport, context.Background(), c).Router() -} - -// Change log level to very verbose (0), verbose (1), debug (2), info (3), warn (4), error (5), -// stacktraces (6), user notifications (7), or no logs (8). gologLevel and consolelogLevel can -// be set independently; ex: LogLevel(2, 6) or LogLevel(8, 0) etc. -func LogLevel(gologLevel, consolelogLevel int32) { - dlvl := log.LevelOf(gologLevel) - clvl := log.LevelOf(consolelogLevel) - log.SetLevel(dlvl) - log.SetConsoleLevel(clvl) - - dbg := dlvl <= log.DEBUG || clvl <= log.DEBUG - verbose := dlvl <= log.VERBOSE || clvl <= log.VERBOSE - settings.Debug = dbg - - // turn off runtime's internal "secure mode" to enable tracebacks - prevsm := core.SecureMode(false /*off*/) - // traceback is always set to "crash" for c-shared / c-archive buildmodes - // github.com/golang/go/blob/fed3b0a298/src/runtime/runtime1.go#L586 - // gomobile builds a c-shared gojnilib: - // github.com/golang/mobile/blob/2553ed8ce2/cmd/gomobile/bind_androidapp.go#L393 - prevtraceback, _ := core.GetRuntimeEnviron("GOTRACEBACK") - newtraceback := one.s() - if verbose { - newtraceback = sys.s() - } else if dbg { - newtraceback = usr.s() - } - core.SetRuntimeEnviron("GOTRACEBACK", newtraceback) - debug.SetTraceback(newtraceback) - curtraceback, _ := core.GetRuntimeEnviron("GOTRACEBACK") - - core.RuntimeFinishDebugVarsSetup() - - gotracelevel, gotraceall, gotracecrash := core.RuntimeGotraceback() - - log.I("tun: new levels; golog: %d, consolelog: %d; debug? %t; traceback: %s => %s => %s (l: %d / a? %t / c? %t); sm? %t", - dlvl, clvl, dbg, prevtraceback, newtraceback, curtraceback, gotracelevel, gotraceall, gotracecrash, prevsm) -} - -// FlightRecorder starts Go runtime's flight recorder if y is true, -// and stops it if y is false. The contents of the flight recorder -// (limited to 15s) is written to log.Console on panics. Thread-safe. -// go.dev/blog/flight-recorder -func FlightRecorder(y bool) (bool, error) { - return core.Record(y) -} - -// LowMem triggers garbage collection cycle & allows for -// setting maximum memory limit, if limit > 0. -// github.com/golang/proposal/blob/master/design/48409-soft-memory-limit.md -func LowMem(limitBytes int64) { - limitBytes = max(limitBytes, minMemLimit) - prevLimit := debug.SetMemoryLimit(limitBytes) - go debug.FreeOSMemory() - log.I("tun: lowmem; limits => new: %d, prev: %d", limitBytes, prevLimit) -} - -// Slowdown sets the TUN forwarder in single-threaded mode. -func Slowdown(y bool) { - ok := settings.SingleThreaded.CompareAndSwap(!y, y) - log.I("tun: slowdown? %t / ok? %t", y, ok) -} - -// ExperimentalWireGuard enables/disables experimental features for WireGuard like allowing incoming packets. -func ExperimentalWireGuard(y bool) { - // todo: move to its own method - wgok := settings.ExperimentalWireGuard.CompareAndSwap(!y, y) - // PortForwarding does not work on Android as-is. - // fwdok := settings.PortForward.CompareAndSwap(!y, y) - fwdok := false - log.I("tun: experimental settings? %t / wg? %t, portfwd? %t", y, wgok, fwdok) -} - -// FloodWireGuard enables/disables flooding WireGuard tunnels with randomly sized non-null packets. -func FloodWireGuard(y bool) { - ok := settings.FloodWireGuard.CompareAndSwap(!y, y) - log.I("tun: flood wireguard? %t / ok? %t", y, ok) -} - -// Loopback informs the network stack that it must deal with packets -// originating from its own process routed back into the tunnel. -func Loopback(y bool) { - ok := settings.Loopingback.CompareAndSwap(!y, y) - log.I("tun: loopback? %t / ok? %t", y, ok) -} - -// If set, use SystemDNS to resolve undelegated (.lan, .internal, .arpa etc) domains. -func UndelegatedDomains(useSystemDNS bool) { - ok := settings.SystemDNSForUndelegatedDomains.CompareAndSwap(!useSystemDNS, useSystemDNS) - log.I("tun: resolve undelegated with system DNS? %t / ok? %t", useSystemDNS, ok) -} - -// DefaultDNSAsFallback allows using the Default transport as a fallback when -// the Preferred transport is missing or paused or ended. -func DefaultDNSAsFallback(y bool) { - ok := settings.DefaultDNSAsFallback.CompareAndSwap(!y, y) - log.I("tun: allow default DNS as fallback? %t / ok? %t", y, ok) -} - -// Transparency enables/disables endpoint-independent mapping/filtering. -// Currently applies only for UDP (RFC 4787). -func Transparency(eim, eif bool) { - settings.EndpointIndependentMapping.Store(eim) - settings.EndpointIndependentFiltering.Store(eif) - settings.SetUserAgent.Store(eim || eif) - log.I("tun: eim? %t / eif? %t", eim, eif) -} - -// Build returns the build information. -func Build(full bool) (v string) { - if !full { - v = core.Version() - } else { - v = core.BuildInfo() - } - log.V("tun: build version %s", v) - return v -} - -// PrintStack logs the stack trace of all active goroutines -// to stdout if onConsole is false, otherwise to Console. -// For testing only. -func PrintStack(onConsole bool) { - bptr := core.LOB() - b := *bptr - b = b[:cap(b)] - defer func() { - *bptr = b - core.Recycle(bptr) - }() - if onConsole { - log.C("tun: debug trace (not a crash)", b) - } else { - log.TALL("tun: debug trace (not a crash)", b) - } -} - -// PrintFlightRecord dumps the contents of the flight recorder -// to Console if get is false, or returns the dumped bytes. -// For testing only. Thread-safe. -func PrintFlightRecord(get bool) []byte { - if got, b := core.DumpRecorder(!get /* onConsole */); get && got { - return b.Bytes() - } - return nil -} - -// PanicAtRandom instructs portions under test to panic at random. -// For testing only. -func PanicAtRandom(y bool) { - settings.PanicAtRandom.Store(y) - log.I("tun: panic at random? %t", y) -} - -// Crash causes a crash by panicking on an out-of-bounds slice access. For testing only. -func Crash(afterMs int64) { - go func() { - log.I("tun: crashing in %s", core.FmtMillis(afterMs)) - time.Sleep(time.Duration(afterMs) * time.Millisecond) - var i []int - i[10] = 1 // panic: runtime error: index out of range [10] with length 10 - }() -} - -// global references to keep go's finalizer from cleaning up the FDs -var crashReader, crashWriter, crashRWErr = os.Pipe() -var crashpiped atomic.Bool - -func pipeCrashOutput(c Console) (ok bool) { - if crashRWErr != nil { - log.E("tun: crashout: err pipe: %v", crashRWErr) - return false - } - pipeBuffer256k(crashWriter) - pipeBuffer256k(crashReader) - // defer core.Close(crashReader) // close iff r is dup'd by client code - defer core.Close(crashWriter) // always close as w is dup'd by the runtime - if setCrashFd(crashWriter) && c.CrashFD(int(crashReader.Fd())) { - return true - } - setCrashFd(nil) - return false -} - -// setCrashFd sets dup(f) as output file to write go runtime crashes in to. -func setCrashFd(f *os.File) (ok bool) { - // f is dup()ed by debug.SetCrashOutput before use - err := debug.SetCrashOutput(f, debug.CrashOptions{}) - logei(err)("tun: crashout: set %s, err? %v", fname(f), err) - return err == nil -} - -// SetCrashOutput set crash output to file at fp; returns true if so. -// Disables crash output if fp cannot be opened; and returns false. -func SetCrashOutput(fp string) bool { - p := crashpiped.Swap(false) - ok := setCrashFd(nil) - // if fd not owned by client code - // if p { defer core.Close(crashReader) } - - fout, err := os.OpenFile(filepath.Clean(fp), os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) - - logei(err)("tun: crashout: closed? %t; was piped? %t; f: %s; err? %v", ok, p, fp, err) - - if err == nil { - return setCrashFd(fout) - } - return false -} - -func pipeBuffer256k(f *os.File) bool { - if f == nil { - return false - } - const b256k = 4 * 64 * 1024 - fd := f.Fd() - nom := f.Name() - // kernel may round this up to the nearest page size multiple? - x, err := unix.FcntlInt(fd, unix.F_SETPIPE_SZ, b256k) - if err != nil { - log.W("tun: crashout: pipe (%s %d) err set size %d: %v", nom, fd, x, err) - return false - } - - x, err = unix.FcntlInt(fd, unix.F_GETPIPE_SZ, 0) - if err != nil { - log.W("tun: crashout: pipe (%s %d) err get size: %v", nom, fd, err) - return false - } - - runtime.KeepAlive(f) - log.W("tun: crashout: pipe (%s %d) buffer %s", nom, fd, core.FmtBytes(uint64(x))) - return true -} - -func fname(f *os.File) string { - if f == nil { - return "" - } - return f.Name() -} diff --git a/intra/tunnel.go b/intra/tunnel.go deleted file mode 100644 index b03a1a1f..00000000 --- a/intra/tunnel.go +++ /dev/null @@ -1,628 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package intra - -import ( - "context" - "errors" - "fmt" - "net/netip" - "os" - "runtime" - "strconv" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "github.com/celzero/firestack/intra/rnet" - "github.com/celzero/firestack/intra/settings" - "github.com/celzero/firestack/intra/x64" - "github.com/celzero/firestack/tunnel" -) - -const mktunTimeout = 8 * time.Second - -var bar = core.NewKeyedBarrier[*x.NetStat, string](30 * time.Second) - -var ( - errNoStatCache = errors.New("netstat: stat in cache is nil") - errClosed = errors.New("tunnel closed for business") - errMakeTunnel = errors.New("could not make tunnel") -) - -type Bridge interface { - Listener - Controller -} - -// Listener receives usage statistics when a UDP or TCP socket is closed, -// or a DNS query is completed. -type Listener interface { - SocketListener - DNSListener - ServerListener - ProxyListener -} - -// Tunnel represents an Intra session. -type Tunnel interface { - tunnel.Tunnel - internalCtx() context.Context - - // Get the resolver. - GetResolver() (x.DNSResolver, error) - // Get the internal resolver. - internalResolver() (dnsx.Resolver, error) - // Get proxies. - GetProxies() (x.Proxies, error) - // Get the internal proxies. - internalProxies() (ipn.Proxies, error) - // Get local services. - GetServices() (x.Services, error) - - // SetLinkAndRoutes sets the tun fd as link with mtu & engine as routes for the tunnel. - // where engine is one of the constants (Ns4, Ns6, Ns46) defined in package settings. - SetLinkAndRoutes(fd, mtu, engine int) error - // SetLinkAndRoutes2 is like SetLinkAndRoutes except it runs the tunnel with tunmtu - // & proxies with linkmtu. tunmtu may be a "fake" MTU (assigned to the TUN device) and - // can be quite large, while linkmtu must be the actual MTU of the underlying network - // or min of MTUs of all available underlying networks. - SetLinkAndRoutes2(fd, tunmtu, linkmtu, engine int) error - // SetLinkMtu sets the link MTU (which must the MTU matching the underlying network or - // min of MTUs of all available underlying networks). Link MTU is different from - // tun MTU which must match the TUN device's MTU. Link MTU is used as a hint by - // some Proxy implementations (eg. WireGuard). - SetLinkMtu(linkmtu int) (didchange bool) - // Restart restarts the tunnel with the given fd, linkmtu & tunmtu, and engine. - // fd is the TUN device file descriptor. - // linkmtu is the MTU of the underlying network, tunmtu is the MTU of the TUN device. - // linkmtu can be different from tunmtu. If linkmtu <= 0, it is assumed to be same as tunmtu. - // engine is one of the constants (Ns4, Ns6, Ns46) defined in package settings. - Restart(fd, linkmtu, tunmtu, engine int) error - - // Close connections by pid, cid, uid. - CloseConns(activecsv string) (closedcsv string) - - // Sets pcap output to fpcap which is the absolute filepath - // to which a PCAP file will be written to. - // If len(fpcap) is 0, no PCAP file will be written. - // If len(fpcap) is 1, PCAP be written to stdout. - SetPcap(fpcap string) error - // NIC, IP, TCP, UDP, and ICMP stats. - Stat() (*x.NetStat, error) -} - -type rtunnel struct { - t *core.Volatile[tunnel.Tunnel] - ctx context.Context - done context.CancelFunc - handlers netstack.GConnHandler - proxies ipn.Proxies - resolver dnsx.Resolver - services rnet.Services - linkmtu *core.Volatile[int] - closed atomic.Bool - once sync.Once -} - -var _ Tunnel = (*rtunnel)(nil) - -type clogAdapter struct { - b Console -} - -var _ log.Console = (*clogAdapter)(nil) - -func (l *clogAdapter) Log(lvl log.LogLevel, msg log.Logmsg) { - if bdg := l.b; bdg != nil { - bdg.Log(int32(lvl), msg) // adopt the log message - } -} - -func NewTunnel(fd, tunmtu int, ifaddrs, fakedns string, dtr DefaultDNS, bdg Bridge) (t Tunnel, err error) { - return NewTunnel2(fd, tunmtu, tunmtu, ifaddrs, fakedns, dtr, bdg) -} - -func NewTunnel2(fd, linkmtu, tunmtu int, ifaddrs, fakedns string, dtr DefaultDNS, bdg Bridge) (t Tunnel, err error) { - defer core.Recover(core.Exit11, "i.newTunnel") - - log.D("tun: <<< new >>>; start; fd: %d, tunmtu: %d, linkmtu: %d, ifaddrs: %s, fakedns: %s, dtr? %t, bdg? %t", - fd, tunmtu, linkmtu, ifaddrs, fakedns, dtr != nil, bdg != nil) - - if dtr == nil || core.IsNil(dtr) { - dtr, err = NewBuiltinDefaultDNS() - log.D("tun: using builtin default dns; err? %v", err) - err = nil // used only for logging - } - - if bdg == nil || dtr == nil { - return nil, fmt.Errorf("tun: no bridge? %t or default-dns? %t", bdg == nil, dtr == nil) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - if err != nil { - cancel() - } - }() - - countdown := make(chan struct{}) - defer close(countdown) - - ontimeout := func() { - log.E("tun: <<< new >>>; timed out ...") - cancel() - } - - go core.EitherOr(countdown, ontimeout, mktunTimeout) - - const dualstack = settings.IP46 - - natpt := x64.NewNatPt2(ctx) - proxies := ipn.NewProxifier(ctx, dualstack, linkmtu, bdg, bdg) - services := rnet.NewServices(ctx, proxies, bdg, bdg) - - if proxies == nil || services == nil { - return nil, fmt.Errorf("tun: no proxies? %t or services? %t", - proxies == nil, services == nil) - } - - // kickstart may call into ProxyFor which has a multi-second wait time - // when proxies are not found - if err := dtr.kickstart(proxies); err != nil { - log.W("tun: <<< new >>>; kickstart err(%v)", err) - return nil, err - } - - log.D("tun: <<< new >>>; proxies, svcs, bootstrap: ok") - - resolver := dnsx.NewResolver(ctx, fakedns, dtr, bdg, natpt) - resolver.Add(newGoosTransport(ctx, proxies)) // os-resolver; fixed - resolver.Add(newBlockAllTransport()) // fixed - resolver.Add(newFixedTransport()) // fixed - resolver.Add(newPlusTransport(ctx, resolver)) // fixed - resolver.Add(newDNSCryptTransport(ctx, proxies, bdg)) // fixed - resolver.Add(newMDNSTransport(ctx, dualstack, proxies)) // fixed - - log.D("tun: <<< new >>>; resolvers: ok") - - dialers.IPProtos(dualstack) // assume dual-stack - addIPMapper(ctx, resolver, dualstack) // namespace aware os-resolver for pkg dialers - - var src []netip.Prefix - for s := range strings.SplitSeq(ifaddrs, ",") { - if p, err := netip.ParsePrefix(s); p.IsValid() && err == nil { - src = append(src, p) - } else { - log.W("tun: <<< new >>>; invalid ifaddr %s; err? %v", s, err) - } - } - - // usually, 10.111.222.0/24 / [fd66:f83a:c650::1]/120 - // github.com/celzero/rethink-app/blob/59aa0daae/app/src/main/java/com/celzero/bravedns/service/BraveVPNService.kt#L2813 - if len(src) <= 0 { // default - src = []netip.Prefix{netip.MustParsePrefix("10.111.222.1/24"), netip.MustParsePrefix("fd66:f83a:c650::1/120")} - } - - tcph := NewTCPHandler(ctx, resolver, proxies, bdg) - udph := NewUDPHandler(ctx, resolver, proxies, bdg) - icmph := NewICMPHandler(ctx, resolver, proxies, bdg) - hdl := netstack.NewGConnHandler(src, tcph, udph, icmph) - - log.D("tun: <<< new >>>; protocol handlers: ok") - - gt, revhdl, err := tunnel.NewGTunnel(ctx, fd, tunmtu, dualstack, hdl) - - if gt == nil || err != nil { - log.W("tun: <<< new >>>; err(%v)", err) - return nil, core.OneErr(err, errMakeTunnel) - } - - log.D("tun: <<< new >>>; netstack: ok") - - // TODO: err on reverser errors too? - rerr := proxies.Reverser(revhdl) - - t = &rtunnel{ - t: core.NewVolatile[tunnel.Tunnel](gt), - ctx: ctx, - done: cancel, - handlers: hdl, - proxies: proxies, - resolver: resolver, - services: services, - linkmtu: core.NewVolatile(linkmtu), - } - - log.I("tun: <<< new >>>; tunnel ok; reverser? %v", rerr) - return t, nil -} - -func (t *rtunnel) Disconnect() { - defer core.Recover(core.Exit11, "intra.Disconnect") - - if t.closed.Load() { - log.I("tun: <<< disconnect >>> already closed") - return - } - t.once.Do(func() { - t.closed.Store(true) - t.done() - log.I("tun: <<< disconnect >>>") - }) -} - -func (t *rtunnel) SetLinkMtu(linkmtu int) (didchange bool) { - prev := t.linkmtu.Swap(linkmtu) - mtudiff := prev != linkmtu - logiif(mtudiff)("tun: set link mtu; set(%d) <= prev(%d); refresh protos? %t", linkmtu, prev, mtudiff) - if mtudiff { - core.Gx("i.setLinkMtuRefresh", func() { - t.proxies.RefreshProto("" /*use existing*/, linkmtu, false /*force*/) - }) - } - return mtudiff -} - -func (t *rtunnel) SetLinkAndRoutes(fd, tunmtu, engine int) error { - return t.SetLinkAndRoutes2(fd, tunmtu, tunmtu, engine) -} - -func (t *rtunnel) SetLinkAndRoutes2(fd, tunmtu, linkmtu, engine int) error { - if t.closed.Load() { - log.W("tun: <<< set link and route >>>; already closed") - return errClosed - } - - tunnel := t.t.Load() - - mtudiff := t.linkmtu.Swap(linkmtu) != linkmtu - l3 := settings.L3(engine) - l3diff := dialers.IPProtos(l3) - - err := tunnel.SetLinkAndRoutes(fd, tunmtu, engine) // route is always dual-stack - - if l3diff { - if mdns, err := t.resolver.MDNS(); err == nil { - mdns.RefreshProto(l3) - } - } - - if l3diff || mtudiff { - // TODO: skip refresh on err? - core.Gx("i.setLinkAndRoutesRefresh", func() { - // dialers.IPProtos must always precede calls to other refreshes - // as it carries the global state for dialers and ipn/multihost - t.proxies.RefreshProto(l3, linkmtu, false /*force*/) - }) - } - - return err -} - -func (t *rtunnel) Restart(fd, linkmtu, tunmtu, engine int) error { - if t.closed.Load() { - log.W("tun: <<< restart >>>; for: %d, intra closed", fd) - return errClosed - } - - if linkmtu <= 0 { - linkmtu = tunmtu - } - - countdown := make(chan struct{}) - defer close(countdown) - - ontimeout := func() { - log.E("tun: <<< restart >>>; for: %d, timed out ...", fd) - t.done() - } - - go core.EitherOr(countdown, ontimeout, mktunTimeout) - - dualstack := settings.IP46 - l3 := settings.L3(engine) - l3diff := dialers.IPProtos(l3) - - old := t.t.Load() - old.Disconnect() // could have been disconnected by the client already - - gt, revhdl, err := tunnel.NewGTunnel(t.ctx, fd, tunmtu, dualstack, t.handlers) - - if err != nil || gt == nil || core.IsNil(gt) { - log.W("tun: <<< restart >>>; for: %d, new? %t / mtu? %d; err(%v)", fd, tunmtu, gt != nil, err) - return core.OneErr(err, errMakeTunnel) - } - - // TODO: CompareAndSwap - if !t.t.Cas(old, gt) { // gt never nil - gt.Disconnect() // close the new tunnel - log.W("tun: <<< restart >>>; for: %d (mtu: %d), cas failed; old %X, new %X", fd, tunmtu, old, gt) - } - - // TODO: err on reverser errors too? - rerr := t.proxies.Reverser(revhdl) - - log.I("tun: <<< restart >>>; for: %d (linkmtu: %d / tunmtu: %d), netstack ok; rev err? %v", fd, linkmtu, tunmtu, rerr) - - if l3diff { - if mdns, err := t.resolver.MDNS(); err == nil { - mdns.RefreshProto(l3) - } - } - core.Gx("i.RestartRefresh", func() { - // Refresh proxies to update to the new reverser - t.proxies.RefreshProto(l3, linkmtu, true /*force; reverser changed*/) // also updates reverser - }) - - return err -} - -func (t *rtunnel) internalCtx() context.Context { - return t.ctx -} - -func (t *rtunnel) GetResolver() (x.DNSResolver, error) { - return t.internalResolver() -} - -func (t *rtunnel) internalResolver() (dnsx.Resolver, error) { - ko := t.closed.Load() - if ko || t.resolver == nil { - log.W("tun: <<< get internal resolver >>>; already closed? %t / %t", ko, t.resolver == nil) - return nil, errClosed - } - - return t.resolver, nil -} - -func (t *rtunnel) GetProxies() (x.Proxies, error) { - return t.internalProxies() -} - -func (t *rtunnel) internalProxies() (ipn.Proxies, error) { - ko := t.closed.Load() - if ko || t.proxies == nil { - log.W("tun: <<< get internal proxies >>>; already closed; %t / %t", ko, t.proxies == nil) - return nil, errClosed - } - - return t.proxies, nil -} - -func (t *rtunnel) GetServices() (x.Services, error) { - ko := t.closed.Load() - - if ko || t.proxies == nil { - log.W("tun: <<< get svc >>>; already closed; %t / %t", ko, t.services == nil) - return nil, errClosed - } - - return t.services, nil -} - -func (t *rtunnel) Stat() (*x.NetStat, error) { - if settings.Debug { - // if debugging, bypass the barrier - return t.stat() - } - - v, err := bar.DoIt("stat", func() (*x.NetStat, error) { - return t.stat() - }) - - if err != nil { - return nil, err - } else if v == nil { - return nil, errNoStatCache - } - - return v, nil -} - -func (t *rtunnel) stat() (*x.NetStat, error) { - log.VV("tun: stat: start") - defer log.VV("tun: stat: done") - - tunnel := t.t.Load() - - // NICInfo, NICStat, IPStat, IPFwdStat, TCPStat, UDPStat, ICMPStat, TUNStat - out, err := tunnel.Stat() - - if err != nil { - return nil, err - } - // rdns info - out.RDNSIn.Open = !t.closed.Load() - out.RDNSIn.Debug = settings.Debug - out.RDNSIn.Recording = core.Recording() - out.RDNSIn.Looping = settings.Loopingback.Load() - out.RDNSIn.Slowdown = settings.SingleThreaded.Load() - out.RDNSIn.NewWireGuard = boolstr(settings.ExperimentalWireGuard.Load(), settings.FloodWireGuard.Load()) - out.RDNSIn.HappyEyeballs = settings.HappyEyeballs.Load() - out.RDNSIn.EIMEIF = boolstr(settings.EndpointIndependentMapping.Load(), settings.EndpointIndependentFiltering.Load()) - out.RDNSIn.OwnTunFd = settings.OwnTunFd.Load() - out.RDNSIn.PortForward = settings.PortForward.Load() - out.RDNSIn.Transparency = settings.EndpointIndependentFiltering.Load() - out.RDNSIn.PanicTest = settings.PanicAtRandom.Load() - out.RDNSIn.SetUserAgent = settings.SetUserAgent.Load() - out.RDNSIn.SystemDNSForUndelegated = settings.SystemDNSForUndelegatedDomains.Load() - out.RDNSIn.DefaultDNSAsFallback = settings.DefaultDNSAsFallback.Load() - out.RDNSIn.Dialer4 = dialers.Use4() - out.RDNSIn.Dialer6 = dialers.Use6() - out.RDNSIn.DialerOpts = csv2ssv(settings.GetDialerOpts().String()) - out.RDNSIn.AutoMode = settings.AutoModeStr() - out.RDNSIn.AutoDialsParallel = settings.AutoDialsParallel.Load() - out.RDNSIn.LinkMTU = core.FmtBytes(uint64(t.linkmtu.Load())) - - firewall := settings.Mode2String("block", settings.BlockMode.Load()) - dns := settings.Mode2String("dns", settings.DNSMode.Load()) - pt := settings.Mode2String("pt", settings.PtMode.Load()) - out.RDNSIn.TunMode = fmt.Sprintf("%s;%s;%s", firewall, dns, pt) - - var mm runtime.MemStats - runtime.ReadMemStats(&mm) // stw & expensive - out.GOSt.Alloc = core.FmtBytes(mm.Alloc) - out.GOSt.TotalAlloc = core.FmtBytes(mm.TotalAlloc) - out.GOSt.Sys = core.FmtBytes(mm.Sys) - out.GOSt.Lookups = int64(mm.Lookups) - out.GOSt.Mallocs = int64(mm.Mallocs) - out.GOSt.Frees = int64(mm.Frees) - out.GOSt.HeapAlloc = core.FmtBytes(mm.HeapAlloc) - out.GOSt.HeapSys = core.FmtBytes(mm.HeapSys) - out.GOSt.HeapIdle = core.FmtBytes(mm.HeapIdle) - out.GOSt.HeapInuse = core.FmtBytes(mm.HeapInuse) - out.GOSt.HeapReleased = core.FmtBytes(mm.HeapReleased) - out.GOSt.HeapObjects = int64(mm.HeapObjects) - out.GOSt.StackInuse = core.FmtBytes(mm.StackInuse) - out.GOSt.StackSys = core.FmtBytes(mm.StackSys) - out.GOSt.MSpanInuse = core.FmtBytes(mm.MSpanInuse) - out.GOSt.MSpanSys = core.FmtBytes(mm.MSpanSys) - out.GOSt.MCacheInuse = core.FmtBytes(mm.MCacheInuse) - out.GOSt.MCacheSys = core.FmtBytes(mm.MCacheSys) - out.GOSt.BuckHashSys = core.FmtBytes(mm.BuckHashSys) - out.GOSt.GCSys = core.FmtBytes(mm.GCSys) - out.GOSt.OtherSys = core.FmtBytes(mm.OtherSys) - out.GOSt.NextGC = core.FmtTimeNs(mm.NextGC) - out.GOSt.LastGC = core.FmtTimeNs(mm.LastGC) - out.GOSt.PauseSecs = core.Nano2Sec(mm.PauseTotalNs) - out.GOSt.NumGC = int32(mm.NumGC) - out.GOSt.NumForcedGC = int32(mm.NumForcedGC) - out.GOSt.GCCPUFraction = fmt.Sprintf("%0.4f", mm.GCCPUFraction) - out.GOSt.EnableGC = mm.EnableGC - out.GOSt.DebugGC = mm.DebugGC - - out.GOSt.NumGoroutine = int64(runtime.NumGoroutine()) - out.GOSt.NumCgo = int64(runtime.NumCgoCall()) - out.GOSt.NumCPU = int64(runtime.NumCPU()) - - l, all, crash := core.RuntimeGotraceback() - out.GOSt.Trac = fmt.Sprintf("%d; all? %t; crash? %t", l, all, crash) - - sm1, sm2 := core.RuntimeSecureMode() - uid := fmt.Sprintf("uid=%d", syscall.Getuid()) - pid := fmt.Sprintf("pid=%d", syscall.Getpid()) - pgsz := fmt.Sprintf("pgsz=%d", os.Getpagesize()) - sec := fmt.Sprintf("sec=%t/%t", sm1, sm2) - out.GOSt.Args = strings.Join(append(os.Args, uid, pid, pgsz, sec), ";") - out.GOSt.Env = strings.Join(core.RuntimeEnviron(), ";") - out.GOSt.Pers, _ = os.Executable() - - if r := t.resolver; r != nil { - out.RDNSIn.DNSPreferred = fetchDNSInfo(r, x.Preferred) - out.RDNSIn.DNSDefault = fetchDNSInfo(r, x.Default) - out.RDNSIn.DNSSystem = fetchDNSInfo(r, x.System) - dns := make([]string, 0, 3) - if csv := r.LiveTransports(); len(csv) > 0 { - for tr := range strings.SplitSeq(csv, ",") { - dns = append(dns, fetchDNSInfo(r, tr)) - } - } - out.RDNSIn.DNS = strconv.Itoa(len(dns)) + "\n" + strings.Join(dns, ";") - out.RDNSIn.ALG = t.resolver.S() - } - if p := t.proxies; p != nil { - rr := p.Router() - ss := rr.Stat() - out.RDNSIn.Proxies = csv2ssv(p.LiveProxies()) - out.RDNSIn.ProxiesHas4 = rr.IP4() - out.RDNSIn.ProxiesHas6 = rr.IP6() - if ss.LastOK > 0 { - out.RDNSIn.ProxyLastOK = core.FmtUnixMillisAsPeriod(ss.LastOK) - } else { - out.RDNSIn.ProxyLastOK = "unknown" - } - if ss.Since > 0 { - out.RDNSIn.ProxySince = core.FmtUnixMillisAsPeriod(ss.Since) - } else { - out.RDNSIn.ProxySince = "down" - } - out.RDNSIn.ProxyStatus = ss.Status - } - - out.GOMet.M = core.Metrics() - - return out, nil -} - -// CloseConns implements Tunnel. -func (t *rtunnel) CloseConns(activecsv string) (closedcsv string) { - defer core.Recover(core.Exit11, "i.CloseConns") - - return t.handlers.CloseConns(activecsv) -} - -// Enabled implements Tunnel. -func (t *rtunnel) Enabled() bool { - tunnel := t.t.Load() - return tunnel.Enabled() -} - -// IsConnected implements Tunnel. -func (t *rtunnel) IsConnected() bool { - tunnel := t.t.Load() - return tunnel.IsConnected() -} - -// Mtu implements Tunnel. -func (t *rtunnel) Mtu() int32 { - tunnel := t.t.Load() - return tunnel.Mtu() -} - -// SetPcap implements Tunnel. -func (t *rtunnel) SetPcap(fpcap string) error { - tunnel := t.t.Load() - return tunnel.SetPcap(fpcap) -} - -// Unlink implements Tunnel. -func (t *rtunnel) Unlink() error { - tunnel := t.t.Load() - return tunnel.Unlink() -} - -func boolstr(b ...bool) string { - var sb strings.Builder - for i, v := range b { - if i > 0 { - sb.WriteString("; ") - } - if v { - sb.WriteString("y") - } else { - sb.WriteString("n") - } - } - return sb.String() -} diff --git a/intra/udp.go b/intra/udp.go deleted file mode 100644 index 6b6a58e1..00000000 --- a/intra/udp.go +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Assumes connected udp; see also: github.com/pion/transport/blob/03c807b/udp/conn.go - -package intra - -import ( - "context" - "errors" - "net" - "net/netip" - "time" - - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/ipn" - "github.com/celzero/firestack/intra/netstack" -) - -type udpHandler struct { - *baseHandler - mux *muxTable // EIM/EIF table -} - -var ( - errNoIPsForDomain = errors.New("dns: no ips") - errIcmpFirewalled = errors.New("icmp: firewalled") - errUdpFirewalled = errors.New("udp: firewalled") - errUdpInFirewalled = errors.New("udp: ingress firewalled") - errUdpSetupConn = errors.New("udp: could not create conn") - errUdpIncomingDrop = errors.New("udp: at capacity; packet in dropped") - errUdpUnconnected = errors.New("udp: cannot connect") - errUdpNoTarget = errors.New("udp: no target addr") -) - -const ( - // RFC 4787 REQ-5 requires a timeout no shorter than 5 minutes; but most - // routers do not keep udp mappings for that long (usually just for 30s) - udptimeout = 2 * 60 // seconds - // TODO: 2h 40m? - tcptimeout = 0 // no timeout -) - -var _ netstack.GUDPConnHandler = (*udpHandler)(nil) - -// NewUDPHandler makes a UDP handler with Intra-style DNS redirection: -// All packets are routed directly to their destination. -// `timeout` controls the effective NAT mapping lifetime. -// `config` is used to bind new external UDP ports. -// `listener` receives a summary about each UDP binding when it expires. -func NewUDPHandler(pctx context.Context, resolver dnsx.Resolver, prox ipn.ProxyProvider, listener SocketListener) netstack.GUDPConnHandler { - if listener == nil || core.IsNil(listener) { - log.W("udp: using noop listener") - listener = nooplistener - } - h := &udpHandler{ - baseHandler: newBaseHandler(pctx, dnsx.NetTypeUDP, resolver, prox, listener), - mux: newMuxTable(), - } - - core.Gx("udp.ps", h.processSummaries) - - log.I("udp: new handler created") - return h -} - -func (h *udpHandler) ReverseProxy(gconn *netstack.GUDPConn, in net.Conn, to, from netip.AddrPort) (ok bool) { - fm := h.onInflow(to, from) - cid, uid, _, pids := h.judge(fm) - smm := udpSummary(cid, uid, to.Addr(), from.Addr()) - - if settings.Debug { - log.VV("udp: %s [%s]: reverse: %s => %s; pids: %v", cid, uid, from, to, pids) - } - - if isAnyBlockPid(pids) { - log.I("udp: %s [%s]: reverse: block %s => %s", cid, uid, from, to) - clos(gconn, in) // blocked - h.queueSummary(smm.done(errUdpInFirewalled)) - return true // ok - } // else: pid should only be ipn.Ingress - - if err := gconn.Establish(); err != nil { // gconn.Establish() failed - err = log.EE("udp: %s [%s]: reverse: %s gconn.Est, err %s => %s: %v", cid, uid, to, from, err) - clos(gconn, in) // teardown - h.queueSummary(smm.done(err)) - return false // not ok - } - - core.Go("udp.reverse:"+cid, func() { - h.forward(gconn, rwext{in, udptimeout}, smm) - }) - return true -} - -// ProxyMux implements netstack.GUDPConnHandler -func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) { - defer core.Recover(core.Exit11, "udp.ProxyMux") - return h.proxy(gconn, src, dst, dmx) -} - -// Error implements netstack.GUDPConnHandler. -// Must be called from a goroutine. -func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, target netip.AddrPort, err error) { - defer clos(gconn) // if open - - log.W("udp: error: %v => %v; err %v", src, target, err) - if !src.IsValid() || !target.IsValid() { - return - } - res, undidAlg, realips, domains := h.onFlow(src, target) - h.maybeReplaceDest(res, &target) - cid, uid, fid, pids := h.judge(res) - smm := udpSummary(cid, uid, src.Addr(), target.Addr()) - - if isAnyBlockPid(pids) { - smm.PID = ipn.Block - if undidAlg && len(realips) <= 0 && len(domains) > 0 { - err = core.JoinErr(errNoIPsForDomain, err) - } else { - err = core.JoinErr(errUdpFirewalled, err) - } - core.Go("udp.stall."+fid, func() { - defer clos(gconn) - defer h.queueSummary(smm.done(err)) - secs := h.stall(fid) - log.I("udp: error: %s [%s] firewalled from %s => %s (dom: %s / real: %s) for %s; stall? %ds", - cid, uid, src, target, domains, realips, uid, secs) - }) - return - } - - h.queueSummary(smm.done(err)) -} - -// Proxy implements netstack.GUDPConnHandler; thread-safe. -// Must be called from a goroutine. -func (h *udpHandler) Proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) { - defer core.Recover(core.Exit11, "udp.Proxy") - return h.proxy(gconn, src, dst, nil) -} - -// proxy connects src to dst over a proxy; thread-safe. -func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) { - // remote, smm, err may all be nil - remote, smm, err := h.Connect(gconn, src, dst, dmx) - - if err != nil { - clos(gconn, remote) // teardown - // smm may be nil; in which case this is a no-op - h.queueSummary(smm.done(err)) // no-op if smm is nil - return false // not ok - } else if remote == nil || smm == nil { // dnsOverride or ipn.Block - h.queueSummary(smm.done(err)) // no-op if smm is nil - // do not close gconn here; it is either - // connected (overridden) or disconnected (blocked) already - // no summary for dns queries; for blocked connection, - // summary is queued in Connect() - return true // ok - } - - cid := smm.ID - core.Go("udp.forward."+cid, func() { - h.listener.PostFlow(smm.postMark()) - h.forward(gconn, rwext{remote, udptimeout}, smm) - }) - return true // ok -} - -// Connect connects the proxy server; thread-safe. -func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, dmx netstack.DemuxerFn) (pc net.Conn, smm *SocketSummary, err error) { - mux := dmx != nil - - // flow is alg/nat-aware, do not change target or any addrs - res, undidAlg, realips, domains := h.onFlow(src, target) - - h.maybeReplaceDest(res, &target) - - filtered, _, fallingback := filterFamilyForDialingWithFailSafe(realips) - actualTargets := makeIPPorts(filtered, target, !undidAlg, 0) - cid, uid, fid, pids := h.judge(res, domains, target.String()) - - if len(actualTargets) <= 0 { // unlikely - actualTargets = []netip.AddrPort{target} - } - smm = udpSummary(cid, uid, src.Addr(), actualTargets[0].Addr()) - - if h.status.Load() == HDLEND { - err = log.EE("udp: connect: %s [%s] %v => %v, end", cid, uid, src, target) - return nil, smm, err // disconnect, no nat - } - - if !target.IsValid() { // must call h.Bind? - log.E("udp: connect: %s [%s] %s => %s; invalid dst", cid, uid, src, target) - return nil, smm, errUdpUnconnected - } - - if isAnyBlockPid(pids) { - smm.PID = ipn.Block - if undidAlg && len(realips) <= 0 && len(domains) > 0 { - err = errNoIPsForDomain - } else { - err = errUdpFirewalled - } - core.Go("udp.stall."+fid, func() { - defer clos(gconn) - defer h.queueSummary(smm.done(err)) - secs := h.stall(fid) - log.I("udp: %s [%s] firewalled from %s => %s (dom: %s / real: %s) for %s; stall? %ds", - cid, uid, src, target, domains, realips, uid, secs) - }) - return nil, smm, nil // disconnect override, no dst - } - - // connect gconn right away, since we assume a duplex-stream from here on - if err = gconn.Establish(); err != nil { - log.W("udp: connect: %s [%s] gconn.Est, mux? %t, %s => %s err: %v", - cid, uid, mux, src, target, err) - return nil, smm, err // disconnect - } - - // requests meant for ipn.Exit are always routed untouched to target - // and never to whatever is set as DNS upstream. - // Ex: If kotlin-land initiates a DNS query (with InetAddress), - // it is routed to the tunnel's fake DNS addr, which is trapped by - // by h.dnsOverride that forwards it to one of the dnsx Transports. - // These dnsx Transports route the query back into the tunnel when - // Rethink-within-Rethink (settings.LoopingBack) routing is enabled. - // If this dnsx Transport is inturn forwarding queries to ANY DNS upstream - // on port 53 (dns53) (see h.resolver.isDns), then the request is trapped - // again & routed back to the dnsx Transport. To avoid this loop, when - // Rethink-within-Rethink routing is enabled (settings.LoopingBack), - // kotlin-land is expected to mark ipn.Base for queries to be trapped - // and sent to user-preferred dnsx Transport, and ipn.Exit for queries - // to be dialed as an outgoing protected connection. In practice, when - // Rethink-within-Rethink routing is enabled and a DNS connection - // as seen (with Flow) is owned by Rethink, then expect the conn - // to be marked ipn.Base for queries sent to tunnel's fake DNS addr - // and ipn.Exit for anywhere else. - if isAnyBasePid(pids) && h.isDNS(target) { - if h.dnsOverride(gconn, uid, smm) { - // socket/session closed by the overriding dns resolver - return nil, nil, nil // connect override, no dst - } // else: not a dns query or target is not a dns addr - } // else: proxy src to dst - - var pxid, rxid, lastselected string - var px ipn.Proxy - var errs error - var selectedTarget netip.AddrPort - - portfwd := settings.PortForward.Load() - canportfwd := portfwd - if mux { - if muxpid := h.mux.pid(src); len(muxpid) > 0 && containsPid(pids, muxpid) { - if settings.Debug { - log.D("udp: connect: %s [%s] mux: %s => %s using muxed-pid %s; all pids %s", - cid, uid, src, target, muxpid, pids) - } - pids = []string{muxpid} - } // else: mxr will dial this conn with a different pid - } - - if settings.Debug { - log.VV("udp: connect: %s [%s] proxying %s => %s [%v]; pids: %s, mux? %t / fwd? %t", - cid, uid, src, target, actualTargets, pids, mux, canportfwd) - } - - // note: fake-dns-ips shouldn't be un-nated / un-alg'd - for i, dstipp := range actualTargets { - rttstart := time.Now() - - px, err = h.prox.ProxyTo(dstipp, uid, pids) - - if px != nil { // last chosen (but not dialed in) proxy - pxid = pidstr(px) - rxid = ipn.ViaID(px) - lastselected = dstipp.Addr().String() - } - selectedTarget = dstipp - - if err != nil || px == nil { - log.W("udp: connect: #%d: %s [%s] failed to get proxy from %s: %v", i, cid, uid, pxid, err) - errs = err // disconnect if loop terminates - continue - } - - canportfwd = portfwd && ipn.Remote(pxid) - - if mux { // mux is not supported by all proxies (few like Exit, Base, WG support it) - pc, err = h.mux.associate(cid, pxid, uid, src, selectedTarget, px.Dialer().Announce, vendor(dmx), canportfwd) - } else { - if settings.Debug { - log.VV("udp: connect: #%d: attempt: %s [%s] proxy(%s) to dst(%s); mux? %t / fwd? %t", - i, cid, uid, pxid, selectedTarget, mux, canportfwd) - } - - if canportfwd { - boundSrc := makeAnyAddrPort(src) - pc, err = px.Dialer().DialBind("udp", boundSrc.String(), selectedTarget.String()) - } else { - pc, err = px.Dialer().Dial("udp", selectedTarget.String()) - } - } - if err == nil { - errs = nil // reset errs - break - } // else try the next realip - - errs = err // store just the last err; complicates logging - end := time.Since(smm.start) - smm.Rtt = time.Since(rttstart).Milliseconds() - log.W("udp: connect: #%d: %s [%s] failed; mux? %t / fwd? %t, addr(%s) / fallback? %t; (rtt:%dms, dur:%s) w err(%v)", - i, cid, uid, mux, canportfwd, dstipp, fallingback, smm.Rtt, core.FmtPeriod(end), err) - if end > retryTimeout { - break - } - } - - if len(pxid) > 0 { // last chosen proxy which may have errored - smm.PID = pxid - smm.RPID = rxid - smm.Target = lastselected // may be invalid - } - - if !selectedTarget.IsValid() { - log.E("udp: connect: %s [%s] no target addr for %s from %v", - cid, uid, target, actualTargets) - return nil, smm, errUdpNoTarget - } - - // pc.RemoteAddr may be that of the proxy, not the actual dst - // ex: pc.RemoteAddr is 127.0.0.1 for Orbot - smm.Target = selectedTarget.Addr().String() - - if errs != nil { - return nil, smm, errs // disconnect - } else if px == nil || pc == nil || core.IsNil(pc) { - log.W("udp: connect: %s [%s] no proxy/egress-conn (mux? %t / fwd? %t) for addr(%s/%s)", - cid, uid, mux, canportfwd, target, selectedTarget) - return nil, smm, errUdpSetupConn // disconnect - } - - var laddr net.Addr - switch x := pc.(type) { - case *net.UDPConn: // direct - laddr = x.LocalAddr() - case core.UDPConn: // connected - laddr = x.LocalAddr() - case net.Conn: // muxed - laddr = x.LocalAddr() - default: - clos(pc) - log.E("udp: connect: %s [%s] proxy(%s) does not impl core.UDPConn(%s/%s); mux? %t", - cid, uid, pxid, target, selectedTarget, mux, canportfwd, uid) - return nil, smm, errUdpSetupConn // disconnect - } - - log.I("udp: connect: %s [%s] (proxy? %s@%s) %v => %s/%s; fallback? %t / mux? %t / fwd? %t", - cid, uid, pxid, px.GetAddr(), laddr, target, selectedTarget, fallingback, mux, canportfwd) - - return pc, smm, nil // connect -} diff --git a/intra/udpmux.go b/intra/udpmux.go deleted file mode 100644 index f080bb37..00000000 --- a/intra/udpmux.go +++ /dev/null @@ -1,681 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package intra - -import ( - "fmt" - "io" - "net" - "net/netip" - "os" - "sync" - "sync/atomic" - "time" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" -) - -// from: github.com/pion/transport/blob/03c807b/udp/conn.go - -const maxtimeouterrors = 3 -const maxInFlight = 512 -const maxOverflow = maxInFlight / 4 - -type flowkind int32 - -var ( - ingress flowkind = 0 - egress flowkind = 1 -) - -func (f flowkind) String() string { - if f == ingress { - return "ingress" - } - return "egress" -} - -type sender interface { - id() string - sendto([]byte, net.Addr) (int, error) - extend(time.Time) -} - -type stats struct { - start time.Time // set only once; on ctor - - dxcount atomic.Uint32 - tx atomic.Uint32 - rx atomic.Uint32 -} - -func (s *stats) String() string { - if s == nil { - return "" - } - return fmt.Sprintf("tx: %d, rx: %d, conns: %d, dur: %s", s.tx.Load(), s.rx.Load(), s.dxcount.Load(), core.FmtTimeAsPeriod(s.start)) -} - -type vendor func(fwd net.Conn, dst netip.AddrPort) error - -// muxer muxes multiple connections grouped by remote addr over net.PacketConn -type muxer struct { - // cid, pid, mxconn, stats are immutable (never reassigned) - mxconn net.PacketConn - cid string // connection id of mxconn - pid string // proxy id mxconn is listening on - uid string // user id owner of mxconn - stats *stats - - until *core.Volatile[time.Time] // deadline extension - - dxconns chan *demuxconn // never closed - dxconnWG *sync.WaitGroup // wait group for demuxed conns - - vnd vendor // endpoint-independent mapping (new routes in netstack) - - doneCh chan struct{} // stop vending, reading, and routing - once sync.Once // muxer.stop() once - cb core.Finally // muxer.stop()/cleanup callback (in a new goroutine) - - rmu sync.RWMutex // protects routes - routes map[netip.AddrPort]*demuxconn // remote addr => demuxed conn -} - -// demuxconn writes to raddr and reads from the muxer -type demuxconn struct { - cid string // connection id - - out sender // promiscuous sender (remuxer) - key netip.AddrPort // promiscuous factor (same as raddr) - raddr net.Addr // remote address connected to - laddr net.Addr // local address connected from - - inCh chan *slice // incoming data from muxer, never closed - overflowCh chan *slice // overflow data, never closed - - closed chan struct{} // close signal - readClosed atomic.Bool // true if read side is closed - writeClosed atomic.Bool // true if write side is closed - once sync.Once // close once - - wt *time.Ticker // write deadline - rt *time.Ticker // read deadline - wto time.Duration // write timeout - rto time.Duration // read timeout -} - -// slice is a byte slice v and its recycler fin. -type slice struct { - v []byte - fin core.Finally -} - -var _ sender = (*muxer)(nil) -var _ core.UDPConn = (*demuxconn)(nil) -var _ core.DuplexCloser = (*demuxconn)(nil) - -// newMuxer creates a muxer/demuxer for a connectionless conn. -func newMuxer(cid, pid, uid string, conn net.PacketConn, vnd vendor, f core.Finally) *muxer { - x := &muxer{ - cid: cid, // same as cid of the first demuxed conn - pid: pid, - uid: uid, - mxconn: conn, - stats: &stats{start: time.Now()}, - until: core.NewZeroVolatile[time.Time](), - routes: make(map[netip.AddrPort]*demuxconn), - rmu: sync.RWMutex{}, - dxconns: make(chan *demuxconn), - doneCh: make(chan struct{}), - dxconnWG: &sync.WaitGroup{}, - cb: f, - vnd: vnd, - } - core.Gx("udpmux.read."+x.pid+x.cid, x.readers) - core.Gx("udpmux.await."+x.pid+x.cid, x.awaiters) - return x -} - -// awaiters waits for a demuxed conns to close, then cleans the state up. -func (x *muxer) awaiters() { - for { - select { - case c := <-x.dxconns: - if settings.Debug { - log.D("udp: mux: %s awaiter: watching %s => %s", c.cid, c.laddr, c.raddr) - } - x.dxconnWG.Add(1) // accept - core.Gx("udpmux.vend.close", func() { - <-c.closed // conn closed - x.unroute(c) - x.dxconnWG.Done() // unaccept - }) - case <-x.doneCh: - log.I("udp: mux: %s awaiter: done", x.cid) - return - } - } -} - -// stop closes conns in the backlog, stops accepting new conns, -// closes muxconn, and waits for demuxed conns to close. -func (x *muxer) stop() error { - if settings.Debug { - log.D("udp: mux: %s stop", x.cid) - } - - var err error - x.once.Do(func() { - close(x.doneCh) - x.drain() - err = x.mxconn.Close() // close the muxed conn - - x.dxconnWG.Wait() // all conns close / error out - core.Go("udpmux.cb", x.cb) // dissociate - log.I("udp: mux: %s stopped; stats: %s", x.cid, x.stats) - }) - - return err -} - -func (x *muxer) drain() { - x.rmu.Lock() - defer x.rmu.Unlock() - - defer clear(x.routes) - if settings.Debug { - log.D("udp: mux: %s drain: closing %d demuxed conns", x.cid, len(x.routes)) - } - for _, c := range x.routes { - clos(c) // will unroute as well - } -} - -// readers has to tasks: -// 1. Dispatching incoming packets to the correct Conn. -// It can therefore not be ended until all Conns are closed. -// 2. Creating a new Conn when receiving from a new remote. -func (x *muxer) readers() { - // todo: recover must call "recycle()" if it wasn't. - defer func() { - _ = x.stop() // stop muxer - }() - - timeouterrors := 0 - for { - bptr := core.Alloc16() - b := *bptr - b = b[:cap(b)] - // todo: if panics are recovered above, recycle() may never be called - recycle := func() { - *bptr = b - core.Recycle(bptr) - } - - // on muxer.stop(), x.doneCh is also closed and x.mxconn.ReadFrom will error out. - n, who, err := x.mxconn.ReadFrom(b) - - x.stats.tx.Add(uint32(n)) // upload - - if timedout(err) { - timeouterrors++ - if timeouterrors < maxtimeouterrors { - // extend by preset (min) udp timeout - x.extend(time.Now().Add(time.Second * udptimeout)) - if settings.Debug { - log.D("udp: mux: %s read timeout(%d): %v", x.cid, timeouterrors, err) - } - recycle() - continue - } - } - if err != nil { - log.I("udp: mux: %s read done n(%d): %v", x.cid, n, err) - recycle() - return - } - - timeouterrors = 0 // reset on successful reads - - if who == nil || n == 0 { - log.W("udp: mux: %s read done n(%d): nil remote addr; skip", x.cid, n) - recycle() - continue - } - - const todoCid = "todo" - // may be an existing route or a new route; - // recycle() if who is invalid or x is closed. - if dst := x.route(todoCid, addr2netip(who), ingress); dst != nil { - select { - case dst.inCh <- &slice{v: b[:n], fin: recycle}: // incomingCh is never closed - default: // dst probably closed, but not yet unrouted - err = errUdpIncomingDrop - recycle() - } - logev(err)("udp: mux: %s read: n(%d) from %v <= %v; dropped? %v", - dst.cid, n, dst.laddr, who, err) - } else { // dst may be nil when x.doneCh is closed by muxer.stop(). - recycle() - } // looping back is okay, as x.mxconn.ReadFrom should error out. - } -} - -func (x *muxer) findRoute(to netip.AddrPort) *demuxconn { - x.rmu.RLock() - defer x.rmu.RUnlock() - return x.routes[to] -} - -func (x *muxer) route(cid string, to netip.AddrPort, flo flowkind) *demuxconn { - if !to.IsValid() { - log.W("udp: mux: %s route: %s invalid addr %s", cid, flo, to) - return nil - } - - if conn := x.findRoute(to); conn != nil { - return conn - } - - x.rmu.Lock() - defer x.rmu.Unlock() - - conn, ok := x.routes[to] - if conn == nil || !ok { - // new routes created here won't really exist in netstack if - // settings.EndpointIndependentMapping or settings.EndpointIndependentFiltering - // is set to false. - conn = x.newLocked(cid, to) - select { - case <-x.doneCh: - clos(conn) - log.W("udp: mux: %s route: %s for %s; muxer closed", conn.cid, flo, to) - return nil - case x.dxconns <- conn: - n := x.stats.dxcount.Add(1) - x.routes[to] = conn - // if egress, a demuxed conn is already vended/sockisifed via netstack - // (see: udpHandler:ProxyMux) and so it need not be vended again. Even - // if it were to be, it'd fail with "port/addr already in use" - // ex: route: egress vend failure 1.1.1.1:53; err connect udp 10.111.222.1:42182: port is in use - if flo == ingress { - core.Go("udpmux.vend", func() { // a fork in the road - if verr := x.vnd(conn, to); verr != nil { - clos(conn) - log.E("udp: mux: %s route: %s vend failure %s; err %v", conn.cid, flo, to, verr) - } - }) - } - log.I("udp: mux: %s route: %s #%d new for %s; stats: %s", - conn.cid, flo, n, to, x.stats) - } - } - return conn -} - -func (x *muxer) unroute(c *demuxconn) { - // don't really expect to handle panic w/ core.Recover - x.rmu.Lock() - defer x.rmu.Unlock() - - log.I("udp: mux: %s unrouting... %s => %s", x.cid, c.laddr, c.raddr) - delete(x.routes, c.key) -} - -func (x *muxer) id() string { return x.cid } - -func (x *muxer) sendto(p []byte, addr net.Addr) (int, error) { - // on closed(x.doneCh), x.mxconn is closed and writes will fail - n, err := x.mxconn.WriteTo(p, addr) - x.stats.rx.Add(uint32(n)) // download - return n, err -} - -func (x *muxer) extend(t time.Time) { - c := x.until.Load() - if t.IsZero() { - extend(x.mxconn, 0) - x.until.Store(t) - } else if c.IsZero() || c.Before(t) { - // extend if t is after existing deadline at x.until - extend(x.mxconn, time.Until(t)) - x.until.Store(t) - } -} - -// new creates a demuxed conn to r. -func (x *muxer) newLocked(cid string, r netip.AddrPort) *demuxconn { - dopt := settings.GetDialerOpts() // TODO: update timeouts when opts change - readtimeout := time.Second * time.Duration(max(udptimeout, dopt.ReadTimeoutSec)) - writetimeout := time.Second * time.Duration(max(udptimeout, dopt.WriteTimeoutSec)) - - if len(cid) > 0 { - cid = x.cid + ":" + cid - } - return &demuxconn{ - cid: cid, - out: x, // muxer - laddr: x.mxconn.LocalAddr(), // listen addr - raddr: net.UDPAddrFromAddrPort(r), // remote addr - key: r, // key (same as raddr) - inCh: make(chan *slice, maxInFlight), // read from muxer - overflowCh: make(chan *slice, maxOverflow), // overflow from read - closed: make(chan struct{}), // always unbuffered - wt: time.NewTicker(writetimeout), - rt: time.NewTicker(readtimeout), - wto: writetimeout, - rto: readtimeout, - } -} - -// Read implements core.UDPConn.Read -func (c *demuxconn) Read(p []byte) (int, error) { - sz := len(p) - if c.readClosed.Load() { - return 0, log.EE("udp: mux: %s demux: read: %v <= %v; closed (sz: %d)", - c.out.id(), c.laddr, c.raddr, sz) - } - - defer c.rt.Reset(c.rto) - select { - case <-c.rt.C: - log.W("udp: mux: %s demux: read: %v <= %v; timeout (sz: %d)", - c.out.id(), c.laddr, c.raddr, sz) - return 0, os.ErrDeadlineExceeded - case <-c.closed: - log.W("udp: mux: %s demux: read: %v <= %v; closed (sz: %d)", - c.out.id(), c.laddr, c.raddr, sz) - return 0, net.ErrClosed - case sx := <-c.overflowCh: - return c.io(&p, sx) - case sx := <-c.inCh: - return c.io(&p, sx) - } -} - -// Write implements core.UDPConn.Write -func (c *demuxconn) Write(p []byte) (n int, err error) { - sz := len(p) - - if c.writeClosed.Load() { - return 0, log.EE("udp: mux: %s demux: write: %v => %v; closed (sz: %d)", - c.out.id(), c.laddr, c.raddr, sz) - } - - defer c.wt.Reset(c.wto) - select { - case <-c.wt.C: - log.W("udp: mux: %s demux: write: %v => %v; timeout (sz: %d)", - c.out.id(), c.laddr, c.raddr, sz) - return 0, os.ErrDeadlineExceeded - case <-c.closed: - log.W("udp: mux: %s demux: write: %v => %v; closed (sz: %d)", - c.out.id(), c.laddr, c.raddr, sz) - return 0, net.ErrClosed - default: - n, err = c.out.sendto(p, c.raddr) - logev(err)("udp: mux: %s demux: write: %v => %v; done(sz: %d/%d); err? %v", - c.out.id(), c.laddr, c.raddr, n, sz, err) - return n, err - } -} - -// ReadFrom implements core.UDPConn.ReadFrom (unused) -func (c *demuxconn) ReadFrom(p []byte) (int, net.Addr, error) { - n, err := c.Read(p) - return n, c.raddr, err -} - -// WriteTo implements core.UDPConn.WriteTo (unused) -func (c *demuxconn) WriteTo(p []byte, to net.Addr) (int, error) { - // todo: check if "to" is the same as c.raddr - // if to != c.raddr { - // return 0, net.ErrWriteToConnected - // } - return c.Write(p) -} - -// Implements core.DuplexCloser. -func (c *demuxconn) CloseRead() (err error) { - if c.readClosed.CompareAndSwap(false, true) { - closeall := c.writeClosed.Load() - if closeall { - err = c.Close() // both sides closed - } - logev(err)("udp: mux: %s demux: %s => %s close read (and conn? %t), err: %v", - c.cid, c.laddr, c.raddr, closeall, err) - return - } - return net.ErrClosed -} - -// Implements core.DuplexCloser. -func (c *demuxconn) CloseWrite() (err error) { - if c.writeClosed.CompareAndSwap(false, true) { - closeall := c.readClosed.Load() - if closeall { - err = c.Close() // both sides closed - } - logev(err)("udp: mux: %s demux: %s => %s close write (and conn? %t), err: %v", - c.cid, c.laddr, c.raddr, closeall, err) - return err - } - return net.ErrClosed -} - -// Close implements core.UDPConn.Close -func (c *demuxconn) Close() error { - if settings.Debug { - log.D("udp: mux: %s demux %s => %s close, inC: %d, overC: %d", - c.out.id(), c.laddr, c.raddr, len(c.inCh), len(c.overflowCh)) - } - c.once.Do(func() { - close(c.closed) // sig close - - c.readClosed.Store(true) - c.writeClosed.Store(true) - - defer c.wt.Stop() - defer c.rt.Stop() - for { - select { - case sx := <-c.inCh: - sx.fin() - case sx := <-c.overflowCh: - sx.fin() - default: - log.I("udp: mux: %s demux from %s => %s closed", c.out.id(), c.laddr, c.raddr) - return - } - } - }) - return nil -} - -// LocalAddr implements core.UDPConn.LocalAddr -func (c *demuxconn) LocalAddr() net.Addr { - return c.laddr -} - -// RemoteAddr implements core.UDPConn.RemoteAddr -func (c *demuxconn) RemoteAddr() net.Addr { - return c.raddr -} - -// SetDeadline implements core.UDPConn.SetDeadline -func (c *demuxconn) SetDeadline(t time.Time) error { - werr := c.SetWriteDeadline(t) - rerr := c.SetReadDeadline(t) - return core.JoinErr(werr, rerr) -} - -// SetReadDeadline implements core.UDPConn.SetReadDeadline -func (c *demuxconn) SetReadDeadline(t time.Time) error { - if d := time.Until(t); d > 0 { - c.rto = d - c.rt.Reset(d) - c.out.extend(t) - } else { - c.out.extend(time.Time{}) // no deadline - c.rt.Stop() - } - return nil -} - -// SetWriteDeadline implements core.UDPConn.SetWriteDeadline -func (c *demuxconn) SetWriteDeadline(t time.Time) error { - if d := time.Until(t); d > 0 { - c.wto = d - c.wt.Reset(d) - c.out.extend(t) - } else { - c.out.extend(time.Time{}) // no deadline - c.wt.Stop() - } - // Write deadline of underlying connection should not be changed - // since the connection can be shared. - return nil -} - -func (c *demuxconn) io(out *[]byte, in *slice) (int, error) { - id := c.out.id() - // todo: handle the case where len(b) > len(p) - n := copy(*out, in.v) - q := len(in.v) - n - if q > 0 { - select { - case <-c.closed: - log.W("udp: mux: %s demux: read: %v <= %v drop(sz: %d)", id, c.laddr, c.raddr, q) - in.fin() - case c.overflowCh <- &slice{v: in.v[n:], fin: in.fin}: - log.W("udp: mux: %s demux: read: %v <= %v overflow(sz: %d)", id, c.laddr, c.raddr, q) - default: - log.E("udp: mux: %s demux: read: %v <= %v dropped(sz: %d)", id, c.laddr, c.raddr, q) - in.fin() - return n, io.ErrShortWrite - } - } else { - if settings.Debug { - log.VV("udp: mux: %s demux: read: %v <= %v done(sz: %d)", id, c.laddr, c.raddr, n) - } - in.fin() - } - return n, nil -} - -func timedout(err error) bool { - x, ok := err.(net.Error) - return ok && x.Timeout() -} - -type muxTable struct { - sync.Mutex - t map[string]map[netip.AddrPort]*muxer // pid -> [src -> dst] endpoint independent nat -} - -type assocFn func(net, dst string) (net.PacketConn, error) - -func newMuxTable() *muxTable { - return &muxTable{t: make(map[string]map[netip.AddrPort]*muxer)} -} - -func (e *muxTable) pid(src netip.AddrPort) string { - e.Lock() - defer e.Unlock() - for _, pxm := range e.t { - if mxr := pxm[src]; mxr != nil { - return mxr.pid - } - } - return "" -} - -func (e *muxTable) associate(cid, pid, uid string, src, dst netip.AddrPort, mk assocFn, v vendor, portfwd bool) (_ net.Conn, err error) { - e.Lock() // lock - - pxm := e.t[pid] - if pxm == nil { - pxm = make(map[netip.AddrPort]*muxer) - e.t[pid] = pxm - } - - mxr := pxm[src] - if mxr == nil { - // dst may be of a different family than src (4to6, 6to4 etc) - // and so, rely on dst to determine the family to listen on. - proto := "udp" - anyaddr := anyaddr6 - anyport := uint16(0) - if dst.Addr().Is4() && !dialers.Use6() { - proto = "udp4" - anyaddr = anyaddr4 - } - anyaddrport := netip.AddrPortFrom(anyaddr, anyport) - if portfwd { - anyaddrport = netip.AddrPortFrom(anyaddr, src.Port()) - } - - pc, err := mk(proto, anyaddrport.String()) - - if err != nil { - core.Close(pc) - e.Unlock() // unlock - return nil, err // return - } - - mxr = newMuxer(cid, pid, uid, pc, v, func() { - e.dissociate(cid, pid, src) - }) - pxm[src] = mxr - log.I("udp: mux: %s new assoc for %s %s via %s; fwd? %t", - cid, pid, src, anyaddrport, portfwd) - } - - if mxr.pid != pid { - e.Unlock() - return nil, log.EE("udp: mux: %s assoc proxy mismatch: %s != %s or %s != %s", - cid, mxr.pid, pid, mxr.uid, uid) - } else if mxr.uid != uid && - (uid != UNKNOWN_UID_STR || mxr.uid != UNKNOWN_UID_STR) { - e.Unlock() - return nil, log.EE("udp: mux: %s assoc uid mismatch: %s != %s or %s != %s", - cid, mxr.pid, pid, mxr.uid, uid) - } - - e.Unlock() // unlock - // do not hold e.lock on calls into mxr - c := mxr.route(cid, dst, egress) - if c == nil { - return nil, log.EE("udp: mux: %s vend: no conn for %s", cid, dst) - } - return c, nil -} - -func (e *muxTable) dissociate(cid, pid string, src netip.AddrPort) { - log.I("udp: mux: %s (%s) dissoc for %s", cid, pid, src) - - e.Lock() - defer e.Unlock() - pxm := e.t[pid] // may be nil and that's okay - delete(pxm, src) -} - -func addr2netip(addr net.Addr) (zz netip.AddrPort) { - if addr == nil { - return // zz - } - ipp, err := netip.ParseAddrPort(addr.String()) - if err != nil { - log.W("udp: mux: addr2netip: %v", err) - return // zz - } - return ipp // may be invalid -} diff --git a/intra/x64/dns64.go b/intra/x64/dns64.go deleted file mode 100644 index 8743176c..00000000 --- a/intra/x64/dns64.go +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2022 -// Frank Denis - -// adopted from: github.com/DNSCrypt/dnscrypt-proxy/blob/df3fb0c9/dnscrypt-proxy/plugin_dns64.go -package x64 - -import ( - "context" - "errors" - "net" - "sync" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/dialers" - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/xdns" - "github.com/miekg/dns" -) - -var ( - rfc7050WKA1 = net.IPv4(192, 0, 0, 170) - rfc7050WKA2 = net.IPv4(192, 0, 0, 171) - - // nslookup ipv4only.arpa 2606:4700:4700::64 - // Non-authoritative answer: - // Address: 192.0.0.171 - // Address: 192.0.0.170 - // Address: 64:ff9b::c000:aa - // Address: 64:ff9b::c000:ab - // _, rfc6052WKP, _ = net.ParseCIDR("64:ff9b::/96") - // _, rfc8215WKP, _ = net.ParseCIDR("64:ff9b:1:fffe::/96") - - ipv6bits = 8 * net.IPv6len - - errQuery = errors.New("invalid dns64 query") - errAns = errors.New("invalid dns64 answer") - errEmpty = errors.New("missing dns64 IPv6 prefixes") - errNotFound = errors.New("resolver did not send dns64 ipv6 prefixes") - errNoSuchServer = errors.New("resolver not registered") - - emptyStruct = struct{}{} - - arpa64 = questionArpa64() -) - -type dns64 struct { - sync.RWMutex - - ctx context.Context - // dns-resolver -> nat64-ips - ip64 map[string][]*net.IPNet - // dns-resolver -> unique nat64-ips - uniqIP64 map[string]map[string]struct{} -} - -func newDns64(ctx context.Context) *dns64 { - d := &dns64{ - ctx: ctx, - ip64: make(map[string][]*net.IPNet), - uniqIP64: make(map[string]map[string]struct{}), - } - core.Gx("dns64.init", d.init) - return d -} - -func (d *dns64) init() { - if err := d.ofLocal464(); err != nil { // unlikely - log.W("dns64: err reg local(%v)", err) - } -} - -func questionArpa64() *dns.Msg { - msg := new(dns.Msg) - msg.SetQuestion(dnsx.Rfc7050WKN, dns.TypeAAAA) - return msg -} - -// register adds a new dns resolver to the dns64 map; thread-safe. -func (d *dns64) register(id string) { - d.Lock() - defer d.Unlock() - if l, ok := d.ip64[id]; ok { - log.W("dns64: overwrite existing ip64(%v) for resolver(%s)", l, id) - } - d.ip64[id] = make([]*net.IPNet, 0) - d.uniqIP64[id] = make(map[string]struct{}) -} - -func (d *dns64) AddResolver(id, r string) (ok bool) { - switch id { - case dnsx.OverlayResolver: - return d.ofOverlay() == nil - case dnsx.Local464Resolver: - return d.ofLocal464() == nil - } - - d.register(id) - - defer func() { - if !ok { - d.RemoveResolver(id) - } - }() - - ans, err := dialers.Query(arpa64, r) - - if err != nil || ans == nil || !xdns.HasAnyAnswer(ans) { - log.W("dns64: udp: could not query %s[%s] or empty ans; err %v", r, id, err) - return - } - - if ans.Truncated { // should never be the case for DOH, ODOH, DOT - // else if: returned response is truncated dns ans, retry over tcp - ans, err = dialers.Query(arpa64, r) - if err != nil { - log.W("dns64: tcp: could not query resolver %s[%s]; err %v", r, id, err) - return - } - if !xdns.HasAnyAnswer(ans) { - log.W("dns64: tcp: invalid response from resolver %s[%s]", r, id) - return - } - } - - ips := make([]net.IP, 0) - for _, answer := range ans.Answer { - if answer.Header().Rrtype == dns.TypeAAAA { - if ipv6, ok := answer.(*dns.AAAA); ok { - ips = append(ips, ipv6.AAAA) - } - } - } - - if err := d.add(id, ips); err == nil { - return true - } - - return -} - -func (d *dns64) RemoveResolver(id string) bool { - d.Lock() - defer d.Unlock() - delete(d.ip64, id) - delete(d.uniqIP64, id) - return true -} - -// TODO: handle svcb/https ipv4hint/ipv6hint -// datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-10#section-7.4 -func (d *dns64) eval(network string, force64 bool, ansin *dns.Msg, r, uid string) *dns.Msg { - qname := xdns.QName(ansin) - // if question is AAAA, then answer must have AAAA; for example CNAME, - // records pointing no where must not be considered as AAAA answers - // but instead must be sent to DNS64 for translation - // q: www.skysports.com AAAA - // ans: www.skysports.com CNAME www.skysports.akadns.net; - // www.skysports.akadns.net CNAME www.skysports.com.edgekey.net; - // www.skysports.com.edgekey.net CNAME e16115.j.akamaiedge.net - // hasaaaq(true) hasans(true) rgood(true) ans0000(false) - hasq6 := xdns.HasAAAAQuestion(ansin) - hasans6 := xdns.HasAAAAAnswer(ansin) - ans00006 := xdns.AQuadAUnspecified(ansin) - hasauth := xdns.IsDNSSECAnswerAuthenticated(ansin) - // treat as if v6 answer missing if enforcing 6to4 - if !hasq6 || ((hasauth || hasans6) && !force64) || ans00006 { - // nb: has-aaaa-answer should cover for cases where - // the response is blocked by dnsx.RDNS - log.D("dns64: for(%s %s), no-op q(%s), q6(%t), ans6(%t), force64(%t), ans0000(%t)", - network, uid, qname, hasq6, hasans6, force64, ans00006) - return nil - } - - id := id64(r) - ip64 := d.get(id) - if len(ip64) <= 0 { - if ip64 = d.get(dnsx.UnderlayResolver); len(ip64) <= 0 { - if ip64 = d.get(dnsx.OverlayResolver); len(ip64) <= 0 { - ip64 = d.get(dnsx.Local464Resolver) - } - } - log.D("dns64: attempt underlay/local464 resolver [%s@%s] ip64 (ad? %t) w len(%d)", r, uid, hasauth, len(ip64)) - } else { - log.V("dns64: for %s, no resolver id(%s[%s]) registered (ad? %t)", uid, r, id, hasauth) - } - - ans4, err := d.query64(network, ansin, r, uid) - rgood := xdns.HasRcodeSuccess(ans4) - hasans := xdns.HasAnyAnswer(ans4) - ans0000 := xdns.AQuadAUnspecified(ans4) - if err != nil || ans4 == nil || !hasans || ans0000 { - log.W("dns64: skip: for %s, query(n:%s / a? %t) on resolver(%s[%s]/%s), code(good? %t / blocked? %t), err(%v)", - uid, qname, hasans, r, id, network, rgood, ans0000, err) - return nil - } - - ans64, didTranslate := xdns.TranslateRecords(ans4, dns.TypeA, func(r dns.RR) (rx []dns.RR, stop bool) { - if len(ip64) <= 0 { // can never be the case, see Local464Resolver - return nil, false - } - for _, ipnet := range ip64 { - if x := xdns.MaybeToQuadA(r, ipnet); x != nil { - rx = append(rx, x) - } - } - return - }) - - logwif(!didTranslate)("dns64: %s for %s: translated on %s[%s] response(%d)", - qname, uid, r, id, xdns.Len(ans64)) - - if !didTranslate { - // may be there were no A records in ans4; or, - // xdns.MaybeToQuadA failed for every A ans4 record - return nil - } - return ans64 -} - -// query64 answers IPv4 query from the given IPv6 query. -func (d *dns64) query64(network string, msg6 *dns.Msg, r, uid string) (*dns.Msg, error) { - msg4 := xdns.Request4FromResponse6(msg6) // may be nil - if msg4 == nil || !xdns.HasAnyQuestion(msg4) { - return nil, errQuery - } - - proto, _ := xdns.Net2ProxyID(network) - - q4 := xdns.QName(msg4) - - res, err := dialers.QueryFor(msg4, uid, r) - - hasAns := xdns.HasAnyAnswer(res) - log.D("dns64: for %s over %s: %s q(%s) / a(%t) / e(%v) / e-not-nil(%t)", - uid, proto, r, q4, hasAns, err, err != nil) - if err != nil { - return nil, err - } - if !hasAns { - return nil, errAns - } - // res.Truncated never likely happens w/ DOH, ODOH, DOT? - if res.Truncated && proto != dnsx.NetTypeTCP { - // else if: returned response is truncated dns ans, retry over tcp - res, err = dialers.QueryFor(msg4, uid, r) - - hasAns = xdns.HasAnyAnswer(res) - log.D("dns64: tcp: for %s over %s: q(%s) / a(%d) / e(%v) / e-not-nil(%t)", - uid, r, q4, hasAns, err, err != nil) - if err != nil { - return nil, err - } else if !hasAns { - return nil, errAns - } - } - return res, err -} - -func (d *dns64) ofOverlay() error { - ips, err := net.DefaultResolver.LookupIP(d.ctx, "ip6", dnsx.Rfc7050WKN) - log.I("dns64: ipv4only.arpa w underlying network resolver") - - if err != nil { - return err - } - - if len(ips) <= 0 { - return errNotFound - } - - d.register(dnsx.OverlayResolver) - return d.add(dnsx.OverlayResolver, ips) -} - -func (d *dns64) ofLocal464() error { - d.register(dnsx.Local464Resolver) - // send a copy of localip64 as d.add mutates its entries in-place - // this addr64, hopefully, isn't used by any other dns world-wide - localip64 := []net.IP{ - net.ParseIP("64:ff9b:1:fffe::192.0.0.170"), - } - return d.add(dnsx.Local464Resolver, localip64) -} - -// add adds the nat64 prefixes to the dns64 map; thread-safe. -func (d *dns64) add(serverid string, nat64 []net.IP) error { - - if len(nat64) <= 0 { - log.W("dns64: no nat64 ips for %s", serverid) - return errEmpty - } - - for _, ipv6 := range nat64 { - log.D("dns64: id(%s); add? nat64 ip(%s / %d)", serverid, ipv6, len(ipv6)) - if len(ipv6) != net.IPv6len { - continue - } - - endByte := 0 - if wka := net.IPv4(ipv6[12], ipv6[13], ipv6[14], ipv6[15]); wka.Equal(rfc7050WKA1) || wka.Equal(rfc7050WKA2) { //96 - endByte = 12 - } else if wka := net.IPv4(ipv6[9], ipv6[10], ipv6[11], ipv6[12]); wka.Equal(rfc7050WKA1) || wka.Equal(rfc7050WKA2) { //64 - endByte = 8 - } else if wka := net.IPv4(ipv6[7], ipv6[9], ipv6[10], ipv6[11]); wka.Equal(rfc7050WKA1) || wka.Equal(rfc7050WKA2) { //56 - endByte = 7 - } else if wka := net.IPv4(ipv6[6], ipv6[7], ipv6[9], ipv6[10]); wka.Equal(rfc7050WKA1) || wka.Equal(rfc7050WKA2) { //48 - endByte = 6 - } else if wka := net.IPv4(ipv6[5], ipv6[6], ipv6[7], ipv6[9]); wka.Equal(rfc7050WKA1) || wka.Equal(rfc7050WKA2) { //40 - endByte = 5 - } else if wka := net.IPv4(ipv6[4], ipv6[5], ipv6[6], ipv6[7]); wka.Equal(rfc7050WKA1) || wka.Equal(rfc7050WKA2) { //32 - endByte = 4 - } - - if endByte <= 0 { - log.I("dns64: id(%s), e(%d); no valid ipv4only.arpa in ans6(%v)", serverid, endByte, ipv6) - continue - } - - endBit := endByte * 8 - ipxx := new(net.IPNet) - // prefix ipv6 until the endByte, followed by all-zeros - // 64:ff9b:1::WKA -> 64:ff9b:1:: - ipxx.IP = append(ipv6[:endByte], net.IPv6zero[endByte:]...) - ipxx.Mask = net.CIDRMask(endBit, ipv6bits) - - if err := d.addNat64Prefix(serverid, ipxx); err != nil { - return err - } - } - - ip64 := d.get(serverid) - - if len(ip64) == 0 { - log.I("dns64: id(%s) has zero nat64 prefixes", serverid) - return errEmpty - } else { - return nil - } -} - -func (d *dns64) get(serverid string) []*net.IPNet { - d.RLock() - defer d.RUnlock() - return d.ip64[serverid] -} - -func (d *dns64) addNat64Prefix(id string, ipxx *net.IPNet) error { - d.Lock() - defer d.Unlock() - - ip64, ok1 := d.ip64[id] - uniq, ok2 := d.uniqIP64[id] - if !ok1 || !ok2 { - log.W("dns64: no server found server(%s)", id) - return errNoSuchServer - } - - // ipxx.String -> 64:ff9b:1::/mask - ipxxstr := ipxx.String() - _, exists := uniq[ipxxstr] - if !exists { - ip64 = append(ip64, ipxx) - uniq[ipxxstr] = emptyStruct - log.I("dns64: add ipnet [%s] for server(%s)", ipxx, id) - } else { - log.D("dns64: prefix6(%v) for server(%s) exists!", id, ipxx) - } - // nil / empty lists are valid values in map[string][]*net.IP - d.ip64[id] = ip64 - - return nil -} - -func logwif(cond bool) log.LogFn { - if cond { - return log.W - } - return log.V -} diff --git a/intra/x64/nat64.go b/intra/x64/nat64.go deleted file mode 100644 index 4d428219..00000000 --- a/intra/x64/nat64.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package x64 - -import ( - "context" - "net" - - "github.com/celzero/firestack/intra/log" -) - -type nat64 struct { -} - -func newNat64(_ context.Context) *nat64 { - return &nat64{} -} - -// IsNat64 Implements NAT64. -func (n *nat64) IsNat64(prefix64 *net.IPNet, ip6 net.IP) bool { - return prefix64.Contains(ip6) -} - -// xAddr translates ip6 to IPv4 discarding prefix64. -// If prefix64 or ip6 is not valid, it returns zerovalueaddr. -// If ip6 is unspecified, it returns unspecified IPv4. -func (n *nat64) xAddr(prefix64 *net.IPNet, ip6 net.IP) net.IP { - return ip6to4(prefix64, ip6) -} - -// ip6to4 converts ip6 to IPv4 discarding prefix64. -func ip6to4(prefix64 *net.IPNet, ip6 net.IP) net.IP { - if ip6.IsUnspecified() { - return net.IPv4zero - } - ip4 := make(net.IP, net.IPv4len) - bitmask, _ := prefix64.Mask.Size() // prefix64 expected to be never nil - startByte := bitmask / 8 - - if startByte+net.IPv4len > len(ip6) { - log.W("natpt: too long; cannot convert ip64(%v) / prefix64(%v) to ip4", ip6, prefix64) - return nil - } - - for i := 0; i < net.IPv4len; i++ { - i6 := startByte + i - // skip byte 8, datatracker.ietf.org/doc/html/rfc6052#section-2.2 - if i6 == 8 { - startByte++ - } - - ip4[i] = ip6[startByte+i] - } - return ip4 -} diff --git a/intra/x64/natpt.go b/intra/x64/natpt.go deleted file mode 100644 index 28733344..00000000 --- a/intra/x64/natpt.go +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2022 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package x64 - -import ( - "context" - "maps" - "net" - "net/netip" - - "github.com/celzero/firestack/intra/dnsx" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/settings" - "github.com/miekg/dns" -) - -// app | interface | pt | who | internet? -// ---- | -------- | -------- | ---- | -------- -// ip4 | ip4 | - | - | y -// ip4 | ip6 | 464xlat | os | y -// ---- | -------- | -------- | ---- | -------- -// ip6 | ip6 | - | - | y -// ip6 | ip4 | nat64 | rdns | y -// ---- | -------- | -------- | ---- | -------- -// ip4+6 | ip6 | 464xlat | os | y -// ip4+6 | ip4 | happyeye | app | y -// ---- | -------- | -------- | ---- | -------- -// ip4+6 | ip4+6 | bind | rdns | y -// ip4+6 | ip6+4 | bind | rdns | y -// -// datatracker.ietf.org/doc/html/rfc8305#section-7 -// nicmx.github.io/Jool/en/intro-xlat.html -type natPt struct { - *nat64 - *dns64 - ip4s []net.IP - ip6s []net.IP -} - -var _ dnsx.NatPt = (*natPt)(nil) - -var ( - unspecified4 = netip.IPv4Unspecified() - zerovalueaddr = netip.Addr{} -) - -func NewNatPt() *natPt { - return NewNatPt2(context.Background()) -} - -// NewNatPt returns a new NatPt. -func NewNatPt2(ctx context.Context) *natPt { - log.I("natpt: new; mode(%v)", settings.PtMode.Load()) - return &natPt{ - nat64: newNat64(ctx), - dns64: newDns64(ctx), - ip4s: nil, - ip6s: nil, - } -} - -// D64 Implements DNS64. -func (pt *natPt) D64(network, id, uid string, ans6 *dns.Msg) *dns.Msg { - ptmode := settings.PtMode.Load() - if ptmode != settings.PtModeNo46 { // do64 - force64 := ptmode == settings.PtModeForce64 - return pt.dns64.eval(network, force64, ans6, id, uid) - } - return nil -} - -// IsNat64 Implements NAT64. -func (n *natPt) IsNat64(id string, ip netip.Addr) bool { - prefixes := n.nat64PrefixForResolver(id) - return match(prefixes, addr2ip(ip)) != nil -} - -// X64 Implements NAT64. -func (n *natPt) X64(id string, ip6 netip.Addr) (ip4 netip.Addr) { - id = id64(id) - if !ip6.Is6() { - log.D("natpt: not ip6: %v", ip6) - return - } - - // blocked domains (with zero IPv6 addr) should always be translated - // to blocked IPv4 addr regardless of NAT64 prefix - if ip6.IsUnspecified() { - log.D("natpt: ip6(%v) is unspecified", ip6) - return unspecified4 - } - - rawip := addr2ip(ip6) - if id == dnsx.AnyResolver { - n.RLock() - all := make(map[string][]*net.IPNet, len(n.ip64)) - maps.Copy(all, n.ip64) - n.RUnlock() - - for tid, prefixes := range all { - if len(prefixes) <= 0 { - continue - } - if x := match(prefixes, rawip); x != nil { - return ip2addr(n.xAddr(x, rawip)) - } else { - log.V("natpt: no matching prefix64 for ip(%v) in id(%s/%d)", ip6, tid, len(prefixes)) - } - } - log.D("natpt: no prefix64 found for resolver(%s)", ip6, id) - return zerovalueaddr - } - - prefixes := n.nat64PrefixForResolver(id) - if len(prefixes) <= 0 { - log.D("natpt: no prefix64 found for resolver(%s)", ip6, id) - return zerovalueaddr - } - if x := match(prefixes, rawip); x != nil { - return ip2addr(n.xAddr(x, rawip)) - } else { - log.VV("natpt: no matching prefix64 for ip(%v) in id(%s/%d)", ip6, id, len(prefixes)) - } - return zerovalueaddr -} - -// Add64 implements DNS64. -func (h *natPt) Add64(id string) bool { - return h.dns64.AddResolver(id64(id), id) -} - -// Remove64 implements DNS64. -func (h *natPt) Remove64(id string) bool { - return h.dns64.RemoveResolver(id64(id)) -} - -func (n *natPt) ResetNat64Prefix(ip6prefix string) bool { - var err error - var ipnet *net.IPNet - if _, ipnet, err = net.ParseCIDR(ip6prefix); err == nil { - n.dns64.register(dnsx.UnderlayResolver) // wipe the slate clean - if err = n.dns64.addNat64Prefix(dnsx.UnderlayResolver, ipnet); err == nil { - return true - } - } - log.W("natpt: could not add underlay nat64 prefix: %s; err %v", ip6prefix, err) - return false -} - -// Returns the first matching local-interface net.IP for the network -func (n *natPt) UIP(network string) []byte { - switch network { - case "tcp6", "udp6": - if len(n.ip6s) > 0 { - return n.ip6s[0] - } - return net.IPv6zero - default: - if len(n.ip4s) > 0 { - return n.ip4s[0] - } - return net.IPv4zero - } -} - -func (n *natPt) nat64PrefixForResolver(id string) []*net.IPNet { - return n.get(id) -} - -// match returns the first matching prefix for ip in nets. -func match(nets []*net.IPNet, ip net.IP) *net.IPNet { - for _, p := range nets { - if p.Contains(ip) { - return p - } - } - return nil -} - -func ID64(t dnsx.Transport) string { - return id64(t.ID()) -} - -func id64(tid string) string { - switch tid { - case dnsx.System: - return dnsx.UnderlayResolver - case dnsx.Goos: - return dnsx.OverlayResolver - default: - return tid - } -} - -func addr2ip(ip netip.Addr) net.IP { - return net.IP(ip.AsSlice()) -} - -func ip2addr(ip net.IP) netip.Addr { - x, _ := netip.AddrFromSlice(ip) - return x.Unmap() -} diff --git a/intra/xdns/common.go b/intra/xdns/common.go deleted file mode 100644 index 7946ad8a..00000000 --- a/intra/xdns/common.go +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package xdns - -import ( - "encoding/binary" - "errors" - "fmt" - "net" - "net/netip" - "net/url" - "strings" - - "github.com/miekg/dns" -) - -type CryptoConstruction uint16 - -const ( - UndefinedConstruction CryptoConstruction = iota - XSalsa20Poly1305 - XChacha20Poly1305 -) - -const ( - ClientMagicLen = 8 - // X-Nile-Flags:[1:AAIAAQ] - blocklistHeaderKey = "x-nile-flags" // "x-bl-fl" - // github.com/serverless-dns/serverless-dns/blob/f247f75d31a/src/core/io-state.js#L188 - // X-Nile-Region:[sin] - rethinkdnsRegionHeaderKey = "x-nile-region" - // Cf-Ray:[d1e2a3d4b5e6e7f8-SIN] - cfRayHeaderKey = "cf-ray" -) - -var ( - CertMagic = [4]byte{0x44, 0x4e, 0x53, 0x43} - ServerMagic = [8]byte{0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38} -) - -const ( - MinDNSPacketSize = 12 + 5 - MaxDNSPacketSize = 4096 - MaxDNSUDPPacketSize = 4096 - MaxDNSUDPSafePacketSize = 1252 - - // 0 TTL means no caching: - // cs.android.com/android/platform/superproject/main/+/main:packages/modules/DnsResolver/res_cache.cpp;l=770;drc=5483e926ea7753866350b1681fef8f3214708261 - ZeroTTL = uint32(0) - - // Network MTU - MaxMTU = 0xffff // 65k, ought to be enough for everybody - - // disable Android dnsproxyd caches - BustDnsproxydResNetCache = true -) - -var ( - // 0 TTL means no caching - blockTTL = uint32(5) - // some short-lived TTL for synthesized answers - ansTTL = uint32(15) -) - -func init() { - if BustDnsproxydResNetCache { - // setting these to 0 trips apps like GrayJay? github.com/futo-org/grayjay-android/issues/2605 - blockTTL = ZeroTTL - ansTTL = ZeroTTL - } -} - -var ( - ip4zero = net.IPv4zero - ip6zero = net.IPv6unspecified - dnsport = uint16(53) -) - -const ( - mdnsip4 = "224.0.0.251" - mdnsip6 = "ff02::fb" - mdnsPort = 5353 - - arpa4suffix = "254.169.in-addr.arpa" - // "8.e.f.ip6.arpa.", "9.e.f.ip6.arpa.", "a.e.f.ip6.arpa.", and "b.e.f.ip6.arpa." - arpa6suffix = "e.f.ip6.arpa" - localsuffix = "local" -) - -var ( - MDNSAddr4 = &net.UDPAddr{ - IP: net.ParseIP(mdnsip4), - Port: mdnsPort, - } - MDNSAddr6 = &net.UDPAddr{ - IP: net.ParseIP(mdnsip6), - Port: mdnsPort, - } -) - -var ( - errMassivePkt = errors.New("packet too large") - errRdnsUrlMissing = errors.New("url missing") - errNoAns = errors.New("no answer record") - errNoPacket = errors.New("nil dns msg") - errNotAscii = errors.New("name not ASCII string") -) - -// Net2ProxyID splits network string into proto and pid; -// proto is the network protocol and pid is the proxy ID. -// May return empty strings. -func Net2ProxyID(network string) (proto string, pids []string) { - x := strings.Split(network, ":") - if len(x) <= 0 { - return // empty - } - if len(x) >= 1 { - proto = x[0] - } - if len(x) >= 2 { - pids = strings.Split(x[1], ",") - if firstEmpty(pids) { - pids = nil - } - } - return -} - -// Bust cache as needed and if ans is not authenticated. -func BustAndroidCacheIfNeeded(ans *dns.Msg) bool { - if BustDnsproxydResNetCache && !IsDNSSECAnswerAuthenticated(ans) { - // TODO: skip negative records (SOA, NXDOMAIN, etc) - return WithTtl(ans, ZeroTTL, dns.TypeA, dns.TypeAAAA) - } - return false -} - -// NetAndProxyID joins proto and pid into a network string. -// proto is the network protocol and pid is the proxy ID. -// May return just the separator ":", if both proto, pid are empty. -func NetAndProxyID(proto string, pidcsv ...string) string { - return fmt.Sprintf("%s:%s", proto, strings.Join(pidcsv, ",")) -} - -func PrefixWithSize(packet []byte) ([]byte, error) { - packetLen := len(packet) - if packetLen > MaxMTU { - return packet, errMassivePkt - } - packet = append(append(packet, 0), 0) - copy(packet[2:], packet[:len(packet)-2]) - binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)-2)) - return packet, nil -} - -func Min(a, b int) int { - if a < b { - return a - } - return b -} - -func Max(a, b int) int { - if a > b { - return a - } - return b -} - -func StringReverse(s string) string { - r := []rune(s) - for i, j := 0, len(r)-1; i < len(r)/2; i, j = i+1, j-1 { - r[i], r[j] = r[j], r[i] - } - return string(r) -} - -// returns unique strings in n not in s as new array -func FindUnique(s []string, n []string) (u []string) { - if len(s) == 0 { - return n - } - if len(n) == 0 { - return u - } - - for _, e := range n { - uniq := true - for _, x := range s { - if e == x { - uniq = false - break - } - } - if uniq { - u = append(u, e) - } - } - - return -} - -// TODO: merge this with doh.Accept -func ReadPrefixed(conn *net.Conn) ([]byte, error) { - buf := make([]byte, 2+MaxDNSPacketSize) - packetLength, pos := -1, 0 - for { - readnb, err := (*conn).Read(buf[pos:]) - if err != nil { - return buf, err - } - pos += readnb - if pos >= 2 && packetLength < 0 { - packetLength = int(binary.BigEndian.Uint16(buf[0:2])) - if packetLength > MaxDNSPacketSize-1 { - return buf, errors.New("resp packet too large") - } - if packetLength < MinDNSPacketSize { - return buf, fmt.Errorf("resp packet too short %d", packetLength) - } - } - if packetLength >= 0 && pos >= 2+packetLength { - return buf[2 : 2+packetLength], nil - } - } -} - -// TODO: Move to dnsx? -func GetBlocklistStampFromURL(rawurl string) (string, error) { - if len(rawurl) <= 0 { - return "", errRdnsUrlMissing - } - // TODO: validate if the domain is bravedns.com/rethinkdns.com? - u, err := url.Parse(rawurl) - if err != nil { - return "", err - } - // p://url.tld or p://url.tld/ - if len(u.Path) <= 1 { - return "", errors.New("no path") - } - s := strings.TrimLeft(u.Path, "/") - i := strings.Index(s, ":") // stamps with ":" are versioned - if i == -1 { - return url.QueryEscape(s), nil - } else { // versioned stamps use path-escape - return url.PathEscape(s), nil - } - // url => path => split: - // url/p/q/r => /p/q/r/ => [' ', 'p', 'q', 'r', ' '] - // url/ => / => [' ', ' '] - // url/a/ => /a/ => [' ', 'a', ' '] - // url => "" => [''] - /* - FIXME: breaks when encoded("/") is in u.Path as split matches it too - ex:74CA77%2B%2F77%2B%2F77%2B%2F77CA -> splits-to -> [74CA77+ 77+ 77+ 77CA] - since %2F is a "/" - p := strings.Split(u.Path, "/") - if (len(p) <= 1) { - return "", errors.New("empty path") - } else if (p[1] != "dns-query" && len(p[1]) > 0) { - return p[1], nil // TODO: validate stamp? - } else if (len(p) >= 3 && len(p[2]) > 0) { - return p[2], nil // validate? - } - return "", errors.New("first two path positions missing stamp") - */ -} - -func DnsIPPort(s string) (ipp netip.AddrPort, err error) { - var ip netip.Addr - if ipp, err = netip.ParseAddrPort(s); err != nil { - if ip, err = netip.ParseAddr(s); err == nil { - ipp = netip.AddrPortFrom(ip, dnsport) - } - } - return -} - -func firstEmpty(arr []string) bool { - return len(arr) <= 0 || len(arr[0]) <= 0 -} diff --git a/intra/xdns/dnsutil.go b/intra/xdns/dnsutil.go deleted file mode 100644 index 4d10bf58..00000000 --- a/intra/xdns/dnsutil.go +++ /dev/null @@ -1,1549 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// ISC License -// -// Copyright (c) 2018-2021 -// Frank Denis - -package xdns - -import ( - "fmt" - "net" - "net/http" - "net/netip" - "strings" - "unicode/utf8" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/miekg/dns" -) - -const paddingBlockSize = 128 // RFC8467 recommendation - -// OPTION-CODE + OPTION-LENGTH -const optPaddingHeaderLen int = 2 + 2 - -func AsMsg(packet []byte) *dns.Msg { - msg, err := AsMsg2(packet) - if err != nil { - log.W("dnsutil: as msg err: %v", err) - } - return msg -} - -func AsMsg2(packet []byte) (*dns.Msg, error) { - if len(packet) < MinDNSPacketSize { - return nil, errNoPacket - } - msg := &dns.Msg{} - if err := msg.Unpack(packet); err != nil { - log.D("dnsutil: failed to unpack msg: %v", err) - return nil, err - } - return msg, nil -} - -func RequestFromResponse(msg *dns.Msg) *dns.Msg { - req := &dns.Msg{ - Compress: true, - } - req.SetQuestion(QName(msg), QType(msg)) - req.RecursionDesired = true - req.CheckingDisabled = false - req.AuthenticatedData = false - req.Authoritative = false - req.Id = msg.Id - return req -} - -func Request4FromResponse6(msg6 *dns.Msg) *dns.Msg { - if !HasAnyQuestion(msg6) { - return nil - } - msg4 := &dns.Msg{ - Compress: true, - } - msg4.SetQuestion(QName(msg6), dns.TypeA) - msg4.RecursionDesired = true - msg4.CheckingDisabled = false - msg4.AuthenticatedData = false - msg4.Authoritative = false - msg4.Id = msg6.Id - return msg4 -} - -func Request4FromRequest6(msg6 *dns.Msg) *dns.Msg { - if !HasAnyQuestion(msg6) { - return nil - } - msg4 := msg6.Copy() - msg4.SetQuestion(QName(msg6), dns.TypeA) - return msg4 -} - -func EmptyResponseFromMessage(srcMsg *dns.Msg) *dns.Msg { - if !HasAnyQuestion(srcMsg) { - return nil - } - dstMsg := dns.Msg{ - MsgHdr: srcMsg.MsgHdr, // copy id, flags, etc - Compress: true, - } - dstMsg.Question = srcMsg.Question - dstMsg.Response = true - if srcMsg.RecursionDesired { - dstMsg.RecursionAvailable = true - } - dstMsg.RecursionDesired = false - dstMsg.CheckingDisabled = false - dstMsg.AuthenticatedData = false - if edns0 := srcMsg.IsEdns0(); edns0 != nil { - dstMsg.SetEdns0(edns0.UDPSize(), edns0.Do()) - } - return &dstMsg -} - -func TruncatedResponse(packet []byte) ([]byte, error) { - if len(packet) <= 0 { - return nil, errNoAns - } - srcMsg := &dns.Msg{} - if err := srcMsg.Unpack(packet); err != nil { - return nil, err - } - dstMsg := EmptyResponseFromMessage(srcMsg) // may be nil - if dstMsg == nil { - return nil, errNoAns - } - dstMsg.Truncated = true - return dstMsg.Pack() -} - -func HasTCFlag(msg *dns.Msg) bool { - if msg == nil { - return false - } - return msg.Truncated -} - -func HasTCFlag2(packet []byte) bool { - if len(packet) < 2 { - return false - } - return packet[2]&2 == 2 -} - -func QName(msg *dns.Msg) string { - if msg == nil || len(msg.Question) <= 0 || !HasAnyQuestion(msg) { - return "" - } - q := msg.Question[0] - return q.Name -} - -func AName(ans dns.RR) (string, error) { - if ans != nil { - if ah := ans.Header(); ah != nil { - n := ah.Name - return NormalizeQName(n) - } - } - return "", errNoAns -} - -func QType(msg *dns.Msg) uint16 { - if HasAnyQuestion(msg) && len(msg.Question) > 0 { - return msg.Question[0].Qtype - } - return dns.TypeNone -} - -func Rcode(msg *dns.Msg) int { - if msg != nil { - return msg.Rcode - } - return dns.RcodeFormatError -} - -func WithTtl(msg *dns.Msg, secs uint32, typ ...uint16) (ok bool) { - if !HasAnyAnswer(msg) { - return ok - } - for _, a := range msg.Answer { - if a == nil { - continue - } - if a.Header().Ttl <= 0 || a.Header().Ttl == secs { - continue - } - resetTtl := len(typ) <= 0 - for _, t := range typ { - if a.Header().Rrtype == t { - resetTtl = true - break - } - } - if resetTtl { - a.Header().Ttl = secs - ok = true - } - } - return ok -} - -func RTtl(msg *dns.Msg) int { - maxttl := uint32(0) - if msg == nil || !HasAnyAnswer(msg) { - return int(maxttl) - } - - for _, a := range msg.Answer { - if a.Header().Ttl > 0 { - ttl := a.Header().Ttl - if maxttl < ttl { - maxttl = ttl - } - } - } - return int(maxttl) -} - -func GetTargets(msg *dns.Msg) string { - if msg == nil { - return "--" - } - - if !msg.Response { - return QName(msg) - } - - targets := make(map[string]struct{}, len(msg.Answer)) - for _, a := range msg.Answer { - nom := a.Header().Name - if len(nom) > 0 { - targets[nom] = struct{}{} - } - } - var sb strings.Builder - sb.Grow(len(targets)) - for k := range targets { - sb.WriteString(k) - sb.WriteString(",") - } - return strings.TrimSuffix(sb.String(), ",") -} - -func GetInterestingRData(msg *dns.Msg) string { - if msg == nil { - return "--" - } - var ipcsv string - ip4s := IPHints(msg, dns.SVCB_IPV4HINT) - ip6s := IPHints(msg, dns.SVCB_IPV6HINT) - data := make([]string, 0) - if len(ip4s) > 0 { - data = append(data, netips2str(ip4s)...) - } - if len(ip6s) > 0 { - data = append(data, netips2str(ip6s)...) - } - if len(data) > 0 { - ipcsv += strings.Join(data, ",") - log.D("dnsutil: RData: %s", ipcsv) - } - for _, a := range msg.Answer { - switch r := a.(type) { - case *dns.A: - if len(ipcsv) > 0 { - ipcsv += "," + ip2str(r.A) - } else { - ipcsv += ip2str(r.A) - } - case *dns.AAAA: - if len(ipcsv) > 0 { - ipcsv += "," + ip2str(r.AAAA) - } else { - ipcsv += ip2str(r.AAAA) - } - case *dns.NS: - return r.Ns - case *dns.TXT: - if len(r.Txt) > 0 { - return r.Txt[0] - } - return r.String() - case *dns.SOA: - return r.Mbox - case *dns.HINFO: - return r.Os - case *dns.SRV: - return r.Target - case *dns.CAA: - return r.Value - case *dns.MX: - return r.Mx - case *dns.RP: - return r.Mbox - case *dns.DNSKEY: - return r.PublicKey - case *dns.DS: - return r.Digest - case *dns.RRSIG: - return r.SignerName - case *dns.SVCB: - // if no hints, simply dump the entire kv list - if len(ip4s) <= 0 && len(ip6s) <= 0 { - if len(ipcsv) > 0 { - ipcsv += "," + r.String() - } else { - log.V("dnsutil: RData: svcb(%s)", r) - return svcbstr(r) - } - } else { - log.D("dnsutil: RData: ignored svcb(%s) for ipcsv(%s)", r, ipcsv) - } - continue - case *dns.HTTPS: - // if no hints, simply dump the entire kv list - if len(ip4s) <= 0 && len(ip6s) <= 0 { - if len(ipcsv) > 0 { - ipcsv += "," + r.String() - } else { - log.V("dnsutil: RData: https(%s)", r) - return httpsstr(r) - } - } else { - // https(sky.rethinkdns.com. 300 IN HTTPS 1 . - // alpn="h3,h2" - // ipv4hint="104.21.83.62,172.67.214.246" - // ech="AEX+DQBB4gAgACBdYSRjAsOpA+y22/VDM2YR/3fxGdNuepJpi9gJZm8nPgAEAAEAAQASY2xvdWRmbGFyZS1lY2guY29tAAA=" - // ipv6hint="2606:4700:3030::6815:533e,2606:4700:3030::ac43:d6f6") - // for ipcsv(104.21.83.62,172.67.214.246,2606:4700:3030::6815:533e,2606:4700:3030::ac43:d6f6) - log.D("dnsutil: RData: ignored https(%s) for ipcsv(%s)", r, ipcsv) - } - continue - case *dns.NSEC: - return r.NextDomain - case *dns.NSEC3: - return r.NextDomain - case *dns.NSEC3PARAM: - return r.Salt - case *dns.TLSA: - return r.Certificate - case *dns.OPT: - if len(ipcsv) > 0 { - ipcsv += "," + r.String() - } else { - return r.String() - } - case *dns.APL: - if len(ipcsv) > 0 { - ipcsv += "," + r.String() - } else { - return r.String() - } - case *dns.SSHFP: - return r.FingerPrint - case *dns.DNAME: - return r.Target - case *dns.NAPTR: - return r.Service - case *dns.CERT: - return r.Certificate - case *dns.DLV: - return r.Digest - case *dns.DHCID: - return r.Digest - case *dns.SMIMEA: - return r.Certificate - case *dns.NINFO: - var str string - if len(r.ZSData) > 0 { - str = r.ZSData[0] - } else { - str = r.String() - } - if len(ipcsv) > 0 { - ipcsv += "," + str - } else { - return str - } - case *dns.RKEY: - return r.PublicKey - case *dns.TKEY: - return r.OtherData - case *dns.TSIG: - return r.OtherData - case *dns.URI: - return r.Target - case *dns.HIP: - return r.PublicKey - case *dns.CDS: - return r.Digest - case *dns.OPENPGPKEY: - return r.PublicKey - case *dns.SPF: - var str string - if len(r.Txt) > 0 { - return r.Txt[0] - } else { - str = r.String() - } - if len(ipcsv) > 0 { - ipcsv += "," + str - } else { - return str - } - case *dns.NSAPPTR: - return r.Ptr - case *dns.TALINK: - return r.NextName - case *dns.CSYNC: - if len(ipcsv) > 0 { - ipcsv += "," + r.String() - } else { - return r.String() - } - case *dns.ZONEMD: - return r.Digest - default: - // no-op - continue - } - } - if len(ipcsv) > 0 { - return strings.TrimSuffix(ipcsv, ",") - } else { - return "--" - } -} - -func Targets(msg *dns.Msg) (targets []string) { - if msg == nil { - return targets - } - touched := make(map[string]struct{}) - if qname, err := NormalizeQName(QName(msg)); err == nil { - targets = append(targets, qname) - touched[qname] = struct{}{} - } - for _, a := range msg.Answer { - var target string - switch r := a.(type) { - case *dns.A: - target = r.Header().Name - case *dns.AAAA: - target = r.Header().Name - case *dns.CNAME: - target = r.Target - case *dns.SVCB: - // discard "." and "" targets - if r.Priority == 0 && len(r.Target) > 1 { - target = r.Target - } - case *dns.HTTPS: - // discard "." and "" targets - if r.Priority == 0 && len(r.Target) > 1 { - target = r.Target - } - default: - // no-op - } - if len(target) <= 0 { - continue - } else if _, ok := dns.IsDomainName(target); !ok { - // discard targets not domain names such as "." - continue - } else if x, err := NormalizeQName(target); err == nil { - if _, has := touched[x]; !has { - targets = append(targets, x) - touched[x] = struct{}{} - } - } - } - return targets -} - -func NormalizeQName(str string) (string, error) { - if len(str) == 0 || str == "." { - return ".", nil - } - hasUpper := false - str = strings.TrimSuffix(str, ".") - strLen := len(str) - for i := range strLen { - c := str[i] - if c >= utf8.RuneSelf { - return str, errNotAscii - } - hasUpper = hasUpper || ('A' <= c && c <= 'Z') - } - if !hasUpper { - return str, nil - } - var b strings.Builder - b.Grow(len(str)) - for i := range strLen { - c := str[i] - if 'A' <= c && c <= 'Z' { - c += 'a' - 'A' - } - b.WriteByte(c) - } - return b.String(), nil -} - -func RemoveEDNS0Options(msg *dns.Msg) bool { - if msg == nil { - return false - } - edns0 := msg.IsEdns0() - if edns0 == nil { - return false - } - edns0.Option = []dns.EDNS0{} - return true -} - -func ensureEDNS0(msg *dns.Msg) *dns.OPT { - edns0 := msg.IsEdns0() - if edns0 == nil { - msg.SetEdns0(uint16(MaxDNSPacketSize), false) - return msg.IsEdns0() - } - return edns0 -} - -// Create an appropriately-sized padding option. -func optPadding(sz int) *dns.EDNS0_PADDING { - return &dns.EDNS0_PADDING{ - Padding: make([]byte, sz), - } -} - -// Compute the number of padding bytes needed, excluding headers. -func ComputePaddingSize(msg *dns.Msg) int { - if msg == nil { - return 0 - } - - // msgLen is the length of a raw DNS message that contains an - // OPT RR with no RFC7830 padding option, and that the message is fully - // label-compressed. - msgLen := msg.Len() - // always add a new padding header inside the OPT RR's data. - extraPadding := optPaddingHeaderLen - - padSize := paddingBlockSize - (msgLen+extraPadding)%paddingBlockSize - return padSize % paddingBlockSize -} - -func AddEDNS0PaddingIfNoneFound(msg *dns.Msg) { - if msg == nil { - return - } - - edns0 := ensureEDNS0(msg) - if edns0 == nil { - return - } - - if edns0padlen(edns0) >= 0 { // -1 = no edns0 padding rr - return - } - - paddingLen := ComputePaddingSize(msg) - if paddingLen <= 0 { - return - } - - edns0.Option = append(edns0.Option, optPadding(paddingLen)) -} - -// IsDNSSECRequested checks if the DNSSEC OK (DO) bit is set in the DNS query. -func IsDNSSECRequested(q *dns.Msg) bool { - if q != nil { - if edns0 := q.IsEdns0(); edns0 != nil { - return edns0.Do() - } - } - return false -} - -// IsDNSSECAnswerAuthenticated checks if the DNSSEC authenticated bit is set in the DNS answer. -func IsDNSSECAnswerAuthenticated(a *dns.Msg) bool { - if a != nil { - return a.AuthenticatedData - } - return false -} - -func CopyAns(a *dns.Msg) *dns.Msg { - if a == nil { - return nil - } - out := a.Copy() - out.AuthenticatedData = false - out.RecursionDesired = false - out.CheckingDisabled = false - // TODO: msg.Ns = nil - // TODO: msg.Extra = nil - return out -} - -func Question(domain string, qtyp uint16) ([]byte, error) { - msg := &dns.Msg{} - msg.SetQuestion(dns.Fqdn(domain), qtyp) - return msg.Pack() -} - -func BlockResponseFromMessage(q []byte) (*dns.Msg, error) { - r := &dns.Msg{} - if err := r.Unpack(q); err != nil { - return r, err - } - return RefusedResponseFromMessage(r) -} - -func RefusedResponseFromMessage(srcMsg *dns.Msg) (dstMsg *dns.Msg, err error) { - if srcMsg == nil { - return nil, errNoPacket - } - dstMsg = EmptyResponseFromMessage(srcMsg) // may be nil - if dstMsg == nil { - return nil, errNoPacket - } - dstMsg.Rcode = dns.RcodeSuccess - ttl := blockTTL - - questions := srcMsg.Question - if len(questions) == 0 { - log.W("dnsutil: no q in msg %s", srcMsg) - return - } - - question := questions[0] - sendHInfoResponse := true - - if question.Qtype == dns.TypeA { - rr := new(dns.A) - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: ttl, - } - rr.A = ip4zero.To4() - if rr.A != nil { - dstMsg.Answer = []dns.RR{rr} - sendHInfoResponse = false - } - } else if question.Qtype == dns.TypeAAAA { - rr := new(dns.AAAA) - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: ttl, - } - rr.AAAA = ip6zero.To16() - if rr.AAAA != nil { - dstMsg.Answer = []dns.RR{rr} - sendHInfoResponse = false - } - } else if IsSVCBQuestion(&question) || IsHTTPQuestion(&question) { - // NODATA datatracker.ietf.org/doc/draft-ietf-dnsop-svcb-https/11 pg 37 - // prefetch.net/blog/2016/09/28/the-subtleties-between-the-nxdomain-noerror-and-nodata-dns-response-codes/ - dstMsg.Answer = nil - // NOEXTRA datatracker.ietf.org/doc/draft-ietf-dnsop-svcb-https/11 pg 16 sec 4.2 - dstMsg.Extra = nil - sendHInfoResponse = false - } - - if sendHInfoResponse { - hinfo := new(dns.HINFO) - hinfo.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeHINFO, - Class: dns.ClassINET, - Ttl: ttl, - } - hinfo.Cpu = "These are not the queries you are" - hinfo.Os = "looking for" - dstMsg.Answer = []dns.RR{hinfo} - } - - return -} - -func AQuadAForQuery(q *dns.Msg, ips ...netip.Addr) (a *dns.Msg, err error) { - return AQuadAForQueryTTL(q, ansTTL, ips...) -} - -func AQuadAForQueryTTL(q *dns.Msg, ttl uint32, ips ...netip.Addr) (a *dns.Msg, err error) { - if q == nil { - return nil, errNoPacket - } - a = EmptyResponseFromMessage(q) // may return nil - if a == nil { - return nil, errNoPacket - } - a.Rcode = dns.RcodeSuccess - - questions := q.Question - if len(questions) == 0 { - log.W("dnsutil: no q in msg %s", q) - return - } - - hasanswers := false - question := questions[0] - - for _, ip := range ips { - ipun := ip.Unmap() - is4 := ipun.Is4() - is6 := ip.Is6() - - if question.Qtype == dns.TypeA && is4 { - rr := new(dns.A) - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: ttl, - } - rr.A = ipun.AsSlice() - if len(rr.A) > 0 { - hasanswers = true - a.Answer = append(a.Answer, rr) - } - } else if question.Qtype == dns.TypeAAAA && is6 { - rr := new(dns.AAAA) - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: ttl, - } - rr.AAAA = ip.AsSlice() - if len(rr.AAAA) > 0 { - hasanswers = true - a.Answer = append(a.Answer, rr) - } - } - } - if !hasanswers { - log.E("dnsutil: unexpected q %d(%s) for ans(%s)", question.Qtype, question.Name, ips) - return nil, errNoAns - } - - return -} - -func HasRcodeSuccess(msg *dns.Msg) bool { - return msg != nil && msg.Rcode == dns.RcodeSuccess -} - -func HasAnyAnswer(msg *dns.Msg) bool { - return msg != nil && len(msg.Answer) > 0 -} - -func IsNXDomain(msg *dns.Msg) bool { - return msg != nil && msg.Rcode == dns.RcodeNameError -} - -func IsARecord(rr dns.RR) bool { - return rr != nil && core.IsNotNil(rr) && rr.Header().Rrtype == dns.TypeA -} - -func HasAAnswer(msg *dns.Msg) bool { - for _, answer := range msg.Answer { - if answer.Header().Rrtype == dns.TypeA { - rec, ok := answer.(*dns.A) - if ok && len(rec.A) >= net.IPv4len { - return true - } - } - } - return false -} - -func HasAAAAAnswer(msg *dns.Msg) bool { - for _, answer := range msg.Answer { - if answer.Header().Rrtype == dns.TypeAAAA { - rec, ok := answer.(*dns.AAAA) - if ok && len(rec.AAAA) == net.IPv6len { - return true - } - } - } - return false -} - -func SubstAAAARecords(out *dns.Msg, subip6s netip.Addr, ttl uint32) bool { - if out == nil || !subip6s.IsValid() { - return false - } - // substitute ips in any a / aaaa records - touched := make(map[string]struct{}) - rrs := make([]dns.RR, 0) - i := 0 - for _, answer := range out.Answer { - switch rec := answer.(type) { - case *dns.AAAA: - // one aaaa rec per name - if _, ok := touched[rec.Hdr.Name]; !ok { - name := rec.Hdr.Name - ip6 := subip6s.String() // todo: use different ips for different names - touched[rec.Hdr.Name] = struct{}{} - if aaaanew := MakeAAAARecord(name, ip6, ttl); aaaanew != nil { - rrs = append(rrs, aaaanew) - i++ - } else { - log.D("dnsutil: subst AAAA rec fail for %s %s %d", name, ip6, ttl) - } - } - default: - // append cnames and other records as is - rrs = append(rrs, rec) - } - } - if len(rrs) > 0 { - out.Answer = rrs - } - return len(touched) > 0 -} - -func SubstARecords(out *dns.Msg, subip4s netip.Addr, ttl uint32) bool { - if out == nil || !subip4s.IsValid() { - return false - } - // substitute ips in any a / aaaa records - touched := make(map[string]struct{}) - rrs := make([]dns.RR, 0) - i := 0 - for _, answer := range out.Answer { - switch rec := answer.(type) { - case *dns.A: - // one a rec per name - if _, ok := touched[rec.Hdr.Name]; !ok { - name := rec.Hdr.Name - ip4 := subip4s.Unmap().String() // todo: use different ips for different names - touched[rec.Hdr.Name] = struct{}{} - if anew := MakeARecord(name, ip4, ttl); anew != nil { - rrs = append(rrs, anew) - i++ - } else { - log.D("dnsutil: subst A rec fail for %s %s %d", name, ip4, ttl) - } - } - default: - // append cnames and other records as is - rrs = append(rrs, rec) - } - } - if len(rrs) > 0 { - out.Answer = rrs - } - return len(touched) > 0 -} - -func TranslateRecords(ansin *dns.Msg, typ uint16, translate func(dns.RR) (x []dns.RR, stop bool)) (ansout *dns.Msg, didTranslate bool) { - if !HasAnyAnswer(ansin) { - return - } - ansout = EmptyResponseFromMessage(ansin) // may be nil - if ansout == nil { - return - } - rrout := make([]dns.RR, 0, len(ansin.Answer)) - for _, rr := range ansin.Answer { - if rr.Header().Rrtype != typ { - // could be a CNAME record which must be preserved as-is - // to maintain the integrity of the response; as MaybeToQuadA - // will reject any non-A records. - // qname: a.com - // ans: a.com -> cname -> b.com -> ipv4 - // translated: a.com -> cname -> b.com -> ipv4 - rrout = append(rrout, rr) - } else { - rrx, stop := translate(rr) - if rrx == nil { - rrout = append(rrout, rr) - if stop { - break - } - } else { - didTranslate = true - rrout = append(rrout, rrx...) - if stop { - break - } - } - } - } - ansout.Answer = append(ansout.Answer, rrout...) - return -} - -func svcbstr(r *dns.SVCB) (s string) { - if r == nil { - return - } - for _, kv := range r.Value { - k := kv.Key().String() - v := kv.String() - s += fmt.Sprintf("%s=%s ", k, v) - } - return s -} - -func httpsstr(r *dns.HTTPS) (s string) { - if r == nil { - return - } - for _, kv := range r.Value { - k := kv.Key().String() - v := kv.String() - s += fmt.Sprintf("%s=%s ", k, v) - } - return strings.TrimSpace(s) -} - -func SubstSVCBRecordIPs(out *dns.Msg, x dns.SVCBKey, subiphints netip.Addr, ttl uint32) bool { - if out == nil || !subiphints.IsValid() { - return false - } - // substitute ip hints in https / svcb records - i := 0 - for _, answer := range out.Answer { - switch rec := answer.(type) { - case *dns.SVCB: - for j, kv := range rec.Value { - k := kv.Key() - // replace with a single ip hint - if k == x && x == dns.SVCB_IPV6HINT { - rec.Value[j] = &dns.SVCBIPv6Hint{ - Hint: []net.IP{subiphints.AsSlice()}, - } - rec.Hdr.Ttl = ttl - i++ - } else if k == x && x == dns.SVCB_IPV4HINT { - rec.Value[j] = &dns.SVCBIPv4Hint{ - Hint: []net.IP{subiphints.AsSlice()}, - } - rec.Hdr.Ttl = ttl - i++ - } - } - case *dns.HTTPS: - if rec.Priority == 0 || len(rec.Target) > 1 { - // no kv pairs to process for https records when pri is 0 - // datatracker.ietf.org/doc/draft-ietf-dnsop-svcb-https/ section 1.2 - continue - } - for j, kv := range rec.Value { - k := kv.Key() - // replace with a single ip hint - if k == x && x == dns.SVCB_IPV6HINT { - rec.Value[j] = &dns.SVCBIPv6Hint{ - Hint: []net.IP{subiphints.AsSlice()}, - } - rec.Hdr.Ttl = ttl - i++ - } else if k == x && x == dns.SVCB_IPV4HINT { - rec.Value[j] = &dns.SVCBIPv4Hint{ - Hint: []net.IP{subiphints.AsSlice()}, - } - rec.Hdr.Ttl = ttl - i++ - } - } - } - } - if i > 0 { - // datatracker.ietf.org/doc/draft-ietf-dnsop-svcb-https/11 pg 16 sec 4.2 - // remove additional records, as they may further have svcb or a / aaaa records - out.Extra = nil - } - return i > 0 -} - -func IPs(msg *dns.Msg) []netip.Addr { - return AQuadAAnswers(msg) -} - -func IPHints(msg *dns.Msg, x dns.SVCBKey) []netip.Addr { - if msg == nil { - return nil - } - qname, _ := NormalizeQName(QName(msg)) - if !HasSVCBQuestion(msg) && !HasHTTPQuestion(msg) { - log.N("dnsutil: svcb/https(%s): no record(%d)", qname, len(msg.Answer)) - return nil - } - - // extract ip hints from https / svcb records - // tools.ietf.org/html/draft-ietf-dnsop-svcb-https-02#section-8.1 - ips := []netip.Addr{} - for _, answer := range msg.Answer { - if !(answer.Header().Rrtype == dns.TypeHTTPS) && !(answer.Header().Rrtype == dns.TypeSVCB) { - continue - } - switch rec := answer.(type) { - case *dns.SVCB: - for _, kv := range rec.Value { - log.V("dnsutil: svcb(%s): current k(%v)/v(%s)", qname, kv.Key(), kv) - if kv.Key() != x { - continue - } - // ipcsv may be "" or a csv of ips - ipcsv := kv.String() - for ipstr := range strings.SplitSeq(ipcsv, ",") { - if v, err := netip.ParseAddr(ipstr); err == nil { - ips = append(ips, v) - } else { - log.W("dnsutil: svcb(%s): could not parse iphint %v", qname, ipstr) - } - } - } - case *dns.HTTPS: - for _, kv := range rec.Value { - log.V("dnsutil: https(%s): current k(%v)/v(%s)", qname, kv.Key(), kv) - if kv.Key() != x { - continue - } - // ipcsv may be "" or a csv of ips - ipcsv := kv.String() - for ipstr := range strings.SplitSeq(ipcsv, ",") { - if v, err := netip.ParseAddr(ipstr); err == nil { - ips = append(ips, v) - } else { - log.W("dnsutil: https(%s): could not parse iphint %v", qname, ipstr) - } - } - } - } - } - note := log.D - if len(ips) > 0 { - note = log.VV - } - note("dnsutil: svcb/https(%s): ip hints %v from %d answers", qname, ips, len(msg.Answer)) - return ips -} - -func AQuadAAnswers(msg *dns.Msg) (ips []netip.Addr) { - if msg == nil { - return ips - } - for _, answer := range msg.Answer { - switch rec := answer.(type) { - case *dns.A: - if ipaddr, ok := netip.AddrFromSlice(rec.A); ok { - ips = append(ips, ipaddr) - } - case *dns.AAAA: - if ipaddr, ok := netip.AddrFromSlice(rec.AAAA); ok { - ips = append(ips, ipaddr) - } - } - } - return ips -} - -func AAnswer(msg *dns.Msg) []netip.Addr { - a4 := []netip.Addr{} - if msg == nil { - return a4 - } - for _, answer := range msg.Answer { - if answer.Header().Rrtype == dns.TypeA { - if rec, ok := answer.(*dns.A); ok { - if ipaddr, ok := netip.AddrFromSlice(rec.A); ok { - a4 = append(a4, ipaddr) - } - } - } - } - return a4 -} - -func AAAAAnswer(msg *dns.Msg) []netip.Addr { - a6 := []netip.Addr{} - if msg == nil { - return a6 - } - for _, answer := range msg.Answer { - if answer.Header().Rrtype == dns.TypeAAAA { - if rec, ok := answer.(*dns.AAAA); ok { - if ipaddr, ok := netip.AddrFromSlice(rec.AAAA); ok { - a6 = append(a6, ipaddr) - } - } - } - } - return a6 -} - -// whether the qtype code is a aaaa qtype -func IsAAAAQType(qtype uint16) bool { - return qtype == dns.TypeAAAA -} - -// whether the qtype code is a A qtype -func IsAQType(qtype uint16) bool { - return qtype == dns.TypeA -} - -// whether the qtype code is a https qtype -func IsHTTPSQType(qtype uint16) bool { - return qtype == dns.TypeHTTPS -} - -// whether the qtype code is a svcb qtype -func IsSVCBQType(qtype uint16) bool { - return qtype == dns.TypeSVCB -} - -func HasAnyQuestion(msg *dns.Msg) bool { - return !(msg == nil || len(msg.Question) <= 0) -} - -// whether the given msg (ans/query) has a AAAA question section -func HasAAAAQuestion(msg *dns.Msg) bool { - if !HasAnyQuestion(msg) || len(msg.Question) <= 0 { - return false - } - q := msg.Question[0] - return q.Qclass == dns.ClassINET && IsAAAAQType(q.Qtype) -} - -// whether the given msg (ans/query) has a A question section -func HasAQuestion(msg *dns.Msg) bool { - if !HasAnyQuestion(msg) || len(msg.Question) <= 0 { - return false - } - q := msg.Question[0] - return q.Qclass == dns.ClassINET && IsAQType(q.Qtype) -} - -// whether question q is a svcb question -func IsSVCBQuestion(q *dns.Question) bool { - return q != nil && IsSVCBQType(q.Qtype) -} - -// whether question q is a https question -func IsHTTPQuestion(q *dns.Question) bool { - return q != nil && IsHTTPSQType(q.Qtype) -} - -// whether the given msg (ans/query) has a a/aaaa question section -func HasAQuadAQuestion(msg *dns.Msg) bool { - return HasAAAAQuestion(msg) || HasAQuestion(msg) -} - -// whether the given msg (ans/query) has a svcb question section -func HasSVCBQuestion(msg *dns.Msg) (ok bool) { - if !HasAnyQuestion(msg) || len(msg.Question) <= 0 { - return false - } else { - q := msg.Question[0] - ok = IsSVCBQuestion(&q) - log.N("dnsutil: svcb: %v ok? %t", q, ok) - } - return -} - -// whether the given msg (ans/query) has a https question section -func HasHTTPQuestion(msg *dns.Msg) (ok bool) { - if !HasAnyQuestion(msg) || len(msg.Question) <= 0 { - return false - } else { - q := msg.Question[0] - ok = IsHTTPQuestion(&q) - log.N("dnsutil: https: %v ok? %t", q, ok) - } - return -} - -func MakeARecord(name string, ip4 string, ttl uint32) *dns.A { - if len(ip4) <= 0 || len(name) <= 0 { - return nil - } - - b := net.ParseIP(ip4) - if len(b) <= 0 { - return nil - } - - rec := new(dns.A) - rec.Hdr = dns.RR_Header{ - Name: name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: ttl, - } - rec.A = b - return rec -} - -func MakeAAAARecord(name string, ip6 string, ttl uint32) *dns.AAAA { - if len(ip6) <= 0 || len(name) <= 0 { - return nil - } - - b := net.ParseIP(ip6) - if len(b) <= 0 { - return nil - } - - rec := new(dns.AAAA) - rec.Hdr = dns.RR_Header{ - Name: name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: ttl, - } - rec.AAAA = b - return rec -} - -// MaybeToQuadA translates an A record to a AAAA record if the prefix is not nil. -// The ttl of the new record is the max of the original ttl and minttl. -// If the prefix is nil or answer has an empty A record, it returns nil. -func MaybeToQuadA(answer dns.RR, prefix *net.IPNet) *dns.AAAA { - header := answer.Header() - if prefix == nil || header.Rrtype != dns.TypeA { - return nil - } - ipxx, aok := answer.(*dns.A) - if !aok || ipxx == nil || ipxx.A == nil { - return nil - } - ipv4 := ipxx.A.To4() - if ipv4 == nil { // TODO: do not translate bogons? - return nil - } - ttl := max(ansTTL, header.Ttl) - - ipv6 := ip4to6(*prefix, ipv4) - - trec := new(dns.AAAA) - trec.Hdr = dns.RR_Header{ - Name: header.Name, - Rrtype: dns.TypeAAAA, - Class: header.Class, - Ttl: ttl, - } - trec.AAAA = ipv6 - return trec -} - -func CloneA(base dns.RR, ip4 netip.Addr) *dns.A { - header := base.Header() - if !ip4.IsValid() || !ip4.Is4() || header.Rrtype != dns.TypeA { - return nil - } - ipxx, aok := base.(*dns.A) - if !aok || ipxx == nil || ipxx.A == nil { - return nil // only clone if the record has A data - } - c := new(dns.A) - c.Hdr = dns.RR_Header{ - Name: header.Name, - Rrtype: dns.TypeA, - Class: header.Class, - Ttl: max(ansTTL, header.Ttl), - } - c.A = ip4.Unmap().AsSlice() - return c -} - -func CloneAAAA(base dns.RR, ip6 netip.Addr) *dns.AAAA { - header := base.Header() - if !ip6.IsValid() || !ip6.Is6() || header.Rrtype != dns.TypeAAAA { - return nil - } - ipxx, aok := base.(*dns.AAAA) - if !aok || ipxx == nil || ipxx.AAAA == nil { - return nil // only clone if the record has AAAA data - } - c := new(dns.AAAA) - c.Hdr = dns.RR_Header{ - Name: header.Name, - Rrtype: dns.TypeAAAA, - Class: header.Class, - Ttl: max(ansTTL, header.Ttl), - } - c.AAAA = ip6.AsSlice() - return c -} - -func ToIp6Hint(answer dns.RR, prefix *net.IPNet) dns.RR { - header := answer.Header() - if prefix == nil { - log.W("dnsutil: toIp6Hint: prefix missing?") - return nil - } - var kv []dns.SVCBKeyValue - switch header.Rrtype { - case dns.TypeHTTPS: - if x, ok := answer.(*dns.HTTPS); ok { - kv = x.Value - } - case dns.TypeSVCB: - if x, ok := answer.(*dns.SVCB); ok { - kv = x.Value - } - default: - log.W("dnsutil: toIp6Hint: not a svcb/https record/1") - return nil - } - - if len(kv) <= 0 { - return nil - } - ttl := max(ansTTL, header.Ttl) - - hint4 := make([]string, 0) - rest := make([]dns.SVCBKeyValue, 0) - for _, x := range kv { - if x.Key() == dns.SVCB_IPV6HINT { - // ipv6hint found, no need to translate ipv4s - return nil - } else if x.Key() == dns.SVCB_IPV4HINT { - ipstr := x.String() - if len(ipstr) <= 0 { - continue - } - hint4 = append(hint4, strings.Split(ipstr, ",")...) - } else { - rest = append(rest, x) - } - } - - hint6 := new(dns.SVCBIPv6Hint) - for _, x := range hint4 { - ip4 := net.ParseIP(x) - if ip4 == nil { - log.W("dnsutil: invalid https/svcb ipv4hint %s", x) - continue - } - hint6.Hint = append(hint6.Hint, ip4to6(*prefix, ip4)) - } - - if header.Rrtype == dns.TypeSVCB { - trec := new(dns.SVCB) - trec.Hdr = dns.RR_Header{ - Name: header.Name, - Rrtype: header.Rrtype, - Class: header.Class, - Ttl: ttl, - } - trec.Value = append(rest, hint6) - return trec - } else if header.Rrtype == dns.TypeHTTPS { - trec := new(dns.HTTPS) - trec.Hdr = dns.RR_Header{ - Name: header.Name, - Rrtype: header.Rrtype, - Class: header.Class, - Ttl: ttl, - } - trec.Value = append(rest, hint6) - return trec - } else { - // should never happen - log.E("dnsutil: toIp6Hint: not a svcb/https record/2") - return nil - } -} - -func ip4to6(prefix6 net.IPNet, ip4 net.IP) net.IP { - ip6 := make(net.IP, net.IPv6len) - if len(prefix6.IP) <= 0 || len(ip4) <= 0 { - return ip6 // all zeros? - } - copy(ip6, prefix6.IP) - n, _ := prefix6.Mask.Size() - ipShift := n / 8 - for i := range net.IPv4len { - // skip byte 8, datatracker.ietf.org/doc/html/rfc6052#section-2.2 - if ipShift+i == 8 { - ipShift++ - } - ip6[ipShift+i] = ip4[i] - } - return ip6 -} - -func AQuadAUnspecified(msg *dns.Msg) bool { - if msg == nil { - return false - } - ans := msg.Answer - for _, rr := range ans { - switch v := rr.(type) { - case *dns.AAAA: - if net.IPv6zero.Equal(v.AAAA) { - return true - } - case *dns.A: - if net.IPv4zero.Equal(v.A) { - return true - } - } - } - return false -} - -func Len(msg *dns.Msg) int { - if msg == nil { - return 0 - } - if msg.Response { - return len(msg.Answer) + len(msg.Extra) - } - return len(msg.Question) -} - -func Size(msg *dns.Msg) int { - if msg == nil { - return 0 - } - return msg.Len() -} - -func EDNS0PadLen(msg *dns.Msg) int { - if msg == nil { - return -1 - } - return edns0padlen(msg.IsEdns0()) -} - -func edns0padlen(edns0 *dns.OPT) int { - if edns0 == nil { - return -1 - } - for _, opt := range edns0.Option { - if opt == nil { - continue - } - if rr, ok := opt.(*dns.EDNS0_PADDING); ok { - return len(rr.Padding) - } - } - return -1 -} - -func Ans(msg *dns.Msg) (s string) { - if msg != nil { - a := msg.Answer - if len(a) > 0 { - for _, rr := range a { - if rr != nil { - s += rr.String() + " " - } - } - } - } - return -} - -func IsServFailOrInvalid(msg *dns.Msg) bool { - if msg == nil { - return true // invalid - } - return msg.Rcode == dns.RcodeServerFailure // servfail -} - -// Servfail returns a SERVFAIL response to the query q. -func Servfail(q *dns.Msg) *dns.Msg { - if q == nil { - log.W("dnsutil: servfail: error reading q") - return nil - } - msg := q.Copy() - msg.Response = true - msg.RecursionAvailable = true - msg.Rcode = dns.RcodeServerFailure - msg.Extra = nil - return msg -} - -// GetBlocklistStampHeaderKey returns the http-header key for blocklists stamp -func GetBlocklistStampHeaderKey() string { - return http.CanonicalHeaderKey(blocklistHeaderKey) -} - -// GetBlocklistStampHeaderKey1 returns the http-header key for region set by rdns upstream on Fly -func GetRethinkDNSRegionHeaderKey1() string { - return http.CanonicalHeaderKey(rethinkdnsRegionHeaderKey) -} - -// GetBlocklistStampHeaderKey2 returns the http-header key for region set by rdns upstream on Cloudflare -func GetRethinkDNSRegionHeaderKey2() (r string) { - return http.CanonicalHeaderKey(cfRayHeaderKey) -} - -func IsMDNSQuery(qname string) bool { - svc, _ := extractMDNSDomain(qname) - // todo: check if tld is valid (local, arpa4, arpa6) - return len(svc) > 0 -} - -func ExtractMDNSDomain(msg *dns.Msg) (svc, tld string) { - if !HasAnyQuestion(msg) { - return - } - svc, _ = NormalizeQName(QName(msg)) // ex: _http._tcp.local. - return extractMDNSDomain(svc) -} - -func extractMDNSDomain(qname string) (svc, tld string) { - // ref: go.dev/play/p/kqdF0nbJj2B - // qname is assumed normalized (lower-case, without fqdn trailing dot) - // example.local. -> example.local - // rfc6762 sec 4; 254.169.in-addr.arpa - tldarpa4 := strings.LastIndex(qname, arpa4suffix) - tldarpa6 := strings.LastIndex(qname, arpa6suffix) - tldlocal := strings.LastIndex(qname, localsuffix) - if tldlocal > 0 && tldlocal == len(qname)-len(localsuffix) { - svc = qname[:tldlocal-1] // remove trailing dot; example. -> example - tld = localsuffix - } else if tldarpa4 > 0 { - svc = qname[:tldarpa4-1] // remove trailing dot - tld = arpa4suffix - } else if tldarpa6 > 0 { - // 1.1.1.1.a.e.f.ip6.arpa. -> a.e.f.ip6.arpa - tld = qname[tldarpa6-2:tldarpa6] + arpa6suffix - // 1.1.1.1.a.e.f.ip6.arpa. -> 1.1.1.1 - svc = qname[:tldarpa6-3] - } - return -} - -func netips2str(addrs []netip.Addr) []string { - var str []string - for _, x := range addrs { - str = append(str, core.UniqStringer(x)) - } - return str -} - -func ip2str(ip fmt.Stringer) string { - if ip == nil { - return "" - } - return core.UniqStringer(ip) -} diff --git a/jitpack.yml b/jitpack.yml index e2b3d701..f7919cdf 100644 --- a/jitpack.yml +++ b/jitpack.yml @@ -2,28 +2,9 @@ jdk: - openjdk14 env: PACK: "aar" - CLASSFULL: "full" - CLASSDBG: "debug" FOUT: "firestack.aar" - FOUTDBG: "firestack-debug.aar" - BOUT: "build/intra/tun2socks.aar" - BOUTDBG: "build/intra/tun2socks-debug.aar" - SOURCES: "build/intra/tun2socks-sources.jar" - NDKVER: "28.2.13676358" - SDKVER: "36" + BOUT: "build/android/tun2socks.aar" before_install: -# - sdk install java 17.0.8-jbr -# - sdk use java 17.0.8-jbr -# developer.android.com/tools/releases/platform-tools - - yes | $ANDROID_HOME/tools/bin/sdkmanager --licenses -# gomobile: failed to find android SDK platform: open /opt/android-sdk-linux/platforms: no such file or directory - - $ANDROID_HOME/tools/bin/sdkmanager "platforms;android-${SDKVER}" - - $ANDROID_HOME/tools/bin/sdkmanager "build-tools;${SDKVER}.0.0" -# ndk is at v22, for now: github.com/jitpack/jitpack.io/issues/4638 -# developer.android.com/ndk/downloads -# github.com/android/ndk-samples/wiki/Configure-NDK-Path#the-sdkmanager-command-line-tool -# github.com/AgregoreWeb/gomobile-android-docker/blob/3720e727fa/Dockerfile - - $ANDROID_HOME/tools/bin/sdkmanager --install "ndk;${NDKVER}" - ./make-aar install: - ./mvn-install @@ -31,7 +12,3 @@ install: # cmds: jitpack.io/docs/BUILDING/#custom-commands # envs: jitpack.io/docs/BUILDING/#environment-variables # bout: github.com/celzero/outline-go-tun2socks/blob/88be3c35/Makefile#L13 -# tool: developer.android.com/studio/releases/build-tools -# sdkm: developer.android.com/studio/command-line/sdkmanager -# ndkm: developer.android.com/ndk/downloads -# inst: github.com/wordpress-mobile/AztecEditor-Android/blob/5d983f8/jitpack.yml diff --git a/make-aar b/make-aar index 3589b715..d79b2d5c 100755 --- a/make-aar +++ b/make-aar @@ -1,10 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 RethinkDNS and its authors. -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. set -eux # refs: @@ -12,90 +7,32 @@ set -eux # gitlab.com/fdroid/fdroiddata/-/blob/81c14003f/metadata/com.tailscale.ipn.yml # gitlab.com/fdroid/fdroiddata/-/blob/d6c5315a/metadata/org.calyxinstitute.vpn.yml -# defaults; can be overridden for ex by jitpack.yml -PACK="${PACK:-aar}" -# final out -FOUT="${FOUT:-firestack.aar}" -FOUTDBG="${FOUTDBG:-firestack-debug.aar}" -# build out -BOUT="${BOUT:-build/intra/tun2socks.aar}" -BOUTDBG="${BOUTDBG:-build/intra/tun2socks-debug.aar}" -# artifact classifier -CLASSFULL="${CLASSFULL:-full}" # unused -CLASSDBG="${CLASSDBG:-debug}" -# artifact bytecode sources -SOURCES="${SOURCES:-build/intra/tun2socks-sources.jar}" -# android sdk/ndk versions (ANDROID_NDK_LATEST_HOME is set on github) -LATESTNDKHOME=${ANDROID_NDK_LATEST_HOME:-0} -# default: stackoverflow.com/a/2013573 -NDKVER="${NDKVER:-0}" -SDKVER="${SDKVER:-36}" +# download golang +curl -Lso go.tar.gz https://golang.org/dl/go1.15.4.linux-amd64.tar.gz +echo "eb61005f0b932c93b424a3a4eaa67d72196c79129d9a3ea8578047683e2c80d5 go.tar.gz" | sha256sum -c - -ARG1="${1:-go}" -ARG2="${2:-debug}" +# setup golang +mkdir -p golang +tar -C golang -xzf go.tar.gz +export GOPATH="$(pwd)" +export GO_LANG="$(pwd)/golang/go/bin" +export GO_COMPILED="$(pwd)/bin" +export PATH="$GO_LANG:$GO_COMPILED:$PATH" -# debug -printenv +# init gomobile +go get golang.org/x/mobile/cmd/gomobile +gomobile init -ls -ltr $ANDROID_HOME/** -head $ANDROID_HOME/ndk-bundle/source.properties || true +# download firestack +git clone "https://github.com/celzero/outline-go-tun2socks.git" -b "$VERSION" go-firestack +cd go-firestack -# gomobile picks up the latest ndk by walking the ndk dir -# in $ANDROID_HOME (on gh-actions: /usr/local/lib/android/sdk/ndk) -# if NDKVER is set, override gomobile's behaviour by force setting -# $ANDROID_NDK_HOME & $ANDROID_NDK_ROOT to the requested NDK version. -if [ "$NDKVER" != "0" ]; then - ANDROID_NDK_HOME="${ANDROID_HOME}/ndk/${NDKVER}" - ANDROID_NDK_ROOT="${ANDROID_NDK_HOME}" - # ls will fail if NDKVER is missing, which is what we want! - ls -ltr $ANDROID_NDK_HOME - head $ANDROID_NDK_HOME/source.properties || true -elif [ "$LATESTNDKHOME" != "0" ]; then - # use the latest NDK version, set by github actions - ANDROID_NDK_HOME="${ANDROID_NDK_LATEST_HOME}" - ANDROID_NDK_ROOT="${ANDROID_NDK_HOME}" - # ls will fail if ANDROID_NDK_LATEST_HOME is missing, which is what we want! - ls -ltr $ANDROID_NDK_HOME - head $ANDROID_NDK_HOME/source.properties || true -fi - -if [ "$ARG1" = "go" ]; then - # download from go.dev/dl - curl -Lso go.tar.gz https://go.dev/dl/go1.26.0.linux-amd64.tar.gz - echo "aac1b08a0fb0c4e0a7c1555beb7b59180b05dfc5a3d62e40e9de90cd42f88235 go.tar.gz" | sha256sum -c - - - # HOME=/home/jitpack - # PWD=/home/jitpack/build - # setup go, /opt isn't writeable - export GOPATH="$HOME/golang" - mkdir -p $GOPATH - # golang in pwd confuses "go mod", as firestack source is in the same dir - tar -C $GOPATH -xzf go.tar.gz - export GO_LANG="$GOPATH/go/bin" - export GO_COMPILED="$GOPATH/bin" - export PATH="$GO_LANG:$GO_COMPILED:$PATH" -fi - -# go debug -go version -go env - -# checkout tagged branch? -# git checkout -b "$VERSION" +# godeps +go get -d ./... # gomobile aar +./build_android.sh intra -if [ "$ARG2" = "debug" ]; then - # default: with debug builds - make clean && make intra && make intradebug - # rename - mv ./"$BOUT" ./"$FOUT" - mv ./"$BOUTDBG" ./"$FOUTDBG" -else - make clean && make intra - # rename - mv ./"$BOUT" ./"$FOUT" -fi +# rename +mv ./"$BOUT" ./"$FOUT" -# ls cwd -ls -ltr diff --git a/mvn-install b/mvn-install index 77cb0d6e..4f033da5 100755 --- a/mvn-install +++ b/mvn-install @@ -1,10 +1,4 @@ #!/bin/bash -# -# Copyright (c) 2021 RethinkDNS and its authors. -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. set -eux @@ -14,30 +8,9 @@ set -eux # repo: jitpack.io/com/github/$GROUP/$ARTIFACT/$VERSION/ # logs: jitpack.io/com/github/$GROUP/$ARTIFACT/$VERSION/build.log -# from: github.com/ignoramous/outline-go-tun2socks/tree/jpac2 -# maven.apache.org/plugins/maven-install-plugin/examples/installing-secondary-artifacts.html mvn install:install-file \ -Dfile=${FOUT} \ -Dpackaging=${PACK} \ -Dversion=${VERSION} \ -DgroupId=${GROUP} \ - -DartifactId=${ARTIFACT} \ - -Dsources=${SOURCES} -# -Dclassifier=${CLASSFULL} \ - -# usage: docs.gradle.org/current/userguide/dependency_management.html -# mvn install:install-file -# -Dfile=firestack.aar -# -Dpackaging=aar -# -Dversion=9744e5fad9 -# -DgroupId=com.github.celzero -# -DartifactId=firestack -# -Dsources=build/intra/tun2socks-sources.jar -mvn -X install:install-file \ - -Dfile=${FOUTDBG} \ - -Dpackaging=${PACK} \ - -Dversion=${VERSION} \ - -DgroupId=${GROUP} \ - -DartifactId=${ARTIFACT} \ - -Dclassifier=${CLASSDBG} \ - -Dsources=${SOURCES} \ No newline at end of file + -DartifactId=${ARTIFACT} diff --git a/ossrhpom.xml b/ossrhpom.xml deleted file mode 100644 index 99e561ba..00000000 --- a/ossrhpom.xml +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - 4.0.0 - com.celzero - firestack - 0.1 - jar - - Firestack - Userspace wireguard and network monitor. - https://github.com/celzero/firestack - - - UTF-8 - - - - - Mozilla Public License 2.0 - https://opensource.org/license/mpl-2-0 - - - - - - The Rethink DNS Open Source Project - hello@celzero.com - Celzero - https://celzero.com - - - - - scm:git:git://github.com/celzero/firestack.git - scm:git:ssh://github.com:celzero/firestack.git - https://github.com/celzero/firestack/tree/main - - - - - - ossrh - https://central.sonatype.com/repository/maven-snapshots/ - - - ossrh - https://ossrh-staging-api.central.sonatype.com/service/local/staging/deploy/maven2/ - - - \ No newline at end of file diff --git a/tools/runtime_write_err_android.patch b/tools/runtime_write_err_android.patch deleted file mode 100644 index 82d0145d..00000000 --- a/tools/runtime_write_err_android.patch +++ /dev/null @@ -1,22 +0,0 @@ -diff --git a/src/runtime/write_err_android.go b/src/runtime/write_err_android.go -index bcc934e54c0461..3551641a89dad9 100644 ---- a/src/runtime/write_err_android.go -+++ b/src/runtime/write_err_android.go -@@ -10,7 +10,7 @@ import ( - ) - - var ( -- writeHeader = []byte{6 /* ANDROID_LOG_ERROR */, 'G', 'o', 0} -+ writeHeader = []byte{7 /* ANDROID_LOG_FATAL*/, 'G', 'o', 'W', 't', 'f', 0} - writePath = []byte("/dev/log/main\x00") - writeLogd = []byte("/dev/socket/logdw\x00") - -@@ -149,7 +149,7 @@ func writeLogdHeader() int { - // hdr[3:11] log_time defined in - // hdr[3:7] sec unsigned uint32, little endian. - // hdr[7:11] nsec unsigned uint32, little endian. -- hdr[0] = 0 // LOG_ID_MAIN -+ hdr[0] = 4 // LOG_ID_CRASH - sec, nsec, _ := time_now() - byteorder.LEPutUint32(hdr[3:7], uint32(sec)) - byteorder.LEPutUint32(hdr[7:11], uint32(nsec)) diff --git a/tools/tools.go b/tools/tools.go deleted file mode 100644 index 2e174757..00000000 --- a/tools/tools.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build tools -// +build tools - -// See github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module -// and github.com/go-modules-by-example/index/blob/master/010_tools/README.md - -package tools - -import ( - _ "github.com/crazy-max/xgo" - _ "github.com/tailscale/depaware/depaware" - _ "golang.org/x/mobile/cmd/gomobile" -) diff --git a/tunnel/depaware.txt b/tunnel/depaware.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/tunnel/dialer.go b/tunnel/dialer.go deleted file mode 100644 index 28a29584..00000000 --- a/tunnel/dialer.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package tunnel - -import ( - "net" - "net/netip" - - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/protect" - "github.com/celzero/firestack/intra/settings" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" -) - -var _ protect.RDialer = (*gtunnel)(nil) - -// ID implements protect.RDialer. -func (h *gtunnel) ID() string { - return "gtunnel" -} - -// Dial implements protect.RDialer. -func (t *gtunnel) Dial(network, addr string) (protect.Conn, error) { - taddr, proto := fulladdr(addr) // taddr may be nil - switch network { - case "tcp", "tcp4", "tcp6": - if taddr == nil { - taddr = &tcpip.FullAddress{} - } - return gonet.DialTCP(t.stack, *taddr, proto) - case "udp", "udp4", "udp6": - return gonet.DialUDP(t.stack, nil, taddr, proto) - } - - log.E("tun: dial: invalid network: %s to %s", network, addr) - return nil, &net.OpError{ - Op: "tun: dial", - Net: network, - Source: netaddr(addr), - Addr: nil, - Err: net.UnknownNetworkError(network), - } -} - -// Dial implements protect.RDialer. -func (t *gtunnel) DialBind(network, local, remote string) (protect.Conn, error) { - taddr, proto := fulladdr(remote) // taddr may be nil - laddr, _ := fulladdr(local) // stack must allow spoofing - switch network { - case "tcp", "tcp4", "tcp6": - if taddr == nil { // todo: error? - taddr = &tcpip.FullAddress{} - } - if laddr == nil { // ok - laddr = &tcpip.FullAddress{} - } - return gonet.DialTCPWithBind(t.ctx, t.stack, *laddr, *taddr, proto) - case "udp", "udp4", "udp6": - return gonet.DialUDP(t.stack, laddr, taddr, proto) - } - - log.E("tun: dial: invalid network: %s to %s<=%s", network, remote, local) - return nil, &net.OpError{ - Op: "tun: dialbind", - Net: network, - Source: netaddr(remote), - Addr: netaddr(local), - Err: net.UnknownNetworkError(network), - } -} - -// Announce implements protect.RDialer. -func (t *gtunnel) Announce(network, local string) (protect.PacketConn, error) { - taddr, proto := fulladdr(local) // taddr may be nil - switch network { - case "udp", "udp4", "udp6": - return gonet.DialUDP(t.stack, taddr, nil, proto) - } - - log.E("tun: announce: invalid network: %s to %s", network, local) - return nil, &net.OpError{ - Op: "tun: announce", - Net: network, - Addr: netaddr(local), - Source: nil, - Err: net.UnknownNetworkError(network), - } -} - -// Accept implements protect.RDialer. -func (t *gtunnel) Accept(network, local string) (protect.Listener, error) { - taddr, proto := fulladdr(local) // taddr may be nil - if taddr == nil { - log.E("tun: accept: invalid addr: %s", local) - return nil, &net.AddrError{Err: "tun: dial: invalid addr", Addr: local} - } - switch network { - case "tcp", "tcp4", "tcp6": - return gonet.ListenTCP(t.stack, *taddr, proto) - } - - log.E("tun: accept: invalid network: %s to %s", network, local) - return nil, &net.OpError{ - Op: "tun: accept", - Net: network, - Addr: netaddr(local), - Source: nil, - Err: net.UnknownNetworkError(network), - } -} - -// Probe implements protect.RDialer. -func (t *gtunnel) Probe(network, local string) (protect.PacketConn, error) { - // TODO: implement probe - return nil, &net.OpError{Op: "probe", - Net: network, - Addr: netaddr(local), - Source: nil, - Err: net.UnknownNetworkError(network), - } -} - -func fulladdr(addr string) (a *tcpip.FullAddress, pn tcpip.NetworkProtocolNumber) { - ipp, err := netip.ParseAddrPort(addr) - if ipp.Addr().Is4() { - pn = ipv4.ProtocolNumber - } else { - pn = ipv6.ProtocolNumber - } - if err != nil || !ipp.IsValid() { // unlikely - log.V("tun: dial: invalid addr: proto(%d) %s; err? %v", pn, addr, err) - return nil, pn - } - return fullAddrFrom(ipp), pn -} - -func fullAddrFrom(ipp netip.AddrPort) *tcpip.FullAddress { - var nsdaddr tcpip.Address - if !ipp.IsValid() { - return nil - } - if ipp.Addr().Is4() { - nsdaddr = tcpip.AddrFrom4(ipp.Addr().As4()) - } else { - nsdaddr = tcpip.AddrFrom16(ipp.Addr().As16()) - } - log.V("tun: dial: translate ipp: %v -> %v", ipp, nsdaddr) - return &tcpip.FullAddress{ - NIC: settings.NICID, - Addr: nsdaddr, - Port: ipp.Port(), // may be 0 - } -} - -// netaddr is a net.Addr that returns "any" as its network. -// Only used for error reporting. -type netaddr string - -func (n netaddr) Network() string { - return "any" -} - -func (n netaddr) String() string { - return string(n) -} diff --git a/tunnel/sink.go b/tunnel/sink.go deleted file mode 100644 index 47c1d04b..00000000 --- a/tunnel/sink.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) 2024 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package tunnel - -import ( - "context" - "io" - - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" -) - -type pcapsink struct { - ctx context.Context - done context.CancelFunc - sink *core.Volatile[io.WriteCloser] - inC chan []byte // always buffered -} - -// nowrite rejects all writes. -type nowrite struct{} - -var _ io.WriteCloser = (*nowrite)(nil) -var _ io.WriteCloser = (*pcapsink)(nil) - -func (*nowrite) Write(b []byte) (int, error) { return len(b), nil } -func (*nowrite) Close() error { return nil } - -func newSink(pctx context.Context) *pcapsink { - ctx, cancel := context.WithCancel(pctx) - // go.dev/play/p/4qANL9VSDXb - p := new(pcapsink) - p.ctx = ctx - p.done = cancel - p.sink = core.NewVolatile[io.WriteCloser](zerowriter) - p.log(false) // no log - p.fout(false) // no file out - p.inC = make(chan []byte, 128) - core.Go("pcap.w", func() { p.writeAsync() }) - context.AfterFunc(ctx, func() { - defer close(p.inC) // signal writeAsync to exit - p.recycle() - }) - return p -} - -func (p *pcapsink) Write(b []byte) (int, error) { - select { - case <-p.ctx.Done(): // closed - default: - select { - case <-p.ctx.Done(): // closed - case p.inC <- b: - return len(b), nil - default: // drop - return len(b), nil - } - } - return 0, io.ErrClosedPipe // err here may panic netstack's sniffer -} - -// writeAsync consumes [p.in] until close. -func (p *pcapsink) writeAsync() { - for b := range p.inC { // winsy spider - w := p.sink.Load() // always re-load current writer - if w != nil && w != zerowriter { - n, err := w.Write(b) - log.VV("tun: pcap: writeAsync: n: %d, err? %v", n, err) - } // else: no op - } -} - -func (p *pcapsink) recycle() error { - p.log(false) // detach - err := p.file(nil) // detach - return err -} - -func (p *pcapsink) Close() error { - p.done() - return nil -} - -func (p *pcapsink) file(f io.WriteCloser) (err error) { - if f == nil || core.IsNil(f) { - f = zerowriter - } - - old := p.sink.Tango(f) // old may be nil - core.CloseOp(old, core.CopRW) - - y := f != zerowriter - if y { - // from: github.com/google/gvisor/blob/596e8d22/pkg/tcpip/link/sniffer/sniffer.go#L93 - err = netstack.WritePCAPHeader(f) // write pcap header before any packets - log.I("tun: pcap: begin: writeHeader; err(%v)", err) - } - p.fout(y) - return -} - -func (p *pcapsink) log(y bool) bool { - return netstack.Pcap2Stdout(y) -} - -func (p *pcapsink) fout(y bool) bool { - return netstack.Pcap2File(y) -} - -func (p *pcapsink) mode() string { - return netstack.PcapModes() -} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go deleted file mode 100644 index 0803aefa..00000000 --- a/tunnel/tunnel.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright (c) 2020 RethinkDNS and its authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. -// -// This file incorporates work covered by the following copyright and -// permission notice: -// -// Copyright 2019 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tunnel - -import ( - "context" - "errors" - "fmt" - "os" - "path/filepath" - "strconv" - "sync" - "sync/atomic" - "syscall" - "time" - - x "github.com/celzero/firestack/intra/backend" - "github.com/celzero/firestack/intra/core" - "github.com/celzero/firestack/intra/log" - "github.com/celzero/firestack/intra/netstack" - "github.com/celzero/firestack/intra/settings" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// Tunnel represents a session on a TUN device. -type Tunnel interface { - // IsConnected indicates whether the tunnel is in a connected state. - IsConnected() bool - // Disconnect disconnects the tunnel. - Disconnect() - // Enabled checks if the tunnel is up and running. - Enabled() bool - // Mtu returns the current MTU of the tunnel (tun MTU). - Mtu() int32 - // Creates a new link using fd (tun device). - SetLinkAndRoutes(fd, tunmtu, engine int) error - // Unsets existing link and closes the fd (tun device). - Unlink() error - // Set or unset the pcap sink - SetPcap(fpcap string) error - // NIC, IP, TCP, UDP, and ICMP stats. - Stat() (*x.NetStat, error) -} - -type gtunnel struct { - ctx context.Context - done context.CancelFunc - stack *stack.Stack // a tcpip stack - ep netstack.SeamlessEndpoint // endpoint for the stack - sid *core.Volatile[int] // session id (almost always tunnel fd) - hdl netstack.GConnHandler // tcp, udp, and icmp handlers - pcapio *pcapsink // pcap output, if any - closed atomic.Bool // open/close? - once sync.Once -} - -var _ Tunnel = (*gtunnel)(nil) - -var ( - errInvalidTunFd = errors.New("invalid tun fd") - zerowriter = &nowrite{} -) - -func (t *gtunnel) Mtu() int32 { - if t.IsConnected() { - // return int32(t.stack.NICInfo()[0].MTU) - return int32(t.ep.MTU()) - } - return -1 -} - -func (t *gtunnel) waitForEndpoint(ctx context.Context) { - defer core.Recover(core.Exit11, "g.wait") - - const maxchecks = 5 - const betweenChecks = 3 * time.Second - const uptimeThreshold = 3 * time.Second - - waitStart := time.Now() - i := 0 - - defer func() { - log.I("tun: waiter: done; #%d, %s", i, core.FmtTimeAsPeriod(waitStart)) - }() - - for i < maxchecks && !t.closed.Load() { - // wait a bit to let the endpoint settle - time.Sleep(betweenChecks) - start := time.Now() - runid := "g." + strconv.Itoa(i) - - select { - case <-ctx.Done(): - t.Disconnect() // may already be disconnected - log.D("tun: waiter: ctx done; #%d", i) - return - case <-core.SigFin(runid, t.ep.Wait): // wait until endpoint closes - log.D("tun: waiter: endpoint not running; #%d", i) - } - - // if the endpoint was up for more than uptimeThreshold, - // reset the counter and do another set of maxchecks - // as a new endpoint may have been created in between - // see: SetLink -> t.ep.Swap - if uptime := time.Since(start); uptime >= uptimeThreshold { - i = 0 // good ep just closed, restart maxchecks - } else { // no endpoint / bad endpoint still closed - // ep.Wait was super quick, and it is possible - // no endpoint will show up in the next few checks - // but if it does, then i is reset to 0 anyway - i++ - } - } - if !t.closed.Load() { - // the endpoint closed without a Disconnect, this may happen - // in cases where a panic was recovered and endpoint was - // closed without a t.ep.Swap or t.stack.Destroy - log.U(fmt.Sprintf("Deactivated! Down after %s", core.FmtTimeAsPeriod(waitStart))) - // todo: disconnect parent tunnel - t.Disconnect() // may already be disconnected - } -} - -func (t *gtunnel) Disconnect() { - defer core.Recover(core.Exit11, "g.Disconnect") - - // no core.Recover here as the tunnel is disconnecting anyway - t.once.Do(func() { - t.closed.Store(true) - // go t.Unlink() // may block? takes more time? - t.stack.Destroy() - log.I("tun: netstack closed") - }) -} - -func (t *gtunnel) Enabled() bool { - s := t.stack - - // nic may be down even if tunnel is up, when SetLink is in between - // removing existing nic and creating a new one. - return s != nil && s.CheckNIC(settings.NICID) -} - -func (t *gtunnel) IsConnected() bool { - return !t.closed.Load() -} - -// fd must be non-blocking. -func NewGTunnel(pctx context.Context, fd, mtu int, l3 string, hdl netstack.GConnHandler) (t *gtunnel, rev netstack.GConnHandler, err error) { - myfd, err := maybeDup(fd) // tunnel will own myfd - if err != nil { - return nil, nil, err - } - - ctx, done := context.WithCancel(pctx) - - sink := newSink(ctx) - stack := netstack.NewNetstack() // always dual-stack - // NewEndpoint takes ownership of myfd; closes it on errors - ep, eerr := netstack.NewEndpoint(myfd, mtu, sink) - if eerr != nil { - done() - return nil, nil, eerr - } - - var nic tcpip.NICID - if l3 != settings.IP46 { - l3 = settings.IP46 // always dual-stack - log.W("tun: new netstack(%d) l3 is %s needed %s", fd, l3, settings.IP46) - } - netstack.Route(stack, l3) - // Enabled() may temporarily return false when Up() is in progress. - if nic, err = netstack.Up(stack, ep, hdl); err != nil { // attach new endpoint - done() - return nil, nil, err - } - - rev = netstack.NewReverseGConnHandler(ctx, stack, nic, ep, hdl) - - log.I("tun: new netstack(%d) up; fd(%d=>%d), mtu(%d)", nic, fd, myfd, mtu) - - t = >unnel{ - ctx: ctx, - done: done, - stack: stack, - ep: ep, - sid: core.NewVolatile(fd), // fd is the og tun device - hdl: hdl, - pcapio: sink, - closed: atomic.Bool{}, - once: sync.Once{}, - } - - core.Go1("tun.awaiter", t.waitForEndpoint, ctx) - - return -} - -func (t *gtunnel) SetPcap(fp string) error { - defer core.Recover(core.Exit11, "g.SetPcap") - - pcap := t.pcapio - - ignored := pcap.recycle() // close any existing pcap sink - if len(fp) == 0 { - log.I("tun: pcap closed (ignored-err? %v)", ignored) - return nil // nothing else to do; pcap closed - } else if len(fp) == 1 { - // if fdpcap is 0, 1, or 2 then pcap is written to stdout - ok := pcap.log(true) - log.I("tun: pcap(%s)/log(%t)", fp, ok) - return nil // fdbased will write to stdout - } else if fout, err := os.OpenFile(filepath.Clean(fp), os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600); err == nil { - ignored = pcap.file(fout) // attach - log.I("tun: pcap(%s)/file(%v) (ignored-err? %v)", fp, fout, ignored) - return nil // sniffer will write to fout - } else { - log.E("tun: pcap(%s); (err? %v)", fp, err) - return err // no pcap - } -} - -func (t *gtunnel) Unlink() error { - defer core.Recover(core.Exit11, "g.Unlink") - - return t.ep.Dispose() -} - -func (t *gtunnel) SetLinkAndRoutes(fd, mtu, engine int) (err error) { - defer core.Recover(core.Exit11, "g.SetLinkAndRoutes") - - if err := t.setLink(fd, mtu); err != nil { - return err - } - return t.setRoute(engine) -} - -func (t *gtunnel) setLink(fd, mtu int) (err error) { - defer func() { - if err != nil { - t.sid.Store(-1) // reset sid - } else { - t.sid.Store(fd) // set sid to fd - } - }() - - myfd, err := maybeDup(fd) // endpoint owns myfd - if err != nil { - log.E("tun: new link; err %v", err) - return err - } - - err = t.ep.Swap(myfd, mtu) // swap fd and mtu - - log.I("tun: new link, fd(%d => %d) mtu(%d); err? %v", fd, myfd, mtu, err) - return err -} - -func (t *gtunnel) setRoute(engine int) error { - // netstack route is never changed; always dual-stack - netstack.Route(t.stack, settings.IP46) - doHappyEyeballs := engine == settings.Ns46 - ok := settings.HappyEyeballs.CompareAndSwap(!doHappyEyeballs, doHappyEyeballs) - log.I("tun: new route; (no-op) got %s but set %s; enable happy eyeballs? %t / ok? %t", - settings.L3(engine), settings.IP46, doHappyEyeballs, ok) - return nil -} - -func (t *gtunnel) Stat() (*x.NetStat, error) { - st, err := netstack.Stat(t.stack) - if err == nil && st != nil { - st.TUNSt.Open = !t.closed.Load() - st.TUNSt.Up = t.ep.IsAttached() - st.TUNSt.Sid = t.sid.Load() // session id (tunnel fd) - st.TUNSt.Mtu = int32(t.ep.MTU()) - st.TUNSt.PcapMode = t.pcapio.mode() - st.TUNSt.EpStats = t.ep.Stat().String() - - if t := t.hdl.TCP(); t != nil { - st.RDNSIn.OpenConnsTCP = t.OpenConns() - } - if u := t.hdl.UDP(); u != nil { - st.RDNSIn.OpenConnsUDP = u.OpenConns() - } - if i := t.hdl.ICMP(); i != nil { - st.RDNSIn.OpenConnsICMP = i.OpenConns() - } - } - return st, err -} - -// copy so golang gc may not close orig fd -func maybeDup(fd int) (int, error) { - if fd < 0 { - return 0, errInvalidTunFd - } - if settings.OwnTunFd.Load() { - // if OwnTunFd is true, then do not dup the fd - // as netstack owns the TUN fd and will not - // assume ownership of the TUN fd shared with it. - log.I("tun: assuming fd ownership %d", fd) - return fd, nil - } - - // ref: github.com/mdlayher/socket/blob/9c51a391b/conn.go#L309 - // fctnl(2) to dup the fd & set cloexec in one syscall - newfd, err := unix.FcntlInt(uintptr(fd), unix.F_DUPFD_CLOEXEC, 0) - if err == nil { // success - return newfd, nil - } else if err == unix.EINVAL { // fallback - // Mirror the standard library: avoid racing a fork/exec with dup - // so that child does not inherit socket fds unexpectedly. - syscall.ForkLock.RLock() - defer syscall.ForkLock.RUnlock() - - newfd, err := unix.Dup(fd) - if err == nil { - unix.CloseOnExec(newfd) - } - return newfd, err - } // other errors? - return 0, os.NewSyscallError("fcntl", err) -}